/*
 * Decompiled with CFR 0.152.
 */
package edu.neu.ccs.pyramid.classification.logistic_regression;

import edu.neu.ccs.pyramid.classification.logistic_regression.LogisticRegression;
import edu.neu.ccs.pyramid.dataset.ClfDataSet;
import edu.neu.ccs.pyramid.dataset.DataSet;
import edu.neu.ccs.pyramid.eval.KLDivergence;
import edu.neu.ccs.pyramid.optimization.Optimizable;
import edu.neu.ccs.pyramid.util.Vectors;
import java.util.Arrays;
import java.util.stream.IntStream;
import org.apache.commons.lang3.time.StopWatch;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.mahout.math.DenseVector;
import org.apache.mahout.math.Vector;

public class LogisticLoss
implements Optimizable.ByGradientValue {
    private static final Logger logger = LogManager.getLogger();
    private LogisticRegression logisticRegression;
    private DataSet dataSet;
    private double[] weights;
    private double[][] targetDistributions;
    private Vector empiricalCounts;
    private Vector predictedCounts;
    private Vector gradient;
    private int numParameters;
    private int numClasses;
    private double[][] logProbabilityMatrix;
    private double[][] probabilityMatrix;
    private double value;
    private boolean isGradientCacheValid;
    private boolean isValueCacheValid;
    private boolean isProbabilityCacheValid;
    private boolean isParallel = false;
    private double priorGaussianVariance;
    private double regularization;
    private double l1Ratio;

    public LogisticLoss(LogisticRegression logisticRegression, DataSet dataSet, double[] weights, double[][] targetDistributions, double priorGaussianVariance, boolean parallel) {
        this.logisticRegression = logisticRegression;
        this.targetDistributions = targetDistributions;
        this.isParallel = parallel;
        this.numParameters = logisticRegression.getWeights().totalSize();
        this.dataSet = dataSet;
        this.weights = weights;
        this.priorGaussianVariance = priorGaussianVariance;
        this.empiricalCounts = new DenseVector(this.numParameters);
        this.predictedCounts = new DenseVector(this.numParameters);
        this.numClasses = targetDistributions[0].length;
        this.logProbabilityMatrix = new double[this.numClasses][dataSet.getNumDataPoints()];
        this.probabilityMatrix = new double[this.numClasses][dataSet.getNumDataPoints()];
        this.updateEmpricalCounts();
        this.isValueCacheValid = false;
        this.isGradientCacheValid = false;
        this.isProbabilityCacheValid = false;
    }

    public LogisticLoss(LogisticRegression logisticRegression, DataSet dataSet, double[] weights, double[][] targetDistributions, double regularization, double l1Ratio, boolean parallel) {
        this.logisticRegression = logisticRegression;
        this.targetDistributions = targetDistributions;
        this.isParallel = parallel;
        this.numParameters = logisticRegression.getWeights().totalSize();
        this.dataSet = dataSet;
        this.weights = weights;
        this.regularization = regularization;
        this.l1Ratio = l1Ratio;
        this.empiricalCounts = new DenseVector(this.numParameters);
        this.predictedCounts = new DenseVector(this.numParameters);
        this.numClasses = targetDistributions[0].length;
        this.logProbabilityMatrix = new double[this.numClasses][dataSet.getNumDataPoints()];
        this.probabilityMatrix = new double[this.numClasses][dataSet.getNumDataPoints()];
        this.updateEmpricalCounts();
        this.isValueCacheValid = false;
        this.isGradientCacheValid = false;
        this.isProbabilityCacheValid = false;
    }

    public LogisticLoss(LogisticRegression logisticRegression, DataSet dataSet, double[][] targetDistributions, double regularization, double l1Ratio, boolean parallel) {
        this(logisticRegression, dataSet, LogisticLoss.defaultWeights(dataSet.getNumDataPoints()), targetDistributions, regularization, l1Ratio, parallel);
    }

    public LogisticLoss(LogisticRegression logisticRegression, DataSet dataSet, double[][] targetDistributions, double gaussianPriorVariance, boolean parallel) {
        this(logisticRegression, dataSet, LogisticLoss.defaultWeights(dataSet.getNumDataPoints()), targetDistributions, gaussianPriorVariance, parallel);
    }

    public LogisticLoss(LogisticRegression logisticRegression, ClfDataSet dataSet, double gaussianPriorVariance, boolean parallel) {
        this(logisticRegression, dataSet, LogisticLoss.defaultTargetDistribution(dataSet), gaussianPriorVariance, parallel);
    }

    @Override
    public Vector getParameters() {
        return this.logisticRegression.getWeights().getAllWeights();
    }

    @Override
    public void setParameters(Vector parameters) {
        this.logisticRegression.getWeights().setWeightVector(parameters);
        this.isValueCacheValid = false;
        this.isGradientCacheValid = false;
        this.isProbabilityCacheValid = false;
    }

    @Override
    public double getValue() {
        if (this.isValueCacheValid) {
            return this.value;
        }
        double kl = this.kl();
        if (logger.isDebugEnabled()) {
            logger.debug("kl divergence = " + kl);
        }
        this.value = kl + this.penaltyValue();
        this.isValueCacheValid = true;
        return this.value;
    }

    public double getValueEL() {
        if (this.isValueCacheValid) {
            return this.value;
        }
        double kl = this.kl();
        if (logger.isDebugEnabled()) {
            logger.debug("kl divergence = " + kl);
        }
        this.value = kl / (double)this.dataSet.getNumDataPoints() + this.penaltyValueEL();
        this.isValueCacheValid = true;
        return this.value;
    }

    private double kl() {
        if (!this.isProbabilityCacheValid) {
            this.updateClassProbMatrix();
        }
        IntStream intStream = this.isParallel ? IntStream.range(0, this.dataSet.getNumDataPoints()).parallel() : IntStream.range(0, this.dataSet.getNumDataPoints());
        return intStream.mapToDouble(this::kl).sum();
    }

    private double kl(int dataPointIndex) {
        if (this.weights[dataPointIndex] == 0.0) {
            return 0.0;
        }
        double[] predicted = new double[this.numClasses];
        for (int k = 0; k < this.numClasses; ++k) {
            predicted[k] = this.logProbabilityMatrix[k][dataPointIndex];
        }
        return this.weights[dataPointIndex] * KLDivergence.klGivenPLogQ(this.targetDistributions[dataPointIndex], predicted);
    }

    private double penaltyValue(int classIndex) {
        double square = 0.0;
        Vector weightVector = this.logisticRegression.getWeights().getWeightsWithoutBiasForClass(classIndex);
        return (square += Vectors.dot(weightVector, weightVector)) / (2.0 * this.priorGaussianVariance);
    }

    public double penaltyValue() {
        IntStream intStream = this.isParallel ? IntStream.range(0, this.numClasses).parallel() : IntStream.range(0, this.numClasses);
        return intStream.mapToDouble(this::penaltyValue).sum();
    }

    private double penaltyValueEL(int classIndex) {
        Vector vector = this.logisticRegression.getWeights().getWeightsWithoutBiasForClass(classIndex);
        double normCombination = (1.0 - this.l1Ratio) * 0.5 * Math.pow(vector.norm(2.0), 2.0) + this.l1Ratio * vector.norm(1.0);
        return this.regularization * normCombination;
    }

    public double penaltyValueEL() {
        IntStream intStream = this.isParallel ? IntStream.range(0, this.numClasses).parallel() : IntStream.range(0, this.numClasses);
        return intStream.mapToDouble(this::penaltyValueEL).sum();
    }

    @Override
    public Vector getGradient() {
        StopWatch stopWatch = null;
        if (logger.isDebugEnabled()) {
            stopWatch = new StopWatch();
            stopWatch.start();
        }
        if (this.isGradientCacheValid) {
            if (logger.isDebugEnabled()) {
                logger.debug("time spent on getGradient = " + stopWatch);
            }
            return this.gradient;
        }
        this.updateClassProbMatrix();
        this.updatePredictedCounts();
        this.updateGradient();
        this.isGradientCacheValid = true;
        if (logger.isDebugEnabled()) {
            logger.debug("time spent on getGradient = " + stopWatch);
        }
        return this.gradient;
    }

    private void updateGradient() {
        this.gradient = this.predictedCounts.minus(this.empiricalCounts).plus(this.penaltyGradient());
    }

    private Vector penaltyGradient() {
        Vector weightsVector = this.logisticRegression.getWeights().getAllWeights();
        DenseVector penalty = new DenseVector(weightsVector.size());
        penalty = penalty.plus(weightsVector.divide(this.priorGaussianVariance));
        for (int j : this.logisticRegression.getWeights().getAllBiasPositions()) {
            penalty.set(j, 0.0);
        }
        return penalty;
    }

    private void updateEmpricalCounts() {
        IntStream intStream = this.isParallel ? IntStream.range(0, this.numParameters).parallel() : IntStream.range(0, this.numParameters);
        intStream.forEach(i -> this.empiricalCounts.set(i, this.calEmpricalCount(i)));
    }

    private void updatePredictedCounts() {
        StopWatch stopWatch = new StopWatch();
        if (logger.isDebugEnabled()) {
            stopWatch.start();
        }
        IntStream intStream = this.isParallel ? IntStream.range(0, this.numParameters).parallel() : IntStream.range(0, this.numParameters);
        intStream.forEach(i -> this.predictedCounts.set(i, this.calPredictedCount(i)));
        if (logger.isDebugEnabled()) {
            logger.debug("time spent on updatePredictedCounts = " + stopWatch);
        }
    }

    private double calEmpricalCount(int parameterIndex) {
        int classIndex = this.logisticRegression.getWeights().getClassIndex(parameterIndex);
        int featureIndex = this.logisticRegression.getWeights().getFeatureIndex(parameterIndex);
        double count = 0.0;
        if (featureIndex == -1) {
            for (int i = 0; i < this.dataSet.getNumDataPoints(); ++i) {
                count += this.targetDistributions[i][classIndex] * this.weights[i];
            }
        } else {
            Vector featureColumn = this.dataSet.getColumn(featureIndex);
            for (Vector.Element element : featureColumn.nonZeroes()) {
                int dataPointIndex = element.index();
                if (this.weights[dataPointIndex] == 0.0) continue;
                double featureValue = element.get();
                count += featureValue * this.targetDistributions[dataPointIndex][classIndex] * this.weights[dataPointIndex];
            }
        }
        return count;
    }

    private double calPredictedCount(int parameterIndex) {
        int classIndex = this.logisticRegression.getWeights().getClassIndex(parameterIndex);
        int featureIndex = this.logisticRegression.getWeights().getFeatureIndex(parameterIndex);
        double count = 0.0;
        double[] probs = this.probabilityMatrix[classIndex];
        if (featureIndex == -1) {
            for (int i = 0; i < this.dataSet.getNumDataPoints(); ++i) {
                if (this.weights[i] == 0.0) continue;
                count += probs[i] * this.weights[i];
            }
        } else {
            Vector featureColumn = this.dataSet.getColumn(featureIndex);
            for (Vector.Element element : featureColumn.nonZeroes()) {
                int dataPointIndex = element.index();
                if (this.weights[dataPointIndex] == 0.0) continue;
                double featureValue = element.get();
                count += probs[dataPointIndex] * featureValue * this.weights[dataPointIndex];
            }
        }
        return count;
    }

    private void updateClassProbs(int dataPointIndex) {
        if (this.weights[dataPointIndex] == 0.0) {
            return;
        }
        double[] logProbs = this.logisticRegression.predictLogClassProbs(this.dataSet.getRow(dataPointIndex));
        for (int k = 0; k < this.numClasses; ++k) {
            this.logProbabilityMatrix[k][dataPointIndex] = logProbs[k];
            this.probabilityMatrix[k][dataPointIndex] = Math.exp(logProbs[k]);
        }
    }

    private void updateClassProbMatrix() {
        StopWatch stopWatch = new StopWatch();
        stopWatch.start();
        IntStream intStream = this.isParallel ? IntStream.range(0, this.dataSet.getNumDataPoints()).parallel() : IntStream.range(0, this.dataSet.getNumDataPoints());
        intStream.forEach(this::updateClassProbs);
        this.isProbabilityCacheValid = true;
        if (logger.isDebugEnabled()) {
            logger.debug("time spent on updateClassProbMatrix = " + stopWatch);
        }
    }

    private static double[] defaultWeights(int numDataPoints) {
        double[] weights = new double[numDataPoints];
        Arrays.fill(weights, 1.0);
        return weights;
    }

    private static double[][] defaultTargetDistribution(ClfDataSet dataSet) {
        double[][] targetDistributions = new double[dataSet.getNumDataPoints()][dataSet.getNumClasses()];
        int[] labels = dataSet.getLabels();
        for (int i = 0; i < labels.length; ++i) {
            int label = labels[i];
            targetDistributions[i][label] = 1.0;
        }
        return targetDistributions;
    }
}

