/*
 * Decompiled with CFR 0.152.
 */
package edu.neu.ccs.pyramid.classification.l2boost;

import edu.neu.ccs.pyramid.classification.Classifier;
import edu.neu.ccs.pyramid.dataset.LabelTranslator;
import edu.neu.ccs.pyramid.feature.FeatureList;
import edu.neu.ccs.pyramid.optimization.gradient_boosting.GradientBoosting;
import edu.neu.ccs.pyramid.util.MathUtil;
import org.apache.mahout.math.Vector;

public class L2Boost
extends GradientBoosting
implements Classifier.ScoreEstimator,
Classifier.ProbabilityEstimator {
    private FeatureList featureList;
    private LabelTranslator labelTranslator;

    public L2Boost() {
        super(1);
    }

    @Override
    public double[] predictClassProbs(Vector vector) {
        double[] scores = this.predictClassScores(vector);
        return this.predictClassProbs(scores);
    }

    double[] predictClassProbs(double[] scores) {
        double[] probVector = new double[2];
        double logDenominator = MathUtil.logSumExp(scores);
        for (int k = 0; k < 2; ++k) {
            double pro;
            double logNumerator = scores[k];
            probVector[k] = pro = Math.exp(logNumerator - logDenominator);
        }
        return probVector;
    }

    @Override
    public double predictClassScore(Vector vector, int k) {
        if (k == 0) {
            return 0.0;
        }
        return this.getEnsemble(0).score(vector);
    }

    @Override
    public int getNumClasses() {
        return 2;
    }

    @Override
    public FeatureList getFeatureList() {
        return this.featureList;
    }

    void setFeatureList(FeatureList featureList) {
        this.featureList = featureList;
    }

    @Override
    public LabelTranslator getLabelTranslator() {
        return this.labelTranslator;
    }

    void setLabelTranslator(LabelTranslator labelTranslator) {
        this.labelTranslator = labelTranslator;
    }
}

