/*
 * Decompiled with CFR 0.152.
 */
package edu.neu.ccs.pyramid.eval;

import edu.neu.ccs.pyramid.classification.Classifier;
import edu.neu.ccs.pyramid.dataset.ClfDataSet;
import edu.neu.ccs.pyramid.dataset.DataSet;
import edu.neu.ccs.pyramid.util.Pair;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Comparator;
import java.util.List;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import java.util.stream.Stream;

public class AUC {
    public static double auc(Classifier.ProbabilityEstimator probEstimator, ClfDataSet dataSet) {
        if (dataSet.getNumClasses() != 2) {
            throw new IllegalArgumentException("dataSet.getNumClasses()!=2");
        }
        double[] probForOne = IntStream.range(0, dataSet.getNumDataPoints()).parallel().mapToDouble(i -> probEstimator.predictClassProbs(dataSet.getRow(i))[1]).toArray();
        int[] labels = dataSet.getLabels();
        return AUC.auc(probForOne, labels);
    }

    public static double auc(Classifier.ProbabilityEstimator probEstimator, DataSet dataSet, int[] labels) {
        double[] probForOne = IntStream.range(0, dataSet.getNumDataPoints()).parallel().mapToDouble(i -> probEstimator.predictClassProbs(dataSet.getRow(i))[1]).toArray();
        return AUC.auc(probForOne, labels);
    }

    public static double auc(double[] scores, int[] labels) {
        int sum = Arrays.stream(labels).sum();
        if (sum == 0 || sum == labels.length) {
            return 1.0;
        }
        Comparator<Pair> comparator = Comparator.comparing(Pair::getFirst);
        List<Pair<Double, Integer>> sortedPairs = IntStream.range(0, scores.length).parallel().mapToObj(i -> new Pair<Double, Integer>(scores[i], labels[i])).sorted(comparator.reversed()).collect(Collectors.toList());
        List<List<Double>> rates = AUC.getRates(sortedPairs);
        return AUC.area(rates);
    }

    private static List<List<Double>> getRates(List<Pair<Double, Integer>> sortedPairs) {
        int numData = sortedPairs.size();
        ArrayList<Double> truePositiveRates = new ArrayList<Double>();
        ArrayList<Double> falsePositiveRates = new ArrayList<Double>();
        double numPositives = ((Stream)sortedPairs.stream().parallel()).filter(pair -> (Integer)pair.getSecond() == 1).count();
        double numNegatives = (double)numData - numPositives;
        double truePositive = 0.0;
        double falsePositive = 0.0;
        truePositiveRates.add(0.0);
        falsePositiveRates.add(0.0);
        for (int i = 0; i < numData; ++i) {
            boolean condition;
            Pair<Double, Integer> pair2 = sortedPairs.get(i);
            int label = pair2.getSecond();
            double score = pair2.getFirst();
            if (label == 1) {
                truePositive += 1.0;
            } else {
                falsePositive += 1.0;
            }
            boolean condition1 = i < numData - 1 && score != sortedPairs.get(i + 1).getFirst();
            boolean condition2 = i == numData - 1;
            boolean bl = condition = condition1 || condition2;
            if (!condition) continue;
            truePositiveRates.add(truePositive / numPositives);
            falsePositiveRates.add(falsePositive / numNegatives);
        }
        ArrayList<List<Double>> rates = new ArrayList<List<Double>>();
        rates.add(truePositiveRates);
        rates.add(falsePositiveRates);
        return rates;
    }

    private static double area(List<List<Double>> rates) {
        List<Double> tpr = rates.get(0);
        List<Double> fpr = rates.get(1);
        double tmp = IntStream.range(0, tpr.size() - 1).parallel().mapToDouble(i -> ((Double)fpr.get(i) - (Double)fpr.get(i + 1)) * ((Double)tpr.get(i) + (Double)tpr.get(i + 1))).sum();
        double area = Math.abs(tmp) / 2.0;
        return area;
    }
}

