/*
 * Decompiled with CFR 0.152.
 */
package edu.neu.ccs.pyramid.eval;

import edu.neu.ccs.pyramid.dataset.MultiLabel;
import edu.neu.ccs.pyramid.dataset.MultiLabelClfDataSet;
import edu.neu.ccs.pyramid.eval.SafeDivide;
import edu.neu.ccs.pyramid.multilabel_classification.MultiLabelClassifier;
import java.util.HashSet;
import java.util.Set;
import java.util.stream.IntStream;

public class Overlap {
    public static double overlap(double tp, double fp, double fn) {
        return SafeDivide.divide(tp, tp + fp + fn, 1.0);
    }

    @Deprecated
    public static double overlap(MultiLabelClassifier classifier, MultiLabelClfDataSet dataSet) {
        return Overlap.overlap(dataSet.getMultiLabels(), classifier.predict(dataSet));
    }

    @Deprecated
    public static double overlap(MultiLabel[] multiLabels, MultiLabel[] predictions) {
        return IntStream.range(0, multiLabels.length).parallel().mapToDouble(i -> Overlap.overlap(multiLabels[i], predictions[i])).average().getAsDouble();
    }

    public static double overlap(MultiLabel multiLabel1, MultiLabel multiLabel2) {
        Set<Integer> set1 = multiLabel1.getMatchedLabels();
        Set<Integer> set2 = multiLabel2.getMatchedLabels();
        HashSet<Integer> union = new HashSet<Integer>();
        union.addAll(set1);
        union.addAll(set2);
        HashSet<Integer> intersection = new HashSet<Integer>();
        intersection.addAll(set1);
        intersection.retainAll(set2);
        return SafeDivide.divide(intersection.size(), union.size(), 1.0);
    }
}

