/*
 * Decompiled with CFR 0.152.
 */
package edu.neu.ccs.pyramid.classification.logistic_regression;

import edu.neu.ccs.pyramid.classification.Classifier;
import edu.neu.ccs.pyramid.classification.logistic_regression.Weights;
import edu.neu.ccs.pyramid.dataset.DataSet;
import edu.neu.ccs.pyramid.dataset.LabelTranslator;
import edu.neu.ccs.pyramid.feature.FeatureList;
import edu.neu.ccs.pyramid.util.MathUtil;
import edu.neu.ccs.pyramid.util.Vectors;
import java.io.File;
import java.util.stream.IntStream;
import org.apache.mahout.math.Vector;

public class LogisticRegression
implements Classifier.ProbabilityEstimator,
Classifier.ScoreEstimator {
    private static final long serialVersionUID = 2L;
    private int numClasses;
    private int numFeatures;
    private Weights weights;
    private FeatureList featureList;
    private LabelTranslator labelTranslator;

    public LogisticRegression(int numClasses, int numFeatures, boolean random) {
        this.numClasses = numClasses;
        this.numFeatures = numFeatures;
        this.weights = new Weights(numClasses, numFeatures, random);
    }

    public LogisticRegression(int numClasses, int numFeatures) {
        this.numClasses = numClasses;
        this.numFeatures = numFeatures;
        this.weights = new Weights(numClasses, numFeatures);
    }

    public LogisticRegression(int numClasses, int numFeatures, Vector weightVector) {
        this.numClasses = numClasses;
        this.numFeatures = numFeatures;
        this.weights = new Weights(numClasses, numFeatures, weightVector);
    }

    public LogisticRegression(int numClasses, int numFeatures, double[] priorProbabilities) {
        this.numClasses = numClasses;
        this.numFeatures = numFeatures;
        this.weights = new Weights(numClasses, numFeatures);
        double[] scores = MathUtil.inverseSoftMax(priorProbabilities);
        for (int l = 0; l < numClasses; ++l) {
            this.weights.setBiasForClass(scores[l], l);
        }
    }

    public Weights getWeights() {
        return this.weights;
    }

    @Override
    public int getNumClasses() {
        return this.numClasses;
    }

    public int getNumFeatures() {
        return this.numFeatures;
    }

    @Override
    public int predict(Vector vector) {
        double[] scores = this.predictClassScores(vector);
        double maxScore = Double.NEGATIVE_INFINITY;
        int predictedClass = 0;
        for (int k = 0; k < this.numClasses; ++k) {
            double scoreClassK = scores[k];
            if (!(scoreClassK > maxScore)) continue;
            maxScore = scoreClassK;
            predictedClass = k;
        }
        return predictedClass;
    }

    @Override
    public double predictClassScore(Vector dataPoint, int k) {
        double score = 0.0;
        score += this.weights.getBiasForClass(k);
        return score += Vectors.dot(this.weights.getWeightsWithoutBiasForClass(k), dataPoint);
    }

    @Override
    public double[] predictClassScores(Vector dataPoint) {
        double[] scores = new double[this.numClasses];
        for (int k = 0; k < this.numClasses; ++k) {
            scores[k] = this.predictClassScore(dataPoint, k);
        }
        return scores;
    }

    @Override
    public double[] predictClassProbs(Vector vector) {
        double[] scoreVector = this.predictClassScores(vector);
        double[] probVector = new double[this.numClasses];
        double logDenominator = MathUtil.logSumExp(scoreVector);
        for (int k = 0; k < this.numClasses; ++k) {
            double pro;
            double logNumerator = scoreVector[k];
            probVector[k] = pro = Math.exp(logNumerator - logDenominator);
        }
        return probVector;
    }

    @Override
    public double[] predictLogClassProbs(Vector vector) {
        double[] scoreVector = this.predictClassScores(vector);
        double[] logProbVector = new double[this.numClasses];
        double logDenominator = MathUtil.logSumExp(scoreVector);
        for (int k = 0; k < this.numClasses; ++k) {
            double logNumerator = scoreVector[k];
            logProbVector[k] = logNumerator - logDenominator;
        }
        return logProbVector;
    }

    double logLikelihood(Vector vector, double[] targets) {
        double[] scoreVector = this.predictClassScores(vector);
        double logDenominator = MathUtil.logSumExp(scoreVector);
        double logNumberator = 0.0;
        for (int k = 0; k < scoreVector.length; ++k) {
            logNumberator += targets[k] * scoreVector[k];
        }
        return logNumberator - logDenominator;
    }

    double logLikelihood(Vector vector, double[] targets, double weight) {
        double[] scoreVector = this.predictClassScores(vector);
        double logDenominator = MathUtil.logSumExp(scoreVector);
        double logNumberator = 0.0;
        for (int k = 0; k < scoreVector.length; ++k) {
            logNumberator += targets[k] * scoreVector[k];
        }
        return weight * (logNumberator - logDenominator);
    }

    public double dataSetLogLikelihood(DataSet dataSet, double[][] targets) {
        return IntStream.range(0, dataSet.getNumDataPoints()).parallel().mapToDouble(i -> this.logLikelihood(dataSet.getRow(i), targets[i])).sum();
    }

    public double dataSetLogLikelihood(DataSet dataSet, double[][] targets, double[] weights) {
        return IntStream.range(0, dataSet.getNumDataPoints()).parallel().mapToDouble(i -> this.logLikelihood(dataSet.getRow(i), targets[i], weights[i])).sum();
    }

    public void truncateByThreshold(double threshold) {
        this.weights.truncateByThreshold(threshold);
    }

    /*
     * Exception decompiling
     */
    public static LogisticRegression deserialize(File file) throws Exception {
        /*
         * This method has failed to decompile.  When submitting a bug report, please provide this stack trace, and (if you hold appropriate legal rights) the relevant class file.
         * 
         * org.benf.cfr.reader.util.ConfusedCFRException: Started 4 blocks at once
         *     at org.benf.cfr.reader.bytecode.analysis.opgraph.Op04StructuredStatement.getStartingBlocks(Op04StructuredStatement.java:412)
         *     at org.benf.cfr.reader.bytecode.analysis.opgraph.Op04StructuredStatement.buildNestedBlocks(Op04StructuredStatement.java:487)
         *     at org.benf.cfr.reader.bytecode.analysis.opgraph.Op03SimpleStatement.createInitialStructuredBlock(Op03SimpleStatement.java:736)
         *     at org.benf.cfr.reader.bytecode.CodeAnalyser.getAnalysisInner(CodeAnalyser.java:850)
         *     at org.benf.cfr.reader.bytecode.CodeAnalyser.getAnalysisOrWrapFail(CodeAnalyser.java:278)
         *     at org.benf.cfr.reader.bytecode.CodeAnalyser.getAnalysis(CodeAnalyser.java:201)
         *     at org.benf.cfr.reader.entities.attributes.AttributeCode.analyse(AttributeCode.java:94)
         *     at org.benf.cfr.reader.entities.Method.analyse(Method.java:531)
         *     at org.benf.cfr.reader.entities.ClassFile.analyseMid(ClassFile.java:1055)
         *     at org.benf.cfr.reader.entities.ClassFile.analyseTop(ClassFile.java:942)
         *     at org.benf.cfr.reader.Driver.doJarVersionTypes(Driver.java:257)
         *     at org.benf.cfr.reader.Driver.doJar(Driver.java:139)
         *     at org.benf.cfr.reader.CfrDriverImpl.analyse(CfrDriverImpl.java:76)
         *     at org.benf.cfr.reader.Main.main(Main.java:54)
         */
        throw new IllegalStateException("Decompilation failed");
    }

    @Override
    public FeatureList getFeatureList() {
        return this.featureList;
    }

    public void setFeatureList(FeatureList featureList) {
        this.featureList = featureList;
    }

    @Override
    public LabelTranslator getLabelTranslator() {
        return this.labelTranslator;
    }

    void setLabelTranslator(LabelTranslator labelTranslator) {
        this.labelTranslator = labelTranslator;
    }

    public String toString() {
        StringBuilder sb = new StringBuilder("LogisticRegression{");
        sb.append("numClasses=").append(this.numClasses);
        sb.append(", numFeatures=").append(this.numFeatures);
        sb.append(", weights=").append(this.weights);
        sb.append('}');
        return sb.toString();
    }
}

