/*
 * Decompiled with CFR 0.152.
 */
package edu.neu.ccs.pyramid.classification;

import edu.neu.ccs.pyramid.classification.Classifier;
import edu.neu.ccs.pyramid.dataset.ClfDataSet;
import edu.neu.ccs.pyramid.dataset.DataSet;
import edu.neu.ccs.pyramid.dataset.LabelTranslator;
import edu.neu.ccs.pyramid.feature.FeatureList;
import edu.neu.ccs.pyramid.util.ArgMax;
import edu.neu.ccs.pyramid.util.MathUtil;
import java.util.Arrays;
import org.apache.mahout.math.Vector;

public class PriorProbClassifier
implements Classifier.ProbabilityEstimator {
    private static final long serialVersionUID = 2L;
    private int numClasses;
    private double[] probs;
    private int topClass;
    private FeatureList featureList;
    private LabelTranslator labelTranslator;

    public PriorProbClassifier(int numClasses) {
        this.numClasses = numClasses;
        this.probs = new double[numClasses];
    }

    public PriorProbClassifier(double[] probs) {
        this.numClasses = probs.length;
        this.probs = probs;
    }

    public void fit(ClfDataSet clfDataSet) {
        int[] labels;
        double[] counts = new double[this.numClasses];
        int[] nArray = labels = clfDataSet.getLabels();
        int n = nArray.length;
        for (int i = 0; i < n; ++i) {
            int label;
            int n2 = label = nArray[i];
            counts[n2] = counts[n2] + 1.0;
        }
        int numDataPoints = clfDataSet.getNumDataPoints();
        for (int k = 0; k < this.numClasses; ++k) {
            this.probs[k] = counts[k] / (double)numDataPoints;
        }
        this.topClass = ArgMax.argMax(this.probs);
    }

    public void fit(DataSet dataSet, double[][] targetDistribution, double[] weights) {
        double totalCount = MathUtil.arraySum(weights);
        double[] counts = new double[this.numClasses];
        for (int i = 0; i < dataSet.getNumDataPoints(); ++i) {
            for (int k = 0; k < this.numClasses; ++k) {
                int n = k;
                counts[n] = counts[n] + targetDistribution[i][k] * weights[i];
            }
        }
        for (int k = 0; k < this.numClasses; ++k) {
            this.probs[k] = counts[k] / totalCount;
        }
        this.topClass = ArgMax.argMax(this.probs);
    }

    public double[] getGradient(ClfDataSet clfDataSet, int k) {
        int numDataPoints = clfDataSet.getNumDataPoints();
        double[] gradient = new double[numDataPoints];
        for (int i = 0; i < numDataPoints; ++i) {
            int label = clfDataSet.getLabels()[i];
            gradient[i] = label == k ? 1.0 - this.probs[label] : 0.0 - this.probs[label];
        }
        return gradient;
    }

    @Override
    public int predict(Vector vector) {
        return this.topClass;
    }

    @Override
    public int getNumClasses() {
        return this.numClasses;
    }

    @Override
    public double[] predictClassProbs(Vector vector) {
        return this.probs;
    }

    public double[] getClassProbs() {
        return this.probs;
    }

    public String toString() {
        return "PriorProbClassifier{numClasses=" + this.numClasses + ", probs=" + Arrays.toString(this.probs) + ", topClass=" + this.topClass + '}';
    }

    @Override
    public FeatureList getFeatureList() {
        return this.featureList;
    }

    public void setFeatureList(FeatureList featureList) {
        this.featureList = featureList;
    }

    @Override
    public LabelTranslator getLabelTranslator() {
        return this.labelTranslator;
    }

    public void setLabelTranslator(LabelTranslator labelTranslator) {
        this.labelTranslator = labelTranslator;
    }
}

