package com.mindspore.flclient;

import com.mindspore.flclient.common.FLLoggerGenerater;
import com.mindspore.flclient.model.Client;
import java.security.SecureRandom;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.logging.Logger;

/* loaded from: input_file:com/mindspore/flclient/SecureProtocol.class */
public class SecureProtocol {
    private static final Logger LOGGER = FLLoggerGenerater.getModelLogger(SecureProtocol.class.toString());
    private static double deltaError = 1.0E-6d;
    private static Map<String, float[]> modelMap;
    private int iteration;
    private CipherClient cipherClient;
    private FLClientStatus status;
    private double dpEps;
    private double dpDelta;
    private double dpNormClip;
    private int retCode;
    private float signK;
    private float signEps;
    private float signThrRatio;
    private float signGlobalLr;
    private int signDimOut;
    private FLParameter flParameter = FLParameter.getInstance();
    private LocalFLParameter localFLParameter = LocalFLParameter.getInstance();
    private float[] featureMask = new float[0];
    private ArrayList<String> updateFeatureName = new ArrayList<>();

    public FLClientStatus getStatus() {
        return this.status;
    }

    public int getRetCode() {
        return this.retCode;
    }

    public void setPWParameter(int i, int i2, byte[] bArr, int i3) {
        if (bArr == null || bArr.length == 0) {
            LOGGER.severe("[PairWiseMask] the input argument <prime> is null, please check!");
            throw new IllegalArgumentException();
        }
        this.iteration = i;
        this.cipherClient = new CipherClient(this.iteration, i2, bArr, i3);
    }

    public FLClientStatus setDPParameter(int i, double d, double d2, double d3) {
        this.iteration = i;
        this.dpEps = d;
        this.dpDelta = d2;
        this.dpNormClip = d3;
        return FLClientStatus.SUCCESS;
    }

    public FLClientStatus setDSParameter(float f, float f2, float f3, float f4, int i) {
        this.signK = f;
        this.signEps = f2;
        this.signThrRatio = f3;
        this.signGlobalLr = f4;
        this.signDimOut = i;
        return FLClientStatus.SUCCESS;
    }

    public ArrayList<String> getUpdateFeatureName() {
        return this.updateFeatureName;
    }

    public void setUpdateFeatureName(ArrayList<String> arrayList) {
        this.updateFeatureName = arrayList;
    }

    public String getNextRequestTime() {
        return this.cipherClient.getNextRequestTime();
    }

    public double getDpNormClip() {
        return this.dpNormClip;
    }

    public FLClientStatus pwCreateMask() {
        LOGGER.info(String.format("[PairWiseMask] ==============request flID: %s ==============", this.localFLParameter.getFlID()));
        if (this.localFLParameter.isStopJobFlag()) {
            LOGGER.info("the stopJObFlag is set to true, the job will be stop");
            return this.status;
        }
        this.status = this.cipherClient.exchangeKeys();
        this.retCode = this.cipherClient.getRetCode();
        LOGGER.info(String.format("[PairWiseMask] ============= RequestExchangeKeys+GetExchangeKeys response: %s ", "============", this.status));
        if (this.status != FLClientStatus.SUCCESS) {
            return this.status;
        }
        if (this.localFLParameter.isStopJobFlag()) {
            LOGGER.info("the stopJObFlag is set to true, the job will be stop");
            return this.status;
        }
        this.status = this.cipherClient.shareSecrets();
        this.retCode = this.cipherClient.getRetCode();
        LOGGER.info(String.format("[Encrypt] =============RequestShareSecrets+GetShareSecrets response: %s ", "=============", this.status));
        if (this.status != FLClientStatus.SUCCESS) {
            return this.status;
        }
        if (this.localFLParameter.isStopJobFlag()) {
            LOGGER.info("the stopJObFlag is set to true, the job will be stop");
            return this.status;
        }
        this.featureMask = this.cipherClient.doubleMaskingWeight();
        if (this.featureMask == null || this.featureMask.length <= 0) {
            LOGGER.severe("[Encrypt] the returned featureMask from cipherClient.doubleMaskingWeight is null, please check!");
            return FLClientStatus.FAILED;
        }
        this.retCode = this.cipherClient.getRetCode();
        LOGGER.info("[Encrypt] =============Create double feature mask: SUCCESS=============");
        return this.status;
    }

    public float[] pwMaskWeight(int i, float[] fArr, int i2) {
        new HashMap();
        if (this.featureMask == null || this.featureMask.length == 0) {
            throw new RuntimeException("[pwMaskWeight] feature mask is null, please check");
        }
        if (this.featureMask.length < i2 + fArr.length) {
            throw new RuntimeException("[pwMaskWeight] the data length is out of range for array featureMask, featureMask length:" + this.featureMask.length + " data length:" + fArr.length);
        }
        LOGGER.info(String.format("[pwMaskWeight] feature mask size: %s", Integer.valueOf(this.featureMask.length)));
        float[] fArr2 = new float[fArr.length];
        LOGGER.info(String.format("[pwMaskWeight] feature  size: %s", Integer.valueOf(fArr.length)));
        for (int i3 = 0; i3 < fArr.length; i3++) {
            fArr2[i3] = (fArr[i3] * i) + this.featureMask[i2 + i3];
        }
        return fArr2;
    }

    public FLClientStatus pwUnmasking() {
        this.status = this.cipherClient.reconstructSecrets();
        this.retCode = this.cipherClient.getRetCode();
        LOGGER.info(String.format("[Encrypt] =============GetClientList+SendReconstructSecret: %s =============", this.status));
        return this.status;
    }

    private static float calculateErf(double d) {
        double d2 = d / FLParameter.SLEEP_TIME;
        double d3 = 0.0d + 1.0d;
        for (int i = 1; i < 10000; i++) {
            d3 += 2.0d * Math.exp(-Math.pow(d2 * i, 2.0d));
        }
        return (float) (((d3 + Math.exp(-Math.pow(d2 * FLParameter.SLEEP_TIME, 2.0d))) * d2) / Math.pow(3.141592653589793d, 0.5d));
    }

    private static double calculatePhi(double d) {
        return 0.5d * (1.0d + calculateErf(d / Math.sqrt(2.0d)));
    }

    private static double calculateBPositive(double d, double d2) {
        return calculatePhi(Math.sqrt(d * d2)) - (Math.exp(d) * calculatePhi(-Math.sqrt(d * (d2 + 2.0d))));
    }

    private static double calculateBNegative(double d, double d2) {
        return calculatePhi(-Math.sqrt(d * d2)) - (Math.exp(d) * calculatePhi(-Math.sqrt(d * (d2 + 2.0d))));
    }

    private static double calculateSPositive(double d, double d2, double d3, double d4) {
        double calculateBPositive = calculateBPositive(d, d4);
        double d5 = d3;
        double d6 = d4;
        while (calculateBPositive <= d2) {
            d5 = d6;
            d6 = 2.0d * d5;
            calculateBPositive = calculateBPositive(d, d6);
        }
        double d7 = d5 + ((d6 - d5) / 2.0d);
        int i = 0;
        do {
            double calculateBPositive2 = calculateBPositive(d, d7);
            if (calculateBPositive2 > d2) {
                d6 = d7;
            } else {
                if (d2 - calculateBPositive2 <= deltaError) {
                    break;
                }
                d5 = d7;
            }
            d7 = d5 + ((d6 - d5) / 2.0d);
            i++;
        } while (i <= 1000);
        return d7;
    }

    private static double calculateSNegative(double d, double d2, double d3, double d4) {
        double calculateBNegative = calculateBNegative(d, d4);
        double d5 = d3;
        double d6 = d4;
        while (calculateBNegative > d2) {
            d5 = d6;
            d6 = 2.0d * d5;
            calculateBNegative = calculateBNegative(d, d6);
        }
        double d7 = d5 + ((d6 - d5) / 2.0d);
        int i = 0;
        do {
            double calculateBNegative2 = calculateBNegative(d, d7);
            if (calculateBNegative2 > d2) {
                d5 = d7;
            } else {
                if (d2 - calculateBNegative2 <= deltaError) {
                    break;
                }
                d6 = d7;
            }
            d7 = d5 + ((d6 - d5) / 2.0d);
            i++;
        } while (i <= 1000);
        return d7;
    }

    public double calculateSigma() {
        double calculateBPositive = calculateBPositive(this.dpEps, 0.0d);
        double d = 1.0d;
        if (this.dpDelta > calculateBPositive) {
            double calculateSPositive = calculateSPositive(this.dpEps, this.dpDelta, 0.0d, 1.0d);
            d = Math.sqrt(1.0d + (calculateSPositive / 2.0d)) - Math.sqrt(calculateSPositive / 2.0d);
        } else if (this.dpDelta < calculateBPositive) {
            double calculateSNegative = calculateSNegative(this.dpEps, this.dpDelta, 0.0d, 1.0d);
            d = Math.sqrt(1.0d + (calculateSNegative / 2.0d)) + Math.sqrt(calculateSNegative / 2.0d);
        } else {
            LOGGER.info("[Encrypt] targetDelta = deltaZero");
        }
        return (d * this.dpNormClip) / Math.sqrt(2.0d * this.dpEps);
    }

    private static double comb(double d, double d2) {
        double d3 = d + 1.0d;
        if (!(d2 <= d && d >= 0.0d && d2 >= 0.0d)) {
            return 0.0d;
        }
        double d4 = 1.0d;
        for (int i = 1; i <= Math.min(d2, d - d2); i++) {
            d4 = (d4 * (d3 - i)) / i;
        }
        return d4;
    }

    private static double countCombs(int i, int i2, int i3, int i4) {
        return comb(i2, i) * comb(i3 - i2, i4 - i);
    }

    private static List<Double> calcPmf(int i, int i2, int i3, int i4, float f) {
        ArrayList arrayList = new ArrayList();
        int i5 = 0;
        while (i5 <= i4) {
            arrayList.add(Double.valueOf(i5 < i ? countCombs(i5, i2, i3, i4) : countCombs(i5, i2, i3, i4) * Math.exp(f)));
            i5++;
        }
        double d = 0.0d;
        for (int i6 = 0; i6 < arrayList.size(); i6++) {
            d += ((Double) arrayList.get(i6)).doubleValue();
        }
        if (d == 0.0d) {
            LOGGER.severe("[SignDS] probability mass function is 0, please check");
            return new ArrayList();
        }
        for (int i7 = 0; i7 < arrayList.size(); i7++) {
            arrayList.set(i7, Double.valueOf(((Double) arrayList.get(i7)).doubleValue() / d));
        }
        return arrayList;
    }

    private static double calcExpectation(List<Double> list) {
        double d = 0.0d;
        for (int i = 0; i < list.size(); i++) {
            d += i * list.get(i).doubleValue();
        }
        return d;
    }

    private static int calcOptThr(int i, int i2, int i3, float f) {
        double d = 0.0d;
        double d2 = 0.0d;
        for (int i4 = 1; i4 <= i3; i4++) {
            double calcExpectation = calcExpectation(calcPmf(i4, i, i2, i3, f));
            if (calcExpectation <= d) {
                break;
            }
            d = calcExpectation;
            d2 = i4;
        }
        return (int) Math.max(d2, 1.0d);
    }

    private static int findOptOutputDim(float f, int i, int i2, float f2) {
        int i3 = 1;
        while (true) {
            double calcExpectation = calcExpectation(calcPmf(calcOptThr(i, i2, i3, f2), i, i2, i3, f2)) / i3;
            if (calcExpectation < f || Double.isNaN(calcExpectation)) {
                break;
            }
            i3++;
        }
        return Math.max(1, i3 - 1);
    }

    private static int countInters(int i, double d, int i2, int i3, int i4, float f) {
        double nextDouble = new SecureRandom().nextDouble();
        int i5 = 0;
        double countCombs = countCombs(0, i2, i3, i4) / d;
        while (true) {
            double d2 = countCombs;
            if (d2 >= nextDouble) {
                return i5;
            }
            i5++;
            countCombs = i5 < i ? d2 + (countCombs(i5, i2, i3, i4) / d) : d2 + ((Math.exp(f) * countCombs(i5, i2, i3, i4)) / d);
        }
    }

    private static void randomSelect(SecureRandom secureRandom, int[] iArr, int i, int[] iArr2, int i2) {
        if (i <= 0) {
            LOGGER.severe("[SignDS] The number to be selected is set incorrectly!");
            return;
        }
        if (iArr.length < i) {
            LOGGER.severe("[SignDS] The size of inputList is small than num!");
            return;
        }
        for (int length = iArr.length; length > iArr.length - i; length--) {
            int nextInt = secureRandom.nextInt(length);
            int i3 = iArr[nextInt];
            iArr[nextInt] = iArr[length - 1];
            iArr[length - 1] = i3;
            iArr2[(i2 + iArr.length) - length] = i3;
        }
    }

    public int[] signDSModel(Client client, boolean z) {
        int size = this.updateFeatureName.size();
        int i = 0;
        for (int i2 = 0; i2 < size; i2++) {
            i += client.getPreFeature(this.updateFeatureName.get(i2)).length;
        }
        int i3 = (int) (this.signK * i);
        if (this.signDimOut == 0) {
            this.signDimOut = findOptOutputDim(this.signThrRatio, i3, i, this.signEps);
        }
        int calcOptThr = calcOptThr(i3, i, this.signDimOut, this.signEps);
        double d = 0.0d;
        double d2 = 0.0d;
        for (int i4 = 0; i4 < calcOptThr; i4++) {
            d += countCombs(i4, i3, i, this.signDimOut);
        }
        for (int i5 = calcOptThr; i5 <= this.signDimOut; i5++) {
            d2 += countCombs(i5, i3, i, this.signDimOut);
        }
        double exp = d + (Math.exp(this.signEps) * d2);
        if (exp == 0.0d) {
            LOGGER.severe("[SignDS] denominator is 0, please check");
            return new int[0];
        }
        int countInters = countInters(calcOptThr, exp, i3, i, this.signDimOut, this.signEps);
        int i6 = this.signDimOut - countInters;
        if (i3 < countInters || this.signDimOut <= 0) {
            LOGGER.severe("[SignDS] topkDim or signDimOut is ERROR! please check");
            return new int[0];
        }
        float[] fArr = new float[i];
        Integer[] numArr = new Integer[i];
        int i7 = 0;
        for (int i8 = 0; i8 < size; i8++) {
            String str = this.updateFeatureName.get(i8);
            float[] feature = client.getFeature(str);
            float[] preFeature = client.getPreFeature(str);
            for (int i9 = 0; i9 < feature.length; i9++) {
                fArr[i7] = feature[i9] - preFeature[i9];
                numArr[i7] = Integer.valueOf(i7);
                i7++;
            }
        }
        if (z) {
            Arrays.sort(numArr, (num, num2) -> {
                return Float.compare(fArr[num2.intValue()], fArr[num.intValue()]);
            });
        } else {
            Arrays.sort(numArr, (num3, num4) -> {
                return Float.compare(fArr[num3.intValue()], fArr[num4.intValue()]);
            });
        }
        int[] iArr = new int[i - i3];
        int[] iArr2 = new int[i3];
        for (int i10 = 0; i10 < i3; i10++) {
            iArr2[i10] = numArr[i10].intValue();
        }
        for (int i11 = i3; i11 < i; i11++) {
            iArr[i11 - i3] = numArr[i11].intValue();
        }
        int[] iArr3 = new int[countInters + i6];
        SecureRandom secureRandom = Common.getSecureRandom();
        randomSelect(secureRandom, iArr2, countInters, iArr3, 0);
        randomSelect(secureRandom, iArr, i6, iArr3, countInters);
        Arrays.sort(iArr3);
        LOGGER.info("[SignDS] outputDimension size is " + iArr3.length);
        return iArr3;
    }
}
