/*
 * Decompiled with CFR 0.152.
 */
package edu.neu.ccs.pyramid.multilabel_classification.multi_label_logistic_regression;

import edu.neu.ccs.pyramid.dataset.MultiLabel;
import edu.neu.ccs.pyramid.dataset.MultiLabelClfDataSet;
import edu.neu.ccs.pyramid.multilabel_classification.multi_label_logistic_regression.MLLogisticRegression;
import edu.neu.ccs.pyramid.optimization.Optimizable;
import java.util.stream.IntStream;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.mahout.math.DenseVector;
import org.apache.mahout.math.Vector;

public class MLLogisticLoss
implements Optimizable.ByGradient,
Optimizable.ByGradientValue {
    private static final Logger logger = LogManager.getLogger();
    private MLLogisticRegression mlLogisticRegression;
    private MultiLabelClfDataSet dataSet;
    private double gaussianPriorVariance;
    private Vector empiricalCounts;
    private Vector predictedCounts;
    private Vector gradient;
    private int numParameters;
    private double[][] classScoreMatrix;
    private double[][] classProbMatrix;
    private double[][] assignmentProbMatrix;
    private double[][] assignmentScoreMatrix;
    private double value;
    private boolean isGradientCacheValid;
    private boolean isValueCacheValid;

    public MLLogisticLoss(MLLogisticRegression mlLogisticRegression, MultiLabelClfDataSet dataSet, double gaussianPriorVariance) {
        int numDataPoints = dataSet.getNumDataPoints();
        int numAssignments = mlLogisticRegression.getAssignments().size();
        int numClasses = dataSet.getNumClasses();
        this.mlLogisticRegression = mlLogisticRegression;
        this.numParameters = mlLogisticRegression.getWeights().totalSize();
        this.dataSet = dataSet;
        this.gaussianPriorVariance = gaussianPriorVariance;
        this.empiricalCounts = new DenseVector(this.numParameters);
        this.predictedCounts = new DenseVector(this.numParameters);
        this.classScoreMatrix = new double[numDataPoints][numClasses];
        this.classProbMatrix = new double[dataSet.getNumDataPoints()][dataSet.getNumClasses()];
        this.assignmentProbMatrix = new double[numDataPoints][numAssignments];
        this.assignmentScoreMatrix = new double[numDataPoints][numAssignments];
        this.updateEmpricalCounts();
        this.isValueCacheValid = false;
        this.isGradientCacheValid = false;
    }

    @Override
    public Vector getParameters() {
        return this.mlLogisticRegression.getWeights().getAllWeights();
    }

    @Override
    public void setParameters(Vector parameters) {
        this.mlLogisticRegression.getWeights().setWeightVector(parameters);
        this.isValueCacheValid = false;
        this.isGradientCacheValid = false;
    }

    @Override
    public double getValue() {
        if (this.isValueCacheValid) {
            return this.value;
        }
        Vector parameters = this.getParameters();
        this.value = -1.0 * this.mlLogisticRegression.dataSetLogLikelihood(this.dataSet) + parameters.dot(parameters) / (2.0 * this.gaussianPriorVariance);
        this.isValueCacheValid = true;
        return this.value;
    }

    @Override
    public Vector getGradient() {
        if (this.isGradientCacheValid) {
            return this.gradient;
        }
        this.updateClassScoreMatrix();
        this.updateAssignmentScoreMatrix();
        this.updateAssignmentProbMatrix();
        this.updateClassProbMatrix();
        this.updatePredictedCounts();
        this.updateGradient();
        this.isGradientCacheValid = true;
        return this.gradient;
    }

    private void updateGradient() {
        Vector weights = this.mlLogisticRegression.getWeights().getAllWeights();
        this.gradient = this.predictedCounts.minus(this.empiricalCounts).plus(weights.divide(this.gaussianPriorVariance));
    }

    private void updateEmpricalCounts() {
        IntStream.range(0, this.numParameters).parallel().forEach(i -> this.empiricalCounts.set(i, this.calEmpricalCount(i)));
    }

    private void updatePredictedCounts() {
        if (logger.isDebugEnabled()) {
            logger.debug("start method  updatePredictedCounts");
        }
        IntStream.range(0, this.numParameters).parallel().forEach(i -> this.predictedCounts.set(i, this.calPredictedCount(i)));
        if (logger.isDebugEnabled()) {
            logger.debug("finish method  updatePredictedCounts");
        }
    }

    private double calEmpricalCount(int parameterIndex) {
        int classIndex = this.mlLogisticRegression.getWeights().getClassIndex(parameterIndex);
        MultiLabel[] labels = this.dataSet.getMultiLabels();
        int featureIndex = this.mlLogisticRegression.getWeights().getFeatureIndex(parameterIndex);
        double count = 0.0;
        if (featureIndex == -1) {
            for (int i = 0; i < this.dataSet.getNumDataPoints(); ++i) {
                if (!labels[i].matchClass(classIndex)) continue;
                count += 1.0;
            }
        } else {
            Vector featureColumn = this.dataSet.getColumn(featureIndex);
            for (Vector.Element element : featureColumn.nonZeroes()) {
                int dataPointIndex = element.index();
                double featureValue = element.get();
                MultiLabel label = labels[dataPointIndex];
                if (!label.matchClass(classIndex)) continue;
                count += featureValue;
            }
        }
        return count;
    }

    private double calPredictedCount(int parameterIndex) {
        int classIndex = this.mlLogisticRegression.getWeights().getClassIndex(parameterIndex);
        int featureIndex = this.mlLogisticRegression.getWeights().getFeatureIndex(parameterIndex);
        double count = 0.0;
        if (featureIndex == -1) {
            for (int i = 0; i < this.dataSet.getNumDataPoints(); ++i) {
                count += this.classProbMatrix[i][classIndex];
            }
        } else {
            Vector featureColumn = this.dataSet.getColumn(featureIndex);
            for (Vector.Element element : featureColumn.nonZeroes()) {
                int dataPointIndex = element.index();
                double featureValue = element.get();
                count += this.classProbMatrix[dataPointIndex][classIndex] * featureValue;
            }
        }
        return count;
    }

    public double[] getClassProbs(int dataPointIndex) {
        return this.classProbMatrix[dataPointIndex];
    }

    private void updateClassScoreMatrix() {
        IntStream.range(0, this.dataSet.getNumDataPoints()).parallel().forEach(i -> {
            this.classScoreMatrix[i] = this.mlLogisticRegression.predictClassScores(this.dataSet.getRow(i));
        });
    }

    private void updateAssignmentScoreMatrix() {
        IntStream.range(0, this.dataSet.getNumDataPoints()).parallel().forEach(i -> {
            this.assignmentScoreMatrix[i] = this.mlLogisticRegression.calAssignmentScores(this.classScoreMatrix[i]);
        });
    }

    private void updateAssignmentProbMatrix() {
        IntStream.range(0, this.dataSet.getNumDataPoints()).parallel().forEach(i -> {
            this.assignmentProbMatrix[i] = this.mlLogisticRegression.calAssignmentProbs(this.assignmentScoreMatrix[i]);
        });
    }

    private void updateClassProbMatrix() {
        IntStream.range(0, this.dataSet.getNumDataPoints()).parallel().forEach(i -> {
            this.classProbMatrix[i] = this.mlLogisticRegression.calClassProbs(this.assignmentProbMatrix[i]);
        });
    }
}

