/*
 * Decompiled with CFR 0.152.
 */
package edu.neu.ccs.pyramid.multilabel_classification.cbm;

import edu.neu.ccs.pyramid.dataset.MultiLabel;
import edu.neu.ccs.pyramid.multilabel_classification.PluginPredictor;
import edu.neu.ccs.pyramid.multilabel_classification.cbm.CBM;
import edu.neu.ccs.pyramid.multilabel_classification.plugin_rule.GeneralF1Predictor;
import java.util.List;
import org.apache.mahout.math.Vector;

public class PluginF1
implements PluginPredictor<CBM> {
    CBM cbm;
    private String predictionMode = "support";
    private int numSamples = 1000;
    private List<MultiLabel> support;
    private double piThreshold = 0.001;
    private int maxSize = 20;

    public void setMaxSize(int maxSize) {
        this.maxSize = maxSize;
    }

    public PluginF1(CBM model) {
        this.cbm = model;
    }

    public PluginF1(CBM cbm, List<MultiLabel> support) {
        this.cbm = cbm;
        this.support = support;
    }

    public void setNumSamples(int numSamples) {
        this.numSamples = numSamples;
    }

    public void setPredictionMode(String predictionMode) {
        this.predictionMode = predictionMode;
    }

    public String getPredictionMode() {
        return this.predictionMode;
    }

    public void setSupport(List<MultiLabel> support) {
        this.support = support;
    }

    public void setPiThreshold(double piThreshold) {
        this.piThreshold = piThreshold;
    }

    @Override
    public MultiLabel predict(Vector vector) {
        MultiLabel pred = null;
        switch (this.predictionMode) {
            case "support": {
                pred = this.predictBySupport(vector);
                break;
            }
            case "sampling": {
                pred = this.predictBySampling(vector);
                break;
            }
            default: {
                throw new IllegalArgumentException("unknown mode");
            }
        }
        return pred;
    }

    private MultiLabel predictBySampling(Vector vector) {
        List<MultiLabel> samples = this.cbm.samples(vector, this.numSamples);
        GeneralF1Predictor generalF1Predictor = new GeneralF1Predictor();
        generalF1Predictor.setMaxSize(this.maxSize);
        return generalF1Predictor.predict(this.cbm.getNumClasses(), samples);
    }

    private MultiLabel predictBySupport(Vector vector) {
        double[] probs = this.cbm.predictAssignmentProbs(vector, this.support, this.piThreshold);
        GeneralF1Predictor generalF1Predictor = new GeneralF1Predictor();
        generalF1Predictor.setMaxSize(this.maxSize);
        return generalF1Predictor.predict(this.cbm.getNumClasses(), this.support, probs);
    }

    public GeneralF1Predictor.Analysis showPredictBySupport(Vector vector, MultiLabel truth) {
        double[] probArray = this.cbm.predictAssignmentProbs(vector, this.support);
        GeneralF1Predictor generalF1Predictor = new GeneralF1Predictor();
        MultiLabel prediction = generalF1Predictor.predict(this.cbm.getNumClasses(), this.support, probArray);
        GeneralF1Predictor.Analysis analysis = GeneralF1Predictor.showSupportPrediction(this.support, probArray, truth, prediction, this.cbm.getNumClasses());
        return analysis;
    }

    @Override
    public CBM getModel() {
        return this.cbm;
    }
}

