/*
 * Decompiled with CFR 0.152.
 */
package edu.neu.ccs.pyramid.eval;

import edu.neu.ccs.pyramid.classification.Classifier;
import edu.neu.ccs.pyramid.dataset.ClfDataSet;
import edu.neu.ccs.pyramid.dataset.MultiLabel;
import edu.neu.ccs.pyramid.dataset.MultiLabelClfDataSet;
import edu.neu.ccs.pyramid.multilabel_classification.MultiLabelClassifier;
import java.util.stream.IntStream;

public class Accuracy {
    public static double accuracy(int tp, int tn, int fp, int fn) {
        return (double)(tp + tn) * 1.0 / (double)(tp + tn + fp + fn);
    }

    public static double accuracy(Classifier classifier, ClfDataSet clfDataSet) {
        int[] prediction = classifier.predict(clfDataSet);
        return Accuracy.accuracy(clfDataSet.getLabels(), prediction);
    }

    public static double accuracy(int[] labels, int[] predictions) {
        if (labels.length != predictions.length) {
            throw new IllegalArgumentException("labels.length!=predictions.length");
        }
        double numCorrect = IntStream.range(0, labels.length).parallel().filter(i -> labels[i] == predictions[i]).count();
        return numCorrect / (double)labels.length;
    }

    public static double accuracy(MultiLabel[] multiLabels, MultiLabel[] predictions) {
        if (multiLabels.length == 0) {
            throw new IllegalArgumentException("multi labels length is zero.");
        }
        if (multiLabels.length != predictions.length) {
            throw new IllegalArgumentException("multi labels length is not equal to predictions length.");
        }
        double numCorrect = IntStream.range(0, multiLabels.length).parallel().filter(i -> multiLabels[i].equals(predictions[i])).count();
        return numCorrect / (double)multiLabels.length;
    }

    public static double accuracy(MultiLabelClassifier classifier, MultiLabelClfDataSet dataSet) {
        return Accuracy.accuracy(dataSet.getMultiLabels(), classifier.predict(dataSet));
    }

    public static double partialAccuracy(MultiLabel[] multiLabels, MultiLabel[] predictions) {
        double a = 0.0;
        for (int i = 0; i < multiLabels.length; ++i) {
            MultiLabel label = multiLabels[i];
            MultiLabel prediction = predictions[i];
            a += (double)MultiLabel.intersection(label, prediction).size() * 1.0 / (double)MultiLabel.union(label, prediction).size();
        }
        return a / (double)multiLabels.length;
    }

    public static double partialAccuracy(MultiLabelClassifier classifier, MultiLabelClfDataSet dataSet) {
        return Accuracy.partialAccuracy(dataSet.getMultiLabels(), classifier.predict(dataSet));
    }
}

