/*
 * Decompiled with CFR 0.152.
 */
package edu.neu.ccs.pyramid.multilabel_classification.cbm;

import edu.neu.ccs.pyramid.dataset.MultiLabel;
import edu.neu.ccs.pyramid.multilabel_classification.PluginPredictor;
import edu.neu.ccs.pyramid.multilabel_classification.cbm.BMDistribution;
import edu.neu.ccs.pyramid.multilabel_classification.cbm.CBM;
import org.apache.mahout.math.Vector;

public class MarginalPredictor
implements PluginPredictor<CBM> {
    CBM cbm;
    double piThreshold = 0.001;

    public MarginalPredictor(CBM cbm) {
        this.cbm = cbm;
    }

    public void setPiThreshold(double piThreshold) {
        this.piThreshold = piThreshold;
    }

    @Override
    public CBM getModel() {
        return this.cbm;
    }

    @Override
    public MultiLabel predict(Vector vector) {
        BMDistribution bmDistribution = new BMDistribution(this.cbm, vector, this.piThreshold);
        double[] probs = bmDistribution.marginals();
        MultiLabel prediction = new MultiLabel();
        for (int l = 0; l < this.cbm.getNumClasses(); ++l) {
            if (!(probs[l] > 0.5)) continue;
            prediction.addLabel(l);
        }
        return prediction;
    }
}

