/*
 * Decompiled with CFR 0.152.
 */
package edu.neu.ccs.pyramid.multilabel_classification.imlgb;

import edu.neu.ccs.pyramid.dataset.MultiLabel;
import edu.neu.ccs.pyramid.multilabel_classification.PluginPredictor;
import edu.neu.ccs.pyramid.multilabel_classification.imlgb.IMLGradientBoosting;
import edu.neu.ccs.pyramid.multilabel_classification.plugin_rule.GeneralF1Predictor;
import java.util.Arrays;
import java.util.List;
import java.util.stream.Collectors;
import org.apache.mahout.math.Vector;

public class InstanceF1Predictor
implements PluginPredictor<IMLGradientBoosting> {
    private static final long serialVersionUID = 1L;
    private IMLGradientBoosting imlGradientBoosting;

    public InstanceF1Predictor(IMLGradientBoosting imlGradientBoosting) {
        this.imlGradientBoosting = imlGradientBoosting;
    }

    @Override
    public IMLGradientBoosting getModel() {
        return this.imlGradientBoosting;
    }

    @Override
    public MultiLabel predict(Vector vector) {
        double[] probs = this.imlGradientBoosting.predictAllAssignmentProbsWithConstraint(vector);
        List<Double> probList = Arrays.stream(probs).mapToObj(a -> a).collect(Collectors.toList());
        GeneralF1Predictor generalF1Predictor = new GeneralF1Predictor();
        return generalF1Predictor.predict(this.imlGradientBoosting.getNumClasses(), this.imlGradientBoosting.getAssignments(), probList);
    }
}

