/*
 * Decompiled with CFR 0.152.
 */
package edu.neu.ccs.pyramid.multilabel_classification.imlgb;

import edu.neu.ccs.pyramid.dataset.MultiLabel;
import edu.neu.ccs.pyramid.multilabel_classification.PluginPredictor;
import edu.neu.ccs.pyramid.multilabel_classification.imlgb.IMLGradientBoosting;
import org.apache.mahout.math.Vector;

public class SubsetAccPredictor
implements PluginPredictor<IMLGradientBoosting> {
    private static final long serialVersionUID = 1L;
    private IMLGradientBoosting imlGradientBoosting;

    public SubsetAccPredictor(IMLGradientBoosting imlGradientBoosting) {
        this.imlGradientBoosting = imlGradientBoosting;
    }

    @Override
    public IMLGradientBoosting getModel() {
        return null;
    }

    @Override
    public MultiLabel predict(Vector vector) {
        if (this.imlGradientBoosting.getAssignments() == null) {
            throw new RuntimeException("CRF is used but legal assignments is not specified!");
        }
        double maxScore = Double.NEGATIVE_INFINITY;
        MultiLabel prediction = null;
        double[] classScores = this.imlGradientBoosting.predictClassScores(vector);
        for (MultiLabel assignment : this.imlGradientBoosting.getAssignments()) {
            double score = this.imlGradientBoosting.calAssignmentScore(assignment, classScores);
            if (!(score > maxScore)) continue;
            maxScore = score;
            prediction = assignment;
        }
        return prediction;
    }

    public double predictAssignmentProb(Vector vector, MultiLabel assignment) {
        if (assignment.outOfBound(this.imlGradientBoosting.getNumClasses())) {
            return 0.0;
        }
        return this.imlGradientBoosting.predictAssignmentProbWithConstraint(vector, assignment);
    }
}

