/*
 * Decompiled with CFR 0.152.
 */
package edu.neu.ccs.pyramid.multilabel_classification.cbm;

import edu.neu.ccs.pyramid.dataset.MultiLabel;
import edu.neu.ccs.pyramid.multilabel_classification.PluginPredictor;
import edu.neu.ccs.pyramid.multilabel_classification.cbm.CBMS;
import edu.neu.ccs.pyramid.multilabel_classification.plugin_rule.GeneralF1Predictor;
import java.util.List;
import org.apache.mahout.math.Vector;

public class CBMSF1Predictor
implements PluginPredictor<CBMS> {
    CBMS cbm;
    private List<MultiLabel> support;

    public CBMSF1Predictor(CBMS model) {
        this.cbm = model;
    }

    public CBMSF1Predictor(CBMS cbm, List<MultiLabel> support) {
        this.cbm = cbm;
        this.support = support;
    }

    public void setSupport(List<MultiLabel> support) {
        this.support = support;
    }

    @Override
    public MultiLabel predict(Vector vector) {
        MultiLabel pred = this.predictBySupport(vector);
        return pred;
    }

    private MultiLabel predictBySupport(Vector vector) {
        double[] probs = this.cbm.predictAssignmentProbs(vector, this.support);
        GeneralF1Predictor generalF1Predictor = new GeneralF1Predictor();
        return generalF1Predictor.predict(this.cbm.getNumClasses(), this.support, probs);
    }

    @Override
    public CBMS getModel() {
        return this.cbm;
    }
}

