/*
 * Decompiled with CFR 0.152.
 */
package edu.neu.ccs.pyramid.dataset;

import edu.neu.ccs.pyramid.dataset.LabelTranslator;
import edu.neu.ccs.pyramid.util.ListUtil;
import java.io.Serializable;
import java.util.BitSet;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.stream.Collectors;
import org.apache.mahout.math.DenseVector;
import org.apache.mahout.math.Vector;

public class MultiLabel
implements Serializable {
    private static final long serialVersionUID = 3L;
    private BitSet labels = new BitSet();

    public MultiLabel() {
    }

    public MultiLabel(Vector vector) {
        this();
        for (Vector.Element element : vector.nonZeroes()) {
            this.addLabel(element.index());
        }
    }

    public MultiLabel copy() {
        MultiLabel c = new MultiLabel();
        int i = this.labels.nextSetBit(0);
        while (i >= 0) {
            c.addLabel(i);
            i = this.labels.nextSetBit(i + 1);
        }
        return c;
    }

    public Vector toVector(int length) {
        DenseVector vector = new DenseVector(length);
        int i = this.labels.nextSetBit(0);
        while (i >= 0) {
            vector.set(i, 1.0);
            i = this.labels.nextSetBit(i + 1);
        }
        return vector;
    }

    public MultiLabel addLabel(int k) {
        this.labels.set(k);
        return this;
    }

    public void removeLabel(int k) {
        this.labels.clear(k);
    }

    public void flipLabel(int k) {
        this.labels.flip(k);
    }

    public boolean matchClass(int k) {
        return this.labels.get(k);
    }

    public Set<Integer> getMatchedLabels() {
        HashSet<Integer> set = new HashSet<Integer>();
        int i = this.labels.nextSetBit(0);
        while (i >= 0) {
            set.add(i);
            i = this.labels.nextSetBit(i + 1);
        }
        return set;
    }

    public int getNumMatchedLabels() {
        return this.labels.cardinality();
    }

    public List<Integer> getMatchedLabelsOrdered() {
        return this.getMatchedLabels().stream().sorted().collect(Collectors.toList());
    }

    public static Set<Integer> union(MultiLabel multiLabel1, MultiLabel multiLabel2) {
        HashSet<Integer> union = new HashSet<Integer>();
        union.addAll(multiLabel1.getMatchedLabels());
        union.addAll(multiLabel2.getMatchedLabels());
        return union;
    }

    public static Set<Integer> intersection(MultiLabel multiLabel1, MultiLabel multiLabel2) {
        HashSet<Integer> intersection = new HashSet<Integer>();
        intersection.addAll(multiLabel1.getMatchedLabels());
        intersection.retainAll(multiLabel2.getMatchedLabels());
        return intersection;
    }

    public static Set<Integer> symmetricDifference(MultiLabel multiLabel1, MultiLabel multiLabel2) {
        Set<Integer> union = MultiLabel.union(multiLabel1, multiLabel2);
        Set<Integer> intersection = MultiLabel.intersection(multiLabel1, multiLabel2);
        union.removeAll(intersection);
        return union;
    }

    public boolean outOfBound(int numClasses) {
        for (int k : this.getMatchedLabels()) {
            if (k <= numClasses - 1) continue;
            return true;
        }
        return false;
    }

    public boolean isSubsetOf(MultiLabel superSet) {
        int i = this.labels.nextSetBit(0);
        while (i >= 0) {
            if (!superSet.matchClass(i)) {
                return false;
            }
            i = this.labels.nextSetBit(i + 1);
        }
        return true;
    }

    public String toString() {
        StringBuilder sb = new StringBuilder();
        sb.append("{");
        sb.append(ListUtil.toSimpleString(this.getMatchedLabels().stream().sorted().collect(Collectors.toList())));
        sb.append("}");
        return sb.toString();
    }

    public String toStringWithExtLabels(LabelTranslator labelTranslator) {
        return this.getMatchedLabels().stream().sorted().map(labelTranslator::toExtLabel).collect(Collectors.toList()).toString();
    }

    public boolean equals(Object o) {
        if (this == o) {
            return true;
        }
        if (o == null || this.getClass() != o.getClass()) {
            return false;
        }
        MultiLabel that = (MultiLabel)o;
        return this.labels.equals(that.labels);
    }

    public int hashCode() {
        return this.labels.hashCode();
    }
}

