/*
 * Decompiled with CFR 0.152.
 */
package edu.neu.ccs.pyramid.multilabel_classification.cbm;

import edu.neu.ccs.pyramid.classification.PriorProbClassifier;
import edu.neu.ccs.pyramid.classification.logistic_regression.ElasticNetLogisticTrainer;
import edu.neu.ccs.pyramid.classification.logistic_regression.LogisticLoss;
import edu.neu.ccs.pyramid.classification.logistic_regression.LogisticRegression;
import edu.neu.ccs.pyramid.dataset.DataSet;
import edu.neu.ccs.pyramid.dataset.DataSetUtil;
import edu.neu.ccs.pyramid.dataset.MultiLabelClfDataSet;
import edu.neu.ccs.pyramid.multilabel_classification.cbm.AbstractCBMOptimizer;
import edu.neu.ccs.pyramid.multilabel_classification.cbm.CBM;
import org.apache.commons.lang3.time.StopWatch;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;

public class ENCBMOptimizer
extends AbstractCBMOptimizer {
    private static final Logger logger = LogManager.getLogger();
    private double regularizationMultiClass = 1.0;
    private double regularizationBinary = 1.0;
    private double l1RatioBinary = 0.0;
    private double l1RatioMultiClass = 0.0;
    private boolean lineSearch = true;
    private boolean activeSet = false;

    public ENCBMOptimizer(CBM cbm, MultiLabelClfDataSet dataSet) {
        super(cbm, dataSet);
    }

    @Override
    protected void updateBinaryClassifier(int component, int label, MultiLabelClfDataSet activeDataset, double[] activeGammas) {
        StopWatch stopWatch = new StopWatch();
        stopWatch.start();
        if (this.cbm.binaryClassifiers[component][label] == null || this.cbm.binaryClassifiers[component][label] instanceof PriorProbClassifier) {
            this.cbm.binaryClassifiers[component][label] = new LogisticRegression(2, activeDataset.getNumFeatures());
        }
        int[] binaryLabels = DataSetUtil.toBinaryLabels(activeDataset.getMultiLabels(), label);
        double[][] targetsDistribution = DataSetUtil.labelsToDistributions(binaryLabels, 2);
        ElasticNetLogisticTrainer elasticNetLogisticTrainer = new ElasticNetLogisticTrainer.Builder((LogisticRegression)this.cbm.binaryClassifiers[component][label], activeDataset, 2, targetsDistribution, activeGammas).setRegularization(this.regularizationBinary).setL1Ratio(this.l1RatioBinary).setLineSearch(this.lineSearch).build();
        elasticNetLogisticTrainer.setActiveSet(this.activeSet);
        elasticNetLogisticTrainer.getTerminator().setMaxIteration(this.binaryUpdatesPerIter);
        elasticNetLogisticTrainer.optimize();
        if (logger.isDebugEnabled()) {
            logger.debug("time spent on updating component " + component + " label " + label + " = " + stopWatch);
        }
    }

    @Override
    protected void updateMultiClassClassifier() {
        if (logger.isDebugEnabled()) {
            logger.debug("start updateMultiClassClassifier");
        }
        ElasticNetLogisticTrainer elasticNetLogisticTrainer = new ElasticNetLogisticTrainer.Builder((LogisticRegression)this.cbm.multiClassClassifier, (DataSet)this.dataSet, this.cbm.multiClassClassifier.getNumClasses(), this.gammas).setRegularization(this.regularizationMultiClass).setL1Ratio(this.l1RatioMultiClass).setLineSearch(this.lineSearch).build();
        elasticNetLogisticTrainer.setActiveSet(this.activeSet);
        elasticNetLogisticTrainer.getTerminator().setMaxIteration(this.multiclassUpdatesPerIter);
        elasticNetLogisticTrainer.optimize();
        if (logger.isDebugEnabled()) {
            logger.debug("finish updateMultiClassClassifier");
        }
    }

    @Override
    protected double binaryObj(int component, int classIndex) {
        int[] binaryLabels = DataSetUtil.toBinaryLabels(this.dataSet.getMultiLabels(), classIndex);
        double[][] targetsDistribution = DataSetUtil.labelsToDistributions(binaryLabels, 2);
        double[] weights = new double[this.dataSet.getNumDataPoints()];
        for (int i = 0; i < this.dataSet.getNumDataPoints(); ++i) {
            weights[i] = this.gammas[i][component];
        }
        LogisticLoss logisticLoss = new LogisticLoss((LogisticRegression)this.cbm.binaryClassifiers[component][classIndex], this.dataSet, weights, targetsDistribution, this.regularizationBinary, this.l1RatioBinary, false);
        return logisticLoss.getValueEL();
    }

    @Override
    protected double multiClassClassifierObj() {
        LogisticLoss logisticLoss = new LogisticLoss((LogisticRegression)this.cbm.multiClassClassifier, (DataSet)this.dataSet, this.gammas, this.regularizationMultiClass, this.l1RatioMultiClass, true);
        return logisticLoss.getValueEL();
    }

    public double getRegularizationMultiClass() {
        return this.regularizationMultiClass;
    }

    public void setRegularizationMultiClass(double regularizationMultiClass) {
        this.regularizationMultiClass = regularizationMultiClass;
    }

    public double getRegularizationBinary() {
        return this.regularizationBinary;
    }

    public void setRegularizationBinary(double regularizationBinary) {
        this.regularizationBinary = regularizationBinary;
    }

    public double getL1RatioBinary() {
        return this.l1RatioBinary;
    }

    public void setL1RatioBinary(double l1RatioBinary) {
        this.l1RatioBinary = l1RatioBinary;
    }

    public double getL1RatioMultiClass() {
        return this.l1RatioMultiClass;
    }

    public void setL1RatioMultiClass(double l1RatioMultiClass) {
        this.l1RatioMultiClass = l1RatioMultiClass;
    }

    public boolean isLineSearch() {
        return this.lineSearch;
    }

    public void setLineSearch(boolean lineSearch) {
        this.lineSearch = lineSearch;
    }

    public void setActiveSet(boolean activeSet) {
        this.activeSet = activeSet;
    }
}

