/*
 * Decompiled with CFR 0.152.
 */
package edu.neu.ccs.pyramid.multilabel_classification.multi_label_logistic_regression;

import edu.neu.ccs.pyramid.dataset.DataSetUtil;
import edu.neu.ccs.pyramid.dataset.MultiLabel;
import edu.neu.ccs.pyramid.dataset.MultiLabelClfDataSet;
import edu.neu.ccs.pyramid.multilabel_classification.multi_label_logistic_regression.MLLogisticLoss;
import edu.neu.ccs.pyramid.multilabel_classification.multi_label_logistic_regression.MLLogisticRegression;
import edu.neu.ccs.pyramid.optimization.LBFGS;
import java.util.List;
import java.util.stream.Collectors;

public class MLLogisticTrainer {
    private double gaussianPriorVariance = 1.0;

    public static Builder getBuilder() {
        return new Builder();
    }

    public MLLogisticRegression train(MultiLabelClfDataSet dataset) {
        List<MultiLabel> assignments = DataSetUtil.gatherMultiLabels(dataset).stream().collect(Collectors.toList());
        return this.train(dataset, assignments);
    }

    public MLLogisticRegression train(MultiLabelClfDataSet dataset, List<MultiLabel> assignments) {
        MLLogisticRegression mlLogisticRegression = new MLLogisticRegression(dataset.getNumClasses(), dataset.getNumFeatures(), assignments);
        mlLogisticRegression.setFeatureList(dataset.getFeatureList());
        mlLogisticRegression.setLabelTranslator(dataset.getLabelTranslator());
        mlLogisticRegression.setFeatureExtraction(false);
        MLLogisticLoss function = new MLLogisticLoss(mlLogisticRegression, dataset, this.gaussianPriorVariance);
        LBFGS lbfgs = new LBFGS(function);
        lbfgs.optimize();
        return mlLogisticRegression;
    }

    public static class Builder {
        private double gaussianPriorVariance = 1.0;

        public Builder setGaussianPriorVariance(double gaussianPriorVariance) {
            this.gaussianPriorVariance = gaussianPriorVariance;
            return this;
        }

        public MLLogisticTrainer build() {
            MLLogisticTrainer trainer = new MLLogisticTrainer();
            trainer.gaussianPriorVariance = this.gaussianPriorVariance;
            return trainer;
        }
    }
}

