/*
 * Decompiled with CFR 0.152.
 */
package edu.neu.ccs.pyramid.multilabel_classification.cbm;

import edu.neu.ccs.pyramid.classification.PriorProbClassifier;
import edu.neu.ccs.pyramid.clustering.bm.BMSelector;
import edu.neu.ccs.pyramid.dataset.DataSet;
import edu.neu.ccs.pyramid.dataset.DataSetBuilder;
import edu.neu.ccs.pyramid.dataset.DataSetUtil;
import edu.neu.ccs.pyramid.dataset.Density;
import edu.neu.ccs.pyramid.dataset.MultiLabel;
import edu.neu.ccs.pyramid.dataset.MultiLabelClfDataSet;
import edu.neu.ccs.pyramid.multilabel_classification.cbm.CBM;
import edu.neu.ccs.pyramid.util.ArgMax;
import edu.neu.ccs.pyramid.util.MathUtil;
import java.util.ArrayList;
import java.util.stream.IntStream;
import org.apache.commons.lang3.time.StopWatch;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.mahout.math.Vector;

public abstract class AbstractCBMOptimizer {
    private static final Logger logger = LogManager.getLogger();
    protected CBM cbm;
    protected MultiLabelClfDataSet dataSet;
    protected double[][] gammas;
    protected double skipLabelThreshold = 1.0E-5;
    protected double skipDataThreshold = 1.0E-5;
    protected int multiclassUpdatesPerIter = 20;
    protected int binaryUpdatesPerIter = 20;
    protected double smoothingStrength = 1.0E-4;
    private int[] positiveCounts;
    protected DataSet labelMatrix;
    protected boolean parallelBinaryUpdates = true;

    public AbstractCBMOptimizer(CBM cbm, MultiLabelClfDataSet dataSet) {
        this.cbm = cbm;
        this.dataSet = dataSet;
        this.gammas = new double[dataSet.getNumDataPoints()][cbm.getNumComponents()];
        double average = 1.0 / (double)cbm.getNumComponents();
        for (int n = 0; n < dataSet.getNumDataPoints(); ++n) {
            for (int k = 0; k < cbm.getNumComponents(); ++k) {
                this.gammas[n][k] = average;
            }
        }
        this.labelMatrix = DataSetBuilder.getBuilder().numDataPoints(dataSet.getNumDataPoints()).numFeatures(dataSet.getNumClasses()).density(Density.SPARSE_RANDOM).build();
        for (int i = 0; i < dataSet.getNumDataPoints(); ++i) {
            MultiLabel multiLabel = dataSet.getMultiLabels()[i];
            for (int l : multiLabel.getMatchedLabels()) {
                this.labelMatrix.setFeatureValue(i, l, 1.0);
            }
        }
        this.positiveCounts = new int[dataSet.getNumClasses()];
        for (int l = 0; l < dataSet.getNumClasses(); ++l) {
            this.positiveCounts[l] = this.labelMatrix.getColumn(l).getNumNonZeroElements();
        }
    }

    public void setSmoothingStrength(double smoothingStrength) {
        this.smoothingStrength = smoothingStrength;
    }

    public void setMulticlassUpdatesPerIter(int multiclassUpdatesPerIter) {
        this.multiclassUpdatesPerIter = multiclassUpdatesPerIter;
    }

    public void setBinaryUpdatesPerIter(int binaryUpdatesPerIter) {
        this.binaryUpdatesPerIter = binaryUpdatesPerIter;
    }

    public void setSkipLabelThreshold(double skipLabelThreshold) {
        this.skipLabelThreshold = skipLabelThreshold;
    }

    public void setSkipDataThreshold(double skipDataThreshold) {
        this.skipDataThreshold = skipDataThreshold;
    }

    public void initialize() {
        this.gammas = BMSelector.selectGammas(this.dataSet.getNumClasses(), this.dataSet.getMultiLabels(), this.cbm.getNumComponents());
        if (logger.isDebugEnabled()) {
            logger.debug("performing M step");
        }
        this.mStep();
    }

    public void randInitialize() {
        int K = this.cbm.getNumComponents();
        for (int i = 0; i < this.dataSet.getNumDataPoints(); ++i) {
            double[] dist = new double[K];
            for (int k = 0; k < K; ++k) {
                dist[k] = Math.random();
            }
            double sum = MathUtil.arraySum(dist);
            for (int k = 0; k < K; ++k) {
                double value;
                this.gammas[i][k] = value = dist[k] / sum;
            }
        }
        System.out.println("performing random M step");
        this.mStep();
    }

    public void averageInitialize() {
        System.out.println("performing average M step");
        this.mStep();
    }

    public void iterate() {
        this.eStep();
        this.mStep();
    }

    protected void eStep() {
        if (logger.isDebugEnabled()) {
            logger.debug("start E step");
        }
        this.updateGamma();
        if (logger.isDebugEnabled()) {
            logger.debug("finish E step");
        }
    }

    protected void updateGamma() {
        IntStream.range(0, this.dataSet.getNumDataPoints()).parallel().forEach(this::updateGamma);
    }

    protected void updateGamma(int n) {
        Vector x = this.dataSet.getRow(n);
        MultiLabel y = this.dataSet.getMultiLabels()[n];
        double[] posterior = this.cbm.posteriorMembershipShortCircuit(x, y);
        for (int k = 0; k < this.cbm.numComponents; ++k) {
            this.gammas[n][k] = posterior[k];
        }
    }

    void mStep() {
        if (logger.isDebugEnabled()) {
            logger.debug("start M step");
        }
        this.updateBinaryClassifiers();
        this.updateMultiClassClassifier();
        if (logger.isDebugEnabled()) {
            logger.debug("finish M step");
        }
    }

    protected void updateBinaryClassifiers() {
        if (logger.isDebugEnabled()) {
            logger.debug("start updateBinaryClassifiers");
        }
        IntStream.range(0, this.cbm.numComponents).forEach(this::updateBinaryClassifiers);
        if (logger.isDebugEnabled()) {
            logger.debug("finish updateBinaryClassifiers");
        }
    }

    protected void updateBinaryClassifiers(int component) {
        if (logger.isDebugEnabled()) {
            logger.debug("computing active dataset for component " + component);
        }
        ArrayList<Double> activeGammasList = new ArrayList<Double>();
        ArrayList<Integer> activeIndices = new ArrayList<Integer>();
        double[] gammasForComponent = IntStream.range(0, this.dataSet.getNumDataPoints()).mapToDouble(i -> this.gammas[i][component]).toArray();
        int maxIndex = ArgMax.argMax(gammasForComponent);
        double weightedTotal = 0.0;
        double thresholdedWeightedTotal = 0.0;
        int counter = 0;
        for (int i2 = 0; i2 < this.dataSet.getNumDataPoints(); ++i2) {
            double v = this.gammas[i2][component];
            weightedTotal += v;
            if (!(v >= this.skipDataThreshold) && i2 != maxIndex) continue;
            activeGammasList.add(v);
            activeIndices.add(i2);
            thresholdedWeightedTotal += v;
            ++counter;
        }
        double[] activeGammas = activeGammasList.stream().mapToDouble(a -> a).toArray();
        if (logger.isDebugEnabled()) {
            logger.debug("number of active data  = " + counter);
            logger.debug("total weight  = " + weightedTotal);
            logger.debug("total weight of active data  = " + thresholdedWeightedTotal);
            logger.debug("creating active dataset");
        }
        MultiLabelClfDataSet activeDataSet = DataSetUtil.sampleData(this.dataSet, activeIndices);
        int activeFeatures = (int)IntStream.range(0, activeDataSet.getNumFeatures()).filter(j -> activeDataSet.getColumn(j).getNumNonZeroElements() > 0).count();
        if (logger.isDebugEnabled()) {
            logger.debug("active dataset created");
            logger.debug("number of active features = " + activeFeatures);
        }
        double totalWeight = weightedTotal;
        if (this.parallelBinaryUpdates) {
            IntStream.range(0, this.cbm.numLabels).parallel().forEach(l -> this.skipOrUpdateBinaryClassifier(component, l, activeDataSet, activeGammas, totalWeight));
        } else {
            IntStream.range(0, this.cbm.numLabels).forEach(l -> this.skipOrUpdateBinaryClassifier(component, l, activeDataSet, activeGammas, totalWeight));
        }
    }

    protected void skipOrUpdateBinaryClassifier(int component, int label, MultiLabelClfDataSet activeDataSet, double[] activeGammas, double totalWeight) {
        StopWatch stopWatch = new StopWatch();
        stopWatch.start();
        double effectivePositives = this.effectivePositives(component, label);
        double nonSmoothedPositiveProb = effectivePositives / totalWeight;
        double smoothedPositiveProb = (effectivePositives + this.smoothingStrength * (double)this.positiveCounts[label]) / (totalWeight + this.smoothingStrength * (double)this.dataSet.getNumDataPoints());
        StringBuilder sb = new StringBuilder();
        sb.append("for component ").append(component).append(", label ").append(label);
        sb.append(", weighted positives = ").append(effectivePositives);
        sb.append(", non-smoothed positive fraction = " + effectivePositives / totalWeight);
        sb.append(", global positive fraction = " + (double)this.positiveCounts[label] / (double)this.dataSet.getNumDataPoints());
        sb.append(", smoothed positive fraction = " + smoothedPositiveProb);
        if (smoothedPositiveProb >= 1.0) {
            smoothedPositiveProb = 1.0;
        }
        if (nonSmoothedPositiveProb < this.skipLabelThreshold || nonSmoothedPositiveProb > 1.0 - this.skipLabelThreshold) {
            double[] probs = new double[]{1.0 - smoothedPositiveProb, smoothedPositiveProb};
            this.cbm.binaryClassifiers[component][label] = new PriorProbClassifier(probs);
            sb.append(", skip, use prior = ").append(smoothedPositiveProb);
            sb.append(", time spent = ").append(stopWatch.toString());
            if (logger.isDebugEnabled()) {
                logger.debug(sb.toString());
            }
            return;
        }
        if (logger.isDebugEnabled()) {
            logger.debug(sb.toString());
        }
        this.updateBinaryClassifier(component, label, activeDataSet, activeGammas);
    }

    protected abstract void updateBinaryClassifier(int var1, int var2, MultiLabelClfDataSet var3, double[] var4);

    protected abstract void updateMultiClassClassifier();

    private double effectivePositives(int componentIndex, int labelIndex) {
        double sum = 0.0;
        Vector labelColumn = this.labelMatrix.getColumn(labelIndex);
        for (Vector.Element element : labelColumn.nonZeroes()) {
            int dataIndex = element.index();
            sum += this.gammas[dataIndex][componentIndex];
        }
        return sum;
    }

    public double getObjective() {
        return this.multiClassClassifierObj() + this.binaryObj();
    }

    protected double binaryObj() {
        return IntStream.range(0, this.cbm.numComponents).mapToDouble(this::binaryObj).sum();
    }

    protected double binaryObj(int component) {
        return IntStream.range(0, this.cbm.numLabels).parallel().mapToDouble(l -> this.binaryObj(component, l)).sum();
    }

    protected abstract double binaryObj(int var1, int var2);

    protected abstract double multiClassClassifierObj();

    public double[][] getGammas() {
        return this.gammas;
    }

    private void checkGamma() {
        for (int i = 0; i < this.gammas.length; ++i) {
            for (int k = 0; k < this.gammas[0].length; ++k) {
                if (!Double.isNaN(this.gammas[i][k])) continue;
                throw new RuntimeException("gamma " + i + " " + k + " is NaN");
            }
        }
    }
}

