/*
 * Decompiled with CFR 0.152.
 */
package edu.neu.ccs.pyramid.eval;

import edu.neu.ccs.pyramid.dataset.DataSet;
import edu.neu.ccs.pyramid.dataset.DataSetBuilder;
import edu.neu.ccs.pyramid.dataset.Density;
import edu.neu.ccs.pyramid.dataset.MultiLabel;
import edu.neu.ccs.pyramid.dataset.MultiLabelClfDataSet;
import edu.neu.ccs.pyramid.multilabel_classification.MultiLabelClassifier;
import java.util.stream.IntStream;
import org.apache.mahout.math.Vector;

public class MLConfusionMatrix {
    private int numClasses;
    private int numDataPoints;
    private DataSet entries;

    public int getNumClasses() {
        return this.numClasses;
    }

    public DataSet getEntries() {
        return this.entries;
    }

    public int getNumDataPoints() {
        return this.numDataPoints;
    }

    public MLConfusionMatrix(int numClasses, MultiLabel[] trueLabels, MultiLabel[] predictions) {
        this.numClasses = numClasses;
        this.numDataPoints = trueLabels.length;
        int numData = trueLabels.length;
        this.entries = DataSetBuilder.getBuilder().numDataPoints(this.numDataPoints).numFeatures(numClasses).density(Density.SPARSE_RANDOM).build();
        IntStream.range(0, numData).forEach(i -> {
            MultiLabel label = trueLabels[i];
            MultiLabel prediction = predictions[i];
            Vector labelVector = label.toVector(numClasses);
            Vector predVector = prediction.toVector(numClasses);
            for (int l = 0; l < numClasses; ++l) {
                double labelMatch = labelVector.get(l);
                double prediMatch = predVector.get(l);
                if (labelMatch == 1.0 && prediMatch == 1.0) {
                    this.entries.setFeatureValue(i, l, 1.0);
                    continue;
                }
                if (labelMatch == 1.0 && prediMatch == 0.0) {
                    this.entries.setFeatureValue(i, l, 2.0);
                    continue;
                }
                if (labelMatch == 0.0 && prediMatch == 0.0) continue;
                this.entries.setFeatureValue(i, l, 3.0);
            }
        });
    }

    public MLConfusionMatrix(MultiLabelClassifier classifier, MultiLabelClfDataSet dataSet) {
        this(dataSet.getNumClasses(), dataSet.getMultiLabels(), classifier.predict(dataSet));
    }

    public MLConfusionMatrix(MultiLabelClfDataSet dataSet, MultiLabel[] predictions) {
        this(dataSet.getNumClasses(), dataSet.getMultiLabels(), predictions);
    }
}

