/*
 * Decompiled with CFR 0.152.
 */
package edu.neu.ccs.pyramid.classification.lkboost;

import edu.neu.ccs.pyramid.classification.Classifier;
import edu.neu.ccs.pyramid.dataset.LabelTranslator;
import edu.neu.ccs.pyramid.optimization.gradient_boosting.Ensemble;
import edu.neu.ccs.pyramid.optimization.gradient_boosting.GradientBoosting;
import edu.neu.ccs.pyramid.util.ArgMax;
import edu.neu.ccs.pyramid.util.MathUtil;
import java.io.File;
import org.apache.mahout.math.Vector;

public class LKBoost
extends GradientBoosting
implements Classifier.ProbabilityEstimator,
Classifier.ScoreEstimator {
    private static final long serialVersionUID = 5L;
    private int numClasses;
    LabelTranslator labelTranslator;

    public LKBoost(int numClasses) {
        super(numClasses);
        this.numClasses = numClasses;
    }

    @Override
    public int predict(Vector vector) {
        double[] scores = this.predictClassScores(vector);
        return ArgMax.argMax(scores);
    }

    @Override
    public int getNumClasses() {
        return this.numClasses;
    }

    @Override
    public double predictClassScore(Vector vector, int k) {
        return this.score(vector, k);
    }

    @Override
    public double[] predictClassScores(Vector vector) {
        return this.scores(vector);
    }

    @Override
    public double[] predictClassProbs(Vector vector) {
        double[] scoreVector = this.predictClassScores(vector);
        double[] probVector = new double[this.numClasses];
        double logDenominator = MathUtil.logSumExp(scoreVector);
        for (int k = 0; k < this.numClasses; ++k) {
            double pro;
            double logNumerator = scoreVector[k];
            probVector[k] = pro = Math.exp(logNumerator - logDenominator);
        }
        return probVector;
    }

    @Override
    public double[] predictLogClassProbs(Vector vector) {
        double[] scoreVector = this.predictClassScores(vector);
        double[] logProbVector = new double[this.numClasses];
        double logDenominator = MathUtil.logSumExp(scoreVector);
        for (int k = 0; k < this.numClasses; ++k) {
            double logPro;
            double logNumerator = scoreVector[k];
            logProbVector[k] = logPro = logNumerator - logDenominator;
        }
        return logProbVector;
    }

    /*
     * Exception decompiling
     */
    public static LKBoost deserialize(File file) throws Exception {
        /*
         * This method has failed to decompile.  When submitting a bug report, please provide this stack trace, and (if you hold appropriate legal rights) the relevant class file.
         * 
         * org.benf.cfr.reader.util.ConfusedCFRException: Started 4 blocks at once
         *     at org.benf.cfr.reader.bytecode.analysis.opgraph.Op04StructuredStatement.getStartingBlocks(Op04StructuredStatement.java:412)
         *     at org.benf.cfr.reader.bytecode.analysis.opgraph.Op04StructuredStatement.buildNestedBlocks(Op04StructuredStatement.java:487)
         *     at org.benf.cfr.reader.bytecode.analysis.opgraph.Op03SimpleStatement.createInitialStructuredBlock(Op03SimpleStatement.java:736)
         *     at org.benf.cfr.reader.bytecode.CodeAnalyser.getAnalysisInner(CodeAnalyser.java:850)
         *     at org.benf.cfr.reader.bytecode.CodeAnalyser.getAnalysisOrWrapFail(CodeAnalyser.java:278)
         *     at org.benf.cfr.reader.bytecode.CodeAnalyser.getAnalysis(CodeAnalyser.java:201)
         *     at org.benf.cfr.reader.entities.attributes.AttributeCode.analyse(AttributeCode.java:94)
         *     at org.benf.cfr.reader.entities.Method.analyse(Method.java:531)
         *     at org.benf.cfr.reader.entities.ClassFile.analyseMid(ClassFile.java:1055)
         *     at org.benf.cfr.reader.entities.ClassFile.analyseTop(ClassFile.java:942)
         *     at org.benf.cfr.reader.Driver.doJarVersionTypes(Driver.java:257)
         *     at org.benf.cfr.reader.Driver.doJar(Driver.java:139)
         *     at org.benf.cfr.reader.CfrDriverImpl.analyse(CfrDriverImpl.java:76)
         *     at org.benf.cfr.reader.Main.main(Main.java:54)
         */
        throw new IllegalStateException("Decompilation failed");
    }

    @Override
    public String toString() {
        StringBuilder sb = new StringBuilder();
        for (int k = 0; k < this.numClasses; ++k) {
            sb.append("for class ").append(k).append("\n");
            Ensemble trees = this.getEnsemble(k);
            for (int i = 0; i < trees.getRegressors().size(); ++i) {
                sb.append("tree ").append(i).append(":");
                sb.append(trees.get(i).toString());
            }
        }
        return sb.toString();
    }

    @Override
    public LabelTranslator getLabelTranslator() {
        return this.labelTranslator;
    }
}

