/*
 * Decompiled with CFR 0.152.
 */
package edu.neu.ccs.pyramid.eval;

import com.fasterxml.jackson.core.JsonGenerator;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.JsonSerializer;
import com.fasterxml.jackson.databind.SerializerProvider;
import com.fasterxml.jackson.databind.annotation.JsonSerialize;
import edu.neu.ccs.pyramid.dataset.LabelTranslator;
import edu.neu.ccs.pyramid.dataset.MultiLabel;
import edu.neu.ccs.pyramid.dataset.MultiLabelClfDataSet;
import edu.neu.ccs.pyramid.eval.Accuracy;
import edu.neu.ccs.pyramid.eval.FMeasure;
import edu.neu.ccs.pyramid.eval.Precision;
import edu.neu.ccs.pyramid.eval.Recall;
import java.io.IOException;

@JsonSerialize(using=Serializer.class)
public class LabelBasedMeasures {
    protected int numLabels;
    protected int[] truePositives;
    protected int[] trueNegatives;
    protected int[] falsePositives;
    protected int[] falseNegatives;
    protected int numDataPoitns;
    protected LabelTranslator labelTranslator;

    public LabelBasedMeasures(MultiLabelClfDataSet dataSet, MultiLabel[] prediction) {
        this(dataSet.getNumClasses());
        this.labelTranslator = dataSet.getLabelTranslator();
        this.update(dataSet.getMultiLabels(), prediction);
    }

    public LabelBasedMeasures(int numLabels) {
        if (numLabels == 0) {
            throw new RuntimeException("initialization with zero label.");
        }
        this.numLabels = numLabels;
        this.truePositives = new int[numLabels];
        this.falsePositives = new int[numLabels];
        this.trueNegatives = new int[numLabels];
        this.falseNegatives = new int[numLabels];
        this.numDataPoitns = 0;
        this.labelTranslator = LabelTranslator.newDefaultLabelTranslator(numLabels);
    }

    public double precision(int classIndex) {
        return Precision.precision(this.truePositives[classIndex], this.falsePositives[classIndex]);
    }

    public double recall(int classIndex) {
        return Recall.recall(this.truePositives[classIndex], this.falseNegatives[classIndex]);
    }

    public double f1(int classIndex) {
        double precision = this.precision(classIndex);
        double recall = this.recall(classIndex);
        return FMeasure.f1(precision, recall);
    }

    public double accuracy(int classIndex) {
        return Accuracy.accuracy(this.truePositives[classIndex], this.trueNegatives[classIndex], this.falsePositives[classIndex], this.falseNegatives[classIndex]);
    }

    public void update(MultiLabel label, MultiLabel prediction) {
        for (int i = 0; i < this.numLabels; ++i) {
            boolean actual = label.matchClass(i);
            boolean predicted = prediction.matchClass(i);
            if (actual) {
                if (predicted) {
                    int n = i;
                    this.truePositives[n] = this.truePositives[n] + 1;
                } else {
                    int n = i;
                    this.falseNegatives[n] = this.falseNegatives[n] + 1;
                }
            } else if (predicted) {
                int n = i;
                this.falsePositives[n] = this.falsePositives[n] + 1;
            } else {
                int n = i;
                this.trueNegatives[n] = this.trueNegatives[n] + 1;
            }
            ++this.numDataPoitns;
        }
    }

    public void update(MultiLabel[] labels, MultiLabel[] predictions) {
        if (labels.length == 0) {
            throw new RuntimeException("Empty given ground truth.");
        }
        if (labels.length != predictions.length) {
            throw new RuntimeException("The lengths of ground truth and predictions shouldbe the same.");
        }
        for (int i = 0; i < labels.length; ++i) {
            this.update(labels[i], predictions[i]);
        }
    }

    public String toString() {
        StringBuilder sb = new StringBuilder("LabelBasedMeasures{");
        for (int k = 0; k < this.numLabels; ++k) {
            sb.append("label=").append(k).append("; ");
            sb.append("precision=").append(this.precision(k)).append("; ");
            sb.append("recall=").append(this.recall(k)).append("; ");
            sb.append("f1=").append(this.f1(k)).append("; ");
            sb.append("accuracy=").append(this.accuracy(k)).append("\n");
        }
        sb.append('}');
        return sb.toString();
    }

    public static class Serializer
    extends JsonSerializer<LabelBasedMeasures> {
        public void serialize(LabelBasedMeasures labelBasedMeasures, JsonGenerator jsonGenerator, SerializerProvider serializerProvider) throws IOException, JsonProcessingException {
            LabelTranslator labelTranslator = labelBasedMeasures.labelTranslator;
            jsonGenerator.writeStartArray();
            for (int k = 0; k < labelBasedMeasures.numLabels; ++k) {
                jsonGenerator.writeStartObject();
                jsonGenerator.writeStringField("label", labelTranslator.toExtLabel(k));
                jsonGenerator.writeNumberField("TP", labelBasedMeasures.truePositives[k]);
                jsonGenerator.writeNumberField("TN", labelBasedMeasures.trueNegatives[k]);
                jsonGenerator.writeNumberField("FP", labelBasedMeasures.falsePositives[k]);
                jsonGenerator.writeNumberField("FN", labelBasedMeasures.falseNegatives[k]);
                jsonGenerator.writeNumberField("precision", labelBasedMeasures.precision(k));
                jsonGenerator.writeNumberField("recall", labelBasedMeasures.recall(k));
                jsonGenerator.writeNumberField("f1", labelBasedMeasures.f1(k));
                jsonGenerator.writeNumberField("accuracy", labelBasedMeasures.accuracy(k));
                jsonGenerator.writeEndObject();
            }
            jsonGenerator.writeEndArray();
        }
    }
}

