/*
 * Decompiled with CFR 0.152.
 */
package edu.neu.ccs.pyramid.multilabel_classification;

import edu.neu.ccs.pyramid.dataset.MultiLabel;
import edu.neu.ccs.pyramid.eval.InstanceAverage;
import edu.neu.ccs.pyramid.eval.MLConfusionMatrix;
import edu.neu.ccs.pyramid.util.Sampling;
import java.util.ArrayList;
import java.util.List;
import org.apache.mahout.math.DenseMatrix;
import org.apache.mahout.math.Matrix;

public class LossMatrixGenerator {
    public static Matrix matrix(int n, String lossName) {
        int size = (int)Math.pow(2.0, n);
        double[][] matrixBuilder = new double[size][size];
        for (int i = 0; i < size; ++i) {
            for (int j = 0; j < size; ++j) {
                double loss;
                String ib = LossMatrixGenerator.toBinary(i, n);
                String jb = LossMatrixGenerator.toBinary(j, n);
                MultiLabel multiLabel1 = LossMatrixGenerator.toML(ib);
                MultiLabel multiLabel2 = LossMatrixGenerator.toML(jb);
                MultiLabel[] trueLabels = new MultiLabel[]{multiLabel1};
                MultiLabel[] predicted = new MultiLabel[]{multiLabel2};
                MLConfusionMatrix mlConfusionMatrix = new MLConfusionMatrix(n, trueLabels, predicted);
                InstanceAverage instanceAverage = new InstanceAverage(mlConfusionMatrix);
                switch (lossName.toLowerCase()) {
                    case "hamming": {
                        loss = instanceAverage.getHammingLoss() * (double)n;
                        break;
                    }
                    case "overlap": {
                        loss = 1.0 - instanceAverage.getOverlap();
                        break;
                    }
                    case "accuracy": {
                        loss = 1.0 - instanceAverage.getAccuracy();
                        break;
                    }
                    case "precision": {
                        loss = 1.0 - instanceAverage.getPrecision();
                        break;
                    }
                    case "recall": {
                        loss = 1.0 - instanceAverage.getRecall();
                        break;
                    }
                    case "f1": {
                        loss = 1.0 - instanceAverage.getF1();
                        break;
                    }
                    default: {
                        throw new IllegalArgumentException("unknown loss");
                    }
                }
                matrixBuilder[i][j] = loss;
            }
        }
        DenseMatrix matrix = new DenseMatrix(matrixBuilder);
        return matrix;
    }

    private static String toBinary(int number, int length) {
        String iBinary = Integer.toBinaryString(number);
        StringBuilder sb = new StringBuilder();
        for (int l = 0; l < length - iBinary.length(); ++l) {
            sb.append("0");
        }
        sb.append(iBinary);
        String ib = sb.toString();
        return ib;
    }

    private static MultiLabel toML(String str) {
        MultiLabel multiLabel = new MultiLabel();
        for (int i = 0; i < str.length(); ++i) {
            String sub = str.substring(i, i + 1);
            if (!sub.equals("1")) continue;
            multiLabel.addLabel(i);
        }
        return multiLabel;
    }

    public static List<Double> sampleDistribution(int numLabels) {
        int size = (int)Math.pow(2.0, numLabels);
        ArrayList<Double> list = new ArrayList<Double>();
        double used = 0.0;
        for (int i = 0; i < size; ++i) {
            double prob = i == size - 1 ? 1.0 - used : Sampling.doubleUniform(0.0, 1.0 - used);
            list.add(prob);
            used += prob;
        }
        return list;
    }
}

