/*
 * Decompiled with CFR 0.152.
 */
package edu.neu.ccs.pyramid.eval;

import com.fasterxml.jackson.core.JsonGenerator;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.JsonSerializer;
import com.fasterxml.jackson.databind.SerializerProvider;
import com.fasterxml.jackson.databind.annotation.JsonSerialize;
import edu.neu.ccs.pyramid.dataset.DataSet;
import edu.neu.ccs.pyramid.dataset.LabelTranslator;
import edu.neu.ccs.pyramid.eval.FMeasure;
import edu.neu.ccs.pyramid.eval.HammingLoss;
import edu.neu.ccs.pyramid.eval.MLConfusionMatrix;
import edu.neu.ccs.pyramid.eval.Overlap;
import edu.neu.ccs.pyramid.eval.Precision;
import edu.neu.ccs.pyramid.eval.Recall;
import edu.neu.ccs.pyramid.util.PrintUtil;
import java.io.IOException;
import java.util.Arrays;
import java.util.stream.IntStream;
import org.apache.mahout.math.Vector;

@JsonSerialize(using=Serializer.class)
public class MacroAverage {
    private int numClasses;
    private double f1;
    private double overlap;
    private double precision;
    private double recall;
    private double hammingLoss;
    private double binaryAccuracy;
    private int[] labelWiseTP;
    private int[] labelWiseTN;
    private int[] labelWiseFP;
    private int[] labelWiseFN;
    private double[] labelWisePrecision;
    private double[] labelWiseRecall;
    private double[] labelWiseOverlap;
    private double[] labelWiseF1;
    private double[] labelWiseHammingLoss;
    private double[] labelWiseAccuracy;
    private LabelTranslator labelTranslator;

    public MacroAverage(MLConfusionMatrix confusionMatrix) {
        this.numClasses = confusionMatrix.getNumClasses();
        int numDataPoints = confusionMatrix.getNumDataPoints();
        DataSet entries = confusionMatrix.getEntries();
        this.labelWiseTP = new int[this.numClasses];
        this.labelWiseTN = new int[this.numClasses];
        this.labelWiseFP = new int[this.numClasses];
        this.labelWiseFN = new int[this.numClasses];
        this.labelWisePrecision = new double[this.numClasses];
        this.labelWiseRecall = new double[this.numClasses];
        this.labelWiseOverlap = new double[this.numClasses];
        this.labelWiseF1 = new double[this.numClasses];
        this.labelWiseHammingLoss = new double[this.numClasses];
        this.labelWiseAccuracy = new double[this.numClasses];
        IntStream.range(0, this.numClasses).parallel().forEach(l -> {
            Vector vector = entries.getColumn(l);
            for (Vector.Element element : vector.nonZeroes()) {
                double v = element.get();
                if (v == 1.0) {
                    int n = l;
                    this.labelWiseTP[n] = this.labelWiseTP[n] + 1;
                    continue;
                }
                if (v == 2.0) {
                    int n = l;
                    this.labelWiseFN[n] = this.labelWiseFN[n] + 1;
                    continue;
                }
                if (v != 3.0) continue;
                int n = l;
                this.labelWiseFP[n] = this.labelWiseFP[n] + 1;
            }
            this.labelWiseTN[l] = numDataPoints - vector.getNumNonZeroElements();
            double tp = (double)this.labelWiseTP[l] / (double)numDataPoints;
            double tn = (double)this.labelWiseTN[l] / (double)numDataPoints;
            double fp = (double)this.labelWiseFP[l] / (double)numDataPoints;
            double fn = (double)this.labelWiseFN[l] / (double)numDataPoints;
            this.labelWisePrecision[l] = Precision.precision(tp, fp);
            this.labelWiseRecall[l] = Recall.recall(tp, fn);
            this.labelWiseF1[l] = FMeasure.f1(tp, fp, fn);
            this.labelWiseOverlap[l] = Overlap.overlap(tp, fp, fn);
            this.labelWiseHammingLoss[l] = HammingLoss.hammingLoss(tp, tn, fp, fn);
            this.labelWiseAccuracy[l] = tp + tn;
        });
        this.precision = Arrays.stream(this.labelWisePrecision).average().getAsDouble();
        this.recall = Arrays.stream(this.labelWiseRecall).average().getAsDouble();
        this.f1 = Arrays.stream(this.labelWiseF1).average().getAsDouble();
        this.overlap = Arrays.stream(this.labelWiseOverlap).average().getAsDouble();
        this.hammingLoss = Arrays.stream(this.labelWiseHammingLoss).average().getAsDouble();
        this.binaryAccuracy = Arrays.stream(this.labelWiseAccuracy).average().getAsDouble();
        this.labelTranslator = LabelTranslator.newDefaultLabelTranslator(this.numClasses);
    }

    public void setLabelTranslator(LabelTranslator labelTranslator) {
        this.labelTranslator = labelTranslator;
    }

    public double getF1() {
        return this.f1;
    }

    public double getOverlap() {
        return this.overlap;
    }

    public double getPrecision() {
        return this.precision;
    }

    public double getRecall() {
        return this.recall;
    }

    public double getHammingLoss() {
        return this.hammingLoss;
    }

    public int[] getLabelWiseTP() {
        return this.labelWiseTP;
    }

    public int[] getLabelWiseTN() {
        return this.labelWiseTN;
    }

    public int[] getLabelWiseFP() {
        return this.labelWiseFP;
    }

    public int[] getLabelWiseFN() {
        return this.labelWiseFN;
    }

    public double[] getLabelWisePrecision() {
        return this.labelWisePrecision;
    }

    public double[] getLabelWiseRecall() {
        return this.labelWiseRecall;
    }

    public double[] getLabelWiseOverlap() {
        return this.labelWiseOverlap;
    }

    public double[] getLabelWiseF1() {
        return this.labelWiseF1;
    }

    public double[] getLabelWiseHammingLoss() {
        return this.labelWiseHammingLoss;
    }

    public double getBinaryAccuracy() {
        return this.binaryAccuracy;
    }

    public double[] getLabelWiseAccuracy() {
        return this.labelWiseAccuracy;
    }

    public String toString() {
        StringBuilder sb = new StringBuilder();
        sb.append("label Jaccard Index = ").append(this.overlap).append("\n");
        sb.append("label Hamming loss = ").append(this.hammingLoss).append("\n");
        sb.append("label F1 = ").append(this.f1).append("\n");
        sb.append("label precision = ").append(this.precision).append("\n");
        sb.append("label recall = ").append(this.recall).append("\n");
        sb.append("label binary accuracy = ").append(this.binaryAccuracy).append("\n");
        return sb.toString();
    }

    public String printDetail() {
        StringBuilder sb = new StringBuilder();
        sb.append("f1=").append(this.f1);
        sb.append(", overlap=").append(this.overlap);
        sb.append(", precision=").append(this.precision);
        sb.append(", recall=").append(this.recall);
        sb.append(", hammingLoss=").append(this.hammingLoss);
        sb.append(", binaryAccuracy=").append(this.binaryAccuracy);
        sb.append(", labelWisePrecision=").append(PrintUtil.printWithIndex(this.labelWisePrecision));
        sb.append(", labelWiseRecall=").append(PrintUtil.printWithIndex(this.labelWiseRecall));
        sb.append(", labelWiseOverlap=").append(PrintUtil.printWithIndex(this.labelWiseOverlap));
        sb.append(", labelWiseF1=").append(PrintUtil.printWithIndex(this.labelWiseF1));
        sb.append(", labelWiseHammingLoss=").append(PrintUtil.printWithIndex(this.labelWiseHammingLoss));
        sb.append(", labelWiseAccuracy=").append(PrintUtil.printWithIndex(this.labelWiseAccuracy));
        sb.append(", labelWiseTP=").append(PrintUtil.printWithIndex(this.labelWiseTP));
        sb.append(", labelWiseTN=").append(PrintUtil.printWithIndex(this.labelWiseTN));
        sb.append(", labelWiseFP=").append(PrintUtil.printWithIndex(this.labelWiseFP));
        sb.append(", labelWiseFN=").append(PrintUtil.printWithIndex(this.labelWiseFN));
        return sb.toString();
    }

    public static class Serializer
    extends JsonSerializer<MacroAverage> {
        public void serialize(MacroAverage macroAverage, JsonGenerator jsonGenerator, SerializerProvider serializerProvider) throws IOException, JsonProcessingException {
            jsonGenerator.writeStartArray();
            for (int k = 0; k < macroAverage.numClasses; ++k) {
                jsonGenerator.writeStartObject();
                jsonGenerator.writeStringField("label", macroAverage.labelTranslator.toExtLabel(k));
                jsonGenerator.writeNumberField("TP", macroAverage.labelWiseTP[k]);
                jsonGenerator.writeNumberField("TN", macroAverage.labelWiseTN[k]);
                jsonGenerator.writeNumberField("FP", macroAverage.labelWiseFP[k]);
                jsonGenerator.writeNumberField("FN", macroAverage.labelWiseFN[k]);
                jsonGenerator.writeNumberField("precision", macroAverage.labelWisePrecision[k]);
                jsonGenerator.writeNumberField("recall", macroAverage.labelWiseRecall[k]);
                jsonGenerator.writeNumberField("f1", macroAverage.labelWiseF1[k]);
                jsonGenerator.writeNumberField("accuracy", macroAverage.labelWiseAccuracy[k]);
                jsonGenerator.writeEndObject();
            }
            jsonGenerator.writeEndArray();
        }
    }
}

