/*
 * Decompiled with CFR 0.152.
 */
package edu.neu.ccs.pyramid.eval;

import edu.neu.ccs.pyramid.classification.Classifier;
import edu.neu.ccs.pyramid.classification.logistic_regression.LogisticRegression;
import edu.neu.ccs.pyramid.dataset.DataSet;
import edu.neu.ccs.pyramid.dataset.MultiLabel;
import edu.neu.ccs.pyramid.dataset.MultiLabelClfDataSet;
import edu.neu.ccs.pyramid.multilabel_classification.MultiLabelClassifier;
import edu.neu.ccs.pyramid.multilabel_classification.cbm.CBM;
import edu.neu.ccs.pyramid.util.ArgSort;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;
import java.util.stream.IntStream;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.mahout.math.Vector;

public class KLDivergence {
    private static final Logger logger = LogManager.getLogger();

    public static double kl(double[] trueDistribution, double[] estimatedDistribution) {
        double r = 0.0;
        for (int i = 0; i < trueDistribution.length; ++i) {
            if (trueDistribution[i] == 0.0) {
                r += 0.0;
                continue;
            }
            if (estimatedDistribution[i] == 0.0) {
                r = Double.POSITIVE_INFINITY;
                break;
            }
            r += trueDistribution[i] * (Math.log(trueDistribution[i]) - Math.log(estimatedDistribution[i]));
        }
        if (Double.isInfinite(r) && logger.isDebugEnabled()) {
            logger.debug("true distribution = " + Arrays.toString(trueDistribution));
            logger.debug("estimated distribution = " + Arrays.toString(estimatedDistribution));
        }
        if (Double.isNaN(r)) {
            throw new RuntimeException("KL divergence between " + Arrays.toString(trueDistribution) + " and " + Arrays.toString(estimatedDistribution) + " is NaN");
        }
        return r;
    }

    public static double klGivenPLogQ(double[] targetDistribution, double[] logEstimatedDistribution) {
        double r = 0.0;
        for (int i = 0; i < targetDistribution.length; ++i) {
            if (targetDistribution[i] == 0.0) continue;
            r += targetDistribution[i] * (Math.log(targetDistribution[i]) - logEstimatedDistribution[i]);
        }
        return r;
    }

    public static double kl(Classifier.ProbabilityEstimator estimator, Vector vector, double[] targetDistribution) {
        double[] logEstimation = estimator.predictLogClassProbs(vector);
        return KLDivergence.klGivenPLogQ(targetDistribution, logEstimation);
    }

    public static double kl(Classifier.ProbabilityEstimator estimator, DataSet dataSet, double[][] targetDistributions, double[] weights) {
        double sum = 0.0;
        for (int n = 0; n < dataSet.getNumDataPoints(); ++n) {
            sum += weights[n] * KLDivergence.kl(estimator, dataSet.getRow(n), targetDistributions[n]);
        }
        return sum;
    }

    public static double kl(Classifier.ProbabilityEstimator estimator, DataSet dataSet, double[][] targetDistributions) {
        double[] weights = new double[dataSet.getNumDataPoints()];
        Arrays.fill(weights, 1.0);
        return KLDivergence.kl(estimator, dataSet, targetDistributions, weights);
    }

    public static double kl(MultiLabelClassifier.AssignmentProbEstimator multiLabelClassifier, MultiLabelClfDataSet dataSet) {
        return IntStream.range(0, dataSet.getNumDataPoints()).mapToDouble(i -> multiLabelClassifier.predictLogAssignmentProb(dataSet.getRow(i), dataSet.getMultiLabels()[i])).sum() * -1.0;
    }

    public static double kl_conditional(MultiLabelClassifier.AssignmentProbEstimator multiLabelClassifier, MultiLabelClfDataSet dataSet) {
        HashMap<MultiLabel, Integer> q_z = new HashMap<MultiLabel, Integer>();
        HashMap q_yz = new HashMap();
        for (int i = 0; i < dataSet.getNumDataPoints(); ++i) {
            MultiLabel z = new MultiLabel(dataSet.getRow(i));
            MultiLabel y = dataSet.getMultiLabels()[i];
            if (q_z.containsKey(z)) {
                q_z.put(z, (Integer)q_z.get(z) + 1);
            } else {
                q_z.put(z, 1);
            }
            if (!q_yz.containsKey(z)) {
                q_yz.put(z, new HashMap());
            }
            if (((HashMap)q_yz.get(z)).containsKey(y)) {
                ((HashMap)q_yz.get(z)).put(y, (Integer)((HashMap)q_yz.get(z)).get(y) + 1);
                continue;
            }
            ((HashMap)q_yz.get(z)).put(y, 1);
        }
        double kl = 0.0;
        for (Map.Entry e1 : q_z.entrySet()) {
            double kl_y = 0.0;
            for (Map.Entry e2 : ((HashMap)q_yz.get(e1.getKey())).entrySet()) {
                double empirical_prob_yz = (double)((Integer)e2.getValue()).intValue() / (double)((Integer)e1.getValue()).intValue();
                double log_estimated_prob_yz = multiLabelClassifier.predictLogAssignmentProb(((MultiLabel)e1.getKey()).toVector(dataSet.getNumFeatures()), (MultiLabel)e2.getKey());
                kl_y += empirical_prob_yz * (Math.log(empirical_prob_yz) - log_estimated_prob_yz);
            }
            double empirical_prob_z = (double)((Integer)e1.getValue()).intValue() / (double)dataSet.getNumDataPoints();
            kl += empirical_prob_z * kl_y;
        }
        int occur_threshold = 10;
        double marginal_threshold = 0.01;
        for (Map.Entry e1 : q_z.entrySet()) {
            int i;
            int i2;
            double[] marginals1 = new double[dataSet.getNumFeatures()];
            for (Map.Entry e2 : ((HashMap)q_yz.get(e1.getKey())).entrySet()) {
                double estimated_prob_yz = multiLabelClassifier.predictAssignmentProb(((MultiLabel)e1.getKey()).toVector(dataSet.getNumFeatures()), (MultiLabel)e2.getKey());
                double empirical_prob_yz = (double)((Integer)e2.getValue()).intValue() / (double)((Integer)e1.getValue()).intValue();
                if ((Integer)e1.getValue() >= occur_threshold) {
                    System.out.println("#z:" + e1.getValue() + ",z=" + ((MultiLabel)e1.getKey()).toStringWithExtLabels(dataSet.getLabelTranslator()) + "->{" + ((MultiLabel)e2.getKey()).toStringWithExtLabels(dataSet.getLabelTranslator()) + "},#y:" + e2.getValue() + ",p_y|z_empirical:" + empirical_prob_yz + ",p_y|z_estimated:" + estimated_prob_yz);
                }
                for (i2 = 0; i2 < dataSet.getNumFeatures(); ++i2) {
                    if (!((MultiLabel)e2.getKey()).matchClass(i2)) continue;
                    int n = i2;
                    marginals1[n] = marginals1[n] + (double)((Integer)e2.getValue()).intValue();
                }
            }
            if ((Integer)e1.getValue() < occur_threshold) continue;
            double estimated_prob_zz = multiLabelClassifier.predictAssignmentProb(((MultiLabel)e1.getKey()).toVector(dataSet.getNumFeatures()), (MultiLabel)e1.getKey());
            System.out.println("p(y=z|z)=" + estimated_prob_zz);
            CBM cbm = (CBM)multiLabelClassifier;
            System.out.println("p_y|z_estimated marginals are: ");
            double[] marginals = cbm.predictClassProbs(((MultiLabel)e1.getKey()).toVector(dataSet.getNumFeatures()));
            int[] order = ArgSort.argSortDescending(marginals);
            for (i = 0; i < order.length; ++i) {
                if (!(marginals[order[i]] > marginal_threshold)) continue;
                System.out.println(dataSet.getLabelTranslator().toExtLabel(order[i]) + ":" + marginals[order[i]]);
            }
            System.out.println("p_y|z_empirical marginals are: ");
            i = 0;
            while (i < dataSet.getNumFeatures()) {
                int n = i++;
                marginals1[n] = marginals1[n] / (double)((Integer)e1.getValue()).intValue();
            }
            int[] order1 = ArgSort.argSortDescending(marginals1);
            for (i2 = 0; i2 < order1.length; ++i2) {
                if (!(marginals1[order1[i2]] > marginal_threshold)) continue;
                System.out.println(dataSet.getLabelTranslator().toExtLabel(order1[i2]) + ":" + marginals1[order1[i2]]);
            }
        }
        System.out.println("LRs for each label:");
        CBM cbm = (CBM)multiLabelClassifier;
        Classifier.ProbabilityEstimator[] estimators = cbm.getBinaryClassifiers()[0];
        for (int i = 0; i < estimators.length; ++i) {
            System.out.println("LR:" + dataSet.getLabelTranslator().toExtLabel(i));
            LogisticRegression lr = (LogisticRegression)estimators[i];
            Vector weight_vec = lr.getWeights().getWeightsWithoutBiasForClass(1);
            double[] weights = new double[weight_vec.size()];
            for (int j = 0; j < weight_vec.size(); ++j) {
                weights[j] = weight_vec.get(j);
            }
            System.out.println("bias:" + lr.getWeights().getBiasForClass(1));
            int[] order2 = ArgSort.argSortDescending(weights);
            for (int j = 0; j < order2.length; ++j) {
                System.out.println(dataSet.getLabelTranslator().toExtLabel(order2[j]) + ":" + weights[order2[j]]);
            }
        }
        System.out.println("---");
        return kl;
    }
}

