/*
 * Decompiled with CFR 0.152.
 */
package edu.neu.ccs.pyramid.multilabel_classification.cbm;

import edu.neu.ccs.pyramid.classification.logistic_regression.LogisticRegression;
import edu.neu.ccs.pyramid.classification.logistic_regression.RidgeLogisticOptimizer;
import edu.neu.ccs.pyramid.dataset.DataSet;
import edu.neu.ccs.pyramid.dataset.MultiLabel;
import edu.neu.ccs.pyramid.dataset.MultiLabelClfDataSet;
import edu.neu.ccs.pyramid.eval.Entropy;
import edu.neu.ccs.pyramid.multilabel_classification.cbm.AugmentedLRLoss;
import edu.neu.ccs.pyramid.multilabel_classification.cbm.CBMS;
import edu.neu.ccs.pyramid.optimization.LBFGS;
import edu.neu.ccs.pyramid.optimization.Terminator;
import java.util.stream.IntStream;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.mahout.math.Vector;

public class CBMSOptimizer {
    private static final Logger logger = LogManager.getLogger();
    private CBMS cbms;
    private MultiLabelClfDataSet dataSet;
    private Terminator terminator;
    double[][] gammas;
    private boolean isParallel = true;
    private double priorVarianceMultiClass = 1.0;
    private double priorVarianceBinary = 1.0;
    private double componentWeightsVariance = 1.0;
    private int numMultiClassParaUpdates = 10;
    private int numBinaryParaUpdates = 10;

    public CBMSOptimizer(CBMS cbms, MultiLabelClfDataSet dataSet) {
        this.cbms = cbms;
        this.dataSet = dataSet;
        this.terminator = new Terminator();
        this.terminator.setGoal(Terminator.Goal.MINIMIZE);
        this.gammas = new double[dataSet.getNumDataPoints()][cbms.getNumComponents()];
    }

    public void setNumMultiClassParaUpdates(int numMultiClassParaUpdates) {
        this.numMultiClassParaUpdates = numMultiClassParaUpdates;
    }

    public void setNumBinaryParaUpdates(int numBinaryParaUpdates) {
        this.numBinaryParaUpdates = numBinaryParaUpdates;
    }

    public void setComponentWeightsVariance(double componentWeightsVariance) {
        this.componentWeightsVariance = componentWeightsVariance;
    }

    public void setPriorVarianceMultiClass(double priorVarianceMultiClass) {
        this.priorVarianceMultiClass = priorVarianceMultiClass;
    }

    public void setPriorVarianceBinary(double priorVarianceBinary) {
        this.priorVarianceBinary = priorVarianceBinary;
    }

    public void optimize() {
        do {
            this.iterate();
        } while (!this.terminator.shouldTerminate());
    }

    public void iterate() {
        this.eStep();
        this.mStep();
    }

    public void eStep() {
        if (logger.isDebugEnabled()) {
            logger.debug("start E step");
        }
        this.updateGamma();
        if (logger.isDebugEnabled()) {
            logger.debug("finish E step");
            logger.debug("objective = " + this.getObjective());
        }
    }

    private void updateGamma() {
        IntStream.range(0, this.dataSet.getNumDataPoints()).parallel().forEach(this::updateGamma);
    }

    private void updateGamma(int n) {
        Vector x = this.dataSet.getRow(n);
        MultiLabel y = this.dataSet.getMultiLabels()[n];
        double[] posterior = this.cbms.posteriorMembership(x, y);
        for (int k = 0; k < this.cbms.numComponents; ++k) {
            this.gammas[n][k] = posterior[k];
        }
    }

    public void mStep() {
        if (logger.isDebugEnabled()) {
            logger.debug("start M step");
        }
        this.updateBinaryClassifiers();
        this.updateMultiClassClassifier();
        if (logger.isDebugEnabled()) {
            logger.debug("finish M step");
            logger.debug("objective = " + this.getObjective());
        }
    }

    private void updateBinaryClassifiers() {
        if (logger.isDebugEnabled()) {
            logger.debug("start updateBinaryClassifiers");
        }
        IntStream.range(0, this.cbms.numLabels).parallel().forEach(this::updateBinaryLogisticRegression);
        if (logger.isDebugEnabled()) {
            logger.debug("finish updateBinaryClassifiers");
        }
    }

    private void updateBinaryLogisticRegression(int labelIndex) {
        AugmentedLRLoss loss = new AugmentedLRLoss(this.dataSet, labelIndex, this.gammas, this.cbms.getBinaryClassifiers()[labelIndex], this.priorVarianceBinary, this.componentWeightsVariance);
        LBFGS lbfgs = new LBFGS(loss);
        lbfgs.getTerminator().setMaxIteration(this.numBinaryParaUpdates);
        lbfgs.getTerminator().setGoal(Terminator.Goal.MINIMIZE);
        lbfgs.optimize();
    }

    private void updateMultiClassLR() {
        RidgeLogisticOptimizer ridgeLogisticOptimizer = new RidgeLogisticOptimizer((LogisticRegression)this.cbms.multiClassClassifier, (DataSet)this.dataSet, this.gammas, this.priorVarianceMultiClass, true);
        ridgeLogisticOptimizer.getOptimizer().getTerminator().setMaxIteration(this.numMultiClassParaUpdates);
        ridgeLogisticOptimizer.optimize();
    }

    private void updateMultiClassClassifier() {
        if (logger.isDebugEnabled()) {
            logger.debug("start updateMultiClassClassifier()");
        }
        this.updateMultiClassLR();
    }

    public double getObjective() {
        return this.multiClassClassifierObj() + this.binaryObj();
    }

    private double getEntropy() {
        return IntStream.range(0, this.dataSet.getNumDataPoints()).parallel().mapToDouble(this::getEntropy).sum();
    }

    private double getEntropy(int i) {
        return Entropy.entropy(this.gammas[i]);
    }

    double binaryObj() {
        return IntStream.range(0, this.cbms.numLabels).parallel().mapToDouble(this::binaryObj).sum();
    }

    private double binaryObj(int labelIndex) {
        AugmentedLRLoss loss = new AugmentedLRLoss(this.dataSet, labelIndex, this.gammas, this.cbms.getBinaryClassifiers()[labelIndex], this.priorVarianceBinary, this.componentWeightsVariance);
        return loss.getValue();
    }

    double multiClassClassifierObj() {
        RidgeLogisticOptimizer ridgeLogisticOptimizer = new RidgeLogisticOptimizer((LogisticRegression)this.cbms.multiClassClassifier, (DataSet)this.dataSet, this.gammas, this.priorVarianceMultiClass, true);
        return ridgeLogisticOptimizer.getFunction().getValue();
    }

    public Terminator getTerminator() {
        return this.terminator;
    }

    public double[][] getGammas() {
        return this.gammas;
    }
}

