/*
 * Decompiled with CFR 0.152.
 */
package edu.neu.ccs.pyramid.util;

import edu.neu.ccs.pyramid.util.Pair;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.Set;
import java.util.stream.IntStream;
import org.apache.commons.math3.distribution.BinomialDistribution;

public class Sampling {
    public static int[] sampleBySize(int totalSize, int sampleSize) {
        ArrayList<Integer> list = new ArrayList<Integer>(totalSize);
        for (int i = 0; i < totalSize; ++i) {
            list.add(i);
        }
        Collections.shuffle(list);
        int[] sample = new int[sampleSize];
        for (int i = 0; i < sampleSize; ++i) {
            sample[i] = (Integer)list.get(i);
        }
        return sample;
    }

    public static int[] sampleByPercentage(int totalSize, double percentage) {
        int sampleSize = (int)Math.ceil(percentage * (double)totalSize);
        return Sampling.sampleBySize(totalSize, sampleSize);
    }

    public static List<Integer> stratified(int[] labels, double percentage) {
        HashMap<Integer, List> map = new HashMap<Integer, List>();
        for (int i = 0; i < labels.length; ++i) {
            int label = labels[i];
            if (!map.containsKey(label)) {
                map.put(label, new ArrayList());
            }
            List list = (List)map.get(label);
            list.add(i);
            map.put(label, list);
        }
        ArrayList<Integer> sample = new ArrayList<Integer>();
        for (Map.Entry entry : map.entrySet()) {
            List indices = (List)entry.getValue();
            sample.addAll(Sampling.sampleByPercentage(indices, percentage));
        }
        return sample;
    }

    public static List<Integer> sampleByPercentage(List<Integer> indices, double percentage) {
        Collections.shuffle(indices);
        int totalSize = indices.size();
        int sampleSize = (int)Math.ceil(percentage * (double)totalSize);
        return indices.subList(0, sampleSize);
    }

    public static IntStream sampleWithReplacement(int sampleSize, int start, int end) {
        return new Random().ints(sampleSize, start, end);
    }

    public static IntStream sampleWithReplacement(int sampleSize, List<Integer> indices) {
        return Sampling.sampleWithReplacement(sampleSize, 0, indices.size()).map(indices::get);
    }

    public static double doubleUniform(double min, double max) {
        return Math.random() * (max - min) + min;
    }

    public static double doubleLogUniform(double min, double max) {
        if (min <= 0.0) {
            throw new IllegalArgumentException("min<=0");
        }
        double minLog = Math.log(min);
        double maxLog = Math.log(max);
        double exp = Sampling.doubleUniform(minLog, maxLog);
        return Math.exp(exp);
    }

    public static int intUniform(int min, int max) {
        return new Random().nextInt(max - min + 1) + min;
    }

    public static Set<Integer> rotate(List<Pair<Integer, Double>> probs, int size) {
        HashSet<Integer> res = new HashSet<Integer>();
        if (size == 0) {
            return res;
        }
        if (probs.size() < size) {
            probs.stream().forEach(pair -> res.add((Integer)pair.getFirst()));
            return res;
        }
        boolean next = true;
        block0: while (next) {
            for (Pair<Integer, Double> pair2 : probs) {
                BinomialDistribution distribution;
                int sample;
                int dataIndex = pair2.getFirst();
                double prob = pair2.getSecond();
                if (res.size() == size) {
                    next = false;
                    continue block0;
                }
                if (res.contains(dataIndex) || (sample = (distribution = new BinomialDistribution(1, prob)).sample()) != 1) continue;
                res.add(dataIndex);
            }
        }
        return res;
    }

    public static double[] randomCategoricalDis(int dimension) {
        double[] vector = new double[dimension];
        double used = 0.0;
        for (int i = 0; i < dimension; ++i) {
            double prob = i == dimension - 1 ? 1.0 - used : Sampling.doubleUniform(0.0, 1.0 - used);
            vector[i] = prob;
            used += prob;
        }
        return vector;
    }
}

