/*
 * Decompiled with CFR 0.152.
 */
package edu.neu.ccs.pyramid.classification.lkboost;

import edu.neu.ccs.pyramid.classification.ClassProbability;
import edu.neu.ccs.pyramid.classification.PredictionAnalysis;
import edu.neu.ccs.pyramid.classification.lkboost.LKBoost;
import edu.neu.ccs.pyramid.dataset.ClfDataSet;
import edu.neu.ccs.pyramid.dataset.IdTranslator;
import edu.neu.ccs.pyramid.dataset.LabelTranslator;
import edu.neu.ccs.pyramid.feature.Feature;
import edu.neu.ccs.pyramid.feature.TopFeatures;
import edu.neu.ccs.pyramid.regression.ClassScoreCalculation;
import edu.neu.ccs.pyramid.regression.ConstantRegressor;
import edu.neu.ccs.pyramid.regression.ConstantRule;
import edu.neu.ccs.pyramid.regression.Regressor;
import edu.neu.ccs.pyramid.regression.regression_tree.RegTreeInspector;
import edu.neu.ccs.pyramid.regression.regression_tree.RegressionTree;
import edu.neu.ccs.pyramid.regression.regression_tree.TreeRule;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
import org.apache.mahout.math.Vector;

public class LKBInspector {
    public static TopFeatures topFeatures(LKBoost boosting, int classIndex) {
        HashMap<Feature, Double> totalContributions = new HashMap<Feature, Double>();
        List<Regressor> regressors = boosting.getEnsemble(classIndex).getRegressors();
        List trees = regressors.stream().filter(regressor -> regressor instanceof RegressionTree).map(regressor -> (RegressionTree)regressor).collect(Collectors.toList());
        for (RegressionTree tree : trees) {
            Map<Feature, Double> contributions = RegTreeInspector.featureImportance(tree);
            for (Map.Entry<Feature, Double> entry : contributions.entrySet()) {
                Feature feature = entry.getKey();
                Double contribution = entry.getValue();
                double oldValue = totalContributions.getOrDefault(feature, 0.0);
                double newValue = oldValue + contribution;
                totalContributions.put(feature, newValue);
            }
        }
        Comparator<Map.Entry> comparator = Comparator.comparing(Map.Entry::getValue);
        List<Feature> list = totalContributions.entrySet().stream().sorted(comparator.reversed()).map(Map.Entry::getKey).collect(Collectors.toList());
        TopFeatures topFeatures = new TopFeatures();
        topFeatures.setTopFeatures(list);
        topFeatures.setClassIndex(classIndex);
        LabelTranslator labelTranslator = boosting.getLabelTranslator();
        topFeatures.setClassName(labelTranslator.toExtLabel(classIndex));
        return topFeatures;
    }

    public static TopFeatures topFeatures(LKBoost boosting, int classIndex, int limit) {
        HashMap<Feature, Double> totalContributions = new HashMap<Feature, Double>();
        List<Regressor> regressors = boosting.getEnsemble(classIndex).getRegressors();
        List trees = regressors.stream().filter(regressor -> regressor instanceof RegressionTree).map(regressor -> (RegressionTree)regressor).collect(Collectors.toList());
        for (RegressionTree tree : trees) {
            Map<Feature, Double> contributions = RegTreeInspector.featureImportance(tree);
            for (Map.Entry<Feature, Double> entry : contributions.entrySet()) {
                Feature feature = entry.getKey();
                Double contribution = entry.getValue();
                double oldValue = totalContributions.getOrDefault(feature, 0.0);
                double newValue = oldValue + contribution;
                totalContributions.put(feature, newValue);
            }
        }
        Comparator<Map.Entry> comparator = Comparator.comparing(Map.Entry::getValue);
        List<Feature> list = totalContributions.entrySet().stream().sorted(comparator.reversed()).limit(limit).map(Map.Entry::getKey).collect(Collectors.toList());
        TopFeatures topFeatures = new TopFeatures();
        topFeatures.setTopFeatures(list);
        topFeatures.setClassIndex(classIndex);
        LabelTranslator labelTranslator = boosting.getLabelTranslator();
        topFeatures.setClassName(labelTranslator.toExtLabel(classIndex));
        return topFeatures;
    }

    public static TopFeatures topFeatures(List<LKBoost> lkBoosts, int classIndex) {
        HashMap<Feature, Double> totalContributions = new HashMap<Feature, Double>();
        for (LKBoost lkBoost : lkBoosts) {
            List<Regressor> regressors = lkBoost.getEnsemble(classIndex).getRegressors();
            List trees = regressors.stream().filter(regressor -> regressor instanceof RegressionTree).map(regressor -> (RegressionTree)regressor).collect(Collectors.toList());
            for (RegressionTree tree : trees) {
                Map<Feature, Double> contributions = RegTreeInspector.featureImportance(tree);
                for (Map.Entry<Feature, Double> entry : contributions.entrySet()) {
                    Feature feature = entry.getKey();
                    Double contribution = entry.getValue();
                    double oldValue = totalContributions.getOrDefault(feature, 0.0);
                    double newValue = oldValue + contribution;
                    totalContributions.put(feature, newValue);
                }
            }
        }
        Comparator<Map.Entry> comparator = Comparator.comparing(Map.Entry::getValue);
        List<Feature> list = totalContributions.entrySet().stream().sorted(comparator.reversed()).map(Map.Entry::getKey).collect(Collectors.toList());
        TopFeatures topFeatures = new TopFeatures();
        topFeatures.setTopFeatures(list);
        topFeatures.setClassIndex(classIndex);
        LabelTranslator labelTranslator = lkBoosts.get(0).getLabelTranslator();
        topFeatures.setClassName(labelTranslator.toExtLabel(classIndex));
        return topFeatures;
    }

    public static Set<Integer> recentlyUsedFeatures(LKBoost boosting, int k) {
        int size;
        HashSet<Integer> features = new HashSet<Integer>();
        List<Regressor> regressors = boosting.getEnsemble(k).getRegressors();
        Regressor lastOne = regressors.get((size = regressors.size()) - 1);
        if (lastOne instanceof RegressionTree) {
            features.addAll(RegTreeInspector.features((RegressionTree)lastOne));
        }
        return features;
    }

    public static PredictionAnalysis analyzePrediction(LKBoost boosting, ClfDataSet dataSet, int dataPointIndex, int limit) {
        PredictionAnalysis predictionAnalysis = new PredictionAnalysis();
        IdTranslator idTranslator = dataSet.getIdTranslator();
        LabelTranslator labelTranslator = dataSet.getLabelTranslator();
        predictionAnalysis.setInternalId(dataPointIndex).setId(idTranslator.toExtId(dataPointIndex)).setInternalLabel(dataSet.getLabels()[dataPointIndex]).setLabel(labelTranslator.toExtLabel(dataSet.getLabels()[dataPointIndex]));
        int prediction = boosting.predict(dataSet.getRow(dataPointIndex));
        predictionAnalysis.setInternalPrediction(prediction);
        predictionAnalysis.setPrediction(labelTranslator.toExtLabel(prediction));
        double[] probs = boosting.predictClassProbs(dataSet.getRow(dataPointIndex));
        ArrayList<ClassProbability> classProbabilities = new ArrayList<ClassProbability>();
        for (int k = 0; k < probs.length; ++k) {
            ClassProbability classProbability = new ClassProbability(k, labelTranslator.toExtLabel(k), probs[k]);
            classProbabilities.add(classProbability);
        }
        predictionAnalysis.setClassProbabilities(classProbabilities);
        ArrayList<ClassScoreCalculation> classScoreCalculations = new ArrayList<ClassScoreCalculation>();
        for (int k = 0; k < probs.length; ++k) {
            ClassScoreCalculation classScoreCalculation = LKBInspector.decisionProcess(boosting, labelTranslator, dataSet.getRow(dataPointIndex), k, limit);
            classScoreCalculations.add(classScoreCalculation);
        }
        predictionAnalysis.setClassScoreCalculations(classScoreCalculations);
        return predictionAnalysis;
    }

    public static ClassScoreCalculation decisionProcess(LKBoost boosting, LabelTranslator labelTranslator, Vector vector, int classIndex, int limit) {
        ClassScoreCalculation classScoreCalculation = new ClassScoreCalculation(classIndex, labelTranslator.toExtLabel(classIndex), boosting.predictClassScore(vector, classIndex));
        List<Regressor> regressors = boosting.getEnsemble(classIndex).getRegressors();
        ArrayList<TreeRule> treeRules = new ArrayList<TreeRule>();
        for (Regressor regressor : regressors) {
            if (regressor instanceof ConstantRegressor) {
                ConstantRule rule = new ConstantRule(((ConstantRegressor)regressor).getScore());
                classScoreCalculation.addRule(rule);
            }
            if (!(regressor instanceof RegressionTree)) continue;
            RegressionTree tree = (RegressionTree)regressor;
            TreeRule treeRule = new TreeRule(tree, vector);
            treeRules.add(treeRule);
        }
        Comparator<TreeRule> comparator = Comparator.comparing(decision -> Math.abs(decision.getScore()));
        List merged = TreeRule.merge(treeRules).stream().sorted(comparator.reversed()).limit(limit).collect(Collectors.toList());
        for (TreeRule treeRule : merged) {
            classScoreCalculation.addRule(treeRule);
        }
        return classScoreCalculation;
    }
}

