/*
 * Decompiled with CFR 0.152.
 */
package edu.neu.ccs.pyramid.multilabel_classification;

import edu.neu.ccs.pyramid.dataset.LabelTranslator;
import edu.neu.ccs.pyramid.dataset.MultiLabel;
import edu.neu.ccs.pyramid.dataset.MultiLabelClfDataSet;
import edu.neu.ccs.pyramid.feature.FeatureList;
import java.io.BufferedOutputStream;
import java.io.File;
import java.io.FileOutputStream;
import java.io.ObjectOutputStream;
import java.io.Serializable;
import java.util.Arrays;
import java.util.List;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import org.apache.mahout.math.Vector;

public interface MultiLabelClassifier
extends Serializable {
    public int getNumClasses();

    public MultiLabel predict(Vector var1);

    default public MultiLabel[] predict(MultiLabelClfDataSet dataSet) {
        List<MultiLabel> results = IntStream.range(0, dataSet.getNumDataPoints()).parallel().mapToObj(i -> this.predict(dataSet.getRow(i))).collect(Collectors.toList());
        return results.toArray(new MultiLabel[results.size()]);
    }

    default public void serialize(File file) throws Exception {
        File parent = file.getParentFile();
        if (!parent.exists()) {
            parent.mkdirs();
        }
        try (FileOutputStream fileOutputStream = new FileOutputStream(file);
             BufferedOutputStream bufferedOutputStream = new BufferedOutputStream(fileOutputStream);
             ObjectOutputStream objectOutputStream = new ObjectOutputStream(bufferedOutputStream);){
            objectOutputStream.writeObject(this);
        }
    }

    default public void serialize(String file) throws Exception {
        this.serialize(new File(file));
    }

    public FeatureList getFeatureList();

    public LabelTranslator getLabelTranslator();

    public static interface AssignmentProbEstimator
    extends MultiLabelClassifier {
        public double predictLogAssignmentProb(Vector var1, MultiLabel var2);

        default public double predictAssignmentProb(Vector vector, MultiLabel assignment) {
            return Math.exp(this.predictLogAssignmentProb(vector, assignment));
        }

        default public double[] predictAssignmentProbs(Vector vector, List<MultiLabel> assignments) {
            return Arrays.stream(this.predictLogAssignmentProbs(vector, assignments)).map(Math::exp).toArray();
        }

        default public double[] predictLogAssignmentProbs(Vector vector, List<MultiLabel> assignments) {
            double[] logProbs = new double[assignments.size()];
            for (int c = 0; c < assignments.size(); ++c) {
                logProbs[c] = this.predictLogAssignmentProb(vector, assignments.get(c));
            }
            return logProbs;
        }
    }

    public static interface ClassProbEstimator
    extends MultiLabelClassifier {
        public double[] predictClassProbs(Vector var1);

        default public double predictClassProb(Vector vector, int classIndex) {
            return this.predictClassProbs(vector)[classIndex];
        }
    }

    public static interface ClassScoreEstimator
    extends MultiLabelClassifier {
        public double predictClassScore(Vector var1, int var2);

        default public double[] predictClassScores(Vector vector) {
            return IntStream.range(0, this.getNumClasses()).mapToDouble(k -> this.predictClassScore(vector, k)).toArray();
        }
    }
}

