/*
 * Decompiled with CFR 0.152.
 */
package edu.neu.ccs.pyramid.multilabel_classification.cbm;

import edu.neu.ccs.pyramid.classification.Classifier;
import edu.neu.ccs.pyramid.classification.lkboost.LKBoost;
import edu.neu.ccs.pyramid.classification.logistic_regression.LogisticRegression;
import edu.neu.ccs.pyramid.dataset.LabelTranslator;
import edu.neu.ccs.pyramid.dataset.MultiLabel;
import edu.neu.ccs.pyramid.feature.FeatureList;
import edu.neu.ccs.pyramid.multilabel_classification.MultiLabelClassifier;
import edu.neu.ccs.pyramid.multilabel_classification.cbm.BMDistribution;
import edu.neu.ccs.pyramid.multilabel_classification.cbm.CBMPredictor;
import edu.neu.ccs.pyramid.multilabel_classification.cbm.ShortCircuitPosterior;
import edu.neu.ccs.pyramid.util.ArgSort;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import org.apache.mahout.math.RandomAccessSparseVector;
import org.apache.mahout.math.Vector;

public class CBM
implements MultiLabelClassifier.ClassProbEstimator,
MultiLabelClassifier.AssignmentProbEstimator,
Serializable {
    private static final long serialVersionUID = 2L;
    int numLabels;
    int numComponents;
    private int numFeatures;
    private int numSample = 100;
    private boolean allowEmpty = false;
    private String predictMode = "dynamic";
    private List<MultiLabel> support;
    Classifier.ProbabilityEstimator[][] binaryClassifiers;
    Classifier.ProbabilityEstimator multiClassClassifier;
    private String binaryClassifierType;
    private String multiClassClassifierType;

    private CBM() {
    }

    public String getBinaryClassifierType() {
        return this.binaryClassifierType;
    }

    public String getMultiClassClassifierType() {
        return this.multiClassClassifierType;
    }

    @Override
    public int getNumClasses() {
        return this.numLabels;
    }

    double[] posteriorMembership(Vector x, MultiLabel y) {
        BMDistribution bmDistribution = this.computeBM(x);
        return bmDistribution.posteriorMembership(y);
    }

    double[] posteriorMembershipShortCircuit(Vector x, MultiLabel y) {
        ShortCircuitPosterior shortCircuitPosterior = new ShortCircuitPosterior(this, x, y);
        return shortCircuitPosterior.posteriorMembership();
    }

    double[] posteriorMembership(Vector x, MultiLabel y, double[] noiseLabelWeights) {
        BMDistribution bmDistribution = this.computeBM(x);
        return bmDistribution.posteriorMembership(y, noiseLabelWeights);
    }

    public BMDistribution computeBM(Vector x) {
        return new BMDistribution(this, x);
    }

    @Override
    public double predictLogAssignmentProb(Vector x, MultiLabel y) {
        BMDistribution bmDistribution = this.computeBM(x);
        return bmDistribution.logProbability(y);
    }

    @Override
    public double predictAssignmentProb(Vector vector, MultiLabel assignment) {
        return Math.exp(this.predictLogAssignmentProb(vector, assignment));
    }

    @Override
    public double[] predictLogAssignmentProbs(Vector x, List<MultiLabel> assignments) {
        BMDistribution bmDistribution = this.computeBM(x);
        double[] probs = new double[assignments.size()];
        for (int c = 0; c < assignments.size(); ++c) {
            MultiLabel multiLabel = assignments.get(c);
            probs[c] = bmDistribution.logProbability(multiLabel);
        }
        return probs;
    }

    public List<Double> predictLogAssignmentProbsAsList(Vector x, List<MultiLabel> assignments) {
        BMDistribution bmDistribution = this.computeBM(x);
        ArrayList<Double> probs = new ArrayList<Double>();
        for (MultiLabel multiLabel : assignments) {
            probs.add(bmDistribution.logProbability(multiLabel));
        }
        return probs;
    }

    @Override
    public double[] predictAssignmentProbs(Vector vector, List<MultiLabel> assignments) {
        double[] logProbs = this.predictLogAssignmentProbs(vector, assignments);
        return Arrays.stream(logProbs).map(Math::exp).toArray();
    }

    public double predictLogAssignmentProb(Vector x, MultiLabel y, double piThreshold) {
        BMDistribution bmDistribution = new BMDistribution(this, x, piThreshold);
        return bmDistribution.logProbability(y);
    }

    public double predictAssignmentProb(Vector vector, MultiLabel assignment, double piThreshold) {
        return Math.exp(this.predictLogAssignmentProb(vector, assignment, piThreshold));
    }

    public double[] predictLogAssignmentProbs(Vector x, List<MultiLabel> assignments, double piThreshold) {
        BMDistribution bmDistribution = new BMDistribution(this, x, piThreshold);
        double[] probs = new double[assignments.size()];
        for (int c = 0; c < assignments.size(); ++c) {
            MultiLabel multiLabel = assignments.get(c);
            probs[c] = bmDistribution.logProbability(multiLabel);
        }
        return probs;
    }

    public double[] predictAssignmentProbs(Vector vector, List<MultiLabel> assignments, double piThreshold) {
        double[] logProbs = this.predictLogAssignmentProbs(vector, assignments, piThreshold);
        return Arrays.stream(logProbs).map(Math::exp).toArray();
    }

    @Override
    public MultiLabel predict(Vector vector) {
        switch (this.predictMode) {
            case "support": {
                return this.predictBySupport(vector);
            }
            case "marginal": {
                return this.predictByMarginals(vector);
            }
        }
        CBMPredictor CBMPredictor2 = new CBMPredictor(this.computeBM(vector));
        CBMPredictor2.setNumSamples(this.numSample);
        CBMPredictor2.setAllowEmpty(this.allowEmpty);
        switch (this.predictMode) {
            case "dynamic": {
                return CBMPredictor2.predictByDynamic();
            }
            case "hard": {
                return CBMPredictor2.predictByHardAssignment();
            }
        }
        throw new RuntimeException("Unknown predictMode: " + this.predictMode);
    }

    private MultiLabel predictBySupport(Vector vector) {
        return null;
    }

    @Override
    public double[] predictClassProbs(Vector vector) {
        BMDistribution bmDistribution = new BMDistribution(this, vector, 0.1);
        return bmDistribution.marginals();
    }

    MultiLabel predictByMarginals(Vector vector) {
        double[] probs = this.predictClassProbs(vector);
        MultiLabel prediction = new MultiLabel();
        for (int l = 0; l < this.numLabels; ++l) {
            if (!(probs[l] > 0.5)) continue;
            prediction.addLabel(l);
        }
        return prediction;
    }

    public MultiLabel predictByMarginals(Vector vector, int top) {
        double[] probs = this.predictClassProbs(vector);
        int[] sortedIndices = ArgSort.argSortDescending(probs);
        MultiLabel prediction = new MultiLabel();
        for (int i = 0; i < top; ++i) {
            prediction.addLabel(sortedIndices[i]);
        }
        return prediction;
    }

    public String toString() {
        int k;
        RandomAccessSparseVector vector = new RandomAccessSparseVector(this.numFeatures);
        double[] mixtureCoefficients = this.multiClassClassifier.predictClassProbs((Vector)vector);
        StringBuilder sb = new StringBuilder("CBM{\n");
        sb.append("numLabels=").append(this.numLabels).append("\n");
        sb.append("numComponents=").append(this.numComponents).append("\n");
        for (k = 0; k < this.numComponents; ++k) {
            sb.append("cluster ").append(k).append(":\n");
            sb.append("proportion = ").append(mixtureCoefficients[k]).append("\n");
        }
        sb.append("multi-class component = \n");
        sb.append(this.multiClassClassifier);
        sb.append("binary components = \n");
        for (k = 0; k < this.numComponents; ++k) {
            for (int l = 0; l < this.numLabels; ++l) {
                sb.append("component ").append(k).append(" class ").append(l).append("\n");
                sb.append(this.binaryClassifiers[k][l]).append("\n");
            }
        }
        sb.append('}');
        return sb.toString();
    }

    public void setPredictMode(String mode) {
        this.predictMode = mode;
    }

    public void setAllowEmpty(boolean allowEmpty) {
        this.allowEmpty = allowEmpty;
    }

    public boolean getAllowEmpty() {
        return this.allowEmpty;
    }

    @Override
    public FeatureList getFeatureList() {
        return null;
    }

    @Override
    public LabelTranslator getLabelTranslator() {
        return null;
    }

    public void setNumSample(int numSample) {
        this.numSample = numSample;
    }

    public List<MultiLabel> samples(Vector x, int numSamples) {
        BMDistribution bmDistribution = this.computeBM(x);
        return bmDistribution.sample(numSamples);
    }

    public Classifier.ProbabilityEstimator[][] getBinaryClassifiers() {
        return this.binaryClassifiers;
    }

    public Classifier.ProbabilityEstimator getMultiClassClassifier() {
        return this.multiClassClassifier;
    }

    public int getNumComponents() {
        return this.numComponents;
    }

    public static Builder getBuilder() {
        return new Builder();
    }

    public static class Builder {
        private int numClasses;
        private int numComponents;
        private int numFeatures;
        private List<MultiLabel> support;
        private String binaryClassifierType = "lr";
        private String multiClassClassifierType = "lr";

        private Builder() {
        }

        public Builder setNumClasses(int numClasses) {
            this.numClasses = numClasses;
            return this;
        }

        public Builder setNumComponents(int numComponents) {
            this.numComponents = numComponents;
            return this;
        }

        public Builder setNumFeatures(int numFeatures) {
            this.numFeatures = numFeatures;
            return this;
        }

        public Builder setBinaryClassifierType(String binaryClassifierType) {
            this.binaryClassifierType = binaryClassifierType;
            return this;
        }

        public Builder setMultiClassClassifierType(String multiClassClassifierType) {
            this.multiClassClassifierType = multiClassClassifierType;
            return this;
        }

        public Builder setSupport(List<MultiLabel> support) {
            this.support = support;
            return this;
        }

        public CBM build() {
            CBM CBM2 = new CBM();
            CBM2.numLabels = this.numClasses;
            CBM2.numComponents = this.numComponents;
            CBM2.numFeatures = this.numFeatures;
            CBM2.binaryClassifierType = this.binaryClassifierType;
            CBM2.multiClassClassifierType = this.multiClassClassifierType;
            CBM2.support = this.support;
            switch (this.binaryClassifierType) {
                case "lr": {
                    CBM2.binaryClassifiers = new Classifier.ProbabilityEstimator[this.numComponents][this.numClasses];
                    break;
                }
                case "boost": {
                    CBM2.binaryClassifiers = new Classifier.ProbabilityEstimator[this.numComponents][this.numClasses];
                    break;
                }
                case "elasticnet": {
                    CBM2.binaryClassifiers = new Classifier.ProbabilityEstimator[this.numComponents][this.numClasses];
                    break;
                }
                default: {
                    throw new IllegalArgumentException("binaryClassifierType can be lr or boost. Given: " + this.binaryClassifierType);
                }
            }
            switch (this.multiClassClassifierType) {
                case "lr": {
                    CBM2.multiClassClassifier = new LogisticRegression(this.numComponents, this.numFeatures, true);
                    break;
                }
                case "boost": {
                    CBM2.multiClassClassifier = new LKBoost(this.numComponents);
                    break;
                }
                case "elasticnet": {
                    CBM2.multiClassClassifier = new LogisticRegression(this.numComponents, this.numFeatures, true);
                    break;
                }
                default: {
                    throw new IllegalArgumentException("multiClassClassifierType can be lr or boost");
                }
            }
            return CBM2;
        }
    }
}

