/*
 * Decompiled with CFR 0.152.
 */
package edu.neu.ccs.pyramid.multilabel_classification.crf;

import edu.neu.ccs.pyramid.dataset.MultiLabel;
import edu.neu.ccs.pyramid.dataset.MultiLabelClfDataSet;
import edu.neu.ccs.pyramid.multilabel_classification.crf.CMLCRF;
import edu.neu.ccs.pyramid.optimization.Optimizable;
import edu.neu.ccs.pyramid.util.MathUtil;
import java.util.ArrayList;
import java.util.List;
import java.util.stream.IntStream;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.mahout.math.DenseVector;
import org.apache.mahout.math.Vector;

public class KLLoss
implements Optimizable.ByGradientValue {
    private static final Logger logger = LogManager.getLogger();
    private CMLCRF cmlcrf;
    private List<MultiLabel> supportedCombinations;
    private int numSupport;
    private MultiLabelClfDataSet dataSet;
    private double gaussianPriorVariance;
    private int numClasses;
    private int numParameters;
    private int numWeightsForFeatures;
    private int numWeightsForLabelPairs;
    private Vector gradient;
    private double value;
    private double[] empiricalCounts;
    private int[] parameterToL1;
    private int[] parameterToL2;
    private int[] parameterToClass;
    private int[] parameterToFeature;
    private boolean[][] comContainsLabel;
    private boolean isParallel = true;
    private boolean isGradientCacheValid = false;
    private boolean isValueCacheValid = false;
    private double[][] classScoreMatrix;
    private double[][] classProbMatrix;
    private double[][] combProbMatrix;
    private double[][] combScoreMatrix;
    private int numData;
    private List<List<Integer>> labelPairToCombination;
    private boolean regularizeAll = true;
    private double[] combProbSums;
    private double[][] targetDistribution;
    private double[][] targetMarginals;

    public KLLoss(CMLCRF cmlcrf, MultiLabelClfDataSet dataSet, double[][] targetDistribution, double gaussianPriorVariance) {
        this.cmlcrf = cmlcrf;
        this.supportedCombinations = cmlcrf.getSupportCombinations();
        this.numSupport = cmlcrf.getNumSupports();
        this.dataSet = dataSet;
        this.numData = dataSet.getNumDataPoints();
        this.numClasses = dataSet.getNumClasses();
        this.targetDistribution = targetDistribution;
        this.gaussianPriorVariance = gaussianPriorVariance;
        this.numParameters = cmlcrf.getWeights().totalSize();
        this.numWeightsForFeatures = cmlcrf.getWeights().getNumWeightsForFeatures();
        this.numWeightsForLabelPairs = cmlcrf.getWeights().getNumWeightsForLabels();
        this.classScoreMatrix = new double[this.numData][this.numClasses];
        this.classProbMatrix = new double[this.numData][this.numClasses];
        this.combScoreMatrix = new double[this.numData][this.numSupport];
        this.combProbMatrix = new double[this.numData][this.numSupport];
        this.isGradientCacheValid = false;
        this.isValueCacheValid = false;
        this.empiricalCounts = new double[this.numParameters];
        this.gradient = new DenseVector(this.numParameters);
        this.combProbSums = new double[this.numSupport];
        this.initTargetMarginals();
        this.mapParameters();
        this.initComContainsLabel();
        this.mapPairToCombination();
        this.initEmpiricalCounts();
    }

    public void setRegularizeAll(boolean regularizeAll) {
        this.regularizeAll = regularizeAll;
    }

    @Override
    public Vector getGradient() {
        if (this.isGradientCacheValid) {
            return this.gradient;
        }
        if (logger.isDebugEnabled()) {
            logger.debug("start method getGradient()");
        }
        this.updateClassScoreMatrix();
        this.updateAssignmentScoreMatrix();
        this.updateAssignmentProbMatrix();
        this.updateCombProbSums();
        this.updateClassProbMatrix();
        this.updateGradient();
        this.isGradientCacheValid = true;
        if (logger.isDebugEnabled()) {
            logger.debug("finish method getGradient()");
        }
        return this.gradient;
    }

    private void updateGradient() {
        if (logger.isDebugEnabled()) {
            logger.debug("start method updateGradient()");
        }
        this.updatedFeatureLabelGradient();
        if (this.cmlcrf.considerPair()) {
            this.updateLabelLabelGradient();
        }
        if (logger.isDebugEnabled()) {
            logger.debug("finish method updateGradient()");
        }
    }

    private void updatedFeatureLabelGradient() {
        if (logger.isDebugEnabled()) {
            logger.debug("start method updatedFeatureLabelGradient()");
        }
        IntStream.range(0, this.numWeightsForFeatures).parallel().forEach(i -> this.gradient.set(i, this.calGradientForFeature(i)));
        if (logger.isDebugEnabled()) {
            logger.debug("finish method updatedFeatureLabelGradient()");
        }
    }

    private void updateLabelLabelGradient() {
        if (logger.isDebugEnabled()) {
            logger.debug("start method updateLabelLabelGradient()");
        }
        IntStream.range(this.numWeightsForFeatures, this.numWeightsForFeatures + this.numWeightsForLabelPairs).parallel().forEach(i -> this.gradient.set(i, this.calGradientForLabelPair(i)));
        if (logger.isDebugEnabled()) {
            logger.debug("finish method updateLabelLabelGradient()");
        }
    }

    private double calGradientForLabelPair(int parameterIndex) {
        double gradient = 0.0;
        int pos = parameterIndex - this.numWeightsForFeatures;
        for (int matched : this.labelPairToCombination.get(pos)) {
            gradient += this.combProbSums[matched];
        }
        gradient -= this.empiricalCounts[parameterIndex];
        if (this.regularizeAll) {
            gradient += this.cmlcrf.getWeights().getWeightForIndex(parameterIndex) / this.gaussianPriorVariance;
        }
        return gradient;
    }

    private double calGradientForFeature(int parameterIndex) {
        double gradient = 0.0;
        int classIndex = this.parameterToClass[parameterIndex];
        int featureIndex = this.parameterToFeature[parameterIndex];
        if (featureIndex == -1) {
            for (int i = 0; i < this.dataSet.getNumDataPoints(); ++i) {
                gradient += this.classProbMatrix[i][classIndex];
            }
        } else {
            Vector featureColumn = this.dataSet.getColumn(featureIndex);
            for (Vector.Element element : featureColumn.nonZeroes()) {
                int dataPointIndex = element.index();
                double featureValue = element.get();
                gradient += this.classProbMatrix[dataPointIndex][classIndex] * featureValue;
            }
        }
        gradient -= this.empiricalCounts[parameterIndex];
        if (this.regularizeAll) {
            gradient += this.cmlcrf.getWeights().getWeightForIndex(parameterIndex) / this.gaussianPriorVariance;
        } else if (featureIndex != -1) {
            gradient += this.cmlcrf.getWeights().getWeightForIndex(parameterIndex) / this.gaussianPriorVariance;
        }
        return gradient;
    }

    private void initEmpiricalCounts() {
        IntStream intStream = this.isParallel ? IntStream.range(0, this.numParameters).parallel() : IntStream.range(0, this.numParameters);
        intStream.forEach(this::calEmpiricalCount);
    }

    private void calEmpiricalCount(int parameterIndex) {
        if (parameterIndex < this.numWeightsForFeatures) {
            this.empiricalCounts[parameterIndex] = this.calEmpiricalCountForFeature(parameterIndex);
        } else if (parameterIndex < this.numWeightsForFeatures + this.numWeightsForLabelPairs) {
            this.empiricalCounts[parameterIndex] = this.calEmpiricalCountForLabelPair(parameterIndex);
        }
    }

    private double calEmpiricalCountForLabelPair(int parameterIndex) {
        double empiricalCount = 0.0;
        int pos = parameterIndex - this.numWeightsForFeatures;
        List<Integer> comIndices = this.labelPairToCombination.get(pos);
        for (int i = 0; i < this.dataSet.getNumDataPoints(); ++i) {
            for (int matchedCom : comIndices) {
                empiricalCount += this.targetDistribution[i][matchedCom];
            }
        }
        return empiricalCount;
    }

    private double calEmpiricalCountForFeature(int parameterIndex) {
        double empiricalCount = 0.0;
        int classIndex = this.parameterToClass[parameterIndex];
        int featureIndex = this.parameterToFeature[parameterIndex];
        if (featureIndex == -1) {
            for (int i = 0; i < this.dataSet.getNumDataPoints(); ++i) {
                empiricalCount += this.targetMarginals[i][classIndex];
            }
        } else {
            Vector column = this.dataSet.getColumn(featureIndex);
            for (Vector.Element element : column.nonZeroes()) {
                int dataIndex = element.index();
                double featureValue = element.get();
                empiricalCount += featureValue * this.targetMarginals[dataIndex][classIndex];
            }
        }
        return empiricalCount;
    }

    private void mapParameters() {
        this.parameterToL1 = new int[this.numWeightsForLabelPairs];
        this.parameterToL2 = new int[this.numWeightsForLabelPairs];
        int start = 0;
        for (int l1 = 0; l1 < this.numClasses; ++l1) {
            int l2 = l1 + 1;
            while (l2 < this.numClasses) {
                this.parameterToL1[start] = l1;
                this.parameterToL1[start + 1] = l1;
                this.parameterToL1[start + 2] = l1;
                this.parameterToL1[start + 3] = l1;
                this.parameterToL2[start] = l2;
                this.parameterToL2[start + 1] = l2;
                this.parameterToL2[start + 2] = l2;
                this.parameterToL2[start + 3] = l2++;
                start += 4;
            }
        }
        this.parameterToClass = new int[this.numWeightsForFeatures];
        this.parameterToFeature = new int[this.numWeightsForFeatures];
        for (int i = 0; i < this.numWeightsForFeatures; ++i) {
            this.parameterToClass[i] = this.cmlcrf.getWeights().getClassIndex(i);
            this.parameterToFeature[i] = this.cmlcrf.getWeights().getFeatureIndex(i);
        }
    }

    private void initComContainsLabel() {
        this.comContainsLabel = new boolean[this.numSupport][this.numClasses];
        for (int num = 0; num < this.numSupport; ++num) {
            for (int l = 0; l < this.numClasses; ++l) {
                if (!this.supportedCombinations.get(num).matchClass(l)) continue;
                this.comContainsLabel[num][l] = true;
            }
        }
    }

    @Override
    public double getValue() {
        if (this.isValueCacheValid) {
            return this.value;
        }
        this.value = this.getValueForAllData() + this.getPenalty();
        this.isValueCacheValid = true;
        return this.value;
    }

    private double getValueForAllData() {
        this.updateClassScoreMatrix();
        this.updateAssignmentScoreMatrix();
        IntStream intStream = this.isParallel ? IntStream.range(0, this.dataSet.getNumDataPoints()).parallel() : IntStream.range(0, this.dataSet.getNumDataPoints());
        return intStream.mapToDouble(this::getValueForOneData).sum();
    }

    private double getValueForOneData(int i) {
        double sum = 0.0;
        sum += MathUtil.logSumExp(this.combScoreMatrix[i]);
        double[] scores = this.combScoreMatrix[i];
        double[] targetProbs = this.targetDistribution[i];
        for (int j = 0; j < this.numSupport; ++j) {
            sum -= scores[j] * targetProbs[j];
        }
        return sum;
    }

    @Override
    public Vector getParameters() {
        return this.cmlcrf.getWeights().getAllWeights();
    }

    @Override
    public void setParameters(Vector parameters) {
        this.cmlcrf.getWeights().setWeightVector(parameters);
        this.isValueCacheValid = false;
        this.isGradientCacheValid = false;
        this.cmlcrf.updateCombLabelPartScores();
    }

    public double getPenalty() {
        int k;
        double weightSquare = 0.0;
        for (k = 0; k < this.numClasses; ++k) {
            Vector weightVector = this.cmlcrf.getWeights().getWeightsWithoutBiasForClass(k);
            weightSquare += weightVector.dot(weightVector);
        }
        if (this.regularizeAll) {
            for (k = 0; k < this.numClasses; ++k) {
                double bias = this.cmlcrf.getWeights().getBiasForClass(k);
                weightSquare += bias * bias;
            }
            Vector labelPairVector = this.cmlcrf.getWeights().getAllLabelPairWeights();
            weightSquare += labelPairVector.dot(labelPairVector);
        }
        return weightSquare / (2.0 * this.gaussianPriorVariance);
    }

    private void updateClassScoreMatrix() {
        if (logger.isDebugEnabled()) {
            logger.debug("start updateClassScoreMatrix()");
        }
        IntStream.range(0, this.dataSet.getNumDataPoints()).parallel().forEach(i -> {
            this.classScoreMatrix[i] = this.cmlcrf.predictClassScores(this.dataSet.getRow(i));
        });
        if (logger.isDebugEnabled()) {
            logger.debug("finish updateClassScoreMatrix()");
        }
    }

    private void updateAssignmentScoreMatrix() {
        if (logger.isDebugEnabled()) {
            logger.debug("start updateAssignmentScoreMatrix()");
        }
        IntStream.range(0, this.dataSet.getNumDataPoints()).parallel().forEach(i -> {
            this.combScoreMatrix[i] = this.cmlcrf.predictCombinationScores(this.classScoreMatrix[i]);
        });
        if (logger.isDebugEnabled()) {
            logger.debug("finish updateAssignmentScoreMatrix()");
        }
    }

    private void updateAssignmentProbMatrix() {
        if (logger.isDebugEnabled()) {
            logger.debug("start updateAssignmentProbMatrix()");
        }
        IntStream.range(0, this.dataSet.getNumDataPoints()).parallel().forEach(i -> {
            this.combProbMatrix[i] = this.cmlcrf.predictCombinationProbs(this.combScoreMatrix[i]);
        });
        if (logger.isDebugEnabled()) {
            logger.debug("finish updateAssignmentProbMatrix()");
        }
    }

    private void updateClassProbMatrix() {
        if (logger.isDebugEnabled()) {
            logger.debug("start updateClassProbMatrix()");
        }
        IntStream.range(0, this.dataSet.getNumDataPoints()).parallel().forEach(i -> {
            this.classProbMatrix[i] = this.cmlcrf.calClassProbs(this.combProbMatrix[i]);
        });
        if (logger.isDebugEnabled()) {
            logger.debug("finish updateClassProbMatrix()");
        }
    }

    private void mapPairToCombination() {
        this.labelPairToCombination = new ArrayList<List<Integer>>();
        for (int i = 0; i < this.numWeightsForLabelPairs; ++i) {
            this.labelPairToCombination.add(new ArrayList());
        }
        IntStream.range(0, this.numWeightsForLabelPairs).parallel().forEach(this::mapPairToCombination);
    }

    private void mapPairToCombination(int position) {
        List<Integer> list = this.labelPairToCombination.get(position);
        int l1 = this.parameterToL1[position];
        int l2 = this.parameterToL2[position];
        int featureCase = position % 4;
        block6: for (int c = 0; c < this.numSupport; ++c) {
            switch (featureCase) {
                case 0: {
                    if (this.comContainsLabel[c][l1] || this.comContainsLabel[c][l2]) continue block6;
                    list.add(c);
                    continue block6;
                }
                case 1: {
                    if (!this.comContainsLabel[c][l1] || this.comContainsLabel[c][l2]) continue block6;
                    list.add(c);
                    continue block6;
                }
                case 2: {
                    if (this.comContainsLabel[c][l1] || !this.comContainsLabel[c][l2]) continue block6;
                    list.add(c);
                    continue block6;
                }
                case 3: {
                    if (!this.comContainsLabel[c][l1] || !this.comContainsLabel[c][l2]) continue block6;
                    list.add(c);
                    continue block6;
                }
                default: {
                    throw new RuntimeException("feature case :" + featureCase + " failed.");
                }
            }
        }
    }

    private void updateCombProbSums(int combinationIndex) {
        double sum = 0.0;
        for (int i = 0; i < this.dataSet.getNumDataPoints(); ++i) {
            sum += this.combProbMatrix[i][combinationIndex];
        }
        this.combProbSums[combinationIndex] = sum;
    }

    private void updateCombProbSums() {
        IntStream.range(0, this.numSupport).parallel().forEach(this::updateCombProbSums);
    }

    private void initTargetMarginals() {
        this.targetMarginals = new double[this.numData][this.numClasses];
        IntStream.range(0, this.numData).parallel().forEach(this::initTargMarginals);
    }

    private void initTargMarginals(int dataPoint) {
        double[] joint = this.targetDistribution[dataPoint];
        for (int c = 0; c < joint.length; ++c) {
            MultiLabel multiLabel = this.supportedCombinations.get(c);
            double prob = joint[c];
            for (int l : multiLabel.getMatchedLabels()) {
                double[] dArray = this.targetMarginals[dataPoint];
                int n = l;
                dArray[n] = dArray[n] + prob;
            }
        }
    }
}

