/*
 * Decompiled with CFR 0.152.
 */
package edu.neu.ccs.pyramid.classification.logistic_regression;

import edu.neu.ccs.pyramid.classification.ClassProbability;
import edu.neu.ccs.pyramid.classification.PredictionAnalysis;
import edu.neu.ccs.pyramid.classification.logistic_regression.LogisticRegression;
import edu.neu.ccs.pyramid.dataset.ClfDataSet;
import edu.neu.ccs.pyramid.dataset.IdTranslator;
import edu.neu.ccs.pyramid.dataset.LabelTranslator;
import edu.neu.ccs.pyramid.feature.Feature;
import edu.neu.ccs.pyramid.feature.FeatureList;
import edu.neu.ccs.pyramid.feature.FeatureUtility;
import edu.neu.ccs.pyramid.feature.Ngram;
import edu.neu.ccs.pyramid.feature.TopFeatures;
import edu.neu.ccs.pyramid.regression.ClassScoreCalculation;
import edu.neu.ccs.pyramid.regression.ConstantRule;
import edu.neu.ccs.pyramid.regression.LinearRule;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import org.apache.mahout.math.Vector;

public class LogisticRegressionInspector {
    public static TopFeatures topFeatures(LogisticRegression logisticRegression, int classIndex, int limit) {
        FeatureList featureList = logisticRegression.getFeatureList();
        Vector weights = logisticRegression.getWeights().getWeightsWithoutBiasForClass(classIndex);
        Comparator<FeatureUtility> comparator = Comparator.comparing(FeatureUtility::getUtility);
        List<Feature> list = IntStream.range(0, weights.size()).mapToObj(i -> new FeatureUtility(featureList.get(i)).setUtility(weights.get(i))).filter(featureUtility -> featureUtility.getUtility() > 0.0).sorted(comparator.reversed()).map(FeatureUtility::getFeature).limit(limit).collect(Collectors.toList());
        TopFeatures topFeatures = new TopFeatures();
        topFeatures.setTopFeatures(list);
        topFeatures.setClassIndex(classIndex);
        LabelTranslator labelTranslator = logisticRegression.getLabelTranslator();
        topFeatures.setClassName(labelTranslator.toExtLabel(classIndex));
        return topFeatures;
    }

    public static ClassScoreCalculation decisionProcess(LogisticRegression logisticRegression, LabelTranslator labelTranslator, Vector vector, int classIndex, int limit) {
        ClassScoreCalculation classScoreCalculation = new ClassScoreCalculation(classIndex, labelTranslator.toExtLabel(classIndex), logisticRegression.predictClassScore(vector, classIndex));
        ArrayList<LinearRule> linearRules = new ArrayList<LinearRule>();
        ConstantRule bias = new ConstantRule(logisticRegression.getWeights().getBiasForClass(classIndex));
        classScoreCalculation.addRule(bias);
        for (int j = 0; j < logisticRegression.getNumFeatures(); ++j) {
            Feature feature = logisticRegression.getFeatureList().get(j);
            double weight = logisticRegression.getWeights().getWeightsWithoutBiasForClass(classIndex).get(j);
            double featureValue = vector.get(j);
            double score = weight * featureValue;
            LinearRule rule = new LinearRule();
            rule.setFeature(feature);
            rule.setFeatureValue(featureValue);
            rule.setScore(score);
            rule.setWeight(weight);
            linearRules.add(rule);
        }
        Comparator<LinearRule> comparator = Comparator.comparing(decision -> Math.abs(decision.getScore()));
        List sorted = linearRules.stream().sorted(comparator.reversed()).limit(limit).collect(Collectors.toList());
        for (LinearRule linearRule : sorted) {
            classScoreCalculation.addRule(linearRule);
        }
        return classScoreCalculation;
    }

    public static PredictionAnalysis analyzePrediction(LogisticRegression logisticRegression, ClfDataSet dataSet, int dataPointIndex, int limit) {
        PredictionAnalysis predictionAnalysis = new PredictionAnalysis();
        IdTranslator idTranslator = dataSet.getIdTranslator();
        LabelTranslator labelTranslator = dataSet.getLabelTranslator();
        predictionAnalysis.setInternalId(dataPointIndex).setId(idTranslator.toExtId(dataPointIndex)).setInternalLabel(dataSet.getLabels()[dataPointIndex]).setLabel(labelTranslator.toExtLabel(dataSet.getLabels()[dataPointIndex]));
        int prediction = logisticRegression.predict(dataSet.getRow(dataPointIndex));
        predictionAnalysis.setInternalPrediction(prediction);
        predictionAnalysis.setPrediction(labelTranslator.toExtLabel(prediction));
        double[] probs = logisticRegression.predictClassProbs(dataSet.getRow(dataPointIndex));
        ArrayList<ClassProbability> classProbabilities = new ArrayList<ClassProbability>();
        for (int k = 0; k < probs.length; ++k) {
            ClassProbability classProbability = new ClassProbability(k, labelTranslator.toExtLabel(k), probs[k]);
            classProbabilities.add(classProbability);
        }
        predictionAnalysis.setClassProbabilities(classProbabilities);
        ArrayList<ClassScoreCalculation> classScoreCalculations = new ArrayList<ClassScoreCalculation>();
        for (int k = 0; k < probs.length; ++k) {
            ClassScoreCalculation classScoreCalculation = LogisticRegressionInspector.decisionProcess(logisticRegression, labelTranslator, dataSet.getRow(dataPointIndex), k, limit);
            classScoreCalculations.add(classScoreCalculation);
        }
        predictionAnalysis.setClassScoreCalculations(classScoreCalculations);
        return predictionAnalysis;
    }

    public static int[] numOfUsedFeaturesEachClass(LogisticRegression logisticRegression) {
        int[] numbers = new int[logisticRegression.getNumClasses()];
        for (int k = 0; k < logisticRegression.getNumClasses(); ++k) {
            numbers[k] = logisticRegression.getWeights().getWeightsWithoutBiasForClass(k).getNumNonZeroElements();
        }
        return numbers;
    }

    public static int numOfUsedFeaturesCombined(LogisticRegression logisticRegression) {
        HashSet<Integer> usedFeatures = new HashSet<Integer>();
        for (int k = 0; k < logisticRegression.getNumClasses(); ++k) {
            Vector vector = logisticRegression.getWeights().getWeightsWithoutBiasForClass(k);
            for (Vector.Element element : vector.nonZeroes()) {
                usedFeatures.add(element.index());
            }
        }
        return usedFeatures.size();
    }

    public static Set<Integer> usedFeaturesCombined(LogisticRegression logisticRegression) {
        HashSet<Integer> usedFeatures = new HashSet<Integer>();
        for (int k = 0; k < logisticRegression.getNumClasses(); ++k) {
            Vector vector = logisticRegression.getWeights().getWeightsWithoutBiasForClass(k);
            for (Vector.Element element : vector.nonZeroes()) {
                usedFeatures.add(element.index());
            }
        }
        return usedFeatures;
    }

    public static String checkNgramUsage(LogisticRegression logisticRegression) {
        int n;
        StringBuilder sb = new StringBuilder();
        FeatureList featureList = logisticRegression.getFeatureList();
        HashSet<Integer> usedFeatures = new HashSet<Integer>();
        for (int k = 0; k < logisticRegression.getNumClasses(); ++k) {
            Vector vector = logisticRegression.getWeights().getWeightsWithoutBiasForClass(k);
            for (Vector.Element element : vector.nonZeroes()) {
                usedFeatures.add(element.index());
            }
        }
        List selected = usedFeatures.stream().map(featureList::get).filter(feature -> feature instanceof Ngram).map(feature -> (Ngram)feature).collect(Collectors.toList());
        List candidates = featureList.getAll().stream().filter(feature -> feature instanceof Ngram).map(feature -> (Ngram)feature).collect(Collectors.toList());
        int maxLength = candidates.stream().mapToInt(Ngram::getN).max().getAsInt();
        int[] numberCandidates = new int[maxLength];
        candidates.stream().forEach(ngram -> {
            int n = ngram.getN() - 1;
            numberCandidates[n] = numberCandidates[n] + 1;
        });
        sb.append("number of ngram candidates: ");
        for (int n2 = 1; n2 <= maxLength; ++n2) {
            sb.append(n2 + "-grams = " + numberCandidates[n2 - 1]);
            sb.append("; ");
        }
        sb.append("\n");
        int[] numberSelected = new int[maxLength];
        selected.stream().forEach(ngram -> {
            int n = ngram.getN() - 1;
            numberSelected[n] = numberSelected[n] + 1;
        });
        sb.append("number of selected ngram: ");
        for (int n3 = 1; n3 <= maxLength; ++n3) {
            sb.append(n3 + "-grams = " + numberSelected[n3 - 1]);
            sb.append("; ");
        }
        sb.append("\n");
        int[] easyCandidates = new int[maxLength];
        int[] easySelected = new int[maxLength];
        Set unigrams = selected.stream().filter(ngram -> ngram.getN() == 1).map(Ngram::getNgram).collect(Collectors.toSet());
        candidates.stream().filter(ngram -> LogisticRegressionInspector.isComposedOf(ngram.getNgram(), unigrams)).forEach(ngram -> {
            int n = ngram.getN() - 1;
            easyCandidates[n] = easyCandidates[n] + 1;
        });
        sb.append("number of ngram candidates that can be constructed from seeds: ");
        for (n = 1; n <= maxLength; ++n) {
            sb.append(n + "-grams = " + easyCandidates[n - 1]);
            sb.append("; ");
        }
        sb.append("\n");
        selected.stream().filter(ngram -> LogisticRegressionInspector.isComposedOf(ngram.getNgram(), unigrams)).forEach(ngram -> {
            int n = ngram.getN() - 1;
            easySelected[n] = easySelected[n] + 1;
        });
        sb.append("number of selected ngrams that can be constructed from seeds: ");
        for (n = 1; n <= maxLength; ++n) {
            sb.append(n + "-grams = " + easySelected[n - 1]);
            sb.append("; ");
        }
        sb.append("\n");
        sb.append("percentage of selected ngrams that can be constructed from seeds: ");
        for (n = 1; n <= maxLength; ++n) {
            sb.append(n + "-grams = " + (double)easySelected[n - 1] / (double)numberSelected[n - 1]);
            sb.append("; ");
        }
        sb.append("\n");
        sb.append("feature selection ratio: ");
        for (n = 1; n <= maxLength; ++n) {
            sb.append(n + "-grams = " + (double)numberSelected[n - 1] / (double)numberCandidates[n - 1]);
            sb.append("; ");
        }
        sb.append("\n");
        sb.append("feature selection ratio based on seeds: ");
        for (n = 1; n <= maxLength; ++n) {
            sb.append(n + "-grams = " + (double)easySelected[n - 1] / (double)easyCandidates[n - 1]);
            sb.append("; ");
        }
        return sb.toString();
    }

    private static boolean isComposedOf(String ngram, Set<String> unigrams) {
        String[] split;
        for (String term : split = ngram.split(" ")) {
            if (!unigrams.contains(term)) continue;
            return true;
        }
        return false;
    }
}

