/*
 * Decompiled with CFR 0.152.
 */
package edu.neu.ccs.pyramid.regression.regression_tree;

import edu.neu.ccs.pyramid.dataset.DataSet;
import edu.neu.ccs.pyramid.regression.regression_tree.IntervalSplitter;
import edu.neu.ccs.pyramid.regression.regression_tree.RegTreeConfig;
import edu.neu.ccs.pyramid.regression.regression_tree.SplitResult;
import java.util.Comparator;
import java.util.List;
import java.util.Optional;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import org.apache.commons.math3.distribution.EnumeratedIntegerDistribution;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;

public class Splitter {
    private static final Logger logger = LogManager.getLogger();

    static Optional<SplitResult> split(RegTreeConfig regTreeConfig, DataSet dataSet, double[] labels, double[] probs) {
        GlobalStats globalStats = new GlobalStats(labels, probs);
        if (logger.isDebugEnabled()) {
            logger.debug("global statistics = " + globalStats);
        }
        int randomLevel = regTreeConfig.getRandomLevel();
        List<SplitResult> splitResults = IntStream.range(0, dataSet.getNumFeatures()).parallel().mapToObj(featureIndex -> Splitter.split(regTreeConfig, dataSet, labels, probs, featureIndex, globalStats)).filter(Optional::isPresent).map(Optional::get).sorted(Comparator.comparing(SplitResult::getReduction).reversed()).limit(randomLevel).collect(Collectors.toList());
        return Splitter.sample(splitResults);
    }

    public static List<SplitResult> getAllSplits(RegTreeConfig regTreeConfig, DataSet dataSet, double[] labels, double[] probs) {
        GlobalStats globalStats = new GlobalStats(labels, probs);
        return IntStream.range(0, dataSet.getNumFeatures()).parallel().mapToObj(featureIndex -> Splitter.split(regTreeConfig, dataSet, labels, probs, featureIndex, globalStats)).filter(Optional::isPresent).map(Optional::get).collect(Collectors.toList());
    }

    public static List<SplitResult> getAllSplits(RegTreeConfig regTreeConfig, DataSet dataSet, double[] labels) {
        double[] probs = new double[labels.length];
        for (int i = 0; i < labels.length; ++i) {
            probs[i] = 1.0;
        }
        return Splitter.getAllSplits(regTreeConfig, dataSet, labels, probs);
    }

    static Optional<SplitResult> split(RegTreeConfig regTreeConfig, DataSet dataSet, double[] labels, double[] probs, int featureIndex, GlobalStats globalStats) {
        Optional<SplitResult> splitResult = IntervalSplitter.split(regTreeConfig, dataSet, labels, probs, featureIndex, globalStats);
        return splitResult;
    }

    static Optional<SplitResult> sample(List<SplitResult> splitResults) {
        if (splitResults.size() == 0) {
            return Optional.empty();
        }
        if (splitResults.get(0).getReduction() == 0.0) {
            return Optional.empty();
        }
        double total = splitResults.stream().mapToDouble(SplitResult::getReduction).sum();
        double[] probs = splitResults.stream().mapToDouble(splitResult -> splitResult.getReduction() / total).toArray();
        int[] singletons = IntStream.range(0, splitResults.size()).toArray();
        EnumeratedIntegerDistribution distribution = new EnumeratedIntegerDistribution(singletons, probs);
        int sample = distribution.sample();
        return Optional.of(splitResults.get(sample));
    }

    static class GlobalStats {
        private double WeightedLabelSum;
        private double probabilisticCount;
        private int binaryCount;

        GlobalStats(double[] labels, double[] probs) {
            for (int i = 0; i < labels.length; ++i) {
                double label = labels[i];
                double prob = probs[i];
                this.WeightedLabelSum += label * prob;
                this.probabilisticCount += prob;
                if (!(prob > 0.0)) continue;
                ++this.binaryCount;
            }
        }

        public double getWeightedLabelSum() {
            return this.WeightedLabelSum;
        }

        public double getProbabilisticCount() {
            return this.probabilisticCount;
        }

        public int getBinaryCount() {
            return this.binaryCount;
        }

        public String toString() {
            StringBuilder sb = new StringBuilder("GlobalStats{");
            sb.append("WeightedLabelSum=").append(this.WeightedLabelSum);
            sb.append(", probabilisticCount=").append(this.probabilisticCount);
            sb.append(", binaryCount=").append(this.binaryCount);
            sb.append('}');
            return sb.toString();
        }
    }
}

