/*
 * Decompiled with CFR 0.152.
 */
package edu.neu.ccs.pyramid.regression.regression_tree;

import edu.neu.ccs.pyramid.feature.Feature;
import edu.neu.ccs.pyramid.feature.FeatureList;
import edu.neu.ccs.pyramid.regression.regression_tree.Node;
import edu.neu.ccs.pyramid.regression.regression_tree.RegressionTree;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
import org.apache.mahout.math.Vector;

public class RegTreeInspector {
    public static Set<Integer> features(RegressionTree tree) {
        return tree.traverse().stream().filter(node -> !node.isLeaf()).map(Node::getFeatureIndex).collect(Collectors.toSet());
    }

    public static Map<Feature, Double> featureImportance(RegressionTree tree) {
        FeatureList featureList = tree.getFeatureList();
        HashMap<Feature, Double> map = new HashMap<Feature, Double>();
        List<Node> nodes = tree.traverse();
        nodes.stream().filter(node -> !node.isLeaf()).forEach(node -> {
            int featureIndex = node.getFeatureIndex();
            Feature feature = featureList.get(featureIndex);
            double reduction = node.getReduction();
            double oldValue = map.getOrDefault(feature, 0.0);
            double newValue = reduction + oldValue;
            map.put(feature, newValue);
        });
        return map;
    }

    public static int getMatchedLeaf(RegressionTree tree, Vector vector) {
        for (int i = 0; i < tree.getNumLeaves(); ++i) {
            double prob = tree.probability(vector, tree.leaves.get(i));
            if (prob != 1.0) continue;
            return i;
        }
        return 0;
    }

    public static List<Integer> getMatchedPath(List<RegressionTree> trees, Vector vector) {
        ArrayList<Integer> list = new ArrayList<Integer>();
        for (RegressionTree tree : trees) {
            list.add(RegTreeInspector.getMatchedLeaf(tree, vector));
        }
        return list;
    }
}

