/*
 * Decompiled with CFR 0.152.
 */
package edu.neu.ccs.pyramid.multilabel_classification.cbm;

import edu.neu.ccs.pyramid.dataset.MultiLabel;
import edu.neu.ccs.pyramid.multilabel_classification.DynamicProgramming;
import edu.neu.ccs.pyramid.multilabel_classification.cbm.BMDistribution;
import edu.neu.ccs.pyramid.util.MathUtil;
import java.util.HashMap;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.Map;

public class CBMPredictor {
    int numClusters;
    private int numLabels;
    private double[] logisticProb;
    private double[] logisticLogProb;
    private double[][][] logProbs;
    private double[][][] probs;
    private int numSample;
    private boolean allowEmpty = false;

    public CBMPredictor(BMDistribution bmDistribution) {
        this.numClusters = bmDistribution.numComponents;
        this.numLabels = bmDistribution.numLabels;
        this.logisticProb = new double[this.numClusters];
        this.logisticLogProb = bmDistribution.logProportions;
        this.probs = new double[this.numClusters][this.numLabels][2];
        this.logProbs = new double[this.numClusters][this.numLabels][2];
        for (int k = 0; k < this.numClusters; ++k) {
            this.logisticProb[k] = Math.exp(this.logisticLogProb[k]);
            for (int l = 0; l < this.numLabels; ++l) {
                this.logProbs[k][l] = bmDistribution.logClassProbs[k][l];
                for (int i = 0; i < 2; ++i) {
                    this.probs[k][l][i] = Math.exp(this.logProbs[k][l][i]);
                }
            }
        }
    }

    public MultiLabel predictByDynamic() {
        HashMap<Integer, DynamicProgramming> DPs = new HashMap<Integer, DynamicProgramming>();
        double[] maxClusterProb = new double[this.numClusters];
        for (int k = 0; k < this.numClusters; ++k) {
            DPs.put(k, new DynamicProgramming(this.probs[k], this.logProbs[k]));
            maxClusterProb[k] = ((DynamicProgramming)DPs.get(k)).nextHighestProb();
        }
        double[] cond1 = new double[this.numClusters];
        double[] sumPiD = new double[this.numClusters];
        for (int k = 0; k < this.numClusters; ++k) {
            cond1[k] = maxClusterProb[k] - 1.0 / this.logisticProb[k] + 1.0;
            double sum = 0.0;
            for (int r = 0; r < this.numClusters; ++r) {
                if (r == k) continue;
                sum += this.logisticProb[r] * maxClusterProb[r];
            }
            sumPiD[k] = sum;
        }
        double maxLogProb = Double.NEGATIVE_INFINITY;
        MultiLabel bestMultiLabel = new MultiLabel();
        int iter = 0;
        int maxIter = 10;
        while (DPs.size() > 0) {
            LinkedList<Integer> removeList = new LinkedList<Integer>();
            for (Map.Entry entry : DPs.entrySet()) {
                int k = (Integer)entry.getKey();
                DynamicProgramming dp = (DynamicProgramming)entry.getValue();
                double prob = dp.nextHighestProb();
                MultiLabel multiLabel = dp.nextHighestVector();
                if (multiLabel.getNumMatchedLabels() == 0 && !this.allowEmpty) {
                    if (dp.getQueue().size() != 0) continue;
                    removeList.add(k);
                    continue;
                }
                double logProb = this.logProbYnGivenXnLogisticProb(multiLabel);
                if (logProb >= maxLogProb) {
                    bestMultiLabel = multiLabel;
                    maxLogProb = logProb;
                }
                if (!this.checkStop(prob, cond1[k], maxLogProb, sumPiD[k], k) && dp.getQueue().size() != 0) continue;
                removeList.add(k);
            }
            Iterator iterator = removeList.iterator();
            while (iterator.hasNext()) {
                int k = (Integer)((Object)iterator.next());
                DPs.remove(k);
            }
            if (++iter < maxIter) continue;
            break;
        }
        return bestMultiLabel;
    }

    private boolean checkStop(double q, double c1, double maxLogProb, double sumPiDk, int k) {
        if (q <= c1) {
            return true;
        }
        if (q * this.logisticProb[k] <= Math.exp(maxLogProb) / (double)this.numClusters) {
            return true;
        }
        return this.logisticProb[k] * q + sumPiDk <= Math.exp(maxLogProb);
    }

    public MultiLabel predictByHardAssignment() {
        int maxK = 0;
        double maxPi = this.logisticLogProb[0];
        for (int k = 1; k < this.logisticLogProb.length; ++k) {
            if (!(maxPi < this.logisticLogProb[k])) continue;
            maxK = k;
            maxPi = this.logisticLogProb[k];
        }
        MultiLabel predict = new MultiLabel();
        for (int l = 0; l < this.numLabels; ++l) {
            if (!(this.probs[maxK][l][1] > 0.5)) continue;
            predict.addLabel(l);
        }
        return predict;
    }

    private double logProbYnGivenXnLogisticProb(MultiLabel Y) {
        double[] logPYnk = this.clusterConditionalLogProbArr(Y);
        double[] sumLog = new double[this.logisticLogProb.length];
        for (int k = 0; k < this.numClusters; ++k) {
            sumLog[k] = this.logisticLogProb[k] + logPYnk[k];
        }
        return MathUtil.logSumExp(sumLog);
    }

    public double[] clusterConditionalLogProbArr(MultiLabel Y) {
        double[] probArr = new double[this.numClusters];
        for (int k = 0; k < this.numClusters; ++k) {
            double logProb = 0.0;
            for (int l = 0; l < this.numLabels; ++l) {
                logProb = Y.matchClass(l) ? (logProb += this.logProbs[k][l][1]) : (logProb += this.logProbs[k][l][0]);
                if (logProb == Double.NEGATIVE_INFINITY) break;
            }
            probArr[k] = logProb;
        }
        return probArr;
    }

    public void setNumSamples(int numSample) {
        this.numSample = numSample;
    }

    public void setAllowEmpty(boolean allowEmpty) {
        this.allowEmpty = allowEmpty;
    }
}

