package com.mindspore.flclient.model;

import java.util.ArrayList;
import java.util.Comparator;
import java.util.List;
import java.util.logging.Logger;

/* loaded from: input_file:com/mindspore/flclient/model/AUCCalculator.class */
public class AUCCalculator {
    private static final Logger LOGGER = Logger.getLogger(AUCCalculator.class.toString());
    private List<Float> fps = new ArrayList();
    private List<Float> tps = new ArrayList();
    private List<Float> thresholds = new ArrayList();
    private List<Float> fpr = new ArrayList();
    private List<Float> tpr = new ArrayList();

    private Boolean isValidValue(Float f) {
        return Boolean.valueOf(!f.isInfinite() && !f.isNaN() && ((double) f.floatValue()) >= 0.0d && ((double) f.floatValue()) <= 1.0d);
    }

    public Float getAuc(List<Float> list, List<Float> list2) {
        LOGGER.info("label is:" + list.toString());
        LOGGER.info("predict is:" + list2.toString());
        this.fps.clear();
        this.tps.clear();
        this.thresholds.clear();
        this.fpr.clear();
        this.tpr.clear();
        if (!binaryClfCurve(list, list2).booleanValue()) {
            LOGGER.severe("Do binaryClfCurve failed.");
            return Float.valueOf(0.0f);
        }
        if (rocCurve()) {
            Float trapz = trapz(this.fpr, this.tpr);
            return Float.valueOf(trapz.isNaN() ? 0.0f : trapz.floatValue());
        }
        LOGGER.severe("Do rocCurve failed.");
        return Float.valueOf(0.0f);
    }

    /* JADX WARN: Multi-variable type inference failed */
    private Boolean binaryClfCurve(List<Float> list, final List<Float> list2) {
        if (list.size() != list2.size()) {
            LOGGER.severe("The input len of label is not same to predict.");
            return false;
        }
        for (int i = 0; i < list.size(); i++) {
            if (!isValidValue(list.get(i)).booleanValue() || !isValidValue(list2.get(i)).booleanValue()) {
                LOGGER.severe("Get invalid value, idx " + i + " label value:" + list.get(i) + " predict value:" + list2.get(i));
                return false;
            }
        }
        ArrayList arrayList = new ArrayList();
        for (int i2 = 0; i2 < list.size(); i2++) {
            arrayList.add(Integer.valueOf((list.size() - 1) - i2));
        }
        arrayList.sort(new Comparator<Integer>() { // from class: com.mindspore.flclient.model.AUCCalculator.1
            @Override // java.util.Comparator
            public int compare(Integer num, Integer num2) {
                if (((Float) list2.get(num.intValue())).floatValue() > ((Float) list2.get(num2.intValue())).floatValue()) {
                    return -1;
                }
                return ((Float) list2.get(num.intValue())).equals(list2.get(num2.intValue())) ? 0 : 1;
            }
        });
        ArrayList arrayList2 = new ArrayList();
        ArrayList arrayList3 = new ArrayList();
        for (int i3 = 0; i3 < arrayList.size(); i3++) {
            int intValue = ((Integer) arrayList.get(i3)).intValue();
            arrayList2.add(list.get(intValue));
            arrayList3.add(list2.get(intValue));
        }
        ArrayList arrayList4 = new ArrayList();
        Float f = (Float) arrayList3.get(0);
        for (int i4 = 1; i4 < arrayList3.size(); i4++) {
            Float f2 = (Float) arrayList3.get(i4);
            if (!f.equals(f2)) {
                f = f2;
                arrayList4.add(Integer.valueOf(i4 - 1));
            }
        }
        arrayList4.add(Integer.valueOf(arrayList3.size() - 1));
        Float valueOf = Float.valueOf(0.0f);
        int i5 = 0;
        for (int i6 = 0; i6 < arrayList4.size(); i6++) {
            int intValue2 = ((Integer) arrayList4.get(i6)).intValue();
            while (intValue2 >= i5) {
                valueOf = Float.valueOf(valueOf.floatValue() + ((Float) arrayList2.get(i5)).floatValue());
                i5++;
            }
            this.tps.add(valueOf);
            this.fps.add(Float.valueOf((intValue2 + 1) - valueOf.floatValue()));
            this.thresholds.add(arrayList3.get(intValue2));
        }
        return true;
    }

    private List<Float> getDiff(List<Float> list) {
        ArrayList arrayList = new ArrayList();
        Float f = list.get(0);
        for (int i = 1; i < list.size(); i++) {
            Float f2 = list.get(i);
            arrayList.add(Float.valueOf(f2.floatValue() - f.floatValue()));
            f = f2;
        }
        return arrayList;
    }

    private List<Float> getDiffByLev(List<Float> list, int i) {
        for (int i2 = 0; i > i2 && list.size() > 1; i2++) {
            list = getDiff(list);
        }
        return list;
    }

    private Float trapz(List<Float> list, List<Float> list2) {
        if (list.size() != list2.size()) {
            throw new IllegalArgumentException("x.length != y.length");
        }
        if (list2.size() == 0) {
            throw new IllegalArgumentException("y.length == 0");
        }
        Float valueOf = Float.valueOf(0.0f);
        Float f = list.get(0);
        Float f2 = list2.get(0);
        for (int i = 1; i < list2.size(); i++) {
            Float f3 = list.get(i);
            Float f4 = list2.get(i);
            valueOf = Float.valueOf(valueOf.floatValue() + (Float.valueOf(f3.floatValue() - f.floatValue()).floatValue() * Float.valueOf(f2.floatValue() + f4.floatValue()).floatValue()));
            f = f3;
            f2 = f4;
        }
        return Float.valueOf(valueOf.floatValue() / 2.0f);
    }

    private boolean rocCurve() {
        if (this.fps.size() > 2) {
            List<Float> diffByLev = getDiffByLev(this.fps, 2);
            List<Float> diffByLev2 = getDiffByLev(this.tps, 2);
            if (diffByLev.size() != diffByLev2.size()) {
                LOGGER.severe("The size of fps_diff2 " + diffByLev.size() + " is not same to tps_diff2 " + diffByLev2.size());
                return false;
            }
            ArrayList arrayList = new ArrayList();
            arrayList.add(0);
            for (int i = 0; i < diffByLev.size(); i++) {
                if (Math.abs(diffByLev.get(i).floatValue()) + Math.abs(diffByLev2.get(i).floatValue()) > 0.0d) {
                    arrayList.add(Integer.valueOf(i + 1));
                }
            }
            arrayList.add(Integer.valueOf(diffByLev.size() + 1));
            ArrayList arrayList2 = new ArrayList();
            ArrayList arrayList3 = new ArrayList();
            ArrayList arrayList4 = new ArrayList();
            arrayList2.add(Float.valueOf(0.0f));
            arrayList3.add(Float.valueOf(0.0f));
            arrayList4.add(Float.valueOf(this.thresholds.get(0).floatValue() + 1.0f));
            for (int i2 = 0; i2 < arrayList.size(); i2++) {
                int intValue = ((Integer) arrayList.get(i2)).intValue();
                arrayList2.add(this.fps.get(intValue));
                arrayList3.add(this.tps.get(intValue));
                arrayList4.add(this.thresholds.get(intValue));
            }
            this.fps = arrayList2;
            this.tps = arrayList3;
            this.thresholds = arrayList4;
        } else {
            this.fps.add(0, Float.valueOf(0.0f));
            this.tps.add(0, Float.valueOf(0.0f));
            this.thresholds.add(0, Float.valueOf(this.thresholds.get(0).floatValue() + 1.0f));
        }
        Float f = this.fps.get(this.fps.size() - 1);
        Float f2 = this.tps.get(this.tps.size() - 1);
        Boolean valueOf = Boolean.valueOf(((double) f.floatValue()) <= 0.0d);
        Boolean valueOf2 = Boolean.valueOf(((double) f2.floatValue()) <= 0.0d);
        for (int i3 = 0; i3 < this.fps.size(); i3++) {
            Float f3 = this.fps.get(i3);
            Float f4 = this.tps.get(i3);
            this.fpr.add(Float.valueOf(valueOf.booleanValue() ? Float.NaN : f3.floatValue() / f.floatValue()));
            this.tpr.add(Float.valueOf(valueOf2.booleanValue() ? Float.NaN : f4.floatValue() / f2.floatValue()));
        }
        return true;
    }
}
