/*
 * Decompiled with CFR 0.152.
 */
package edu.neu.ccs.pyramid.multilabel_classification.imlgb;

import edu.neu.ccs.pyramid.dataset.LabelTranslator;
import edu.neu.ccs.pyramid.dataset.MultiLabel;
import edu.neu.ccs.pyramid.feature.FeatureList;
import edu.neu.ccs.pyramid.multilabel_classification.MultiLabelClassifier;
import edu.neu.ccs.pyramid.multilabel_classification.imlgb.HammingPredictor;
import edu.neu.ccs.pyramid.regression.Regressor;
import edu.neu.ccs.pyramid.util.MathUtil;
import java.io.File;
import java.util.ArrayList;
import java.util.List;
import java.util.stream.IntStream;
import org.apache.mahout.math.Vector;

public class IMLGradientBoosting
implements MultiLabelClassifier.ClassScoreEstimator,
MultiLabelClassifier.ClassProbEstimator {
    private static final long serialVersionUID = 3L;
    private List<List<Regressor>> regressors;
    private int numClasses;
    private List<MultiLabel> assignments;
    private FeatureList featureList;
    private LabelTranslator labelTranslator;
    @Deprecated
    private PredictFashion predictFashion = PredictFashion.INDEPENDENT;

    public IMLGradientBoosting(int numClasses) {
        this.numClasses = numClasses;
        this.regressors = new ArrayList<List<Regressor>>(this.numClasses);
        for (int k = 0; k < this.numClasses; ++k) {
            ArrayList regressorsClassK = new ArrayList();
            this.regressors.add(regressorsClassK);
        }
    }

    @Override
    public int getNumClasses() {
        return this.numClasses;
    }

    public List<MultiLabel> getAssignments() {
        return this.assignments;
    }

    public void setAssignments(List<MultiLabel> assignments) {
        this.assignments = assignments;
    }

    void addRegressor(Regressor regressor, int k) {
        this.regressors.get(k).add(regressor);
    }

    @Override
    public MultiLabel predict(Vector vector) {
        HammingPredictor hammingPredictor = new HammingPredictor(this);
        return hammingPredictor.predict(vector);
    }

    double calAssignmentScore(MultiLabel assignment, double[] classScores) {
        double score = 0.0;
        for (Integer label : assignment.getMatchedLabels()) {
            score += classScores[label];
        }
        return score;
    }

    @Override
    public double predictClassScore(Vector vector, int k) {
        List<Regressor> regressorsClassK = this.regressors.get(k);
        double score = 0.0;
        for (Regressor regressor : regressorsClassK) {
            score += regressor.predict(vector);
        }
        return score;
    }

    @Override
    public double[] predictClassScores(Vector vector) {
        int numClasses = this.numClasses;
        double[] scores = new double[numClasses];
        for (int k = 0; k < numClasses; ++k) {
            scores[k] = this.predictClassScore(vector, k);
        }
        return scores;
    }

    public List<Regressor> getRegressors(int k) {
        return this.regressors.get(k);
    }

    @Override
    public double predictClassProb(Vector vector, int classIndex) {
        double score;
        double logNumerator = score = this.predictClassScore(vector, classIndex);
        double[] scores = new double[]{0.0, score};
        double logDenominator = MathUtil.logSumExp(scores);
        double pro = Math.exp(logNumerator - logDenominator);
        return pro;
    }

    public double[] predictLogClassProbs(Vector vector, int classIndex) {
        double score = this.predictClassScore(vector, classIndex);
        double[] scores = new double[]{0.0, score};
        return MathUtil.logSoftmax(scores);
    }

    @Override
    public double[] predictClassProbs(Vector vector) {
        return IntStream.range(0, this.numClasses).mapToDouble(k -> this.predictClassProb(vector, k)).toArray();
    }

    double predictAssignmentProbWithConstraint(Vector vector, MultiLabel assignment) {
        if (this.assignments == null) {
            throw new RuntimeException("CRF is used but legal assignments is not specified!");
        }
        if (!this.assignments.contains(assignment)) {
            return 0.0;
        }
        double[] classScores = this.predictClassScores(vector);
        double[] assignmentScores = new double[this.assignments.size()];
        for (int i = 0; i < this.assignments.size(); ++i) {
            assignmentScores[i] = this.calAssignmentScore(this.assignments.get(i), classScores);
        }
        double logNumerator = this.calAssignmentScore(assignment, classScores);
        double logDenominator = MathUtil.logSumExp(assignmentScores);
        double pro = Math.exp(logNumerator - logDenominator);
        return pro;
    }

    public double[] predictAllAssignmentProbsWithConstraint(Vector vector) {
        if (this.assignments == null) {
            throw new RuntimeException("CRF is used but legal assignments is not specified!");
        }
        double[] classScores = this.predictClassScores(vector);
        double[] assignmentScores = new double[this.assignments.size()];
        for (int i = 0; i < this.assignments.size(); ++i) {
            assignmentScores[i] = this.calAssignmentScore(this.assignments.get(i), classScores);
        }
        double logDenominator = MathUtil.logSumExp(assignmentScores);
        double[] probs = new double[this.assignments.size()];
        for (int i = 0; i < this.assignments.size(); ++i) {
            double pro;
            double logNumerator = this.calAssignmentScore(this.assignments.get(i), classScores);
            probs[i] = pro = Math.exp(logNumerator - logDenominator);
        }
        return probs;
    }

    double predictAssignmentProbWithoutConstraint(Vector vector, MultiLabel assignment) {
        double[] classScores = this.predictClassScores(vector);
        double logProb = 0.0;
        for (int k = 0; k < this.numClasses; ++k) {
            double logNumerator = 0.0;
            if (assignment.matchClass(k)) {
                logNumerator = classScores[k];
            }
            double[] scores = new double[]{0.0, classScores[k]};
            double logDenominator = MathUtil.logSumExp(scores);
            logProb += logNumerator;
            logProb -= logDenominator;
        }
        return Math.exp(logProb);
    }

    public String toString() {
        StringBuilder sb = new StringBuilder();
        for (int k = 0; k < this.numClasses; ++k) {
            sb.append("for class ").append(k).append("\n");
            List<Regressor> trees = this.getRegressors(k);
            for (int i = 0; i < trees.size(); ++i) {
                sb.append("tree ").append(i).append(":");
                sb.append(trees.get(i).toString());
            }
        }
        return sb.toString();
    }

    public static IMLGradientBoosting deserialize(String file) throws Exception {
        return IMLGradientBoosting.deserialize(new File(file));
    }

    /*
     * Exception decompiling
     */
    public static IMLGradientBoosting deserialize(File file) throws Exception {
        /*
         * This method has failed to decompile.  When submitting a bug report, please provide this stack trace, and (if you hold appropriate legal rights) the relevant class file.
         * 
         * org.benf.cfr.reader.util.ConfusedCFRException: Started 4 blocks at once
         *     at org.benf.cfr.reader.bytecode.analysis.opgraph.Op04StructuredStatement.getStartingBlocks(Op04StructuredStatement.java:412)
         *     at org.benf.cfr.reader.bytecode.analysis.opgraph.Op04StructuredStatement.buildNestedBlocks(Op04StructuredStatement.java:487)
         *     at org.benf.cfr.reader.bytecode.analysis.opgraph.Op03SimpleStatement.createInitialStructuredBlock(Op03SimpleStatement.java:736)
         *     at org.benf.cfr.reader.bytecode.CodeAnalyser.getAnalysisInner(CodeAnalyser.java:850)
         *     at org.benf.cfr.reader.bytecode.CodeAnalyser.getAnalysisOrWrapFail(CodeAnalyser.java:278)
         *     at org.benf.cfr.reader.bytecode.CodeAnalyser.getAnalysis(CodeAnalyser.java:201)
         *     at org.benf.cfr.reader.entities.attributes.AttributeCode.analyse(AttributeCode.java:94)
         *     at org.benf.cfr.reader.entities.Method.analyse(Method.java:531)
         *     at org.benf.cfr.reader.entities.ClassFile.analyseMid(ClassFile.java:1055)
         *     at org.benf.cfr.reader.entities.ClassFile.analyseTop(ClassFile.java:942)
         *     at org.benf.cfr.reader.Driver.doJarVersionTypes(Driver.java:257)
         *     at org.benf.cfr.reader.Driver.doJar(Driver.java:139)
         *     at org.benf.cfr.reader.CfrDriverImpl.analyse(CfrDriverImpl.java:76)
         *     at org.benf.cfr.reader.Main.main(Main.java:54)
         */
        throw new IllegalStateException("Decompilation failed");
    }

    @Override
    public FeatureList getFeatureList() {
        return this.featureList;
    }

    void setFeatureList(FeatureList featureList) {
        this.featureList = featureList;
    }

    @Override
    public LabelTranslator getLabelTranslator() {
        return this.labelTranslator;
    }

    void setLabelTranslator(LabelTranslator labelTranslator) {
        this.labelTranslator = labelTranslator;
    }

    public static enum PredictFashion {
        CRF,
        INDEPENDENT,
        CRF_PLUS_HIGH_PROB;

    }
}

