package com.mindspore.flclient;

import com.mindspore.config.Version;
import com.mindspore.flclient.common.FLLoggerGenerater;
import com.mindspore.flclient.model.Client;
import com.mindspore.flclient.model.ClientManager;
import com.mindspore.flclient.model.RunType;
import com.mindspore.flclient.model.Status;
import com.mindspore.flclient.pki.PkiUtil;
import java.security.SecureRandom;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.logging.Logger;

/* loaded from: input_file:com/mindspore/flclient/SyncFLJob.class */
public class SyncFLJob {
    private static Logger LOGGER = FLLoggerGenerater.getModelLogger(SyncFLJob.class.toString());
    private IFLJobResultCallback flJobResultCallback;
    private FLClientStatus curStatus;
    private FLParameter flParameter = FLParameter.getInstance();
    private LocalFLParameter localFLParameter = LocalFLParameter.getInstance();
    private int tryTimePerIter = 0;
    private int lastIteration = -1;
    private int waitTryTime = 0;
    private ArrayList<String> msVersions = new ArrayList<>(Arrays.asList("MindSpore Lite 1.8.1", "MindSpore Lite 1.9.0", "MindSpore Lite 2.1.0"));

    private void initFlIDForPkiVerify() {
        if (!this.flParameter.isPkiVerify()) {
            LOGGER.info("pkiVerify mode is not open!");
            this.localFLParameter.setFlID(this.flParameter.getClientID());
            return;
        }
        LOGGER.info("pkiVerify mode is open!");
        String genEquipCertHash = PkiUtil.genEquipCertHash(this.flParameter.getClientID());
        if (genEquipCertHash == null || genEquipCertHash.isEmpty()) {
            LOGGER.severe("equipCertHash is empty, please check your mobile phone, only Huawei phones are supported now.");
            throw new IllegalArgumentException();
        }
        LOGGER.info("flID for pki verify is: " + genEquipCertHash);
        this.localFLParameter.setFlID(genEquipCertHash);
    }

    private void msVersionCheck() {
        String version = Version.version();
        boolean z = false;
        Iterator<String> it = this.msVersions.iterator();
        while (true) {
            if (!it.hasNext()) {
                break;
            } else if (version.equals(it.next())) {
                z = true;
                break;
            }
        }
        if (z) {
            LOGGER.info("Got compatible mindspore lite version:" + version);
        } else {
            if (version.compareTo(this.msVersions.get(0)) <= 0) {
                throw new RuntimeException("Expect mindspore lite version in " + this.msVersions.toString() + ", but got incompatible mindspore lite version:" + version);
            }
            LOGGER.warning("Expect mindspore lite version in " + this.msVersions.toString() + ", but got incompatible mindspore lite version:" + version);
        }
    }

    public SyncFLJob() {
        try {
            LOGGER.info("the flName: " + this.flParameter.getFlName());
            Class.forName(this.flParameter.getFlName());
            msVersionCheck();
        } catch (ClassNotFoundException e) {
            LOGGER.severe("catch ClassNotFoundException error, the set flName does not exist, please check: " + e.getMessage());
            throw new IllegalArgumentException();
        }
    }

    public FLClientStatus flJobRun() {
        this.flJobResultCallback = this.flParameter.getIflJobResultCallback();
        if (LocalFLParameter.ANDROID.equals(this.flParameter.getDeployEnv())) {
            Common.setSecureRandom(Common.getFastSecureRandom());
        } else {
            Common.setSecureRandom(new SecureRandom());
        }
        initFlIDForPkiVerify();
        this.localFLParameter.setMsConfig(0, this.flParameter.getThreadNum(), this.flParameter.getCpuBindMode(), false);
        Client client = ClientManager.getClient(this.flParameter.getFlName());
        if (client.initModel(this.flParameter) != Status.SUCCESS) {
            LOGGER.severe("initModel failed");
            client.free();
            return FLClientStatus.FAILED;
        }
        FLLiteClient fLLiteClient = new FLLiteClient();
        LOGGER.info("recovery StopJobFlag to false in the start of fl job");
        this.localFLParameter.setStopJobFlag(false);
        InitialParameters();
        LOGGER.info("flJobRun start");
        flRunLoop(fLLiteClient);
        if (this.curStatus == FLClientStatus.SUCCESS) {
            client.saveModel(this.flParameter, this.localFLParameter);
        }
        LOGGER.info("flJobRun finish");
        this.flJobResultCallback.onFlJobFinished(this.flParameter.getFlName(), fLLiteClient.getIterations(), fLLiteClient.getRetCode());
        client.free();
        return this.curStatus;
    }

    private void flRunLoop(FLLiteClient fLLiteClient) {
        while (!tryTimeExceedsLimit().booleanValue() && !checkStopJobFlag()) {
            LOGGER.info("flName: " + this.flParameter.getFlName());
            if (!fLLiteClient.initDataSets()) {
                this.curStatus = FLClientStatus.FAILED;
                failed("unsolved error code in <flLiteClient.setInput>: the return trainDataSize<=0, setInput", fLLiteClient);
                return;
            }
            this.curStatus = fLLiteClient.startFLJob();
            if (this.curStatus == FLClientStatus.RESTART) {
                this.tryTimePerIter++;
                resetContext("[startFLJob]", fLLiteClient.getNextRequestTime(), fLLiteClient);
            } else {
                if (this.curStatus != FLClientStatus.SUCCESS) {
                    failed("[startFLJob]", fLLiteClient);
                    return;
                }
                LOGGER.info("[startFLJob] startFLJob succeed, curIteration: " + fLLiteClient.getIteration());
                updateTryTimePerIter(fLLiteClient);
                if (checkEvalPath()) {
                    this.curStatus = fLLiteClient.evaluateModel();
                    if (this.curStatus != FLClientStatus.SUCCESS) {
                        failed("[evaluate] evaluate", fLLiteClient);
                        return;
                    }
                    LOGGER.info("[evaluate] evaluate succeed");
                } else {
                    LOGGER.info("[evaluate] the data map set by user do not contain evaluation dataset, don't evaluate the model after getting model from server");
                }
                ClientManager.getClient(this.flParameter.getFlName()).EnableTrain(true);
                this.curStatus = fLLiteClient.getFeatureMask();
                if (this.curStatus == FLClientStatus.RESTART) {
                    resetContext("[Encrypt] creatMask", fLLiteClient.getNextRequestTime(), fLLiteClient);
                } else {
                    if (this.curStatus != FLClientStatus.SUCCESS) {
                        failed("[Encrypt] createMask", fLLiteClient);
                        return;
                    }
                    this.curStatus = fLLiteClient.localTrain();
                    if (this.curStatus != FLClientStatus.SUCCESS) {
                        failed("[train] train", fLLiteClient);
                        return;
                    }
                    LOGGER.info("[train] train succeed");
                    this.curStatus = fLLiteClient.updateModel();
                    if (this.curStatus == FLClientStatus.RESTART) {
                        resetContext("[updateModel]", fLLiteClient.getNextRequestTime(), fLLiteClient);
                    } else {
                        if (this.curStatus != FLClientStatus.SUCCESS) {
                            failed("[updateModel] updateModel", fLLiteClient);
                            return;
                        }
                        this.curStatus = fLLiteClient.unMasking();
                        if (this.curStatus == FLClientStatus.RESTART) {
                            resetContext("[Encrypt] unmasking", fLLiteClient.getNextRequestTime(), fLLiteClient);
                        } else {
                            if (this.curStatus != FLClientStatus.SUCCESS) {
                                failed("[Encrypt] unmasking", fLLiteClient);
                                return;
                            }
                            this.curStatus = getResult(fLLiteClient);
                            if (this.curStatus == FLClientStatus.RESTART) {
                                resetContext("[getResult]", fLLiteClient.getNextRequestTime(), fLLiteClient);
                            } else {
                                if (this.curStatus != FLClientStatus.SUCCESS) {
                                    failed("[getResult] getResult", fLLiteClient);
                                    return;
                                }
                                if (this.localFLParameter.getEncryptLevel() == EncryptLevel.DP_ENCRYPT) {
                                    this.curStatus = getModel(fLLiteClient);
                                    if (this.curStatus == FLClientStatus.RESTART) {
                                        resetContext("[getModel]", fLLiteClient.getNextRequestTime(), fLLiteClient);
                                    } else {
                                        if (this.curStatus != FLClientStatus.SUCCESS) {
                                            failed("[getModel] getModel", fLLiteClient);
                                            return;
                                        }
                                        fLLiteClient.updateDpNormClip();
                                    }
                                }
                                LOGGER.info("========================================================the total response of " + fLLiteClient.getIteration() + ": " + this.curStatus + "======================================================================");
                                this.flJobResultCallback.onFlJobIterationFinished(this.flParameter.getFlName(), fLLiteClient.getIteration(), fLLiteClient.getRetCode());
                                this.tryTimePerIter = 0;
                            }
                        }
                    }
                }
            }
            if (fLLiteClient.getIteration() >= fLLiteClient.getIterations()) {
                return;
            }
        }
    }

    private void InitialParameters() {
        this.tryTimePerIter = 0;
        this.lastIteration = -1;
        this.waitTryTime = 0;
    }

    private Boolean tryTimeExceedsLimit() {
        if (this.tryTimePerIter <= 1) {
            return false;
        }
        LOGGER.severe("[tryTimeExceedsLimit] the repeated request time exceeds the limit, current repeated request time is: " + this.tryTimePerIter + " the limited time is: 1");
        this.curStatus = FLClientStatus.FAILED;
        return true;
    }

    private void updateTryTimePerIter(FLLiteClient fLLiteClient) {
        if (this.lastIteration != -1 && this.lastIteration == fLLiteClient.getIteration()) {
            this.tryTimePerIter++;
        } else {
            this.tryTimePerIter = 1;
            this.lastIteration = fLLiteClient.getIteration();
        }
    }

    private Boolean waitTryTimeExceedsLimit() {
        if (this.waitTryTime <= 18) {
            return false;
        }
        LOGGER.severe("[waitTryTimeExceedsLimit] the waitTryTime exceeds the limit, current waitTryTime is: " + this.waitTryTime + " the limited time is: 18");
        this.curStatus = FLClientStatus.FAILED;
        return true;
    }

    private FLClientStatus getResult(FLLiteClient fLLiteClient) {
        FLClientStatus result = fLLiteClient.getResult();
        this.waitTryTime = 0;
        while (true) {
            if (result != FLClientStatus.WAIT) {
                break;
            }
            this.waitTryTime++;
            if (waitTryTimeExceedsLimit().booleanValue()) {
                result = FLClientStatus.FAILED;
                break;
            }
            if (checkStopJobFlag()) {
                result = FLClientStatus.FAILED;
                break;
            }
            waitSomeTime();
            result = fLLiteClient.getResult();
        }
        return result;
    }

    private FLClientStatus getModel(FLLiteClient fLLiteClient) {
        FLClientStatus model = fLLiteClient.getModel();
        this.waitTryTime = 0;
        while (true) {
            if (model != FLClientStatus.WAIT) {
                break;
            }
            this.waitTryTime++;
            if (waitTryTimeExceedsLimit().booleanValue()) {
                model = FLClientStatus.FAILED;
                break;
            }
            if (checkStopJobFlag()) {
                model = FLClientStatus.FAILED;
                break;
            }
            waitSomeTime();
            model = fLLiteClient.getModel();
        }
        return model;
    }

    private boolean checkEvalPath() {
        if (this.flParameter.getDataMap().containsKey(RunType.EVALMODE)) {
            return true;
        }
        LOGGER.info("[evaluate] the data map set by user do not contain evaluation dataset, don't evaluate the model after getting model from server");
        return false;
    }

    private boolean checkStopJobFlag() {
        if (!this.localFLParameter.isStopJobFlag()) {
            return false;
        }
        LOGGER.info("the stopJObFlag is set to true, the job will be stop");
        this.curStatus = FLClientStatus.FAILED;
        return true;
    }

    public List<Object> modelInfer() {
        Client client = ClientManager.getClient(this.flParameter.getFlName());
        this.localFLParameter.setMsConfig(0, this.flParameter.getThreadNum(), this.flParameter.getCpuBindMode(), false);
        this.localFLParameter.setStopJobFlag(false);
        if (null != this.flParameter.getInputShape()) {
            LOGGER.info("[model inference] the inference model has dynamic input.");
        }
        if (client.initDataSets(this.flParameter.getDataMap()).isEmpty()) {
            LOGGER.severe("[model inference] initDataSets failed, please check");
            client.free();
            return null;
        }
        if (client.initModel(this.flParameter) != Status.SUCCESS) {
            LOGGER.severe("initModel failed");
            return null;
        }
        if (!client.EnableTrain(false)) {
            LOGGER.severe("[model inference] call EnableTrain failed");
            client.free();
            return null;
        }
        client.setBatchSize(this.flParameter.getBatchSize());
        LOGGER.info("===========model inference=============");
        List<Object> inferModel = client.inferModel();
        if (inferModel == null || inferModel.size() == 0) {
            LOGGER.severe("[model inference] the returned label from client.inferModel() is null, please check");
            client.free();
            return null;
        }
        LOGGER.fine("[model inference] the predicted outputs: " + Arrays.deepToString(inferModel.toArray()));
        client.free();
        LOGGER.info("[model inference] inference finish");
        return inferModel;
    }

    public FLClientStatus getModel() {
        if (LocalFLParameter.ANDROID.equals(this.flParameter.getDeployEnv())) {
            Common.setSecureRandom(Common.getFastSecureRandom());
        } else {
            Common.setSecureRandom(new SecureRandom());
        }
        this.localFLParameter.setServerMod(this.flParameter.getServerMod().toString());
        this.localFLParameter.setMsConfig(0, 1, 0, false);
        Client client = ClientManager.getClient(this.flParameter.getFlName());
        if (client.initModel(this.flParameter) != Status.SUCCESS) {
            LOGGER.severe("initModel failed");
            client.free();
            return null;
        }
        FLClientStatus model = new FLLiteClient().getModel();
        if (model == FLClientStatus.SUCCESS) {
            client.saveModel(this.flParameter, this.localFLParameter);
        }
        client.free();
        return model;
    }

    public void stopFLJob() {
        LOGGER.info("will stop the flJob");
        this.localFLParameter.setStopJobFlag(true);
        Common.notifyObject();
    }

    private void waitSomeTime() {
        if (this.flParameter.getSleepTime() != 0) {
            Common.sleep(this.flParameter.getSleepTime());
        } else {
            Common.sleep(10000L);
        }
    }

    private void waitNextReqTime(String str) {
        Common.sleep(Common.getWaitTime(str));
    }

    private void resetContext(String str, String str2, FLLiteClient fLLiteClient) {
        LOGGER.info(str + " out of time: need wait and request startFLJob again");
        waitNextReqTime(str2);
        this.flJobResultCallback.onFlJobIterationFinished(this.flParameter.getFlName(), fLLiteClient.getIteration(), fLLiteClient.getRetCode());
    }

    private void failed(String str, FLLiteClient fLLiteClient) {
        LOGGER.info(str + " failed");
        LOGGER.info("=========================================the total response of " + fLLiteClient.getIteration() + ": " + this.curStatus + "=========================================");
        this.flJobResultCallback.onFlJobIterationFinished(this.flParameter.getFlName(), fLLiteClient.getIteration(), fLLiteClient.getRetCode());
    }

    private static Map<RunType, List<String>> createDatasetMap(String str, String str2, String str3, String str4) {
        HashMap hashMap = new HashMap();
        if (str == null || "null".equals(str) || str.isEmpty()) {
            LOGGER.info("the trainDataPath is null or empty, please check if you are in the case of only inference");
        } else {
            hashMap.put(RunType.TRAINMODE, Arrays.asList(str.split(str4)));
            LOGGER.info("the trainDataPath: " + Arrays.toString(str.split(str4)));
        }
        if (str2 == null || "null".equals(str2) || str2.isEmpty()) {
            LOGGER.info("the evalDataPath is null or empty, please check if you are in the case of only training without evaluation");
        } else {
            hashMap.put(RunType.EVALMODE, Arrays.asList(str2.split(str4)));
            LOGGER.info("the evalDataPath: " + Arrays.toString(str2.split(str4)));
        }
        if (str3 == null || "null".equals(str3) || str3.isEmpty()) {
            LOGGER.info("the inferDataPath is null or empty, please check if you are in the case of training without inference");
        } else {
            hashMap.put(RunType.INFERMODE, Arrays.asList(str3.split(str4)));
            LOGGER.info("the inferDataPath: " + Arrays.toString(str3.split(str4)));
        }
        return hashMap;
    }

    private static void createWeightNameList(String str, String str2, String str3, FLParameter fLParameter) {
        if (str == null || "null".equals(str) || str.isEmpty()) {
            LOGGER.info("the trainWeightName is null or empty");
        } else {
            fLParameter.setHybridWeightName(Arrays.asList(str.split(str3)), RunType.TRAINMODE);
            LOGGER.info("the trainWeightName: " + Arrays.toString(str.split(str3)));
        }
        if (str2 == null || "null".equals(str2) || str2.isEmpty()) {
            LOGGER.info("the inferWeightName is null or empty");
        } else {
            fLParameter.setHybridWeightName(Arrays.asList(str2.split(str3)), RunType.INFERMODE);
            LOGGER.info("the inferWeightName: " + Arrays.toString(str2.split(str3)));
        }
    }

    /* JADX WARN: Type inference failed for: r0v5, types: [int[], int[][]] */
    private static int[][] getInputShapeArray(String str) {
        String[] split = str.split(";");
        int length = split.length;
        ?? r0 = new int[length];
        for (int i = 0; i < length; i++) {
            r0[i] = Arrays.stream(split[i].split(",")).mapToInt(Integer::parseInt).toArray();
        }
        return r0;
    }

    private static void task(String[] strArr) {
        String str = strArr[0];
        String str2 = strArr[1];
        String str3 = strArr[2];
        String str4 = strArr[3];
        String str5 = strArr[4];
        String str6 = strArr[5];
        String str7 = strArr[6];
        String str8 = strArr[7];
        String str9 = strArr[8];
        String str10 = strArr[9];
        String str11 = strArr[10];
        boolean parseBoolean = Boolean.parseBoolean(strArr[11]);
        int parseInt = Integer.parseInt(strArr[12]);
        String str12 = strArr[13];
        int parseInt2 = Integer.parseInt(strArr[14]);
        String str13 = strArr[15];
        String str14 = strArr[16];
        String str15 = strArr[17];
        String str16 = strArr[18];
        String str17 = strArr[19];
        String str18 = strArr[21];
        int parseInt3 = Integer.parseInt(strArr[20]);
        FLParameter fLParameter = FLParameter.getInstance();
        if (!"null".equals(str18) && str18 != null) {
            fLParameter.setInputShape(getInputShapeArray(str18));
        }
        Map<RunType, List<String>> createDatasetMap = createDatasetMap(str, str2, str3, str4);
        createWeightNameList(str14, str15, str16, fLParameter);
        fLParameter.setFlName(str5);
        SyncFLJob syncFLJob = new SyncFLJob();
        boolean z = -1;
        switch (str12.hashCode()) {
            case 110621192:
                if (str12.equals("train")) {
                    z = false;
                    break;
                }
                break;
            case 962456601:
                if (str12.equals("inference")) {
                    z = true;
                    break;
                }
                break;
            case 1959895411:
                if (str12.equals("getModel")) {
                    z = 2;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
                LOGGER.info("start syncFLJob.flJobRun()");
                fLParameter.setDataMap(createDatasetMap);
                fLParameter.setTrainModelPath(str6);
                fLParameter.setInferModelPath(str7);
                fLParameter.setSslProtocol(str8);
                fLParameter.setDeployEnv(str9);
                fLParameter.setDomainName(str10);
                if (Common.isHttps()) {
                    fLParameter.setCertPath(str11);
                }
                fLParameter.setUseElb(parseBoolean);
                fLParameter.setServerNum(parseInt);
                fLParameter.setThreadNum(parseInt2);
                fLParameter.setCpuBindMode(BindMode.valueOf(str13));
                fLParameter.setBatchSize(parseInt3);
                syncFLJob.flJobRun();
                return;
            case true:
                LOGGER.info("start syncFLJob.modelInference()");
                fLParameter.setDataMap(createDatasetMap);
                fLParameter.setInferModelPath(str7);
                fLParameter.setThreadNum(parseInt2);
                fLParameter.setCpuBindMode(BindMode.valueOf(str13));
                fLParameter.setBatchSize(parseInt3);
                syncFLJob.modelInfer();
                return;
            case true:
                LOGGER.info("start syncFLJob.getModel()");
                fLParameter.setTrainModelPath(str6);
                fLParameter.setInferModelPath(str7);
                fLParameter.setSslProtocol(str8);
                fLParameter.setDeployEnv(str9);
                fLParameter.setDomainName(str10);
                if (Common.isHttps()) {
                    fLParameter.setCertPath(str11);
                }
                fLParameter.setUseElb(parseBoolean);
                fLParameter.setServerNum(parseInt);
                fLParameter.setServerMod(ServerMod.valueOf(str17));
                syncFLJob.getModel();
                return;
            default:
                LOGGER.info("do not do any thing!");
                return;
        }
    }

    public static void main(String[] strArr) {
        if (strArr[4] == null || strArr[4].isEmpty()) {
            LOGGER.severe("the parameter of <args[4]> is null, please check");
            throw new IllegalArgumentException();
        }
        task(strArr);
    }
}
