/*
 * Decompiled with CFR 0.152.
 */
package edu.neu.ccs.pyramid.multilabel_classification.imlgb;

import edu.neu.ccs.pyramid.dataset.MultiLabel;
import edu.neu.ccs.pyramid.multilabel_classification.PluginPredictor;
import edu.neu.ccs.pyramid.multilabel_classification.imlgb.IMLGradientBoosting;
import org.apache.mahout.math.Vector;

public class HammingPredictor
implements PluginPredictor<IMLGradientBoosting> {
    private static final long serialVersionUID = 1L;
    private IMLGradientBoosting imlGradientBoosting;

    public HammingPredictor(IMLGradientBoosting imlGradientBoosting) {
        this.imlGradientBoosting = imlGradientBoosting;
    }

    @Override
    public IMLGradientBoosting getModel() {
        return this.imlGradientBoosting;
    }

    @Override
    public MultiLabel predict(Vector vector) {
        MultiLabel prediction = new MultiLabel();
        for (int k = 0; k < this.getNumClasses(); ++k) {
            double score = this.imlGradientBoosting.predictClassScore(vector, k);
            if (!(score > 0.0)) continue;
            prediction.addLabel(k);
        }
        return prediction;
    }

    public double predictAssignmentProb(Vector vector, MultiLabel assignment) {
        if (assignment.outOfBound(this.imlGradientBoosting.getNumClasses())) {
            return 0.0;
        }
        return this.imlGradientBoosting.predictAssignmentProbWithoutConstraint(vector, assignment);
    }
}

