/*
 * Decompiled with CFR 0.152.
 */
package edu.neu.ccs.pyramid.eval;

import edu.neu.ccs.pyramid.classification.Classifier;
import edu.neu.ccs.pyramid.dataset.ClfDataSet;
import edu.neu.ccs.pyramid.dataset.DataSet;
import edu.neu.ccs.pyramid.dataset.MultiLabel;
import edu.neu.ccs.pyramid.dataset.MultiLabelClfDataSet;
import edu.neu.ccs.pyramid.eval.Precision;
import edu.neu.ccs.pyramid.eval.SafeDivide;
import edu.neu.ccs.pyramid.multilabel_classification.MultiLabelClassifier;
import edu.neu.ccs.pyramid.util.ArgSort;
import java.util.HashSet;
import java.util.Set;
import java.util.stream.IntStream;

public class AveragePrecision {
    public static double averagePrecision(boolean[] relevance) {
        double totalRelevant = 0.0;
        double relevantSoFar = 0.0;
        double sumPrecisionAtK = 0.0;
        for (int i = 0; i < relevance.length; ++i) {
            if (!relevance[i]) continue;
            totalRelevant += 1.0;
            sumPrecisionAtK += (relevantSoFar += 1.0) / (double)(i + 1);
        }
        return SafeDivide.divide(sumPrecisionAtK, totalRelevant, 1.0);
    }

    public static double averagePrecision(int[] relevance) {
        double totalRelevant = 0.0;
        double relevantSoFar = 0.0;
        double sumPrecisionAtK = 0.0;
        for (int i = 0; i < relevance.length; ++i) {
            if (relevance[i] != 1) continue;
            totalRelevant += 1.0;
            sumPrecisionAtK += (relevantSoFar += 1.0) / (double)(i + 1);
        }
        return SafeDivide.divide(sumPrecisionAtK, totalRelevant, 1.0);
    }

    public static double averagePrecision(int[] binaryLabels, double[] scores) {
        int[] sortedIndices = ArgSort.argSortDescending(scores);
        int[] relevance = new int[binaryLabels.length];
        for (int i = 0; i < relevance.length; ++i) {
            relevance[i] = binaryLabels[sortedIndices[i]];
        }
        return AveragePrecision.averagePrecision(relevance);
    }

    public static double averagePrecision(Classifier.ProbabilityEstimator classifier, ClfDataSet dataSet) {
        if (classifier.getNumClasses() != 2) {
            throw new IllegalArgumentException("classifier.getNumClasses()!=2");
        }
        return AveragePrecision.averagePrecision(classifier, dataSet, dataSet.getLabels());
    }

    public static double averagePrecision(Classifier.ProbabilityEstimator classifier, DataSet dataSet, int[] labels) {
        double[] probs = new double[dataSet.getNumDataPoints()];
        IntStream.range(0, dataSet.getNumDataPoints()).parallel().forEach(i -> {
            probs[i] = classifier.predictClassProbs(dataSet.getRow(i))[1];
        });
        return AveragePrecision.averagePrecision(labels, probs);
    }

    public static double averagePrecision(MultiLabelClassifier.ClassScoreEstimator classifier, MultiLabelClfDataSet dataSet) {
        double ap = 0.0;
        MultiLabel[] labels = dataSet.getMultiLabels();
        for (int i = 0; i < labels.length; ++i) {
            Set<Integer> label = labels[i].getMatchedLabels();
            double[] scores = classifier.predictClassScores(dataSet.getRow(i));
            ap += AveragePrecision.averagePrecision(label, scores);
        }
        return ap * 1.0 / (double)labels.length;
    }

    private static double averagePrecision(Set<Integer> label, double[] scores) {
        int[] sortedIndices = ArgSort.argSortDescending(scores);
        double sumPrecision = 0.0;
        HashSet<Integer> positivePredict = new HashSet<Integer>();
        for (int k = 0; k < sortedIndices.length; ++k) {
            int predict = sortedIndices[k];
            positivePredict.add(predict);
            if (!label.contains(predict)) continue;
            sumPrecision += AveragePrecision.getPrecision(label, positivePredict);
        }
        return 1.0 / (double)label.size() * sumPrecision;
    }

    private static double getPrecision(Set<Integer> label, Set<Integer> positivePredict) {
        int tp = 0;
        int fp = 0;
        int fn = 0;
        for (Integer predict : positivePredict) {
            if (!label.contains(predict)) continue;
            ++tp;
        }
        fp = positivePredict.size() - tp;
        fn = label.size() - tp;
        return Precision.precision(tp, fp);
    }
}

