/*
 * Decompiled with CFR 0.152.
 */
package edu.neu.ccs.pyramid.multilabel_classification.cbm;

import edu.neu.ccs.pyramid.classification.PriorProbClassifier;
import edu.neu.ccs.pyramid.classification.logistic_regression.LogisticLoss;
import edu.neu.ccs.pyramid.classification.logistic_regression.LogisticRegression;
import edu.neu.ccs.pyramid.classification.logistic_regression.RidgeLogisticOptimizer;
import edu.neu.ccs.pyramid.dataset.DataSet;
import edu.neu.ccs.pyramid.dataset.DataSetUtil;
import edu.neu.ccs.pyramid.dataset.MultiLabelClfDataSet;
import edu.neu.ccs.pyramid.multilabel_classification.cbm.AbstractCBMOptimizer;
import edu.neu.ccs.pyramid.multilabel_classification.cbm.CBM;
import org.apache.commons.lang3.time.StopWatch;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;

public class LRCBMOptimizer
extends AbstractCBMOptimizer {
    private static final Logger logger = LogManager.getLogger();
    private double priorVarianceMultiClass = 1.0;
    private double priorVarianceBinary = 1.0;

    public LRCBMOptimizer(CBM cbm, MultiLabelClfDataSet dataSet) {
        super(cbm, dataSet);
    }

    public void setPriorVarianceMultiClass(double priorVarianceMultiClass) {
        this.priorVarianceMultiClass = priorVarianceMultiClass;
    }

    public void setPriorVarianceBinary(double priorVarianceBinary) {
        this.priorVarianceBinary = priorVarianceBinary;
    }

    @Override
    protected void updateBinaryClassifier(int component, int label, MultiLabelClfDataSet activeDataset, double[] activeGammas) {
        StopWatch stopWatch = new StopWatch();
        stopWatch.start();
        if (this.cbm.binaryClassifiers[component][label] == null || this.cbm.binaryClassifiers[component][label] instanceof PriorProbClassifier) {
            this.cbm.binaryClassifiers[component][label] = new LogisticRegression(2, activeDataset.getNumFeatures());
        }
        int[] binaryLabels = DataSetUtil.toBinaryLabels(activeDataset.getMultiLabels(), label);
        RidgeLogisticOptimizer ridgeLogisticOptimizer = new RidgeLogisticOptimizer((LogisticRegression)this.cbm.binaryClassifiers[component][label], (DataSet)activeDataset, binaryLabels, activeGammas, this.priorVarianceBinary, false);
        ridgeLogisticOptimizer.getOptimizer().getTerminator().setMaxIteration(this.binaryUpdatesPerIter);
        ridgeLogisticOptimizer.optimize();
        if (logger.isDebugEnabled()) {
            logger.debug("time spent on updating component " + component + " label " + label + " = " + stopWatch);
        }
    }

    @Override
    protected void updateMultiClassClassifier() {
        if (logger.isDebugEnabled()) {
            logger.debug("start updateMultiClassClassifier");
        }
        RidgeLogisticOptimizer ridgeLogisticOptimizer = new RidgeLogisticOptimizer((LogisticRegression)this.cbm.multiClassClassifier, (DataSet)this.dataSet, this.gammas, this.priorVarianceMultiClass, true);
        ridgeLogisticOptimizer.getOptimizer().getTerminator().setMaxIteration(this.multiclassUpdatesPerIter);
        ridgeLogisticOptimizer.optimize();
        if (logger.isDebugEnabled()) {
            logger.debug("finish updateMultiClassClassifier");
        }
    }

    @Override
    protected double binaryObj(int component, int classIndex) {
        int[] binaryLabels = DataSetUtil.toBinaryLabels(this.dataSet.getMultiLabels(), classIndex);
        double[][] targetsDistribution = DataSetUtil.labelsToDistributions(binaryLabels, 2);
        double[] weights = new double[this.dataSet.getNumDataPoints()];
        for (int i = 0; i < this.dataSet.getNumDataPoints(); ++i) {
            weights[i] = this.gammas[i][component];
        }
        LogisticLoss logisticLoss = new LogisticLoss((LogisticRegression)this.cbm.binaryClassifiers[component][classIndex], (DataSet)this.dataSet, weights, targetsDistribution, this.priorVarianceBinary, false);
        return logisticLoss.getValue();
    }

    @Override
    protected double multiClassClassifierObj() {
        LogisticLoss logisticLoss = new LogisticLoss((LogisticRegression)this.cbm.multiClassClassifier, this.dataSet, this.gammas, this.priorVarianceMultiClass, true);
        return logisticLoss.getValue();
    }
}

