/*
 * Decompiled with CFR 0.152.
 */
package edu.neu.ccs.pyramid.eval;

import edu.neu.ccs.pyramid.dataset.DataSet;
import edu.neu.ccs.pyramid.dataset.MultiLabel;
import edu.neu.ccs.pyramid.eval.FMeasure;
import edu.neu.ccs.pyramid.eval.HammingLoss;
import edu.neu.ccs.pyramid.eval.MLConfusionMatrix;
import edu.neu.ccs.pyramid.eval.Overlap;
import edu.neu.ccs.pyramid.eval.Precision;
import edu.neu.ccs.pyramid.eval.Recall;
import java.util.stream.IntStream;
import org.apache.mahout.math.Vector;

public class InstanceAverage {
    private double f1;
    private double overlap;
    private double precision;
    private double recall;
    private double hammingLoss;
    private double accuracy;

    public InstanceAverage(int numClasses, MultiLabel trueLabel, MultiLabel prediction) {
        this(new MLConfusionMatrix(numClasses, InstanceAverage.toArray(trueLabel), InstanceAverage.toArray(prediction)));
    }

    private static MultiLabel[] toArray(MultiLabel multiLabel) {
        return new MultiLabel[]{multiLabel};
    }

    public InstanceAverage(MLConfusionMatrix confusionMatrix) {
        int numClasses = confusionMatrix.getNumClasses();
        int numDataPoints = confusionMatrix.getNumDataPoints();
        DataSet entries = confusionMatrix.getEntries();
        double[] tpArray = new double[numDataPoints];
        double[] tnArray = new double[numDataPoints];
        double[] fpArray = new double[numDataPoints];
        double[] fnArray = new double[numDataPoints];
        IntStream.range(0, numDataPoints).parallel().forEach(i -> {
            for (Vector.Element element : entries.getRow(i).nonZeroes()) {
                double v = element.get();
                if (v == 1.0) {
                    int n = i;
                    tpArray[n] = tpArray[n] + 1.0;
                    continue;
                }
                if (v == 2.0) {
                    int n = i;
                    fnArray[n] = fnArray[n] + 1.0;
                    continue;
                }
                if (v != 3.0) continue;
                int n = i;
                fpArray[n] = fpArray[n] + 1.0;
            }
            tnArray[i] = numClasses - entries.getRow(i).getNumNonZeroElements();
            int n = i;
            tpArray[n] = tpArray[n] / (double)numClasses;
            int n2 = i;
            tnArray[n2] = tnArray[n2] / (double)numClasses;
            int n3 = i;
            fpArray[n3] = fpArray[n3] / (double)numClasses;
            int n4 = i;
            fnArray[n4] = fnArray[n4] / (double)numClasses;
        });
        this.precision = IntStream.range(0, numDataPoints).parallel().mapToDouble(i -> Precision.precision(tpArray[i], fpArray[i])).average().getAsDouble();
        this.recall = IntStream.range(0, numDataPoints).parallel().mapToDouble(i -> Recall.recall(tpArray[i], fnArray[i])).average().getAsDouble();
        this.f1 = IntStream.range(0, numDataPoints).parallel().mapToDouble(i -> FMeasure.f1(tpArray[i], fpArray[i], fnArray[i])).average().getAsDouble();
        this.overlap = IntStream.range(0, numDataPoints).parallel().mapToDouble(i -> Overlap.overlap(tpArray[i], fpArray[i], fnArray[i])).average().getAsDouble();
        this.hammingLoss = IntStream.range(0, numDataPoints).parallel().mapToDouble(i -> HammingLoss.hammingLoss(tpArray[i], tnArray[i], fpArray[i], fnArray[i])).average().getAsDouble();
        this.accuracy = (double)IntStream.range(0, numDataPoints).parallel().filter(i -> this.correct(entries.getRow(i))).count() / (double)numDataPoints;
    }

    public double getF1() {
        return this.f1;
    }

    public double getOverlap() {
        return this.overlap;
    }

    public double getPrecision() {
        return this.precision;
    }

    public double getRecall() {
        return this.recall;
    }

    public double getHammingLoss() {
        return this.hammingLoss;
    }

    public double getAccuracy() {
        return this.accuracy;
    }

    private boolean correct(Vector dataEntry) {
        for (Vector.Element element : dataEntry.nonZeroes()) {
            double v = element.get();
            if (v != 2.0 && v != 3.0) continue;
            return false;
        }
        return true;
    }

    public String toString() {
        StringBuilder sb = new StringBuilder();
        sb.append("instance subset accuracy = ").append(this.accuracy).append("\n");
        sb.append("instance Jaccard index = ").append(this.overlap).append("\n");
        sb.append("instance Hamming loss = ").append(this.hammingLoss).append("\n");
        sb.append("instance F1 = ").append(this.f1).append("\n");
        sb.append("instance precision = ").append(this.precision).append("\n");
        sb.append("instance recall = ").append(this.recall).append("\n");
        return sb.toString();
    }
}

