/*
 * Decompiled with CFR 0.152.
 */
package edu.neu.ccs.pyramid.dataset;

import edu.neu.ccs.pyramid.dataset.AbstractDataSet;
import edu.neu.ccs.pyramid.dataset.ClfDataSet;
import edu.neu.ccs.pyramid.dataset.ClfDataSetBuilder;
import edu.neu.ccs.pyramid.dataset.DataSet;
import edu.neu.ccs.pyramid.dataset.DataSetBuilder;
import edu.neu.ccs.pyramid.dataset.DenseClfDataSet;
import edu.neu.ccs.pyramid.dataset.DenseMLClfDataSet;
import edu.neu.ccs.pyramid.dataset.IdTranslator;
import edu.neu.ccs.pyramid.dataset.LabelTranslator;
import edu.neu.ccs.pyramid.dataset.MLClfDataSetBuilder;
import edu.neu.ccs.pyramid.dataset.MultiLabel;
import edu.neu.ccs.pyramid.dataset.MultiLabelClfDataSet;
import edu.neu.ccs.pyramid.dataset.RegDataSet;
import edu.neu.ccs.pyramid.dataset.RegDataSetBuilder;
import edu.neu.ccs.pyramid.dataset.SparseClfDataSet;
import edu.neu.ccs.pyramid.dataset.SparseMLClfDataSet;
import edu.neu.ccs.pyramid.feature.Feature;
import edu.neu.ccs.pyramid.feature.FeatureList;
import edu.neu.ccs.pyramid.util.Pair;
import edu.neu.ccs.pyramid.util.Sampling;
import edu.neu.ccs.pyramid.util.SetUtil;
import edu.neu.ccs.pyramid.util.Translator;
import java.io.BufferedReader;
import java.io.BufferedWriter;
import java.io.File;
import java.io.FileNotFoundException;
import java.io.FileReader;
import java.io.FileWriter;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.regex.Pattern;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import org.apache.mahout.math.Vector;

public class DataSetUtil {
    public static ClfDataSet changeLabels(ClfDataSet dataSet, int numClasses) {
        int numDataPoints = dataSet.getNumDataPoints();
        int numFeatures = dataSet.getNumFeatures();
        boolean missingValue = dataSet.hasMissingValue();
        AbstractDataSet dataSet1 = dataSet.isDense() ? new DenseClfDataSet(numDataPoints, numFeatures, missingValue, numClasses) : new SparseClfDataSet(numDataPoints, numFeatures, missingValue, numClasses);
        for (int i = 0; i < numDataPoints; ++i) {
            Vector vector = dataSet.getRow(i);
            for (Vector.Element element : vector.nonZeroes()) {
                int featureIndex = element.index();
                double value = element.get();
                if (featureIndex >= numFeatures) continue;
                dataSet1.setFeatureValue(i, featureIndex, value);
            }
        }
        return dataSet1;
    }

    public static ClfDataSet sampleFeatures(ClfDataSet dataSet, List<Integer> columnsToKeep) {
        int numClasses = dataSet.getNumClasses();
        boolean missingValue = dataSet.hasMissingValue();
        AbstractDataSet trimmed = dataSet.isDense() ? new DenseClfDataSet(dataSet.getNumDataPoints(), columnsToKeep.size(), missingValue, numClasses) : new SparseClfDataSet(dataSet.getNumDataPoints(), columnsToKeep.size(), missingValue, numClasses);
        for (int j = 0; j < trimmed.getNumFeatures(); ++j) {
            int oldColumnIndex = columnsToKeep.get(j);
            Vector vector = dataSet.getColumn(oldColumnIndex);
            for (Vector.Element element : vector.nonZeroes()) {
                int dataPointIndex = element.index();
                double value = element.get();
                trimmed.setFeatureValue(dataPointIndex, j, value);
            }
        }
        int[] labels = dataSet.getLabels();
        for (int i = 0; i < trimmed.getNumDataPoints(); ++i) {
            trimmed.setLabel(i, labels[i]);
        }
        trimmed.setLabelTranslator(dataSet.getLabelTranslator());
        trimmed.setIdTranslator(dataSet.getIdTranslator());
        List<Feature> oldFeatures = dataSet.getFeatureList().getAll();
        List<Feature> newFeatures = columnsToKeep.stream().map(oldFeatures::get).collect(Collectors.toList());
        for (int i = 0; i < newFeatures.size(); ++i) {
            ((Feature)newFeatures.get(i)).setIndex(i);
        }
        trimmed.setFeatureList(new FeatureList(newFeatures));
        return trimmed;
    }

    public static RegDataSet sampleFeatures(RegDataSet dataSet, List<Integer> columnsToKeep) {
        RegDataSet trimmed = RegDataSetBuilder.getBuilder().numDataPoints(dataSet.getNumDataPoints()).numFeatures(columnsToKeep.size()).missingValue(dataSet.hasMissingValue()).dense(dataSet.isDense()).build();
        for (int j = 0; j < trimmed.getNumFeatures(); ++j) {
            int oldColumnIndex = columnsToKeep.get(j);
            Vector vector = dataSet.getColumn(oldColumnIndex);
            for (Vector.Element element : vector.nonZeroes()) {
                int dataPointIndex = element.index();
                double value = element.get();
                trimmed.setFeatureValue(dataPointIndex, j, value);
            }
        }
        double[] labels = dataSet.getLabels();
        for (int i = 0; i < trimmed.getNumDataPoints(); ++i) {
            trimmed.setLabel(i, labels[i]);
        }
        trimmed.setIdTranslator(dataSet.getIdTranslator());
        List<Feature> oldFeatures = dataSet.getFeatureList().getAll();
        List<Feature> newFeatures = columnsToKeep.stream().map(oldFeatures::get).collect(Collectors.toList());
        for (int i = 0; i < newFeatures.size(); ++i) {
            ((Feature)newFeatures.get(i)).setIndex(i);
        }
        trimmed.setFeatureList(new FeatureList(newFeatures));
        return trimmed;
    }

    public static MultiLabelClfDataSet sampleFeatures(MultiLabelClfDataSet dataSet, List<Integer> columnsToKeep) {
        boolean missingValue = dataSet.hasMissingValue();
        int numClasses = dataSet.getNumClasses();
        AbstractDataSet trimmed = dataSet.isDense() ? new DenseMLClfDataSet(dataSet.getNumDataPoints(), columnsToKeep.size(), missingValue, numClasses) : new SparseMLClfDataSet(dataSet.getNumDataPoints(), columnsToKeep.size(), missingValue, numClasses);
        for (int j = 0; j < trimmed.getNumFeatures(); ++j) {
            int oldColumnIndex = columnsToKeep.get(j);
            Vector vector = dataSet.getColumn(oldColumnIndex);
            for (Vector.Element element : vector.nonZeroes()) {
                int dataPointIndex = element.index();
                double value = element.get();
                trimmed.setFeatureValue(dataPointIndex, j, value);
            }
        }
        MultiLabel[] multiLabels = dataSet.getMultiLabels();
        for (int i = 0; i < trimmed.getNumDataPoints(); ++i) {
            trimmed.addLabels(i, multiLabels[i].getMatchedLabels());
        }
        trimmed.setLabelTranslator(dataSet.getLabelTranslator());
        trimmed.setIdTranslator(dataSet.getIdTranslator());
        List<Feature> oldFeatures = dataSet.getFeatureList().getAll();
        List<Feature> newFeatures = columnsToKeep.stream().map(oldFeatures::get).collect(Collectors.toList());
        for (int i = 0; i < newFeatures.size(); ++i) {
            ((Feature)newFeatures.get(i)).setIndex(i);
        }
        trimmed.setFeatureList(new FeatureList(newFeatures));
        return trimmed;
    }

    public static ClfDataSet sampleFeatures(ClfDataSet clfDataSet, int numFeatures) {
        List<Integer> columnsToKeep = IntStream.range(0, numFeatures).mapToObj(i -> i).collect(Collectors.toList());
        return DataSetUtil.sampleFeatures(clfDataSet, columnsToKeep);
    }

    public static void extractColumns(String inputFile, String outputFile, int start, int end, String delimiter) throws IOException {
        try (BufferedReader br = new BufferedReader(new FileReader(new File(inputFile)));
             BufferedWriter bw = new BufferedWriter(new FileWriter(new File(outputFile)));){
            String line;
            while ((line = br.readLine()) != null) {
                Object[] split = line.split(Pattern.quote(delimiter));
                System.out.println(Arrays.toString(split));
                System.out.println(split.length);
                for (int i = start; i <= end; ++i) {
                    System.out.println(i);
                    bw.write((String)split[i]);
                    if (i >= end) continue;
                    bw.write(delimiter);
                }
                bw.newLine();
            }
        }
    }

    public static void extractColumns(String inputFile, String outputFile, int start, int end, Pattern pattern) throws IOException {
        try (BufferedReader br = new BufferedReader(new FileReader(new File(inputFile)));
             BufferedWriter bw = new BufferedWriter(new FileWriter(new File(outputFile)));){
            String line;
            while ((line = br.readLine()) != null) {
                String[] split = line.trim().split(pattern.pattern());
                for (int i = start; i <= end; ++i) {
                    bw.write(split[i]);
                    if (i >= end) continue;
                    bw.write(",");
                }
                bw.newLine();
            }
        }
    }

    public static ClfDataSet bootstrap(ClfDataSet clfDataSet) {
        HashMap labelIndicesMap = new HashMap();
        int[] labels = clfDataSet.getLabels();
        for (int i = 0; i < clfDataSet.getNumDataPoints(); ++i) {
            int label = labels[i];
            if (!labelIndicesMap.containsKey(label)) {
                labelIndicesMap.put(label, new ArrayList());
            }
            ((List)labelIndicesMap.get(label)).add(i);
        }
        ArrayList<Integer> sampledIndices = new ArrayList<Integer>(clfDataSet.getNumDataPoints());
        for (Map.Entry entry : labelIndicesMap.entrySet()) {
            int[] sampleForClass;
            List indices = (List)entry.getValue();
            for (int index : sampleForClass = Sampling.sampleWithReplacement(indices.size(), indices).toArray()) {
                sampledIndices.add(index);
            }
        }
        return DataSetUtil.sampleData(clfDataSet, sampledIndices);
    }

    public static ClfDataSet sampleData(ClfDataSet dataSet, List<Integer> indices) {
        int numClasses = dataSet.getNumClasses();
        boolean missingValue = dataSet.hasMissingValue();
        AbstractDataSet sample = dataSet instanceof DenseClfDataSet ? new DenseClfDataSet(indices.size(), dataSet.getNumFeatures(), missingValue, numClasses) : new SparseClfDataSet(indices.size(), dataSet.getNumFeatures(), missingValue, numClasses);
        int[] labels = dataSet.getLabels();
        for (int i = 0; i < indices.size(); ++i) {
            int indexInOld = indices.get(i);
            Vector oldVector = dataSet.getRow(indexInOld);
            int label = labels[indexInOld];
            sample.setLabel(i, label);
            for (Vector.Element element : oldVector.nonZeroes()) {
                sample.setFeatureValue(i, element.index(), element.get());
            }
        }
        sample.setLabelTranslator(dataSet.getLabelTranslator());
        sample.setFeatureList(dataSet.getFeatureList());
        return sample;
    }

    public static RegDataSet sampleData(RegDataSet dataSet, List<Integer> indices) {
        RegDataSet sample = RegDataSetBuilder.getBuilder().numDataPoints(indices.size()).numFeatures(dataSet.getNumFeatures()).missingValue(dataSet.hasMissingValue()).dense(dataSet.isDense()).build();
        double[] labels = dataSet.getLabels();
        for (int i = 0; i < indices.size(); ++i) {
            int indexInOld = indices.get(i);
            Vector oldVector = dataSet.getRow(indexInOld);
            double label = labels[indexInOld];
            sample.setLabel(i, label);
            for (Vector.Element element : oldVector.nonZeroes()) {
                sample.setFeatureValue(i, element.index(), element.get());
            }
        }
        sample.setFeatureList(dataSet.getFeatureList());
        return sample;
    }

    public static Pair<DataSet, double[][]> sampleData(DataSet dataSet, double[][] targetDistribution, List<Integer> indices) {
        int numClasses = targetDistribution[0].length;
        double[][] sampledTargets = new double[indices.size()][numClasses];
        DataSet sample = DataSetBuilder.getBuilder().dense(dataSet.isDense()).missingValue(dataSet.hasMissingValue()).numDataPoints(indices.size()).numFeatures(dataSet.getNumFeatures()).build();
        for (int i = 0; i < indices.size(); ++i) {
            int indexInOld = indices.get(i);
            Vector oldVector = dataSet.getRow(indexInOld);
            double[] targets = targetDistribution[indexInOld];
            sampledTargets[i] = Arrays.copyOf(targets, targets.length);
            for (Vector.Element element : oldVector.nonZeroes()) {
                sample.setFeatureValue(i, element.index(), element.get());
            }
        }
        sample.setFeatureList(dataSet.getFeatureList());
        return new Pair<DataSet, double[][]>(sample, sampledTargets);
    }

    public static MultiLabelClfDataSet sampleData(MultiLabelClfDataSet dataSet, List<Integer> indices) {
        MultiLabelClfDataSet sample = MLClfDataSetBuilder.getBuilder().numClasses(dataSet.getNumClasses()).numDataPoints(indices.size()).numFeatures(dataSet.getNumFeatures()).missingValue(dataSet.hasMissingValue()).density(dataSet.density()).build();
        MultiLabel[] labels = dataSet.getMultiLabels();
        IdTranslator idTranslator = new IdTranslator();
        for (int i = 0; i < indices.size(); ++i) {
            int indexInOld = indices.get(i);
            String extId = dataSet.getIdTranslator().toExtId(indexInOld);
            idTranslator.addData(i, extId);
            Vector oldVector = dataSet.getRow(indexInOld);
            Set<Integer> label = labels[indexInOld].getMatchedLabels();
            sample.addLabels(i, label);
            for (Vector.Element element : oldVector.nonZeroes()) {
                sample.setFeatureValue(i, element.index(), element.get());
            }
        }
        sample.setFeatureList(dataSet.getFeatureList());
        sample.setIdTranslator(idTranslator);
        sample.setLabelTranslator(dataSet.getLabelTranslator());
        return sample;
    }

    public static ClfDataSet concatenateByColumn(ClfDataSet dataSet1, ClfDataSet dataSet2) {
        double value;
        int i;
        Vector vector;
        int j;
        int numDataPoints = dataSet1.getNumDataPoints();
        int numFeatures1 = dataSet1.getNumFeatures();
        int numFeatures2 = dataSet2.getNumFeatures();
        int numFeatures = numFeatures1 + numFeatures2;
        ClfDataSet dataSet = ClfDataSetBuilder.getBuilder().numDataPoints(numDataPoints).numFeatures(numFeatures).numClasses(dataSet1.getNumClasses()).dense(dataSet1.isDense()).missingValue(dataSet1.hasMissingValue()).build();
        int featureIndex = 0;
        for (j = 0; j < numFeatures1; ++j) {
            vector = dataSet1.getColumn(j);
            for (Vector.Element element : vector.nonZeroes()) {
                i = element.index();
                value = element.get();
                dataSet.setFeatureValue(i, featureIndex, value);
            }
            ++featureIndex;
        }
        for (j = 0; j < numFeatures2; ++j) {
            vector = dataSet2.getColumn(j);
            for (Vector.Element element : vector.nonZeroes()) {
                i = element.index();
                value = element.get();
                dataSet.setFeatureValue(i, featureIndex, value);
            }
            ++featureIndex;
        }
        int[] labels = dataSet1.getLabels();
        for (int i2 = 0; i2 < numDataPoints; ++i2) {
            dataSet.setLabel(i2, labels[i2]);
        }
        FeatureList featureList = new FeatureList();
        for (Feature feature : dataSet1.getFeatureList().getAll()) {
            featureList.add(feature);
        }
        for (Feature feature : dataSet2.getFeatureList().getAll()) {
            featureList.add(feature);
        }
        dataSet.setFeatureList(featureList);
        dataSet.setLabelTranslator(dataSet1.getLabelTranslator());
        dataSet.setIdTranslator(dataSet1.getIdTranslator());
        return dataSet;
    }

    public static ClfDataSet concatenateByRow(ClfDataSet dataSet1, ClfDataSet dataSet2) {
        double value;
        int j;
        Vector row;
        int i;
        int numDataPoints1 = dataSet1.getNumDataPoints();
        int numDataPoints2 = dataSet2.getNumDataPoints();
        int numDataPoints = numDataPoints1 + numDataPoints2;
        int numFeatures = dataSet1.getNumFeatures();
        ClfDataSet dataSet = ClfDataSetBuilder.getBuilder().numDataPoints(numDataPoints).numFeatures(numFeatures).numClasses(dataSet1.getNumClasses()).dense(dataSet1.isDense()).missingValue(dataSet1.hasMissingValue()).build();
        int dataIndex = 0;
        for (i = 0; i < dataSet1.getNumDataPoints(); ++i) {
            row = dataSet1.getRow(i);
            for (Vector.Element element : row.nonZeroes()) {
                j = element.index();
                value = element.get();
                dataSet.setFeatureValue(dataIndex, j, value);
            }
            dataSet.setLabel(dataIndex, dataSet1.getLabels()[i]);
            ++dataIndex;
        }
        for (i = 0; i < dataSet2.getNumDataPoints(); ++i) {
            row = dataSet2.getRow(i);
            for (Vector.Element element : row.nonZeroes()) {
                j = element.index();
                value = element.get();
                dataSet.setFeatureValue(dataIndex, j, value);
            }
            dataSet.setLabel(dataIndex, dataSet2.getLabels()[i]);
            ++dataIndex;
        }
        dataSet.setFeatureList(dataSet1.getFeatureList());
        dataSet.setLabelTranslator(dataSet1.getLabelTranslator());
        return dataSet;
    }

    public static MultiLabelClfDataSet concatenateByRow(MultiLabelClfDataSet dataSet1, MultiLabelClfDataSet dataSet2) {
        double value;
        int j;
        Vector row;
        int i;
        int numDataPoints1 = dataSet1.getNumDataPoints();
        int numDataPoints2 = dataSet2.getNumDataPoints();
        int numDataPoints = numDataPoints1 + numDataPoints2;
        int numFeatures = dataSet1.getNumFeatures();
        MultiLabelClfDataSet dataSet = MLClfDataSetBuilder.getBuilder().numDataPoints(numDataPoints).numFeatures(numFeatures).numClasses(dataSet1.getNumClasses()).density(dataSet1.density()).missingValue(dataSet1.hasMissingValue()).build();
        int dataIndex = 0;
        for (i = 0; i < dataSet1.getNumDataPoints(); ++i) {
            row = dataSet1.getRow(i);
            for (Vector.Element element : row.nonZeroes()) {
                j = element.index();
                value = element.get();
                dataSet.setFeatureValue(dataIndex, j, value);
            }
            dataSet.setLabels(dataIndex, dataSet1.getMultiLabels()[i]);
            ++dataIndex;
        }
        for (i = 0; i < dataSet2.getNumDataPoints(); ++i) {
            row = dataSet2.getRow(i);
            for (Vector.Element element : row.nonZeroes()) {
                j = element.index();
                value = element.get();
                dataSet.setFeatureValue(dataIndex, j, value);
            }
            dataSet.setLabels(dataIndex, dataSet2.getMultiLabels()[i]);
            ++dataIndex;
        }
        dataSet.setFeatureList(dataSet1.getFeatureList());
        dataSet.setLabelTranslator(dataSet1.getLabelTranslator());
        return dataSet;
    }

    public static MultiLabelClfDataSet concatenateByColumn(MultiLabelClfDataSet dataSet1, MultiLabelClfDataSet dataSet2) {
        double value;
        int i;
        Vector vector;
        int j;
        int numDataPoints = dataSet1.getNumDataPoints();
        int numFeatures1 = dataSet1.getNumFeatures();
        int numFeatures2 = dataSet2.getNumFeatures();
        int numFeatures = numFeatures1 + numFeatures2;
        MultiLabelClfDataSet dataSet = MLClfDataSetBuilder.getBuilder().numDataPoints(numDataPoints).numFeatures(numFeatures).numClasses(dataSet1.getNumClasses()).density(dataSet1.density()).missingValue(dataSet1.hasMissingValue()).build();
        int featureIndex = 0;
        for (j = 0; j < numFeatures1; ++j) {
            vector = dataSet1.getColumn(j);
            for (Vector.Element element : vector.nonZeroes()) {
                i = element.index();
                value = element.get();
                dataSet.setFeatureValue(i, featureIndex, value);
            }
            ++featureIndex;
        }
        for (j = 0; j < numFeatures2; ++j) {
            vector = dataSet2.getColumn(j);
            for (Vector.Element element : vector.nonZeroes()) {
                i = element.index();
                value = element.get();
                dataSet.setFeatureValue(i, featureIndex, value);
            }
            ++featureIndex;
        }
        MultiLabel[] labels = dataSet1.getMultiLabels();
        for (int i2 = 0; i2 < numDataPoints; ++i2) {
            dataSet.setLabels(i2, labels[i2]);
        }
        FeatureList featureList = new FeatureList();
        for (Feature feature : dataSet1.getFeatureList().getAll()) {
            featureList.add(feature);
        }
        for (Feature feature : dataSet2.getFeatureList().getAll()) {
            featureList.add(feature);
        }
        dataSet.setFeatureList(featureList);
        dataSet.setLabelTranslator(dataSet1.getLabelTranslator());
        dataSet.setIdTranslator(dataSet1.getIdTranslator());
        return dataSet;
    }

    public static ClfDataSet sampleByFold(ClfDataSet dataSet, int numFolds, Set<Integer> foldIndices) {
        for (int fold : foldIndices) {
            boolean con = fold >= 1 && fold <= numFolds;
            if (con) continue;
            throw new IllegalArgumentException("should have fold>=1 && fold<=numFolds");
        }
        int numData = dataSet.getNumDataPoints();
        ArrayList<Integer> keep = new ArrayList<Integer>();
        for (int i = 0; i < numData; ++i) {
            int rem = i % numFolds;
            if (!foldIndices.contains(rem + 1)) continue;
            keep.add(i);
        }
        return DataSetUtil.sampleData(dataSet, keep);
    }

    public static List<ClfDataSet> partitionToBatches(ClfDataSet dataSet, int numBatches) {
        ArrayList<ClfDataSet> batches = new ArrayList<ClfDataSet>();
        for (int i = 1; i <= numBatches; ++i) {
            HashSet<Integer> index = new HashSet<Integer>();
            index.add(i);
            batches.add(DataSetUtil.sampleByFold(dataSet, numBatches, index));
        }
        return batches;
    }

    public static MultiLabelClfDataSet sampleByFold(MultiLabelClfDataSet dataSet, int numFolds, Set<Integer> foldIndices) {
        for (int fold : foldIndices) {
            boolean con = fold >= 1 && fold <= numFolds;
            if (con) continue;
            throw new IllegalArgumentException("should have fold>=1 && fold<=numFolds");
        }
        int numData = dataSet.getNumDataPoints();
        ArrayList<Integer> keep = new ArrayList<Integer>();
        for (int i = 0; i < numData; ++i) {
            int rem = i % numFolds;
            if (!foldIndices.contains(rem + 1)) continue;
            keep.add(i);
        }
        return DataSetUtil.sampleData(dataSet, keep);
    }

    public static List<MultiLabelClfDataSet> partitionToBatches(MultiLabelClfDataSet dataSet, int numBatches) {
        ArrayList<MultiLabelClfDataSet> batches = new ArrayList<MultiLabelClfDataSet>();
        for (int i = 1; i <= numBatches; ++i) {
            HashSet<Integer> index = new HashSet<Integer>();
            index.add(i);
            batches.add(DataSetUtil.sampleByFold(dataSet, numBatches, index));
        }
        return batches;
    }

    public static RegDataSet sampleByFold(RegDataSet dataSet, int numFolds, Set<Integer> foldIndices) {
        for (int fold : foldIndices) {
            boolean con = fold >= 1 && fold <= numFolds;
            if (con) continue;
            throw new IllegalArgumentException("should have fold>=1 && fold<=numFolds");
        }
        int numData = dataSet.getNumDataPoints();
        ArrayList<Integer> keep = new ArrayList<Integer>();
        for (int i = 0; i < numData; ++i) {
            int rem = i % numFolds;
            if (!foldIndices.contains(rem + 1)) continue;
            keep.add(i);
        }
        return DataSetUtil.sampleData(dataSet, keep);
    }

    public static Pair<DataSet, double[][]> sampleByFold(DataSet dataSet, double[][] targetDistribution, int numFolds, Set<Integer> foldIndices) {
        for (int fold : foldIndices) {
            boolean con = fold >= 1 && fold <= numFolds;
            if (con) continue;
            throw new IllegalArgumentException("should have fold>=1 && fold<=numFolds");
        }
        int numData = dataSet.getNumDataPoints();
        ArrayList<Integer> keep = new ArrayList<Integer>();
        for (int i = 0; i < numData; ++i) {
            int rem = i % numFolds;
            if (!foldIndices.contains(rem + 1)) continue;
            keep.add(i);
        }
        return DataSetUtil.sampleData(dataSet, targetDistribution, keep);
    }

    public static Pair<ClfDataSet, ClfDataSet> splitToTrainValidation(ClfDataSet clfDataSet, double trainPercentage) {
        int numDataPoints = clfDataSet.getNumDataPoints();
        List<Integer> trainIndices = Sampling.stratified(clfDataSet.getLabels(), trainPercentage);
        HashSet<Integer> testIndicesSet = new HashSet<Integer>();
        for (int i = 0; i < numDataPoints; ++i) {
            testIndicesSet.add(i);
        }
        testIndicesSet.removeAll(trainIndices);
        List<Integer> testIndices = testIndicesSet.stream().collect(Collectors.toList());
        Pair<ClfDataSet, ClfDataSet> pair = new Pair<ClfDataSet, ClfDataSet>();
        pair.setFirst(DataSetUtil.sampleData(clfDataSet, trainIndices));
        pair.setSecond(DataSetUtil.sampleData(clfDataSet, testIndices));
        return pair;
    }

    public static Pair<RegDataSet, RegDataSet> splitToTrainValidation(RegDataSet dataSet, double trainPercentage) {
        int numDataPoints = dataSet.getNumDataPoints();
        List<Integer> all = IntStream.range(0, dataSet.getNumDataPoints()).mapToObj(i -> i).collect(Collectors.toList());
        List<Integer> trainIndices = Sampling.sampleByPercentage(all, trainPercentage);
        HashSet<Integer> testIndicesSet = new HashSet<Integer>();
        for (int i2 = 0; i2 < numDataPoints; ++i2) {
            testIndicesSet.add(i2);
        }
        testIndicesSet.removeAll(trainIndices);
        List<Integer> testIndices = testIndicesSet.stream().collect(Collectors.toList());
        Pair<RegDataSet, RegDataSet> pair = new Pair<RegDataSet, RegDataSet>();
        pair.setFirst(DataSetUtil.sampleData(dataSet, trainIndices));
        pair.setSecond(DataSetUtil.sampleData(dataSet, testIndices));
        return pair;
    }

    public static Pair<MultiLabelClfDataSet, MultiLabelClfDataSet> splitToTrainValidation(MultiLabelClfDataSet multiLabelClfDataSet, double trainPercentage) {
        int numDataPoints = multiLabelClfDataSet.getNumDataPoints();
        List<Integer> all = IntStream.range(0, numDataPoints).mapToObj(i -> i).collect(Collectors.toList());
        List<Integer> trainIndices = Sampling.sampleByPercentage(all, trainPercentage);
        HashSet<Integer> testIndicesSet = new HashSet<Integer>();
        for (int i2 = 0; i2 < numDataPoints; ++i2) {
            testIndicesSet.add(i2);
        }
        testIndicesSet.removeAll(trainIndices);
        List<Integer> testIndices = testIndicesSet.stream().collect(Collectors.toList());
        Pair<MultiLabelClfDataSet, MultiLabelClfDataSet> pair = new Pair<MultiLabelClfDataSet, MultiLabelClfDataSet>();
        pair.setFirst(DataSetUtil.sampleData(multiLabelClfDataSet, trainIndices));
        pair.setSecond(DataSetUtil.sampleData(multiLabelClfDataSet, testIndices));
        return pair;
    }

    public static void dumpDataPointSettings(ClfDataSet dataSet, String file) throws IOException {
        DataSetUtil.dumpDataPointSettings(dataSet, new File(file));
    }

    public static void dumpDataPointSettings(ClfDataSet dataSet, File file) throws IOException {
        int numDataPoints = dataSet.getNumDataPoints();
        int[] labels = dataSet.getLabels();
        IdTranslator idTranslator = dataSet.getIdTranslator();
        LabelTranslator labelTranslator = dataSet.getLabelTranslator();
        try (BufferedWriter bw = new BufferedWriter(new FileWriter(file));){
            for (int i = 0; i < numDataPoints; ++i) {
                bw.write("intId=");
                bw.write("" + i);
                bw.write(",");
                bw.write("extId=");
                bw.write(idTranslator.toExtId(i));
                bw.write(",");
                bw.write("intLabel=");
                bw.write("" + labels[i]);
                bw.write(",");
                bw.write("extLabel=");
                bw.write(labelTranslator.toExtLabel(labels[i]));
                bw.newLine();
            }
        }
    }

    public static void dumpDataPointSettings(MultiLabelClfDataSet dataSet, String file) throws IOException {
        DataSetUtil.dumpDataPointSettings(dataSet, new File(file));
    }

    public static void dumpDataPointSettings(MultiLabelClfDataSet dataSet, File file) throws IOException {
        IdTranslator idTranslator = dataSet.getIdTranslator();
        LabelTranslator labelTranslator = dataSet.getLabelTranslator();
        int numDataPoints = dataSet.getNumDataPoints();
        MultiLabel[] labels = dataSet.getMultiLabels();
        try (BufferedWriter bw = new BufferedWriter(new FileWriter(file));){
            for (int i = 0; i < numDataPoints; ++i) {
                bw.write("intId=");
                bw.write("" + i);
                bw.write(",");
                bw.write("extId=");
                bw.write(idTranslator.toExtId(i));
                bw.write(",");
                bw.write("intLabel=");
                bw.write("" + labels[i].getMatchedLabelsOrdered());
                bw.write(",");
                bw.write("extLabel=");
                bw.write(labels[i].getMatchedLabelsOrdered().stream().map(labelTranslator::toExtLabel).collect(Collectors.toList()).toString());
                bw.newLine();
            }
        }
    }

    public static void dumpFeatureSettings(DataSet dataSet, String file) throws IOException {
        DataSetUtil.dumpFeatureSettings(dataSet, new File(file));
    }

    public static void dumpFeatureSettings(DataSet dataSet, File file) throws IOException {
        int numFeatures = dataSet.getNumFeatures();
        List<Feature> features = dataSet.getFeatureList().getAll();
        try (BufferedWriter bw = new BufferedWriter(new FileWriter(file));){
            for (int j = 0; j < numFeatures; ++j) {
                bw.write(features.get(j).toString());
                bw.newLine();
            }
        }
    }

    public static List<MultiLabel> gatherMultiLabels(MultiLabelClfDataSet dataSet) {
        MultiLabel[] multiLabelsArray;
        HashSet<MultiLabel> multiLabels = new HashSet<MultiLabel>();
        for (MultiLabel multiLabel : multiLabelsArray = dataSet.getMultiLabels()) {
            multiLabels.add(multiLabel);
        }
        return multiLabels.stream().collect(Collectors.toList());
    }

    public static Set<Integer> gatherLabels(MultiLabelClfDataSet dataSet) {
        MultiLabel[] multiLabelsArray;
        HashSet<Integer> labels = new HashSet<Integer>();
        for (MultiLabel multiLabel : multiLabelsArray = dataSet.getMultiLabels()) {
            labels.addAll(multiLabel.getMatchedLabels());
        }
        return labels;
    }

    public static int[] toBinaryLabels(MultiLabel[] multiLabels, int k) {
        int[] binaryLabels = new int[multiLabels.length];
        for (int i = 0; i < multiLabels.length; ++i) {
            if (!multiLabels[i].matchClass(k)) continue;
            binaryLabels[i] = 1;
        }
        return binaryLabels;
    }

    public static ClfDataSet toBinary(MultiLabelClfDataSet dataSet, int k) {
        int numDataPoints = dataSet.getNumDataPoints();
        int numFeatures = dataSet.getNumFeatures();
        boolean missingValue = dataSet.hasMissingValue();
        AbstractDataSet clfDataSet = dataSet.isDense() ? new DenseClfDataSet(numDataPoints, numFeatures, missingValue, 2) : new SparseClfDataSet(numDataPoints, numFeatures, missingValue, 2);
        for (int i = 0; i < numDataPoints; ++i) {
            Vector vector = dataSet.getRow(i);
            for (Vector.Element element : vector.nonZeroes()) {
                int featureIndex = element.index();
                double value = element.get();
                clfDataSet.setFeatureValue(i, featureIndex, value);
            }
            if (dataSet.getMultiLabels()[i].matchClass(k)) {
                clfDataSet.setLabel(i, 1);
                continue;
            }
            clfDataSet.setLabel(i, 0);
        }
        ArrayList<String> extLabels = new ArrayList<String>();
        String extLabel = dataSet.getLabelTranslator().toExtLabel(k);
        extLabels.add("NOT " + extLabel);
        extLabels.add(extLabel);
        LabelTranslator labelTranslator = new LabelTranslator(extLabels);
        clfDataSet.setLabelTranslator(labelTranslator);
        clfDataSet.setFeatureList(dataSet.getFeatureList());
        return clfDataSet;
    }

    public static Pair<ClfDataSet, Translator<MultiLabel>> toMultiClass(MultiLabelClfDataSet dataSet) {
        int numDataPoints = dataSet.getNumDataPoints();
        int numFeatures = dataSet.getNumFeatures();
        List<MultiLabel> multiLabels = DataSetUtil.gatherMultiLabels(dataSet);
        Translator<MultiLabel> translator = new Translator<MultiLabel>();
        translator.addAll(multiLabels);
        ClfDataSet clfDataSet = ClfDataSetBuilder.getBuilder().numDataPoints(numDataPoints).numFeatures(numFeatures).dense(dataSet.isDense()).missingValue(dataSet.hasMissingValue()).numClasses(translator.size()).build();
        for (int i = 0; i < numDataPoints; ++i) {
            Vector vector = dataSet.getRow(i);
            for (Vector.Element element : vector.nonZeroes()) {
                int featureIndex = element.index();
                double value = element.get();
                clfDataSet.setFeatureValue(i, featureIndex, value);
            }
            int label = translator.getIndex(dataSet.getMultiLabels()[i]);
            clfDataSet.setLabel(i, label);
        }
        List<String> extLabels = multiLabels.stream().map(MultiLabel::toString).collect(Collectors.toList());
        LabelTranslator labelTranslator = new LabelTranslator(extLabels);
        clfDataSet.setLabelTranslator(labelTranslator);
        clfDataSet.setFeatureList(dataSet.getFeatureList());
        return new Pair<ClfDataSet, Translator<MultiLabel>>(clfDataSet, translator);
    }

    public static void allowMissingValue(DataSet dataSet) {
        if (dataSet instanceof AbstractDataSet) {
            ((AbstractDataSet)dataSet).allowMissingValue();
        }
    }

    public static int[] getCountPerClass(ClfDataSet dataSet) {
        int numClasses = dataSet.getNumClasses();
        int[] counts = new int[numClasses];
        int[] labels = dataSet.getLabels();
        for (int i = 0; i < dataSet.getNumDataPoints(); ++i) {
            int label;
            int n = label = labels[i];
            counts[n] = counts[n] + 1;
        }
        return counts;
    }

    public static int[] getCountPerClass(MultiLabelClfDataSet dataSet) {
        int numClasses = dataSet.getNumClasses();
        int[] counts = new int[numClasses];
        MultiLabel[] labels = dataSet.getMultiLabels();
        for (int i = 0; i < dataSet.getNumDataPoints(); ++i) {
            MultiLabel multiLabel = labels[i];
            Iterator<Integer> iterator = multiLabel.getMatchedLabels().iterator();
            while (iterator.hasNext()) {
                int label;
                int n = label = iterator.next().intValue();
                counts[n] = counts[n] + 1;
            }
        }
        return counts;
    }

    public static List<List<Integer>> labelToDataPoints(ClfDataSet dataSet) {
        int numClasses = dataSet.getNumClasses();
        int[] labels = dataSet.getLabels();
        ArrayList<List<Integer>> list = new ArrayList<List<Integer>>();
        for (int k = 0; k < numClasses; ++k) {
            list.add(new ArrayList());
        }
        for (int i = 0; i < dataSet.getNumDataPoints(); ++i) {
            int label = labels[i];
            ((List)list.get(label)).add(i);
        }
        return list;
    }

    public static List<List<Integer>> labelToDataPoints(MultiLabelClfDataSet dataSet) {
        int numClasses = dataSet.getNumClasses();
        MultiLabel[] labels = dataSet.getMultiLabels();
        ArrayList<List<Integer>> list = new ArrayList<List<Integer>>();
        for (int k = 0; k < numClasses; ++k) {
            list.add(new ArrayList());
        }
        for (int i = 0; i < dataSet.getNumDataPoints(); ++i) {
            MultiLabel multiLabel = labels[i];
            for (int label : multiLabel.getMatchedLabels()) {
                ((List)list.get(label)).add(i);
            }
        }
        return list;
    }

    public static double density(DataSet dataSet) {
        int nonZeros = IntStream.range(0, dataSet.getNumDataPoints()).parallel().map(i -> dataSet.getRow(i).getNumNonZeroElements()).sum();
        return (double)nonZeros / (double)(dataSet.getNumDataPoints() * dataSet.getNumFeatures());
    }

    public static void setFeatureNames(DataSet dataSet, List<String> names) {
        if (dataSet.getNumFeatures() != names.size()) {
            throw new IllegalArgumentException("dataSet.getNumFeatures()!=names.size()");
        }
        for (int i = 0; i < names.size(); ++i) {
            dataSet.getFeatureList().get(i).setName(names.get(i));
        }
    }

    public static void binarizeFeature(DataSet dataSet) {
        for (int i = 0; i < dataSet.getNumDataPoints(); ++i) {
            ArrayList<Integer> nonZeors = new ArrayList<Integer>();
            Vector row = dataSet.getRow(i);
            for (Vector.Element element : row.nonZeroes()) {
                nonZeors.add(element.index());
            }
            Iterator<Object> iterator = nonZeors.iterator();
            while (iterator.hasNext()) {
                int j = (Integer)iterator.next();
                dataSet.setFeatureValue(i, j, 1.0);
            }
        }
    }

    public static double[][] labelDistribution(ClfDataSet dataSet) {
        int numData = dataSet.getNumDataPoints();
        int numClass = dataSet.getNumClasses();
        double[][] dis = new double[numData][numClass];
        int[] labels = dataSet.getLabels();
        for (int i = 0; i < numData; ++i) {
            int label = labels[i];
            dis[i][label] = 1.0;
        }
        return dis;
    }

    public static String multiLabelToBinaryString(MultiLabelClfDataSet dataSet) {
        int numData = dataSet.getNumDataPoints();
        int numClasses = dataSet.getNumClasses();
        StringBuilder sb = new StringBuilder();
        for (int i = 0; i < numData; ++i) {
            MultiLabel multiLabel = dataSet.getMultiLabels()[i];
            for (int l = 0; l < numClasses; ++l) {
                String bit = multiLabel.matchClass(l) ? "1" : "0";
                sb.append(bit);
                if (l >= numClasses - 1) continue;
                sb.append(" ");
            }
            sb.append("\n");
        }
        return sb.toString();
    }

    public static void detectDuplicate(MultiLabelClfDataSet train, MultiLabelClfDataSet test) {
        HashSet<Vector> vectors = new HashSet<Vector>();
        for (int i = 0; i < train.getNumDataPoints(); ++i) {
            vectors.add(train.getRow(i));
        }
        ArrayList<Integer> duplicate = new ArrayList<Integer>();
        for (int i = 0; i < test.getNumDataPoints(); ++i) {
            if (!vectors.contains(test.getRow(i))) continue;
            duplicate.add(i);
        }
        System.out.println("number of test data points which occur in training set = " + duplicate.size());
        System.out.println("duplicates = " + duplicate);
    }

    public static void dataComparasion(MultiLabelClfDataSet trainSet, MultiLabelClfDataSet testSet) {
        System.out.println("---------------------------------Data Comparasion------------------------------");
        System.out.println("Number of Features: " + trainSet.getNumFeatures());
        System.out.println("Number of Labels: " + trainSet.getNumClasses());
        System.out.println("Number of Training: " + trainSet.getNumDataPoints());
        System.out.println("Number of Testing: " + testSet.getNumDataPoints());
        HashSet<MultiLabel> trainLabelSet = new HashSet<MultiLabel>();
        HashSet<MultiLabel> testLabelSet = new HashSet<MultiLabel>();
        for (MultiLabel multiLabel : trainSet.getMultiLabels()) {
            trainLabelSet.add(multiLabel);
        }
        for (MultiLabel multiLabel : testSet.getMultiLabels()) {
            testLabelSet.add(multiLabel);
        }
        System.out.println("Train label Cardinality: " + trainSet.labelCardinality());
        System.out.println("Test label Cardinality: " + testSet.labelCardinality());
        System.out.println("Train label Density: " + trainSet.labelDensity());
        System.out.println("Test label Density: " + testSet.labelDensity());
        System.out.println();
        System.out.println("Train distinct label num: " + trainLabelSet.size());
        System.out.println("Test distinct label num: " + testLabelSet.size());
        Set unionSet = SetUtil.union(trainLabelSet, testLabelSet);
        System.out.println("Union distinct label num: " + unionSet.size());
        Set intersectSet = SetUtil.intersect(trainLabelSet, testLabelSet);
        System.out.println("Intersect distinct label num: " + intersectSet.size());
        Set newTestSet = SetUtil.complement(testLabelSet, trainLabelSet);
        System.out.println("New label combination number in test: " + newTestSet.size());
        int newTestLabelCounts = 0;
        for (MultiLabel label : testSet.getMultiLabels()) {
            if (!newTestSet.contains(label)) continue;
            ++newTestLabelCounts;
        }
        System.out.println("New label combination data counts: " + newTestLabelCounts);
        System.out.println("New label combination data rate: " + (double)newTestLabelCounts / (double)testSet.getNumDataPoints());
        System.out.println("---------------------------------------------------------------");
        System.out.println();
        System.out.println();
    }

    public static DataSet loadFeatureMatrixFromCSV(String filename, int numData, int numFeatures) {
        DenseClfDataSet clfDataSet = new DenseClfDataSet(numData, numFeatures, false, 2);
        try {
            BufferedReader br = new BufferedReader(new FileReader(filename));
            int i = 0;
            String line = br.readLine();
            while (line != null) {
                String[] lineSplit = line.split(",");
                for (int j = 0; j < numFeatures; ++j) {
                    double featureValue = Double.parseDouble(lineSplit[j]);
                    clfDataSet.setFeatureValue(i, j, featureValue);
                }
                ++i;
                line = br.readLine();
            }
        }
        catch (FileNotFoundException e) {
            e.printStackTrace();
        }
        catch (IOException e) {
            e.printStackTrace();
        }
        return clfDataSet;
    }

    public static void saveFeatureMatrixToCSV(String filename, DataSet dataSet) throws IOException {
        FileWriter writer = new FileWriter(filename);
        for (int i = 0; i < dataSet.getNumDataPoints(); ++i) {
            StringBuilder sb = new StringBuilder();
            for (int j = 0; j < dataSet.getNumFeatures(); ++j) {
                if (j != 0) {
                    sb.append(",");
                }
                sb.append(dataSet.getRow(i).get(j));
            }
            sb.append("\n");
            writer.append(sb.toString());
        }
        writer.close();
    }

    public static double[][] labelsToDistributions(int[] labels, int numClass) {
        int numData = labels.length;
        double[][] distribution = new double[numData][numClass];
        for (int i = 0; i < numData; ++i) {
            int label = labels[i];
            distribution[i][label] = 1.0;
        }
        return distribution;
    }

    public static List<Integer> unobservedLabels(MultiLabelClfDataSet dataSet) {
        boolean[] check = new boolean[dataSet.getNumClasses()];
        for (int i = 0; i < dataSet.getNumDataPoints(); ++i) {
            MultiLabel multiLabel = dataSet.getMultiLabels()[i];
            for (int l2 : multiLabel.getMatchedLabels()) {
                check[l2] = true;
            }
        }
        return IntStream.range(0, dataSet.getNumClasses()).filter(l -> !check[l]).boxed().sorted().collect(Collectors.toList());
    }
}

