/*
 * Decompiled with CFR 0.152.
 */
package edu.neu.ccs.pyramid.multilabel_classification.cbm;

import edu.neu.ccs.pyramid.dataset.DataSetUtil;
import edu.neu.ccs.pyramid.dataset.MultiLabelClfDataSet;
import edu.neu.ccs.pyramid.multilabel_classification.cbm.AugmentedLR;
import edu.neu.ccs.pyramid.optimization.Optimizable;
import edu.neu.ccs.pyramid.util.MathUtil;
import edu.neu.ccs.pyramid.util.Vectors;
import org.apache.mahout.math.DenseVector;
import org.apache.mahout.math.Vector;

public class AugmentedLRLoss
implements Optimizable.ByGradientValue {
    private MultiLabelClfDataSet dataSet;
    private double[][] gammas;
    private AugmentedLR augmentedLR;
    private int[] binaryLabels;
    private int numFeatures;
    private int numComponents;
    private int numData;
    private Vector empiricalCounts;
    private Vector predictedCounts;
    private Vector gradient;
    private double[][][] logProbs;
    private double[] expectedProbs;
    private double value;
    private boolean isGradientCacheValid;
    private boolean isValueCacheValid;
    private boolean isProbabilityCacheValid;
    private double featureWeightVariance;
    private double componentWeightVariance;

    public AugmentedLRLoss(MultiLabelClfDataSet dataSet, int labelIndex, double[][] gammas, AugmentedLR augmentedLR, double featureWeightVariance, double componentWeightVariance) {
        this.dataSet = dataSet;
        this.gammas = gammas;
        this.augmentedLR = augmentedLR;
        this.featureWeightVariance = featureWeightVariance;
        this.componentWeightVariance = componentWeightVariance;
        this.binaryLabels = DataSetUtil.toBinaryLabels(dataSet.getMultiLabels(), labelIndex);
        this.numFeatures = dataSet.getNumFeatures();
        this.numComponents = augmentedLR.getNumComponents();
        this.empiricalCounts = new DenseVector(this.numFeatures + this.numComponents + 1);
        this.predictedCounts = new DenseVector(this.numFeatures + this.numComponents + 1);
        this.numData = dataSet.getNumDataPoints();
        this.logProbs = new double[this.numData][this.numComponents][2];
        this.expectedProbs = new double[this.numData];
        this.updateEmpiricalCounts();
        this.isGradientCacheValid = false;
        this.isValueCacheValid = false;
        this.isProbabilityCacheValid = false;
    }

    @Override
    public Vector getParameters() {
        return this.augmentedLR.getAllWeights();
    }

    @Override
    public void setParameters(Vector parameters) {
        this.augmentedLR.setWeights(parameters);
        this.isGradientCacheValid = false;
        this.isValueCacheValid = false;
        this.isValueCacheValid = false;
        this.isProbabilityCacheValid = false;
    }

    @Override
    public double getValue() {
        if (this.isValueCacheValid) {
            return this.value;
        }
        double nll = this.computeNLL();
        this.value = nll + this.penalty();
        this.isValueCacheValid = true;
        return this.value;
    }

    @Override
    public Vector getGradient() {
        if (this.isGradientCacheValid) {
            return this.gradient;
        }
        this.updateProbs();
        this.updateExpectedProbs();
        this.updatePredictedCounts();
        this.updateGradient();
        this.isGradientCacheValid = true;
        return this.gradient;
    }

    private double calEmpiricalCountFeatureWeight(int d) {
        Vector featureColumn = this.dataSet.getColumn(d);
        double sum = 0.0;
        for (Vector.Element element : featureColumn.nonZeroes()) {
            int dataIndex = element.index();
            double feature = element.get();
            if (this.binaryLabels[dataIndex] != 1) continue;
            sum += feature;
        }
        return sum;
    }

    private double calEmpiricalCountComponentWeight(int k) {
        double sum = 0.0;
        for (int i = 0; i < this.numData; ++i) {
            sum += (double)this.binaryLabels[i] * this.gammas[i][k];
        }
        return sum;
    }

    private double calEmpiricalCountBias() {
        double sum = 0.0;
        for (int i = 0; i < this.numData; ++i) {
            sum += (double)this.binaryLabels[i];
        }
        return sum;
    }

    private void updateEmpiricalCounts() {
        double count;
        for (int d = 0; d < this.numFeatures; ++d) {
            count = this.calEmpiricalCountFeatureWeight(d);
            this.empiricalCounts.set(d, count);
        }
        for (int k = 0; k < this.numComponents; ++k) {
            count = this.calEmpiricalCountComponentWeight(k);
            this.empiricalCounts.set(this.numFeatures + k, count);
        }
        this.empiricalCounts.set(this.numFeatures + this.numComponents, this.calEmpiricalCountBias());
    }

    private void updateProbs() {
        for (int i = 0; i < this.numData; ++i) {
            this.logProbs[i] = this.augmentedLR.logAugmentedProbs(this.dataSet.getRow(i));
        }
        this.isProbabilityCacheValid = true;
    }

    private double calPredictedCountFeatureWeight(int d) {
        Vector featureColumn = this.dataSet.getColumn(d);
        double sum = 0.0;
        for (Vector.Element element : featureColumn.nonZeroes()) {
            int dataIndex = element.index();
            double feature = element.get();
            sum += feature * this.expectedProbs[dataIndex];
        }
        return sum;
    }

    private double calPredictedCountComponentWeight(int k) {
        double sum = 0.0;
        for (int i = 0; i < this.numData; ++i) {
            sum += Math.exp(this.logProbs[i][k][1]) * this.gammas[i][k];
        }
        return sum;
    }

    private void updateExpectedProb(int i) {
        double sum = 0.0;
        for (int k = 0; k < this.numComponents; ++k) {
            sum += this.gammas[i][k] * Math.exp(this.logProbs[i][k][1]);
        }
        this.expectedProbs[i] = sum;
    }

    private void updateExpectedProbs() {
        for (int i = 0; i < this.numData; ++i) {
            this.updateExpectedProb(i);
        }
    }

    private double calPredictedCountBias() {
        return MathUtil.arraySum(this.expectedProbs);
    }

    private void updatePredictedCounts() {
        double count;
        for (int d = 0; d < this.numFeatures; ++d) {
            count = this.calPredictedCountFeatureWeight(d);
            this.predictedCounts.set(d, count);
        }
        for (int k = 0; k < this.numComponents; ++k) {
            count = this.calPredictedCountComponentWeight(k);
            this.predictedCounts.set(this.numFeatures + k, count);
        }
        this.predictedCounts.set(this.numFeatures + this.numComponents, this.calPredictedCountBias());
    }

    private double penalty() {
        double sum = 0.0;
        Vector featureWeight = this.augmentedLR.featureWeights();
        sum += Vectors.dot(featureWeight, featureWeight) / (2.0 * this.featureWeightVariance);
        Vector componentWeight = this.augmentedLR.componentWeights();
        return sum += Vectors.dot(componentWeight, componentWeight) / (2.0 * this.componentWeightVariance);
    }

    private Vector penaltyGradient() {
        Vector featureWeights = this.augmentedLR.featureWeights();
        Vector componentWeights = this.augmentedLR.componentWeights();
        DenseVector penaltyGradient = new DenseVector(this.augmentedLR.getAllWeights().size());
        for (int d = 0; d < this.numFeatures; ++d) {
            penaltyGradient.set(d, featureWeights.get(d) / this.featureWeightVariance);
        }
        for (int k = 0; k < this.numComponents; ++k) {
            penaltyGradient.set(this.numFeatures + k, componentWeights.get(k) / this.componentWeightVariance);
        }
        return penaltyGradient;
    }

    private void updateGradient() {
        this.gradient = this.predictedCounts.minus(this.empiricalCounts).plus(this.penaltyGradient());
    }

    private double computeNLL() {
        if (!this.isProbabilityCacheValid) {
            this.updateProbs();
        }
        double sum = 0.0;
        for (int i = 0; i < this.numData; ++i) {
            sum += this.computeNLL(i);
        }
        return sum;
    }

    private double computeNLL(int i) {
        double sum = 0.0;
        int label = this.binaryLabels[i];
        for (int k = 0; k < this.numComponents; ++k) {
            if (label == 1) {
                sum += this.gammas[i][k] * this.logProbs[i][k][1];
                continue;
            }
            sum += this.gammas[i][k] * this.logProbs[i][k][0];
        }
        return -1.0 * sum;
    }
}

