/*
 * Decompiled with CFR 0.152.
 */
package edu.neu.ccs.pyramid.eval;

import edu.neu.ccs.pyramid.dataset.MultiLabelClfDataSet;
import edu.neu.ccs.pyramid.eval.AveragePrecision;
import edu.neu.ccs.pyramid.multilabel_classification.MultiLabelClassifier;
import java.util.List;
import java.util.stream.Collectors;
import java.util.stream.IntStream;

public class MAP {
    public static double map(MultiLabelClassifier.ClassProbEstimator classifier, MultiLabelClfDataSet dataSet, List<Integer> labels) {
        if (classifier.getNumClasses() != dataSet.getNumClasses()) {
            throw new IllegalArgumentException("classifier.getNumClasses()!=dataSet.getNumClasses()");
        }
        int numData = dataSet.getNumDataPoints();
        double[][] probs = new double[dataSet.getNumDataPoints()][dataSet.getNumClasses()];
        IntStream.range(0, dataSet.getNumDataPoints()).parallel().forEach(i -> {
            probs[i] = classifier.predictClassProbs(dataSet.getRow(i));
        });
        double sum = 0.0;
        for (int l : labels) {
            int[] binaryLabels = new int[numData];
            double[] marginals = new double[numData];
            for (int i2 = 0; i2 < numData; ++i2) {
                if (dataSet.getMultiLabels()[i2].matchClass(l)) {
                    binaryLabels[i2] = 1;
                }
                marginals[i2] = probs[i2][l];
            }
            double averagePrecision = AveragePrecision.averagePrecision(binaryLabels, marginals);
            sum += averagePrecision;
        }
        return sum / (double)labels.size();
    }

    public static double map(MultiLabelClassifier.ClassProbEstimator classifier, MultiLabelClfDataSet dataSet) {
        List<Integer> labels = IntStream.range(0, dataSet.getNumClasses()).boxed().collect(Collectors.toList());
        return MAP.map(classifier, dataSet, labels);
    }
}

