/*
 * Decompiled with CFR 0.152.
 */
package edu.neu.ccs.pyramid.classification;

import edu.neu.ccs.pyramid.classification.Classifier;
import edu.neu.ccs.pyramid.dataset.LabelTranslator;
import edu.neu.ccs.pyramid.feature.FeatureList;
import java.util.ArrayList;
import java.util.List;
import org.apache.mahout.math.Vector;

public class ProbabilityVoting
implements Classifier.ProbabilityEstimator {
    private int numClasses;
    private List<Classifier.ProbabilityEstimator> estimatorList;
    private FeatureList featureList;
    private LabelTranslator labelTranslator;

    public ProbabilityVoting(int numClasses) {
        this.numClasses = numClasses;
        this.estimatorList = new ArrayList<Classifier.ProbabilityEstimator>();
    }

    public void add(Classifier.ProbabilityEstimator estimator) {
        if (estimator.getNumClasses() != this.numClasses) {
            throw new IllegalArgumentException("illegal number of classes");
        }
        this.estimatorList.add(estimator);
    }

    @Override
    public int predict(Vector vector) {
        int k;
        int numEstimators = this.estimatorList.size();
        double[] averageProbs = new double[this.numClasses];
        for (Classifier.ProbabilityEstimator estimator : this.estimatorList) {
            double[] probs = estimator.predictClassProbs(vector);
            for (k = 0; k < this.numClasses; ++k) {
                int n = k;
                averageProbs[n] = averageProbs[n] + probs[k];
            }
        }
        int k2 = 0;
        while (k2 < this.numClasses) {
            int n = k2++;
            averageProbs[n] = averageProbs[n] / (double)numEstimators;
        }
        int pred = 0;
        double maxProb = averageProbs[0];
        for (k = 0; k < this.numClasses; ++k) {
            if (!(averageProbs[k] > maxProb)) continue;
            maxProb = averageProbs[k];
            pred = k;
        }
        return pred;
    }

    public int size() {
        return this.estimatorList.size();
    }

    public Classifier.ProbabilityEstimator getProbEstimator(int i) {
        return this.estimatorList.get(i);
    }

    @Override
    public int getNumClasses() {
        return this.numClasses;
    }

    @Override
    public double[] predictClassProbs(Vector vector) {
        int numEstimators = this.estimatorList.size();
        double[] averageProbs = new double[this.numClasses];
        for (Classifier.ProbabilityEstimator estimator : this.estimatorList) {
            double[] probs = estimator.predictClassProbs(vector);
            for (int k = 0; k < this.numClasses; ++k) {
                int n = k;
                averageProbs[n] = averageProbs[n] + probs[k];
            }
        }
        int k = 0;
        while (k < this.numClasses) {
            int n = k++;
            averageProbs[n] = averageProbs[n] / (double)numEstimators;
        }
        return averageProbs;
    }

    @Override
    public FeatureList getFeatureList() {
        return this.featureList;
    }

    public void setFeatureList(FeatureList featureList) {
        this.featureList = featureList;
    }

    @Override
    public LabelTranslator getLabelTranslator() {
        return this.labelTranslator;
    }

    public void setLabelTranslator(LabelTranslator labelTranslator) {
        this.labelTranslator = labelTranslator;
    }
}

