/*
 * Decompiled with CFR 0.152.
 */
package edu.neu.ccs.pyramid.regression.p_boost;

import edu.neu.ccs.pyramid.dataset.DataSet;
import edu.neu.ccs.pyramid.dataset.RegDataSet;
import edu.neu.ccs.pyramid.optimization.gradient_boosting.GBOptimizer;
import edu.neu.ccs.pyramid.optimization.gradient_boosting.GradientBoosting;
import edu.neu.ccs.pyramid.regression.ConstantRegressor;
import edu.neu.ccs.pyramid.regression.RegressorFactory;
import edu.neu.ccs.pyramid.util.MathUtil;
import java.util.stream.IntStream;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;

public class PBoostOptimizer
extends GBOptimizer {
    private static final Logger logger = LogManager.getLogger();
    private double[] labels;

    public PBoostOptimizer(GradientBoosting boosting, DataSet dataSet, RegressorFactory factory, double[] weights, double[] labels) {
        super(boosting, dataSet, factory, weights);
        this.labels = labels;
    }

    public PBoostOptimizer(GradientBoosting boosting, DataSet dataSet, RegressorFactory factory, double[] labels) {
        super(boosting, dataSet, factory);
        this.labels = labels;
    }

    public PBoostOptimizer(GradientBoosting boosting, RegDataSet dataSet, RegressorFactory factory) {
        this(boosting, dataSet, factory, dataSet.getLabels());
    }

    @Override
    protected void addPriors() {
        double average = IntStream.range(0, this.dataSet.getNumDataPoints()).parallel().mapToDouble(i -> this.labels[i] * this.weights[i]).average().getAsDouble();
        ConstantRegressor constant = new ConstantRegressor(average);
        this.boosting.getEnsemble(0).add(constant);
    }

    @Override
    protected double[] gradient(int ensembleIndex) {
        int n = this.dataSet.getNumDataPoints();
        double labelAve = MathUtil.arraySum(this.labels) / (double)n;
        double[] pred = IntStream.range(0, n).mapToDouble(i -> this.scoreMatrix.getScoresForData(i)[0]).toArray();
        double predAve = MathUtil.arraySum(pred) / (double)n;
        double[] labelDev = IntStream.range(0, n).mapToDouble(i -> this.labels[i] - labelAve).toArray();
        double[] predDev = IntStream.range(0, n).mapToDouble(i -> pred[i] - predAve).toArray();
        double labelDevAve = MathUtil.arraySum(labelDev) / (double)n;
        double predDevAve = MathUtil.arraySum(predDev) / (double)n;
        double product = IntStream.range(0, n).mapToDouble(i -> predDev[i] * labelDev[i]).sum();
        double sigmaSquqre = IntStream.range(0, n).mapToDouble(i -> Math.pow(predDev[i], 2.0)).sum();
        if (sigmaSquqre == 0.0) {
            sigmaSquqre = 1.0;
        }
        double sigma = Math.sqrt(sigmaSquqre);
        double[] gradient = new double[n];
        for (int i2 = 0; i2 < n; ++i2) {
            double g = (labelDev[i2] - labelDevAve) * sigma - 1.0 / sigma * (product * (predDev[i2] - predDevAve));
            gradient[i2] = g /= sigmaSquqre;
        }
        return gradient;
    }

    @Override
    protected void initializeOthers() {
    }

    @Override
    protected void updateOthers() {
    }
}

