/*
 * Decompiled with CFR 0.152.
 */
package edu.neu.ccs.pyramid.eval;

import edu.neu.ccs.pyramid.classification.Classifier;
import edu.neu.ccs.pyramid.dataset.ClfDataSet;
import edu.neu.ccs.pyramid.dataset.MultiLabel;
import edu.neu.ccs.pyramid.dataset.MultiLabelClfDataSet;
import edu.neu.ccs.pyramid.eval.SafeDivide;
import edu.neu.ccs.pyramid.multilabel_classification.MultiLabelClassifier;
import edu.neu.ccs.pyramid.util.ArgSort;
import java.util.stream.IntStream;

public class Precision {
    public static double precision(double tp, double fp) {
        return SafeDivide.divide(tp, tp + fp, 1.0);
    }

    public static double precision(Classifier classifier, ClfDataSet dataSet, int k) {
        int[] labels = dataSet.getLabels();
        int[] predictions = classifier.predict(dataSet);
        return Precision.precision(labels, predictions, k);
    }

    public static double precision(int[] labels, int[] predictions, int k) {
        int falsePositive = 0;
        int truePositive = 0;
        for (int i = 0; i < labels.length; ++i) {
            if (predictions[i] != k) continue;
            if (labels[i] == k) {
                ++truePositive;
                continue;
            }
            ++falsePositive;
        }
        return Precision.precision(truePositive, falsePositive);
    }

    public static double precisionAtK(double[] scores, MultiLabel groudtruth, int k) {
        double total = 0.0;
        int[] top = ArgSort.argSortDescending(scores);
        for (int r = 0; r < k; ++r) {
            int l = top[r];
            if (!groudtruth.matchClass(l)) continue;
            total += 1.0;
        }
        return total / (double)k;
    }

    public static double precisionAtK(MultiLabelClassifier.ClassProbEstimator multiLabelClassifier, MultiLabelClfDataSet dataSet, int k) {
        return IntStream.range(0, dataSet.getNumDataPoints()).parallel().mapToDouble(i -> Precision.precisionAtK(multiLabelClassifier.predictClassProbs(dataSet.getRow(i)), dataSet.getMultiLabels()[i], k)).average().getAsDouble();
    }

    @Deprecated
    public static double precision(MultiLabel[] multiLabels, MultiLabel[] predictions) {
        double p = 0.0;
        for (int i = 0; i < multiLabels.length; ++i) {
            MultiLabel label = multiLabels[i];
            MultiLabel prediction = predictions[i];
            if (prediction.getMatchedLabels().size() == 0) {
                p += 1.0;
                continue;
            }
            p += (double)MultiLabel.intersection(label, prediction).size() * 1.0 / (double)prediction.getMatchedLabels().size();
        }
        return p / (double)multiLabels.length;
    }

    @Deprecated
    public static double precision(MultiLabelClassifier classifier, MultiLabelClfDataSet dataSet) {
        return Precision.precision(dataSet.getMultiLabels(), classifier.predict(dataSet));
    }
}

