/*
 * Decompiled with CFR 0.152.
 */
package edu.neu.ccs.pyramid.multilabel_classification.cbm;

import edu.neu.ccs.pyramid.classification.Classifier;
import edu.neu.ccs.pyramid.classification.logistic_regression.LogisticRegression;
import edu.neu.ccs.pyramid.classification.logistic_regression.Weights;
import edu.neu.ccs.pyramid.dataset.LabelTranslator;
import edu.neu.ccs.pyramid.dataset.MultiLabel;
import edu.neu.ccs.pyramid.dataset.MultiLabelClfDataSet;
import edu.neu.ccs.pyramid.eval.Entropy;
import edu.neu.ccs.pyramid.multilabel_classification.cbm.CBM;
import edu.neu.ccs.pyramid.util.ArgSort;
import edu.neu.ccs.pyramid.util.Pair;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import org.apache.mahout.math.DenseVector;
import org.apache.mahout.math.Vector;
import org.ojalgo.access.Access2D;
import org.ojalgo.matrix.BasicMatrix;
import org.ojalgo.matrix.PrimitiveMatrix;

public class CBMInspector {
    private static BasicMatrix.Factory<PrimitiveMatrix> factory = PrimitiveMatrix.FACTORY;

    public static String topLabels(CBM cbm, Vector vector, double probabilityThreshold) {
        double[] marginals = cbm.predictClassProbs(vector);
        ArrayList<Pair<Integer, Double>> list = new ArrayList<Pair<Integer, Double>>();
        Comparator<Pair> comparator = Comparator.comparing(Pair::getSecond);
        for (int l = 0; l < cbm.getNumClasses(); ++l) {
            list.add(new Pair<Integer, Double>(l, marginals[l]));
        }
        List sorted = list.stream().filter(pair -> (Double)pair.getSecond() >= probabilityThreshold).sorted(comparator.reversed()).collect(Collectors.toList());
        StringBuilder sb = new StringBuilder();
        for (int i = 0; i < sorted.size(); ++i) {
            Pair pair2 = (Pair)sorted.get(i);
            sb.append(pair2.getFirst()).append(":").append(pair2.getSecond());
            if (i == sorted.size() - 1) continue;
            sb.append(", ");
        }
        return sb.toString();
    }

    public static Weights getMean(CBM bmm, int label) {
        int numClusters = bmm.getNumComponents();
        int length = ((LogisticRegression)bmm.getBinaryClassifiers()[0][0]).getWeights().getAllWeights().size();
        int numFeatures = ((LogisticRegression)bmm.getBinaryClassifiers()[0][0]).getNumFeatures();
        DenseVector mean = new DenseVector(length);
        for (int k = 0; k < numClusters; ++k) {
            mean = mean.plus(((LogisticRegression)bmm.getBinaryClassifiers()[k][label]).getWeights().getAllWeights());
        }
        mean = mean.divide((double)numClusters);
        return new Weights(2, numFeatures, (Vector)mean);
    }

    public static double distanceFromMean(CBM bmm) {
        int numClasses = bmm.getNumClasses();
        return IntStream.range(0, numClasses).mapToDouble(l -> CBMInspector.distanceFromMean(bmm, l)).average().getAsDouble();
    }

    public static double distanceFromMean(CBM bmm, int label) {
        Classifier.ProbabilityEstimator[][] logistics = bmm.getBinaryClassifiers();
        int numClusters = bmm.getNumComponents();
        int numFeatures = ((LogisticRegression)logistics[0][0]).getNumFeatures();
        DenseVector positiveAverageVector = new DenseVector(numFeatures);
        for (int k = 0; k < numClusters; ++k) {
            Vector positiveVector = ((LogisticRegression)logistics[k][label]).getWeights().getWeightsWithoutBiasForClass(1);
            positiveAverageVector = positiveAverageVector.plus(positiveVector);
        }
        positiveAverageVector = positiveAverageVector.divide((double)numClusters);
        double dis = 0.0;
        for (int k = 0; k < numClusters; ++k) {
            Vector positiveVector = ((LogisticRegression)logistics[k][label]).getWeights().getWeightsWithoutBiasForClass(1);
            dis += positiveVector.minus((Vector)positiveAverageVector).norm(2.0);
        }
        return dis / (double)numClusters;
    }

    public static List<Map<MultiLabel, Double>> visualizeClusters(CBM bmm, MultiLabelClfDataSet dataSet) {
        int numClusters = bmm.getNumComponents();
        ArrayList<Map<MultiLabel, Double>> list = new ArrayList<Map<MultiLabel, Double>>();
        for (int k = 0; k < numClusters; ++k) {
            list.add(new HashMap());
        }
        for (int i = 0; i < dataSet.getNumDataPoints(); ++i) {
            double[] clusterProbs = bmm.getMultiClassClassifier().predictClassProbs(dataSet.getRow(i));
            MultiLabel multiLabel = dataSet.getMultiLabels()[i];
            for (int k = 0; k < numClusters; ++k) {
                Map map = (Map)list.get(k);
                double count = map.getOrDefault(multiLabel, 0.0);
                double newcount = count + clusterProbs[k];
                map.put(multiLabel, newcount);
            }
        }
        return list;
    }

    public static void visualizePrediction(CBM CBM2, Vector vector, LabelTranslator labelTranslator) {
        int k;
        int t;
        int numClusters = CBM2.getNumComponents();
        int numClasses = CBM2.getNumClasses();
        double[] proportions = CBM2.getMultiClassClassifier().predictClassProbs(vector);
        double[][] probabilities = new double[numClusters][numClasses];
        for (int k2 = 0; k2 < numClusters; ++k2) {
            for (int l = 0; l < numClasses; ++l) {
                probabilities[k2][l] = CBM2.getBinaryClassifiers()[k2][l].predictClassProb(vector, 1);
            }
        }
        int[] sorted = ArgSort.argSortDescending(proportions);
        double[] topProbs = probabilities[sorted[0]];
        MultiLabel trivalPred = new MultiLabel();
        for (int l = 0; l < numClasses; ++l) {
            if (!(topProbs[l] >= 0.5)) continue;
            trivalPred.addLabel(l);
        }
        MultiLabel secondPred = new MultiLabel();
        for (int l = 0; l < numClasses; ++l) {
            if (!(probabilities[sorted[1]][l] >= 0.5)) continue;
            secondPred.addLabel(l);
        }
        MultiLabel predicted = CBM2.predict(vector);
        if (!predicted.equals(trivalPred)) {
            System.out.println("interesting case !");
            if (!predicted.equals(secondPred)) {
                System.out.println("very interesting case !");
            }
        }
        double[] sortedPorportions = new double[numClusters];
        for (t = 0; t < sorted.length; ++t) {
            k = sorted[t];
            sortedPorportions[t] = proportions[k];
        }
        System.out.println("proportions = " + Arrays.toString(sortedPorportions));
        for (t = 0; t < sorted.length; ++t) {
            k = sorted[t];
            System.out.println("prob" + (t + 1) + " = " + Arrays.toString(probabilities[k]));
        }
        for (t = 0; t < sorted.length; ++t) {
            k = sorted[t];
            double[] probs = probabilities[k];
            ArrayList<String> labels = new ArrayList<String>();
            for (int l = 0; l < numClasses; ++l) {
                if (!(probs[l] > 0.5)) continue;
                labels.add("\"" + labelTranslator.toExtLabel(l) + "\"");
            }
            System.out.println("labels" + (t + 1) + " = " + labels);
        }
        System.out.println("perplexity=" + Math.pow(2.0, Entropy.entropy2Based(proportions)));
        for (t = 0; t < numClusters; ++t) {
            System.out.println("cluster" + (t + 1) + " = " + Arrays.toString(probabilities[t]));
        }
    }

    public static void covariance(CBM CBM2, Vector vector, LabelTranslator labelTranslator) {
        int l;
        int numClusters = CBM2.getNumComponents();
        int numClasses = CBM2.getNumClasses();
        double[] proportions = CBM2.getMultiClassClassifier().predictClassProbs(vector);
        double[][] probabilities = new double[numClusters][numClasses];
        for (int k = 0; k < numClusters; ++k) {
            for (l = 0; l < numClasses; ++l) {
                probabilities[k][l] = CBM2.getBinaryClassifiers()[k][l].predictClassProb(vector, 1);
            }
        }
        Access2D.Builder meanBuilder = factory.getBuilder(numClasses, 1);
        for (l = 0; l < numClasses; ++l) {
            double sum = 0.0;
            for (int k = 0; k < numClusters; ++k) {
                sum += proportions[k] * probabilities[k][l];
            }
            meanBuilder.set((long)l, 0L, sum);
        }
        BasicMatrix mean = (BasicMatrix)meanBuilder.build();
        ArrayList<BasicMatrix> mus = new ArrayList<BasicMatrix>();
        for (int k = 0; k < numClusters; ++k) {
            Access2D.Builder muBuilder = factory.getBuilder(numClasses, 1);
            for (int l2 = 0; l2 < numClasses; ++l2) {
                muBuilder.set((long)l2, 0L, probabilities[k][l2]);
            }
            BasicMatrix muK = (BasicMatrix)muBuilder.build();
            mus.add(muK);
        }
        ArrayList<BasicMatrix> sigmas = new ArrayList<BasicMatrix>();
        for (int k = 0; k < numClusters; ++k) {
            Access2D.Builder sigmaBuilder = factory.getBuilder(numClasses, numClasses);
            for (int l3 = 0; l3 < numClasses; ++l3) {
                double v = probabilities[k][l3] * (1.0 - probabilities[k][l3]);
                sigmaBuilder.set((long)l3, (long)l3, v);
            }
            BasicMatrix sigmaK = (BasicMatrix)sigmaBuilder.build();
            sigmas.add(sigmaK);
        }
        BasicMatrix covariance = (BasicMatrix)factory.makeZero((long)numClasses, (long)numClasses);
        for (int k = 0; k < numClusters; ++k) {
            BasicMatrix muk = (BasicMatrix)mus.get(k);
            BasicMatrix toadd = (BasicMatrix)((BasicMatrix)((BasicMatrix)sigmas.get(k)).add((Object)muk.multiply((Access2D)muk.transpose()))).multiply(proportions[k]);
            covariance = (BasicMatrix)covariance.add((Object)toadd);
        }
        covariance = covariance.subtract((Access2D)mean.multiply((Access2D)mean.transpose()));
        Access2D.Builder correlationBuilder = factory.getBuilder(numClasses, numClasses);
        for (int l4 = 0; l4 < numClasses; ++l4) {
            for (int j = 0; j < numClasses; ++j) {
                double v = covariance.get((long)l4, (long)j).doubleValue() / (Math.sqrt(covariance.get((long)l4, (long)l4).doubleValue()) * Math.sqrt(covariance.get((long)j, (long)j).doubleValue()));
                correlationBuilder.set((long)l4, (long)j, v);
            }
        }
        BasicMatrix correlation = (BasicMatrix)correlationBuilder.build();
        ArrayList<Pair<String, Double>> list = new ArrayList<Pair<String, Double>>();
        for (int l5 = 0; l5 < numClasses; ++l5) {
            for (int j = 0; j < l5; ++j) {
                String s = "" + labelTranslator.toExtLabel(l5) + ", " + labelTranslator.toExtLabel(j);
                double v = correlation.get((long)l5, (long)j).doubleValue();
                Pair<String, Double> pair2 = new Pair<String, Double>(s, v);
                list.add(pair2);
            }
        }
        Comparator<Pair> comparator = Comparator.comparing(pair -> Math.abs((Double)pair.getSecond()));
        List top = list.stream().sorted(comparator.reversed()).limit(20L).collect(Collectors.toList());
        System.out.println(top);
    }
}

