/*
 * Decompiled with CFR 0.152.
 */
package edu.neu.ccs.pyramid.optimization.gradient_boosting;

import edu.neu.ccs.pyramid.dataset.DataSet;
import edu.neu.ccs.pyramid.dataset.ScoreMatrix;
import edu.neu.ccs.pyramid.optimization.gradient_boosting.GradientBoosting;
import edu.neu.ccs.pyramid.regression.Regressor;
import edu.neu.ccs.pyramid.regression.RegressorFactory;
import edu.neu.ccs.pyramid.regression.regression_tree.RegressionTree;
import java.util.Arrays;
import java.util.stream.IntStream;
import org.apache.mahout.math.Vector;

public abstract class GBOptimizer {
    protected ScoreMatrix scoreMatrix;
    protected GradientBoosting boosting;
    protected RegressorFactory factory;
    protected DataSet dataSet;
    protected double[] weights;
    protected boolean isInitialized;
    protected double shrinkage = 1.0;

    protected GBOptimizer(GradientBoosting boosting, DataSet dataSet, RegressorFactory factory, double[] weights) {
        this.boosting = boosting;
        this.factory = factory;
        this.dataSet = dataSet;
        this.weights = weights;
        boosting.featureList = dataSet.getFeatureList();
    }

    protected GBOptimizer(GradientBoosting boosting, DataSet dataSet, RegressorFactory factory) {
        this(boosting, dataSet, factory, GBOptimizer.defaultWeights(dataSet.getNumDataPoints()));
    }

    public void initialize() {
        if (this.boosting.getEnsemble(0).getRegressors().size() == 0) {
            this.addPriors();
        }
        this.scoreMatrix = new ScoreMatrix(this.dataSet.getNumDataPoints(), this.boosting.getNumEnsembles());
        this.initStagedScores();
        this.initializeOthers();
        this.updateOthers();
        this.isInitialized = true;
    }

    protected abstract void addPriors();

    protected abstract double[] gradient(int var1);

    protected abstract void initializeOthers();

    protected Regressor fitRegressor(int ensembleIndex) {
        double[] gradients = this.gradient(ensembleIndex);
        Regressor regressor = this.factory.fit(this.dataSet, gradients, this.weights);
        return regressor;
    }

    protected void shrink(Regressor regressor) {
        if (regressor instanceof RegressionTree) {
            ((RegressionTree)regressor).shrink(this.shrinkage);
        }
    }

    protected void updateStagedScore(Regressor regressor, int ensembleIndex, int dataIndex) {
        Vector vector = this.dataSet.getRow(dataIndex);
        double score = regressor.predict(vector);
        this.scoreMatrix.increment(dataIndex, ensembleIndex, score);
    }

    protected void updateStagedScores(Regressor regressor, int ensembleIndex) {
        int numDataPoints = this.dataSet.getNumDataPoints();
        IntStream.range(0, numDataPoints).parallel().forEach(dataIndex -> this.updateStagedScore(regressor, ensembleIndex, dataIndex));
    }

    public void iterate() {
        if (!this.isInitialized) {
            throw new RuntimeException("GBOptimizer is not initialized");
        }
        for (int k = 0; k < this.boosting.getNumEnsembles(); ++k) {
            Regressor regressor = this.fitRegressor(k);
            this.shrink(regressor);
            this.boosting.getEnsemble(k).add(regressor);
            this.updateStagedScores(regressor, k);
        }
        this.updateOthers();
    }

    public void iterate(int numIterations) {
        for (int i = 0; i < numIterations; ++i) {
            this.iterate();
        }
    }

    protected void initStagedScores() {
        for (int k = 0; k < this.boosting.getNumEnsembles(); ++k) {
            for (Regressor regressor : this.boosting.getEnsemble(k).getRegressors()) {
                this.updateStagedScores(regressor, k);
            }
        }
    }

    protected abstract void updateOthers();

    public void setShrinkage(double shrinkage) {
        this.shrinkage = shrinkage;
    }

    public RegressorFactory getRegressorFactory() {
        return this.factory;
    }

    protected static double[] defaultWeights(int numData) {
        double[] weights = new double[numData];
        Arrays.fill(weights, 1.0);
        return weights;
    }
}

