/*
 * Decompiled with CFR 0.152.
 */
package edu.neu.ccs.pyramid.classification.logistic_regression;

import edu.neu.ccs.pyramid.classification.logistic_regression.LogisticRegression;
import edu.neu.ccs.pyramid.classification.logistic_regression.Weights;
import edu.neu.ccs.pyramid.dataset.ClfDataSet;
import edu.neu.ccs.pyramid.dataset.DataSet;
import edu.neu.ccs.pyramid.optimization.Terminator;
import edu.neu.ccs.pyramid.regression.linear_regression.ElasticNetLinearRegOptimizer;
import edu.neu.ccs.pyramid.regression.linear_regression.LinearRegression;
import java.util.Arrays;
import java.util.stream.IntStream;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.mahout.math.DenseVector;
import org.apache.mahout.math.Vector;

public class ElasticNetLogisticTrainer {
    private static final Logger logger = LogManager.getLogger();
    private LogisticRegression logisticRegression;
    private DataSet dataSet;
    private int numClasses;
    private double[][] targets;
    private double[] weights;
    private double sumWeights;
    private double regularization;
    private double l1Ratio;
    private double epsilon;
    private Vector empiricalCounts;
    private Vector predictedCounts;
    private int numParameters;
    private double[][] probabilityMatrix;
    private Terminator terminator;
    private boolean lineSearch = true;
    private boolean isActiveSet = false;

    public boolean isActiveSet() {
        return this.isActiveSet;
    }

    public void setActiveSet(boolean activeSet) {
        this.isActiveSet = activeSet;
    }

    public static Builder newBuilder(LogisticRegression logisticRegression, DataSet dataSet, int numClasses, double[][] targets, double[] weights) {
        return new Builder(logisticRegression, dataSet, numClasses, targets, weights);
    }

    public static Builder newBuilder(LogisticRegression logisticRegression, DataSet dataSet, int numClasses, double[][] targets) {
        return new Builder(logisticRegression, dataSet, numClasses, targets);
    }

    public static Builder newBuilder(LogisticRegression logisticRegression, DataSet dataSet, int numClasses, int[] labels) {
        return new Builder(logisticRegression, dataSet, numClasses, labels);
    }

    public static Builder newBuilder(LogisticRegression logisticRegression, ClfDataSet dataSet) {
        return new Builder(logisticRegression, dataSet);
    }

    public void optimize() {
        this.logisticRegression.setFeatureList(this.dataSet.getFeatureList());
        do {
            this.iterate();
        } while (!this.terminator.shouldTerminate());
    }

    private void optimizeOneClass(int classIndex, double[][] probs, double[][] classScores) {
        int numDataPoints = this.dataSet.getNumDataPoints();
        double[] realLabels = new double[numDataPoints];
        double[] instanceWeights = new double[numDataPoints];
        IntStream.range(0, numDataPoints).parallel().forEach(i -> {
            double prob = probs[i][classIndex];
            double classScore = classScores[i][classIndex];
            double y = this.targets[i][classIndex];
            double frac = 0.0;
            double tmpP = prob * (1.0 - prob);
            if (prob != 0.0 && prob != 1.0) {
                frac = (y - prob) / tmpP;
            }
            if (frac > 1.0) {
                frac = 1.0;
            }
            if (frac < -1.0) {
                frac = -1.0;
            }
            realLabels[i] = classScore + frac;
            instanceWeights[i] = this.weights[i] * tmpP;
        });
        LinearRegression linearRegression = new LinearRegression(this.dataSet.getNumFeatures(), this.logisticRegression.getWeights().getWeightsForClass(classIndex));
        ElasticNetLinearRegOptimizer linearRegTrainer = new ElasticNetLinearRegOptimizer(linearRegression, this.dataSet, realLabels, instanceWeights, this.sumWeights);
        linearRegTrainer.setRegularization(this.regularization);
        linearRegTrainer.setL1Ratio(this.l1Ratio);
        if (logger.isDebugEnabled()) {
            logger.debug("start linearRegTrainer.optimize()");
        }
        linearRegTrainer.optimize();
        if (logger.isDebugEnabled()) {
            logger.debug("finish linearRegTrainer.optimize()");
        }
        if (logger.isDebugEnabled()) {
            logger.debug("loss after optimization of one class = " + this.loss());
        }
    }

    public void iterate() {
        for (int k = 0; k < this.numClasses; ++k) {
            this.optimizeOneClass(k);
        }
        this.terminator.add(this.getLoss());
    }

    public double getLoss() {
        return this.loss();
    }

    public Terminator getTerminator() {
        return this.terminator;
    }

    private void optimizeOneClass(int classIndex) {
        int numDataPoints = this.dataSet.getNumDataPoints();
        double[] realLabels = new double[numDataPoints];
        double[] instanceWeights = new double[numDataPoints];
        IntStream.range(0, numDataPoints).parallel().forEach(i -> {
            double prob = this.logisticRegression.predictClassProbs(this.dataSet.getRow(i))[classIndex];
            double classScore = this.logisticRegression.predictClassScore(this.dataSet.getRow(i), classIndex);
            double y = this.targets[i][classIndex];
            double frac = 0.0;
            double tmpP = prob * (1.0 - prob);
            if (tmpP != 0.0) {
                frac = (y - prob) / tmpP;
            }
            if (frac > 1.0) {
                frac = 1.0;
            }
            if (frac < -1.0) {
                frac = -1.0;
            }
            realLabels[i] = classScore + frac;
            instanceWeights[i] = this.weights[i] * tmpP;
        });
        Weights oldWeights = null;
        if (this.lineSearch) {
            oldWeights = this.logisticRegression.getWeights().deepCopy();
        }
        LinearRegression linearRegression = new LinearRegression(this.dataSet.getNumFeatures(), this.logisticRegression.getWeights().getWeightsForClass(classIndex));
        ElasticNetLinearRegOptimizer linearRegTrainer = new ElasticNetLinearRegOptimizer(linearRegression, this.dataSet, realLabels, instanceWeights, this.sumWeights);
        linearRegTrainer.setRegularization(this.regularization);
        linearRegTrainer.setL1Ratio(this.l1Ratio);
        linearRegTrainer.setActiveSet(this.isActiveSet);
        linearRegTrainer.getTerminator().setMaxIteration(10);
        if (logger.isDebugEnabled()) {
            logger.debug("start linearRegTrainer.optimize()");
        }
        linearRegTrainer.optimize();
        if (logger.isDebugEnabled()) {
            logger.debug("finish linearRegTrainer.optimize()");
        }
        if (this.lineSearch) {
            Weights newWeights = this.logisticRegression.getWeights().deepCopy();
            Vector searchDirection = newWeights.getAllWeights().minus(oldWeights.getAllWeights());
            if (logger.isDebugEnabled()) {
                logger.debug("norm of the search direction = " + searchDirection.norm(2.0));
            }
            this.logisticRegression.getWeights().setWeightVector(oldWeights.getAllWeights());
            Vector gradient = this.predictedCounts.minus(this.empiricalCounts).divide((double)numDataPoints);
            this.lineSearch(searchDirection, gradient);
            this.updatePredictedCounts();
            this.updateClassProbMatrix();
        }
        if (logger.isDebugEnabled()) {
            logger.debug("loss after optimization of one class = " + this.loss());
        }
    }

    private void updateClassProbs(int dataPointIndex) {
        double[] probs = this.logisticRegression.predictClassProbs(this.dataSet.getRow(dataPointIndex));
        for (int k = 0; k < this.numClasses; ++k) {
            this.probabilityMatrix[k][dataPointIndex] = probs[k];
        }
    }

    private void updateClassProbMatrix() {
        if (logger.isDebugEnabled()) {
            logger.debug("start updateClassProbMatrix()");
        }
        IntStream.range(0, this.dataSet.getNumDataPoints()).parallel().forEach(this::updateClassProbs);
        if (logger.isDebugEnabled()) {
            logger.debug("finish updateClassProbMatrix()");
        }
    }

    private double loss() {
        double negativeLogLikelihood = this.logisticRegression.dataSetLogLikelihood(this.dataSet, this.targets, this.weights) * -1.0;
        double penalty = this.penalty();
        return negativeLogLikelihood / this.sumWeights + penalty;
    }

    private double loss(double penalty) {
        double negativeLogLikelihood = this.logisticRegression.dataSetLogLikelihood(this.dataSet, this.targets, this.weights) * -1.0;
        return negativeLogLikelihood / this.sumWeights + penalty;
    }

    private double penalty() {
        return IntStream.range(0, this.logisticRegression.getNumClasses()).parallel().mapToDouble(k -> this.penalty(k)).sum();
    }

    private double penalty(int k) {
        Vector vector = this.logisticRegression.getWeights().getWeightsWithoutBiasForClass(k);
        double normCombination = (1.0 - this.l1Ratio) * 0.5 * Math.pow(vector.norm(2.0), 2.0) + this.l1Ratio * vector.norm(1.0);
        return this.regularization * normCombination;
    }

    private void lineSearch(Vector searchDirection, Vector gradient) {
        double initialStepLength = 1.0;
        double shrinkage = 0.5;
        double c = 1.0E-4;
        double stepLength = initialStepLength;
        Vector start = this.logisticRegression.getWeights().getAllWeights();
        double penalty = this.penalty();
        double value = this.loss(penalty);
        if (logger.isDebugEnabled()) {
            logger.debug("start line search");
            logger.debug("initial loss = " + this.loss());
        }
        double product = gradient.dot(searchDirection);
        Vector localSearchDir = searchDirection;
        while (true) {
            Vector step = localSearchDir.times(stepLength);
            Vector target = start.plus(step);
            this.logisticRegression.getWeights().setWeightVector(target);
            double targetPenalty = this.penalty();
            double targetValue = this.loss(targetPenalty);
            if (targetValue <= value + c * stepLength * (product + targetPenalty - penalty)) {
                if (!logger.isDebugEnabled()) break;
                logger.debug("step size = " + stepLength);
                logger.debug("final loss = " + targetValue);
                logger.debug("line search done");
                break;
            }
            stepLength *= shrinkage;
        }
    }

    private void updateEmpricalCounts() {
        IntStream.range(0, this.numParameters).parallel().forEach(i -> this.empiricalCounts.set(i, this.calEmpricalCount(i)));
    }

    private double calEmpricalCount(int parameterIndex) {
        int classIndex = this.logisticRegression.getWeights().getClassIndex(parameterIndex);
        int featureIndex = this.logisticRegression.getWeights().getFeatureIndex(parameterIndex);
        double count = 0.0;
        if (featureIndex == -1) {
            for (int i = 0; i < this.dataSet.getNumDataPoints(); ++i) {
                count += this.targets[i][classIndex];
            }
        } else {
            Vector featureColumn = this.dataSet.getColumn(featureIndex);
            for (Vector.Element element : featureColumn.nonZeroes()) {
                int dataPointIndex = element.index();
                double featureValue = element.get();
                count += featureValue * this.targets[dataPointIndex][classIndex];
            }
        }
        return count;
    }

    private void updatePredictedCounts() {
        if (logger.isDebugEnabled()) {
            logger.debug("start updatePredictedCounts()");
        }
        IntStream.range(0, this.numParameters).parallel().forEach(i -> this.predictedCounts.set(i, this.calPredictedCount(i)));
        if (logger.isDebugEnabled()) {
            logger.debug("finish updatePredictedCounts()");
        }
    }

    private double calPredictedCount(int parameterIndex) {
        int classIndex = this.logisticRegression.getWeights().getClassIndex(parameterIndex);
        int featureIndex = this.logisticRegression.getWeights().getFeatureIndex(parameterIndex);
        double count = 0.0;
        double[] probs = this.probabilityMatrix[classIndex];
        if (featureIndex == -1) {
            for (int i = 0; i < this.dataSet.getNumDataPoints(); ++i) {
                count += probs[i];
            }
        } else {
            Vector featureColumn = this.dataSet.getColumn(featureIndex);
            for (Vector.Element element : featureColumn.nonZeroes()) {
                int dataPointIndex = element.index();
                double featureValue = element.get();
                count += probs[dataPointIndex] * featureValue;
            }
        }
        return count;
    }

    static /* synthetic */ double[][] access$202(ElasticNetLogisticTrainer x0, double[][] x1) {
        x0.targets = x1;
        return x1;
    }

    static /* synthetic */ double[] access$302(ElasticNetLogisticTrainer x0, double[] x1) {
        x0.weights = x1;
        return x1;
    }

    static /* synthetic */ double[][] access$1302(ElasticNetLogisticTrainer x0, double[][] x1) {
        x0.probabilityMatrix = x1;
        return x1;
    }

    public static class Builder {
        private LogisticRegression logisticRegression;
        private DataSet dataSet;
        private double[][] targets;
        private double[] weights;
        private double sumWeights;
        private int numClasses;
        private double regularization = 1.0E-5;
        private double l1Ratio = 0.0;
        private double epsilon = 0.001;
        private boolean lineSearch = true;

        public Builder(LogisticRegression logisticRegression, DataSet dataSet, int numClasses, int[] labels) {
            int numDataPoints = dataSet.getNumDataPoints();
            double[][] targs = new double[numDataPoints][numClasses];
            for (int i = 0; i < numDataPoints; ++i) {
                targs[i][labels[i]] = 1.0;
            }
            this.logisticRegression = logisticRegression;
            this.dataSet = dataSet;
            this.numClasses = numClasses;
            this.targets = targs;
            this.weights = new double[dataSet.getNumDataPoints()];
            Arrays.fill(this.weights, 1.0);
            this.sumWeights = Arrays.stream(this.weights).parallel().sum();
        }

        public Builder(LogisticRegression logisticRegression, ClfDataSet dataSet) {
            this(logisticRegression, (DataSet)dataSet, dataSet.getNumClasses(), dataSet.getLabels());
        }

        public Builder(LogisticRegression logisticRegression, DataSet dataSet, int numClasses, double[][] targets) {
            this.logisticRegression = logisticRegression;
            this.dataSet = dataSet;
            this.numClasses = numClasses;
            this.targets = targets;
            this.weights = new double[dataSet.getNumDataPoints()];
            Arrays.fill(this.weights, 1.0);
            this.sumWeights = Arrays.stream(this.weights).parallel().sum();
        }

        public Builder(LogisticRegression logisticRegression, DataSet dataSet, int numClasses, double[][] targets, double[] weights) {
            this.logisticRegression = logisticRegression;
            this.dataSet = dataSet;
            this.numClasses = numClasses;
            this.targets = targets;
            this.weights = weights;
            this.sumWeights = Arrays.stream(weights).parallel().sum();
        }

        public Builder setRegularization(double regularization) {
            boolean legal;
            boolean bl = legal = regularization >= 0.0;
            if (!legal) {
                throw new IllegalArgumentException("regularization>=0");
            }
            this.regularization = regularization;
            return this;
        }

        public Builder setL1Ratio(double l1Ratio) {
            boolean legal;
            boolean bl = legal = l1Ratio >= 0.0 && l1Ratio <= 1.0;
            if (!legal) {
                throw new IllegalArgumentException("(l1Ratio>=0)&&(l1Ratio<=1)");
            }
            this.l1Ratio = l1Ratio;
            return this;
        }

        public Builder setEpsilon(double epsilon) {
            boolean legal;
            boolean bl = legal = epsilon > 0.0 && epsilon < 1.0;
            if (!legal) {
                throw new IllegalArgumentException("(epsilon>0)&&(epsilon<1)");
            }
            this.epsilon = epsilon;
            return this;
        }

        public Builder setLineSearch(boolean lineSearch) {
            this.lineSearch = lineSearch;
            return this;
        }

        public ElasticNetLogisticTrainer build() {
            ElasticNetLogisticTrainer trainer = new ElasticNetLogisticTrainer();
            trainer.logisticRegression = this.logisticRegression;
            trainer.dataSet = this.dataSet;
            ElasticNetLogisticTrainer.access$202(trainer, this.targets);
            ElasticNetLogisticTrainer.access$302(trainer, this.weights);
            trainer.sumWeights = this.sumWeights;
            trainer.numClasses = this.numClasses;
            trainer.regularization = this.regularization;
            trainer.l1Ratio = this.l1Ratio;
            trainer.epsilon = this.epsilon;
            trainer.lineSearch = this.lineSearch;
            trainer.numParameters = this.logisticRegression.getWeights().totalSize();
            trainer.empiricalCounts = (Vector)new DenseVector(trainer.numParameters);
            trainer.predictedCounts = (Vector)new DenseVector(trainer.numParameters);
            ElasticNetLogisticTrainer.access$1302(trainer, new double[this.numClasses][this.dataSet.getNumDataPoints()]);
            trainer.updateEmpricalCounts();
            trainer.updateClassProbMatrix();
            trainer.updatePredictedCounts();
            trainer.terminator = new Terminator();
            return trainer;
        }
    }
}

