/*
 * Decompiled with CFR 0.152.
 */
package edu.neu.ccs.pyramid.multilabel_classification.imlgb;

import edu.neu.ccs.pyramid.dataset.DataSetUtil;
import edu.neu.ccs.pyramid.dataset.MultiLabel;
import edu.neu.ccs.pyramid.dataset.MultiLabelClfDataSet;
import edu.neu.ccs.pyramid.dataset.ScoreMatrix;
import edu.neu.ccs.pyramid.multilabel_classification.MLPriorProbClassifier;
import edu.neu.ccs.pyramid.multilabel_classification.imlgb.IMLGBConfig;
import edu.neu.ccs.pyramid.multilabel_classification.imlgb.IMLGradientBoosting;
import edu.neu.ccs.pyramid.regression.ConstantRegressor;
import edu.neu.ccs.pyramid.regression.Regressor;
import edu.neu.ccs.pyramid.regression.regression_tree.AverageOutputCalculator;
import edu.neu.ccs.pyramid.regression.regression_tree.RegTreeConfig;
import edu.neu.ccs.pyramid.regression.regression_tree.RegTreeTrainer;
import edu.neu.ccs.pyramid.regression.regression_tree.RegressionTree;
import edu.neu.ccs.pyramid.util.MathUtil;
import java.util.List;
import java.util.stream.IntStream;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.mahout.math.Vector;

public class IMLGBTrainer {
    private static final Logger logger = LogManager.getLogger();
    private IMLGBConfig config;
    private ScoreMatrix scoreMatrix;
    private IMLGradientBoosting boosting;
    private boolean[] shouldStop;

    public IMLGBTrainer(IMLGBConfig config, IMLGradientBoosting boosting) {
        if (config.getDataSet().getNumClasses() != boosting.getNumClasses()) {
            throw new IllegalArgumentException("config.getDataSet().getNumClasses()!=boosting.getNumClasses()");
        }
        this.config = config;
        this.boosting = boosting;
        MultiLabelClfDataSet dataSet = config.getDataSet();
        boosting.setFeatureList(dataSet.getFeatureList());
        boosting.setLabelTranslator(dataSet.getLabelTranslator());
        int numClasses = dataSet.getNumClasses();
        int numDataPoints = dataSet.getNumDataPoints();
        this.scoreMatrix = new ScoreMatrix(numDataPoints, numClasses);
        if (config.usePrior() && boosting.getRegressors(0).size() == 0) {
            this.setPriorProbs(dataSet);
        }
        this.initStagedClassScoreMatrix(boosting);
        List<MultiLabel> assignments = DataSetUtil.gatherMultiLabels(dataSet);
        boosting.setAssignments(assignments);
        this.shouldStop = new boolean[numClasses];
    }

    public void setShouldStop(int classIndex) {
        this.shouldStop[classIndex] = true;
        if (logger.isDebugEnabled()) {
            logger.debug("class " + classIndex + " is set to stop");
        }
    }

    public boolean[] getShouldStop() {
        return this.shouldStop;
    }

    public void iterate() {
        for (int k = 0; k < this.boosting.getNumClasses(); ++k) {
            if (this.shouldStop[k]) continue;
            if (logger.isDebugEnabled()) {
                logger.debug("updating class " + k);
            }
            RegressionTree regressor = this.fitClassK(k);
            this.boosting.addRegressor(regressor, k);
            this.updateStagedClassScores(regressor, k);
        }
    }

    public void iterateWithoutStagingScores(boolean[] shouldStop) {
        for (int k = 0; k < this.boosting.getNumClasses(); ++k) {
            if (shouldStop[k]) continue;
            if (logger.isDebugEnabled()) {
                logger.debug("updating class " + k);
            }
            RegressionTree regressor = this.fitClassK(k);
            this.boosting.addRegressor(regressor, k);
        }
    }

    private void setPriorProbs(double[] probs) {
        if (probs.length != this.boosting.getNumClasses()) {
            throw new IllegalArgumentException("probs.length!=this.numClasses");
        }
        for (int k = 0; k < this.boosting.getNumClasses(); ++k) {
            double score = MathUtil.inverseSigmoid(probs[k]);
            double soft = Math.sqrt(Math.abs(score));
            if (score < 0.0) {
                soft = -soft;
            }
            ConstantRegressor constant = new ConstantRegressor(soft);
            this.boosting.addRegressor(constant, k);
        }
    }

    private void setPriorProbs(MultiLabelClfDataSet dataSet) {
        MLPriorProbClassifier priorProbClassifier = new MLPriorProbClassifier(dataSet.getNumClasses());
        priorProbClassifier.fit(dataSet);
        double[] probs = priorProbClassifier.getClassProbs();
        this.setPriorProbs(probs);
    }

    private void initStagedClassScoreMatrix(IMLGradientBoosting boosting) {
        int numClasses = this.config.getDataSet().getNumClasses();
        for (int k = 0; k < numClasses; ++k) {
            for (Regressor regressor : boosting.getRegressors(k)) {
                this.updateStagedClassScores(regressor, k);
            }
        }
    }

    private void updateStagedClassScores(Regressor regressor, int k) {
        MultiLabelClfDataSet dataSet = this.config.getDataSet();
        int numDataPoints = dataSet.getNumDataPoints();
        IntStream.range(0, numDataPoints).parallel().forEach(dataIndex -> this.updateStagedClassScore(regressor, k, dataIndex));
    }

    private void updateStagedClassScore(Regressor regressor, int k, int dataIndex) {
        MultiLabelClfDataSet dataSet = this.config.getDataSet();
        Vector vector = dataSet.getRow(dataIndex);
        double prediction = regressor.predict(vector);
        this.scoreMatrix.increment(dataIndex, k, prediction);
    }

    private double calClassProb(int dataPoint, int k) {
        double score;
        double logNumerator = score = (double)this.scoreMatrix.getScoresForData(dataPoint)[k];
        double[] scores = new double[]{0.0, score};
        double logDenominator = MathUtil.logSumExp(scores);
        double pro = Math.exp(logNumerator - logDenominator);
        return pro;
    }

    private double[] computeGradientForClass(int k) {
        return IntStream.range(0, this.config.getDataSet().getNumDataPoints()).parallel().mapToDouble(i -> this.computeGradient(k, i)).toArray();
    }

    private double computeGradient(int k, int dataPoint) {
        MultiLabel multiLabel = this.config.getDataSet().getMultiLabels()[dataPoint];
        double classProb = this.calClassProb(dataPoint, k);
        double gradient = multiLabel.matchClass(k) ? 1.0 - classProb : 0.0 - classProb;
        return gradient;
    }

    private RegressionTree fitClassK(int k) {
        double[] gradients = this.computeGradientForClass(k);
        int numClasses = this.config.getDataSet().getNumClasses();
        double learningRate = this.config.getLearningRate();
        AverageOutputCalculator leafOutputCalculator = new AverageOutputCalculator();
        RegTreeConfig regTreeConfig = new RegTreeConfig();
        regTreeConfig.setMaxNumLeaves(this.config.getNumLeaves());
        regTreeConfig.setMinDataPerLeaf(this.config.getMinDataPerLeaf());
        regTreeConfig.setNumSplitIntervals(this.config.getNumSplitIntervals());
        RegressionTree regressionTree = RegTreeTrainer.fit(regTreeConfig, this.config.getDataSet(), gradients, leafOutputCalculator);
        regressionTree.shrink(learningRate);
        return regressionTree;
    }
}

