/*
 * Decompiled with CFR 0.152.
 */
package edu.neu.ccs.pyramid.multilabel_classification.crf;

import edu.neu.ccs.pyramid.dataset.MultiLabel;
import edu.neu.ccs.pyramid.dataset.MultiLabelClfDataSet;
import edu.neu.ccs.pyramid.multilabel_classification.crf.CMLCRF;
import edu.neu.ccs.pyramid.optimization.Terminator;
import edu.neu.ccs.pyramid.util.MathUtil;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.stream.IntStream;
import org.apache.mahout.math.Vector;

public class BlockwiseCD {
    private CMLCRF cmlcrf;
    private List<MultiLabel> supportedCombinations;
    private int numSupport;
    private MultiLabelClfDataSet dataSet;
    private double regularization;
    private double l1Ratio;
    private int numClasses;
    private int numParameters;
    private int numWeightsForFeatures;
    private int numWeightsForLabelPairs;
    private double value;
    private double[] empiricalCounts;
    private int[] parameterToL1;
    private int[] parameterToL2;
    private int[] parameterToClass;
    private int[] parameterToFeature;
    private boolean[][] comContainsLabel;
    private boolean isParallel = true;
    private double[][] classScoreMatrix;
    private double[][] classProbMatrix;
    private double[][] combProbMatrix;
    private double[][] combScoreMatrix;
    private int numData;
    private List<List<Integer>> labelPairToCombination;
    private double[] combProbSums;
    private int[] labelComIndices;
    private Terminator terminator;
    private int numFeatures;
    private int[][] labelPairToParams;
    private double weight;

    public BlockwiseCD(CMLCRF cmlcrf, MultiLabelClfDataSet dataSet) {
        this(cmlcrf, dataSet, 0.0, 1.0);
    }

    public BlockwiseCD(CMLCRF cmlcrf, MultiLabelClfDataSet dataSet, double l1Ratio, double regularization) {
        this.cmlcrf = cmlcrf;
        this.supportedCombinations = cmlcrf.getSupportCombinations();
        this.numSupport = cmlcrf.getNumSupports();
        this.dataSet = dataSet;
        this.numData = dataSet.getNumDataPoints();
        this.numClasses = dataSet.getNumClasses();
        this.l1Ratio = l1Ratio;
        this.regularization = regularization;
        this.numParameters = cmlcrf.getWeights().totalSize();
        this.numWeightsForFeatures = cmlcrf.getWeights().getNumWeightsForFeatures();
        this.numWeightsForLabelPairs = cmlcrf.getWeights().getNumWeightsForLabels();
        this.classScoreMatrix = new double[this.numData][this.numClasses];
        this.classProbMatrix = new double[this.numData][this.numClasses];
        this.combScoreMatrix = new double[this.numData][this.numSupport];
        this.combProbMatrix = new double[this.numData][this.numSupport];
        this.empiricalCounts = new double[this.numParameters];
        this.initCache();
        this.updateEmpiricalCounts();
        this.labelPairToCombination = new ArrayList<List<Integer>>();
        for (int i = 0; i < this.numWeightsForLabelPairs; ++i) {
            this.labelPairToCombination.add(new ArrayList());
        }
        this.mapPairToCombination();
        this.combProbSums = new double[this.numSupport];
        HashMap<MultiLabel, Integer> map = new HashMap<MultiLabel, Integer>();
        for (int s = 0; s < this.numSupport; ++s) {
            map.put(this.supportedCombinations.get(s), s);
        }
        this.labelComIndices = new int[dataSet.getNumDataPoints()];
        for (int i = 0; i < dataSet.getNumDataPoints(); ++i) {
            this.labelComIndices[i] = (Integer)map.get(dataSet.getMultiLabels()[i]);
        }
        this.terminator = new Terminator();
        this.numFeatures = dataSet.getNumFeatures();
        this.weight = 1.0 / (double)this.numData;
    }

    public void optimize() {
        int iter = 1;
        while (true) {
            this.iterate();
            double loss = this.getValue();
            this.terminator.add(loss);
            System.out.println(iter + ": " + loss);
            if (this.terminator.shouldTerminate()) break;
            ++iter;
        }
    }

    public void iterate() {
        int i;
        for (i = 0; i < this.numWeightsForFeatures; ++i) {
            this.updateClassScoreMatrix();
            this.updateAssignmentScoreMatrix();
            this.updateAssignmentProbMatrix();
            this.updateCombProbSums();
            this.updateClassProbMatrix();
            this.iterateForF(i);
        }
        for (i = this.numWeightsForFeatures; i < this.numParameters; ++i) {
            this.updateClassScoreMatrix();
            this.updateAssignmentScoreMatrix();
            this.updateAssignmentProbMatrix();
            this.updateCombProbSums();
            this.updateClassProbMatrix();
            this.iterateForLF(i);
        }
    }

    private void iterateForLF(int parameterIndex) {
        int pos = parameterIndex - this.numWeightsForFeatures;
        double gradientForLabelPair = this.calGradientForLabelPair(pos);
        double hessiansForLabelPair = this.calHessiansForLabelPair(pos);
        double fit = this.weight * (hessiansForLabelPair * this.cmlcrf.getWeights().getWeightForIndex(parameterIndex) - gradientForLabelPair);
        double numerator = this.softThreshold(fit);
        double denominator = this.weight * hessiansForLabelPair + this.regularization * (1.0 - this.l1Ratio);
        double newCoeff = 0.0;
        if (denominator != 0.0) {
            newCoeff = numerator / denominator;
        }
        this.cmlcrf.getWeights().getAllWeights().set(parameterIndex, newCoeff);
    }

    private void iterateForF(int parameterIndex) {
        int classIndex = this.parameterToClass[parameterIndex];
        int featureIndex = this.parameterToFeature[parameterIndex];
        double gradientForFeature = this.calGradientForFeature(classIndex, featureIndex, parameterIndex);
        double hessianForFeature = this.calHessiansForFeature(classIndex, featureIndex);
        double fit = this.weight * (hessianForFeature * this.cmlcrf.getWeights().getWeightForIndex(parameterIndex) - gradientForFeature);
        double numerator = this.softThreshold(fit);
        double denominator = this.weight * hessianForFeature + this.regularization * (1.0 - this.l1Ratio);
        double newCoeff = 0.0;
        if (denominator != 0.0) {
            newCoeff = numerator / denominator;
        }
        this.cmlcrf.getWeights().getAllWeights().set(parameterIndex, newCoeff);
    }

    private void iterateForL(int l) {
        this.updateClassScoreMatrix();
        this.updateAssignmentScoreMatrix();
        this.updateAssignmentProbMatrix();
        this.updateCombProbSums();
        this.updateClassProbMatrix();
        for (int m = -1; m < this.numFeatures; ++m) {
            int parameterIndex = l * (this.numFeatures + 1) + m + 1;
            double gradientForFeature = this.calGradientForFeature(l, m, parameterIndex);
            double hessianForFeature = this.calHessiansForFeature(l, m);
            double fit = hessianForFeature * this.cmlcrf.getWeights().getWeightForIndex(parameterIndex) - gradientForFeature;
            double numerator = this.softThreshold(fit);
            double denominator = hessianForFeature + this.regularization * (1.0 - this.l1Ratio);
            double newCoeff = 0.0;
            if (denominator != 0.0) {
                newCoeff = numerator / denominator;
            }
            this.cmlcrf.getWeights().getAllWeights().set(parameterIndex, newCoeff);
        }
        int numLabelPair = this.numClasses - l - 1;
        if (numLabelPair > 0) {
            for (int l2 = l + 1; l2 < this.numClasses; ++l2) {
                for (int pos = this.labelPairToParams[l][l2]; pos < this.labelPairToParams[l][l2] + 4; ++pos) {
                    double gradientForLabelPair = this.calGradientForLabelPair(pos);
                    double hessiansForLabelPair = this.calHessiansForLabelPair(pos);
                    int parameterIndex = pos + this.numWeightsForFeatures;
                    double fit = hessiansForLabelPair * this.cmlcrf.getWeights().getWeightForIndex(parameterIndex) - gradientForLabelPair;
                    double numerator = this.softThreshold(fit);
                    double denominator = hessiansForLabelPair + this.regularization * (1.0 - this.l1Ratio);
                    double newCoeff = 0.0;
                    if (denominator != 0.0) {
                        newCoeff = numerator / denominator;
                    }
                    this.cmlcrf.getWeights().getAllWeights().set(parameterIndex, newCoeff);
                }
            }
        }
    }

    private static double softThreshold(double z, double gamma) {
        if (z > 0.0 && gamma < Math.abs(z)) {
            return z - gamma;
        }
        if (z < 0.0 && gamma < Math.abs(z)) {
            return z + gamma;
        }
        return 0.0;
    }

    private double softThreshold(double z) {
        return BlockwiseCD.softThreshold(z, this.regularization * this.l1Ratio);
    }

    private double calGradientForLabelPair(int pos) {
        double count = 0.0;
        for (int matched : this.labelPairToCombination.get(pos)) {
            count -= this.combProbSums[matched];
        }
        return count += this.empiricalCounts[pos + this.numWeightsForFeatures];
    }

    private double calHessiansForLabelPair(int pos) {
        double count = 0.0;
        for (int matched : this.labelPairToCombination.get(pos)) {
            count -= this.combProbSums[matched];
        }
        for (int i = 0; i < this.numData; ++i) {
            double[] probs = this.combProbMatrix[i];
            double matchedSum = 0.0;
            for (int matched : this.labelPairToCombination.get(pos)) {
                matchedSum += probs[matched];
            }
            count += Math.pow(matchedSum, 2.0);
        }
        return count;
    }

    private double calHessiansForFeature(int l, int m) {
        double count = 0.0;
        if (m == -1) {
            for (int i = 0; i < this.numData; ++i) {
                count += Math.pow(this.classProbMatrix[i][l], 2.0) - this.classProbMatrix[i][l];
            }
        } else {
            Vector featureColumn = this.dataSet.getColumn(m);
            for (Vector.Element element : featureColumn.nonZeroes()) {
                int dataPointIndex = element.index();
                double featureValue = element.get();
                count += Math.pow(this.classProbMatrix[dataPointIndex][l] * featureValue, 2.0) - this.classProbMatrix[dataPointIndex][l] * Math.pow(featureValue, 2.0);
            }
        }
        return count;
    }

    private double calGradientForFeature(int l, int m, int parameterIndex) {
        double count = 0.0;
        if (m == -1) {
            for (int i = 0; i < this.numData; ++i) {
                count -= this.classProbMatrix[i][l];
            }
        } else {
            Vector featureColumn = this.dataSet.getColumn(m);
            for (Vector.Element element : featureColumn.nonZeroes()) {
                int dataPointIndex = element.index();
                double featureValue = element.get();
                count -= this.classProbMatrix[dataPointIndex][l] * featureValue;
            }
        }
        return count += this.empiricalCounts[parameterIndex];
    }

    private void updateEmpiricalCounts() {
        IntStream intStream = this.isParallel ? IntStream.range(0, this.numParameters).parallel() : IntStream.range(0, this.numParameters);
        intStream.forEach(this::calEmpiricalCount);
    }

    private void calEmpiricalCount(int parameterIndex) {
        if (parameterIndex < this.numWeightsForFeatures) {
            this.empiricalCounts[parameterIndex] = this.calEmpiricalCountForFeature(parameterIndex);
        } else if (parameterIndex < this.numWeightsForFeatures + this.numWeightsForLabelPairs) {
            this.empiricalCounts[parameterIndex] = this.calEmpiricalCountForLabelPair(parameterIndex);
        }
    }

    private double calEmpiricalCountForLabelPair(int parameterIndex) {
        double empiricalCount = 0.0;
        int start = parameterIndex - this.numWeightsForFeatures;
        int l1 = this.parameterToL1[start];
        int l2 = this.parameterToL2[start];
        int featureCase = start % 4;
        block6: for (int i = 0; i < this.dataSet.getNumDataPoints(); ++i) {
            MultiLabel label = this.dataSet.getMultiLabels()[i];
            switch (featureCase) {
                case 0: {
                    if (label.matchClass(l1) || label.matchClass(l2)) continue block6;
                    empiricalCount += 1.0;
                    continue block6;
                }
                case 1: {
                    if (!label.matchClass(l1) || label.matchClass(l2)) continue block6;
                    empiricalCount += 1.0;
                    continue block6;
                }
                case 2: {
                    if (label.matchClass(l1) || !label.matchClass(l2)) continue block6;
                    empiricalCount += 1.0;
                    continue block6;
                }
                case 3: {
                    if (!label.matchClass(l1) || !label.matchClass(l2)) continue block6;
                    empiricalCount += 1.0;
                    continue block6;
                }
                default: {
                    throw new RuntimeException("feature case :" + featureCase + " failed.");
                }
            }
        }
        return empiricalCount;
    }

    private double calEmpiricalCountForFeature(int parameterIndex) {
        double empiricalCount = 0.0;
        int classIndex = this.parameterToClass[parameterIndex];
        int featureIndex = this.parameterToFeature[parameterIndex];
        if (featureIndex == -1) {
            for (int i = 0; i < this.dataSet.getNumDataPoints(); ++i) {
                if (!this.dataSet.getMultiLabels()[i].matchClass(classIndex)) continue;
                empiricalCount += 1.0;
            }
        } else {
            Vector column = this.dataSet.getColumn(featureIndex);
            MultiLabel[] multiLabels = this.dataSet.getMultiLabels();
            for (Vector.Element element : column.nonZeroes()) {
                int dataIndex = element.index();
                double featureValue = element.get();
                if (!multiLabels[dataIndex].matchClass(classIndex)) continue;
                empiricalCount += featureValue;
            }
        }
        return empiricalCount;
    }

    private void initCache() {
        this.parameterToL1 = new int[this.numWeightsForLabelPairs];
        this.parameterToL2 = new int[this.numWeightsForLabelPairs];
        this.labelPairToParams = new int[this.numClasses][this.numClasses];
        int start = 0;
        for (int l1 = 0; l1 < this.numClasses; ++l1) {
            for (int l2 = l1 + 1; l2 < this.numClasses; ++l2) {
                this.parameterToL1[start] = l1;
                this.parameterToL1[start + 1] = l1;
                this.parameterToL1[start + 2] = l1;
                this.parameterToL1[start + 3] = l1;
                this.parameterToL2[start] = l2;
                this.parameterToL2[start + 1] = l2;
                this.parameterToL2[start + 2] = l2;
                this.parameterToL2[start + 3] = l2;
                this.labelPairToParams[l1][l2] = start;
                start += 4;
            }
        }
        this.parameterToClass = new int[this.numWeightsForFeatures];
        this.parameterToFeature = new int[this.numWeightsForFeatures];
        for (int i = 0; i < this.numWeightsForFeatures; ++i) {
            this.parameterToClass[i] = this.cmlcrf.getWeights().getClassIndex(i);
            this.parameterToFeature[i] = this.cmlcrf.getWeights().getFeatureIndex(i);
        }
        this.comContainsLabel = new boolean[this.numSupport][this.numClasses];
        for (int num = 0; num < this.numSupport; ++num) {
            for (int l = 0; l < this.numClasses; ++l) {
                if (!this.supportedCombinations.get(num).matchClass(l)) continue;
                this.comContainsLabel[num][l] = true;
            }
        }
    }

    public double getValue() {
        this.value = this.getValueForAllData() + this.getPenalty();
        return this.value;
    }

    private double getPenalty() {
        Vector vector = this.cmlcrf.getWeights().getAllWeights();
        double norm = (1.0 - this.l1Ratio) * 0.5 * Math.pow(vector.norm(2.0), 2.0) + this.l1Ratio * vector.norm(1.0);
        return norm * this.regularization;
    }

    private double getValueForAllData() {
        this.updateClassScoreMatrix();
        this.updateAssignmentScoreMatrix();
        IntStream intStream = this.isParallel ? IntStream.range(0, this.dataSet.getNumDataPoints()).parallel() : IntStream.range(0, this.dataSet.getNumDataPoints());
        return intStream.mapToDouble(this::getValueForOneData).sum();
    }

    private double getValueForOneData(int i) {
        double sum = 0.0;
        sum += MathUtil.logSumExp(this.combScoreMatrix[i]);
        return sum -= this.combScoreMatrix[i][this.labelComIndices[i]];
    }

    public Vector getParameters() {
        return this.cmlcrf.getWeights().getAllWeights();
    }

    private void updateClassScoreMatrix() {
        IntStream.range(0, this.dataSet.getNumDataPoints()).parallel().forEach(i -> {
            this.classScoreMatrix[i] = this.cmlcrf.predictClassScores(this.dataSet.getRow(i));
        });
    }

    private void updateAssignmentScoreMatrix() {
        IntStream.range(0, this.dataSet.getNumDataPoints()).parallel().forEach(i -> {
            this.combScoreMatrix[i] = this.cmlcrf.predictCombinationScores(this.classScoreMatrix[i]);
        });
    }

    private void updateAssignmentProbMatrix() {
        IntStream.range(0, this.dataSet.getNumDataPoints()).parallel().forEach(i -> {
            this.combProbMatrix[i] = this.cmlcrf.predictCombinationProbs(this.combScoreMatrix[i]);
        });
    }

    private void updateClassProbMatrix() {
        IntStream.range(0, this.dataSet.getNumDataPoints()).parallel().forEach(i -> {
            this.classProbMatrix[i] = this.cmlcrf.calClassProbs(this.combProbMatrix[i]);
        });
    }

    private void mapPairToCombination() {
        IntStream.range(0, this.numWeightsForLabelPairs).parallel().forEach(this::mapPairToCombination);
    }

    private void mapPairToCombination(int position) {
        List<Integer> list = this.labelPairToCombination.get(position);
        int l1 = this.parameterToL1[position];
        int l2 = this.parameterToL2[position];
        int featureCase = position % 4;
        block6: for (int c = 0; c < this.numSupport; ++c) {
            switch (featureCase) {
                case 0: {
                    if (this.comContainsLabel[c][l1] || this.comContainsLabel[c][l2]) continue block6;
                    list.add(c);
                    continue block6;
                }
                case 1: {
                    if (!this.comContainsLabel[c][l1] || this.comContainsLabel[c][l2]) continue block6;
                    list.add(c);
                    continue block6;
                }
                case 2: {
                    if (this.comContainsLabel[c][l1] || !this.comContainsLabel[c][l2]) continue block6;
                    list.add(c);
                    continue block6;
                }
                case 3: {
                    if (!this.comContainsLabel[c][l1] || !this.comContainsLabel[c][l2]) continue block6;
                    list.add(c);
                    continue block6;
                }
                default: {
                    throw new RuntimeException("feature case :" + featureCase + " failed.");
                }
            }
        }
    }

    private void updateCombProbSums(int combinationIndex) {
        double sum = 0.0;
        for (int i = 0; i < this.dataSet.getNumDataPoints(); ++i) {
            double prob = this.combProbMatrix[i][combinationIndex];
            sum += prob;
        }
        this.combProbSums[combinationIndex] = sum;
    }

    private void updateCombProbSums() {
        IntStream.range(0, this.numSupport).parallel().forEach(this::updateCombProbSums);
    }

    public Terminator getTerminator() {
        return this.terminator;
    }
}

