/*
 * Decompiled with CFR 0.152.
 */
package edu.neu.ccs.pyramid.multilabel_classification.adaboostmh;

import edu.neu.ccs.pyramid.dataset.MultiLabelClfDataSet;
import edu.neu.ccs.pyramid.dataset.ScoreMatrix;
import edu.neu.ccs.pyramid.dataset.WeightMatrix;
import edu.neu.ccs.pyramid.multilabel_classification.MLPriorProbClassifier;
import edu.neu.ccs.pyramid.multilabel_classification.adaboostmh.AdaBoostMH;
import edu.neu.ccs.pyramid.regression.ConstantRegressor;
import edu.neu.ccs.pyramid.regression.Regressor;
import edu.neu.ccs.pyramid.regression.regression_tree.RegressionTree;
import java.util.Arrays;
import java.util.Comparator;
import java.util.stream.IntStream;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.mahout.math.Vector;

public class AdaBoostMHTrainer {
    private static final Logger logger = LogManager.getLogger();
    private AdaBoostMH boosting;
    private ScoreMatrix scoreMatrix;
    private WeightMatrix weightMatrix;
    private MultiLabelClfDataSet dataSet;
    private boolean[][] labels;

    public AdaBoostMHTrainer(MultiLabelClfDataSet dataSet, AdaBoostMH boosting) {
        this.dataSet = dataSet;
        this.boosting = boosting;
        this.boosting.setFeatureList(this.dataSet.getFeatureList());
        this.boosting.setLabelTranslator(this.dataSet.getLabelTranslator());
        if (boosting.getRegressors(0).size() == 0) {
            this.setPriorProbs(dataSet);
        }
        this.scoreMatrix = new ScoreMatrix(dataSet.getNumDataPoints(), dataSet.getNumClasses());
        this.initStagedClassScoreMatrix();
        this.weightMatrix = new WeightMatrix(dataSet.getNumDataPoints(), dataSet.getNumClasses());
        this.updateDistribution();
        this.labels = new boolean[dataSet.getNumDataPoints()][dataSet.getNumClasses()];
        IntStream.range(0, dataSet.getNumDataPoints()).parallel().forEach(i -> {
            for (int k : dataSet.getMultiLabels()[i].getMatchedLabels()) {
                this.labels[i][k] = true;
            }
        });
    }

    public void iterate() {
        for (int k = 0; k < this.boosting.getNumClasses(); ++k) {
            RegressionTree regressor = this.fitClassK(k);
            this.boosting.addRegressor(regressor, k);
            this.updateStagedClassScores(regressor, k);
        }
        this.updateDistribution();
    }

    private void setPriorProbs(double[] probs) {
        if (probs.length != this.boosting.getNumClasses()) {
            throw new IllegalArgumentException("probs.length!=this.numClasses");
        }
        double average = Arrays.stream(probs).map(Math::log).average().getAsDouble();
        for (int k = 0; k < this.boosting.getNumClasses(); ++k) {
            double score = Math.log(probs[k] - average);
            ConstantRegressor constant = new ConstantRegressor(score);
            this.boosting.addRegressor(constant, k);
        }
    }

    private void setPriorProbs(MultiLabelClfDataSet dataSet) {
        MLPriorProbClassifier priorProbClassifier = new MLPriorProbClassifier(dataSet.getNumClasses());
        priorProbClassifier.fit(dataSet);
        double[] probs = priorProbClassifier.getClassProbs();
        this.setPriorProbs(probs);
    }

    private void updateDistribution() {
        int numClasses = this.boosting.getNumClasses();
        IntStream.range(0, this.dataSet.getNumDataPoints()).parallel().forEach(i -> {
            double[] y = new double[numClasses];
            Arrays.fill(y, -1.0);
            for (int k : this.dataSet.getMultiLabels()[i].getMatchedLabels()) {
                y[k] = 1.0;
            }
            float[] scores = this.scoreMatrix.getScoresForData(i);
            for (int k = 0; k < numClasses; ++k) {
                double prob = Math.exp(-1.0 * y[k] * (double)scores[k]);
                this.weightMatrix.setProbability(i, k, prob);
            }
        });
        this.weightMatrix.normalize();
    }

    private void initStagedClassScoreMatrix() {
        int numClasses = this.boosting.getNumClasses();
        for (int k = 0; k < numClasses; ++k) {
            for (Regressor regressor : this.boosting.getRegressors(k)) {
                this.updateStagedClassScores(regressor, k);
            }
        }
    }

    private void updateStagedClassScores(Regressor regressor, int k) {
        int numDataPoints = this.dataSet.getNumDataPoints();
        IntStream.range(0, numDataPoints).parallel().forEach(dataIndex -> this.updateStagedClassScore(regressor, k, dataIndex));
    }

    private void updateStagedClassScore(Regressor regressor, int k, int dataIndex) {
        Vector vector = this.dataSet.getRow(dataIndex);
        double prediction = regressor.predict(vector);
        this.scoreMatrix.increment(dataIndex, k, prediction);
    }

    private RegressionTree fitClassK(int k) {
        double[] probs = this.weightMatrix.getProbsForClass(k);
        double match = IntStream.range(0, this.dataSet.getNumDataPoints()).parallel().filter(i -> this.labels[i][k]).mapToDouble(i -> this.weightMatrix.getProbsForData(i)[k]).sum();
        double notMatch = IntStream.range(0, this.dataSet.getNumDataPoints()).parallel().filter(i -> !this.labels[i][k]).mapToDouble(i -> this.weightMatrix.getProbsForData(i)[k]).sum();
        StumpInfo optimal = IntStream.range(0, this.dataSet.getNumFeatures()).parallel().mapToObj(j -> {
            double matchOccur = 0.0;
            double notMatchOccur = 0.0;
            Vector vector = this.dataSet.getColumn(j);
            for (Vector.Element element : vector.nonZeroes()) {
                int i = element.index();
                double prob = probs[i];
                if (this.labels[i][k]) {
                    matchOccur += prob;
                    continue;
                }
                notMatchOccur += prob;
            }
            double matchNotOccur = match - matchOccur;
            double notMatchNotOccur = notMatch - notMatchOccur;
            StumpInfo stumpInfo = new StumpInfo();
            stumpInfo.featureIndex = j;
            stumpInfo.matchOccur = matchOccur;
            stumpInfo.matchNotOccur = matchNotOccur;
            stumpInfo.notMatchOccur = notMatchOccur;
            stumpInfo.notMatchNotOccur = notMatchNotOccur;
            return stumpInfo;
        }).min(Comparator.comparing(StumpInfo::getObjective)).get();
        double smooth = 1.0 / (double)(this.dataSet.getNumDataPoints() * this.dataSet.getNumClasses());
        double leftOutput = 0.5 * Math.log((optimal.matchNotOccur + smooth) / (optimal.notMatchNotOccur + smooth));
        double rightOutput = 0.5 * Math.log((optimal.matchOccur + smooth) / (optimal.notMatchOccur + smooth));
        RegressionTree tree = RegressionTree.newStump(optimal.featureIndex, 0.0, leftOutput, rightOutput);
        tree.setFeatureList(this.dataSet.getFeatureList());
        return tree;
    }

    private static class StumpInfo {
        private int featureIndex;
        private double matchOccur;
        private double matchNotOccur;
        private double notMatchOccur;
        private double notMatchNotOccur;

        private StumpInfo() {
        }

        double getObjective() {
            return Math.sqrt(this.matchOccur * this.notMatchOccur) + Math.sqrt(this.matchNotOccur * this.notMatchNotOccur);
        }
    }
}

