/*
 * Decompiled with CFR 0.152.
 */
package edu.neu.ccs.pyramid.multilabel_classification.cbm;

import edu.neu.ccs.pyramid.classification.PriorProbClassifier;
import edu.neu.ccs.pyramid.classification.logistic_regression.LogisticRegression;
import edu.neu.ccs.pyramid.classification.logistic_regression.RidgeLogisticOptimizer;
import edu.neu.ccs.pyramid.clustering.bm.BM;
import edu.neu.ccs.pyramid.clustering.bm.BMSelector;
import edu.neu.ccs.pyramid.dataset.DataSet;
import edu.neu.ccs.pyramid.dataset.DataSetUtil;
import edu.neu.ccs.pyramid.dataset.MultiLabel;
import edu.neu.ccs.pyramid.dataset.MultiLabelClfDataSet;
import edu.neu.ccs.pyramid.multilabel_classification.cbm.CBM;
import edu.neu.ccs.pyramid.util.Pair;
import java.util.stream.IntStream;
import org.apache.commons.lang3.time.StopWatch;
import org.apache.mahout.math.Vector;

public class SparseCBMOptimzer {
    private CBM cbm;
    private MultiLabelClfDataSet dataSet;
    double[][] gammas;
    BM bm;
    private double priorVarianceMultiClass = 1.0;
    private double priorVarianceBinary = 1.0;
    private double[] activeGammas;
    private double activeThreshold = 1.0E-5;
    private double weightedTotal;
    private int numMulticlassUpdates = 50;
    private int numBinaryUpdates = 50;

    public SparseCBMOptimzer(CBM cbm, MultiLabelClfDataSet dataSet) {
        this.cbm = cbm;
        this.dataSet = dataSet;
        this.gammas = new double[dataSet.getNumDataPoints()][cbm.numComponents];
    }

    public void setNumMulticlassUpdates(int numMulticlassUpdates) {
        this.numMulticlassUpdates = numMulticlassUpdates;
    }

    public void setNumBinaryUpdates(int numBinaryUpdates) {
        this.numBinaryUpdates = numBinaryUpdates;
    }

    public void setPriorVarianceMultiClass(double priorVarianceMultiClass) {
        this.priorVarianceMultiClass = priorVarianceMultiClass;
    }

    public void setPriorVarianceBinary(double priorVarianceBinary) {
        this.priorVarianceBinary = priorVarianceBinary;
    }

    public void initalizeGammaByBM() {
        Pair<BM, double[][]> pair = BMSelector.selectAll(this.dataSet.getNumClasses(), this.dataSet.getMultiLabels(), this.cbm.getNumComponents());
        this.bm = pair.getFirst();
        this.gammas = pair.getSecond();
    }

    public void updateMultiClassLR() {
        RidgeLogisticOptimizer ridgeLogisticOptimizer = new RidgeLogisticOptimizer((LogisticRegression)this.cbm.multiClassClassifier, (DataSet)this.dataSet, this.gammas, this.priorVarianceMultiClass, true);
        ridgeLogisticOptimizer.getOptimizer().getTerminator().setMaxIteration(this.numMulticlassUpdates);
        ridgeLogisticOptimizer.optimize();
    }

    public void updateAllBinary() {
        int k = 0;
        while (k < this.cbm.getNumComponents()) {
            this.updateEffectiveData(k);
            int com = k++;
            IntStream.range(0, this.cbm.getNumClasses()).parallel().forEach(l -> this.updateBinaryLogisticRegression(com, l));
        }
    }

    private void updateBinaryLogisticRegression(int componentIndex, int labelIndex) {
        StopWatch stopWatch = new StopWatch();
        stopWatch.start();
        double effectivePositives = this.effectivePositives(componentIndex, labelIndex);
        StringBuilder sb = new StringBuilder();
        sb.append("for component ").append(componentIndex).append(", label ").append(labelIndex);
        sb.append(", effective positives = ").append(effectivePositives);
        if (effectivePositives <= 1.0) {
            double positiveProb = this.prior(componentIndex, labelIndex);
            double[] probs = new double[]{1.0 - positiveProb, positiveProb};
            this.cbm.binaryClassifiers[componentIndex][labelIndex] = new PriorProbClassifier(probs);
            sb.append(", skip, use prior = ").append(positiveProb);
            sb.append(", time spent = " + stopWatch.toString());
            System.out.println(sb.toString());
            return;
        }
        if (this.cbm.binaryClassifiers[componentIndex][labelIndex] == null || this.cbm.binaryClassifiers[componentIndex][labelIndex] instanceof PriorProbClassifier) {
            this.cbm.binaryClassifiers[componentIndex][labelIndex] = new LogisticRegression(2, this.dataSet.getNumFeatures());
        }
        int[] binaryLabels = DataSetUtil.toBinaryLabels(this.dataSet.getMultiLabels(), labelIndex);
        RidgeLogisticOptimizer ridgeLogisticOptimizer = new RidgeLogisticOptimizer((LogisticRegression)this.cbm.binaryClassifiers[componentIndex][labelIndex], (DataSet)this.dataSet, binaryLabels, this.activeGammas, this.priorVarianceBinary, false);
        ridgeLogisticOptimizer.getOptimizer().getTerminator().setMaxIteration(this.numBinaryUpdates);
        ridgeLogisticOptimizer.optimize();
        sb.append(", time spent = " + stopWatch.toString());
        System.out.println(sb.toString());
    }

    private void updateEffectiveData(int componentIndex) {
        System.out.println("computing active dataset for component " + componentIndex);
        this.activeGammas = new double[this.dataSet.getNumDataPoints()];
        this.weightedTotal = 0.0;
        int counter = 0;
        for (int i = 0; i < this.dataSet.getNumDataPoints(); ++i) {
            double v = this.gammas[i][componentIndex];
            if (v > this.activeThreshold) {
                this.activeGammas[i] = v;
                this.weightedTotal += v;
                ++counter;
                continue;
            }
            this.activeGammas[i] = 0.0;
        }
        System.out.println("raw number of data in active dataset = " + counter);
        System.out.println("weighted number of data in active dataset = " + this.weightedTotal);
    }

    private double effectivePositives(int componentIndex, int labelIndex) {
        double sum = 0.0;
        for (int i = 0; i < this.dataSet.getNumDataPoints(); ++i) {
            if (!this.dataSet.getMultiLabels()[i].matchClass(labelIndex)) continue;
            sum += this.gammas[i][componentIndex];
        }
        return sum;
    }

    public void updateGamma() {
        IntStream.range(0, this.dataSet.getNumDataPoints()).parallel().forEach(this::updateGamma);
    }

    private void updateGamma(int n) {
        Vector x = this.dataSet.getRow(n);
        MultiLabel y = this.dataSet.getMultiLabels()[n];
        double[] posterior = this.cbm.posteriorMembership(x, y);
        for (int k = 0; k < this.cbm.numComponents; ++k) {
            this.gammas[n][k] = posterior[k];
        }
    }

    private double prior(int componentIndex, int labelIndex) {
        double positives = 0.0;
        double total = 0.0;
        for (int i = 0; i < this.dataSet.getNumDataPoints(); ++i) {
            total += this.gammas[i][componentIndex];
            if (!this.dataSet.getMultiLabels()[i].matchClass(labelIndex)) continue;
            positives += this.gammas[i][componentIndex];
        }
        return positives / total;
    }
}

