/*
 * Decompiled with CFR 0.152.
 */
package edu.neu.ccs.pyramid.multilabel_classification.imllr;

import edu.neu.ccs.pyramid.dataset.MultiLabel;
import edu.neu.ccs.pyramid.dataset.MultiLabelClfDataSet;
import edu.neu.ccs.pyramid.multilabel_classification.imllr.IMLLogisticRegression;
import edu.neu.ccs.pyramid.optimization.Optimizable;
import java.util.stream.IntStream;
import org.apache.mahout.math.DenseVector;
import org.apache.mahout.math.Vector;

public class IMLLogisticLoss
implements Optimizable.ByGradientValue {
    private IMLLogisticRegression logisticRegression;
    private MultiLabelClfDataSet dataSet;
    private double gaussianPriorVariance;
    private Vector empiricalCounts;
    private Vector predictedCounts;
    private Vector gradient;
    private int numParameters;
    private double[][] classProbMatrix;
    private double value;
    private boolean isGradientCacheValid;
    private boolean isValueCacheValid;

    public IMLLogisticLoss(IMLLogisticRegression mlLogisticRegression, MultiLabelClfDataSet dataSet, double gaussianPriorVariance) {
        this.logisticRegression = mlLogisticRegression;
        this.numParameters = mlLogisticRegression.getWeights().totalSize();
        this.dataSet = dataSet;
        this.gaussianPriorVariance = gaussianPriorVariance;
        this.empiricalCounts = new DenseVector(this.numParameters);
        this.predictedCounts = new DenseVector(this.numParameters);
        this.classProbMatrix = new double[dataSet.getNumDataPoints()][dataSet.getNumClasses()];
        this.updateEmpricalCounts();
        this.isValueCacheValid = false;
        this.isGradientCacheValid = false;
    }

    @Override
    public Vector getParameters() {
        return this.logisticRegression.getWeights().getAllWeights();
    }

    @Override
    public void setParameters(Vector parameters) {
        this.logisticRegression.getWeights().setWeightVector(parameters);
        this.isValueCacheValid = false;
        this.isGradientCacheValid = false;
    }

    @Override
    public double getValue() {
        if (this.isValueCacheValid) {
            return this.value;
        }
        Vector parameters = this.getParameters();
        this.value = -1.0 * this.logisticRegression.dataSetLogLikelihood(this.dataSet) + parameters.dot(parameters) / (2.0 * this.gaussianPriorVariance);
        this.isValueCacheValid = true;
        return this.value;
    }

    @Override
    public Vector getGradient() {
        if (this.isGradientCacheValid) {
            return this.gradient;
        }
        this.updateClassProbMatrix();
        this.updatePredictedCounts();
        this.updateGradient();
        this.isGradientCacheValid = true;
        return this.gradient;
    }

    private void updateGradient() {
        Vector weights = this.logisticRegression.getWeights().getAllWeights();
        this.gradient = this.predictedCounts.minus(this.empiricalCounts).plus(weights.divide(this.gaussianPriorVariance));
    }

    private void updateEmpricalCounts() {
        IntStream.range(0, this.numParameters).parallel().forEach(i -> this.empiricalCounts.set(i, this.calEmpricalCount(i)));
    }

    private void updatePredictedCounts() {
        IntStream.range(0, this.numParameters).parallel().forEach(i -> this.predictedCounts.set(i, this.calPredictedCount(i)));
    }

    private double calEmpricalCount(int parameterIndex) {
        int classIndex = this.logisticRegression.getWeights().getClassIndex(parameterIndex);
        MultiLabel[] labels = this.dataSet.getMultiLabels();
        int featureIndex = this.logisticRegression.getWeights().getFeatureIndex(parameterIndex);
        double count = 0.0;
        if (featureIndex == -1) {
            for (int i = 0; i < this.dataSet.getNumDataPoints(); ++i) {
                if (!labels[i].matchClass(classIndex)) continue;
                count += 1.0;
            }
        } else {
            Vector featureColumn = this.dataSet.getColumn(featureIndex);
            for (Vector.Element element : featureColumn.nonZeroes()) {
                int dataPointIndex = element.index();
                double featureValue = element.get();
                MultiLabel label = labels[dataPointIndex];
                if (!label.matchClass(classIndex)) continue;
                count += featureValue;
            }
        }
        return count;
    }

    private double calPredictedCount(int parameterIndex) {
        int classIndex = this.logisticRegression.getWeights().getClassIndex(parameterIndex);
        int featureIndex = this.logisticRegression.getWeights().getFeatureIndex(parameterIndex);
        double count = 0.0;
        if (featureIndex == -1) {
            for (int i = 0; i < this.dataSet.getNumDataPoints(); ++i) {
                count += this.classProbMatrix[i][classIndex];
            }
        } else {
            Vector featureColumn = this.dataSet.getColumn(featureIndex);
            for (Vector.Element element : featureColumn.nonZeroes()) {
                int dataPointIndex = element.index();
                double featureValue = element.get();
                count += this.classProbMatrix[dataPointIndex][classIndex] * featureValue;
            }
        }
        return count;
    }

    public double[] getClassProbs(int dataPointIndex) {
        return this.classProbMatrix[dataPointIndex];
    }

    private void updateClassProbMatrix() {
        IntStream.range(0, this.dataSet.getNumDataPoints()).parallel().forEach(i -> {
            this.classProbMatrix[i] = this.logisticRegression.predictClassProbs(this.dataSet.getRow(i));
        });
    }
}

