/*
 * Decompiled with CFR 0.152.
 */
package edu.neu.ccs.pyramid.clustering.bm;

import edu.neu.ccs.pyramid.clustering.bm.BM;
import edu.neu.ccs.pyramid.dataset.DataSet;
import edu.neu.ccs.pyramid.util.BernoulliDistribution;
import edu.neu.ccs.pyramid.util.MathUtil;
import java.util.stream.IntStream;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.mahout.math.Vector;

public class BMTrainer {
    private static final Logger logger = LogManager.getLogger();
    DataSet dataSet;
    double[][] gammas;
    int numClusters;
    BM bm;
    int numIterations = 200;

    public BMTrainer(DataSet dataSet, int numClusters, long randomSeed) {
        this.numClusters = numClusters;
        this.dataSet = dataSet;
        this.gammas = new double[dataSet.getNumDataPoints()][numClusters];
        this.bm = new BM(numClusters, dataSet.getNumFeatures(), randomSeed);
    }

    public BM train() {
        for (int i = 0; i < this.numIterations; ++i) {
            this.iterate();
        }
        return this.bm;
    }

    public BM getBm() {
        return this.bm;
    }

    public double[][] getGammas() {
        return this.gammas;
    }

    public void iterate() {
        if (logger.isDebugEnabled()) {
            logger.debug("start one EM iteration");
        }
        this.eStep();
        this.mStep();
    }

    public void eStep() {
        if (logger.isDebugEnabled()) {
            logger.debug("start E step");
        }
        this.updateGamma();
        if (logger.isDebugEnabled()) {
            logger.debug("finish E step");
            logger.debug("objective = " + this.getObjective());
        }
    }

    public void mStep() {
        if (logger.isDebugEnabled()) {
            logger.debug("start M step");
        }
        IntStream.range(0, this.numClusters).forEach(this::updateCluster);
        this.bm.updateLogClusterConditioinalForEmpty();
        if (logger.isDebugEnabled()) {
            logger.debug("finish M step");
            logger.debug("objective = " + this.getObjective());
        }
    }

    private void updateCluster(int k) {
        double effectiveTotal = IntStream.range(0, this.dataSet.getNumDataPoints()).parallel().mapToDouble(i -> this.gammas[i][k]).sum();
        IntStream.range(0, this.dataSet.getNumFeatures()).parallel().forEach(d -> {
            double sum = this.weightedSum(k, d);
            double average = sum / effectiveTotal;
            if (average >= 1.0) {
                average = 0.9999;
            }
            this.bm.distributions[k][d] = new BernoulliDistribution(average);
        });
        this.bm.mixtureCoefficients[k] = effectiveTotal / (double)this.dataSet.getNumDataPoints();
        this.bm.logMixtureCoefficients[k] = Math.log(this.bm.mixtureCoefficients[k]);
    }

    private double weightedSum(int clusterIndex, int dimensionIndex) {
        Vector column = this.dataSet.getColumn(dimensionIndex);
        double sum = 0.0;
        for (Vector.Element nonzero : column.nonZeroes()) {
            int i = nonzero.index();
            sum += this.gammas[i][clusterIndex];
        }
        return sum;
    }

    private void updateGamma() {
        IntStream.range(0, this.dataSet.getNumDataPoints()).parallel().forEach(this::updateGamma);
    }

    private void updateGamma(int n) {
        Vector feature = this.dataSet.getRow(n);
        int numClusters = this.bm.getNumClusters();
        double[] logClusterConditionalProbs = this.bm.clusterConditionalLogProbArr(feature);
        double[] logNumerators = new double[numClusters];
        for (int k = 0; k < numClusters; ++k) {
            logNumerators[k] = this.bm.logMixtureCoefficients[k] + logClusterConditionalProbs[k];
        }
        double logDenominator = MathUtil.logSumExp(logNumerators);
        for (int k = 0; k < numClusters; ++k) {
            this.gammas[n][k] = Math.exp(logNumerators[k] - logDenominator);
        }
    }

    private double exactObjective() {
        return IntStream.range(0, this.dataSet.getNumDataPoints()).parallel().mapToDouble(this::exactObjective).sum();
    }

    private double exactObjective(int i) {
        return -1.0 * this.bm.logProbability(this.dataSet.getRow(i));
    }

    public double getObjective() {
        return this.exactObjective();
    }
}

