/*
 * Decompiled with CFR 0.152.
 */
package edu.neu.ccs.pyramid.multilabel_classification.cbm;

import edu.neu.ccs.pyramid.classification.logistic_regression.LogisticRegression;
import edu.neu.ccs.pyramid.dataset.MultiLabel;
import edu.neu.ccs.pyramid.multilabel_classification.cbm.AugmentedLR;
import edu.neu.ccs.pyramid.multilabel_classification.cbm.CBM;
import edu.neu.ccs.pyramid.multilabel_classification.cbm.CBMS;
import edu.neu.ccs.pyramid.util.BernoulliDistribution;
import edu.neu.ccs.pyramid.util.MathUtil;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import org.apache.commons.math3.distribution.EnumeratedIntegerDistribution;
import org.apache.mahout.math.Vector;

public class BMDistribution {
    int numLabels;
    int numComponents;
    double[] logProportions;
    double[][][] logClassProbs;
    private List<MultiLabel> support;
    double[][] normalizedLogProbs;
    private boolean ifSupport;

    BMDistribution(CBM cbm, Vector x) {
        this.numLabels = cbm.numLabels;
        this.numComponents = cbm.numComponents;
        this.logProportions = cbm.multiClassClassifier.predictLogClassProbs(x);
        this.logClassProbs = new double[this.numComponents][this.numLabels][2];
        for (int k = 0; k < this.numComponents; ++k) {
            for (int l = 0; l < this.numLabels; ++l) {
                this.logClassProbs[k][l] = cbm.binaryClassifiers[k][l].predictLogClassProbs(x);
            }
        }
    }

    public double[] getLogProportions() {
        return this.logProportions;
    }

    public double[][][] getLogClassProbs() {
        return this.logClassProbs;
    }

    BMDistribution(CBM cbm, Vector x, double threshold) {
        this.numLabels = cbm.numLabels;
        double[] allLogProportions = cbm.multiClassClassifier.predictLogClassProbs(x);
        double logThreshold = Math.log(threshold);
        List activeComponents = IntStream.range(0, allLogProportions.length).filter(k -> allLogProportions[k] >= logThreshold).boxed().collect(Collectors.toList());
        this.numComponents = activeComponents.size();
        this.logProportions = activeComponents.stream().mapToDouble(k -> allLogProportions[k]).toArray();
        this.logClassProbs = new double[this.numComponents][this.numLabels][2];
        for (int k2 = 0; k2 < this.numComponents; ++k2) {
            for (int l = 0; l < this.numLabels; ++l) {
                this.logClassProbs[k2][l] = cbm.binaryClassifiers[(Integer)activeComponents.get(k2)][l].predictLogClassProbs(x);
            }
        }
    }

    BMDistribution(CBMS cbms, Vector x) {
        this.numLabels = cbms.numLabels;
        this.numComponents = cbms.numComponents;
        this.logProportions = cbms.multiClassClassifier.predictLogClassProbs(x);
        this.logClassProbs = new double[this.numComponents][this.numLabels][2];
        for (int l = 0; l < this.numLabels; ++l) {
            AugmentedLR augmentedLR = cbms.getBinaryClassifiers()[l];
            double[][] lp = augmentedLR.logAugmentedProbs(x);
            for (int k = 0; k < this.numComponents; ++k) {
                this.logClassProbs[k][l][0] = lp[k][0];
                this.logClassProbs[k][l][1] = lp[k][1];
            }
        }
    }

    BMDistribution(CBM cbm, Vector x, List<MultiLabel> support) {
        this.numLabels = cbm.numLabels;
        this.numComponents = cbm.numComponents;
        this.logProportions = cbm.multiClassClassifier.predictLogClassProbs(x);
        this.support = support;
        this.ifSupport = true;
        double[][] classScore = new double[this.numComponents][this.numLabels];
        this.normalizedLogProbs = new double[this.numComponents][support.size()];
        for (int k = 0; k < this.numComponents; ++k) {
            for (int l = 0; l < this.numLabels; ++l) {
                classScore[k][l] = ((LogisticRegression)cbm.binaryClassifiers[k][l]).predictClassScores(x)[1];
            }
            double[] supportScores = new double[support.size()];
            for (int s = 0; s < support.size(); ++s) {
                MultiLabel label = support.get(s);
                for (Integer l : label.getMatchedLabels()) {
                    int n = s;
                    supportScores[n] = supportScores[n] + classScore[k][l];
                }
            }
            double[] supportProbs = MathUtil.softmax(supportScores);
            for (int s = 0; s < support.size(); ++s) {
                this.normalizedLogProbs[k][s] = Math.log(supportProbs[s]);
            }
        }
    }

    private double logYGivenComponent(MultiLabel y, int k) {
        if (this.ifSupport) {
            return this.logYGivenComponentBySupport(y, k);
        }
        return this.logYGivenComponentByDefault(y, k);
    }

    private double logYGivenComponentBySupport(MultiLabel y, int k) {
        int supportId = -1;
        for (int l = 0; l < this.support.size(); ++l) {
            if (!this.support.get(l).equals(y)) continue;
            supportId = l;
            break;
        }
        return this.normalizedLogProbs[k][supportId];
    }

    public double logYGivenComponentByDefault(MultiLabel y, int k) {
        double sum = 0.0;
        for (int l = 0; l < this.numLabels; ++l) {
            if (y.matchClass(l)) {
                sum += this.logClassProbs[k][l][1];
                continue;
            }
            sum += this.logClassProbs[k][l][0];
        }
        return sum;
    }

    private double logYGivenComponent(MultiLabel y, int k, double[] noiseLabelWeight) {
        double sum = 0.0;
        for (int l = 0; l < this.numLabels; ++l) {
            if (y.matchClass(l)) {
                sum += noiseLabelWeight[l] * this.logClassProbs[k][l][1];
                continue;
            }
            sum += noiseLabelWeight[l] * this.logClassProbs[k][l][0];
        }
        return sum;
    }

    public double[] posteriorMembership(MultiLabel y) {
        double[] logNumerator = new double[this.numComponents];
        for (int k = 0; k < this.numComponents; ++k) {
            logNumerator[k] = this.logProportions[k] + this.logYGivenComponent(y, k);
        }
        double logDenominator = MathUtil.logSumExp(logNumerator);
        double[] membership = new double[this.numComponents];
        for (int k = 0; k < this.numComponents; ++k) {
            membership[k] = Math.exp(logNumerator[k] - logDenominator);
        }
        return membership;
    }

    double[] posteriorMembership(MultiLabel y, double[] noiseLabelWeight) {
        double[] logNumerator = new double[this.numComponents];
        for (int k = 0; k < this.numComponents; ++k) {
            logNumerator[k] = this.logProportions[k] + this.logYGivenComponent(y, k, noiseLabelWeight);
        }
        double logDenominator = MathUtil.logSumExp(logNumerator);
        double[] membership = new double[this.numComponents];
        for (int k = 0; k < this.numComponents; ++k) {
            membership[k] = Math.exp(logNumerator[k] - logDenominator);
        }
        return membership;
    }

    double[] logPosteriorMembership(MultiLabel y) {
        double[] logNumerator = new double[this.numComponents];
        for (int k = 0; k < this.numComponents; ++k) {
            logNumerator[k] = this.logProportions[k] + this.logYGivenComponent(y, k);
        }
        double logDenominator = MathUtil.logSumExp(logNumerator);
        double[] membership = new double[this.numComponents];
        for (int k = 0; k < this.numComponents; ++k) {
            membership[k] = logNumerator[k] - logDenominator;
        }
        return membership;
    }

    public double logProbability(MultiLabel y) {
        double[] logPs = new double[this.numComponents];
        for (int k = 0; k < this.numComponents; ++k) {
            logPs[k] = this.logProportions[k] + this.logYGivenComponent(y, k);
        }
        double logP = MathUtil.logSumExp(logPs);
        return logP;
    }

    double probability(MultiLabel y) {
        return Math.exp(this.logProbability(y));
    }

    private double marginal(int labelIndex) {
        double sum = 0.0;
        for (int k = 0; k < this.numComponents; ++k) {
            sum += Math.exp(this.logProportions[k]) * Math.exp(this.logClassProbs[k][labelIndex][1]);
        }
        return sum;
    }

    double[] marginals() {
        double[] m = new double[this.numLabels];
        for (int l = 0; l < this.numLabels; ++l) {
            m[l] = this.marginal(l);
        }
        return m;
    }

    List<MultiLabel> sample(int numSamples) {
        ArrayList<MultiLabel> list = new ArrayList<MultiLabel>();
        double[] proportions = Arrays.stream(this.logProportions).map(Math::exp).toArray();
        double[][] classProbs = new double[this.numComponents][this.numLabels];
        for (int k = 0; k < this.numComponents; ++k) {
            for (int l = 0; l < this.numLabels; ++l) {
                classProbs[k][l] = Math.exp(this.logClassProbs[k][l][1]);
            }
        }
        int[] components = IntStream.range(0, this.numComponents).toArray();
        EnumeratedIntegerDistribution enumeratedIntegerDistribution = new EnumeratedIntegerDistribution(components, proportions);
        BernoulliDistribution[][] bernoulliDistributions = new BernoulliDistribution[this.numComponents][this.numLabels];
        for (int k = 0; k < this.numComponents; ++k) {
            for (int l = 0; l < this.numLabels; ++l) {
                bernoulliDistributions[k][l] = new BernoulliDistribution(classProbs[k][l]);
            }
        }
        for (int num = 0; num < numSamples; ++num) {
            MultiLabel multiLabel = new MultiLabel();
            int k = enumeratedIntegerDistribution.sample();
            for (int l = 0; l < this.numLabels; ++l) {
                int v = bernoulliDistributions[k][l].sample();
                if (v != 1) continue;
                multiLabel.addLabel(l);
            }
            list.add(multiLabel);
        }
        return list;
    }

    public String toString() {
        StringBuilder sb = new StringBuilder("BMDistribution{");
        sb.append("numLabels=").append(this.numLabels);
        sb.append(", numComponents=").append(this.numComponents);
        sb.append(", logProportions=").append(Arrays.toString(this.logProportions));
        sb.append(", logClassProbs=").append(Arrays.deepToString((Object[])this.logClassProbs));
        sb.append('}');
        return sb.toString();
    }
}

