/*
 * Decompiled with CFR 0.152.
 */
package edu.neu.ccs.pyramid.application;

import com.fasterxml.jackson.core.JsonEncoding;
import com.fasterxml.jackson.core.JsonFactory;
import com.fasterxml.jackson.core.JsonGenerator;
import com.fasterxml.jackson.databind.ObjectMapper;
import edu.neu.ccs.pyramid.configuration.Config;
import edu.neu.ccs.pyramid.dataset.DataSetType;
import edu.neu.ccs.pyramid.dataset.DataSetUtil;
import edu.neu.ccs.pyramid.dataset.MultiLabel;
import edu.neu.ccs.pyramid.dataset.MultiLabelClfDataSet;
import edu.neu.ccs.pyramid.dataset.TRECFormat;
import edu.neu.ccs.pyramid.eval.KLDivergence;
import edu.neu.ccs.pyramid.eval.MLMeasures;
import edu.neu.ccs.pyramid.feature.Feature;
import edu.neu.ccs.pyramid.feature.TopFeatures;
import edu.neu.ccs.pyramid.multilabel_classification.PluginPredictor;
import edu.neu.ccs.pyramid.multilabel_classification.imlgb.HammingPredictor;
import edu.neu.ccs.pyramid.multilabel_classification.imlgb.IMLGBConfig;
import edu.neu.ccs.pyramid.multilabel_classification.imlgb.IMLGBInspector;
import edu.neu.ccs.pyramid.multilabel_classification.imlgb.IMLGBTrainer;
import edu.neu.ccs.pyramid.multilabel_classification.imlgb.IMLGradientBoosting;
import edu.neu.ccs.pyramid.multilabel_classification.imlgb.InstanceF1Predictor;
import edu.neu.ccs.pyramid.multilabel_classification.imlgb.MacroF1Predictor;
import edu.neu.ccs.pyramid.multilabel_classification.imlgb.SubsetAccPredictor;
import edu.neu.ccs.pyramid.multilabel_classification.thresholding.MacroFMeasureTuner;
import edu.neu.ccs.pyramid.multilabel_classification.thresholding.TunedMarginalClassifier;
import edu.neu.ccs.pyramid.optimization.EarlyStopper;
import edu.neu.ccs.pyramid.optimization.Terminator;
import edu.neu.ccs.pyramid.util.Progress;
import edu.neu.ccs.pyramid.util.Serialization;
import edu.neu.ccs.pyramid.util.SetUtil;
import java.io.BufferedWriter;
import java.io.File;
import java.io.FileWriter;
import java.io.IOException;
import java.nio.file.Paths;
import java.util.ArrayList;
import java.util.List;
import java.util.Set;
import java.util.logging.FileHandler;
import java.util.logging.Logger;
import java.util.logging.SimpleFormatter;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import org.apache.commons.io.FileUtils;
import org.apache.commons.lang3.time.StopWatch;
import org.apache.mahout.math.Vector;

public class App2 {
    public static void main(String[] args) throws Exception {
        if (args.length != 1) {
            throw new IllegalArgumentException("Please specify a properties file.");
        }
        Config config = new Config(args[0]);
        App2.main(config);
    }

    public static void main(Config config) throws Exception {
        File metaDataFolder;
        Logger logger = Logger.getAnonymousLogger();
        String logFile = config.getString("output.log");
        FileHandler fileHandler = null;
        if (!logFile.isEmpty()) {
            new File(logFile).getParentFile().mkdirs();
            fileHandler = new FileHandler(logFile, true);
            SimpleFormatter formatter = new SimpleFormatter();
            fileHandler.setFormatter(formatter);
            logger.addHandler(fileHandler);
            logger.setUseParentHandlers(false);
        }
        logger.info(config.toString());
        new File(config.getString("output.folder")).mkdirs();
        if (config.getBoolean("train")) {
            App2.train(config, logger);
            if (config.getString("predict.target").equals("macroFMeasure")) {
                logger.info("predict.target=macroFMeasure,  user needs to run 'tune' before predictions can be made. Reports will be generated after tuning.");
            } else if (config.getBoolean("train.generateReports")) {
                App2.report(config, config.getString("input.trainData"), logger);
            }
            metaDataFolder = new File(config.getString("input.folder"), "meta_data");
            config.store(new File(metaDataFolder, "saved_config_app2"));
        }
        if (config.getBoolean("tune")) {
            App2.tuneForMacroF(config, logger);
            metaDataFolder = new File(config.getString("input.folder"), "meta_data");
            Config savedConfig = new Config(new File(metaDataFolder, "saved_config_app2"));
            if (savedConfig.getBoolean("train.generateReports")) {
                App2.report(config, config.getString("input.trainData"), logger);
            }
        }
        if (config.getBoolean("test")) {
            App2.report(config, config.getString("input.testData"), logger);
        }
        if (fileHandler != null) {
            fileHandler.close();
        }
    }

    static MultiLabelClfDataSet loadData(Config config, String dataName) throws Exception {
        File dataFile = new File(new File(config.getString("input.folder"), "data_sets"), dataName);
        MultiLabelClfDataSet dataSet = TRECFormat.loadMultiLabelClfDataSet(dataFile, DataSetType.ML_CLF_SPARSE, true);
        return dataSet;
    }

    static void train(Config config, Logger logger) throws Exception {
        boolean topFeaturesToFile;
        int l;
        String output = config.getString("output.folder");
        int numIterations = config.getInt("train.numIterations");
        int numLeaves = config.getInt("train.numLeaves");
        double learningRate = config.getDouble("train.learningRate");
        int minDataPerLeaf = config.getInt("train.minDataPerLeaf");
        String modelName = "model_app3";
        StopWatch stopWatch = new StopWatch();
        stopWatch.start();
        MultiLabelClfDataSet dataSet = App2.loadData(config, config.getString("input.trainData"));
        MultiLabelClfDataSet testSet = null;
        if (config.getBoolean("train.showTestProgress")) {
            testSet = App2.loadData(config, config.getString("input.testData"));
        }
        int numClasses = dataSet.getNumClasses();
        logger.info("number of class = " + numClasses);
        IMLGBConfig imlgbConfig = new IMLGBConfig.Builder(dataSet).learningRate(learningRate).minDataPerLeaf(minDataPerLeaf).numLeaves(numLeaves).numSplitIntervals(config.getInt("train.numSplitIntervals")).usePrior(config.getBoolean("train.usePrior")).build();
        IMLGradientBoosting boosting = config.getBoolean("train.warmStart") ? IMLGradientBoosting.deserialize(new File(output, modelName)) : new IMLGradientBoosting(numClasses);
        logger.info("During training, the performance is reported using Hamming loss optimal predictor");
        logger.info("initialing trainer");
        IMLGBTrainer trainer = new IMLGBTrainer(imlgbConfig, boosting);
        boolean earlyStop = config.getBoolean("train.earlyStop");
        ArrayList<EarlyStopper> earlyStoppers = new ArrayList<EarlyStopper>();
        ArrayList<Terminator> terminators = new ArrayList<Terminator>();
        if (earlyStop) {
            int l2;
            for (l2 = 0; l2 < numClasses; ++l2) {
                EarlyStopper earlyStopper = new EarlyStopper(EarlyStopper.Goal.MINIMIZE, config.getInt("train.earlyStop.patience"));
                earlyStopper.setMinimumIterations(config.getInt("train.earlyStop.minIterations"));
                earlyStoppers.add(earlyStopper);
            }
            for (l2 = 0; l2 < numClasses; ++l2) {
                Terminator terminator = new Terminator();
                terminator.setMaxStableIterations(config.getInt("train.earlyStop.patience")).setMinIterations(config.getInt("train.earlyStop.minIterations") / config.getInt("train.showProgress.interval")).setAbsoluteEpsilon(config.getDouble("train.earlyStop.absoluteChange")).setRelativeEpsilon(config.getDouble("train.earlyStop.relativeChange")).setOperation(Terminator.Operation.OR);
                terminators.add(terminator);
            }
        }
        logger.info("trainer initialized");
        int numLabelsLeftToTrain = numClasses;
        int progressInterval = config.getInt("train.showProgress.interval");
        for (int i = 1; i <= numIterations; ++i) {
            logger.info("iteration " + i);
            trainer.iterate();
            if (config.getBoolean("train.showTrainProgress") && (i % progressInterval == 0 || i == numIterations)) {
                logger.info("training set performance");
                logger.info(new MLMeasures(boosting, dataSet).toString());
            }
            if (config.getBoolean("train.showTestProgress") && (i % progressInterval == 0 || i == numIterations)) {
                logger.info("test set performance");
                logger.info(new MLMeasures(boosting, testSet).toString());
                if (earlyStop) {
                    for (l = 0; l < numClasses; ++l) {
                        EarlyStopper earlyStopper = (EarlyStopper)earlyStoppers.get(l);
                        Terminator terminator = (Terminator)terminators.get(l);
                        if (trainer.getShouldStop()[l]) continue;
                        double kl = App2.KL(boosting, testSet, l);
                        earlyStopper.add(i, kl);
                        terminator.add(kl);
                        if (!earlyStopper.shouldStop() && !terminator.shouldTerminate()) continue;
                        logger.info("training for label " + l + " (" + dataSet.getLabelTranslator().toExtLabel(l) + ") should stop now");
                        logger.info("the best number of training iterations for the label is " + earlyStopper.getBestIteration());
                        trainer.setShouldStop(l);
                        logger.info("the number of labels left to be trained on = " + --numLabelsLeftToTrain);
                    }
                }
            }
            if (numLabelsLeftToTrain != 0) continue;
            logger.info("all label training finished");
            break;
        }
        logger.info("training done");
        File serializedModel = new File(output, modelName);
        boosting.serialize(serializedModel);
        logger.info(stopWatch.toString());
        if (earlyStop) {
            for (l = 0; l < numClasses; ++l) {
                logger.info("----------------------------------------------------");
                logger.info("test performance history for label " + l + ": " + ((EarlyStopper)earlyStoppers.get(l)).history());
                logger.info("model size for label " + l + " = " + (boosting.getRegressors(l).size() - 1));
            }
        }
        if (topFeaturesToFile = true) {
            logger.info("start writing top features");
            int limit = config.getInt("report.topFeatures.limit");
            List topFeaturesList = IntStream.range(0, boosting.getNumClasses()).mapToObj(k -> IMLGBInspector.topFeatures(boosting, k, limit)).collect(Collectors.toList());
            ObjectMapper mapper = new ObjectMapper();
            String file = "top_features.json";
            mapper.writeValue(new File(output, file), topFeaturesList);
            StringBuilder sb = new StringBuilder();
            for (int l3 = 0; l3 < boosting.getNumClasses(); ++l3) {
                sb.append("-------------------------").append("\n");
                sb.append(dataSet.getLabelTranslator().toExtLabel(l3)).append(":").append("\n");
                for (Feature feature : ((TopFeatures)topFeaturesList.get(l3)).getTopFeatures()) {
                    sb.append(feature.simpleString()).append(", ");
                }
                sb.append("\n");
            }
            FileUtils.writeStringToFile((File)new File(output, "top_features.txt"), (String)sb.toString());
            logger.info("finish writing top features");
        }
    }

    static void tuneForMacroF(Config config, Logger logger) throws Exception {
        String dataName;
        String tuneBy;
        logger.info("start tuning for macro F measure");
        String output = config.getString("output.folder");
        String modelName = "model_app3";
        double beta = config.getDouble("tune.FMeasure.beta");
        IMLGradientBoosting boosting = IMLGradientBoosting.deserialize(new File(output, modelName));
        switch (tuneBy = config.getString("tune.data")) {
            case "train": {
                dataName = config.getString("input.trainData");
                break;
            }
            case "test": {
                dataName = config.getString("input.testData");
                break;
            }
            default: {
                throw new IllegalArgumentException("tune.data should be train or test");
            }
        }
        MultiLabelClfDataSet dataSet = App2.loadData(config, dataName);
        double[] thresholds = MacroFMeasureTuner.tuneThresholds(boosting, dataSet, beta);
        TunedMarginalClassifier tunedMarginalClassifier = new TunedMarginalClassifier(boosting, thresholds);
        Serialization.serialize((Object)tunedMarginalClassifier, new File(output, "predictor_macro_f"));
        logger.info("finish tuning for macro F measure");
    }

    static void report(Config config, String dataName, Logger logger) throws Exception {
        boolean individualPerformance;
        boolean performanceToJson;
        boolean dataConfigToJson;
        boolean modelConfigToJson;
        boolean dataInfoToJson;
        boolean rulesToJson;
        logger.info("generating reports for data set " + dataName);
        String output = config.getString("output.folder");
        String modelName = "model_app3";
        File analysisFolder = new File(new File(output, "reports_app3"), dataName + "_reports");
        analysisFolder.mkdirs();
        FileUtils.cleanDirectory((File)analysisFolder);
        IMLGradientBoosting boosting = IMLGradientBoosting.deserialize(new File(output, modelName));
        String predictTarget = config.getString("predict.target");
        PluginPredictor<IMLGradientBoosting> pluginPredictorTmp = null;
        switch (predictTarget) {
            case "subsetAccuracy": {
                pluginPredictorTmp = new SubsetAccPredictor(boosting);
                break;
            }
            case "hammingLoss": {
                pluginPredictorTmp = new HammingPredictor(boosting);
                break;
            }
            case "instanceFMeasure": {
                pluginPredictorTmp = new InstanceF1Predictor(boosting);
                break;
            }
            case "macroFMeasure": {
                TunedMarginalClassifier tunedMarginalClassifier = (TunedMarginalClassifier)Serialization.deserialize(new File(output, "predictor_macro_f"));
                pluginPredictorTmp = new MacroF1Predictor(boosting, tunedMarginalClassifier);
                break;
            }
            default: {
                throw new IllegalArgumentException("unknown prediction target measure " + predictTarget);
            }
        }
        PluginPredictor<IMLGradientBoosting> pluginPredictor = pluginPredictorTmp;
        MultiLabelClfDataSet dataSet = App2.loadData(config, dataName);
        MLMeasures mlMeasures = new MLMeasures(pluginPredictor, dataSet);
        mlMeasures.getMacroAverage().setLabelTranslator(boosting.getLabelTranslator());
        logger.info("performance on dataset " + dataName);
        logger.info(mlMeasures.toString());
        boolean simpleCSV = true;
        if (simpleCSV) {
            logger.info("start generating simple CSV report");
            double probThreshold = config.getDouble("report.classProbThreshold");
            File csv = new File(analysisFolder, "report.csv");
            List strs = IntStream.range(0, dataSet.getNumDataPoints()).parallel().mapToObj(i -> IMLGBInspector.simplePredictionAnalysis(boosting, pluginPredictor, dataSet, i, probThreshold)).collect(Collectors.toList());
            try (BufferedWriter bw = new BufferedWriter(new FileWriter(csv));){
                for (int i2 = 0; i2 < dataSet.getNumDataPoints(); ++i2) {
                    String str = (String)strs.get(i2);
                    bw.write(str);
                }
            }
            logger.info("finish generating simple CSV report");
        }
        if (rulesToJson = config.getBoolean("report.showPredictionDetail")) {
            logger.info("start writing rules to json");
            int ruleLimit = config.getInt("report.rule.limit");
            int numDocsPerFile = config.getInt("report.numDocsPerFile");
            int numFiles = (int)Math.ceil((double)dataSet.getNumDataPoints() / (double)numDocsPerFile);
            double probThreshold = config.getDouble("report.classProbThreshold");
            int labelSetLimit = config.getInt("report.labelSetLimit");
            IntStream.range(0, numFiles).forEach(i -> {
                int start = i * numDocsPerFile;
                int end = start + numDocsPerFile;
                List partition = IntStream.range(start, Math.min(end, dataSet.getNumDataPoints())).parallel().mapToObj(a -> IMLGBInspector.analyzePrediction(boosting, pluginPredictor, dataSet, a, ruleLimit, labelSetLimit, probThreshold)).collect(Collectors.toList());
                ObjectMapper mapper = new ObjectMapper();
                String file = "report_" + (i + 1) + ".json";
                try {
                    mapper.writeValue(new File(analysisFolder, file), partition);
                }
                catch (IOException e) {
                    e.printStackTrace();
                }
                logger.info("progress = " + Progress.percentage(i + 1, numFiles));
            });
            logger.info("finish writing rules to json");
        }
        if (dataInfoToJson = true) {
            logger.info("start writing data info to json");
            Set modelLabels = IntStream.range(0, boosting.getNumClasses()).mapToObj(i -> boosting.getLabelTranslator().toExtLabel(i)).collect(Collectors.toSet());
            Set dataSetLabels = DataSetUtil.gatherLabels(dataSet).stream().map(i -> dataSet.getLabelTranslator().toExtLabel((int)i)).collect(Collectors.toSet());
            JsonGenerator jsonGenerator = new JsonFactory().createGenerator(new File(analysisFolder, "data_info.json"), JsonEncoding.UTF8);
            jsonGenerator.writeStartObject();
            jsonGenerator.writeStringField("dataSet", dataName);
            jsonGenerator.writeNumberField("numClassesInModel", boosting.getNumClasses());
            jsonGenerator.writeNumberField("numClassesInDataSet", dataSetLabels.size());
            jsonGenerator.writeNumberField("numClassesInModelDataSetCombined", dataSet.getNumClasses());
            Set<String> modelNotDataLabels = SetUtil.complement(modelLabels, dataSetLabels);
            Set<String> dataNotModelLabels = SetUtil.complement(dataSetLabels, modelLabels);
            jsonGenerator.writeNumberField("numClassesInDataSetButNotModel", dataNotModelLabels.size());
            jsonGenerator.writeNumberField("numClassesInModelButNotDataSet", modelNotDataLabels.size());
            jsonGenerator.writeArrayFieldStart("classesInDataSetButNotModel");
            for (String label : dataNotModelLabels) {
                jsonGenerator.writeObject((Object)label);
            }
            jsonGenerator.writeEndArray();
            jsonGenerator.writeArrayFieldStart("classesInModelButNotDataSet");
            for (String label : modelNotDataLabels) {
                jsonGenerator.writeObject((Object)label);
            }
            jsonGenerator.writeEndArray();
            jsonGenerator.writeNumberField("labelCardinality", dataSet.labelCardinality());
            jsonGenerator.writeEndObject();
            jsonGenerator.close();
            logger.info("finish writing data info to json");
        }
        if (modelConfigToJson = true) {
            logger.info("start writing model config to json");
            ObjectMapper objectMapper = new ObjectMapper();
            objectMapper.writeValue(new File(analysisFolder, "model_config.json"), (Object)config);
            logger.info("finish writing model config to json");
        }
        if (dataConfigToJson = true) {
            logger.info("start writing data config to json");
            File dataConfigFile = Paths.get(config.getString("input.folder"), "data_sets", dataName, "data_config.json").toFile();
            if (dataConfigFile.exists()) {
                FileUtils.copyFileToDirectory((File)dataConfigFile, (File)analysisFolder);
            }
            logger.info("finish writing data config to json");
        }
        if (performanceToJson = true) {
            ObjectMapper objectMapper = new ObjectMapper();
            objectMapper.writeValue(new File(analysisFolder, "performance.json"), (Object)mlMeasures);
        }
        if (individualPerformance = true) {
            logger.info("start writing individual label performance to json");
            ObjectMapper objectMapper = new ObjectMapper();
            objectMapper.writeValue(new File(analysisFolder, "individual_performance.json"), (Object)mlMeasures.getMacroAverage());
            logger.info("finish writing individual label performance to json");
        }
        logger.info("reports generated");
    }

    private static double KL(IMLGradientBoosting boosting, Vector vector, MultiLabel multiLabel, int classIndex) {
        double[] p = new double[2];
        if (multiLabel.matchClass(classIndex)) {
            p[0] = 0.0;
            p[1] = 1.0;
        } else {
            p[0] = 1.0;
            p[1] = 0.0;
        }
        double[] logQ = boosting.predictLogClassProbs(vector, classIndex);
        return KLDivergence.klGivenPLogQ(p, logQ);
    }

    private static double KL(IMLGradientBoosting boosting, MultiLabelClfDataSet dataSet, int classIndex) {
        return IntStream.range(0, dataSet.getNumDataPoints()).parallel().mapToDouble(i -> App2.KL(boosting, dataSet.getRow(i), dataSet.getMultiLabels()[i], classIndex)).average().getAsDouble();
    }
}

