/*
 * Decompiled with CFR 0.152.
 */
package edu.neu.ccs.pyramid.multilabel_classification;

import edu.neu.ccs.pyramid.dataset.MultiLabel;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.PriorityQueue;
import java.util.Set;
import org.apache.mahout.math.DenseVector;

public class DynamicProgramming {
    private PriorityQueue<Candidate> queue;
    private double[][] probs;
    private double[][] logProbs;
    private int numLabels;
    private Set<MultiLabel> cache;
    private List<Integer> uncertainLabels;

    public DynamicProgramming(double[] probabilities) {
        double[][] probs = new double[probabilities.length][2];
        double[][] logProbs = new double[probabilities.length][2];
        this.uncertainLabels = new ArrayList<Integer>();
        for (int l = 0; l < probabilities.length; ++l) {
            probs[l][0] = 1.0 - probabilities[l];
            probs[l][1] = probabilities[l];
            logProbs[l][0] = Math.log(probs[l][0]);
            logProbs[l][1] = Math.log(probs[l][1]);
            if (probabilities[l] == 0.0 || probabilities[l] == 1.0) continue;
            this.uncertainLabels.add(l);
        }
        this.numLabels = probs.length;
        this.probs = probs;
        this.logProbs = logProbs;
        this.cache = new HashSet<MultiLabel>();
        this.queue = new PriorityQueue();
        MultiLabel multiLabel = new MultiLabel();
        double logProb = 0.0;
        for (int l = 0; l < this.numLabels; ++l) {
            if (this.probs[l][1] >= 0.5) {
                multiLabel.addLabel(l);
                logProb += this.logProbs[l][1];
                continue;
            }
            logProb += this.logProbs[l][0];
        }
        this.queue.add(new Candidate(multiLabel, logProb));
        this.cache.add(multiLabel);
    }

    public DynamicProgramming(double[][] probs, double[][] logProbs) {
        this.numLabels = probs.length;
        this.probs = probs;
        this.logProbs = logProbs;
        this.cache = new HashSet<MultiLabel>();
        this.queue = new PriorityQueue();
        MultiLabel multiLabel = new MultiLabel();
        this.uncertainLabels = new ArrayList<Integer>();
        for (int l = 0; l < this.numLabels; ++l) {
            double p = probs[l][1];
            if (p == 0.0 || p == 1.0) continue;
            this.uncertainLabels.add(l);
        }
        double logProb = 0.0;
        for (int l = 0; l < this.numLabels; ++l) {
            if (this.probs[l][1] >= 0.5) {
                multiLabel.addLabel(l);
                logProb += this.logProbs[l][1];
                continue;
            }
            logProb += this.logProbs[l][0];
        }
        this.queue.add(new Candidate(multiLabel, logProb));
        this.cache.add(multiLabel);
    }

    public PriorityQueue<Candidate> getQueue() {
        return this.queue;
    }

    public double nextHighestProb() {
        if (this.queue.size() > 0) {
            return this.queue.peek().probability;
        }
        return 0.0;
    }

    public double highestLogProb() {
        if (this.queue.size() > 0) {
            return this.queue.peek().logProbability;
        }
        return Double.NEGATIVE_INFINITY;
    }

    public MultiLabel nextHighestVector() {
        if (this.queue.size() > 0) {
            this.flipLabels(this.queue.peek());
            return this.queue.poll().multiLabel;
        }
        return new MultiLabel();
    }

    public Candidate nextHighest() {
        if (this.queue.size() > 0) {
            this.flipLabels(this.queue.peek());
            return this.queue.poll();
        }
        MultiLabel multiLabel = new MultiLabel();
        Candidate candidate = new Candidate(multiLabel, Double.NEGATIVE_INFINITY);
        return candidate;
    }

    private void flipLabels(Candidate data) {
        double prevlogProb = data.logProbability;
        MultiLabel multiLabel = data.multiLabel;
        for (int l : this.uncertainLabels) {
            MultiLabel flipped = multiLabel.copy();
            flipped.flipLabel(l);
            double logProb = flipped.matchClass(l) ? prevlogProb - this.logProbs[l][0] + this.logProbs[l][1] : prevlogProb - this.logProbs[l][1] + this.logProbs[l][0];
            if (this.cache.contains(flipped)) continue;
            this.queue.add(new Candidate(flipped, logProb));
            this.cache.add(flipped);
        }
    }

    private double calculateProb(DenseVector vector) {
        double logProb = 0.0;
        for (int l = 0; l < this.numLabels; ++l) {
            if (vector.get(l) == 1.0) {
                logProb += this.logProbs[l][1];
                continue;
            }
            logProb += this.logProbs[l][0];
        }
        return logProb;
    }

    public String toString() {
        return this.queue.toString();
    }

    public class Candidate
    implements Comparable<Candidate> {
        private final MultiLabel multiLabel;
        private final double logProbability;
        private final double probability;

        Candidate(MultiLabel multiLabel, double logProbability) {
            this.multiLabel = multiLabel;
            this.logProbability = logProbability;
            this.probability = Math.exp(logProbability);
        }

        public MultiLabel getMultiLabel() {
            return this.multiLabel;
        }

        public double getLogProbability() {
            return this.logProbability;
        }

        public double getProbability() {
            return this.probability;
        }

        @Override
        public int compareTo(Candidate o) {
            return Double.valueOf(o.logProbability).compareTo(this.logProbability);
        }

        public String toString() {
            return "prob: " + String.format("%.3f", Math.exp(this.logProbability)) + "\tvetcor: " + this.multiLabel;
        }
    }
}

