/*
 * Decompiled with CFR 0.152.
 */
package edu.neu.ccs.pyramid.regression.regression_tree;

import edu.neu.ccs.pyramid.dataset.DataSet;
import edu.neu.ccs.pyramid.dataset.RegDataSet;
import edu.neu.ccs.pyramid.regression.regression_tree.AverageOutputCalculator;
import edu.neu.ccs.pyramid.regression.regression_tree.LeafOutputCalculator;
import edu.neu.ccs.pyramid.regression.regression_tree.Node;
import edu.neu.ccs.pyramid.regression.regression_tree.RegTreeConfig;
import edu.neu.ccs.pyramid.regression.regression_tree.RegressionTree;
import edu.neu.ccs.pyramid.regression.regression_tree.SplitResult;
import edu.neu.ccs.pyramid.regression.regression_tree.Splitter;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Comparator;
import java.util.List;
import java.util.Optional;
import java.util.stream.IntStream;
import org.apache.mahout.math.DenseVector;
import org.apache.mahout.math.Vector;

public class RegTreeTrainer {
    public static RegressionTree fit(RegTreeConfig regTreeConfig, RegDataSet regDataSet) {
        return RegTreeTrainer.fit(regTreeConfig, regDataSet, regDataSet.getLabels());
    }

    public static RegressionTree fit(RegTreeConfig regTreeConfig, DataSet dataSet, double[] labels) {
        AverageOutputCalculator leafOutputCalculator = new AverageOutputCalculator();
        return RegTreeTrainer.fit(regTreeConfig, dataSet, labels, leafOutputCalculator);
    }

    public static RegressionTree fit(RegTreeConfig regTreeConfig, DataSet dataSet, double[] labels, LeafOutputCalculator leafOutputCalculator) {
        double[] weights = new double[labels.length];
        Arrays.fill(weights, 1.0);
        return RegTreeTrainer.fit(regTreeConfig, dataSet, labels, weights, leafOutputCalculator);
    }

    public static RegressionTree fit(RegTreeConfig regTreeConfig, DataSet dataSet, double[] labels, double[] weights, LeafOutputCalculator leafOutputCalculator) {
        Optional<Node> leafToSplitOptional;
        RegressionTree tree = new RegressionTree();
        tree.setFeatureList(dataSet.getFeatureList());
        tree.leaves = new ArrayList<Node>();
        tree.root = new Node();
        tree.root.setId(tree.numNodes);
        ++tree.numNodes;
        double[] rootProbs = new double[dataSet.getNumDataPoints()];
        for (int dataPoint = 0; dataPoint < dataSet.getNumDataPoints(); ++dataPoint) {
            rootProbs[dataPoint] = weights[dataPoint];
        }
        tree.root.setProbs(rootProbs);
        RegTreeTrainer.updateNode(tree.root, regTreeConfig, dataSet, labels);
        tree.leaves.add(tree.root);
        tree.root.setLeaf(true);
        int maxNumLeaves = regTreeConfig.getMaxNumLeaves();
        while (tree.leaves.size() < maxNumLeaves && (leafToSplitOptional = RegTreeTrainer.findLeafToSplit(tree.leaves)).isPresent()) {
            Node leafToSplit = leafToSplitOptional.get();
            RegTreeTrainer.splitNode(tree, leafToSplit, regTreeConfig, dataSet, labels);
        }
        RegTreeTrainer.setLeavesOutputs(tree.leaves, leafOutputCalculator, labels);
        RegTreeTrainer.cleanLeaves(tree.leaves);
        RegTreeTrainer.normalizeReductions(tree, dataSet);
        return tree;
    }

    public static RegressionTree constantTree(double score) {
        RegressionTree tree = new RegressionTree();
        tree.root = new Node();
        tree.root.setValue(score);
        tree.root.setLeaf(true);
        tree.leaves.add(tree.root);
        return tree;
    }

    private static void splitNode(RegressionTree tree, Node leafToSplit, RegTreeConfig regTreeConfig, DataSet dataSet, double[] labels) {
        int numDataPoints = dataSet.getNumDataPoints();
        int featureIndex = leafToSplit.getFeatureIndex();
        double threshold = leafToSplit.getThreshold();
        Vector inputVector = dataSet.getColumn(featureIndex);
        Object columnVector = inputVector.isDense() ? inputVector : new DenseVector(inputVector);
        Node leftChild = new Node();
        leftChild.setId(tree.numNodes);
        ++tree.numNodes;
        Node rightChild = new Node();
        rightChild.setId(tree.numNodes);
        ++tree.numNodes;
        double[] parentProbs = leafToSplit.getProbs();
        double[] leftProbs = new double[numDataPoints];
        double[] rightProbs = new double[numDataPoints];
        IntStream.range(0, numDataPoints).parallel().forEach(arg_0 -> RegTreeTrainer.lambda$splitNode$0((Vector)columnVector, leftProbs, parentProbs, leafToSplit, rightProbs, threshold, arg_0));
        leftChild.setProbs(leftProbs);
        rightChild.setProbs(rightProbs);
        int maxNumLeaves = regTreeConfig.getMaxNumLeaves();
        if (tree.leaves.size() != maxNumLeaves - 1) {
            RegTreeTrainer.updateNode(leftChild, regTreeConfig, dataSet, labels);
            RegTreeTrainer.updateNode(rightChild, regTreeConfig, dataSet, labels);
        }
        leafToSplit.setLeftChild(leftChild);
        leafToSplit.setRightChild(rightChild);
        leafToSplit.setLeaf(false);
        leafToSplit.clearProbs();
        tree.leaves.remove(leafToSplit);
        leftChild.setLeaf(true);
        rightChild.setLeaf(true);
        tree.leaves.add(leftChild);
        tree.leaves.add(rightChild);
    }

    private static void updateNode(Node node, RegTreeConfig regTreeConfig, DataSet dataSet, double[] labels) {
        Optional<SplitResult> splitResultOptional = Splitter.split(regTreeConfig, dataSet, labels, node.getProbs());
        if (splitResultOptional.isPresent()) {
            SplitResult splitResult = splitResultOptional.get();
            node.setFeatureIndex(splitResult.getFeatureIndex());
            node.setThreshold(splitResult.getThreshold());
            node.setReduction(splitResult.getReduction());
            double leftCount = splitResult.getLeftCount();
            double rightCount = splitResult.getRightCount();
            double totalCount = leftCount + rightCount;
            node.setLeftProb(leftCount / totalCount);
            node.setRightProb(rightCount / totalCount);
            node.setSplitable(true);
        } else {
            node.setSplitable(false);
        }
    }

    private static void cleanLeaves(List<Node> leaves) {
        for (Node leaf : leaves) {
            leaf.clearProbs();
        }
    }

    private static void setLeavesOutputs(List<Node> leaves, LeafOutputCalculator calculator, double[] labels) {
        leaves.parallelStream().forEach(leaf -> RegTreeTrainer.setLeafOutput(leaf, calculator, labels));
    }

    private static void setLeafOutput(Node leaf, LeafOutputCalculator calculator, double[] labels) {
        double[] probs = leaf.getProbs();
        double output = calculator.getLeafOutput(probs, labels);
        leaf.setValue(output);
    }

    private static Optional<Node> findLeafToSplit(List<Node> leaves) {
        return leaves.stream().filter(Node::isSplitable).max(Comparator.comparing(Node::getReduction));
    }

    private static void normalizeReductions(RegressionTree tree, DataSet dataSet) {
        int numDataPoints = dataSet.getNumDataPoints();
        List<Node> nodes = tree.traverse();
        for (Node node : nodes) {
            double oldReduction = node.getReduction();
            node.setReduction(oldReduction / (double)numDataPoints);
        }
    }

    private static /* synthetic */ void lambda$splitNode$0(Vector columnVector, double[] leftProbs, double[] parentProbs, Node leafToSplit, double[] rightProbs, double threshold, int i) {
        double featureValue = columnVector.get(i);
        if (Double.isNaN(featureValue)) {
            leftProbs[i] = parentProbs[i] * leafToSplit.getLeftProb();
            rightProbs[i] = parentProbs[i] * leafToSplit.getRightProb();
        } else if (featureValue <= threshold) {
            leftProbs[i] = parentProbs[i];
            rightProbs[i] = 0.0;
        } else {
            leftProbs[i] = 0.0;
            rightProbs[i] = parentProbs[i];
        }
    }
}

