/*
 * Decompiled with CFR 0.152.
 */
package edu.neu.ccs.pyramid.multilabel_classification.thresholding;

import edu.neu.ccs.pyramid.dataset.LabelTranslator;
import edu.neu.ccs.pyramid.dataset.MultiLabel;
import edu.neu.ccs.pyramid.feature.FeatureList;
import edu.neu.ccs.pyramid.multilabel_classification.MultiLabelClassifier;
import java.util.Arrays;
import org.apache.mahout.math.Vector;

public class TunedMarginalClassifier
implements MultiLabelClassifier {
    private static final long serialVersionUID = 1L;
    private MultiLabelClassifier.ClassProbEstimator classProbEstimator;
    private double[] thresholds;

    public TunedMarginalClassifier(MultiLabelClassifier.ClassProbEstimator classProbEstimator) {
        this.classProbEstimator = classProbEstimator;
        this.thresholds = new double[classProbEstimator.getNumClasses()];
    }

    public TunedMarginalClassifier(MultiLabelClassifier.ClassProbEstimator classProbEstimator, double[] thresholds) {
        this.classProbEstimator = classProbEstimator;
        this.thresholds = thresholds;
    }

    public double[] getThresholds() {
        return this.thresholds;
    }

    public void setThresholds(double[] thresholds) {
        this.thresholds = thresholds;
    }

    public void setThresholdSameValue(double threshold) {
        Arrays.fill(this.thresholds, threshold);
    }

    @Override
    public int getNumClasses() {
        return this.classProbEstimator.getNumClasses();
    }

    @Override
    public MultiLabel predict(Vector vector) {
        MultiLabel multiLabel = new MultiLabel();
        int numClasses = this.classProbEstimator.getNumClasses();
        double[] probs = this.classProbEstimator.predictClassProbs(vector);
        for (int l = 0; l < numClasses; ++l) {
            if (!(probs[l] > this.thresholds[l])) continue;
            multiLabel.addLabel(l);
        }
        return multiLabel;
    }

    @Override
    public FeatureList getFeatureList() {
        return null;
    }

    @Override
    public LabelTranslator getLabelTranslator() {
        return null;
    }
}

