/*
 * Decompiled with CFR 0.152.
 */
package edu.neu.ccs.pyramid.regression.linear_regression;

import edu.neu.ccs.pyramid.dataset.DataSet;
import edu.neu.ccs.pyramid.dataset.RegDataSet;
import edu.neu.ccs.pyramid.optimization.Terminator;
import edu.neu.ccs.pyramid.regression.linear_regression.LinearRegression;
import java.util.Arrays;
import java.util.BitSet;
import java.util.stream.IntStream;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.mahout.math.Vector;

public class ElasticNetLinearRegOptimizer {
    private static final Logger logger = LogManager.getLogger();
    private double regularization = 0.0;
    private double l1Ratio = 0.0;
    private Terminator terminator;
    private LinearRegression linearRegression;
    private DataSet dataSet;
    private double[] labels;
    double[] instanceWeights;
    double sumWeights;
    private boolean isActiveSet = false;

    public boolean isActiveSet() {
        return this.isActiveSet;
    }

    public void setActiveSet(boolean activeSet) {
        this.isActiveSet = activeSet;
    }

    public ElasticNetLinearRegOptimizer(LinearRegression linearRegression, DataSet dataSet, double[] labels, double[] instanceWeights, double sumWeights) {
        this.linearRegression = linearRegression;
        this.dataSet = dataSet;
        this.labels = labels;
        this.instanceWeights = instanceWeights;
        this.terminator = new Terminator();
        this.sumWeights = sumWeights;
        this.isActiveSet = false;
    }

    public ElasticNetLinearRegOptimizer(LinearRegression linearRegression, DataSet dataSet, double[] labels, double[] instanceWeights) {
        this(linearRegression, dataSet, labels, instanceWeights, Arrays.stream(instanceWeights).parallel().sum());
    }

    public ElasticNetLinearRegOptimizer(LinearRegression linearRegression, DataSet dataSet, double[] labels) {
        this(linearRegression, dataSet, labels, ElasticNetLinearRegOptimizer.defaultWeights(dataSet.getNumDataPoints()));
    }

    public ElasticNetLinearRegOptimizer(LinearRegression linearRegression, RegDataSet dataSet) {
        this(linearRegression, dataSet, dataSet.getLabels());
    }

    public double getRegularization() {
        return this.regularization;
    }

    public void setRegularization(double regularization) {
        this.regularization = regularization;
    }

    public double getL1Ratio() {
        return this.l1Ratio;
    }

    public void setL1Ratio(double l1Ratio) {
        this.l1Ratio = l1Ratio;
    }

    public Terminator getTerminator() {
        return this.terminator;
    }

    public void optimize() {
        if (!this.isActiveSet) {
            this.normalOptimize();
        } else {
            this.terminator.setMode(Terminator.Mode.FINISH_MAX_ITER);
            this.activeSetOptimize();
        }
    }

    private void activeSetOptimize() {
        double[] scores = new double[this.dataSet.getNumDataPoints()];
        IntStream.range(0, this.dataSet.getNumDataPoints()).parallel().forEach(i -> {
            scores[i] = this.linearRegression.predict(this.dataSet.getRow(i));
        });
        this.iterate(scores);
        this.terminator.add(1.0);
        BitSet activeSet = this.updateActiveSet();
        boolean shouldTerminate = false;
        while (!shouldTerminate) {
            int maxIter = 0;
            do {
                this.activeSetIterate(scores, activeSet);
                this.terminator.add(1.0);
            } while (!this.terminator.shouldTerminate() && ++maxIter <= 5);
            this.iterate(scores);
            this.terminator.add(1.0);
            if (this.terminator.shouldTerminate()) break;
            BitSet latestActiveSet = this.updateActiveSet();
            shouldTerminate = this.isActiveSetChanged(activeSet, latestActiveSet);
            activeSet = latestActiveSet;
        }
    }

    private boolean isActiveSetChanged(BitSet activeSet, BitSet latestActiveSet) {
        if (activeSet.cardinality() != latestActiveSet.cardinality()) {
            return false;
        }
        return activeSet.equals(latestActiveSet);
    }

    private BitSet updateActiveSet() {
        BitSet activeSet = new BitSet();
        for (Vector.Element element : this.linearRegression.getWeights().getWeightsWithoutBias().nonZeroes()) {
            activeSet.set(element.index());
        }
        return activeSet;
    }

    private void normalOptimize() {
        block3: {
            double loss;
            double[] scores = new double[this.dataSet.getNumDataPoints()];
            IntStream.range(0, this.dataSet.getNumDataPoints()).parallel().forEach(i -> {
                scores[i] = this.linearRegression.predict(this.dataSet.getRow(i));
            });
            double lastLoss = this.loss(this.linearRegression, scores, this.labels, this.instanceWeights, this.sumWeights);
            if (logger.isDebugEnabled()) {
                logger.debug("initial loss = " + lastLoss);
            }
            do {
                this.iterate(scores);
                loss = this.loss(this.linearRegression, scores, this.labels, this.instanceWeights, this.sumWeights);
                if (logger.isDebugEnabled()) {
                    logger.debug("loss = " + loss);
                }
                this.terminator.add(loss);
            } while (!this.terminator.shouldTerminate());
            if (!logger.isDebugEnabled()) break block3;
            logger.debug("final loss = " + loss);
        }
    }

    private void activeSetIterate(double[] scores, BitSet activeSet) {
        if (this.sumWeights == 0.0) {
            if (this.regularization > 0.0) {
                for (int j = 0; j < this.dataSet.getNumFeatures(); ++j) {
                    this.linearRegression.getWeights().setWeight(j, 0.0);
                }
            }
            return;
        }
        double oldBias = this.linearRegression.getWeights().getBias();
        double newBias = IntStream.range(0, this.dataSet.getNumDataPoints()).parallel().mapToDouble(i -> this.instanceWeights[i] * (this.labels[i] - scores[i] + oldBias)).sum() / this.sumWeights;
        this.linearRegression.getWeights().setBias(newBias);
        double difference = newBias - oldBias;
        if (difference != 0.0) {
            IntStream.range(0, this.dataSet.getNumDataPoints()).parallel().forEach(i -> {
                scores[i] = scores[i] + difference;
            });
        }
        int j = activeSet.nextSetBit(0);
        while (j >= 0) {
            this.optimizeOneFeature(scores, j);
            j = activeSet.nextSetBit(j + 1);
        }
    }

    private void iterate(double[] scores) {
        if (this.sumWeights == 0.0) {
            if (this.regularization > 0.0) {
                for (int j = 0; j < this.dataSet.getNumFeatures(); ++j) {
                    this.linearRegression.getWeights().setWeight(j, 0.0);
                }
            }
            return;
        }
        double oldBias = this.linearRegression.getWeights().getBias();
        double newBias = IntStream.range(0, this.dataSet.getNumDataPoints()).parallel().mapToDouble(i -> this.instanceWeights[i] * (this.labels[i] - scores[i] + oldBias)).sum() / this.sumWeights;
        this.linearRegression.getWeights().setBias(newBias);
        double difference = newBias - oldBias;
        if (difference != 0.0) {
            IntStream.range(0, this.dataSet.getNumDataPoints()).parallel().forEach(i -> {
                scores[i] = scores[i] + difference;
            });
        }
        for (int j = 0; j < this.dataSet.getNumFeatures(); ++j) {
            this.optimizeOneFeature(scores, j);
        }
    }

    private void optimizeOneFeature(double[] scores, int featureIndex) {
        double oldCoeff = this.linearRegression.getWeights().getWeightsWithoutBias().get(featureIndex);
        double fit = 0.0;
        double denominator = 0.0;
        Vector featureColumn = this.dataSet.getColumn(featureIndex);
        for (Vector.Element element : featureColumn.nonZeroes()) {
            int i = element.index();
            double x = element.get();
            double partialResidual = this.labels[i] - scores[i] + x * oldCoeff;
            double tmp = this.instanceWeights[i] * x;
            fit += tmp * partialResidual;
            denominator += x * tmp;
        }
        double numerator = this.softThreshold(fit /= this.sumWeights);
        denominator = denominator / this.sumWeights + this.regularization * (1.0 - this.l1Ratio);
        double newCoeff = 0.0;
        if (denominator != 0.0) {
            newCoeff = numerator / denominator;
        }
        this.linearRegression.getWeights().setWeight(featureIndex, newCoeff);
        double difference = newCoeff - oldCoeff;
        if (difference != 0.0) {
            for (Vector.Element element : featureColumn.nonZeroes()) {
                int i = element.index();
                double x = element.get();
                scores[i] = scores[i] + difference * x;
            }
        }
    }

    private double loss(LinearRegression linearRegression, double[] scores, double[] labels, double[] instanceWeights, double sumWeights) {
        double mse = IntStream.range(0, scores.length).parallel().mapToDouble(i -> instanceWeights[i] * Math.pow(labels[i] - scores[i], 2.0)).sum();
        double penalty = this.penalty(linearRegression);
        return mse / (2.0 * sumWeights) + penalty;
    }

    private double penalty(LinearRegression linearRegression) {
        Vector vector = linearRegression.getWeights().getWeightsWithoutBias();
        double normCombination = (1.0 - this.l1Ratio) * 0.5 * Math.pow(vector.norm(2.0), 2.0) + this.l1Ratio * vector.norm(1.0);
        return this.regularization * normCombination;
    }

    private static double softThreshold(double z, double gamma) {
        if (z > 0.0 && gamma < Math.abs(z)) {
            return z - gamma;
        }
        if (z < 0.0 && gamma < Math.abs(z)) {
            return z + gamma;
        }
        return 0.0;
    }

    private double softThreshold(double z) {
        return ElasticNetLinearRegOptimizer.softThreshold(z, this.regularization * this.l1Ratio);
    }

    private static double[] defaultWeights(int numData) {
        double[] weights = new double[numData];
        double weight = 1.0;
        Arrays.fill(weights, weight);
        return weights;
    }
}

