/*
 * Decompiled with CFR 0.152.
 */
package edu.neu.ccs.pyramid.application;

import com.fasterxml.jackson.core.JsonEncoding;
import com.fasterxml.jackson.core.JsonFactory;
import com.fasterxml.jackson.core.JsonGenerator;
import com.fasterxml.jackson.databind.ObjectMapper;
import edu.neu.ccs.pyramid.configuration.Config;
import edu.neu.ccs.pyramid.dataset.DataSetType;
import edu.neu.ccs.pyramid.dataset.DataSetUtil;
import edu.neu.ccs.pyramid.dataset.MultiLabelClfDataSet;
import edu.neu.ccs.pyramid.dataset.TRECFormat;
import edu.neu.ccs.pyramid.eval.MLMeasures;
import edu.neu.ccs.pyramid.multilabel_classification.PluginPredictor;
import edu.neu.ccs.pyramid.multilabel_classification.crf.CMLCRF;
import edu.neu.ccs.pyramid.multilabel_classification.crf.CRFInspector;
import edu.neu.ccs.pyramid.multilabel_classification.crf.CRFLoss;
import edu.neu.ccs.pyramid.multilabel_classification.crf.InstanceF1Predictor;
import edu.neu.ccs.pyramid.multilabel_classification.crf.SubsetAccPredictor;
import edu.neu.ccs.pyramid.optimization.LBFGS;
import edu.neu.ccs.pyramid.util.PrintUtil;
import edu.neu.ccs.pyramid.util.Serialization;
import edu.neu.ccs.pyramid.util.SetUtil;
import java.io.File;
import java.util.List;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import org.apache.commons.io.FileUtils;

public class App6 {
    public static void main(String[] args) throws Exception {
        if (args.length != 1) {
            throw new IllegalArgumentException("Please specify a properties file.");
        }
        Config config = new Config(args[0]);
        System.out.println(config);
        if (config.getBoolean("train")) {
            App6.train(config);
        }
        if (config.getBoolean("test")) {
            App6.test(config);
        }
    }

    private static void train(Config config) throws Exception {
        String predictTarget;
        MultiLabelClfDataSet trainSet = TRECFormat.loadMultiLabelClfDataSet(config.getString("input.trainData"), DataSetType.ML_CLF_SEQ_SPARSE, true);
        MultiLabelClfDataSet testSet = TRECFormat.loadMultiLabelClfDataSet(config.getString("input.testData"), DataSetType.ML_CLF_SEQ_SPARSE, true);
        CMLCRF cmlcrf = new CMLCRF(trainSet);
        double gaussianVariance = config.getDouble("train.gaussianVariance");
        cmlcrf.setConsiderPair(true);
        CRFLoss crfLoss = new CRFLoss(cmlcrf, trainSet, gaussianVariance);
        int maxIteration = config.getInt("train.maxIteration");
        crfLoss.setRegularizeAll(true);
        LBFGS optimizer = new LBFGS(crfLoss);
        optimizer.getTerminator().setMaxIteration(maxIteration);
        PluginPredictor<CMLCRF> predictor = null;
        switch (predictTarget = config.getString("predict.target")) {
            case "subsetAccuracy": {
                predictor = new SubsetAccPredictor(cmlcrf);
                break;
            }
            case "instanceFMeasure": {
                predictor = new InstanceF1Predictor(cmlcrf);
                break;
            }
            default: {
                throw new IllegalArgumentException("predict.target must be subsetAccuracy or instanceFMeasure");
            }
        }
        int progressInterval = config.getInt("train.showProgress.interval");
        System.out.println("start training");
        int iteration = 0;
        do {
            optimizer.iterate();
            if (++iteration % progressInterval != 0) continue;
            System.out.println("iteration " + iteration);
            System.out.println("training objective = " + optimizer.getTerminator().getLastValue());
            System.out.println("training performance:");
            System.out.println(new MLMeasures(predictor, trainSet));
            System.out.println("test performance:");
            System.out.println(new MLMeasures(predictor, testSet));
        } while (!optimizer.getTerminator().shouldTerminate());
        System.out.println("iteration " + iteration);
        System.out.println("training objective = " + optimizer.getTerminator().getLastValue());
        System.out.println("training performance:");
        System.out.println(new MLMeasures(predictor, trainSet));
        System.out.println("test performance:");
        System.out.println(new MLMeasures(predictor, testSet));
        System.out.println("training done!");
        String modelName = "model_crf";
        String output = config.getString("output.folder");
        new File(output).mkdirs();
        File serializeModel = new File(output, modelName);
        cmlcrf.serialize(serializeModel);
        Object[] predictions = cmlcrf.predict(trainSet);
        File predictionFile = new File(output, "train_predictions.txt");
        FileUtils.writeStringToFile((File)predictionFile, (String)PrintUtil.toMutipleLines(predictions));
        System.out.println("predictions on the training set are written to " + predictionFile.getAbsolutePath());
        if (config.getBoolean("train.generateReports")) {
            App6.report(config, trainSet, "trainSet");
        }
    }

    private static void test(Config config) throws Exception {
        String predictTarget;
        MultiLabelClfDataSet testSet = TRECFormat.loadMultiLabelClfDataSet(config.getString("input.testData"), DataSetType.ML_CLF_SEQ_SPARSE, true);
        String modelName = "model_crf";
        String output = config.getString("output.folder");
        CMLCRF cmlcrf = (CMLCRF)Serialization.deserialize(new File(output, modelName));
        PluginPredictor<CMLCRF> predictor = null;
        switch (predictTarget = config.getString("predict.target")) {
            case "subsetAccuracy": {
                predictor = new SubsetAccPredictor(cmlcrf);
                break;
            }
            case "instanceFMeasure": {
                predictor = new InstanceF1Predictor(cmlcrf);
                break;
            }
            default: {
                throw new IllegalArgumentException("predict.target must be subsetAccuracy or instanceFMeasure");
            }
        }
        System.out.println("test performance:");
        System.out.println(new MLMeasures(predictor, testSet));
        Object[] predictions = cmlcrf.predict(testSet);
        File predictionFile = new File(output, "test_predictions.txt");
        FileUtils.writeStringToFile((File)predictionFile, (String)PrintUtil.toMutipleLines(predictions));
        System.out.println("predictions on the test set are written to " + predictionFile.getAbsolutePath());
        App6.report(config, testSet, "testSet");
    }

    static void report(Config config, MultiLabelClfDataSet dataSet, String dataName) throws Exception {
        boolean individualPerformance;
        boolean performanceToJson;
        boolean modelConfigToJson;
        boolean dataInfoToJson;
        String predictTarget;
        System.out.println("generating reports for data set " + dataName);
        String output = config.getString("output.folder");
        String modelName = "model_crf";
        File analysisFolder = new File(new File(output, "reports_crf"), dataName + "_reports");
        analysisFolder.mkdirs();
        FileUtils.cleanDirectory((File)analysisFolder);
        CMLCRF crf = (CMLCRF)Serialization.deserialize(new File(output, modelName));
        PluginPredictor<CMLCRF> predictorTmp = null;
        switch (predictTarget = config.getString("predict.target")) {
            case "subsetAccuracy": {
                predictorTmp = new SubsetAccPredictor(crf);
                break;
            }
            case "instanceFMeasure": {
                predictorTmp = new InstanceF1Predictor(crf);
                break;
            }
            default: {
                throw new IllegalArgumentException("predict.target must be subsetAccuracy or instanceFMeasure");
            }
        }
        PluginPredictor<CMLCRF> predictor = predictorTmp;
        MLMeasures mlMeasures = new MLMeasures(predictor, dataSet);
        mlMeasures.getMacroAverage().setLabelTranslator(crf.getLabelTranslator());
        System.out.println("performance on dataset " + dataName);
        System.out.println(mlMeasures);
        boolean simpleCSV = true;
        if (simpleCSV) {
            double probThreshold = config.getDouble("report.classProbThreshold");
            File csv = new File(analysisFolder, "report.csv");
            List strs = IntStream.range(0, dataSet.getNumDataPoints()).parallel().mapToObj(i -> CRFInspector.simplePredictionAnalysis(crf, predictor, dataSet, i, probThreshold)).collect(Collectors.toList());
            StringBuilder sb = new StringBuilder();
            for (int i2 = 0; i2 < dataSet.getNumDataPoints(); ++i2) {
                String str = (String)strs.get(i2);
                sb.append(str);
            }
            FileUtils.writeStringToFile((File)csv, (String)sb.toString(), (boolean)false);
        }
        if (dataInfoToJson = true) {
            Set modelLabels = IntStream.range(0, crf.getNumClasses()).mapToObj(i -> crf.getLabelTranslator().toExtLabel(i)).collect(Collectors.toSet());
            Set dataSetLabels = DataSetUtil.gatherLabels(dataSet).stream().map(i -> dataSet.getLabelTranslator().toExtLabel((int)i)).collect(Collectors.toSet());
            JsonGenerator jsonGenerator = new JsonFactory().createGenerator(new File(analysisFolder, "data_info.json"), JsonEncoding.UTF8);
            jsonGenerator.writeStartObject();
            jsonGenerator.writeStringField("dataSet", dataName);
            jsonGenerator.writeNumberField("numClassesInModel", crf.getNumClasses());
            jsonGenerator.writeNumberField("numClassesInDataSet", dataSetLabels.size());
            jsonGenerator.writeNumberField("numClassesInModelDataSetCombined", dataSet.getNumClasses());
            Set<String> modelNotDataLabels = SetUtil.complement(modelLabels, dataSetLabels);
            Set<String> dataNotModelLabels = SetUtil.complement(dataSetLabels, modelLabels);
            jsonGenerator.writeNumberField("numClassesInDataSetButNotModel", dataNotModelLabels.size());
            jsonGenerator.writeNumberField("numClassesInModelButNotDataSet", modelNotDataLabels.size());
            jsonGenerator.writeArrayFieldStart("classesInDataSetButNotModel");
            for (String label : dataNotModelLabels) {
                jsonGenerator.writeObject((Object)label);
            }
            jsonGenerator.writeEndArray();
            jsonGenerator.writeArrayFieldStart("classesInModelButNotDataSet");
            for (String label : modelNotDataLabels) {
                jsonGenerator.writeObject((Object)label);
            }
            jsonGenerator.writeEndArray();
            jsonGenerator.writeNumberField("labelCardinality", dataSet.labelCardinality());
            jsonGenerator.writeEndObject();
            jsonGenerator.close();
        }
        if (modelConfigToJson = true) {
            ObjectMapper objectMapper = new ObjectMapper();
            objectMapper.writeValue(new File(analysisFolder, "model_config.json"), (Object)config);
        }
        if (performanceToJson = true) {
            ObjectMapper objectMapper = new ObjectMapper();
            objectMapper.writeValue(new File(analysisFolder, "performance.json"), (Object)mlMeasures);
        }
        if (individualPerformance = true) {
            ObjectMapper objectMapper = new ObjectMapper();
            objectMapper.writeValue(new File(analysisFolder, "individual_performance.json"), (Object)mlMeasures.getMacroAverage());
        }
        System.out.println("reports generated");
    }
}

