/*
 * Decompiled with CFR 0.152.
 */
package edu.neu.ccs.pyramid.multilabel_classification.imllr;

import edu.neu.ccs.pyramid.dataset.MultiLabel;
import edu.neu.ccs.pyramid.dataset.MultiLabelClfDataSet;
import edu.neu.ccs.pyramid.multilabel_classification.imllr.IMLLogisticLoss;
import edu.neu.ccs.pyramid.multilabel_classification.imllr.IMLLogisticRegression;
import edu.neu.ccs.pyramid.optimization.LBFGS;
import java.util.List;

public class IMLLogisticTrainer {
    private double gaussianPriorVariance = 1.0;
    private double epsilon = 1.0;
    private int history = 5;

    public static Builder getBuilder() {
        return new Builder();
    }

    public IMLLogisticRegression train(MultiLabelClfDataSet dataset, List<MultiLabel> assignments) {
        IMLLogisticRegression IMLLogisticRegression2 = new IMLLogisticRegression(dataset.getNumClasses(), dataset.getNumFeatures(), assignments);
        IMLLogisticRegression2.setFeatureList(dataset.getFeatureList());
        IMLLogisticRegression2.setLabelTranslator(dataset.getLabelTranslator());
        IMLLogisticLoss function = new IMLLogisticLoss(IMLLogisticRegression2, dataset, this.gaussianPriorVariance);
        LBFGS lbfgs = new LBFGS(function);
        lbfgs.getTerminator().setRelativeEpsilon(this.epsilon);
        lbfgs.setHistory(this.history);
        lbfgs.optimize();
        return IMLLogisticRegression2;
    }

    public static class Builder {
        private double gaussianPriorVariance = 1.0;
        private double epsilon = 1.0;
        private int history = 5;

        public Builder setGaussianPriorVariance(double gaussianPriorVariance) {
            this.gaussianPriorVariance = gaussianPriorVariance;
            return this;
        }

        public Builder setEpsilon(double epsilon) {
            this.epsilon = epsilon;
            return this;
        }

        public Builder setHistory(int history) {
            this.history = history;
            return this;
        }

        public IMLLogisticTrainer build() {
            IMLLogisticTrainer trainer = new IMLLogisticTrainer();
            trainer.gaussianPriorVariance = this.gaussianPriorVariance;
            trainer.epsilon = this.epsilon;
            trainer.history = this.history;
            return trainer;
        }
    }
}

