/*
 * Decompiled with CFR 0.152.
 */
package edu.neu.ccs.pyramid.multilabel_classification.cbm;

import edu.neu.ccs.pyramid.classification.Classifier;
import edu.neu.ccs.pyramid.classification.logistic_regression.LogisticRegression;
import edu.neu.ccs.pyramid.dataset.LabelTranslator;
import edu.neu.ccs.pyramid.dataset.MultiLabel;
import edu.neu.ccs.pyramid.feature.FeatureList;
import edu.neu.ccs.pyramid.multilabel_classification.MultiLabelClassifier;
import edu.neu.ccs.pyramid.multilabel_classification.cbm.AugmentedLR;
import edu.neu.ccs.pyramid.multilabel_classification.cbm.BMDistribution;
import edu.neu.ccs.pyramid.multilabel_classification.cbm.CBMPredictor;
import edu.neu.ccs.pyramid.util.ArgSort;
import java.io.Serializable;
import java.util.Arrays;
import java.util.List;
import org.apache.mahout.math.Vector;

public class CBMS
implements MultiLabelClassifier.ClassProbEstimator,
Serializable {
    private static final long serialVersionUID = 2L;
    int numLabels;
    int numComponents;
    private int numSample = 100;
    private boolean allowEmpty = false;
    private String predictMode = "dynamic";
    AugmentedLR[] binaryClassifiers;
    Classifier.ProbabilityEstimator multiClassClassifier;

    public CBMS(int numLabels, int numComponents, int numFeatures) {
        this.numLabels = numLabels;
        this.numComponents = numComponents;
        this.binaryClassifiers = new AugmentedLR[numLabels];
        for (int l = 0; l < numLabels; ++l) {
            this.binaryClassifiers[l] = new AugmentedLR(numFeatures, numComponents);
        }
        this.multiClassClassifier = new LogisticRegression(numComponents, numFeatures);
    }

    @Override
    public int getNumClasses() {
        return this.numLabels;
    }

    double[] posteriorMembership(Vector x, MultiLabel y) {
        BMDistribution bmDistribution = this.computeBM(x);
        return bmDistribution.posteriorMembership(y);
    }

    BMDistribution computeBM(Vector x) {
        return new BMDistribution(this, x);
    }

    private double predictLogAssignmentProb(Vector x, MultiLabel y) {
        BMDistribution bmDistribution = this.computeBM(x);
        return bmDistribution.logProbability(y);
    }

    public double predictAssignmentProb(Vector vector, MultiLabel assignment) {
        return Math.exp(this.predictLogAssignmentProb(vector, assignment));
    }

    private double[] predictLogAssignmentProbs(Vector x, List<MultiLabel> assignments) {
        BMDistribution bmDistribution = this.computeBM(x);
        double[] probs = new double[assignments.size()];
        for (int c = 0; c < assignments.size(); ++c) {
            MultiLabel multiLabel = assignments.get(c);
            probs[c] = bmDistribution.logProbability(multiLabel);
        }
        return probs;
    }

    public double[] predictAssignmentProbs(Vector vector, List<MultiLabel> assignments) {
        double[] logProbs = this.predictLogAssignmentProbs(vector, assignments);
        return Arrays.stream(logProbs).map(Math::exp).toArray();
    }

    @Override
    public MultiLabel predict(Vector vector) {
        CBMPredictor CBMPredictor2 = new CBMPredictor(this.computeBM(vector));
        CBMPredictor2.setNumSamples(this.numSample);
        CBMPredictor2.setAllowEmpty(this.allowEmpty);
        return CBMPredictor2.predictByDynamic();
    }

    @Override
    public double[] predictClassProbs(Vector vector) {
        BMDistribution bmDistribution = this.computeBM(vector);
        return bmDistribution.marginals();
    }

    MultiLabel predictByMarginals(Vector vector) {
        double[] probs = this.predictClassProbs(vector);
        MultiLabel prediction = new MultiLabel();
        for (int l = 0; l < this.numLabels; ++l) {
            if (!(probs[l] > 0.5)) continue;
            prediction.addLabel(l);
        }
        return prediction;
    }

    public MultiLabel predictByMarginals(Vector vector, int top) {
        double[] probs = this.predictClassProbs(vector);
        int[] sortedIndices = ArgSort.argSortDescending(probs);
        MultiLabel prediction = new MultiLabel();
        for (int i = 0; i < top; ++i) {
            prediction.addLabel(sortedIndices[i]);
        }
        return prediction;
    }

    public void setPredictMode(String mode) {
        this.predictMode = mode;
    }

    public void setAllowEmpty(boolean allowEmpty) {
        this.allowEmpty = allowEmpty;
    }

    @Override
    public FeatureList getFeatureList() {
        return null;
    }

    @Override
    public LabelTranslator getLabelTranslator() {
        return null;
    }

    public void setNumSample(int numSample) {
        this.numSample = numSample;
    }

    public List<MultiLabel> samples(Vector x, int numSamples) {
        BMDistribution bmDistribution = this.computeBM(x);
        return bmDistribution.sample(numSamples);
    }

    public AugmentedLR[] getBinaryClassifiers() {
        return this.binaryClassifiers;
    }

    public Classifier.ProbabilityEstimator getMultiClassClassifier() {
        return this.multiClassClassifier;
    }

    public int getNumComponents() {
        return this.numComponents;
    }
}

