/*
 * Decompiled with CFR 0.152.
 */
package edu.neu.ccs.pyramid.multilabel_classification.plugin_rule;

import com.google.common.collect.ConcurrentHashMultiset;
import com.google.common.collect.Multiset;
import edu.neu.ccs.pyramid.dataset.MultiLabel;
import edu.neu.ccs.pyramid.eval.InstanceAverage;
import edu.neu.ccs.pyramid.eval.KLDivergence;
import edu.neu.ccs.pyramid.multilabel_classification.Enumerator;
import edu.neu.ccs.pyramid.util.ArgSort;
import edu.neu.ccs.pyramid.util.Pair;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Comparator;
import java.util.List;
import java.util.stream.Collectors;
import org.apache.mahout.math.DenseVector;
import org.apache.mahout.math.Matrix;
import org.apache.mahout.math.Vector;

public class GeneralF1Predictor {
    private int maxSize = 20;

    public void setMaxSize(int maxSize) {
        this.maxSize = maxSize;
    }

    public MultiLabel predict(int numClasses, List<MultiLabel> multiLabels, List<Double> probabilities) {
        double[][] p = this.getPMatrix(numClasses, multiLabels, probabilities);
        double zeroProb = 0.0;
        for (int i = 0; i < multiLabels.size(); ++i) {
            if (multiLabels.get(i).getMatchedLabels().size() != 0) continue;
            zeroProb = probabilities.get(i);
            break;
        }
        return this.predictWithPMatrix(p, zeroProb);
    }

    public MultiLabel predict(int numClasses, List<MultiLabel> multiLabels, double[] probabilities) {
        List<Double> p = Arrays.stream(probabilities).mapToObj(a -> a).collect(Collectors.toList());
        return this.predict(numClasses, multiLabels, p);
    }

    public MultiLabel predict(int numClasses, List<MultiLabel> samples) {
        ConcurrentHashMultiset multiset = ConcurrentHashMultiset.create();
        for (MultiLabel multiLabel : samples) {
            multiset.add((Object)multiLabel);
        }
        int sampleSize = samples.size();
        ArrayList<MultiLabel> uniqueOnes = new ArrayList<MultiLabel>();
        ArrayList<Double> probs = new ArrayList<Double>();
        for (Multiset.Entry entry : multiset.entrySet()) {
            uniqueOnes.add((MultiLabel)entry.getElement());
            probs.add((double)entry.getCount() / (double)sampleSize);
        }
        return this.predict(numClasses, uniqueOnes, probs);
    }

    public MultiLabel predictWithPMatrix(double[][] pMatrix, double zeroProbability) {
        int numLabels = pMatrix.length;
        int min = Math.min(this.maxSize, numLabels);
        MultiLabel best = new MultiLabel();
        double bestScore = zeroProbability;
        for (int k = 1; k <= min; ++k) {
            double[] deltaVector = this.getDeltaVector(pMatrix, k);
            Pair<MultiLabel, Double> innerBest = this.bestWithLengthK(deltaVector, k);
            if (!(innerBest.getSecond() > bestScore)) continue;
            bestScore = innerBest.getSecond();
            best = innerBest.getFirst();
        }
        return best;
    }

    private Pair<MultiLabel, Double> bestWithLengthK(double[] deltaVector, int k) {
        int[] sortedIndcies = ArgSort.argSortDescending(deltaVector);
        MultiLabel multiLabel = new MultiLabel();
        double score = 0.0;
        for (int i = 0; i < k; ++i) {
            int label = sortedIndcies[i];
            multiLabel.addLabel(label);
            score += deltaVector[label];
        }
        return new Pair<MultiLabel, Double>(multiLabel, score);
    }

    private double[] getDeltaVector(double[][] pMatrix, int size) {
        int numLabels = pMatrix.length;
        int min = Math.min(this.maxSize, numLabels);
        double[] d = new double[numLabels];
        for (int i = 0; i < numLabels; ++i) {
            double sum = 0.0;
            for (int s = 1; s <= min; ++s) {
                sum += 2.0 * pMatrix[i][s - 1] / (double)(s + size);
            }
            d[i] = sum;
        }
        return d;
    }

    private double[][] getPMatrix(int numClasses, List<MultiLabel> multiLabels, List<Double> probabilities) {
        int min = Math.min(this.maxSize, numClasses);
        double[][] pMatrix = new double[numClasses][min];
        for (int j = 0; j < multiLabels.size(); ++j) {
            MultiLabel multiLabel = multiLabels.get(j);
            double prob = probabilities.get(j);
            int s = multiLabel.getMatchedLabels().size();
            if (s > this.maxSize) continue;
            for (int i : multiLabel.getMatchedLabels()) {
                double old = pMatrix[i][s - 1];
                pMatrix[i][s - 1] = old + prob;
            }
        }
        return pMatrix;
    }

    public static MultiLabel exhaustiveSearch(int numClasses, Matrix lossMatrix, List<Double> probabilities) {
        double bestScore = Double.POSITIVE_INFINITY;
        DenseVector vector = new DenseVector(probabilities.size());
        for (int i = 0; i < vector.size(); ++i) {
            vector.set(i, probabilities.get(i).doubleValue());
        }
        List<MultiLabel> multiLabels = Enumerator.enumerate(numClasses);
        MultiLabel multiLabel = null;
        for (int j = 0; j < lossMatrix.numCols(); ++j) {
            Vector column = lossMatrix.viewColumn(j);
            double score = column.dot((Vector)vector);
            System.out.println("column " + j + ", expected loss = " + score);
            if (!(score < bestScore)) continue;
            bestScore = score;
            multiLabel = multiLabels.get(j);
        }
        return multiLabel;
    }

    public static double expectedF1(List<MultiLabel> combinations, double[] probs, MultiLabel target, int numClasses) {
        double sum = 0.0;
        for (int i = 0; i < combinations.size(); ++i) {
            sum += probs[i] * new InstanceAverage(numClasses, combinations.get(i), target).getF1();
        }
        return sum;
    }

    public static Analysis showSupportPrediction(List<MultiLabel> combinations, double[] probs, MultiLabel truth, MultiLabel prediction, int numClasses) {
        int truthIndex = 0;
        for (int i = 0; i < combinations.size(); ++i) {
            if (!combinations.get(i).equals(truth)) continue;
            truthIndex = i;
            break;
        }
        double[] trueJoint = new double[combinations.size()];
        trueJoint[truthIndex] = 1.0;
        double kl = KLDivergence.kl(trueJoint, probs);
        ArrayList<Pair<MultiLabel, Double>> list = new ArrayList<Pair<MultiLabel, Double>>();
        for (int i = 0; i < combinations.size(); ++i) {
            list.add(new Pair<MultiLabel, Double>(combinations.get(i), probs[i]));
        }
        Comparator<Pair> comparator = Comparator.comparing(a -> (Double)a.getSecond());
        List sorted = list.stream().sorted(comparator.reversed()).filter(pair -> (Double)pair.getSecond() > 0.01).collect(Collectors.toList());
        double expectedF1Prediction = GeneralF1Predictor.expectedF1(combinations, probs, prediction, numClasses);
        double expectedF1Truth = GeneralF1Predictor.expectedF1(combinations, probs, truth, numClasses);
        double actualF1 = new InstanceAverage(numClasses, truth, prediction).getF1();
        StringBuilder jointString = new StringBuilder();
        for (int i = 0; i < sorted.size(); ++i) {
            jointString.append(((Pair)sorted.get(i)).getFirst()).append(":").append(((Pair)sorted.get(i)).getSecond()).append(", ");
        }
        Analysis analysis = new Analysis();
        analysis.expectedF1Prediction = expectedF1Prediction;
        analysis.expectedF1Truth = expectedF1Truth;
        analysis.actualF1 = actualF1;
        analysis.kl = kl;
        analysis.prediction = prediction;
        analysis.truth = truth;
        analysis.joint = jointString.toString();
        return analysis;
    }

    public static class Analysis {
        double expectedF1Prediction;
        double expectedF1Truth;
        double actualF1;
        double kl;
        MultiLabel truth;
        MultiLabel prediction;
        String joint;

        public double getExpectedF1Prediction() {
            return this.expectedF1Prediction;
        }

        public double getExpectedF1Truth() {
            return this.expectedF1Truth;
        }

        public double getActualF1() {
            return this.actualF1;
        }

        public double getKl() {
            return this.kl;
        }

        public MultiLabel getTruth() {
            return this.truth;
        }

        public MultiLabel getPrediction() {
            return this.prediction;
        }

        public String getJoint() {
            return this.joint;
        }

        public String toString() {
            StringBuilder sb = new StringBuilder();
            sb.append("truth=").append(this.truth).append("\n");
            sb.append("prediction=").append(this.prediction).append("\n");
            sb.append("actual F1=").append(this.actualF1).append("\n");
            sb.append("kl=").append(this.kl).append("\n");
            sb.append("expected F1 of truth=").append(this.expectedF1Truth).append("\n");
            sb.append("expected F1 of prediction=").append(this.expectedF1Prediction).append("\n");
            sb.append("joint=").append(this.joint).append("\n");
            return sb.toString();
        }
    }
}

