/*
 * Decompiled with CFR 0.152.
 */
package edu.neu.ccs.pyramid.regression.regression_tree;

import edu.neu.ccs.pyramid.feature.Feature;
import edu.neu.ccs.pyramid.regression.Rule;
import edu.neu.ccs.pyramid.regression.regression_tree.Checks;
import edu.neu.ccs.pyramid.regression.regression_tree.Node;
import edu.neu.ccs.pyramid.regression.regression_tree.RegressionTree;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.apache.mahout.math.Vector;

public class TreeRule
implements Rule {
    private Checks checks = new Checks();
    private double score;

    TreeRule() {
    }

    public TreeRule(RegressionTree tree, Vector vector) {
        this();
        Node node = tree.getRoot();
        this.add(tree, node, vector);
    }

    public Checks getChecks() {
        return this.checks;
    }

    public double getScore() {
        return this.score;
    }

    public void add(RegressionTree tree, Node node, Vector vector) {
        List<Feature> featureList = tree.getFeatureList().getAll();
        if (node.isLeaf()) {
            this.score = node.getValue();
        } else {
            int featureIndex = node.getFeatureIndex();
            Feature feature = featureList.get(node.getFeatureIndex());
            double threshold = node.getThreshold();
            double featureValue = vector.get(featureIndex);
            if (Double.isNaN(featureValue)) {
                featureValue = -9999.0;
            }
            boolean direction = featureValue <= threshold;
            this.checks.featureIndices.add(featureIndex);
            this.checks.features.add(feature);
            this.checks.thresholds.add(threshold);
            this.checks.directions.add(direction);
            this.checks.values.add(featureValue);
            Node child = direction ? node.getLeftChild() : node.getRightChild();
            this.add(tree, child, vector);
        }
    }

    public String toString() {
        StringBuilder sb = new StringBuilder();
        sb.append(this.checks.toString());
        sb.append("score = ").append(this.score);
        return sb.toString();
    }

    public static TreeRule merge(TreeRule treeRule1, TreeRule treeRule2) {
        if (!treeRule1.getChecks().equals(treeRule2.getChecks())) {
            throw new IllegalArgumentException("cannot merge decisions with different decision paths");
        }
        TreeRule treeRule = new TreeRule();
        treeRule.checks = treeRule1.checks.copy();
        treeRule.score = treeRule1.score + treeRule2.score;
        return treeRule;
    }

    public static List<TreeRule> merge(List<TreeRule> treeRules) {
        HashMap<Checks, Double> map = new HashMap<Checks, Double>();
        for (TreeRule treeRule : treeRules) {
            double oldScore = map.getOrDefault(treeRule.checks, 0.0);
            double newScore = oldScore + treeRule.score;
            map.put(treeRule.checks, newScore);
        }
        ArrayList<TreeRule> merged = new ArrayList<TreeRule>();
        for (Map.Entry entry : map.entrySet()) {
            Checks checks = (Checks)entry.getKey();
            double score = (Double)entry.getValue();
            TreeRule treeRule = new TreeRule();
            treeRule.checks = checks;
            treeRule.score = score;
            merged.add(treeRule);
        }
        return merged;
    }
}

