/*
 * Decompiled with CFR 0.152.
 */
package edu.neu.ccs.pyramid.multilabel_classification;

import edu.neu.ccs.pyramid.classification.logistic_regression.LogisticRegression;
import edu.neu.ccs.pyramid.classification.logistic_regression.RidgeLogisticTrainer;
import edu.neu.ccs.pyramid.dataset.ClfDataSet;
import edu.neu.ccs.pyramid.dataset.ClfDataSetBuilder;
import edu.neu.ccs.pyramid.dataset.LabelTranslator;
import edu.neu.ccs.pyramid.dataset.MultiLabel;
import edu.neu.ccs.pyramid.dataset.MultiLabelClfDataSet;
import edu.neu.ccs.pyramid.feature.FeatureList;
import edu.neu.ccs.pyramid.multilabel_classification.MultiLabelClassifier;
import org.apache.mahout.math.DenseVector;
import org.apache.mahout.math.Vector;

public class MLFlatScaling
implements MultiLabelClassifier.ClassProbEstimator {
    private static final long serialVersionUID = 1L;
    private MultiLabelClassifier.ClassScoreEstimator scoreEstimator;
    private LogisticRegression logisticRegression;

    public MLFlatScaling(MultiLabelClfDataSet dataSet, MultiLabelClassifier.ClassScoreEstimator scoreEstimator) {
        this.scoreEstimator = scoreEstimator;
        int numDataPoints = dataSet.getNumDataPoints();
        int numClasses = dataSet.getNumClasses();
        ClfDataSet clfDataSet = ClfDataSetBuilder.getBuilder().numDataPoints(numDataPoints * numClasses).numFeatures(1).numClasses(2).dense(true).missingValue(false).build();
        int rowIndex = 0;
        for (int i = 0; i < numDataPoints; ++i) {
            double[] scores = scoreEstimator.predictClassScores(dataSet.getRow(i));
            MultiLabel multiLabel = dataSet.getMultiLabels()[i];
            for (int k = 0; k < numClasses; ++k) {
                clfDataSet.setFeatureValue(rowIndex, 0, scores[k]);
                if (multiLabel.matchClass(k)) {
                    clfDataSet.setLabel(rowIndex, 1);
                }
                ++rowIndex;
            }
        }
        RidgeLogisticTrainer trainer = RidgeLogisticTrainer.getBuilder().setEpsilon(1.0).setGaussianPriorVariance(100.0).setHistory(5).build();
        this.logisticRegression = trainer.train(clfDataSet);
    }

    @Override
    public double[] predictClassProbs(Vector vector) {
        double[] scores = this.scoreEstimator.predictClassScores(vector);
        double[] probs = new double[scores.length];
        for (int k = 0; k < scores.length; ++k) {
            DenseVector scoreFeatureVector = new DenseVector(1);
            scoreFeatureVector.set(0, scores[k]);
            probs[k] = this.logisticRegression.predictClassProb((Vector)scoreFeatureVector, 1);
        }
        return probs;
    }

    @Override
    public int getNumClasses() {
        return this.scoreEstimator.getNumClasses();
    }

    @Override
    public MultiLabel predict(Vector vector) {
        return this.scoreEstimator.predict(vector);
    }

    @Override
    public FeatureList getFeatureList() {
        return this.logisticRegression.getFeatureList();
    }

    @Override
    public LabelTranslator getLabelTranslator() {
        return this.logisticRegression.getLabelTranslator();
    }
}

