/*
 * Decompiled with CFR 0.152.
 */
package edu.neu.ccs.pyramid.application;

import edu.neu.ccs.pyramid.configuration.Config;
import edu.neu.ccs.pyramid.dataset.DataSetUtil;
import edu.neu.ccs.pyramid.dataset.MultiLabel;
import edu.neu.ccs.pyramid.dataset.MultiLabelClfDataSet;
import edu.neu.ccs.pyramid.dataset.TRECFormat;
import edu.neu.ccs.pyramid.eval.MLMeasures;
import edu.neu.ccs.pyramid.multilabel_classification.PluginPredictor;
import edu.neu.ccs.pyramid.multilabel_classification.cbm.AccPredictor;
import edu.neu.ccs.pyramid.multilabel_classification.cbm.CBM;
import edu.neu.ccs.pyramid.multilabel_classification.cbm.CBMInspector;
import edu.neu.ccs.pyramid.multilabel_classification.cbm.LRCBMOptimizer;
import edu.neu.ccs.pyramid.multilabel_classification.cbm.MarginalPredictor;
import edu.neu.ccs.pyramid.multilabel_classification.cbm.PluginF1;
import edu.neu.ccs.pyramid.optimization.EarlyStopper;
import edu.neu.ccs.pyramid.util.ListUtil;
import edu.neu.ccs.pyramid.util.Pair;
import edu.neu.ccs.pyramid.util.PrintUtil;
import edu.neu.ccs.pyramid.util.Serialization;
import java.io.BufferedWriter;
import java.io.File;
import java.io.FileWriter;
import java.nio.file.Paths;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Comparator;
import java.util.List;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import org.apache.commons.io.FileUtils;
import org.apache.commons.lang3.time.StopWatch;

public class CBMLR {
    private static boolean VERBOSE = false;

    public static void main(String[] args) throws Exception {
        if (args.length != 1) {
            throw new IllegalArgumentException("Please specify a properties file.");
        }
        Config config = new Config(args[0]);
        System.out.println(config);
        VERBOSE = config.getBoolean("output.verbose");
        new File(config.getString("output.dir")).mkdirs();
        if (config.getBoolean("tune")) {
            TuneResult best;
            System.out.println("============================================================");
            System.out.println("Start hyper parameter tuning");
            StopWatch stopWatch = new StopWatch();
            stopWatch.start();
            ArrayList<TuneResult> tuneResults = new ArrayList<TuneResult>();
            List<MultiLabelClfDataSet> dataSets = CBMLR.loadTrainValidData(config);
            List<Double> variances = config.getDoubles("tune.variance.candidates");
            List<Integer> components = config.getIntegers("tune.numComponents.candidates");
            for (double variance : variances) {
                Object object = components.iterator();
                while (object.hasNext()) {
                    int component = (Integer)object.next();
                    StopWatch stopWatch1 = new StopWatch();
                    stopWatch1.start();
                    HyperParameters hyperParameters = new HyperParameters();
                    hyperParameters.numComponents = component;
                    hyperParameters.variance = variance;
                    System.out.println("---------------------------");
                    System.out.println("Trying hyper parameters:");
                    System.out.println("train.numComponents = " + hyperParameters.numComponents);
                    System.out.println("train.variance = " + hyperParameters.variance);
                    TuneResult tuneResult = CBMLR.tune(config, hyperParameters, dataSets.get(0), dataSets.get(1));
                    System.out.println("Found optimal train.iterations = " + tuneResult.hyperParameters.iterations);
                    System.out.println("Validation performance = " + tuneResult.performance);
                    tuneResults.add(tuneResult);
                    System.out.println("Time spent on trying this set of hyper parameters = " + stopWatch1);
                }
            }
            Comparator<TuneResult> comparator = Comparator.comparing(res -> res.performance);
            String predictTarget = config.getString("tune.targetMetric");
            switch (predictTarget) {
                case "instance_set_accuracy": {
                    best = tuneResults.stream().max(comparator).get();
                    break;
                }
                case "instance_f1": {
                    best = tuneResults.stream().max(comparator).get();
                    break;
                }
                case "instance_hamming_loss": {
                    best = tuneResults.stream().min(comparator).get();
                    break;
                }
                default: {
                    throw new IllegalArgumentException("tune.targetMetric should be instance_set_accuracy, instance_f1 or instance_hamming_loss");
                }
            }
            System.out.println("---------------------------");
            System.out.println("Hyper parameter tuning done.");
            System.out.println("Time spent on entire hyper parameter tuning = " + stopWatch);
            System.out.println("Best validation performance = " + best.performance);
            System.out.println("Best hyper parameters:");
            System.out.println("train.numComponents = " + best.hyperParameters.numComponents);
            System.out.println("train.variance = " + best.hyperParameters.variance);
            System.out.println("train.iterations = " + best.hyperParameters.iterations);
            Config tunedHypers = best.hyperParameters.asConfig();
            tunedHypers.store(new File(config.getString("output.dir"), "tuned_hyper_parameters.properties"));
            System.out.println("Tuned hyper parameters saved to " + new File(config.getString("output.dir"), "tuned_hyper_parameters.properties").getAbsolutePath());
            System.out.println("============================================================");
        }
        if (config.getBoolean("train")) {
            System.out.println("============================================================");
            if (config.getBoolean("train.useTunedHyperParameters")) {
                File hyperFile = new File(config.getString("output.dir"), "tuned_hyper_parameters.properties");
                if (!hyperFile.exists()) {
                    System.out.println("train.useTunedHyperParameters is set to true. But no tuned hyper parameters can be found in the output directory.");
                    System.out.println("Please either run hyper parameter tuning, or provide hyper parameters manually and set train.useTunedHyperParameters=false.");
                    System.exit(1);
                }
                Config tunedHypers = new Config(hyperFile);
                HyperParameters hyperParameters = new HyperParameters(tunedHypers);
                System.out.println("Start training with tuned hyper parameters:");
                System.out.println("train.numComponents = " + hyperParameters.numComponents);
                System.out.println("train.variance = " + hyperParameters.variance);
                System.out.println("train.iterations = " + hyperParameters.iterations);
                MultiLabelClfDataSet trainSet = CBMLR.loadTrainData(config);
                CBMLR.train(config, hyperParameters, trainSet);
            } else {
                HyperParameters hyperParameters = new HyperParameters(config);
                System.out.println("Start training with given hyper parameters:");
                System.out.println("train.numComponents = " + hyperParameters.numComponents);
                System.out.println("train.variance = " + hyperParameters.variance);
                System.out.println("train.iterations = " + hyperParameters.iterations);
                MultiLabelClfDataSet trainSet = CBMLR.loadTrainData(config);
                CBMLR.train(config, hyperParameters, trainSet);
            }
            System.out.println("============================================================");
        }
        if (config.getBoolean("test")) {
            System.out.println("============================================================");
            CBMLR.test(config);
            System.out.println("============================================================");
        }
    }

    private static TuneResult tune(Config config, HyperParameters hyperParameters, MultiLabelClfDataSet trainSet, MultiLabelClfDataSet validSet) throws Exception {
        PluginPredictor<CBM> classifier;
        String predictTarget;
        CBM cbm = CBMLR.newCBM(config, trainSet, hyperParameters);
        EarlyStopper earlyStopper = CBMLR.loadNewEarlyStopper(config);
        LRCBMOptimizer optimizer = CBMLR.getOptimizer(config, hyperParameters, cbm, trainSet);
        optimizer.initialize();
        switch (predictTarget = config.getString("tune.targetMetric")) {
            case "instance_set_accuracy": {
                AccPredictor accPredictor = new AccPredictor(cbm);
                accPredictor.setComponentContributionThreshold(config.getDouble("predict.piThreshold"));
                classifier = accPredictor;
                break;
            }
            case "instance_f1": {
                PluginF1 pluginF1 = new PluginF1(cbm);
                List<MultiLabel> support = DataSetUtil.gatherMultiLabels(trainSet);
                pluginF1.setSupport(support);
                pluginF1.setPiThreshold(config.getDouble("predict.piThreshold"));
                classifier = pluginF1;
                break;
            }
            case "instance_hamming_loss": {
                MarginalPredictor marginalPredictor = new MarginalPredictor(cbm);
                marginalPredictor.setPiThreshold(config.getDouble("predict.piThreshold"));
                classifier = marginalPredictor;
                break;
            }
            default: {
                throw new IllegalArgumentException("predictTarget should be instance_set_accuracy, instance_f1 or instance_hamming_loss");
            }
        }
        int interval = config.getInt("tune.monitorInterval");
        int iter = 1;
        while (true) {
            if (VERBOSE) {
                System.out.println("iteration " + iter);
            }
            optimizer.iterate();
            if (iter % interval == 0) {
                MLMeasures validMeasures = new MLMeasures(classifier, validSet);
                if (VERBOSE) {
                    System.out.println("validation performance with " + predictTarget + " optimal predictor:");
                    System.out.println(validMeasures);
                }
                switch (predictTarget) {
                    case "instance_set_accuracy": {
                        earlyStopper.add(iter, validMeasures.getInstanceAverage().getAccuracy());
                        break;
                    }
                    case "instance_f1": {
                        earlyStopper.add(iter, validMeasures.getInstanceAverage().getF1());
                        break;
                    }
                    case "instance_hamming_loss": {
                        earlyStopper.add(iter, validMeasures.getInstanceAverage().getHammingLoss());
                        break;
                    }
                    default: {
                        throw new IllegalArgumentException("predictTarget should be instance_set_accuracy or instance_f1");
                    }
                }
                if (earlyStopper.shouldStop()) {
                    if (!VERBOSE) break;
                    System.out.println("Early Stopper: the training should stop now!");
                    break;
                }
            }
            ++iter;
        }
        if (VERBOSE) {
            System.out.println("done!");
        }
        hyperParameters.iterations = earlyStopper.getBestIteration();
        TuneResult tuneResult = new TuneResult();
        tuneResult.hyperParameters = hyperParameters;
        tuneResult.performance = earlyStopper.getBestValue();
        return tuneResult;
    }

    private static void train(Config config, HyperParameters hyperParameters, MultiLabelClfDataSet trainSet) throws Exception {
        List<Integer> unobservedLabels = DataSetUtil.unobservedLabels(trainSet);
        if (!unobservedLabels.isEmpty()) {
            System.out.println("The following labels do not actually appear in the training set and therefore cannot be learned:");
            System.out.println(ListUtil.toSimpleString(unobservedLabels));
        }
        String output = config.getString("output.dir");
        FileUtils.writeStringToFile((File)new File(output, "unobserved_labels.txt"), (String)ListUtil.toSimpleString(unobservedLabels));
        StopWatch stopWatch = new StopWatch();
        stopWatch.start();
        CBM cbm = CBMLR.newCBM(config, trainSet, hyperParameters);
        LRCBMOptimizer optimizer = CBMLR.getOptimizer(config, hyperParameters, cbm, trainSet);
        System.out.println("Initializing the model");
        optimizer.initialize();
        System.out.println("Initialization done");
        for (int iter = 1; iter <= hyperParameters.iterations; ++iter) {
            System.out.println("Training progress: iteration " + iter);
            optimizer.iterate();
        }
        System.out.println("training done!");
        System.out.println("time spent on training = " + stopWatch);
        Serialization.serialize((Object)cbm, new File(output, "model"));
        List<MultiLabel> support = DataSetUtil.gatherMultiLabels(trainSet);
        Serialization.serialize(support, new File(output, "support"));
    }

    private static void test(Config config) throws Exception {
        MultiLabelClfDataSet testSet = TRECFormat.loadMultiLabelClfDataSetAutoSparseSequential(config.getString("input.testData"));
        String output = config.getString("output.dir");
        CBM cbm = (CBM)Serialization.deserialize(new File(output, "model"));
        System.out.println();
        System.out.println("Making predictions on test set with 3 different predictors designed for different metrics:");
        CBMLR.reportAccPrediction(config, cbm, testSet);
        CBMLR.reportF1Prediction(config, cbm, testSet);
        CBMLR.reportHammingPrediction(config, cbm, testSet);
        CBMLR.reportGeneral(config, cbm, testSet);
        System.out.println();
    }

    private static void reportAccPrediction(Config config, CBM cbm, MultiLabelClfDataSet dataSet) throws Exception {
        System.out.println("============================================================");
        System.out.println("Making predictions on test set with the instance set accuracy optimal predictor");
        String output = config.getString("output.dir");
        AccPredictor accPredictor = new AccPredictor(cbm);
        accPredictor.setComponentContributionThreshold(config.getDouble("predict.piThreshold"));
        MultiLabel[] predictions = accPredictor.predict(dataSet);
        MLMeasures mlMeasures = new MLMeasures(dataSet.getNumClasses(), dataSet.getMultiLabels(), predictions);
        System.out.println("test performance with the instance set accuracy optimal predictor");
        System.out.println(mlMeasures);
        File performanceFile = Paths.get(output, "test_predictions", "instance_accuracy_optimal", "performance.txt").toFile();
        FileUtils.writeStringToFile((File)performanceFile, (String)mlMeasures.toString());
        System.out.println("test performance is saved to " + performanceFile.toString());
        double[] setProbs = IntStream.range(0, predictions.length).parallel().mapToDouble(i -> cbm.predictAssignmentProb(dataSet.getRow(i), predictions[i])).toArray();
        File predictionFile = Paths.get(output, "test_predictions", "instance_accuracy_optimal", "predictions.txt").toFile();
        try (BufferedWriter br = new BufferedWriter(new FileWriter(predictionFile));){
            for (int i2 = 0; i2 < dataSet.getNumDataPoints(); ++i2) {
                br.write(predictions[i2].toString());
                br.write(":");
                br.write("" + setProbs[i2]);
                br.newLine();
            }
        }
        System.out.println("predicted sets and their probabilities are saved to " + predictionFile.getAbsolutePath());
        System.out.println("============================================================");
    }

    private static void reportF1Prediction(Config config, CBM cbm, MultiLabelClfDataSet dataSet) throws Exception {
        System.out.println("============================================================");
        System.out.println("Making predictions on test set with the instance F1 optimal predictor");
        String output = config.getString("output.dir");
        PluginF1 pluginF1 = new PluginF1(cbm);
        List support = (List)Serialization.deserialize(new File(output, "support"));
        pluginF1.setSupport(support);
        pluginF1.setPiThreshold(config.getDouble("predict.piThreshold"));
        MultiLabel[] predictions = pluginF1.predict(dataSet);
        MLMeasures mlMeasures = new MLMeasures(dataSet.getNumClasses(), dataSet.getMultiLabels(), predictions);
        System.out.println("test performance with the instance F1 optimal predictor");
        System.out.println(mlMeasures);
        File performanceFile = Paths.get(output, "test_predictions", "instance_f1_optimal", "performance.txt").toFile();
        FileUtils.writeStringToFile((File)performanceFile, (String)mlMeasures.toString());
        System.out.println("test performance is saved to " + performanceFile.toString());
        double[] setProbs = IntStream.range(0, predictions.length).parallel().mapToDouble(i -> cbm.predictAssignmentProb(dataSet.getRow(i), predictions[i])).toArray();
        File predictionFile = Paths.get(output, "test_predictions", "instance_f1_optimal", "predictions.txt").toFile();
        try (BufferedWriter br = new BufferedWriter(new FileWriter(predictionFile));){
            for (int i2 = 0; i2 < dataSet.getNumDataPoints(); ++i2) {
                br.write(predictions[i2].toString());
                br.write(":");
                br.write("" + setProbs[i2]);
                br.newLine();
            }
        }
        System.out.println("predicted sets and their probabilities are saved to " + predictionFile.getAbsolutePath());
        System.out.println("============================================================");
    }

    private static void reportHammingPrediction(Config config, CBM cbm, MultiLabelClfDataSet dataSet) throws Exception {
        System.out.println("============================================================");
        System.out.println("Making predictions on test set with the instance Hamming loss optimal predictor");
        String output = config.getString("output.dir");
        MarginalPredictor marginalPredictor = new MarginalPredictor(cbm);
        marginalPredictor.setPiThreshold(config.getDouble("predict.piThreshold"));
        MultiLabel[] predictions = marginalPredictor.predict(dataSet);
        MLMeasures mlMeasures = new MLMeasures(dataSet.getNumClasses(), dataSet.getMultiLabels(), predictions);
        System.out.println("test performance with the instance Hamming loss optimal predictor");
        System.out.println(mlMeasures);
        File performanceFile = Paths.get(output, "test_predictions", "instance_hamming_loss_optimal", "performance.txt").toFile();
        FileUtils.writeStringToFile((File)performanceFile, (String)mlMeasures.toString());
        System.out.println("test performance is saved to " + performanceFile.toString());
        double[] setProbs = IntStream.range(0, predictions.length).parallel().mapToDouble(i -> cbm.predictAssignmentProb(dataSet.getRow(i), predictions[i])).toArray();
        File predictionFile = Paths.get(output, "test_predictions", "instance_hamming_loss_optimal", "predictions.txt").toFile();
        try (BufferedWriter br = new BufferedWriter(new FileWriter(predictionFile));){
            for (int i2 = 0; i2 < dataSet.getNumDataPoints(); ++i2) {
                br.write(predictions[i2].toString());
                br.write(":");
                br.write("" + setProbs[i2]);
                br.newLine();
            }
        }
        System.out.println("predicted sets and their probabilities are saved to " + predictionFile.getAbsolutePath());
        System.out.println("============================================================");
    }

    private static void reportGeneral(Config config, CBM cbm, MultiLabelClfDataSet dataSet) throws Exception {
        System.out.println("============================================================");
        System.out.println("computing other predictor-independent metrics");
        String output = config.getString("output.dir");
        File labelProbFile = Paths.get(output, "test_predictions", "label_probabilities.txt").toFile();
        double labelProbThreshold = config.getDouble("report.labelProbThreshold");
        try (BufferedWriter br = new BufferedWriter(new FileWriter(labelProbFile));){
            for (int i2 = 0; i2 < dataSet.getNumDataPoints(); ++i2) {
                br.write(CBMInspector.topLabels(cbm, dataSet.getRow(i2), labelProbThreshold));
                br.newLine();
            }
        }
        System.out.println("individual label probabilities are saved to " + labelProbFile.getAbsolutePath());
        List unobservedLabels = Arrays.stream(FileUtils.readFileToString((File)new File(output, "unobserved_labels.txt")).split(",")).map(s -> s.trim()).filter(s -> !s.isEmpty()).map(s -> Integer.parseInt(s)).collect(Collectors.toList());
        double[] logLikelihoods = IntStream.range(0, dataSet.getNumDataPoints()).parallel().mapToDouble(i -> cbm.predictLogAssignmentProb(dataSet.getRow(i), dataSet.getMultiLabels()[i])).toArray();
        double average = IntStream.range(0, dataSet.getNumDataPoints()).filter(i -> !CBMLR.containsNovelClass(dataSet.getMultiLabels()[i], unobservedLabels)).mapToDouble(i -> logLikelihoods[i]).average().getAsDouble();
        File logLikelihoodFile = Paths.get(output, "test_predictions", "ground_truth_log_likelihood.txt").toFile();
        FileUtils.writeStringToFile((File)logLikelihoodFile, (String)PrintUtil.toMutipleLines(logLikelihoods));
        System.out.println("individual log likelihood of the test ground truth label set is written to " + logLikelihoodFile.getAbsolutePath());
        System.out.println("average log likelihood of the test ground truth label sets = " + average);
        if (!unobservedLabels.isEmpty()) {
            System.out.println("This is computed by ignoring test instances with new labels unobserved during training");
            System.out.println("The following labels do not actually appear in the training set and therefore cannot be learned:");
            System.out.println(ListUtil.toSimpleString(unobservedLabels));
        }
    }

    private static LRCBMOptimizer getOptimizer(Config config, HyperParameters hyperParameters, CBM cbm, MultiLabelClfDataSet trainSet) {
        LRCBMOptimizer lrcbmOptimizer = new LRCBMOptimizer(cbm, trainSet);
        lrcbmOptimizer.setPriorVarianceBinary(hyperParameters.variance);
        lrcbmOptimizer.setPriorVarianceMultiClass(hyperParameters.variance);
        lrcbmOptimizer.setBinaryUpdatesPerIter(config.getInt("train.updatesPerIteration"));
        lrcbmOptimizer.setBinaryUpdatesPerIter(config.getInt("train.updatesPerIteration"));
        lrcbmOptimizer.setSkipDataThreshold(config.getDouble("train.skipDataThreshold"));
        lrcbmOptimizer.setSkipLabelThreshold(config.getDouble("train.skipLabelThreshold"));
        lrcbmOptimizer.setSmoothingStrength(config.getDouble("train.smoothStrength"));
        return lrcbmOptimizer;
    }

    private static CBM newCBM(Config config, MultiLabelClfDataSet trainSet, HyperParameters hyperParameters) {
        String allowEmpty;
        CBM cbm = CBM.getBuilder().setNumClasses(trainSet.getNumClasses()).setNumFeatures(trainSet.getNumFeatures()).setNumComponents(hyperParameters.numComponents).setMultiClassClassifierType("lr").setBinaryClassifierType("lr").build();
        switch (allowEmpty = config.getString("predict.allowEmpty")) {
            case "true": {
                cbm.setAllowEmpty(true);
                break;
            }
            case "false": {
                cbm.setAllowEmpty(false);
                break;
            }
            case "auto": {
                Set seen = DataSetUtil.gatherMultiLabels(trainSet).stream().collect(Collectors.toSet());
                MultiLabel empty = new MultiLabel();
                if (seen.contains(empty)) {
                    cbm.setAllowEmpty(true);
                    if (!VERBOSE) break;
                    System.out.println("training set contains empty labels, automatically set predict.allowEmpty = true");
                    break;
                }
                cbm.setAllowEmpty(false);
                if (!VERBOSE) break;
                System.out.println("training set does not contain empty labels, automatically set predict.allowEmpty = false");
                break;
            }
            default: {
                throw new IllegalArgumentException("unknown value for predict.allowEmpty");
            }
        }
        return cbm;
    }

    private static EarlyStopper loadNewEarlyStopper(Config config) {
        String earlyStopMetric = config.getString("tune.targetMetric");
        int patience = config.getInt("tune.earlyStop.patience");
        EarlyStopper.Goal earlyStopGoal = null;
        switch (earlyStopMetric) {
            case "instance_set_accuracy": {
                earlyStopGoal = EarlyStopper.Goal.MAXIMIZE;
                break;
            }
            case "instance_f1": {
                earlyStopGoal = EarlyStopper.Goal.MAXIMIZE;
                break;
            }
            case "instance_hamming_loss": {
                earlyStopGoal = EarlyStopper.Goal.MINIMIZE;
                break;
            }
            default: {
                throw new IllegalArgumentException("unsupported tune.targetMetric " + earlyStopMetric);
            }
        }
        EarlyStopper earlyStopper = new EarlyStopper(earlyStopGoal, patience);
        earlyStopper.setMinimumIterations(config.getInt("tune.earlyStop.minIterations"));
        return earlyStopper;
    }

    private static List<MultiLabelClfDataSet> loadTrainValidData(Config config) throws Exception {
        String validPath = config.getString("input.validData");
        ArrayList<MultiLabelClfDataSet> datasets = new ArrayList<MultiLabelClfDataSet>();
        MultiLabelClfDataSet trainSet = TRECFormat.loadMultiLabelClfDataSetAutoSparseSequential(config.getString("input.trainData"));
        if (validPath.isEmpty()) {
            System.out.println("No external validation data is provided. Use random 20% of the training data for validation.");
            Pair<MultiLabelClfDataSet, MultiLabelClfDataSet> dataSetPair = DataSetUtil.splitToTrainValidation(trainSet, 0.8);
            MultiLabelClfDataSet subTrain = dataSetPair.getFirst();
            MultiLabelClfDataSet validSet = dataSetPair.getSecond();
            datasets.add(subTrain);
            datasets.add(validSet);
        } else {
            MultiLabelClfDataSet validSet = TRECFormat.loadMultiLabelClfDataSetAutoSparseSequential(config.getString("input.validData"));
            datasets.add(trainSet);
            datasets.add(validSet);
        }
        return datasets;
    }

    private static MultiLabelClfDataSet loadTrainData(Config config) throws Exception {
        String validPath = config.getString("input.validData");
        MultiLabelClfDataSet trainSet = TRECFormat.loadMultiLabelClfDataSetAutoSparseSequential(config.getString("input.trainData"));
        if (validPath.isEmpty()) {
            return trainSet;
        }
        MultiLabelClfDataSet validSet = TRECFormat.loadMultiLabelClfDataSetAutoSparseSequential(config.getString("input.validData"));
        return DataSetUtil.concatenateByRow(trainSet, validSet);
    }

    private static boolean containsNovelClass(MultiLabel multiLabel, List<Integer> novelLabels) {
        for (int l : novelLabels) {
            if (!multiLabel.matchClass(l)) continue;
            return true;
        }
        return false;
    }

    private static class TuneResult {
        HyperParameters hyperParameters;
        double performance;

        private TuneResult() {
        }
    }

    private static class HyperParameters {
        double variance;
        int iterations;
        int numComponents;

        HyperParameters() {
        }

        HyperParameters(Config config) {
            this.variance = config.getDouble("train.variance");
            this.iterations = config.getInt("train.iterations");
            this.numComponents = config.getInt("train.numComponents");
        }

        Config asConfig() {
            Config config = new Config();
            config.setDouble("train.variance", this.variance);
            config.setInt("train.iterations", this.iterations);
            config.setInt("train.numComponents", this.numComponents);
            return config;
        }
    }
}

