/*
 * Decompiled with CFR 0.152.
 */
package com.mindspore.flclient.model;

import java.util.ArrayList;
import java.util.Comparator;
import java.util.List;
import java.util.logging.Logger;

public class AUCCalculator {
    private static final Logger LOGGER = Logger.getLogger(AUCCalculator.class.toString());
    private List<Float> fps = new ArrayList<Float>();
    private List<Float> tps = new ArrayList<Float>();
    private List<Float> thresholds = new ArrayList<Float>();
    private List<Float> fpr = new ArrayList<Float>();
    private List<Float> tpr = new ArrayList<Float>();

    private Boolean isValidValue(Float val) {
        return !val.isInfinite() && !val.isNaN() && (double)val.floatValue() >= 0.0 && (double)val.floatValue() <= 1.0;
    }

    public Float getAuc(List<Float> label, List<Float> predict) {
        LOGGER.info("label is:" + label.toString());
        LOGGER.info("predict is:" + predict.toString());
        this.fps.clear();
        this.tps.clear();
        this.thresholds.clear();
        this.fpr.clear();
        this.tpr.clear();
        boolean ret = this.binaryClfCurve(label, predict);
        if (!ret) {
            LOGGER.severe("Do binaryClfCurve failed.");
            return Float.valueOf(0.0f);
        }
        ret = this.rocCurve();
        if (!ret) {
            LOGGER.severe("Do rocCurve failed.");
            return Float.valueOf(0.0f);
        }
        Float val = this.trapz(this.fpr, this.tpr);
        return Float.valueOf(val.isNaN() ? 0.0f : val.floatValue());
    }

    private Boolean binaryClfCurve(List<Float> label, final List<Float> predict) {
        if (label.size() != predict.size()) {
            LOGGER.severe("The input len of label is not same to predict.");
            return false;
        }
        for (int i = 0; i < label.size(); ++i) {
            if (this.isValidValue(label.get(i)).booleanValue() && this.isValidValue(predict.get(i)).booleanValue()) continue;
            LOGGER.severe("Get invalid value, idx " + i + " label value:" + label.get(i) + " predict value:" + predict.get(i));
            return false;
        }
        ArrayList<Integer> idx = new ArrayList<Integer>();
        for (int i = 0; i < label.size(); ++i) {
            idx.add(label.size() - 1 - i);
        }
        idx.sort(new Comparator<Integer>(){

            @Override
            public int compare(Integer o1, Integer o2) {
                if (((Float)predict.get(o1)).floatValue() > ((Float)predict.get(o2)).floatValue()) {
                    return -1;
                }
                if (((Float)predict.get(o1)).equals(predict.get(o2))) {
                    return 0;
                }
                return 1;
            }
        });
        ArrayList<Float> sortLabel = new ArrayList<Float>();
        ArrayList<Float> sortPred = new ArrayList<Float>();
        for (int i = 0; i < idx.size(); ++i) {
            int curIdx = (Integer)idx.get(i);
            sortLabel.add(label.get(curIdx));
            sortPred.add(predict.get(curIdx));
        }
        ArrayList<Integer> distinctValueIdx = new ArrayList<Integer>();
        Float preVal = (Float)sortPred.get(0);
        for (int i = 1; i < sortPred.size(); ++i) {
            Float curVal = (Float)sortPred.get(i);
            if (preVal.equals(curVal)) continue;
            preVal = curVal;
            distinctValueIdx.add(i - 1);
        }
        distinctValueIdx.add(sortPred.size() - 1);
        Float sum = Float.valueOf(0.0f);
        int pred_idx = 0;
        for (int i = 0; i < distinctValueIdx.size(); ++i) {
            int cur_idx = (Integer)distinctValueIdx.get(i);
            while (cur_idx >= pred_idx) {
                Float curVal = (Float)sortLabel.get(pred_idx);
                sum = Float.valueOf(sum.floatValue() + curVal.floatValue());
                ++pred_idx;
            }
            this.tps.add(sum);
            this.fps.add(Float.valueOf((float)(cur_idx + 1) - sum.floatValue()));
            this.thresholds.add((Float)sortPred.get(cur_idx));
        }
        return true;
    }

    private List<Float> getDiff(List<Float> data) {
        ArrayList<Float> diff = new ArrayList<Float>();
        Float pred = data.get(0);
        for (int i = 1; i < data.size(); ++i) {
            Float cur = data.get(i);
            diff.add(Float.valueOf(cur.floatValue() - pred.floatValue()));
            pred = cur;
        }
        return diff;
    }

    private List<Float> getDiffByLev(List<Float> data, int level) {
        for (int curDiffLev = 0; level > curDiffLev && data.size() > 1; ++curDiffLev) {
            data = this.getDiff(data);
        }
        return data;
    }

    private Float trapz(List<Float> x, List<Float> y) {
        if (x.size() != y.size()) {
            throw new IllegalArgumentException("x.length != y.length");
        }
        if (y.size() == 0) {
            throw new IllegalArgumentException("y.length == 0");
        }
        Float value = Float.valueOf(0.0f);
        Float x0 = x.get(0);
        Float y0 = y.get(0);
        for (int i = 1; i < y.size(); ++i) {
            Float x1 = x.get(i);
            Float y1 = y.get(i);
            Float dx = Float.valueOf(x1.floatValue() - x0.floatValue());
            Float ym = Float.valueOf(y0.floatValue() + y1.floatValue());
            value = Float.valueOf(value.floatValue() + dx.floatValue() * ym.floatValue());
            x0 = x1;
            y0 = y1;
        }
        return Float.valueOf(value.floatValue() / 2.0f);
    }

    private boolean rocCurve() {
        if (this.fps.size() > 2) {
            List<Float> fpsDiff2 = this.getDiffByLev(this.fps, 2);
            List<Float> tpsDiff2 = this.getDiffByLev(this.tps, 2);
            if (fpsDiff2.size() != tpsDiff2.size()) {
                LOGGER.severe("The size of fps_diff2 " + fpsDiff2.size() + " is not same to tps_diff2 " + tpsDiff2.size());
                return false;
            }
            ArrayList<Integer> optimalIdxs = new ArrayList<Integer>();
            optimalIdxs.add(0);
            for (int i = 0; i < fpsDiff2.size(); ++i) {
                Float fpsDiffVal = fpsDiff2.get(i);
                Float tpsDiffVal = tpsDiff2.get(i);
                if (!((double)(Math.abs(fpsDiffVal.floatValue()) + Math.abs(tpsDiffVal.floatValue())) > 0.0)) continue;
                optimalIdxs.add(i + 1);
            }
            optimalIdxs.add(fpsDiff2.size() + 1);
            ArrayList<Float> optFps = new ArrayList<Float>();
            ArrayList<Float> optTps = new ArrayList<Float>();
            ArrayList<Float> optThresholds = new ArrayList<Float>();
            optFps.add(Float.valueOf(0.0f));
            optTps.add(Float.valueOf(0.0f));
            optThresholds.add(Float.valueOf(this.thresholds.get(0).floatValue() + 1.0f));
            for (int i = 0; i < optimalIdxs.size(); ++i) {
                int idx = (Integer)optimalIdxs.get(i);
                optFps.add(this.fps.get(idx));
                optTps.add(this.tps.get(idx));
                optThresholds.add(this.thresholds.get(idx));
            }
            this.fps = optFps;
            this.tps = optTps;
            this.thresholds = optThresholds;
        } else {
            this.fps.add(0, Float.valueOf(0.0f));
            this.tps.add(0, Float.valueOf(0.0f));
            Float addVal = Float.valueOf(this.thresholds.get(0).floatValue() + 1.0f);
            this.thresholds.add(0, addVal);
        }
        Float lastFps = this.fps.get(this.fps.size() - 1);
        Float lastTps = this.tps.get(this.tps.size() - 1);
        Boolean fprNanFlg = (double)lastFps.floatValue() <= 0.0;
        Boolean tprNamFlg = (double)lastTps.floatValue() <= 0.0;
        for (int i = 0; i < this.fps.size(); ++i) {
            Float curFps = this.fps.get(i);
            Float curTps = this.tps.get(i);
            this.fpr.add(Float.valueOf(fprNanFlg != false ? Float.NaN : curFps.floatValue() / lastFps.floatValue()));
            this.tpr.add(Float.valueOf(tprNamFlg != false ? Float.NaN : curTps.floatValue() / lastTps.floatValue()));
        }
        return true;
    }
}

