/*
 * Decompiled with CFR 0.152.
 */
package edu.neu.ccs.pyramid.classification.lkboost;

import edu.neu.ccs.pyramid.classification.PriorProbClassifier;
import edu.neu.ccs.pyramid.classification.lkboost.LKBOutputCalculator;
import edu.neu.ccs.pyramid.classification.lkboost.LKBoost;
import edu.neu.ccs.pyramid.dataset.ClfDataSet;
import edu.neu.ccs.pyramid.dataset.DataSet;
import edu.neu.ccs.pyramid.dataset.DataSetUtil;
import edu.neu.ccs.pyramid.dataset.ProbabilityMatrix;
import edu.neu.ccs.pyramid.optimization.gradient_boosting.GBOptimizer;
import edu.neu.ccs.pyramid.regression.ConstantRegressor;
import edu.neu.ccs.pyramid.regression.RegressorFactory;
import edu.neu.ccs.pyramid.regression.regression_tree.RegTreeConfig;
import edu.neu.ccs.pyramid.regression.regression_tree.RegTreeFactory;
import edu.neu.ccs.pyramid.util.MathUtil;
import java.util.Arrays;
import java.util.stream.IntStream;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;

public class LKBoostOptimizer
extends GBOptimizer {
    private static final Logger logger = LogManager.getLogger();
    private ProbabilityMatrix probabilityMatrix;
    private double[][] targetDistribution;
    private LKBoost boosting;
    private int numClasses;

    public LKBoostOptimizer(LKBoost boosting, DataSet dataSet, RegressorFactory factory, double[] weights, double[][] targetDistribution) {
        super(boosting, dataSet, factory, weights);
        this.boosting = boosting;
        this.targetDistribution = targetDistribution;
        this.numClasses = boosting.getNumClasses();
    }

    public LKBoostOptimizer(LKBoost boosting, DataSet dataSet, RegressorFactory factory, double[][] targetDistribution) {
        this(boosting, dataSet, factory, LKBoostOptimizer.defaultWeights(dataSet.getNumDataPoints()), targetDistribution);
    }

    public LKBoostOptimizer(LKBoost boosting, ClfDataSet dataSet, RegressorFactory factory, double[] weights) {
        this(boosting, dataSet, factory, weights, DataSetUtil.labelDistribution(dataSet));
        this.boosting.labelTranslator = dataSet.getLabelTranslator();
    }

    public LKBoostOptimizer(LKBoost boosting, ClfDataSet dataSet, RegressorFactory factory) {
        this(boosting, dataSet, factory, LKBoostOptimizer.defaultWeights(dataSet.getNumDataPoints()), DataSetUtil.labelDistribution(dataSet));
        this.boosting.labelTranslator = dataSet.getLabelTranslator();
    }

    public LKBoostOptimizer(LKBoost boosting, ClfDataSet dataSet, double[] weights) {
        this(boosting, dataSet, LKBoostOptimizer.defaultFactory(dataSet.getNumClasses()), weights);
    }

    public LKBoostOptimizer(LKBoost boosting, ClfDataSet dataSet) {
        this(boosting, dataSet, LKBoostOptimizer.defaultFactory(dataSet.getNumClasses()));
    }

    @Override
    protected void initializeOthers() {
        this.probabilityMatrix = new ProbabilityMatrix(this.dataSet.getNumDataPoints(), this.numClasses);
    }

    @Override
    protected void updateOthers() {
        this.updateProbabilityMatrix();
    }

    @Override
    protected void addPriors() {
        PriorProbClassifier priorProbClassifier = new PriorProbClassifier(this.numClasses);
        priorProbClassifier.fit(this.dataSet, this.targetDistribution, this.weights);
        double[] probs = priorProbClassifier.getClassProbs();
        double[] scores = MathUtil.inverseSoftMax(probs);
        for (int i = 0; i < scores.length; ++i) {
            if (scores[i] > 5.0) {
                scores[i] = 5.0;
            }
            if (!(scores[i] < -5.0)) continue;
            scores[i] = -5.0;
        }
        for (int k = 0; k < this.numClasses; ++k) {
            ConstantRegressor constant = new ConstantRegressor(scores[k]);
            this.boosting.getEnsemble(k).add(constant);
        }
    }

    @Override
    protected double[] gradient(int ensembleIndex) {
        return IntStream.range(0, this.dataSet.getNumDataPoints()).parallel().mapToDouble(i -> this.gradient(ensembleIndex, i)).toArray();
    }

    private double gradient(int ensembleIndex, int dataPoint) {
        double prob = this.probabilityMatrix.getProbabilitiesForData(dataPoint)[ensembleIndex];
        return this.targetDistribution[dataPoint][ensembleIndex] - prob;
    }

    private void updateClassProb(int i) {
        int numClasses = this.boosting.getNumClasses();
        float[] scores = this.scoreMatrix.getScoresForData(i);
        double logDenominator = MathUtil.logSumExp(scores);
        for (int k = 0; k < numClasses; ++k) {
            double logNumerator = scores[k];
            double pro = Math.exp(logNumerator - logDenominator);
            this.probabilityMatrix.setProbability(i, k, pro);
            if (!Double.isNaN(pro)) continue;
            throw new RuntimeException("pro=NaN, logNumerator = " + logNumerator + ", logDenominator=" + logDenominator + ", scores = " + Arrays.toString(scores));
        }
    }

    private void updateProbabilityMatrix() {
        int numDataPoints = this.dataSet.getNumDataPoints();
        IntStream.range(0, numDataPoints).parallel().forEach(this::updateClassProb);
    }

    private static RegressorFactory defaultFactory(int numClasses) {
        RegTreeConfig regTreeConfig = new RegTreeConfig();
        RegTreeFactory regTreeFactory = new RegTreeFactory(regTreeConfig);
        regTreeFactory.setLeafOutputCalculator(new LKBOutputCalculator(numClasses));
        return regTreeFactory;
    }
}

