/*
 * Decompiled with CFR 0.152.
 */
package edu.neu.ccs.pyramid.regression.m_boost;

import edu.neu.ccs.pyramid.dataset.DataSet;
import edu.neu.ccs.pyramid.dataset.RegDataSet;
import edu.neu.ccs.pyramid.optimization.gradient_boosting.GBOptimizer;
import edu.neu.ccs.pyramid.optimization.gradient_boosting.GradientBoosting;
import edu.neu.ccs.pyramid.regression.ConstantRegressor;
import edu.neu.ccs.pyramid.regression.RegressorFactory;
import edu.neu.ccs.pyramid.util.MathUtil;
import java.util.Arrays;
import java.util.stream.IntStream;
import org.apache.commons.math3.stat.descriptive.DescriptiveStatistics;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;

public class MBoostOptimizer
extends GBOptimizer {
    private static final Logger logger = LogManager.getLogger();
    private double[] labels;
    private double alpha = 0.9;

    public MBoostOptimizer(GradientBoosting boosting, DataSet dataSet, RegressorFactory factory, double[] weights, double[] labels, double alpha) {
        super(boosting, dataSet, factory, weights);
        this.labels = labels;
        this.alpha = alpha;
    }

    public MBoostOptimizer(GradientBoosting boosting, DataSet dataSet, RegressorFactory factory, double[] labels, double alpha) {
        super(boosting, dataSet, factory);
        this.labels = labels;
        this.alpha = alpha;
    }

    public MBoostOptimizer(GradientBoosting boosting, RegDataSet dataSet, RegressorFactory factory, double alpha) {
        this(boosting, dataSet, factory, dataSet.getLabels(), alpha);
    }

    @Override
    protected void addPriors() {
        double median = MathUtil.weightedMedian(this.labels, this.weights);
        ConstantRegressor constant = new ConstantRegressor(median);
        this.boosting.getEnsemble(0).add(constant);
    }

    @Override
    protected double[] gradient(int ensembleIndex) {
        double[] residual = IntStream.range(0, this.dataSet.getNumDataPoints()).parallel().mapToDouble(i -> this.labels[i] - (double)this.scoreMatrix.getScoresForData(i)[0]).toArray();
        double[] absResidual = Arrays.stream(residual).map(Math::abs).toArray();
        DescriptiveStatistics statistics = new DescriptiveStatistics(absResidual);
        double threshold = statistics.getPercentile(this.alpha * 100.0);
        double[] gradient = new double[residual.length];
        for (int i2 = 0; i2 < gradient.length; ++i2) {
            gradient[i2] = absResidual[i2] <= threshold ? residual[i2] : threshold * MathUtil.sign(residual[i2]);
        }
        return gradient;
    }

    @Override
    protected void initializeOthers() {
    }

    @Override
    protected void updateOthers() {
    }
}

