/*
 * Decompiled with CFR 0.152.
 */
package edu.neu.ccs.pyramid.multilabel_classification.crf;

import edu.neu.ccs.pyramid.classification.logistic_regression.LogisticRegression;
import edu.neu.ccs.pyramid.classification.logistic_regression.RidgeLogisticOptimizer;
import edu.neu.ccs.pyramid.dataset.ClfDataSet;
import edu.neu.ccs.pyramid.dataset.ClfDataSetBuilder;
import edu.neu.ccs.pyramid.dataset.DataSet;
import edu.neu.ccs.pyramid.dataset.MultiLabel;
import edu.neu.ccs.pyramid.dataset.MultiLabelClfDataSet;
import edu.neu.ccs.pyramid.multilabel_classification.Enumerator;
import edu.neu.ccs.pyramid.multilabel_classification.crf.CMLCRF;
import edu.neu.ccs.pyramid.multilabel_classification.crf.KLLoss;
import edu.neu.ccs.pyramid.optimization.GradientDescent;
import edu.neu.ccs.pyramid.optimization.GradientValueOptimizer;
import edu.neu.ccs.pyramid.optimization.LBFGS;
import edu.neu.ccs.pyramid.optimization.Terminator;
import edu.neu.ccs.pyramid.util.ArgSort;
import edu.neu.ccs.pyramid.util.MathUtil;
import java.util.ArrayList;
import java.util.List;
import java.util.stream.IntStream;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.mahout.math.DenseVector;
import org.apache.mahout.math.Vector;

public class NoiseOptimizerLR {
    private static final Logger logger = LogManager.getLogger();
    private MultiLabelClfDataSet dataSet;
    private CMLCRF crf;
    private double[][] targets;
    private double[][] transformProbs;
    private List<MultiLabel> combinations;
    private double variance;
    private double[][] probabilities;
    private Terminator terminator;
    private String optimizer = "LBFGS";
    public List<LogisticRegression> lrTransforms;
    private List<ClfDataSet> lrDataSet;
    private double[][][] lrTargets;

    public NoiseOptimizerLR(MultiLabelClfDataSet dataSet, CMLCRF crf, double variance) {
        this.dataSet = dataSet;
        this.variance = variance;
        this.crf = crf;
        this.combinations = Enumerator.enumerate(dataSet.getNumClasses());
        this.targets = new double[dataSet.getNumDataPoints()][this.combinations.size()];
        this.probabilities = new double[dataSet.getNumDataPoints()][this.combinations.size()];
        this.transformProbs = new double[dataSet.getNumDataPoints()][this.combinations.size()];
        this.lrTransforms = new ArrayList<LogisticRegression>();
        this.lrDataSet = new ArrayList<ClfDataSet>();
        int numCombination = (int)Math.pow(2.0, dataSet.getNumClasses());
        if (numCombination != this.combinations.size()) {
            throw new IllegalArgumentException("number of combination should equal!");
        }
        this.lrTargets = new double[dataSet.getNumClasses()][dataSet.getNumDataPoints() * numCombination][2];
        for (int i = 0; i < dataSet.getNumClasses(); ++i) {
            LogisticRegression lr = new LogisticRegression(2, dataSet.getNumClasses(), true);
            this.lrTransforms.add(lr);
            this.lrDataSet.add(this.buildLrData(i));
            this.lrTargets[i] = this.buildLrTargets(i);
        }
        this.updateTransformProbs();
        this.updateProbabilities();
        this.terminator = new Terminator();
        if (logger.isDebugEnabled()) {
            logger.debug("finish constructor");
        }
    }

    private double[][] buildLrTargets(int classIndex) {
        int numCombination = this.combinations.size();
        double[][] targets = new double[this.dataSet.getNumDataPoints() * numCombination][2];
        for (int i = 0; i < this.dataSet.getNumDataPoints(); ++i) {
            boolean labelToSet = false;
            if (this.dataSet.getMultiLabels()[i].matchClass(classIndex)) {
                labelToSet = true;
            }
            for (int j = 0; j < numCombination; ++j) {
                if (labelToSet) {
                    targets[i * numCombination + j][0] = 0.0;
                    targets[i * numCombination + j][1] = 1.0;
                    continue;
                }
                targets[i * numCombination + j][0] = 1.0;
                targets[i * numCombination + j][1] = 0.0;
            }
        }
        return targets;
    }

    private ClfDataSet buildLrData(int classIndex) {
        int numCombination = this.combinations.size();
        ClfDataSet lrDataSet = ClfDataSetBuilder.getBuilder().numDataPoints(this.dataSet.getNumDataPoints() * numCombination).numFeatures(this.dataSet.getNumClasses()).numClasses(2).dense(true).missingValue(false).build();
        for (int i = 0; i < this.dataSet.getNumDataPoints(); ++i) {
            int labelToSet = 0;
            if (this.dataSet.getMultiLabels()[i].matchClass(classIndex)) {
                labelToSet = 1;
            }
            for (int k = 0; k < numCombination; ++k) {
                for (int j = 0; j < this.dataSet.getNumClasses(); ++j) {
                    if (this.combinations.get(k).matchClass(j)) {
                        lrDataSet.setFeatureValue(i * numCombination + k, j, 0.5);
                        continue;
                    }
                    lrDataSet.setFeatureValue(i * numCombination + k, j, -0.5);
                }
                lrDataSet.setLabel(i * numCombination + k, labelToSet);
            }
        }
        return lrDataSet;
    }

    public void setOptimizer(String optimizer) {
        this.optimizer = optimizer;
    }

    private void updateProbabilities(int dataPointIndex) {
        this.probabilities[dataPointIndex] = this.crf.predictCombinationProbs(this.dataSet.getRow(dataPointIndex));
    }

    private void updateProbabilities() {
        if (logger.isDebugEnabled()) {
            logger.debug("start updateProbabilities()");
        }
        IntStream.range(0, this.dataSet.getNumDataPoints()).parallel().forEach(this::updateProbabilities);
        if (logger.isDebugEnabled()) {
            logger.debug("finish updateProbabilities()");
        }
    }

    private void updateTransformProbs() {
        IntStream.range(0, this.dataSet.getNumDataPoints()).parallel().forEach(this::updateTransformProbs);
    }

    private void updateTransformProbs(int dataPoint) {
        for (int c = 0; c < this.combinations.size(); ++c) {
            this.updateTransformProb(dataPoint, c);
        }
    }

    private void updateTransformProb(int dataPoint, int comIndex) {
        MultiLabel labels = this.dataSet.getMultiLabels()[dataPoint];
        MultiLabel candidate = this.combinations.get(comIndex);
        DenseVector toMinus = new DenseVector(this.dataSet.getNumClasses());
        for (int i = 0; i < this.dataSet.getNumClasses(); ++i) {
            toMinus.set(i, 0.5);
        }
        double prod = 1.0;
        for (int l = 0; l < this.dataSet.getNumClasses(); ++l) {
            if (labels.matchClass(l)) {
                prod *= this.lrTransforms.get(l).predictClassProb(candidate.toVector(this.dataSet.getNumClasses()).minus((Vector)toMinus), 1);
                continue;
            }
            prod *= this.lrTransforms.get(l).predictClassProb(candidate.toVector(this.dataSet.getNumClasses()).minus((Vector)toMinus), 0);
        }
        this.transformProbs[dataPoint][comIndex] = prod;
    }

    private void updateAlphas() {
        IntStream.range(0, this.dataSet.getNumClasses()).parallel().forEach(this::updateAlpha);
    }

    private void updateAlpha(int classIndex) {
        int numCombination = this.combinations.size();
        double[] weights = new double[this.dataSet.getNumDataPoints() * numCombination];
        for (int i = 0; i < this.dataSet.getNumDataPoints(); ++i) {
            for (int j = 0; j < numCombination; ++j) {
                weights[i * numCombination + j] = this.targets[i][j];
            }
        }
        RidgeLogisticOptimizer optimizer = new RidgeLogisticOptimizer(this.lrTransforms.get(classIndex), (DataSet)this.lrDataSet.get(classIndex), weights, this.lrTargets[classIndex], 1000.0, true);
        optimizer.getOptimizer().getTerminator().setMaxIteration(10000).setMode(Terminator.Mode.STANDARD);
        optimizer.optimize();
    }

    private void updateTargets(int dataPointIndex) {
        double[] probs = this.probabilities[dataPointIndex];
        double[] product = new double[probs.length];
        double[] s = this.transformProbs[dataPointIndex];
        for (int j = 0; j < probs.length; ++j) {
            product[j] = probs[j] * s[j];
        }
        double denominator = MathUtil.arraySum(product);
        for (int j = 0; j < probs.length; ++j) {
            this.targets[dataPointIndex][j] = product[j] / denominator;
        }
    }

    public void optimize() {
        while (!this.terminator.shouldTerminate()) {
            this.iterate();
        }
    }

    public Terminator getTerminator() {
        return this.terminator;
    }

    private void printProbWithThreshold(double[] probs, double thresh) {
        int[] indices = ArgSort.argSortDescending(probs);
        for (int i = 0; i < indices.length; ++i) {
            if (!(probs[indices[i]] >= thresh)) continue;
            System.out.println(indices[i] + ":" + this.combinations.get(indices[i]).toString() + ":" + probs[indices[i]]);
        }
    }

    public void printInfo() {
        for (int i = 0; i < this.dataSet.getNumDataPoints(); ++i) {
            System.out.println("index=" + i + ",label=" + this.dataSet.getMultiLabels()[i].toString());
            System.out.println("printing targets ..");
            this.printProbWithThreshold(this.targets[i], 0.1);
            System.out.println("printing transformProbs ..");
            this.printProbWithThreshold(this.transformProbs[i], 0.1);
            System.out.println("printing probability ..");
            this.printProbWithThreshold(this.probabilities[i], 0.1);
        }
    }

    public void iterate() {
        this.updateTargets();
        System.out.println("finish updateTargets ");
        System.out.println("objective = " + this.objective());
        this.updateAlphas();
        System.out.println("finish updateAlphas ");
        System.out.println("objective = " + this.objective());
        this.updateTransformProbs();
        System.out.println("finish updateTransformProbs ");
        System.out.println("objective = " + this.objective());
        this.updateModel();
        System.out.println("finish updateModel ");
        System.out.println("objective = " + this.objective());
        this.updateProbabilities();
        System.out.println("finish updateProbabilities ");
        double objective = this.objective();
        System.out.println("objective = " + objective);
        this.terminator.add(objective);
    }

    public void iteratePartial(int modelIterations) {
        this.updateTargets();
        System.out.println("finish updateTargets ");
        System.out.println("objective = " + this.objective());
        this.updateAlphas();
        System.out.println("finish updateAlphas ");
        System.out.println("objective = " + this.objective());
        this.updateTransformProbs();
        System.out.println("finish updateTransformProbs ");
        System.out.println("objective = " + this.objective());
        this.updateModelPartial(modelIterations);
        System.out.println("finish updateModel ");
        System.out.println("objective = " + this.objective());
        this.updateProbabilities();
        System.out.println("finish updateProbabilities ");
        double objective = this.objective();
        System.out.println("objective = " + objective);
        this.terminator.add(objective);
    }

    private void updateTargets() {
        if (logger.isDebugEnabled()) {
            logger.debug("start updateTargets()");
        }
        IntStream.range(0, this.dataSet.getNumDataPoints()).parallel().forEach(this::updateTargets);
        if (logger.isDebugEnabled()) {
            logger.debug("finish updateTargets()");
        }
    }

    private void updateModel() {
        if (logger.isDebugEnabled()) {
            logger.debug("start updateModel()");
        }
        KLLoss klLoss = new KLLoss(this.crf, this.dataSet, this.targets, this.variance);
        GradientValueOptimizer opt = null;
        switch (this.optimizer) {
            case "LBFGS": {
                opt = new LBFGS(klLoss);
                break;
            }
            case "GD": {
                opt = new GradientDescent(klLoss);
                break;
            }
            default: {
                throw new IllegalArgumentException("unknown");
            }
        }
        opt.optimize();
        if (logger.isDebugEnabled()) {
            logger.debug("finish updateModel()");
        }
    }

    private void updateModelPartial(int modelIterations) {
        if (logger.isDebugEnabled()) {
            logger.debug("start updateModelPartial()");
        }
        KLLoss klLoss = new KLLoss(this.crf, this.dataSet, this.targets, this.variance);
        GradientValueOptimizer opt = null;
        switch (this.optimizer) {
            case "LBFGS": {
                opt = new LBFGS(klLoss);
                break;
            }
            case "GD": {
                opt = new GradientDescent(klLoss);
                break;
            }
            default: {
                throw new IllegalArgumentException("unknown");
            }
        }
        opt.getTerminator().setMaxIteration(modelIterations);
        opt.optimize();
        if (logger.isDebugEnabled()) {
            logger.debug("finish updateModelPartial()");
        }
    }

    private double objective(int dataPointIndex) {
        double sum = 0.0;
        double[] p = this.probabilities[dataPointIndex];
        double[] s = this.transformProbs[dataPointIndex];
        for (int j = 0; j < p.length; ++j) {
            sum += p[j] * s[j];
        }
        return -Math.log(sum);
    }

    public double objective() {
        if (logger.isDebugEnabled()) {
            logger.debug("start objective()");
        }
        double obj = IntStream.range(0, this.dataSet.getNumDataPoints()).parallel().mapToDouble(this::objective).sum();
        if (logger.isDebugEnabled()) {
            logger.debug("finish obj");
        }
        double penalty = this.penalty();
        if (logger.isDebugEnabled()) {
            logger.debug("finish penalty");
        }
        if (logger.isDebugEnabled()) {
            logger.debug("finish objective()");
        }
        return obj + penalty;
    }

    public String objectiveDetail() {
        double obj = IntStream.range(0, this.dataSet.getNumDataPoints()).parallel().mapToDouble(this::objective).sum();
        double penalty = this.penalty();
        StringBuilder sb = new StringBuilder();
        sb.append("empirical loss = " + obj).append("\n");
        sb.append("regularization penalty = " + penalty).append("\n");
        sb.append("total objective = " + (obj + penalty)).append("\n");
        return sb.toString();
    }

    private double penalty() {
        KLLoss klLoss = new KLLoss(this.crf, this.dataSet, this.targets, this.variance);
        return klLoss.getPenalty();
    }
}

