/*
 * Decompiled with CFR 0.152.
 */
package edu.neu.ccs.pyramid.multilabel_classification.crf;

import edu.neu.ccs.pyramid.dataset.DataSet;
import edu.neu.ccs.pyramid.multilabel_classification.crf.CRFLinearRegression;
import edu.neu.ccs.pyramid.optimization.Terminator;
import java.util.Arrays;
import java.util.stream.IntStream;
import org.apache.mahout.math.Vector;

public class CRFElasticNetLinearRegOptimizer {
    private double regularization = 0.0;
    private double l1Ratio = 0.0;
    private Terminator terminator;
    private DataSet dataSet;
    private double[] labels;
    double[] instanceWeights;
    double sumWeights;
    private CRFLinearRegression linearRegression;

    public CRFElasticNetLinearRegOptimizer(CRFLinearRegression linearRegression, DataSet dataSet, double[] labels, double[] instanceWeights, double sumWeights) {
        this.linearRegression = linearRegression;
        this.dataSet = dataSet;
        this.labels = labels;
        this.instanceWeights = instanceWeights;
        this.terminator = new Terminator();
        this.terminator.setAllowNaN(true);
        this.sumWeights = sumWeights;
    }

    public CRFElasticNetLinearRegOptimizer(CRFLinearRegression linearRegression, DataSet dataSet, double[] labels, double[] instanceWeights) {
        this(linearRegression, dataSet, labels, instanceWeights, Arrays.stream(instanceWeights).parallel().sum());
    }

    public CRFElasticNetLinearRegOptimizer(CRFLinearRegression linearRegression, DataSet dataSet, double[] labels) {
        this(linearRegression, dataSet, labels, CRFElasticNetLinearRegOptimizer.defaultWeights(dataSet.getNumDataPoints()));
    }

    public double getRegularization() {
        return this.regularization;
    }

    public void setRegularization(double regularization) {
        this.regularization = regularization;
    }

    public double getL1Ratio() {
        return this.l1Ratio;
    }

    public void setL1Ratio(double l1Ratio) {
        this.l1Ratio = l1Ratio;
    }

    public Terminator getTerminator() {
        return this.terminator;
    }

    public void optimize() {
        double[] scores = new double[this.dataSet.getNumDataPoints()];
        IntStream.range(0, this.dataSet.getNumDataPoints()).parallel().forEach(i -> {
            scores[i] = this.linearRegression.predict(this.dataSet.getRow(i));
        });
        do {
            this.iterate(scores);
            double loss = this.loss(this.linearRegression, scores, this.labels, this.instanceWeights, this.sumWeights);
            this.terminator.add(loss);
        } while (!this.terminator.shouldTerminate());
    }

    private void iterate(double[] scores) {
        if (this.sumWeights == 0.0) {
            if (this.regularization > 0.0) {
                for (int j = 0; j < this.dataSet.getNumFeatures(); ++j) {
                    this.linearRegression.getWeights().setWeight(j, 0.0);
                }
            }
            return;
        }
        for (int j = 0; j < this.dataSet.getNumFeatures(); ++j) {
            this.optimizeOneFeature(scores, j);
        }
    }

    private void optimizeOneFeature(double[] scores, int featureIndex) {
        double oldCoeff = this.linearRegression.getWeights().getWeights().get(featureIndex);
        double fit = 0.0;
        double denominator = 0.0;
        Vector featureColumn = this.dataSet.getColumn(featureIndex);
        for (Vector.Element element : featureColumn.nonZeroes()) {
            int i = element.index();
            double x = element.get();
            double partialResidual = this.labels[i] - scores[i] + x * oldCoeff;
            double tmp = this.instanceWeights[i] * x;
            fit += tmp * partialResidual;
            denominator += x * tmp;
        }
        double numerator = this.softThreshold(fit /= this.sumWeights);
        denominator = denominator / this.sumWeights + this.regularization * (1.0 - this.l1Ratio);
        double newCoeff = 0.0;
        if (denominator != 0.0) {
            newCoeff = numerator / denominator;
        }
        this.linearRegression.getWeights().setWeight(featureIndex, newCoeff);
        double difference = newCoeff - oldCoeff;
        if (difference != 0.0) {
            for (Vector.Element element : featureColumn.nonZeroes()) {
                int i = element.index();
                double x = element.get();
                scores[i] = scores[i] + difference * x;
            }
        }
    }

    private double loss(CRFLinearRegression linearRegression, double[] scores, double[] labels, double[] instanceWeights, double sumWeights) {
        double mse = IntStream.range(0, scores.length).parallel().mapToDouble(i -> instanceWeights[i] * Math.pow(labels[i] - scores[i], 2.0)).sum();
        double penalty = this.penalty(linearRegression);
        return mse / (2.0 * sumWeights) + penalty;
    }

    private double penalty(CRFLinearRegression linearRegression) {
        Vector vector = linearRegression.getWeights().getWeights();
        double normCombination = (1.0 - this.l1Ratio) * 0.5 * Math.pow(vector.norm(2.0), 2.0) + this.l1Ratio * vector.norm(1.0);
        return this.regularization * normCombination;
    }

    private static double softThreshold(double z, double gamma) {
        if (z > 0.0 && gamma < Math.abs(z)) {
            return z - gamma;
        }
        if (z < 0.0 && gamma < Math.abs(z)) {
            return z + gamma;
        }
        return 0.0;
    }

    private double softThreshold(double z) {
        return CRFElasticNetLinearRegOptimizer.softThreshold(z, this.regularization * this.l1Ratio);
    }

    private static double[] defaultWeights(int numData) {
        double[] weights = new double[numData];
        double weight = 1.0;
        Arrays.fill(weights, weight);
        return weights;
    }
}

