/*
 * Decompiled with CFR 0.152.
 */
package edu.neu.ccs.pyramid.multilabel_classification.cbm;

import edu.neu.ccs.pyramid.classification.Classifier;
import edu.neu.ccs.pyramid.classification.lkboost.LKBOutputCalculator;
import edu.neu.ccs.pyramid.classification.lkboost.LKBoost;
import edu.neu.ccs.pyramid.classification.lkboost.LKBoostOptimizer;
import edu.neu.ccs.pyramid.classification.logistic_regression.ElasticNetLogisticTrainer;
import edu.neu.ccs.pyramid.classification.logistic_regression.LogisticLoss;
import edu.neu.ccs.pyramid.classification.logistic_regression.LogisticRegression;
import edu.neu.ccs.pyramid.classification.logistic_regression.RidgeLogisticOptimizer;
import edu.neu.ccs.pyramid.dataset.DataSet;
import edu.neu.ccs.pyramid.dataset.DataSetUtil;
import edu.neu.ccs.pyramid.dataset.MultiLabel;
import edu.neu.ccs.pyramid.dataset.MultiLabelClfDataSet;
import edu.neu.ccs.pyramid.eval.Entropy;
import edu.neu.ccs.pyramid.eval.KLDivergence;
import edu.neu.ccs.pyramid.multilabel_classification.MultiLabelClassifier;
import edu.neu.ccs.pyramid.multilabel_classification.cbm.BMDistribution;
import edu.neu.ccs.pyramid.multilabel_classification.cbm.CBM;
import edu.neu.ccs.pyramid.optimization.Terminator;
import edu.neu.ccs.pyramid.regression.RegressorFactory;
import edu.neu.ccs.pyramid.regression.regression_tree.RegTreeConfig;
import edu.neu.ccs.pyramid.regression.regression_tree.RegTreeFactory;
import edu.neu.ccs.pyramid.util.MathUtil;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.stream.IntStream;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.mahout.math.Vector;

public class CBMNoiseOptimizerFixed {
    private static final Logger logger = LogManager.getLogger();
    private CBM cbm;
    private MultiLabelClfDataSet dataSet;
    private Terminator terminator;
    double[][] gammas;
    double[][] gammasT;
    private double[][][] binaryTargetsDistributions;
    private double priorVarianceMultiClass = 1.0;
    private double priorVarianceBinary = 1.0;
    private double regularizationMultiClass = 1.0;
    private double regularizationBinary = 1.0;
    private double l1RatioBinary = 0.0;
    private double l1RatioMultiClass = 0.0;
    private boolean lineSearch = true;
    private int numLeavesBinary = 2;
    private int numLeavesMultiClass = 2;
    private double shrinkageBinary = 0.1;
    private double shrinkageMultiClass = 0.1;
    private int numIterationsBinary = 20;
    private int numIterationsMultiClass = 20;
    private List<MultiLabel> combinations;
    private double[][] targets;
    private double[][] probabilities;
    private double[][] scores;

    public CBMNoiseOptimizerFixed(CBM cbm, MultiLabelClfDataSet dataSet, MultiLabelClfDataSet dataSetGroundTruth, MultiLabelClassifier.AssignmentProbEstimator classifier, Boolean includeFeature) {
        System.out.println("Enter CBMNoiseOptimizerFixed constructor ...");
        this.cbm = cbm;
        this.dataSet = dataSet;
        this.combinations = DataSetUtil.gatherMultiLabels(dataSetGroundTruth);
        this.terminator = new Terminator();
        this.terminator.setGoal(Terminator.Goal.MINIMIZE);
        this.gammas = new double[dataSet.getNumDataPoints()][cbm.getNumComponents()];
        this.gammasT = new double[cbm.getNumComponents()][dataSet.getNumDataPoints()];
        this.binaryTargetsDistributions = new double[cbm.getNumClasses()][dataSet.getNumDataPoints()][2];
        this.scores = new double[dataSet.getNumDataPoints()][this.combinations.size()];
        System.out.println("#data points: " + dataSet.getNumDataPoints() + ", #combinations: " + this.combinations.size());
        IntStream.range(0, dataSet.getNumDataPoints()).forEach(i -> IntStream.range(0, this.combinations.size()).parallel().forEach(j -> {
            MultiLabel truth = dataSet.getMultiLabels()[i];
            MultiLabel combination = this.combinations.get(j);
            double f = 0.0;
            if (!includeFeature.booleanValue()) {
                f = classifier.predictAssignmentProb(combination.toVector(dataSet.getNumClasses()), truth);
            } else {
                int k;
                MultiLabel xz = new MultiLabel();
                MultiLabel x = new MultiLabel(dataSet.getRow(i));
                for (k = 0; k < dataSet.getNumFeatures(); ++k) {
                    if (!x.matchClass(k)) continue;
                    xz.addLabel(k);
                }
                for (k = 0; k < dataSet.getNumClasses(); ++k) {
                    if (!combination.matchClass(k)) continue;
                    xz.addLabel(k + dataSet.getNumFeatures());
                }
                f = classifier.predictAssignmentProb(xz.toVector(dataSet.getNumFeatures() + dataSet.getNumClasses()), truth);
            }
            this.scores[i][j] = f;
        }));
        System.out.println("Finished evaluating fixed noise model p(y_n | z) ...");
        this.targets = new double[dataSet.getNumDataPoints()][this.combinations.size()];
        this.probabilities = new double[dataSet.getNumDataPoints()][this.combinations.size()];
        this.updateProbabilities();
        if (logger.isDebugEnabled()) {
            logger.debug("finish constructor");
        }
    }

    private void updateProbabilities(int dataPointIndex) {
        this.probabilities[dataPointIndex] = this.cbm.predictAssignmentProbs(this.dataSet.getRow(dataPointIndex), this.combinations);
    }

    private void updateProbabilities() {
        if (logger.isDebugEnabled()) {
            logger.debug("start updateProbabilities()");
        }
        IntStream.range(0, this.dataSet.getNumDataPoints()).parallel().forEach(this::updateProbabilities);
        if (logger.isDebugEnabled()) {
            logger.debug("finish updateProbabilities()");
        }
    }

    private void updateTargets(int dataPointIndex) {
        double[] probs = this.probabilities[dataPointIndex];
        double[] product = new double[probs.length];
        double[] s = this.scores[dataPointIndex];
        for (int j = 0; j < probs.length; ++j) {
            product[j] = probs[j] * s[j];
        }
        double denominator = MathUtil.arraySum(product);
        for (int j = 0; j < probs.length; ++j) {
            this.targets[dataPointIndex][j] = product[j] / denominator;
        }
    }

    private void updateTargets() {
        if (logger.isDebugEnabled()) {
            logger.debug("start updateTargets()");
        }
        IntStream.range(0, this.dataSet.getNumDataPoints()).parallel().forEach(this::updateTargets);
        if (logger.isDebugEnabled()) {
            logger.debug("finish updateTargets()");
        }
    }

    private void updateBinaryTargets() {
        if (logger.isDebugEnabled()) {
            logger.debug("start updateBinaryTargets()");
        }
        IntStream.range(0, this.dataSet.getNumDataPoints()).parallel().forEach(this::updateBinaryTarget);
        if (logger.isDebugEnabled()) {
            logger.debug("finish updateBinaryTargets()");
        }
    }

    private void updateBinaryTarget(int dataPointIndex) {
        double[] comProb = this.targets[dataPointIndex];
        double[] marginals = new double[this.cbm.getNumClasses()];
        for (int c = 0; c < comProb.length; ++c) {
            MultiLabel multiLabel = this.combinations.get(c);
            double prob = comProb[c];
            Iterator<Integer> iterator = multiLabel.getMatchedLabels().iterator();
            while (iterator.hasNext()) {
                int l;
                int n = l = iterator.next().intValue();
                marginals[n] = marginals[n] + prob;
            }
        }
        for (int l = 0; l < this.cbm.getNumClasses(); ++l) {
            if (marginals[l] > 1.0) {
                marginals[l] = 1.0;
            }
            this.binaryTargetsDistributions[l][dataPointIndex][0] = 1.0 - marginals[l];
            this.binaryTargetsDistributions[l][dataPointIndex][1] = marginals[l];
        }
    }

    public void setPriorVarianceMultiClass(double priorVarianceMultiClass) {
        this.priorVarianceMultiClass = priorVarianceMultiClass;
    }

    public void setPriorVarianceBinary(double priorVarianceBinary) {
        this.priorVarianceBinary = priorVarianceBinary;
    }

    public void setNumLeavesBinary(int numLeavesBinary) {
        this.numLeavesBinary = numLeavesBinary;
    }

    public void setNumLeavesMultiClass(int numLeavesMultiClass) {
        this.numLeavesMultiClass = numLeavesMultiClass;
    }

    public void setShrinkageBinary(double shrinkageBinary) {
        this.shrinkageBinary = shrinkageBinary;
    }

    public void setShrinkageMultiClass(double shrinkageMultiClass) {
        this.shrinkageMultiClass = shrinkageMultiClass;
    }

    public void setNumIterationsBinary(int numIterationsBinary) {
        this.numIterationsBinary = numIterationsBinary;
    }

    public void setNumIterationsMultiClass(int numIterationsMultiClass) {
        this.numIterationsMultiClass = numIterationsMultiClass;
    }

    public void optimize() {
        do {
            this.iterate();
        } while (!this.terminator.shouldTerminate());
    }

    public void iterate() {
        this.updateTargets();
        this.updateBinaryTargets();
        this.updateGamma();
        this.updateMultiClassClassifier();
        this.updateBinaryClassifiers();
        this.updateProbabilities();
        this.terminator.add(this.objective());
    }

    private void updateGamma() {
        if (logger.isDebugEnabled()) {
            logger.debug("start updateGamma()");
        }
        IntStream.range(0, this.dataSet.getNumDataPoints()).parallel().forEach(this::updateGamma);
        if (logger.isDebugEnabled()) {
            logger.debug("finish updateGamma()");
        }
    }

    private void updateGamma(int n) {
        Vector x = this.dataSet.getRow(n);
        BMDistribution bmDistribution = this.cbm.computeBM(x);
        ArrayList<double[]> logPosteriors = new ArrayList<double[]>();
        for (int c = 0; c < this.combinations.size(); ++c) {
            MultiLabel combination = this.combinations.get(c);
            double[] pos = bmDistribution.logPosteriorMembership(combination);
            logPosteriors.add(pos);
        }
        double[] sums = new double[this.cbm.numComponents];
        for (int k = 0; k < this.cbm.numComponents; ++k) {
            double sum = 0.0;
            for (int c = 0; c < this.combinations.size(); ++c) {
                sum += this.targets[n][c] * ((double[])logPosteriors.get(c))[k];
            }
            sums[k] = sum;
        }
        double[] posterior = MathUtil.softmax(sums);
        for (int k = 0; k < this.cbm.numComponents; ++k) {
            this.gammas[n][k] = posterior[k];
            this.gammasT[k][n] = posterior[k];
        }
    }

    private void updateBinaryClassifiers() {
        if (logger.isDebugEnabled()) {
            logger.debug("start updateBinaryClassifiers()");
        }
        IntStream.range(0, this.cbm.numComponents).forEach(this::updateBinaryClassifiers);
        if (logger.isDebugEnabled()) {
            logger.debug("finish updateBinaryClassifiers()");
        }
    }

    private void updateBinaryClassifiers(int component) {
        String type;
        switch (type = this.cbm.getBinaryClassifierType()) {
            case "lr": {
                IntStream.range(0, this.cbm.numLabels).parallel().forEach(l -> this.updateBinaryLogisticRegression(component, l));
                break;
            }
            case "boost": {
                IntStream.range(0, this.cbm.numLabels).forEach(l -> this.updateBinaryBoosting(component, l));
                break;
            }
            case "elasticnet": {
                IntStream.range(0, this.cbm.numLabels).parallel().forEach(l -> this.updateBinaryLogisticRegressionEL(component, l));
                break;
            }
            default: {
                throw new IllegalArgumentException("unknown type: " + this.cbm.getBinaryClassifierType());
            }
        }
    }

    private void updateBinaryBoosting(int componentIndex, int labelIndex) {
        int numIterations = this.numIterationsBinary;
        double shrinkage = this.shrinkageBinary;
        LKBoost boost = (LKBoost)this.cbm.binaryClassifiers[componentIndex][labelIndex];
        RegTreeConfig regTreeConfig = new RegTreeConfig().setMaxNumLeaves(this.numLeavesBinary);
        RegTreeFactory regTreeFactory = new RegTreeFactory(regTreeConfig);
        regTreeFactory.setLeafOutputCalculator(new LKBOutputCalculator(2));
        LKBoostOptimizer optimizer = new LKBoostOptimizer(boost, this.dataSet, regTreeFactory, this.gammasT[componentIndex], this.binaryTargetsDistributions[labelIndex]);
        optimizer.setShrinkage(shrinkage);
        optimizer.initialize();
        optimizer.iterate(numIterations);
    }

    private void updateBinaryLogisticRegression(int componentIndex, int labelIndex) {
        RidgeLogisticOptimizer ridgeLogisticOptimizer = new RidgeLogisticOptimizer((LogisticRegression)this.cbm.binaryClassifiers[componentIndex][labelIndex], (DataSet)this.dataSet, this.gammasT[componentIndex], this.binaryTargetsDistributions[labelIndex], this.priorVarianceBinary, false);
        ridgeLogisticOptimizer.getOptimizer().getTerminator().setMaxIteration(10);
        ridgeLogisticOptimizer.optimize();
    }

    private void updateBinaryLogisticRegressionEL(int componentIndex, int labelIndex) {
        ElasticNetLogisticTrainer elasticNetLogisticTrainer = new ElasticNetLogisticTrainer.Builder((LogisticRegression)this.cbm.binaryClassifiers[componentIndex][labelIndex], this.dataSet, 2, this.binaryTargetsDistributions[labelIndex], this.gammasT[componentIndex]).setRegularization(this.regularizationBinary).setL1Ratio(this.l1RatioBinary).setLineSearch(this.lineSearch).build();
        elasticNetLogisticTrainer.getTerminator().setMaxIteration(10);
        elasticNetLogisticTrainer.optimize();
    }

    private void updateMultiClassClassifier() {
        String type;
        if (logger.isDebugEnabled()) {
            logger.debug("start updateMultiClassClassifier()");
        }
        switch (type = this.cbm.getMultiClassClassifierType()) {
            case "lr": {
                this.updateMultiClassLR();
                break;
            }
            case "boost": {
                this.updateMultiClassBoost();
                break;
            }
            case "elasticnet": {
                this.updateMultiClassEL();
                break;
            }
            default: {
                throw new IllegalArgumentException("unknown type: " + this.cbm.getMultiClassClassifierType());
            }
        }
        if (logger.isDebugEnabled()) {
            logger.debug("finish updateMultiClassClassifier()");
        }
    }

    private void updateMultiClassEL() {
        ElasticNetLogisticTrainer elasticNetLogisticTrainer = new ElasticNetLogisticTrainer.Builder((LogisticRegression)this.cbm.multiClassClassifier, (DataSet)this.dataSet, this.cbm.multiClassClassifier.getNumClasses(), this.gammas).setRegularization(this.regularizationMultiClass).setL1Ratio(this.l1RatioMultiClass).setLineSearch(this.lineSearch).build();
        elasticNetLogisticTrainer.getTerminator().setMaxIteration(10);
        elasticNetLogisticTrainer.optimize();
    }

    private void updateMultiClassLR() {
        RidgeLogisticOptimizer ridgeLogisticOptimizer = new RidgeLogisticOptimizer((LogisticRegression)this.cbm.multiClassClassifier, (DataSet)this.dataSet, this.gammas, this.priorVarianceMultiClass, true);
        ridgeLogisticOptimizer.getOptimizer().getTerminator().setMaxIteration(10);
        ridgeLogisticOptimizer.optimize();
    }

    private void updateMultiClassBoost() {
        int numComponents = this.cbm.numComponents;
        int numIterations = this.numIterationsMultiClass;
        double shrinkage = this.shrinkageMultiClass;
        LKBoost boost = (LKBoost)this.cbm.multiClassClassifier;
        RegTreeConfig regTreeConfig = new RegTreeConfig().setMaxNumLeaves(this.numLeavesMultiClass);
        RegTreeFactory regTreeFactory = new RegTreeFactory(regTreeConfig);
        regTreeFactory.setLeafOutputCalculator(new LKBOutputCalculator(numComponents));
        LKBoostOptimizer optimizer = new LKBoostOptimizer(boost, (DataSet)this.dataSet, (RegressorFactory)regTreeFactory, this.gammas);
        optimizer.setShrinkage(shrinkage);
        optimizer.initialize();
        optimizer.iterate(numIterations);
    }

    private double objective(int dataPointIndex) {
        double sum = 0.0;
        double[] p = this.probabilities[dataPointIndex];
        double[] s = this.scores[dataPointIndex];
        for (int j = 0; j < p.length; ++j) {
            sum += p[j] * s[j];
        }
        return -Math.log(sum);
    }

    public double objective() {
        if (logger.isDebugEnabled()) {
            logger.debug("start objective()");
        }
        double obj = IntStream.range(0, this.dataSet.getNumDataPoints()).parallel().mapToDouble(this::objective).sum();
        if (logger.isDebugEnabled()) {
            logger.debug("finish obj");
        }
        double penalty = this.penalty();
        if (logger.isDebugEnabled()) {
            logger.debug("finish penalty");
        }
        if (logger.isDebugEnabled()) {
            logger.debug("finish objective()");
        }
        return obj + penalty;
    }

    private double penalty() {
        double sum = 0.0;
        LogisticLoss logisticLoss = new LogisticLoss((LogisticRegression)this.cbm.multiClassClassifier, this.dataSet, this.gammas, this.priorVarianceMultiClass, true);
        sum += logisticLoss.penaltyValue();
        for (int k = 0; k < this.cbm.numComponents; ++k) {
            for (int l = 0; l < this.cbm.getNumClasses(); ++l) {
                sum += new LogisticLoss((LogisticRegression)this.cbm.binaryClassifiers[k][l], (DataSet)this.dataSet, this.gammasT[k], this.binaryTargetsDistributions[l], this.priorVarianceBinary, true).penaltyValue();
            }
        }
        return sum;
    }

    private double getEntropy() {
        return IntStream.range(0, this.dataSet.getNumDataPoints()).parallel().mapToDouble(this::getEntropy).sum();
    }

    private double getEntropy(int i) {
        return Entropy.entropy(this.gammas[i]);
    }

    private double binaryObj() {
        return IntStream.range(0, this.cbm.numComponents).mapToDouble(this::binaryObj).sum();
    }

    private double binaryObj(int clusterIndex) {
        return IntStream.range(0, this.cbm.numLabels).parallel().mapToDouble(l -> this.binaryObj(clusterIndex, l)).sum();
    }

    private double binaryObj(int clusterIndex, int classIndex) {
        String type;
        switch (type = this.cbm.getBinaryClassifierType()) {
            case "lr": {
                return this.binaryLRObj(clusterIndex, classIndex);
            }
            case "boost": {
                return this.binaryBoostObj(clusterIndex, classIndex);
            }
            case "elasticnet": {
                return this.binaryLRObj(clusterIndex, classIndex);
            }
        }
        throw new IllegalArgumentException("unknown type: " + type);
    }

    private double binaryLRObj(int clusterIndex, int classIndex) {
        LogisticLoss logisticLoss = new LogisticLoss((LogisticRegression)this.cbm.binaryClassifiers[clusterIndex][classIndex], (DataSet)this.dataSet, this.gammasT[clusterIndex], this.binaryTargetsDistributions[classIndex], this.priorVarianceBinary, false);
        return logisticLoss.getValue();
    }

    private double binaryBoostObj(int clusterIndex, int classIndex) {
        Classifier.ProbabilityEstimator estimator = this.cbm.binaryClassifiers[clusterIndex][classIndex];
        double[][] targets = this.binaryTargetsDistributions[classIndex];
        double[] weights = this.gammasT[clusterIndex];
        return KLDivergence.kl(estimator, this.dataSet, targets, weights);
    }

    private double multiClassClassifierObj() {
        String type;
        switch (type = this.cbm.getMultiClassClassifierType()) {
            case "lr": {
                return this.multiClassLRObj();
            }
            case "boost": {
                return this.multiClassBoostObj();
            }
            case "elasticnet": {
                return this.multiClassLRObj();
            }
        }
        throw new IllegalArgumentException("unknown type: " + type);
    }

    private double multiClassBoostObj() {
        Classifier.ProbabilityEstimator estimator = this.cbm.multiClassClassifier;
        double[][] targets = this.gammas;
        return KLDivergence.kl(estimator, this.dataSet, targets);
    }

    private double multiClassLRObj() {
        LogisticLoss logisticLoss = new LogisticLoss((LogisticRegression)this.cbm.multiClassClassifier, this.dataSet, this.gammas, this.priorVarianceMultiClass, true);
        return logisticLoss.getValue();
    }

    public Terminator getTerminator() {
        return this.terminator;
    }

    public double[][] getGammas() {
        return this.gammas;
    }

    public double[][] getPIs() {
        double[][] PIs = new double[this.dataSet.getNumDataPoints()][this.cbm.getNumComponents()];
        for (int n = 0; n < PIs.length; ++n) {
            double[] logProbs = this.cbm.multiClassClassifier.predictLogClassProbs(this.dataSet.getRow(n));
            for (int k = 0; k < PIs[n].length; ++k) {
                PIs[n][k] = Math.exp(logProbs[k]);
            }
        }
        return PIs;
    }

    public double getRegularizationMultiClass() {
        return this.regularizationMultiClass;
    }

    public void setRegularizationMultiClass(double regularizationMultiClass) {
        this.regularizationMultiClass = regularizationMultiClass;
    }

    public double getRegularizationBinary() {
        return this.regularizationBinary;
    }

    public void setRegularizationBinary(double regularizationBinary) {
        this.regularizationBinary = regularizationBinary;
    }

    public boolean isLineSearch() {
        return this.lineSearch;
    }

    public void setLineSearch(boolean lineSearch) {
        this.lineSearch = lineSearch;
    }

    public double getL1RatioBinary() {
        return this.l1RatioBinary;
    }

    public void setL1RatioBinary(double l1RatioBinary) {
        this.l1RatioBinary = l1RatioBinary;
    }

    public double getL1RatioMultiClass() {
        return this.l1RatioMultiClass;
    }

    public void setL1RatioMultiClass(double l1RatioMultiClass) {
        this.l1RatioMultiClass = l1RatioMultiClass;
    }
}

