package com.mindspore.flclient.model;

import com.mindspore.Graph;
import com.mindspore.MSTensor;
import com.mindspore.Model;
import com.mindspore.config.MSContext;
import com.mindspore.config.TrainCfg;
import com.mindspore.flclient.LocalFLParameter;
import com.mindspore.flclient.common.FLLoggerGenerater;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.logging.Logger;
import java.util.stream.IntStream;
import mindspore.fl.schema.FeatureMap;

/* loaded from: input_file:com/mindspore/flclient/model/ModelProxy.class */
public class ModelProxy {
    private static final Logger logger = FLLoggerGenerater.getModelLogger(Client.class.toString());
    private Model model;
    private List<MSTensor> inputs;
    private Map<RunType, DataSet> dataSets = new HashMap();
    private final List<ByteBuffer> inputsBuffer = new ArrayList();
    private HashMap<String, MSTensor> featureMap = new HashMap<>();
    private float uploadLoss = 0.0f;

    public Model getModel() {
        return this.model;
    }

    public float getUploadLoss() {
        return this.uploadLoss;
    }

    public void setUploadLoss(float f) {
        this.uploadLoss = f;
    }

    public void free() {
        if (this.model != null) {
            this.inputs.forEach((v0) -> {
                v0.free();
            });
            this.featureMap.forEach((str, mSTensor) -> {
                mSTensor.free();
            });
            this.model.free();
            this.model = null;
        }
    }

    private MSContext getMsContext() {
        int deviceType = LocalFLParameter.getInstance().getDeviceType();
        int threadNum = LocalFLParameter.getInstance().getThreadNum();
        int cpuBindMode = LocalFLParameter.getInstance().getCpuBindMode();
        boolean isEnableFp16 = LocalFLParameter.getInstance().isEnableFp16();
        MSContext mSContext = new MSContext();
        if (!mSContext.init(threadNum, cpuBindMode)) {
            logger.severe("Call msContext.init failed, threadNum " + threadNum + ", cpuBindMode " + cpuBindMode);
            mSContext.free();
            return null;
        }
        if (mSContext.addDeviceInfo(deviceType, isEnableFp16, 0)) {
            return mSContext;
        }
        logger.severe("Call msContext.addDeviceInfo failed, deviceType " + deviceType + ", enableFp16 " + isEnableFp16);
        mSContext.free();
        return null;
    }

    private boolean initModelWithoutShape(String str, MSContext mSContext) {
        TrainCfg trainCfg = new TrainCfg();
        if (!trainCfg.init()) {
            logger.severe("Call trainCfg.init failed ...");
            mSContext.free();
            trainCfg.free();
            return false;
        }
        Graph graph = new Graph();
        if (!graph.load(str)) {
            logger.severe("Call graph.load failed, modelPath: " + str);
            graph.free();
            trainCfg.free();
            mSContext.free();
            return false;
        }
        this.model = new Model();
        if (!this.model.build(graph, mSContext, trainCfg)) {
            logger.severe("Call model.build failed ... ");
            graph.free();
            this.model.free();
            return false;
        }
        graph.free();
        this.inputs = this.model.getInputs();
        Iterator<MSTensor> it = this.inputs.iterator();
        while (it.hasNext()) {
            ByteBuffer allocateDirect = ByteBuffer.allocateDirect((int) it.next().size());
            allocateDirect.order(ByteOrder.nativeOrder());
            this.inputsBuffer.add(allocateDirect);
        }
        for (MSTensor mSTensor : this.model.getFeatureMaps()) {
            this.featureMap.put(mSTensor.tensorName(), mSTensor);
        }
        return true;
    }

    private boolean initModelWithShape(String str, MSContext mSContext, int[][] iArr) {
        this.model = new Model();
        if (!this.model.build(str, 0, mSContext)) {
            logger.severe("Call model.build failed ... ");
            this.model.free();
            return false;
        }
        this.inputs = this.model.getInputs();
        if (!this.model.resize(this.inputs, iArr)) {
            this.model.free();
            logger.severe("session resize failed");
            return false;
        }
        for (int[] iArr2 : iArr) {
            ByteBuffer allocateDirect = ByteBuffer.allocateDirect(IntStream.of(iArr2).reduce((i, i2) -> {
                return i * i2;
            }).getAsInt() * 4);
            allocateDirect.order(ByteOrder.nativeOrder());
            this.inputsBuffer.add(allocateDirect);
        }
        for (MSTensor mSTensor : this.model.getFeatureMaps()) {
            this.featureMap.put(mSTensor.tensorName(), mSTensor);
        }
        return true;
    }

    public Status initModel(String str, int[][] iArr) {
        if (str == null) {
            logger.severe("session init failed");
            return Status.FAILED;
        }
        MSContext msContext = getMsContext();
        if (msContext == null) {
            return Status.FAILED;
        }
        return iArr == null ? initModelWithoutShape(str, msContext) : initModelWithShape(str, msContext, iArr) ? Status.SUCCESS : Status.FAILED;
    }

    private void fillModelInput(DataSet dataSet, int i) {
        dataSet.fillInputBuffer(this.inputsBuffer, i);
        for (int i2 = 0; i2 < this.inputs.size(); i2++) {
            this.inputs.get(i2).setData(this.inputsBuffer.get(i2));
        }
    }

    public Status runModel(int i, List<Callback> list, DataSet dataSet) {
        LocalFLParameter localFLParameter = LocalFLParameter.getInstance();
        long currentTimeMillis = System.currentTimeMillis();
        for (int i2 = 0; i2 < i; i2++) {
            for (int i3 = 0; i3 < dataSet.batchNum; i3++) {
                if (localFLParameter.isStopJobFlag()) {
                    logger.info("the stopJObFlag is set to true, the job will be stop");
                    return Status.FAILED;
                }
                fillModelInput(dataSet, i3);
                if (!this.model.runStep()) {
                    logger.severe("run graph failed");
                    return Status.FAILED;
                }
                Iterator<Callback> it = list.iterator();
                while (it.hasNext()) {
                    it.next().stepEnd();
                }
            }
            for (Callback callback : list) {
                callback.epochEnd();
                if ((callback instanceof LossCallback) && i2 == i - 1) {
                    setUploadLoss(((LossCallback) callback).getUploadLoss());
                }
            }
        }
        logger.info("total run time:" + (System.currentTimeMillis() - currentTimeMillis) + "ms");
        return Status.SUCCESS;
    }

    public Map<String, float[]> getFeatureMap() {
        HashMap hashMap = new HashMap(this.featureMap.size());
        for (Map.Entry<String, MSTensor> entry : this.featureMap.entrySet()) {
            hashMap.put(entry.getKey(), entry.getValue().getFloatData());
        }
        return hashMap;
    }

    public float[] getFeature(String str) {
        if (this.featureMap.containsKey(str)) {
            return this.featureMap.get(str).getFloatData();
        }
        return null;
    }

    public Status updateFeatures(String str, List<FeatureMap> list) {
        if (this.model == null || list == null || str == null || str.isEmpty()) {
            logger.severe("trainSession,featureMaps modelName cannot be null");
            return Status.NULLPTR;
        }
        new ArrayList(list.size());
        for (FeatureMap featureMap : list) {
            if (featureMap == null) {
                logger.severe("newFeature cannot be null");
                return Status.NULLPTR;
            }
            if (featureMap.weightFullname().isEmpty() || !this.featureMap.containsKey(featureMap.weightFullname())) {
                logger.severe("Can't get feature for name:" + featureMap.weightFullname());
                return Status.NULLPTR;
            }
            MSTensor mSTensor = this.featureMap.get(featureMap.weightFullname());
            ByteBuffer dataAsByteBuffer = featureMap.dataAsByteBuffer();
            ByteBuffer allocateDirect = ByteBuffer.allocateDirect(dataAsByteBuffer.remaining());
            allocateDirect.order(ByteOrder.nativeOrder());
            allocateDirect.put(dataAsByteBuffer);
            if (!mSTensor.setData(allocateDirect)) {
                logger.severe("Set tensor value failed, name:" + mSTensor.tensorName());
                return Status.FAILED;
            }
        }
        this.model.export(str, 0, false, (List) null);
        return Status.SUCCESS;
    }

    public Status updateFeature(FeatureMap featureMap) {
        if (featureMap == null) {
            logger.severe("newFeature cannot be null");
            return Status.NULLPTR;
        }
        if (featureMap.weightFullname().isEmpty() || !this.featureMap.containsKey(featureMap.weightFullname())) {
            logger.severe("Can't get feature for name:" + featureMap.weightFullname());
            return Status.NULLPTR;
        }
        MSTensor mSTensor = this.featureMap.get(featureMap.weightFullname());
        ByteBuffer dataAsByteBuffer = featureMap.dataAsByteBuffer();
        ByteBuffer allocateDirect = ByteBuffer.allocateDirect(dataAsByteBuffer.remaining());
        allocateDirect.order(ByteOrder.nativeOrder());
        allocateDirect.put(dataAsByteBuffer);
        if (mSTensor.setData(allocateDirect)) {
            return Status.SUCCESS;
        }
        logger.severe("Set tensor value failed, name:" + mSTensor.tensorName());
        return Status.FAILED;
    }
}
