/*
 * Decompiled with CFR 0.152.
 */
package edu.neu.ccs.pyramid.regression.regression_tree;

import edu.neu.ccs.pyramid.dataset.DataSet;
import edu.neu.ccs.pyramid.regression.regression_tree.Interval;
import edu.neu.ccs.pyramid.regression.regression_tree.RegTreeConfig;
import edu.neu.ccs.pyramid.regression.regression_tree.SplitResult;
import edu.neu.ccs.pyramid.regression.regression_tree.Splitter;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.List;
import java.util.Optional;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.mahout.math.Vector;

class IntervalSplitter {
    private static final Logger logger = LogManager.getLogger();

    IntervalSplitter() {
    }

    static Optional<SplitResult> split(RegTreeConfig regTreeConfig, DataSet dataSet, double[] labels, double[] probs, int featureIndex, Splitter.GlobalStats globalStats) {
        Vector featureValues = dataSet.getColumn(featureIndex);
        List<Interval> possibleIntervals = IntervalSplitter.generateIntervals(regTreeConfig, featureValues, probs, labels, globalStats);
        List<Interval> compressedIntervals = IntervalSplitter.compress(possibleIntervals);
        return IntervalSplitter.findBest(regTreeConfig, compressedIntervals, featureIndex);
    }

    static List<Interval> generateIntervals(RegTreeConfig regTreeConfig, Vector featureValues, double[] probs, double[] labels, Splitter.GlobalStats globalStats) {
        Interval interval3;
        FeatureStats featureStats = new FeatureStats(featureValues, probs, labels, globalStats);
        if (logger.isDebugEnabled()) {
            logger.debug("feature statistics = " + featureStats);
        }
        int numIntervals = regTreeConfig.getNumSplitIntervals();
        ArrayList<Interval> intervals = new ArrayList<Interval>(numIntervals);
        double maxFeature = featureStats.getMax();
        double minFeature = featureStats.getMin();
        if (maxFeature == Double.NEGATIVE_INFINITY) {
            return intervals;
        }
        if (minFeature == Double.POSITIVE_INFINITY) {
            return intervals;
        }
        if (minFeature == maxFeature) {
            if (logger.isDebugEnabled()) {
                logger.debug("num generated intervals = " + intervals.size());
            }
            return intervals;
        }
        if (logger.isDebugEnabled()) {
            logger.debug("min = " + minFeature);
            logger.debug("max = " + maxFeature);
        }
        double intervalLength = (maxFeature - minFeature) / (double)numIntervals;
        for (int i = 0; i < numIntervals; ++i) {
            interval3 = new Interval();
            double lower = minFeature + (double)i * intervalLength;
            double upper = lower + intervalLength;
            interval3.setLower(lower);
            interval3.setUpper(upper);
            intervals.add(interval3);
        }
        if (featureStats.getNonZeroBinaryCount() > 0) {
            for (Vector.Element element : featureValues.nonZeroes()) {
                int i = element.index();
                double featureValue = element.get();
                double label = labels[i];
                if (Double.isNaN(featureValue) || probs[i] == 0.0) continue;
                int intervalIndex = IntervalSplitter.getIntervalIndex(featureValue, minFeature, intervalLength, numIntervals);
                Interval interval2 = (Interval)intervals.get(intervalIndex);
                double probability = probs[i];
                double oldProbCount = interval2.getProbabilisticCount();
                interval2.setProbabilisticCount(oldProbCount + probability);
                double oldWeightedLabelSum = interval2.getWeightedSum();
                interval2.setWeightedSum(oldWeightedLabelSum + label * probability);
            }
        }
        if (featureStats.getZeroBinaryCount() > 0) {
            int intervalIndex = IntervalSplitter.getIntervalIndex(0.0, minFeature, intervalLength, numIntervals);
            interval3 = (Interval)intervals.get(intervalIndex);
            double oldProbCount = interval3.getProbabilisticCount();
            interval3.setProbabilisticCount(oldProbCount + featureStats.getZeroProbCount());
            double oldWeightedLabelSum = interval3.getWeightedSum();
            interval3.setWeightedSum(oldWeightedLabelSum + featureStats.getZeroWeightedLabelSum());
        }
        for (Interval interval3 : intervals) {
            interval3.setPercentage(interval3.getProbabilisticCount() / globalStats.getProbabilisticCount());
        }
        if (featureStats.getNanBinaryCount() > 0) {
            for (Interval interval3 : intervals) {
                double oldCount = interval3.getProbabilisticCount();
                interval3.setProbabilisticCount(oldCount + interval3.getPercentage() * featureStats.getNanProbCount());
                double oldSum = interval3.getWeightedSum();
                interval3.setWeightedSum(oldSum + interval3.getPercentage() * featureStats.getNanWeightedLabelSum());
            }
        }
        if (logger.isDebugEnabled()) {
            logger.debug("num generated intervals = " + intervals.size());
        }
        return intervals;
    }

    static List<Interval> compress(List<Interval> intervals) {
        boolean inBlock = false;
        int start = 0;
        int end = 0;
        if (logger.isDebugEnabled()) {
            logger.debug("number of intervals to compress = " + intervals.size());
            logger.debug("intervals = " + intervals);
            if (intervals.size() > 0) {
                logger.debug("first interval prob count=" + intervals.get(0));
            }
        }
        for (int i = 0; i < intervals.size(); ++i) {
            if (i > 1 && intervals.get(i).getProbabilisticCount() == 0.0) {
                if (!inBlock) {
                    inBlock = true;
                    start = i;
                    end = i;
                    continue;
                }
                end = i;
                continue;
            }
            if (!inBlock) continue;
            inBlock = false;
            double mid = (intervals.get(start).getLower() + intervals.get(end).getUpper()) / 2.0;
            if (logger.isDebugEnabled()) {
                logger.debug("in block and start = " + start);
            }
            intervals.get(start - 1).setUpper(mid);
            intervals.get(end + 1).setLower(mid);
        }
        ArrayList<Interval> compressed = new ArrayList<Interval>(intervals.size());
        for (Interval interval : intervals) {
            if (interval.getProbabilisticCount() == 0.0) continue;
            compressed.add(interval);
        }
        return compressed;
    }

    private static Optional<SplitResult> findBest(RegTreeConfig regTreeConfig, List<Interval> intervals, int featureIndex) {
        ArrayList<SplitResult> splitResults = new ArrayList<SplitResult>(intervals.size());
        int minDataPerLeaf = regTreeConfig.getMinDataPerLeaf();
        double totalSum = 0.0;
        double totalCount = 0.0;
        for (Interval interval : intervals) {
            totalCount += interval.getProbabilisticCount();
            totalSum += interval.getWeightedSum();
        }
        double leftSum = 0.0;
        double leftCount = 0.0;
        for (int i = 0; i <= intervals.size() - 2; ++i) {
            Interval interval = intervals.get(i);
            double rightSum = totalSum - (leftSum += interval.getWeightedSum());
            double rightCount = totalCount - (leftCount += interval.getProbabilisticCount());
            double reduction = leftSum * leftSum / leftCount + rightSum * rightSum / rightCount - totalSum * totalSum / totalCount;
            SplitResult splitResult2 = new SplitResult();
            splitResult2.setFeatureIndex(featureIndex).setLeftCount(leftCount).setRightCount(rightCount).setReduction(reduction).setThreshold(interval.getUpper());
            splitResults.add(splitResult2);
        }
        return splitResults.stream().filter(splitResult -> splitResult.getLeftCount() >= (double)minDataPerLeaf && splitResult.getRightCount() >= (double)minDataPerLeaf).max(Comparator.comparing(SplitResult::getReduction));
    }

    static int getIntervalIndex(double featureValue, double minFeature, double intervalLength, int numIntervals) {
        int ceil = (int)Math.ceil((featureValue - minFeature) / intervalLength);
        if (ceil > numIntervals) {
            ceil = numIntervals;
        }
        int intervalIndex = ceil == 0 ? 0 : ceil - 1;
        return intervalIndex;
    }

    static class FeatureStats {
        private int zeroBinaryCount;
        private int nonZeroBinaryCount;
        private int nanBinaryCount;
        private double zeroProbCount;
        private double nonZeroProbCount;
        private double nanProbCount;
        private double zeroWeightedLabelSum;
        private double nonZeroWeightedLabelSum;
        private double nanWeightedLabelSum;
        private double min = Double.POSITIVE_INFINITY;
        private double max = Double.NEGATIVE_INFINITY;

        FeatureStats(Vector featureValues, double[] probs, double[] labels, Splitter.GlobalStats globalStats) {
            for (Vector.Element element : featureValues.nonZeroes()) {
                int index = element.index();
                double prob = probs[index];
                if (!(prob > 0.0)) continue;
                double value = element.get();
                double label = labels[index];
                if (Double.isNaN(value)) {
                    ++this.nanBinaryCount;
                    this.nanProbCount += prob;
                    this.nanWeightedLabelSum += prob * label;
                    continue;
                }
                ++this.nonZeroBinaryCount;
                this.nonZeroProbCount += prob;
                this.nonZeroWeightedLabelSum += prob * label;
                if (value < this.min) {
                    this.min = value;
                }
                if (!(value > this.max)) continue;
                this.max = value;
            }
            this.zeroBinaryCount = globalStats.getBinaryCount() - this.nonZeroBinaryCount - this.nanBinaryCount;
            this.zeroProbCount = globalStats.getProbabilisticCount() - this.nonZeroProbCount - this.nanProbCount;
            this.zeroWeightedLabelSum = globalStats.getWeightedLabelSum() - this.nonZeroWeightedLabelSum - this.nanWeightedLabelSum;
            if (this.min > 0.0 && this.zeroBinaryCount > 0) {
                this.min = 0.0;
            }
            if (this.max < 0.0 && this.zeroBinaryCount > 0) {
                this.max = 0.0;
            }
        }

        int getZeroBinaryCount() {
            return this.zeroBinaryCount;
        }

        int getNonZeroBinaryCount() {
            return this.nonZeroBinaryCount;
        }

        int getNanBinaryCount() {
            return this.nanBinaryCount;
        }

        double getMin() {
            return this.min;
        }

        double getMax() {
            return this.max;
        }

        public double getZeroProbCount() {
            return this.zeroProbCount;
        }

        public double getNonZeroProbCount() {
            return this.nonZeroProbCount;
        }

        public double getNanProbCount() {
            return this.nanProbCount;
        }

        public double getZeroWeightedLabelSum() {
            return this.zeroWeightedLabelSum;
        }

        public double getNonZeroWeightedLabelSum() {
            return this.nonZeroWeightedLabelSum;
        }

        public double getNanWeightedLabelSum() {
            return this.nanWeightedLabelSum;
        }

        public String toString() {
            StringBuilder sb = new StringBuilder("FeatureStats{");
            sb.append("zeroBinaryCount=").append(this.zeroBinaryCount);
            sb.append(", nonZeroBinaryCount=").append(this.nonZeroBinaryCount);
            sb.append(", nanBinaryCount=").append(this.nanBinaryCount);
            sb.append(", zeroProbCount=").append(this.zeroProbCount);
            sb.append(", nonZeroProbCount=").append(this.nonZeroProbCount);
            sb.append(", nanProbCount=").append(this.nanProbCount);
            sb.append(", zeroWeightedLabelSum=").append(this.zeroWeightedLabelSum);
            sb.append(", nonZeroWeightedLabelSum=").append(this.nonZeroWeightedLabelSum);
            sb.append(", nanWeightedLabelSum=").append(this.nanWeightedLabelSum);
            sb.append(", min=").append(this.min);
            sb.append(", max=").append(this.max);
            sb.append('}');
            return sb.toString();
        }
    }
}

