/*
 * Decompiled with CFR 0.152.
 */
package edu.neu.ccs.pyramid.multilabel_classification.crf;

import edu.neu.ccs.pyramid.dataset.DataSet;
import edu.neu.ccs.pyramid.dataset.MultiLabel;
import edu.neu.ccs.pyramid.dataset.MultiLabelClfDataSet;
import edu.neu.ccs.pyramid.dataset.SequentialSparseDataSet;
import edu.neu.ccs.pyramid.multilabel_classification.crf.CMLCRF;
import edu.neu.ccs.pyramid.multilabel_classification.crf.CRFElasticNetLinearRegOptimizer;
import edu.neu.ccs.pyramid.multilabel_classification.crf.CRFLinearRegression;
import edu.neu.ccs.pyramid.optimization.Terminator;
import edu.neu.ccs.pyramid.util.MathUtil;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.stream.IntStream;
import org.apache.mahout.math.DenseVector;
import org.apache.mahout.math.SequentialAccessSparseVector;
import org.apache.mahout.math.Vector;

public class CMLCRFElasticNet {
    private Terminator terminator;
    private CMLCRF cmlcrf;
    private List<MultiLabel> supportedCombinations;
    private int numSupport;
    private MultiLabelClfDataSet dataSet;
    private int numClasses;
    private int numParameters;
    private int numWeightsForFeatures;
    private int numWeightsForLabelPairs;
    private double value;
    private Vector empiricalCounts;
    private Vector predictedCounts;
    private int[] parameterToL1;
    private int[] parameterToL2;
    private int[] parameterToClass;
    private int[] parameterToFeature;
    private boolean[][] comContainsLabel;
    private boolean isParallel = true;
    private boolean isValueCacheValid = false;
    private double[][] classScoreMatrix;
    private double[][] classProbMatrix;
    private double[][] combProbMatrix;
    private double[][] combScoreMatrix;
    private int numData;
    private int[] labelComIndices;
    private double l1Ratio;
    private double regularization;
    private int numFeature;
    private List<List<Integer>> combinationToLabelPair;
    private List<List<Integer>> labelPairToCombination;
    private double[] combProbSums;

    public CMLCRFElasticNet(CMLCRF cmlcrf, MultiLabelClfDataSet dataSet, double l1Ratio, double regularization) {
        int i;
        this.l1Ratio = l1Ratio;
        this.regularization = regularization;
        this.terminator = new Terminator();
        this.terminator.setGoal(Terminator.Goal.MINIMIZE);
        this.numFeature = dataSet.getNumFeatures();
        this.cmlcrf = cmlcrf;
        this.supportedCombinations = cmlcrf.getSupportCombinations();
        this.numSupport = cmlcrf.getNumSupports();
        this.dataSet = dataSet;
        this.numData = dataSet.getNumDataPoints();
        this.numClasses = dataSet.getNumClasses();
        this.numParameters = cmlcrf.getWeights().totalSize();
        this.numWeightsForFeatures = cmlcrf.getWeights().getNumWeightsForFeatures();
        this.numWeightsForLabelPairs = cmlcrf.getWeights().getNumWeightsForLabels();
        this.classScoreMatrix = new double[this.numData][this.numClasses];
        this.classProbMatrix = new double[this.numData][this.numClasses];
        this.combScoreMatrix = new double[this.numData][this.numSupport];
        this.combProbMatrix = new double[this.numData][this.numSupport];
        this.isValueCacheValid = false;
        this.empiricalCounts = new DenseVector(this.numParameters);
        this.predictedCounts = new DenseVector(this.numParameters);
        this.initCache();
        this.updateEmpiricalCounts();
        this.combinationToLabelPair = new ArrayList<List<Integer>>(this.numSupport);
        for (i = 0; i < this.numSupport; ++i) {
            this.combinationToLabelPair.add(new LinkedList());
        }
        this.labelPairToCombination = new ArrayList<List<Integer>>(this.numWeightsForLabelPairs);
        for (i = 0; i < this.numWeightsForLabelPairs; ++i) {
            this.labelPairToCombination.add(new ArrayList());
        }
        this.mapCombinattionToPair();
        this.mapPairToCombination();
        this.combProbSums = new double[this.numSupport];
        HashMap<MultiLabel, Integer> map = new HashMap<MultiLabel, Integer>();
        for (int s = 0; s < this.numSupport; ++s) {
            map.put(this.supportedCombinations.get(s), s);
        }
        this.labelComIndices = new int[dataSet.getNumDataPoints()];
        for (int i2 = 0; i2 < dataSet.getNumDataPoints(); ++i2) {
            this.labelComIndices[i2] = (Integer)map.get(dataSet.getMultiLabels()[i2]);
        }
    }

    public void optimize() {
        do {
            this.iterate();
        } while (!this.terminator.shouldTerminate());
    }

    public void iterate() {
        this.updateClassScoreMatrix();
        this.cmlcrf.updateCombLabelPartScores();
        this.updateAssignmentScoreMatrix();
        this.updateAssignmentProbMatrix();
        this.updateCombProbSums();
        this.updatePredictedCounts();
        this.updateClassProbMatrix();
        SequentialAccessSparseVector accumulateWeights = new SequentialAccessSparseVector(this.numParameters);
        Vector oldWeights = this.cmlcrf.getWeights().deepCopy().getAllWeights();
        for (int l = 0; l < this.numSupport; ++l) {
            DataSet newData = this.expandData(l);
            this.iterateForOneComb(newData, l);
            accumulateWeights = accumulateWeights.plus(this.cmlcrf.getWeights().getAllWeights());
            this.cmlcrf.getWeights().setWeightVector(oldWeights);
        }
        SequentialAccessSparseVector searchDirection = accumulateWeights;
        Vector gradient = this.predictedCounts.minus(this.empiricalCounts).divide((double)this.numData);
        this.lineSearch((Vector)searchDirection, gradient);
        this.terminator.add(this.getValue());
    }

    private DataSet expandData(int l) {
        SequentialSparseDataSet newData = new SequentialSparseDataSet(this.numData, this.numParameters, false);
        MultiLabel label = this.supportedCombinations.get(l);
        List<Integer> labelPairForL = this.combinationToLabelPair.get(l);
        for (int i = 0; i < this.numData; ++i) {
            for (int y : label.getMatchedLabels()) {
                newData.setFeatureValue(i, (this.numFeature + 1) * y, 1.0);
                for (Vector.Element element : this.dataSet.getRow(i).nonZeroes()) {
                    int index = element.index();
                    double value = element.get();
                    newData.setFeatureValue(i, (this.numFeature + 1) * y + index + 1, value);
                }
            }
            for (int y : labelPairForL) {
                newData.setFeatureValue(i, this.numWeightsForFeatures + y, 1.0);
            }
        }
        return newData;
    }

    private void iterateForOneComb(DataSet newData, int l) {
        double[] realLabels = new double[this.numData];
        double[] instanceWeights = new double[this.numData];
        IntStream.range(0, this.numData).parallel().forEach(i -> {
            boolean indicator;
            double prob = this.combProbMatrix[i][l];
            double classScore = this.combScoreMatrix[i][l];
            int y = this.labelComIndices[i];
            double frac = 0.0;
            double tmpP = prob * (1.0 - prob);
            boolean bl = indicator = y == l;
            if (prob != 0.0 && prob != 1.0) {
                frac = ((double)indicator - prob) / tmpP;
            }
            if (frac > 1.0) {
                frac = 1.0;
            }
            if (frac < -1.0) {
                frac = -1.0;
            }
            realLabels[i] = classScore + frac;
            instanceWeights[i] = tmpP;
        });
        CRFLinearRegression linearRegression = new CRFLinearRegression(this.numParameters, this.cmlcrf.getWeights().getAllWeights());
        CRFElasticNetLinearRegOptimizer linearRegTrainer = new CRFElasticNetLinearRegOptimizer(linearRegression, newData, realLabels, instanceWeights);
        linearRegTrainer.setRegularization(this.regularization);
        linearRegTrainer.setL1Ratio(this.l1Ratio);
        linearRegTrainer.optimize();
        this.isValueCacheValid = false;
    }

    private void updateClassScoreMatrix() {
        IntStream.range(0, this.dataSet.getNumDataPoints()).parallel().forEach(i -> {
            this.classScoreMatrix[i] = this.cmlcrf.predictClassScores(this.dataSet.getRow(i));
        });
    }

    private void updateAssignmentScoreMatrix() {
        IntStream.range(0, this.dataSet.getNumDataPoints()).parallel().forEach(i -> {
            this.combScoreMatrix[i] = this.cmlcrf.predictCombinationScores(this.classScoreMatrix[i]);
        });
    }

    private void updateAssignmentProbMatrix() {
        IntStream.range(0, this.dataSet.getNumDataPoints()).parallel().forEach(i -> {
            this.combProbMatrix[i] = this.cmlcrf.predictCombinationProbs(this.combScoreMatrix[i]);
        });
    }

    private void initCache() {
        this.parameterToL1 = new int[this.numWeightsForLabelPairs];
        this.parameterToL2 = new int[this.numWeightsForLabelPairs];
        int start = 0;
        for (int l1 = 0; l1 < this.numClasses; ++l1) {
            int l2 = l1 + 1;
            while (l2 < this.numClasses) {
                this.parameterToL1[start] = l1;
                this.parameterToL1[start + 1] = l1;
                this.parameterToL1[start + 2] = l1;
                this.parameterToL1[start + 3] = l1;
                this.parameterToL2[start] = l2;
                this.parameterToL2[start + 1] = l2;
                this.parameterToL2[start + 2] = l2;
                this.parameterToL2[start + 3] = l2++;
                start += 4;
            }
        }
        this.parameterToClass = new int[this.numWeightsForFeatures];
        this.parameterToFeature = new int[this.numWeightsForFeatures];
        for (int i = 0; i < this.numWeightsForFeatures; ++i) {
            this.parameterToClass[i] = this.cmlcrf.getWeights().getClassIndex(i);
            this.parameterToFeature[i] = this.cmlcrf.getWeights().getFeatureIndex(i);
        }
        this.comContainsLabel = new boolean[this.numSupport][this.numClasses];
        for (int num = 0; num < this.numSupport; ++num) {
            for (int l = 0; l < this.numClasses; ++l) {
                if (!this.supportedCombinations.get(num).matchClass(l)) continue;
                this.comContainsLabel[num][l] = true;
            }
        }
    }

    private void mapPairToCombination() {
        IntStream.range(0, this.numWeightsForLabelPairs).parallel().forEach(this::mapPairToCombination);
    }

    private void mapPairToCombination(int position) {
        List<Integer> list = this.labelPairToCombination.get(position);
        int l1 = this.parameterToL1[position];
        int l2 = this.parameterToL2[position];
        int featureCase = position % 4;
        block6: for (int c = 0; c < this.numSupport; ++c) {
            switch (featureCase) {
                case 0: {
                    if (this.comContainsLabel[c][l1] || this.comContainsLabel[c][l2]) continue block6;
                    list.add(c);
                    continue block6;
                }
                case 1: {
                    if (!this.comContainsLabel[c][l1] || this.comContainsLabel[c][l2]) continue block6;
                    list.add(c);
                    continue block6;
                }
                case 2: {
                    if (this.comContainsLabel[c][l1] || !this.comContainsLabel[c][l2]) continue block6;
                    list.add(c);
                    continue block6;
                }
                case 3: {
                    if (!this.comContainsLabel[c][l1] || !this.comContainsLabel[c][l2]) continue block6;
                    list.add(c);
                    continue block6;
                }
                default: {
                    throw new RuntimeException("feature case :" + featureCase + " failed.");
                }
            }
        }
    }

    private void mapCombinattionToPair() {
        IntStream.range(0, this.numSupport).forEach(this::mapCombinattionToPair);
    }

    private void mapCombinattionToPair(int s) {
        block6: for (int position = 0; position < this.numWeightsForLabelPairs; ++position) {
            int l1 = this.parameterToL1[position];
            int l2 = this.parameterToL2[position];
            int featureCase = position % 4;
            switch (featureCase) {
                case 0: {
                    if (this.comContainsLabel[s][l1] || this.comContainsLabel[s][l2]) continue block6;
                    this.combinationToLabelPair.get(s).add(position);
                    continue block6;
                }
                case 1: {
                    if (!this.comContainsLabel[s][l1] || this.comContainsLabel[s][l2]) continue block6;
                    this.combinationToLabelPair.get(s).add(position);
                    continue block6;
                }
                case 2: {
                    if (this.comContainsLabel[s][l1] || !this.comContainsLabel[s][l2]) continue block6;
                    this.combinationToLabelPair.get(s).add(position);
                    continue block6;
                }
                case 3: {
                    if (!this.comContainsLabel[s][l1] || !this.comContainsLabel[s][l2]) continue block6;
                    this.combinationToLabelPair.get(s).add(position);
                    continue block6;
                }
                default: {
                    throw new RuntimeException("feature case :" + featureCase + " failed.");
                }
            }
        }
    }

    public double getValue() {
        if (this.isValueCacheValid) {
            return this.value;
        }
        this.value = this.getValueForAllData() + this.getPenalty();
        this.isValueCacheValid = true;
        return this.value;
    }

    private double getPenalty() {
        Vector vector = this.cmlcrf.getWeights().getAllWeights();
        double norm = (1.0 - this.l1Ratio) * 0.5 * Math.pow(vector.norm(2.0), 2.0) + this.l1Ratio * vector.norm(1.0);
        return norm * this.regularization;
    }

    private double getValueForAllData() {
        this.updateClassScoreMatrix();
        this.updateAssignmentScoreMatrix();
        IntStream intStream = this.isParallel ? IntStream.range(0, this.dataSet.getNumDataPoints()).parallel() : IntStream.range(0, this.dataSet.getNumDataPoints());
        return intStream.mapToDouble(this::getValueForOneData).sum();
    }

    private double getValueForOneData(int i) {
        double sum = 0.0;
        sum += MathUtil.logSumExp(this.combScoreMatrix[i]);
        return sum -= this.combScoreMatrix[i][this.labelComIndices[i]];
    }

    private void updateEmpiricalCounts() {
        IntStream intStream = this.isParallel ? IntStream.range(0, this.numParameters).parallel() : IntStream.range(0, this.numParameters);
        intStream.forEach(this::calEmpiricalCount);
    }

    private void calEmpiricalCount(int parameterIndex) {
        if (parameterIndex < this.numWeightsForFeatures) {
            this.empiricalCounts.set(parameterIndex, this.calEmpiricalCountForFeature(parameterIndex));
        } else if (parameterIndex < this.numWeightsForFeatures + this.numWeightsForLabelPairs) {
            this.empiricalCounts.set(parameterIndex, this.calEmpiricalCountForLabelPair(parameterIndex));
        }
    }

    private double calEmpiricalCountForLabelPair(int parameterIndex) {
        double empiricalCount = 0.0;
        int start = parameterIndex - this.numWeightsForFeatures;
        int l1 = this.parameterToL1[start];
        int l2 = this.parameterToL2[start];
        int featureCase = start % 4;
        block6: for (int i = 0; i < this.dataSet.getNumDataPoints(); ++i) {
            MultiLabel label = this.dataSet.getMultiLabels()[i];
            switch (featureCase) {
                case 0: {
                    if (label.matchClass(l1) || label.matchClass(l2)) continue block6;
                    empiricalCount += 1.0;
                    continue block6;
                }
                case 1: {
                    if (!label.matchClass(l1) || label.matchClass(l2)) continue block6;
                    empiricalCount += 1.0;
                    continue block6;
                }
                case 2: {
                    if (label.matchClass(l1) || !label.matchClass(l2)) continue block6;
                    empiricalCount += 1.0;
                    continue block6;
                }
                case 3: {
                    if (!label.matchClass(l1) || !label.matchClass(l2)) continue block6;
                    empiricalCount += 1.0;
                    continue block6;
                }
                default: {
                    throw new RuntimeException("feature case :" + featureCase + " failed.");
                }
            }
        }
        return empiricalCount;
    }

    private double calEmpiricalCountForFeature(int parameterIndex) {
        double empiricalCount = 0.0;
        int classIndex = this.parameterToClass[parameterIndex];
        int featureIndex = this.parameterToFeature[parameterIndex];
        if (featureIndex == -1) {
            for (int i = 0; i < this.dataSet.getNumDataPoints(); ++i) {
                if (!this.dataSet.getMultiLabels()[i].matchClass(classIndex)) continue;
                empiricalCount += 1.0;
            }
        } else {
            Vector column = this.dataSet.getColumn(featureIndex);
            MultiLabel[] multiLabels = this.dataSet.getMultiLabels();
            for (Vector.Element element : column.nonZeroes()) {
                int dataIndex = element.index();
                double featureValue = element.get();
                if (!multiLabels[dataIndex].matchClass(classIndex)) continue;
                empiricalCount += featureValue;
            }
        }
        return empiricalCount;
    }

    private void updateClassProbMatrix() {
        IntStream.range(0, this.dataSet.getNumDataPoints()).parallel().forEach(i -> {
            this.classProbMatrix[i] = this.cmlcrf.calClassProbs(this.combProbMatrix[i]);
        });
    }

    private void updatePredictedCounts() {
        IntStream.range(0, this.numWeightsForFeatures).parallel().forEach(i -> this.predictedCounts.set(i, this.calPredictedFeatureCounts(i)));
        IntStream.range(this.numWeightsForFeatures, this.numParameters).parallel().forEach(i -> this.predictedCounts.set(i, this.calPredictedLabelPairCounts(i)));
    }

    private double calPredictedLabelPairCounts(int parameterIndex) {
        double count = 0.0;
        int pos = parameterIndex - this.numWeightsForFeatures;
        for (int matched : this.labelPairToCombination.get(pos)) {
            count += this.combProbSums[matched];
        }
        return count;
    }

    private double calPredictedFeatureCounts(int parameterIndex) {
        double count = 0.0;
        int classIndex = this.parameterToClass[parameterIndex];
        int featureIndex = this.parameterToFeature[parameterIndex];
        if (featureIndex == -1) {
            for (int i = 0; i < this.numData; ++i) {
                count += this.classProbMatrix[i][classIndex];
            }
        } else {
            Vector featureColumn = this.dataSet.getColumn(featureIndex);
            for (Vector.Element element : featureColumn.nonZeroes()) {
                int dataPointIndex = element.index();
                double featureValue = element.get();
                count += this.classProbMatrix[dataPointIndex][classIndex] * featureValue;
            }
        }
        return count;
    }

    private void updateCombProbSums(int combinationIndex) {
        double sum = 0.0;
        for (int i = 0; i < this.dataSet.getNumDataPoints(); ++i) {
            sum += this.combProbMatrix[i][combinationIndex];
        }
        this.combProbSums[combinationIndex] = sum;
    }

    private void updateCombProbSums() {
        IntStream.range(0, this.numSupport).parallel().forEach(this::updateCombProbSums);
    }

    private void lineSearch(Vector searchDirection, Vector gradient) {
        double initialStepLength = 1.0;
        double shrinkage = 0.5;
        double c = 1.0E-4;
        double stepLength = initialStepLength;
        Vector start = this.cmlcrf.getWeights().getAllWeights();
        double penalty = this.getPenalty();
        double value = this.getValue();
        double product = gradient.dot(searchDirection);
        Vector localSearchDir = searchDirection;
        while (true) {
            Vector step = localSearchDir.times(stepLength);
            Vector target = start.plus(step);
            this.cmlcrf.getWeights().setWeightVector(target);
            double targetPenalty = this.getPenalty();
            double targetValue = this.getValue();
            if (targetValue <= value + c * stepLength * (product + targetPenalty - penalty)) break;
            stepLength *= shrinkage;
        }
    }
}

