/*
 * Decompiled with CFR 0.152.
 */
package edu.neu.ccs.pyramid.multilabel_classification.cbm;

import edu.neu.ccs.pyramid.classification.PriorProbClassifier;
import edu.neu.ccs.pyramid.dataset.MultiLabel;
import edu.neu.ccs.pyramid.multilabel_classification.cbm.CBM;
import edu.neu.ccs.pyramid.util.ArgSort;
import edu.neu.ccs.pyramid.util.MathUtil;
import org.apache.mahout.math.Vector;

public class ShortCircuitPosterior {
    int numLabels;
    int numComponents;
    double[] logProportions;
    double[] logYGivenComponent;
    MultiLabel y;
    CBM cbm;
    Vector x;
    double skipThreshold = 30.0;

    public ShortCircuitPosterior(CBM cbm, Vector x, MultiLabel y) {
        int[] sortedComponents;
        this.numLabels = cbm.numLabels;
        this.y = y;
        this.x = x;
        this.cbm = cbm;
        this.numComponents = cbm.numComponents;
        this.logProportions = cbm.multiClassClassifier.predictLogClassProbs(x);
        this.logYGivenComponent = new double[this.numComponents];
        double max = Double.NEGATIVE_INFINITY;
        for (int k : sortedComponents = ArgSort.argSortDescending(this.logProportions)) {
            if (!(this.logProportions[k] > max - this.skipThreshold)) continue;
            this.logYGivenComponent[k] = this.computeLogYGivenComponent(k, max);
            double s = this.logProportions[k] + this.logYGivenComponent[k];
            if (!(s > max)) continue;
            max = s;
        }
    }

    private double computeLogYGivenComponent(int k, double max) {
        int l;
        double sum = 0.0;
        for (int l2 : this.y.getMatchedLabels()) {
            if (!(this.cbm.binaryClassifiers[k][l2] instanceof PriorProbClassifier) || !((sum += this.cbm.binaryClassifiers[k][l2].predictLogClassProbs(this.x)[1]) + this.logProportions[k] < max - this.skipThreshold)) continue;
            return sum;
        }
        for (l = 0; l < this.numLabels; ++l) {
            if (!(this.cbm.binaryClassifiers[k][l] instanceof PriorProbClassifier)) continue;
            double[] logProbs = this.cbm.binaryClassifiers[k][l].predictLogClassProbs(this.x);
            sum = this.y.matchClass(l) ? (sum += logProbs[1]) : (sum += logProbs[0]);
            if (!(sum + this.logProportions[k] < max - this.skipThreshold)) continue;
            return sum;
        }
        for (l = 0; l < this.numLabels; ++l) {
            if (this.cbm.binaryClassifiers[k][l] instanceof PriorProbClassifier) continue;
            double[] logProbs = this.cbm.binaryClassifiers[k][l].predictLogClassProbs(this.x);
            sum = this.y.matchClass(l) ? (sum += logProbs[1]) : (sum += logProbs[0]);
            if (!(sum + this.logProportions[k] < max - this.skipThreshold)) continue;
            return sum;
        }
        return sum;
    }

    public double[] posteriorMembership() {
        double[] logNumerator = new double[this.numComponents];
        for (int k = 0; k < this.numComponents; ++k) {
            logNumerator[k] = this.logProportions[k] + this.logYGivenComponent[k];
        }
        return MathUtil.softmax(logNumerator);
    }
}

