/*
 * Decompiled with CFR 0.152.
 */
package edu.neu.ccs.pyramid.simulation;

import edu.neu.ccs.pyramid.dataset.MLClfDataSetBuilder;
import edu.neu.ccs.pyramid.dataset.MultiLabel;
import edu.neu.ccs.pyramid.dataset.MultiLabelClfDataSet;
import edu.neu.ccs.pyramid.multilabel_classification.Enumerator;
import edu.neu.ccs.pyramid.multilabel_classification.crf.CMLCRF;
import edu.neu.ccs.pyramid.multilabel_classification.crf.SamplingPredictor;
import edu.neu.ccs.pyramid.multilabel_classification.crf.SubsetAccPredictor;
import edu.neu.ccs.pyramid.util.Sampling;
import java.util.List;
import org.apache.commons.math3.distribution.EnumeratedIntegerDistribution;
import org.apache.commons.math3.distribution.MultivariateNormalDistribution;
import org.apache.commons.math3.distribution.NormalDistribution;
import org.apache.mahout.math.DenseVector;
import org.apache.mahout.math.Vector;

public class MultiLabelSynthesizer {
    public static MultiLabelClfDataSet randomBinary() {
        int i;
        MultiLabelClfDataSet dataSet = MLClfDataSetBuilder.getBuilder().numFeatures(1).numClasses(2).numDataPoints(100).build();
        for (i = 0; i < 60; ++i) {
            dataSet.addLabel(i, 0);
        }
        for (i = 60; i < 100; ++i) {
            dataSet.addLabel(i, 1);
        }
        return dataSet;
    }

    public static MultiLabelClfDataSet randomTwoLabels() {
        int i;
        MultiLabelClfDataSet dataSet = MLClfDataSetBuilder.getBuilder().numFeatures(1).numClasses(2).numDataPoints(100).build();
        for (i = 0; i < 30; ++i) {
            dataSet.addLabel(i, 0);
            dataSet.addLabel(i, 1);
        }
        for (i = 30; i < 70; ++i) {
            dataSet.addLabel(i, 0);
        }
        for (i = 70; i < 100; ++i) {
            dataSet.addLabel(i, 1);
        }
        return dataSet;
    }

    public static MultiLabelClfDataSet randomMultiClass() {
        int i;
        MultiLabelClfDataSet dataSet = MLClfDataSetBuilder.getBuilder().numFeatures(1).numClasses(3).numDataPoints(100).build();
        for (i = 0; i < 30; ++i) {
            dataSet.addLabel(i, 0);
        }
        for (i = 30; i < 70; ++i) {
            dataSet.addLabel(i, 1);
        }
        for (i = 70; i < 100; ++i) {
            dataSet.addLabel(i, 2);
        }
        return dataSet;
    }

    public static MultiLabelClfDataSet flipOne(int numData, int numFeature, int numClass) {
        int i;
        MultiLabelClfDataSet dataSet = MLClfDataSetBuilder.getBuilder().numFeatures(numFeature).numClasses(numClass).numDataPoints(numData).build();
        Vector[] weights = new Vector[numClass];
        for (int k = 0; k < numClass; ++k) {
            DenseVector vector = new DenseVector(numFeature);
            for (int j = 0; j < numFeature; ++j) {
                vector.set(j, Sampling.doubleUniform(-1.0, 1.0));
            }
            weights[k] = vector;
        }
        for (i = 0; i < numData; ++i) {
            for (int j = 0; j < numFeature; ++j) {
                dataSet.setFeatureValue(i, j, Sampling.doubleUniform(-1.0, 1.0));
            }
        }
        for (i = 0; i < numData; ++i) {
            for (int k = 0; k < numClass; ++k) {
                double dot = weights[k].dot(dataSet.getRow(i));
                if (!(dot >= 0.0)) continue;
                dataSet.addLabel(i, k);
            }
        }
        for (i = 0; i < numData; ++i) {
            int toChange = Sampling.intUniform(0, numClass - 1);
            MultiLabel label = dataSet.getMultiLabels()[i];
            if (label.matchClass(toChange)) {
                label.removeLabel(toChange);
                continue;
            }
            label.addLabel(toChange);
        }
        return dataSet;
    }

    public static MultiLabelClfDataSet flipTwo(int numData, int numFeature, int numClass) {
        int i;
        MultiLabelClfDataSet dataSet = MLClfDataSetBuilder.getBuilder().numFeatures(numFeature).numClasses(numClass).numDataPoints(numData).build();
        Vector[] weights = new Vector[numClass];
        for (int k = 0; k < numClass; ++k) {
            DenseVector vector = new DenseVector(numFeature);
            for (int j = 0; j < numFeature; ++j) {
                vector.set(j, Sampling.doubleUniform(-1.0, 1.0));
            }
            weights[k] = vector;
        }
        for (i = 0; i < numData; ++i) {
            for (int j = 0; j < numFeature; ++j) {
                dataSet.setFeatureValue(i, j, Sampling.doubleUniform(-1.0, 1.0));
            }
        }
        for (i = 0; i < numData; ++i) {
            for (int k = 0; k < numClass; ++k) {
                double dot = weights[k].dot(dataSet.getRow(i));
                if (!(dot >= 0.0)) continue;
                dataSet.addLabel(i, k);
            }
        }
        for (i = 0; i < numData; ++i) {
            int toChange = Sampling.intUniform(0, numClass - 1);
            MultiLabel label = dataSet.getMultiLabels()[i];
            if (label.matchClass(toChange)) {
                label.removeLabel(toChange);
            } else {
                label.addLabel(toChange);
            }
            if (toChange != 0) continue;
            int another = Sampling.intUniform(1, numClass - 1);
            if (label.matchClass(another)) {
                label.removeLabel(another);
                continue;
            }
            label.addLabel(another);
        }
        return dataSet;
    }

    public static MultiLabelClfDataSet flipOneNonUniform(int numData) {
        int i;
        int numClass = 4;
        int numFeature = 2;
        MultiLabelClfDataSet dataSet = MLClfDataSetBuilder.getBuilder().numFeatures(numFeature).numClasses(numClass).numDataPoints(numData).build();
        Vector[] weights = new Vector[numClass];
        for (int k = 0; k < numClass; ++k) {
            DenseVector vector = new DenseVector(numFeature);
            weights[k] = vector;
        }
        weights[0].set(0, 0.0);
        weights[0].set(1, 1.0);
        weights[1].set(0, 1.0);
        weights[1].set(1, 1.0);
        weights[2].set(0, 1.0);
        weights[2].set(1, 0.0);
        weights[3].set(0, 1.0);
        weights[3].set(1, -1.0);
        for (i = 0; i < numData; ++i) {
            for (int j = 0; j < numFeature; ++j) {
                dataSet.setFeatureValue(i, j, Sampling.doubleUniform(-1.0, 1.0));
            }
        }
        for (i = 0; i < numData; ++i) {
            for (int k = 0; k < numClass; ++k) {
                double dot = weights[k].dot(dataSet.getRow(i));
                if (!(dot >= 0.0)) continue;
                dataSet.addLabel(i, k);
            }
        }
        int[] indices = new int[]{0, 1, 2, 3};
        double[] probs = new double[]{0.4, 0.2, 0.2, 0.2};
        EnumeratedIntegerDistribution distribution = new EnumeratedIntegerDistribution(indices, probs);
        for (int i2 = 0; i2 < numData; ++i2) {
            int toChange = distribution.sample();
            MultiLabel label = dataSet.getMultiLabels()[i2];
            if (label.matchClass(toChange)) {
                label.removeLabel(toChange);
                continue;
            }
            label.addLabel(toChange);
        }
        return dataSet;
    }

    public static MultiLabelClfDataSet sampleFromMix() {
        int numData = 10000;
        int numClass = 2;
        int numFeature = 2;
        int numClusters = 2;
        double[] proportions = new double[]{0.4, 0.6};
        int[] indices = new int[]{0, 1};
        MultiLabelClfDataSet dataSet = MLClfDataSetBuilder.getBuilder().numFeatures(numFeature).numClasses(numClass).numDataPoints(numData).build();
        Vector[][] weights = new Vector[numClusters][numClass];
        for (int c = 0; c < numClusters; ++c) {
            for (int l = 0; l < numClass; ++l) {
                DenseVector vector = new DenseVector(numFeature);
                weights[c][l] = vector;
            }
        }
        weights[0][0].set(0, 0.0);
        weights[0][0].set(1, 1.0);
        weights[0][1].set(0, 1.0);
        weights[0][1].set(1, 1.0);
        weights[1][0].set(0, 1.0);
        weights[1][0].set(1, 0.0);
        weights[1][1].set(0, 1.0);
        weights[1][1].set(1, -1.0);
        for (int i = 0; i < numData; ++i) {
            for (int j = 0; j < numFeature; ++j) {
                dataSet.setFeatureValue(i, j, Sampling.doubleUniform(-1.0, 1.0));
            }
        }
        EnumeratedIntegerDistribution distribution = new EnumeratedIntegerDistribution(indices, proportions);
        for (int i = 0; i < numData; ++i) {
            int cluster = distribution.sample();
            System.out.println("cluster " + cluster);
            for (int l = 0; l < numClass; ++l) {
                System.out.println("row = " + dataSet.getRow(i));
                System.out.println("weight = " + weights[cluster][l]);
                double dot = weights[cluster][l].dot(dataSet.getRow(i));
                System.out.println("dot = " + dot);
                if (!(dot >= 0.0)) continue;
                dataSet.addLabel(i, l);
            }
        }
        return dataSet;
    }

    public static MultiLabelClfDataSet independentNoise() {
        int numData = 10000;
        int numClass = 4;
        int numFeature = 2;
        MultiLabelClfDataSet dataSet = MLClfDataSetBuilder.getBuilder().numFeatures(numFeature).numClasses(numClass).numDataPoints(numData).build();
        Vector[] weights = new Vector[numClass];
        for (int k = 0; k < numClass; ++k) {
            DenseVector vector = new DenseVector(numFeature);
            weights[k] = vector;
        }
        weights[0].set(0, 0.0);
        weights[0].set(1, 1.0);
        weights[1].set(0, 1.0);
        weights[1].set(1, 1.0);
        weights[2].set(0, 1.0);
        weights[2].set(1, 0.0);
        weights[3].set(0, 1.0);
        weights[3].set(1, -1.0);
        for (int i = 0; i < numData; ++i) {
            for (int j = 0; j < numFeature; ++j) {
                dataSet.setFeatureValue(i, j, Sampling.doubleUniform(-1.0, 1.0));
            }
        }
        NormalDistribution[] noises = new NormalDistribution[]{new NormalDistribution(0.0, 0.1), new NormalDistribution(0.0, 0.1), new NormalDistribution(0.0, 0.1), new NormalDistribution(0.0, 0.1)};
        int numFlipped = 0;
        for (int i = 0; i < numData; ++i) {
            for (int k = 0; k < numClass; ++k) {
                double dot = weights[k].dot(dataSet.getRow(i));
                double score = dot + noises[k].sample();
                if (score >= 0.0) {
                    dataSet.addLabel(i, k);
                }
                if (!(dot * score < 0.0)) continue;
                ++numFlipped;
            }
        }
        System.out.println("number of flipped = " + numFlipped);
        return dataSet;
    }

    public static MultiLabelClfDataSet gaussianNoise(int numData) {
        int numClass = 3;
        int numFeature = 3;
        MultiLabelClfDataSet dataSet = MLClfDataSetBuilder.getBuilder().numFeatures(numFeature).numClasses(numClass).numDataPoints(numData).build();
        Vector[] weights = new Vector[numClass];
        for (int k = 0; k < numClass; ++k) {
            DenseVector vector = new DenseVector(numFeature);
            weights[k] = vector;
        }
        weights[0].set(1, 1.0);
        weights[1].set(0, 1.0);
        weights[2].set(2, 1.0);
        for (int i = 0; i < numData; ++i) {
            for (int j = 0; j < numFeature; ++j) {
                dataSet.setFeatureValue(i, j, Sampling.doubleUniform(-1.0, 1.0));
            }
        }
        double[] means = new double[numClass];
        double[][] covars = new double[numClass][numClass];
        covars[0][0] = 0.5;
        covars[0][1] = 0.02;
        covars[1][0] = 0.02;
        covars[0][2] = -0.03;
        covars[2][0] = -0.03;
        covars[1][1] = 0.2;
        covars[1][2] = -0.03;
        covars[2][1] = -0.03;
        covars[2][2] = 0.3;
        MultivariateNormalDistribution distribution = new MultivariateNormalDistribution(means, covars);
        int numFlipped = 0;
        for (int i = 0; i < numData; ++i) {
            double[] noises = distribution.sample();
            for (int k = 0; k < numClass; ++k) {
                double dot = weights[k].dot(dataSet.getRow(i));
                double score = dot + noises[k];
                if (score >= 0.0) {
                    dataSet.addLabel(i, k);
                }
                if (!(dot * score < 0.0)) continue;
                ++numFlipped;
            }
        }
        System.out.println("number of flipped bits = " + numFlipped);
        return dataSet;
    }

    public static MultiLabelClfDataSet independent() {
        int i;
        int numData = 10000;
        int numClass = 4;
        int numFeature = 2;
        MultiLabelClfDataSet dataSet = MLClfDataSetBuilder.getBuilder().numFeatures(numFeature).numClasses(numClass).numDataPoints(numData).build();
        Vector[] weights = new Vector[numClass];
        for (int k = 0; k < numClass; ++k) {
            DenseVector vector = new DenseVector(numFeature);
            weights[k] = vector;
        }
        weights[0].set(0, 0.0);
        weights[0].set(1, 1.0);
        weights[1].set(0, 1.0);
        weights[1].set(1, 1.0);
        weights[2].set(0, 1.0);
        weights[2].set(1, 0.0);
        weights[3].set(0, 1.0);
        weights[3].set(1, -1.0);
        for (i = 0; i < numData; ++i) {
            for (int j = 0; j < numFeature; ++j) {
                dataSet.setFeatureValue(i, j, Sampling.doubleUniform(-1.0, 1.0));
            }
        }
        for (i = 0; i < numData; ++i) {
            for (int k = 0; k < numClass; ++k) {
                double dot = weights[k].dot(dataSet.getRow(i));
                double score = dot;
                if (!(score >= 0.0)) continue;
                dataSet.addLabel(i, k);
            }
        }
        return dataSet;
    }

    public static MultiLabelClfDataSet crfSample() {
        int numData = 10000;
        int numClass = 4;
        int numFeature = 2;
        MultiLabelClfDataSet dataSet = MLClfDataSetBuilder.getBuilder().numFeatures(numFeature).numClasses(numClass).numDataPoints(numData).build();
        List<MultiLabel> support = Enumerator.enumerate(numClass);
        CMLCRF cmlcrf = new CMLCRF(numClass, numFeature, support);
        cmlcrf.getWeights().getWeightsWithoutBiasForClass(0).set(0, 0.0);
        cmlcrf.getWeights().getWeightsWithoutBiasForClass(0).set(1, 10.0);
        cmlcrf.getWeights().getWeightsWithoutBiasForClass(1).set(0, 10.0);
        cmlcrf.getWeights().getWeightsWithoutBiasForClass(1).set(1, 10.0);
        cmlcrf.getWeights().getWeightsWithoutBiasForClass(2).set(0, 10.0);
        cmlcrf.getWeights().getWeightsWithoutBiasForClass(2).set(1, 0.0);
        cmlcrf.getWeights().getWeightsWithoutBiasForClass(3).set(0, -10.0);
        cmlcrf.getWeights().getWeightsWithoutBiasForClass(3).set(1, -10.0);
        for (int i = 0; i < numData; ++i) {
            for (int j = 0; j < numFeature; ++j) {
                dataSet.setFeatureValue(i, j, Sampling.doubleUniform(-1.0, 1.0));
            }
        }
        SamplingPredictor samplingPredictor = new SamplingPredictor(cmlcrf);
        for (int i = 0; i < numData; ++i) {
            MultiLabel label = samplingPredictor.predict(dataSet.getRow(i));
            dataSet.setLabels(i, label);
        }
        return dataSet;
    }

    public static MultiLabelClfDataSet crfArgmax() {
        int numData = 1000;
        int numClass = 4;
        int numFeature = 10;
        MultiLabelClfDataSet dataSet = MLClfDataSetBuilder.getBuilder().numFeatures(numFeature).numClasses(numClass).numDataPoints(numData).build();
        List<MultiLabel> support = Enumerator.enumerate(numClass);
        CMLCRF cmlcrf = new CMLCRF(numClass, numFeature, support);
        cmlcrf.getWeights().getWeightsWithoutBiasForClass(0).set(0, 0.0);
        cmlcrf.getWeights().getWeightsWithoutBiasForClass(0).set(1, 10.0);
        cmlcrf.getWeights().getWeightsWithoutBiasForClass(1).set(0, 10.0);
        cmlcrf.getWeights().getWeightsWithoutBiasForClass(1).set(1, 10.0);
        cmlcrf.getWeights().getWeightsWithoutBiasForClass(2).set(0, 10.0);
        cmlcrf.getWeights().getWeightsWithoutBiasForClass(2).set(1, 0.0);
        cmlcrf.getWeights().getWeightsWithoutBiasForClass(3).set(0, -10.0);
        cmlcrf.getWeights().getWeightsWithoutBiasForClass(3).set(1, -10.0);
        for (int i = 0; i < numData; ++i) {
            for (int j = 0; j < numFeature; ++j) {
                dataSet.setFeatureValue(i, j, Sampling.doubleUniform(-1.0, 1.0));
            }
        }
        SubsetAccPredictor predictor = new SubsetAccPredictor(cmlcrf);
        for (int i = 0; i < numData; ++i) {
            MultiLabel label = predictor.predict(dataSet.getRow(i));
            dataSet.setLabels(i, label);
        }
        return dataSet;
    }

    public static MultiLabelClfDataSet crfArgmaxHide() {
        int i;
        int numData = 10000;
        int numClass = 4;
        int numFeature = 2;
        MultiLabelClfDataSet dataSet = MLClfDataSetBuilder.getBuilder().numFeatures(numFeature).numClasses(numClass).numDataPoints(numData).build();
        List<MultiLabel> support = Enumerator.enumerate(numClass);
        CMLCRF cmlcrf = new CMLCRF(numClass, numFeature, support);
        cmlcrf.getWeights().getWeightsWithoutBiasForClass(0).set(0, 0.0);
        cmlcrf.getWeights().getWeightsWithoutBiasForClass(0).set(1, 10.0);
        cmlcrf.getWeights().getWeightsWithoutBiasForClass(1).set(0, 10.0);
        cmlcrf.getWeights().getWeightsWithoutBiasForClass(1).set(1, 10.0);
        cmlcrf.getWeights().getWeightsWithoutBiasForClass(2).set(0, 10.0);
        cmlcrf.getWeights().getWeightsWithoutBiasForClass(2).set(1, 0.0);
        cmlcrf.getWeights().getWeightsWithoutBiasForClass(3).set(0, -10.0);
        cmlcrf.getWeights().getWeightsWithoutBiasForClass(3).set(1, -10.0);
        for (int i2 = 0; i2 < numData; ++i2) {
            for (int j = 0; j < numFeature; ++j) {
                dataSet.setFeatureValue(i2, j, Sampling.doubleUniform(-1.0, 1.0));
            }
        }
        SubsetAccPredictor predictor = new SubsetAccPredictor(cmlcrf);
        for (i = 0; i < numData; ++i) {
            MultiLabel label = predictor.predict(dataSet.getRow(i));
            dataSet.setLabels(i, label);
        }
        for (i = 0; i < numData; ++i) {
            dataSet.setFeatureValue(i, 0, 0.0);
        }
        return dataSet;
    }

    public static MultiLabelClfDataSet crfArgmaxDrop() {
        int numData = 1000;
        int numClass = 4;
        int numFeature = 10;
        MultiLabelClfDataSet dataSet = MLClfDataSetBuilder.getBuilder().numFeatures(numFeature).numClasses(numClass).numDataPoints(numData).build();
        List<MultiLabel> support = Enumerator.enumerate(numClass);
        CMLCRF cmlcrf = new CMLCRF(numClass, numFeature, support);
        cmlcrf.getWeights().getWeightsWithoutBiasForClass(0).set(0, 0.0);
        cmlcrf.getWeights().getWeightsWithoutBiasForClass(0).set(1, 10.0);
        cmlcrf.getWeights().getWeightsWithoutBiasForClass(1).set(0, 10.0);
        cmlcrf.getWeights().getWeightsWithoutBiasForClass(1).set(1, 10.0);
        cmlcrf.getWeights().getWeightsWithoutBiasForClass(2).set(0, 10.0);
        cmlcrf.getWeights().getWeightsWithoutBiasForClass(2).set(1, 0.0);
        cmlcrf.getWeights().getWeightsWithoutBiasForClass(3).set(0, -10.0);
        cmlcrf.getWeights().getWeightsWithoutBiasForClass(3).set(1, -10.0);
        for (int i = 0; i < numData; ++i) {
            for (int j = 0; j < numFeature; ++j) {
                dataSet.setFeatureValue(i, j, Sampling.doubleUniform(-1.0, 1.0));
            }
        }
        SubsetAccPredictor predictor = new SubsetAccPredictor(cmlcrf);
        double[] alphas = new double[]{1.0, 0.9, 0.8, 0.7};
        for (int i = 0; i < numData; ++i) {
            MultiLabel label = predictor.predict(dataSet.getRow(i)).copy();
            for (int l = 0; l < numClass; ++l) {
                if (!(Math.random() > alphas[l]) || !label.matchClass(l)) continue;
                label.removeLabel(l);
            }
            dataSet.setLabels(i, label);
        }
        return dataSet;
    }
}

