package com.mindspore.flclient.model;

import com.mindspore.Model;
import com.mindspore.flclient.FLParameter;
import com.mindspore.flclient.LocalFLParameter;
import com.mindspore.flclient.common.FLLoggerGenerater;
import java.io.File;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.file.CopyOption;
import java.nio.file.Files;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.logging.Logger;
import mindspore.fl.schema.FeatureMap;

/* loaded from: input_file:com/mindspore/flclient/model/Client.class */
public abstract class Client {
    private static final Logger logger = FLLoggerGenerater.getModelLogger(Client.class.toString());
    public Model model;
    private ModelProxy trainModelProxy = null;
    private ModelProxy inferModelProxy = null;
    private ModelProxy curProxy = null;
    public Map<RunType, DataSet> dataSets = new HashMap();
    private final List<ByteBuffer> inputsBuffer = new ArrayList();
    private float uploadLoss = 0.0f;
    private Map<String, float[]> preFeatures = null;

    public abstract List<Callback> initCallbacks(RunType runType, DataSet dataSet);

    public abstract Map<RunType, Integer> initDataSets(Map<RunType, List<String>> map);

    public abstract float getEvalAccuracy(List<Callback> list);

    public abstract List<Object> getInferResult(List<Callback> list);

    private void backupModelFile(String str) {
        File file = new File(str);
        String str2 = file.getParent() + "/bak_" + file.getName();
        File file2 = new File(str2);
        if (file2.exists() || !file.exists()) {
            return;
        }
        try {
            logger.info("Backup model file:" + str + " to :" + str2);
            Files.copy(file.toPath(), file2.toPath(), new CopyOption[0]);
        } catch (IOException e) {
            e.printStackTrace();
        }
    }

    public void restoreModelFile(String str) {
        File file = new File(str);
        File file2 = new File(file.getParent() + "/bak_" + file.getName());
        if (!file2.exists()) {
            logger.severe("Restore failed, backup file:" + file2.getName() + " not exist.");
            return;
        }
        if (file.exists()) {
            logger.severe("Delete the origin file:" + file.getName());
            file.delete();
        }
        try {
            logger.info("Restore model file:" + file.getName() + " from :" + file2.getName());
            Files.copy(file2.toPath(), file.toPath(), new CopyOption[0]);
        } catch (IOException e) {
            e.printStackTrace();
        }
    }

    public Status initModel(FLParameter fLParameter) {
        String trainModelPath = fLParameter.getTrainModelPath();
        String inferModelPath = fLParameter.getInferModelPath();
        int[][] inputShape = fLParameter.getInputShape();
        Status status = Status.FAILED;
        if (trainModelPath != null && !trainModelPath.isEmpty()) {
            backupModelFile(trainModelPath);
            this.trainModelProxy = new ModelProxy();
            status = this.trainModelProxy.initModel(trainModelPath, inputShape);
        }
        if (inferModelPath == null || inferModelPath.isEmpty()) {
            return status;
        }
        if (inferModelPath.equals(trainModelPath)) {
            this.inferModelProxy = this.trainModelProxy;
        } else {
            this.inferModelProxy = new ModelProxy();
            status = this.inferModelProxy.initModel(inferModelPath, inputShape);
        }
        return status;
    }

    public boolean EnableTrain(boolean z) {
        if (z && this.trainModelProxy != null) {
            this.curProxy = this.trainModelProxy;
            this.model = this.trainModelProxy.getModel();
            return true;
        }
        if (z || this.inferModelProxy == null) {
            logger.severe("Can't get Model proxy for " + (z ? "train" : "infer"));
            return false;
        }
        this.curProxy = this.inferModelProxy;
        this.model = this.inferModelProxy.getModel();
        return true;
    }

    public Status trainModel(int i) {
        if (i <= 0) {
            logger.severe("epochs cannot smaller than 0");
            return Status.INVALID;
        }
        this.preFeatures = this.curProxy.getFeatureMap();
        DataSet orDefault = this.dataSets.getOrDefault(RunType.TRAINMODE, null);
        if (orDefault == null) {
            logger.severe("not find train dataset");
            return Status.NULLPTR;
        }
        orDefault.padding();
        List<Callback> initCallbacks = initCallbacks(RunType.TRAINMODE, orDefault);
        this.model.setTrainMode(true);
        Status runModel = this.curProxy.runModel(i, initCallbacks, orDefault);
        if (runModel == Status.SUCCESS) {
            return Status.SUCCESS;
        }
        logger.severe("train loop failed");
        return runModel;
    }

    public float evalModel() {
        this.model.setTrainMode(false);
        DataSet orDefault = this.dataSets.getOrDefault(RunType.EVALMODE, null);
        orDefault.padding();
        List<Callback> initCallbacks = initCallbacks(RunType.EVALMODE, orDefault);
        if (this.curProxy.runModel(1, initCallbacks, orDefault) == Status.SUCCESS) {
            return getEvalAccuracy(initCallbacks);
        }
        logger.severe("train loop failed");
        return Float.NaN;
    }

    public List<Object> inferModel() {
        this.model.setTrainMode(false);
        DataSet orDefault = this.dataSets.getOrDefault(RunType.INFERMODE, null);
        orDefault.padding();
        List<Callback> initCallbacks = initCallbacks(RunType.INFERMODE, orDefault);
        if (this.curProxy.runModel(1, initCallbacks, orDefault) == Status.SUCCESS) {
            return getInferResult(initCallbacks);
        }
        logger.severe("train loop failed");
        return null;
    }

    private boolean saveModelbyProxy(ModelProxy modelProxy, String str) {
        if (str == null || str.isEmpty() || modelProxy == null || modelProxy.getModel() == null) {
            logger.info("Path is empty or no model provided, no need to save model, out path:" + str);
            return true;
        }
        Model model = modelProxy.getModel();
        File file = new File(str);
        String str2 = file.getParent() + "/tmp_" + file.getName();
        boolean export = model.export(str2, 0, false, (List) null);
        if (export) {
            export = new File(str2).renameTo(file);
        }
        return export;
    }

    public Status saveModel(FLParameter fLParameter, LocalFLParameter localFLParameter) {
        String trainModelPath = fLParameter.getTrainModelPath();
        String inferModelPath = fLParameter.getInferModelPath();
        if (!saveModelbyProxy(this.trainModelProxy, trainModelPath)) {
            logger.severe("Save train model to file failed.");
            return Status.FAILED;
        }
        if (saveModelbyProxy(this.inferModelProxy, inferModelPath)) {
            return Status.SUCCESS;
        }
        logger.severe("Save infer model to file failed.");
        return Status.FAILED;
    }

    public float getDpWeightNorm(ArrayList<String> arrayList) {
        if (this.preFeatures == null) {
            throw new RuntimeException("Must call getDpWeightNorm after train.");
        }
        float f = 0.0f;
        Iterator<String> it = arrayList.iterator();
        while (it.hasNext()) {
            String next = it.next();
            float[] fArr = this.preFeatures.get(next);
            float[] feature = this.trainModelProxy.getFeature(next);
            if (fArr == null || feature == null) {
                throw new RuntimeException("Get feature value failed, feature name:" + next);
            }
            if (fArr.length != feature.length) {
                throw new RuntimeException("Length of " + next + " is changed after update, origin len:" + fArr.length + " cur len:" + feature.length);
            }
            for (int i = 0; i < fArr.length; i++) {
                float f2 = fArr[i] - feature[i];
                f += f2 * f2;
            }
        }
        return (float) Math.sqrt(f);
    }

    public float[] getFeature(String str) {
        return this.curProxy.getFeature(str);
    }

    public float[] getPreFeature(String str) {
        return this.preFeatures.get(str);
    }

    public Status updateFeature(FeatureMap featureMap, boolean z) {
        if (z && this.trainModelProxy != null) {
            return this.trainModelProxy.updateFeature(featureMap);
        }
        if (!z && this.inferModelProxy != null) {
            return this.inferModelProxy.updateFeature(featureMap);
        }
        logger.severe("Can't get ModelProxy for " + (z ? "trainModel" : "inferModel"));
        return Status.FAILED;
    }

    public void free() {
        if (this.trainModelProxy != null) {
            this.trainModelProxy.free();
        }
        if (this.inferModelProxy != null && this.inferModelProxy != this.trainModelProxy) {
            this.inferModelProxy.free();
        }
        this.trainModelProxy = null;
        this.inferModelProxy = null;
        this.curProxy = null;
        this.model = null;
    }

    public Status setLearningRate(float f) {
        if (this.trainModelProxy != null && this.trainModelProxy.getModel().setLearningRate(f)) {
            return Status.SUCCESS;
        }
        logger.severe("set learning rate failed");
        return Status.FAILED;
    }

    public void setBatchSize(int i) {
        Iterator<DataSet> it = this.dataSets.values().iterator();
        while (it.hasNext()) {
            it.next().batchSize = i;
        }
    }

    public float getUploadLoss() {
        return this.curProxy.getUploadLoss();
    }

    public void setUploadLoss(float f) {
        this.curProxy.setUploadLoss(f);
    }
}
