/*
 * Decompiled with CFR 0.152.
 */
package edu.neu.ccs.pyramid.multilabel_classification.crf;

import edu.neu.ccs.pyramid.dataset.DataSetUtil;
import edu.neu.ccs.pyramid.dataset.LabelTranslator;
import edu.neu.ccs.pyramid.dataset.MultiLabel;
import edu.neu.ccs.pyramid.dataset.MultiLabelClfDataSet;
import edu.neu.ccs.pyramid.feature.FeatureList;
import edu.neu.ccs.pyramid.multilabel_classification.MultiLabelClassifier;
import edu.neu.ccs.pyramid.multilabel_classification.crf.Weights;
import edu.neu.ccs.pyramid.util.ArgMax;
import edu.neu.ccs.pyramid.util.MathUtil;
import java.io.BufferedOutputStream;
import java.io.File;
import java.io.FileOutputStream;
import java.io.ObjectOutputStream;
import java.io.Serializable;
import java.util.List;
import java.util.stream.IntStream;
import org.apache.mahout.math.Vector;

public class CMLCRF
implements MultiLabelClassifier,
MultiLabelClassifier.AssignmentProbEstimator,
Serializable {
    private static final long serialVersionUID = 3L;
    private int numClasses;
    private int numFeatures;
    private Weights weights;
    private List<MultiLabel> supportCombinations;
    private int numSupports;
    private double[] combinationLabelPartScores;
    private boolean considerPair = true;
    private double lossStrength = 1.0;
    private LabelTranslator labelTranslator;
    private FeatureList featureList;

    public CMLCRF(MultiLabelClfDataSet dataSet) {
        this.numClasses = dataSet.getNumClasses();
        this.numFeatures = dataSet.getNumFeatures();
        this.weights = new Weights(this.numClasses, this.numFeatures);
        this.supportCombinations = DataSetUtil.gatherMultiLabels(dataSet);
        this.numSupports = this.supportCombinations.size();
        this.combinationLabelPartScores = new double[this.supportCombinations.size()];
        this.updateCombLabelPartScores();
        this.labelTranslator = dataSet.getLabelTranslator();
        this.featureList = dataSet.getFeatureList();
    }

    public CMLCRF(int numClasses, int numFeatures, List<MultiLabel> supportCombinations) {
        this.numClasses = numClasses;
        this.numFeatures = numFeatures;
        this.weights = new Weights(numClasses, numFeatures);
        this.supportCombinations = supportCombinations;
        this.numSupports = supportCombinations.size();
        this.combinationLabelPartScores = new double[supportCombinations.size()];
        this.updateCombLabelPartScores();
    }

    public double getLossStrength() {
        return this.lossStrength;
    }

    public void setLossStrength(double lossStrength) {
        this.lossStrength = lossStrength;
    }

    public boolean considerPair() {
        return this.considerPair;
    }

    public void setConsiderPair(boolean considerPair) {
        this.considerPair = considerPair;
        this.updateCombLabelPartScores();
    }

    double predictClassScore(Vector vector, int classIndex) {
        double score = 0.0;
        score += this.weights.getWeightsWithoutBiasForClass(classIndex).dot(vector);
        return score += this.weights.getBiasForClass(classIndex);
    }

    double[] predictClassScores(Vector vector) {
        double[] scores = new double[this.numClasses];
        for (int k = 0; k < this.numClasses; ++k) {
            scores[k] = this.predictClassScore(vector, k);
        }
        return scores;
    }

    public double[] predictCombinationScores(Vector vector) {
        double[] classScores = this.predictClassScores(vector);
        return this.predictCombinationScores(classScores);
    }

    double[] predictCombinationScores(double[] classScores) {
        double[] scores = new double[this.numSupports];
        for (int k = 0; k < scores.length; ++k) {
            scores[k] = this.predictCombinationScore(k, classScores);
        }
        return scores;
    }

    private double predictCombinationScore(int labelComIndex, double[] classScores) {
        MultiLabel label = this.supportCombinations.get(labelComIndex);
        double score = 0.0;
        for (Integer l : label.getMatchedLabels()) {
            score += classScores[l];
        }
        if (this.considerPair) {
            score += this.combinationLabelPartScores[labelComIndex];
        }
        return score;
    }

    public double[] predictLossAugmentedCombinationScores(int trueComIndex, Vector vector, double[][] lossMatrix) {
        double[] classScores = this.predictClassScores(vector);
        return this.predictLossAugmentedCombinationScores(trueComIndex, classScores, lossMatrix);
    }

    double[] predictLossAugmentedCombinationScores(int trueComIndex, double[] classScores, double[][] lossMatrix) {
        double[] scores = new double[this.numSupports];
        for (int k = 0; k < scores.length; ++k) {
            scores[k] = this.predictLossAugmentedCombinationScore(trueComIndex, k, classScores, lossMatrix);
        }
        return scores;
    }

    private double predictLossAugmentedCombinationScore(int trueComIndex, int predictComIndex, double[] classScores, double[][] lossMatrix) {
        double original = this.predictCombinationScore(predictComIndex, classScores);
        return original + this.lossStrength * lossMatrix[trueComIndex][predictComIndex];
    }

    double computeLabelPartScore(int labelComIndex) {
        MultiLabel label = this.supportCombinations.get(labelComIndex);
        double score = 0.0;
        int pos = this.weights.getNumWeightsForFeatures();
        boolean[] matches = new boolean[this.numClasses];
        for (int match : label.getMatchedLabels()) {
            matches[match] = true;
        }
        for (int l1 = 0; l1 < this.numClasses; ++l1) {
            for (int l2 = l1 + 1; l2 < this.numClasses; ++l2) {
                score = !matches[l1] && !matches[l2] ? (score += this.weights.getWeightForIndex(pos)) : (matches[l1] && !matches[l2] ? (score += this.weights.getWeightForIndex(pos + 1)) : (!matches[l1] && matches[l2] ? (score += this.weights.getWeightForIndex(pos + 2)) : (score += this.weights.getWeightForIndex(pos + 3))));
                pos += 4;
            }
        }
        return score;
    }

    void updateCombLabelPartScores() {
        IntStream.range(0, this.supportCombinations.size()).parallel().forEach(c -> {
            this.combinationLabelPartScores[c] = this.computeLabelPartScore(c);
        });
    }

    public double[] predictCombinationProbs(Vector vector) {
        double[] combinationScores = this.predictCombinationScores(vector);
        return this.predictCombinationProbs(combinationScores);
    }

    public double[] predictCombinationProbs(double[] combinationScores) {
        return MathUtil.softmax(combinationScores);
    }

    public double[] predictLogCombinationProbs(Vector vector) {
        double[] scoreVector = this.predictCombinationScores(vector);
        double[] logProbVector = new double[this.numSupports];
        double logDenominator = MathUtil.logSumExp(scoreVector);
        for (int k = 0; k < this.numSupports; ++k) {
            double logNumerator = scoreVector[k];
            logProbVector[k] = logNumerator - logDenominator;
        }
        return logProbVector;
    }

    public double[] calClassProbs(double[] assignmentProbs) {
        double[] classProbs = new double[this.numClasses];
        for (int a = 0; a < this.numSupports; ++a) {
            MultiLabel assignment = this.supportCombinations.get(a);
            double prob = assignmentProbs[a];
            for (Integer label : assignment.getMatchedLabels()) {
                int n = label;
                classProbs[n] = classProbs[n] + prob;
            }
        }
        return classProbs;
    }

    public double[] predictClassProbs(Vector vector) {
        double[] combProbs = this.predictCombinationProbs(vector);
        return this.calClassProbs(combProbs);
    }

    @Override
    public int getNumClasses() {
        return this.numClasses;
    }

    public int getNumSupports() {
        return this.numSupports;
    }

    public int getNumFeatures() {
        return this.numFeatures;
    }

    public Weights getWeights() {
        return this.weights;
    }

    public List<MultiLabel> getSupportCombinations() {
        return this.supportCombinations;
    }

    @Override
    public MultiLabel predict(Vector vector) {
        double[] scores = this.predictCombinationScores(vector);
        int predictedCombination = ArgMax.argMax(scores);
        return this.supportCombinations.get(predictedCombination).copy();
    }

    public MultiLabel predictByArgmax(Vector vector) {
        double[] scores = this.predictClassScores(vector);
        MultiLabel label = new MultiLabel();
        for (int l = 0; l < scores.length; ++l) {
            if (!(scores[l] > 0.0)) continue;
            label.addLabel(l);
        }
        return label;
    }

    @Override
    public FeatureList getFeatureList() {
        return this.featureList;
    }

    @Override
    public LabelTranslator getLabelTranslator() {
        return this.labelTranslator;
    }

    public String toString() {
        return this.getWeights().toString();
    }

    /*
     * Exception decompiling
     */
    public static CMLCRF deserialize(File file) throws Exception {
        /*
         * This method has failed to decompile.  When submitting a bug report, please provide this stack trace, and (if you hold appropriate legal rights) the relevant class file.
         * 
         * org.benf.cfr.reader.util.ConfusedCFRException: Started 4 blocks at once
         *     at org.benf.cfr.reader.bytecode.analysis.opgraph.Op04StructuredStatement.getStartingBlocks(Op04StructuredStatement.java:412)
         *     at org.benf.cfr.reader.bytecode.analysis.opgraph.Op04StructuredStatement.buildNestedBlocks(Op04StructuredStatement.java:487)
         *     at org.benf.cfr.reader.bytecode.analysis.opgraph.Op03SimpleStatement.createInitialStructuredBlock(Op03SimpleStatement.java:736)
         *     at org.benf.cfr.reader.bytecode.CodeAnalyser.getAnalysisInner(CodeAnalyser.java:850)
         *     at org.benf.cfr.reader.bytecode.CodeAnalyser.getAnalysisOrWrapFail(CodeAnalyser.java:278)
         *     at org.benf.cfr.reader.bytecode.CodeAnalyser.getAnalysis(CodeAnalyser.java:201)
         *     at org.benf.cfr.reader.entities.attributes.AttributeCode.analyse(AttributeCode.java:94)
         *     at org.benf.cfr.reader.entities.Method.analyse(Method.java:531)
         *     at org.benf.cfr.reader.entities.ClassFile.analyseMid(ClassFile.java:1055)
         *     at org.benf.cfr.reader.entities.ClassFile.analyseTop(ClassFile.java:942)
         *     at org.benf.cfr.reader.Driver.doJarVersionTypes(Driver.java:257)
         *     at org.benf.cfr.reader.Driver.doJar(Driver.java:139)
         *     at org.benf.cfr.reader.CfrDriverImpl.analyse(CfrDriverImpl.java:76)
         *     at org.benf.cfr.reader.Main.main(Main.java:54)
         */
        throw new IllegalStateException("Decompilation failed");
    }

    public static CMLCRF deserialize(String file) throws Exception {
        File file1 = new File(file);
        return CMLCRF.deserialize(file1);
    }

    @Override
    public void serialize(File file) throws Exception {
        File parent = file.getParentFile();
        if (!parent.exists()) {
            parent.mkdir();
        }
        try (FileOutputStream fileOutputStream = new FileOutputStream(file);
             BufferedOutputStream bufferedOutputStream = new BufferedOutputStream(fileOutputStream);
             ObjectOutputStream objectOutputStream = new ObjectOutputStream(bufferedOutputStream);){
            objectOutputStream.writeObject(this);
        }
    }

    @Override
    public void serialize(String file) throws Exception {
        File file1 = new File(file);
        this.serialize(file1);
    }

    @Override
    public double predictLogAssignmentProb(Vector vector, MultiLabel assignment) {
        double res = Double.NEGATIVE_INFINITY;
        double[] logComProbs = this.predictLogCombinationProbs(vector);
        for (int c = 0; c < this.numSupports; ++c) {
            if (!this.supportCombinations.get(c).equals(assignment)) continue;
            res = logComProbs[c];
            break;
        }
        return res;
    }
}

