/*
 * Decompiled with CFR 0.152.
 */
package edu.neu.ccs.pyramid.multilabel_classification.hmlgb;

import edu.neu.ccs.pyramid.dataset.GradientMatrix;
import edu.neu.ccs.pyramid.dataset.MultiLabel;
import edu.neu.ccs.pyramid.dataset.MultiLabelClfDataSet;
import edu.neu.ccs.pyramid.dataset.ProbabilityMatrix;
import edu.neu.ccs.pyramid.dataset.ScoreMatrix;
import edu.neu.ccs.pyramid.multilabel_classification.MLPriorProbClassifier;
import edu.neu.ccs.pyramid.multilabel_classification.hmlgb.HMLGBConfig;
import edu.neu.ccs.pyramid.multilabel_classification.hmlgb.HMLGBLeafOutputCalculator;
import edu.neu.ccs.pyramid.multilabel_classification.hmlgb.HMLGradientBoosting;
import edu.neu.ccs.pyramid.regression.ConstantRegressor;
import edu.neu.ccs.pyramid.regression.Regressor;
import edu.neu.ccs.pyramid.regression.regression_tree.RegTreeConfig;
import edu.neu.ccs.pyramid.regression.regression_tree.RegTreeTrainer;
import edu.neu.ccs.pyramid.regression.regression_tree.RegressionTree;
import edu.neu.ccs.pyramid.util.MathUtil;
import java.util.Arrays;
import java.util.List;
import java.util.stream.IntStream;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.mahout.math.Vector;

public class HMLGBTrainer {
    private static final Logger logger = LogManager.getLogger();
    private HMLGBConfig config;
    private List<MultiLabel> assignments;
    private ScoreMatrix scoreMatrix;
    private double[][] assignmentProbabilityMatrix;
    private GradientMatrix gradientMatrix;
    private ProbabilityMatrix probabilityMatrix;
    private HMLGradientBoosting boosting;

    public HMLGBTrainer(HMLGBConfig config, HMLGradientBoosting boosting) {
        if (config.getDataSet().getNumClasses() != boosting.getNumClasses()) {
            throw new IllegalArgumentException("config.getDataSet().getNumClasses()!=boosting.getNumClasses()");
        }
        this.config = config;
        this.boosting = boosting;
        this.assignments = boosting.getAssignments();
        MultiLabelClfDataSet dataSet = config.getDataSet();
        boosting.setFeatureList(dataSet.getFeatureList());
        boosting.setLabelTranslator(dataSet.getLabelTranslator());
        int numClasses = dataSet.getNumClasses();
        int numDataPoints = dataSet.getNumDataPoints();
        int numAssignments = this.assignments.size();
        this.scoreMatrix = new ScoreMatrix(numDataPoints, numClasses);
        if (config.usePrior() && boosting.getRegressors(0).size() == 0) {
            this.setPriorProbs(config.getDataSet());
        }
        this.initScoreMatrix(boosting);
        this.assignmentProbabilityMatrix = new double[numDataPoints][numAssignments];
        this.updateAssignmentProbMatrix();
        this.probabilityMatrix = new ProbabilityMatrix(numDataPoints, numClasses);
        this.updateProbabilityMatrix();
        this.gradientMatrix = new GradientMatrix(numDataPoints, numClasses, GradientMatrix.Objective.MAXIMIZE);
        this.updateClassGradientMatrix();
    }

    public void iterate() {
        for (int k = 0; k < this.boosting.getNumClasses(); ++k) {
            RegressionTree regressor = this.fitClassK(k);
            this.boosting.addRegressor(regressor, k);
            this.updateClassScores(regressor, k);
        }
        this.updateAssignmentProbMatrix();
        this.updateProbabilityMatrix();
        this.updateClassGradientMatrix();
    }

    public void setActiveFeatures(int[] activeFeatures) {
        this.config.setActiveFeatures(activeFeatures);
    }

    public void setActiveDataPoints(int[] activeDataPoints) {
        this.config.setActiveDataPoints(activeDataPoints);
    }

    private void setPriorProbs(double[] probs) {
        if (probs.length != this.boosting.getNumClasses()) {
            throw new IllegalArgumentException("probs.length!=this.numClasses");
        }
        double average = Arrays.stream(probs).map(Math::log).average().getAsDouble();
        for (int k = 0; k < this.boosting.getNumClasses(); ++k) {
            double score = Math.log(probs[k] - average);
            ConstantRegressor constant = new ConstantRegressor(score);
            this.boosting.addRegressor(constant, k);
        }
    }

    private void setPriorProbs(MultiLabelClfDataSet dataSet) {
        MLPriorProbClassifier priorProbClassifier = new MLPriorProbClassifier(dataSet.getNumClasses());
        priorProbClassifier.fit(dataSet);
        double[] probs = priorProbClassifier.getClassProbs();
        this.setPriorProbs(probs);
    }

    private void initScoreMatrix(HMLGradientBoosting boosting) {
        int numClasses = this.config.getDataSet().getNumClasses();
        for (int k = 0; k < numClasses; ++k) {
            for (Regressor regressor : boosting.getRegressors(k)) {
                this.updateClassScores(regressor, k);
            }
        }
    }

    private void updateClassScores(Regressor regressor, int k) {
        MultiLabelClfDataSet dataSet = this.config.getDataSet();
        int numDataPoints = dataSet.getNumDataPoints();
        IntStream.range(0, numDataPoints).parallel().forEach(dataIndex -> this.updateClassScore(regressor, k, dataIndex));
    }

    private void updateClassScore(Regressor regressor, int k, int dataIndex) {
        MultiLabelClfDataSet dataSet = this.config.getDataSet();
        Vector vector = dataSet.getRow(dataIndex);
        double prediction = regressor.predict(vector);
        this.scoreMatrix.increment(dataIndex, k, prediction);
    }

    private void updateAssignmentProbMatrix() {
        int numDataPoints = this.config.getDataSet().getNumDataPoints();
        IntStream.range(0, numDataPoints).parallel().forEach(this::updateAssignmentProbs);
    }

    private void updateAssignmentProbs(int dataPoint) {
        int numAssignments = this.assignments.size();
        double[] assignmentScores = new double[numAssignments];
        for (int a = 0; a < numAssignments; ++a) {
            MultiLabel assignment = this.assignments.get(a);
            assignmentScores[a] = this.calAssignmentScores(dataPoint, assignment);
        }
        double logDenominator = MathUtil.logSumExp(assignmentScores);
        for (int a = 0; a < numAssignments; ++a) {
            double pro;
            double logNumerator = assignmentScores[a];
            this.assignmentProbabilityMatrix[dataPoint][a] = pro = Math.exp(logNumerator - logDenominator);
        }
    }

    private double calAssignmentScores(int dataPoint, MultiLabel assignment) {
        double score = 0.0;
        float[] scores = this.scoreMatrix.getScoresForData(dataPoint);
        for (Integer label : assignment.getMatchedLabels()) {
            score += (double)scores[label];
        }
        return score;
    }

    private void updateClassProbs(int dataPoint) {
        double[] assignmentProbs = this.assignmentProbabilityMatrix[dataPoint];
        int numAssignments = this.assignments.size();
        int numClasses = this.config.getDataSet().getNumClasses();
        for (int k = 0; k < numClasses; ++k) {
            this.probabilityMatrix.setProbability(dataPoint, k, 0.0);
        }
        for (int a = 0; a < numAssignments; ++a) {
            MultiLabel assignment = this.assignments.get(a);
            double prob = assignmentProbs[a];
            for (Integer label : assignment.getMatchedLabels()) {
                this.probabilityMatrix.increment(dataPoint, label, prob);
            }
        }
    }

    private void updateProbabilityMatrix() {
        int numDataPoints = this.config.getDataSet().getNumDataPoints();
        IntStream.range(0, numDataPoints).parallel().forEach(this::updateClassProbs);
    }

    private void updateClassGradientMatrix() {
        int numDataPoints = this.config.getDataSet().getNumDataPoints();
        IntStream.range(0, numDataPoints).parallel().forEach(this::updateClassGradients);
    }

    private void updateClassGradients(int dataPoint) {
        int numClasses = this.config.getDataSet().getNumClasses();
        MultiLabel multiLabel = this.config.getDataSet().getMultiLabels()[dataPoint];
        float[] classProbs = this.probabilityMatrix.getProbabilitiesForData(dataPoint);
        for (int k = 0; k < numClasses; ++k) {
            double gradient = 0.0;
            gradient = multiLabel.matchClass(k) ? (double)(1.0f - classProbs[k]) : (double)(0.0f - classProbs[k]);
            this.gradientMatrix.setGradient(dataPoint, k, gradient);
        }
    }

    private RegressionTree fitClassK(int k) {
        double[] gradients = this.gradientMatrix.getGradientsForClass(k);
        int numClasses = this.config.getDataSet().getNumClasses();
        double learningRate = this.config.getLearningRate();
        HMLGBLeafOutputCalculator leafOutputCalculator = new HMLGBLeafOutputCalculator(numClasses);
        RegTreeConfig regTreeConfig = new RegTreeConfig();
        regTreeConfig.setMaxNumLeaves(this.config.getNumLeaves());
        regTreeConfig.setMinDataPerLeaf(this.config.getMinDataPerLeaf());
        regTreeConfig.setNumSplitIntervals(this.config.getNumSplitIntervals());
        RegressionTree regressionTree = RegTreeTrainer.fit(regTreeConfig, this.config.getDataSet(), gradients, leafOutputCalculator);
        regressionTree.shrink(learningRate);
        return regressionTree;
    }
}

