/*
 * Decompiled with CFR 0.152.
 */
package edu.neu.ccs.pyramid.multilabel_classification.imllr;

import edu.neu.ccs.pyramid.classification.logistic_regression.Weights;
import edu.neu.ccs.pyramid.dataset.LabelTranslator;
import edu.neu.ccs.pyramid.dataset.MultiLabel;
import edu.neu.ccs.pyramid.dataset.MultiLabelClfDataSet;
import edu.neu.ccs.pyramid.feature.FeatureList;
import edu.neu.ccs.pyramid.multilabel_classification.MultiLabelClassifier;
import edu.neu.ccs.pyramid.util.MathUtil;
import java.util.List;
import java.util.stream.IntStream;
import org.apache.mahout.math.Vector;

public class IMLLogisticRegression
implements MultiLabelClassifier,
MultiLabelClassifier.ClassScoreEstimator {
    private static final long serialVersionUID = 1L;
    private int numClasses;
    private int numFeatures;
    private Weights weights;
    private FeatureList featureList;
    private LabelTranslator labelTranslator;
    private List<MultiLabel> assignments;

    public IMLLogisticRegression(int numClasses, int numFeatures, List<MultiLabel> assignments) {
        this.numClasses = numClasses;
        this.numFeatures = numFeatures;
        this.weights = new Weights(numClasses, numFeatures);
        this.assignments = assignments;
    }

    public IMLLogisticRegression(int numClasses, int numFeatures, List<MultiLabel> assignments, Vector weightVector) {
        this.numClasses = numClasses;
        this.numFeatures = numFeatures;
        this.weights = new Weights(numClasses, numFeatures, weightVector);
        this.assignments = assignments;
    }

    public List<MultiLabel> getAssignments() {
        return this.assignments;
    }

    public Weights getWeights() {
        return this.weights;
    }

    public int getNumFeatures() {
        return this.numFeatures;
    }

    @Override
    public int getNumClasses() {
        return this.numClasses;
    }

    @Override
    public MultiLabel predict(Vector vector) {
        MultiLabel prediction = this.assignments != null ? this.predictWithConstraints(vector) : this.predictWithoutConstraints(vector);
        return prediction;
    }

    private MultiLabel predictWithoutConstraints(Vector vector) {
        MultiLabel prediction = new MultiLabel();
        for (int k = 0; k < this.numClasses; ++k) {
            double score = this.predictClassScore(vector, k);
            if (!(score > 0.0)) continue;
            prediction.addLabel(k);
        }
        return prediction;
    }

    private MultiLabel predictWithConstraints(Vector vector) {
        double maxScore = Double.NEGATIVE_INFINITY;
        MultiLabel prediction = null;
        double[] classScores = this.predictClassScores(vector);
        for (MultiLabel assignment : this.assignments) {
            double score = this.calAssignmentScore(assignment, classScores);
            if (!(score > maxScore)) continue;
            maxScore = score;
            prediction = assignment;
        }
        return prediction;
    }

    @Override
    public double predictClassScore(Vector dataPoint, int k) {
        double score = 0.0;
        score += this.weights.getBiasForClass(k);
        return score += this.weights.getWeightsWithoutBiasForClass(k).dot(dataPoint);
    }

    @Override
    public double[] predictClassScores(Vector dataPoint) {
        double[] scores = new double[this.numClasses];
        for (int k = 0; k < this.numClasses; ++k) {
            scores[k] = this.predictClassScore(dataPoint, k);
        }
        return scores;
    }

    public double predictAssignmentProb(Vector vector, MultiLabel assignment) {
        if (assignment.outOfBound(this.numClasses)) {
            return 0.0;
        }
        if (this.assignments != null) {
            return this.predictAssignmentProbWithConstraint(vector, assignment);
        }
        return this.predictAssignmentProbWithoutConstraint(vector, assignment);
    }

    double predictAssignmentProbWithConstraint(Vector vector, MultiLabel assignment) {
        if (!this.assignments.contains(assignment)) {
            return 0.0;
        }
        double[] classScores = this.predictClassScores(vector);
        double[] assignmentScores = new double[this.assignments.size()];
        for (int i = 0; i < this.assignments.size(); ++i) {
            assignmentScores[i] = this.calAssignmentScore(this.assignments.get(i), classScores);
        }
        double logNumerator = this.calAssignmentScore(assignment, classScores);
        double logDenominator = MathUtil.logSumExp(assignmentScores);
        double pro = Math.exp(logNumerator - logDenominator);
        return pro;
    }

    double predictAssignmentProbWithoutConstraint(Vector vector, MultiLabel assignment) {
        double[] classScores = this.predictClassScores(vector);
        double logProb = 0.0;
        for (int k = 0; k < this.numClasses; ++k) {
            double logNumerator = 0.0;
            if (assignment.matchClass(k)) {
                logNumerator = classScores[k];
            }
            double[] scores = new double[]{0.0, classScores[k]};
            double logDenominator = MathUtil.logSumExp(scores);
            logProb += logNumerator;
            logProb -= logDenominator;
        }
        return Math.exp(logProb);
    }

    public double predictClassProb(Vector vector, int classIndex) {
        double score;
        double logNumerator = score = this.predictClassScore(vector, classIndex);
        double[] scores = new double[]{0.0, score};
        double logDenominator = MathUtil.logSumExp(scores);
        double pro = Math.exp(logNumerator - logDenominator);
        return pro;
    }

    public double[] predictClassProbs(Vector vector) {
        return IntStream.range(0, this.numClasses).mapToDouble(k -> this.predictClassProb(vector, k)).toArray();
    }

    double calAssignmentScore(MultiLabel assignment, double[] classScores) {
        double score = 0.0;
        for (Integer label : assignment.getMatchedLabels()) {
            score += classScores[label];
        }
        return score;
    }

    double logLikelihood(Vector vector, MultiLabel assignment) {
        double[] classScores = this.predictClassScores(vector);
        double logProb = 0.0;
        for (int k = 0; k < this.numClasses; ++k) {
            double logNumerator = 0.0;
            if (assignment.matchClass(k)) {
                logNumerator = classScores[k];
            }
            double[] scores = new double[]{0.0, classScores[k]};
            double logDenominator = MathUtil.logSumExp(scores);
            logProb += logNumerator;
            logProb -= logDenominator;
        }
        return logProb;
    }

    double dataSetLogLikelihood(MultiLabelClfDataSet dataSet) {
        MultiLabel[] multiLabels = dataSet.getMultiLabels();
        return IntStream.range(0, dataSet.getNumDataPoints()).parallel().mapToDouble(i -> this.logLikelihood(dataSet.getRow(i), multiLabels[i])).sum();
    }

    @Override
    public FeatureList getFeatureList() {
        return this.featureList;
    }

    public void setFeatureList(FeatureList featureList) {
        this.featureList = featureList;
    }

    @Override
    public LabelTranslator getLabelTranslator() {
        return this.labelTranslator;
    }

    public void setLabelTranslator(LabelTranslator labelTranslator) {
        this.labelTranslator = labelTranslator;
    }
}

