/*
 * Decompiled with CFR 0.152.
 */
package edu.neu.ccs.pyramid.util;

import edu.neu.ccs.pyramid.util.Pair;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.List;
import java.util.Random;
import java.util.stream.Collectors;
import org.apache.commons.math3.stat.descriptive.DescriptiveStatistics;

public class MathUtil {
    public static int[] shffuleArray(int[] array) {
        Random rgen = new Random();
        for (int i = 0; i < array.length; ++i) {
            int randomPosition = rgen.nextInt(array.length);
            int temp = array[i];
            array[i] = array[randomPosition];
            array[randomPosition] = temp;
        }
        return array;
    }

    public static double[] zeros(int m) {
        double[] results = new double[m];
        return results;
    }

    public static int[] range(int start, int end) {
        int[] results = new int[end - start];
        int index = 0;
        int i = start;
        while (i < end) {
            results[index++] = i++;
        }
        return results;
    }

    public static int[] randomRange(int start, int end) {
        int i;
        Random rgen = new Random();
        int size = end - start;
        int[] array = new int[size];
        for (i = 0; i < size; ++i) {
            array[i] = start + i;
        }
        for (i = 0; i < array.length; ++i) {
            int randomPosition = rgen.nextInt(array.length);
            int temp = array[i];
            array[i] = array[randomPosition];
            array[randomPosition] = temp;
        }
        return array;
    }

    public static double logSumExp(double[] arr) {
        double maxElement = Double.NEGATIVE_INFINITY;
        for (double number : arr) {
            if (!(number > maxElement)) continue;
            maxElement = number;
        }
        if (maxElement == Double.NEGATIVE_INFINITY) {
            return Double.NEGATIVE_INFINITY;
        }
        double sum = 0.0;
        for (double number : arr) {
            sum += Math.exp(number - maxElement);
        }
        return Math.log(sum) + maxElement;
    }

    public static double logSumExp(float[] arr) {
        double[] d = new double[arr.length];
        for (int i = 0; i < d.length; ++i) {
            d[i] = arr[i];
        }
        return MathUtil.logSumExp(d);
    }

    public static double l1Norm(double[] arr) {
        double norm = 0.0;
        for (double number : arr) {
            norm += Math.abs(number);
        }
        return norm;
    }

    public static double l2Norm(double[] arr) {
        double norm = 0.0;
        for (double number : arr) {
            norm += Math.pow(number, 2.0);
        }
        return Math.sqrt(norm);
    }

    public static double maxNorm(double[] arr) {
        double norm = 0.0;
        for (double number : arr) {
            double abs = Math.abs(number);
            if (!(abs > norm)) continue;
            norm = abs;
        }
        return norm;
    }

    public static double entropy(double[] distribution) {
        double entropy = 0.0;
        for (double prob : distribution) {
            if (prob == 0.0) continue;
            entropy -= prob * Math.log(prob) / Math.log(2.0);
        }
        return entropy;
    }

    public static double arraySum(double[] arr) {
        double sum = 0.0;
        for (double num : arr) {
            sum += num;
        }
        return sum;
    }

    public static float arraySum(float[] arr) {
        float sum = 0.0f;
        for (float num : arr) {
            sum += num;
        }
        return sum;
    }

    public static double[] softmax(double[] scores) {
        double[] probVector = new double[scores.length];
        double logDenominator = MathUtil.logSumExp(scores);
        for (int k = 0; k < scores.length; ++k) {
            double pro;
            double logNumerator = scores[k];
            probVector[k] = pro = Math.exp(logNumerator - logDenominator);
        }
        return probVector;
    }

    public static double[] logSoftmax(double[] scores) {
        double[] logProbVector = new double[scores.length];
        double logDenominator = MathUtil.logSumExp(scores);
        for (int k = 0; k < scores.length; ++k) {
            logProbVector[k] = scores[k] - logDenominator;
        }
        return logProbVector;
    }

    public static double logSigmoid(double score) {
        double[] arr = new double[]{0.0, score};
        return MathUtil.logSoftmax(arr)[1];
    }

    public static double[] inverseSoftMax(double[] probabilities) {
        int len = probabilities.length;
        for (int i = 0; i < len; ++i) {
            if (probabilities[i] != 0.0) continue;
            probabilities[i] = 1.0E-10;
        }
        double[] logs = new double[len];
        for (int i = 0; i < len; ++i) {
            logs[i] = Math.log(probabilities[i]);
        }
        double average = MathUtil.arraySum(logs) / (double)len;
        double[] scores = new double[len];
        for (int i = 0; i < len; ++i) {
            scores[i] = logs[i] - average;
        }
        return scores;
    }

    public static double inverseSigmoid(double prob) {
        double p = prob;
        if (p == 0.0) {
            p = 1.0E-10;
        }
        if (p == 1.0) {
            p = 0.9999;
        }
        return -Math.log(1.0 / p - 1.0);
    }

    public static double median(double[] arr) {
        DescriptiveStatistics stats = new DescriptiveStatistics(arr);
        return stats.getPercentile(50.0);
    }

    public static double sign(double d) {
        if (d > 0.0) {
            return 1.0;
        }
        if (d < 0.0) {
            return -1.0;
        }
        return 0.0;
    }

    public static double weightedMedian(double[] scores, double[] weights) {
        ArrayList<Pair<Double, Double>> pairs = new ArrayList<Pair<Double, Double>>();
        for (int i = 0; i < scores.length; ++i) {
            pairs.add(new Pair<Double, Double>(scores[i], weights[i]));
        }
        Comparator<Pair> comparator = Comparator.comparing(pair -> (Double)pair.getFirst());
        List sorted = pairs.stream().sorted(comparator).collect(Collectors.toList());
        double totalWeight = MathUtil.arraySum(weights);
        double sum = 0.0;
        for (Pair aSorted : sorted) {
            if (!((sum += ((Double)aSorted.getSecond()).doubleValue()) >= totalWeight / 2.0)) continue;
            return (Double)aSorted.getFirst();
        }
        return Double.NaN;
    }
}

