/*
 * Decompiled with CFR 0.152.
 */
package edu.neu.ccs.pyramid.multilabel_classification;

import edu.neu.ccs.pyramid.dataset.DataSetUtil;
import edu.neu.ccs.pyramid.dataset.LabelTranslator;
import edu.neu.ccs.pyramid.dataset.MLClfDataSetBuilder;
import edu.neu.ccs.pyramid.dataset.MultiLabel;
import edu.neu.ccs.pyramid.dataset.MultiLabelClfDataSet;
import edu.neu.ccs.pyramid.feature.FeatureList;
import edu.neu.ccs.pyramid.multilabel_classification.MultiLabelClassifier;
import edu.neu.ccs.pyramid.multilabel_classification.multi_label_logistic_regression.MLLogisticRegression;
import edu.neu.ccs.pyramid.multilabel_classification.multi_label_logistic_regression.MLLogisticTrainer;
import java.util.List;
import org.apache.mahout.math.DenseVector;
import org.apache.mahout.math.Vector;

public class MLACPlattScaling
implements MultiLabelClassifier.ClassProbEstimator,
MultiLabelClassifier.AssignmentProbEstimator {
    private static final long serialVersionUID = 1L;
    private MultiLabelClassifier.ClassScoreEstimator scoreEstimator;
    private MLLogisticRegression logisticRegression;

    public MLACPlattScaling(MultiLabelClfDataSet dataSet, MultiLabelClassifier.ClassScoreEstimator scoreEstimator) {
        int i;
        this.scoreEstimator = scoreEstimator;
        MultiLabelClfDataSet scoreDataSet = MLClfDataSetBuilder.getBuilder().numDataPoints(dataSet.getNumDataPoints()).numFeatures(dataSet.getNumClasses()).numClasses(dataSet.getNumClasses()).missingValue(false).build();
        for (i = 0; i < scoreDataSet.getNumDataPoints(); ++i) {
            scoreDataSet.addLabels(i, dataSet.getMultiLabels()[i].getMatchedLabels());
        }
        for (i = 0; i < scoreDataSet.getNumDataPoints(); ++i) {
            double[] scores = scoreEstimator.predictClassScores(dataSet.getRow(i));
            for (int k = 0; k < scoreDataSet.getNumClasses(); ++k) {
                scoreDataSet.setFeatureValue(i, k, scores[k]);
            }
        }
        MLLogisticTrainer trainer = MLLogisticTrainer.getBuilder().setGaussianPriorVariance(100000.0).build();
        List<MultiLabel> assignments = DataSetUtil.gatherMultiLabels(scoreDataSet);
        this.logisticRegression = trainer.train(scoreDataSet, assignments);
    }

    @Override
    public double[] predictClassProbs(Vector vector) {
        double[] scores = this.scoreEstimator.predictClassScores(vector);
        DenseVector scoreVector = new DenseVector(scores.length);
        for (int i = 0; i < scores.length; ++i) {
            scoreVector.set(i, scores[i]);
        }
        return this.logisticRegression.predictClassProbs((Vector)scoreVector);
    }

    @Override
    public int getNumClasses() {
        return this.logisticRegression.getNumClasses();
    }

    @Override
    public MultiLabel predict(Vector vector) {
        return this.scoreEstimator.predict(vector);
    }

    @Override
    public FeatureList getFeatureList() {
        return this.logisticRegression.getFeatureList();
    }

    @Override
    public LabelTranslator getLabelTranslator() {
        return this.logisticRegression.getLabelTranslator();
    }

    @Override
    public double predictLogAssignmentProb(Vector vector, MultiLabel assignment) {
        return 0.0;
    }

    @Override
    public double predictAssignmentProb(Vector vector, MultiLabel assignment) {
        double[] scores = this.scoreEstimator.predictClassScores(vector);
        DenseVector scoreVector = new DenseVector(scores.length);
        for (int i = 0; i < scores.length; ++i) {
            scoreVector.set(i, scores[i]);
        }
        return this.logisticRegression.predictAssignmentProb((Vector)scoreVector, assignment);
    }
}

