/*
 * Decompiled with CFR 0.152.
 */
package edu.neu.ccs.pyramid.clustering.bm;

import edu.neu.ccs.pyramid.util.ArgSort;
import edu.neu.ccs.pyramid.util.BernoulliDistribution;
import edu.neu.ccs.pyramid.util.MathUtil;
import edu.neu.ccs.pyramid.util.Pair;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Comparator;
import java.util.List;
import java.util.Random;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import org.apache.commons.math3.distribution.EnumeratedIntegerDistribution;
import org.apache.commons.math3.distribution.UniformRealDistribution;
import org.apache.commons.math3.random.RandomGenerator;
import org.apache.commons.math3.random.RandomGeneratorFactory;
import org.apache.mahout.math.DenseVector;
import org.apache.mahout.math.Vector;

public class BM
implements Serializable {
    private static final long serialVersionUID = 2L;
    private int numClusters;
    private int dimension;
    BernoulliDistribution[][] distributions;
    double[] mixtureCoefficients;
    double[] logMixtureCoefficients;
    private double[] logClusterConditioinalForEmpty;
    private List<String> names;

    public BM(int numClusters, int dimension, long randomSeed) {
        this.numClusters = numClusters;
        this.dimension = dimension;
        this.distributions = new BernoulliDistribution[numClusters][dimension];
        this.mixtureCoefficients = new double[numClusters];
        Arrays.fill(this.mixtureCoefficients, 1.0 / (double)numClusters);
        this.logMixtureCoefficients = new double[numClusters];
        Arrays.fill(this.logMixtureCoefficients, Math.log(1.0 / (double)numClusters));
        Random random = new Random(randomSeed);
        RandomGenerator randomGenerator = RandomGeneratorFactory.createRandomGenerator((Random)random);
        UniformRealDistribution uniform = new UniformRealDistribution(randomGenerator, 0.25, 0.75);
        for (int k = 0; k < numClusters; ++k) {
            for (int d = 0; d < dimension; ++d) {
                double p = uniform.sample();
                this.distributions[k][d] = new BernoulliDistribution(p);
            }
        }
        this.logClusterConditioinalForEmpty = new double[numClusters];
        this.updateLogClusterConditioinalForEmpty();
        this.names = new ArrayList<String>(dimension);
        for (int d = 0; d < dimension; ++d) {
            this.names.add("" + d);
        }
    }

    public void setNames(List<String> names) {
        this.names = names;
    }

    public List<String> getNames() {
        return this.names;
    }

    public double clusterConditionalLogProb(Vector vector, int clusterIndex) {
        double logProb = this.logClusterConditioinalForEmpty[clusterIndex];
        for (Vector.Element nonzero : vector.nonZeroes()) {
            int l = nonzero.index();
            BernoulliDistribution distribution = this.distributions[clusterIndex][l];
            logProb -= distribution.logProbability(0);
            logProb += distribution.logProbability(1);
        }
        return logProb;
    }

    private double computeLogClusterConditionalForEmpty(int clusterIndex) {
        double logProb = 0.0;
        for (int l = 0; l < this.dimension; ++l) {
            BernoulliDistribution distribution = this.distributions[clusterIndex][l];
            logProb += distribution.logProbability(0);
        }
        return logProb;
    }

    void updateLogClusterConditioinalForEmpty() {
        IntStream.range(0, this.numClusters).forEach(k -> {
            this.logClusterConditioinalForEmpty[k] = this.computeLogClusterConditionalForEmpty(k);
        });
    }

    public double[] clusterConditionalLogProbArr(Vector vector) {
        double[] probArr = new double[this.numClusters];
        for (int clusterIndex = 0; clusterIndex < this.numClusters; ++clusterIndex) {
            probArr[clusterIndex] = this.clusterConditionalLogProb(vector, clusterIndex);
        }
        return probArr;
    }

    public double logProbability(Vector vector) {
        double[] clusterConditionalLogProbArr = this.clusterConditionalLogProbArr(vector);
        double[] arr = new double[this.numClusters];
        for (int k = 0; k < this.numClusters; ++k) {
            arr[k] = this.logMixtureCoefficients[k] + clusterConditionalLogProbArr[k];
        }
        return MathUtil.logSumExp(arr);
    }

    public Vector sample() {
        DenseVector vector = new DenseVector(this.dimension);
        int[] clusters = IntStream.range(0, this.numClusters).toArray();
        EnumeratedIntegerDistribution enumeratedIntegerDistribution = new EnumeratedIntegerDistribution(clusters, this.mixtureCoefficients);
        int cluster = enumeratedIntegerDistribution.sample();
        for (int d = 0; d < this.dimension; ++d) {
            vector.set(d, (double)this.distributions[cluster][d].sample());
        }
        return vector;
    }

    public Vector sample(int kCluster) {
        if (kCluster < 0 || kCluster >= this.numClusters) {
            throw new RuntimeException("Please given a legal k-th cluster");
        }
        DenseVector vector = new DenseVector(this.dimension);
        for (int d = 0; d < this.dimension; ++d) {
            vector.set(d, (double)this.distributions[kCluster][d].sample());
        }
        return vector;
    }

    public int getNumClusters() {
        return this.numClusters;
    }

    public int getDimension() {
        return this.dimension;
    }

    public BernoulliDistribution[][] getDistributions() {
        return this.distributions;
    }

    public double[] getMixtureCoefficients() {
        return this.mixtureCoefficients;
    }

    public String toString() {
        int[] sortedComponents;
        StringBuilder sb = new StringBuilder("BMM{");
        sb.append("numClusters=").append(this.numClusters);
        sb.append(", dimension=").append(this.dimension).append("\n");
        for (int k : sortedComponents = ArgSort.argSortDescending(this.mixtureCoefficients)) {
            sb.append("cluster ").append(k).append(":\n");
            sb.append("proportion = ").append(this.mixtureCoefficients[k]).append("\n");
            sb.append("probabilities = ").append("[");
            ArrayList<Pair<String, Double>> pairs = new ArrayList<Pair<String, Double>>();
            for (int d = 0; d < this.dimension; ++d) {
                Pair<String, Double> pair = new Pair<String, Double>(this.names.get(d), this.distributions[k][d].getP());
                pairs.add(pair);
            }
            Comparator<Pair> comparator = Comparator.comparing(Pair::getSecond);
            List sorted = pairs.stream().sorted(comparator.reversed()).collect(Collectors.toList());
            for (int d = 0; d < this.dimension; ++d) {
                Pair pair = (Pair)sorted.get(d);
                sb.append((String)pair.getFirst()).append(":").append(pair.getSecond());
                if (d == this.dimension - 1) continue;
                sb.append(", ");
            }
            sb.append("]");
            sb.append("\n");
        }
        sb.append('}');
        return sb.toString();
    }
}

