/*
 * Decompiled with CFR 0.152.
 */
package edu.neu.ccs.pyramid.classification;

import edu.neu.ccs.pyramid.dataset.DataSet;
import edu.neu.ccs.pyramid.dataset.LabelTranslator;
import edu.neu.ccs.pyramid.feature.FeatureList;
import edu.neu.ccs.pyramid.util.Pair;
import java.io.BufferedOutputStream;
import java.io.File;
import java.io.FileOutputStream;
import java.io.ObjectOutputStream;
import java.io.Serializable;
import java.util.Comparator;
import java.util.List;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import org.apache.mahout.math.Vector;

public interface Classifier
extends Serializable {
    public int predict(Vector var1);

    public int getNumClasses();

    default public int[] predict(DataSet dataSet) {
        return IntStream.range(0, dataSet.getNumDataPoints()).parallel().map(i -> this.predict(dataSet.getRow(i))).toArray();
    }

    default public void serialize(File file) throws Exception {
        File parent = file.getParentFile();
        if (!parent.exists()) {
            parent.mkdirs();
        }
        try (FileOutputStream fileOutputStream = new FileOutputStream(file);
             BufferedOutputStream bufferedOutputStream = new BufferedOutputStream(fileOutputStream);
             ObjectOutputStream objectOutputStream = new ObjectOutputStream(bufferedOutputStream);){
            objectOutputStream.writeObject(this);
        }
    }

    default public void serialize(String file) throws Exception {
        this.serialize(new File(file));
    }

    public FeatureList getFeatureList();

    public LabelTranslator getLabelTranslator();

    public static interface ProbabilityEstimator
    extends Classifier {
        public double[] predictClassProbs(Vector var1);

        default public double[] predictLogClassProbs(Vector vector) {
            double[] probs = this.predictClassProbs(vector);
            double[] logs = new double[probs.length];
            for (int k = 0; k < logs.length; ++k) {
                logs[k] = Math.log(probs[k]);
            }
            return logs;
        }

        default public double predictClassProb(Vector vector, int classIndex) {
            return this.predictClassProbs(vector)[classIndex];
        }

        default public List<double[]> predictClassProbs(DataSet dataSet) {
            return IntStream.range(0, dataSet.getNumDataPoints()).parallel().mapToObj(i -> this.predictClassProbs(dataSet.getRow(i))).collect(Collectors.toList());
        }

        @Override
        default public int predict(Vector vector) {
            Comparator<Pair> comparator = Comparator.comparing(Pair::getSecond);
            double[] probs = this.predictClassProbs(vector);
            return (Integer)IntStream.range(0, probs.length).mapToObj(i -> new Pair<Integer, Double>(i, probs[i])).max(comparator).get().getFirst();
        }
    }

    public static interface ScoreEstimator
    extends Classifier {
        public double predictClassScore(Vector var1, int var2);

        default public double[] predictClassScores(Vector vector) {
            double[] scores = new double[this.getNumClasses()];
            for (int k = 0; k < this.getNumClasses(); ++k) {
                scores[k] = this.predictClassScore(vector, k);
            }
            return scores;
        }
    }
}

