/*
 * Decompiled with CFR 0.152.
 */
package com.mindspore.flclient.model;

import com.mindspore.Model;
import com.mindspore.flclient.FLParameter;
import com.mindspore.flclient.LocalFLParameter;
import com.mindspore.flclient.common.FLLoggerGenerater;
import com.mindspore.flclient.model.Callback;
import com.mindspore.flclient.model.DataSet;
import com.mindspore.flclient.model.ModelProxy;
import com.mindspore.flclient.model.RunType;
import com.mindspore.flclient.model.Status;
import java.io.File;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.file.CopyOption;
import java.nio.file.Files;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.logging.Logger;
import mindspore.fl.schema.FeatureMap;

public abstract class Client {
    private static final Logger logger = FLLoggerGenerater.getModelLogger(Client.class.toString());
    public Model model;
    private ModelProxy trainModelProxy = null;
    private ModelProxy inferModelProxy = null;
    private ModelProxy curProxy = null;
    public Map<RunType, DataSet> dataSets = new HashMap<RunType, DataSet>();
    private final List<ByteBuffer> inputsBuffer = new ArrayList<ByteBuffer>();
    private float uploadLoss = 0.0f;
    private Map<String, float[]> preFeatures = null;

    public abstract List<Callback> initCallbacks(RunType var1, DataSet var2);

    public abstract Map<RunType, Integer> initDataSets(Map<RunType, List<String>> var1);

    public abstract float getEvalAccuracy(List<Callback> var1);

    public abstract List<Object> getInferResult(List<Callback> var1);

    private void backupModelFile(String origFile) {
        File ofile = new File(origFile);
        String bakFile = ofile.getParent() + "/bak_" + ofile.getName();
        File dfile = new File(bakFile);
        if (dfile.exists() || !ofile.exists()) {
            return;
        }
        try {
            logger.info("Backup model file:" + origFile + " to :" + bakFile);
            Files.copy(ofile.toPath(), dfile.toPath(), new CopyOption[0]);
        }
        catch (IOException e) {
            e.printStackTrace();
        }
    }

    public void restoreModelFile(String fileName) {
        File dfile = new File(fileName);
        String bakFile = dfile.getParent() + "/bak_" + dfile.getName();
        File bfile = new File(bakFile);
        if (!bfile.exists()) {
            logger.severe("Restore failed, backup file:" + bfile.getName() + " not exist.");
            return;
        }
        if (dfile.exists()) {
            logger.severe("Delete the origin file:" + dfile.getName());
            dfile.delete();
        }
        try {
            logger.info("Restore model file:" + dfile.getName() + " from :" + bfile.getName());
            Files.copy(bfile.toPath(), dfile.toPath(), new CopyOption[0]);
        }
        catch (IOException e) {
            e.printStackTrace();
        }
    }

    public Status initModel(FLParameter flParameter) {
        String trainModelPath = flParameter.getTrainModelPath();
        String inferModelPath = flParameter.getInferModelPath();
        int[][] inputShapes = flParameter.getInputShape();
        Status initRet = Status.FAILED;
        if (trainModelPath != null && !trainModelPath.isEmpty()) {
            this.backupModelFile(trainModelPath);
            this.trainModelProxy = new ModelProxy();
            initRet = this.trainModelProxy.initModel(trainModelPath, inputShapes);
        }
        if (inferModelPath == null || inferModelPath.isEmpty()) {
            return initRet;
        }
        if (inferModelPath.equals(trainModelPath)) {
            this.inferModelProxy = this.trainModelProxy;
        } else {
            this.inferModelProxy = new ModelProxy();
            initRet = this.inferModelProxy.initModel(inferModelPath, inputShapes);
        }
        return initRet;
    }

    public boolean EnableTrain(boolean enableFlg) {
        if (enableFlg && this.trainModelProxy != null) {
            this.curProxy = this.trainModelProxy;
            this.model = this.trainModelProxy.getModel();
            return true;
        }
        if (!enableFlg && this.inferModelProxy != null) {
            this.curProxy = this.inferModelProxy;
            this.model = this.inferModelProxy.getModel();
            return true;
        }
        logger.severe("Can't get Model proxy for " + (enableFlg ? "train" : "infer"));
        return false;
    }

    public Status trainModel(int epochs) {
        if (epochs <= 0) {
            logger.severe("epochs cannot smaller than 0");
            return Status.INVALID;
        }
        this.preFeatures = this.curProxy.getFeatureMap();
        DataSet trainDataSet = this.dataSets.getOrDefault((Object)RunType.TRAINMODE, null);
        if (trainDataSet == null) {
            logger.severe("not find train dataset");
            return Status.NULLPTR;
        }
        trainDataSet.padding();
        List<Callback> trainCallbacks = this.initCallbacks(RunType.TRAINMODE, trainDataSet);
        this.model.setTrainMode(true);
        Status status = this.curProxy.runModel(epochs, trainCallbacks, trainDataSet);
        if (status != Status.SUCCESS) {
            logger.severe("train loop failed");
            return status;
        }
        return Status.SUCCESS;
    }

    public float evalModel() {
        this.model.setTrainMode(false);
        DataSet evalDataSet = this.dataSets.getOrDefault((Object)RunType.EVALMODE, null);
        evalDataSet.padding();
        List<Callback> evalCallbacks = this.initCallbacks(RunType.EVALMODE, evalDataSet);
        Status status = this.curProxy.runModel(1, evalCallbacks, evalDataSet);
        if (status != Status.SUCCESS) {
            logger.severe("train loop failed");
            return Float.NaN;
        }
        return this.getEvalAccuracy(evalCallbacks);
    }

    public List<Object> inferModel() {
        this.model.setTrainMode(false);
        DataSet inferDataSet = this.dataSets.getOrDefault((Object)RunType.INFERMODE, null);
        inferDataSet.padding();
        List<Callback> inferCallbacks = this.initCallbacks(RunType.INFERMODE, inferDataSet);
        Status status = this.curProxy.runModel(1, inferCallbacks, inferDataSet);
        if (status != Status.SUCCESS) {
            logger.severe("train loop failed");
            return null;
        }
        return this.getInferResult(inferCallbacks);
    }

    private boolean saveModelbyProxy(ModelProxy proxy, String outFile) {
        File dstFile;
        String tmpFileName;
        if (outFile == null || outFile.isEmpty() || proxy == null || proxy.getModel() == null) {
            logger.info("Path is empty or no model provided, no need to save model, out path:" + outFile);
            return true;
        }
        Model trainModel = proxy.getModel();
        boolean exportRet = trainModel.export(tmpFileName = (dstFile = new File(outFile)).getParent() + "/tmp_" + dstFile.getName(), 0, false, null);
        if (exportRet) {
            File tmpFile = new File(tmpFileName);
            exportRet = tmpFile.renameTo(dstFile);
        }
        return exportRet;
    }

    public Status saveModel(FLParameter flParameter, LocalFLParameter localFLParameter) {
        String trainModelPath = flParameter.getTrainModelPath();
        String inferModelPath = flParameter.getInferModelPath();
        boolean exportRet = this.saveModelbyProxy(this.trainModelProxy, trainModelPath);
        if (!exportRet) {
            logger.severe("Save train model to file failed.");
            return Status.FAILED;
        }
        exportRet = this.saveModelbyProxy(this.inferModelProxy, inferModelPath);
        if (!exportRet) {
            logger.severe("Save infer model to file failed.");
            return Status.FAILED;
        }
        return Status.SUCCESS;
    }

    public float getDpWeightNorm(ArrayList<String> featureList) {
        if (this.preFeatures == null) {
            throw new RuntimeException("Must call getDpWeightNorm after train.");
        }
        float updateL2Norm = 0.0f;
        for (String key : featureList) {
            float[] preData = this.preFeatures.get(key);
            float[] curData = this.trainModelProxy.getFeature(key);
            if (preData == null || curData == null) {
                throw new RuntimeException("Get feature value failed, feature name:" + key);
            }
            if (preData.length != curData.length) {
                throw new RuntimeException("Length of " + key + " is changed after update, origin len:" + preData.length + " cur len:" + curData.length);
            }
            for (int j = 0; j < preData.length; ++j) {
                float updateData = preData[j] - curData[j];
                updateL2Norm += updateData * updateData;
            }
        }
        updateL2Norm = (float)Math.sqrt(updateL2Norm);
        return updateL2Norm;
    }

    public float[] getFeature(String weightName) {
        return this.curProxy.getFeature(weightName);
    }

    public float[] getPreFeature(String weightName) {
        return this.preFeatures.get(weightName);
    }

    public Status updateFeature(FeatureMap newFeature, boolean trainFlg) {
        if (trainFlg && this.trainModelProxy != null) {
            return this.trainModelProxy.updateFeature(newFeature);
        }
        if (!trainFlg && this.inferModelProxy != null) {
            return this.inferModelProxy.updateFeature(newFeature);
        }
        logger.severe("Can't get ModelProxy for " + (trainFlg ? "trainModel" : "inferModel"));
        return Status.FAILED;
    }

    public void free() {
        if (this.trainModelProxy != null) {
            this.trainModelProxy.free();
        }
        if (this.inferModelProxy != null && this.inferModelProxy != this.trainModelProxy) {
            this.inferModelProxy.free();
        }
        this.trainModelProxy = null;
        this.inferModelProxy = null;
        this.curProxy = null;
        this.model = null;
    }

    public Status setLearningRate(float lr) {
        if (this.trainModelProxy != null && this.trainModelProxy.getModel().setLearningRate(lr)) {
            return Status.SUCCESS;
        }
        logger.severe("set learning rate failed");
        return Status.FAILED;
    }

    public void setBatchSize(int batchSize) {
        for (DataSet dataset : this.dataSets.values()) {
            dataset.batchSize = batchSize;
        }
    }

    public float getUploadLoss() {
        return this.curProxy.getUploadLoss();
    }

    public void setUploadLoss(float uploadLoss) {
        this.curProxy.setUploadLoss(uploadLoss);
    }
}

