/*
 * Decompiled with CFR 0.152.
 */
package edu.neu.ccs.pyramid.classification.l2boost;

import edu.neu.ccs.pyramid.classification.l2boost.L2BLeafOutputCalculator;
import edu.neu.ccs.pyramid.classification.l2boost.L2Boost;
import edu.neu.ccs.pyramid.dataset.ClfDataSet;
import edu.neu.ccs.pyramid.dataset.DataSet;
import edu.neu.ccs.pyramid.dataset.DataSetUtil;
import edu.neu.ccs.pyramid.dataset.ProbabilityMatrix;
import edu.neu.ccs.pyramid.optimization.gradient_boosting.GBOptimizer;
import edu.neu.ccs.pyramid.regression.RegressorFactory;
import edu.neu.ccs.pyramid.regression.regression_tree.RegTreeConfig;
import edu.neu.ccs.pyramid.regression.regression_tree.RegTreeFactory;
import java.util.stream.IntStream;

public class L2BoostOptimizer
extends GBOptimizer {
    private L2Boost boosting;
    private ProbabilityMatrix probabilityMatrix;
    private double[][] targetDistribution;

    public L2BoostOptimizer(L2Boost boosting, DataSet dataSet, RegressorFactory factory, double[] weights, double[][] targetDistribution) {
        super(boosting, dataSet, factory, weights);
        this.boosting = boosting;
        this.targetDistribution = targetDistribution;
    }

    public L2BoostOptimizer(L2Boost boosting, DataSet dataSet, double[][] targetDistribution, RegressorFactory factory) {
        this(boosting, dataSet, factory, L2BoostOptimizer.defaultWeights(dataSet.getNumDataPoints()), targetDistribution);
    }

    public L2BoostOptimizer(L2Boost boosting, ClfDataSet dataSet, RegressorFactory factory) {
        this(boosting, (DataSet)dataSet, DataSetUtil.labelDistribution(dataSet), factory);
    }

    public L2BoostOptimizer(L2Boost boosting, ClfDataSet dataSet) {
        this(boosting, dataSet, L2BoostOptimizer.defaultFactory());
    }

    @Override
    protected void initializeOthers() {
        this.probabilityMatrix = new ProbabilityMatrix(this.dataSet.getNumDataPoints(), 2);
    }

    @Override
    protected void updateOthers() {
        this.updateProbMatrix();
    }

    @Override
    protected void addPriors() {
    }

    private void updateProbMatrix() {
        int numDataPoints = this.dataSet.getNumDataPoints();
        IntStream.range(0, numDataPoints).parallel().forEach(this::updateProbability);
    }

    private void updateProbability(int i) {
        double positiveScore = this.scoreMatrix.getScoresForData(i)[0];
        double[] scores = new double[2];
        scores[1] = positiveScore;
        double[] probs = this.boosting.predictClassProbs(scores);
        for (int k = 0; k < 2; ++k) {
            this.probabilityMatrix.setProbability(i, k, probs[k]);
        }
    }

    @Override
    protected double[] gradient(int ensembleIndex) {
        return IntStream.range(0, this.dataSet.getNumDataPoints()).parallel().mapToDouble(i -> this.targetDistribution[i][1] - (double)this.probabilityMatrix.getProbabilitiesForData(i)[1]).toArray();
    }

    private static RegressorFactory defaultFactory() {
        RegTreeConfig regTreeConfig = new RegTreeConfig();
        RegTreeFactory regTreeFactory = new RegTreeFactory(regTreeConfig);
        regTreeFactory.setLeafOutputCalculator(new L2BLeafOutputCalculator());
        return regTreeFactory;
    }
}

