/*
 * Decompiled with CFR 0.152.
 */
package edu.neu.ccs.pyramid.regression.regression_tree;

import edu.neu.ccs.pyramid.feature.FeatureList;
import edu.neu.ccs.pyramid.regression.Regressor;
import edu.neu.ccs.pyramid.regression.regression_tree.Node;
import java.io.Serializable;
import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.List;
import java.util.Stack;
import java.util.concurrent.LinkedBlockingDeque;
import org.apache.mahout.math.Vector;

public class RegressionTree
implements Regressor,
Serializable {
    private static final long serialVersionUID = 3L;
    int numNodes = 0;
    protected Node root;
    protected List<Node> leaves = new ArrayList<Node>();
    private FeatureList featureList;

    protected RegressionTree() {
    }

    public void shrink(double shrinkage) {
        for (Node leaf : this.leaves) {
            double value = leaf.getValue();
            leaf.setValue(value * shrinkage);
        }
    }

    public static RegressionTree newStump(int featureIndex, double threshold, double leftOutput, double rightOutput) {
        RegressionTree tree = new RegressionTree();
        tree.leaves = new ArrayList<Node>();
        tree.root = new Node();
        tree.root.setId(tree.numNodes);
        tree.root.setFeatureIndex(featureIndex);
        tree.root.setThreshold(threshold);
        tree.root.setLeaf(false);
        ++tree.numNodes;
        Node leftChild = new Node();
        leftChild.setId(tree.numNodes);
        leftChild.setLeaf(true);
        leftChild.setValue(leftOutput);
        tree.leaves.add(leftChild);
        ++tree.numNodes;
        Node rightChild = new Node();
        rightChild.setId(tree.numNodes);
        rightChild.setLeaf(true);
        rightChild.setValue(rightOutput);
        tree.leaves.add(rightChild);
        ++tree.numNodes;
        tree.root.setLeftChild(leftChild);
        tree.root.setRightChild(rightChild);
        return tree;
    }

    public int getNumLeaves() {
        return this.leaves.size();
    }

    public Node getRoot() {
        return this.root;
    }

    @Override
    public double predict(Vector vector) {
        int numNodes = this.numNodes;
        boolean[] calculated = new boolean[numNodes];
        double[] probs = new double[numNodes];
        double prediction = 0.0;
        for (Node leaf : this.leaves) {
            double prob = this.probability(vector, leaf, calculated, probs);
            prediction += prob * leaf.getValue();
        }
        return prediction;
    }

    double probability(Vector vector, Node node, boolean[] calculated, double[] probs) {
        int id = node.getId();
        if (calculated[id]) {
            return probs[id];
        }
        if (node == this.root) {
            return 1.0;
        }
        Node parent = node.getParent();
        int featureIndex = parent.getFeatureIndex();
        double threshold = parent.getThreshold();
        boolean isLeftChild = node == parent.getLeftChild();
        double featureValue = vector.get(featureIndex);
        if (Double.isNaN(featureValue)) {
            if (isLeftChild) {
                double prob = parent.getLeftProb() * this.probability(vector, parent, calculated, probs);
                calculated[id] = true;
                probs[id] = prob;
                return prob;
            }
            double prob = parent.getRightProb() * this.probability(vector, parent, calculated, probs);
            calculated[id] = true;
            probs[id] = prob;
            return prob;
        }
        if (isLeftChild && featureValue <= threshold) {
            double prob = this.probability(vector, parent, calculated, probs);
            calculated[id] = true;
            probs[id] = prob;
            return prob;
        }
        if (isLeftChild && featureValue > threshold) {
            double prob = 0.0;
            calculated[id] = true;
            probs[id] = prob;
            return prob;
        }
        if (!isLeftChild && featureValue <= threshold) {
            double prob = 0.0;
            calculated[id] = true;
            probs[id] = prob;
            return prob;
        }
        if (!isLeftChild && featureValue > threshold) {
            double prob = this.probability(vector, parent, calculated, probs);
            calculated[id] = true;
            probs[id] = prob;
            return prob;
        }
        return 1.0;
    }

    double probability(Vector vector, Node node) {
        int id = node.getId();
        if (node == this.root) {
            return 1.0;
        }
        Node parent = node.getParent();
        int featureIndex = parent.getFeatureIndex();
        double threshold = parent.getThreshold();
        boolean isLeftChild = node == parent.getLeftChild();
        double featureValue = vector.get(featureIndex);
        if (Double.isNaN(featureValue)) {
            if (isLeftChild) {
                double prob = parent.getLeftProb() * this.probability(vector, parent);
                return prob;
            }
            double prob = parent.getRightProb() * this.probability(vector, parent);
            return prob;
        }
        if (isLeftChild && featureValue <= threshold) {
            double prob = this.probability(vector, parent);
            return prob;
        }
        if (isLeftChild && featureValue > threshold) {
            double prob = 0.0;
            return prob;
        }
        if (!isLeftChild && featureValue <= threshold) {
            double prob = 0.0;
            return prob;
        }
        if (!isLeftChild && featureValue > threshold) {
            double prob = this.probability(vector, parent);
            return prob;
        }
        return 1.0;
    }

    public List<Integer> getFeatureIndices() {
        ArrayList<Integer> featureIndices = new ArrayList<Integer>();
        LinkedBlockingDeque<Node> queue = new LinkedBlockingDeque<Node>();
        queue.offer(this.root);
        while (queue.size() != 0) {
            Node node = (Node)queue.poll();
            if (node.isLeaf()) continue;
            featureIndices.add(node.getFeatureIndex());
            queue.offer(node.getLeftChild());
            queue.offer(node.getRightChild());
        }
        return featureIndices;
    }

    public String toString() {
        StringBuilder sb = new StringBuilder();
        sb.append("RegTree{");
        for (Node node : this.leaves) {
            Stack<Node> stack = new Stack<Node>();
            while (true) {
                stack.push(node);
                if (node.getParent() == null) break;
                node = node.getParent();
            }
            while (!stack.empty()) {
                Node node1 = (Node)stack.pop();
                if (!node1.isLeaf()) {
                    Node node2 = (Node)stack.peek();
                    if (node2 == node1.getLeftChild()) {
                        sb.append(node1.getFeatureIndex()).append("<=").append(node1.getThreshold()).append("   ");
                        continue;
                    }
                    sb.append(node1.getFeatureIndex()).append(">").append(node1.getThreshold()).append("   ");
                    continue;
                }
                sb.append(": ").append(node1.getValue()).append("\n");
            }
        }
        sb.append("}");
        return sb.toString();
    }

    List<Node> traverse() {
        ArrayList<Node> list = new ArrayList<Node>();
        ArrayDeque<Node> deque = new ArrayDeque<Node>();
        deque.addFirst(this.root);
        while (deque.size() != 0) {
            Node visit = (Node)deque.removeFirst();
            list.add(visit);
            if (visit.isLeaf()) continue;
            deque.addFirst(visit.getRightChild());
            deque.addFirst(visit.getLeftChild());
        }
        return list;
    }

    @Override
    public FeatureList getFeatureList() {
        return this.featureList;
    }

    public void setFeatureList(FeatureList featureList) {
        this.featureList = featureList;
    }
}

