/*
 * Decompiled with CFR 0.152.
 */
package edu.neu.ccs.pyramid.clustering.bm;

import edu.neu.ccs.pyramid.clustering.bm.BM;
import edu.neu.ccs.pyramid.clustering.bm.BMTrainer;
import edu.neu.ccs.pyramid.dataset.DataSet;
import edu.neu.ccs.pyramid.dataset.DataSetBuilder;
import edu.neu.ccs.pyramid.dataset.Density;
import edu.neu.ccs.pyramid.dataset.MultiLabel;
import edu.neu.ccs.pyramid.util.Pair;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;

public class BMSelector {
    private static final Logger logger = LogManager.getLogger();

    public static BM select(DataSet dataSet, int numClusters, int numRuns) {
        if (logger.isDebugEnabled()) {
            logger.debug("start method select");
        }
        BM best = null;
        double bestObjective = Double.POSITIVE_INFINITY;
        for (int i = 0; i < numRuns; ++i) {
            BMTrainer trainer = new BMTrainer(dataSet, numClusters, i);
            BM bm = trainer.train();
            double objective = trainer.getObjective();
            if (!(objective < bestObjective)) continue;
            bestObjective = objective;
            best = bm;
        }
        if (logger.isDebugEnabled()) {
            logger.debug("finish method select");
        }
        return best;
    }

    public static BMTrainer selectTrainer(DataSet dataSet, int numClusters, int numRuns) {
        BMTrainer best = null;
        double bestObjective = Double.POSITIVE_INFINITY;
        for (int i = 0; i < numRuns; ++i) {
            BMTrainer trainer = new BMTrainer(dataSet, numClusters, i);
            BM bm = trainer.train();
            double objective = trainer.getObjective();
            if (!(objective < bestObjective)) continue;
            bestObjective = objective;
            best = trainer;
        }
        return best;
    }

    public static double[][] selectGammas(int numClasses, MultiLabel[] multiLabels, int numClusters) {
        DataSet dataSet = DataSetBuilder.getBuilder().numDataPoints(multiLabels.length).numFeatures(numClasses).density(Density.SPARSE_RANDOM).build();
        for (int i = 0; i < multiLabels.length; ++i) {
            MultiLabel multiLabel = multiLabels[i];
            for (int label : multiLabel.getMatchedLabels()) {
                dataSet.setFeatureValue(i, label, 1.0);
            }
        }
        BMTrainer trainer = BMSelector.selectTrainer(dataSet, numClusters, 10);
        return trainer.gammas;
    }

    public static Pair<BM, double[][]> selectAll(int numClasses, MultiLabel[] multiLabels, int numClusters) {
        DataSet dataSet = DataSetBuilder.getBuilder().numDataPoints(multiLabels.length).numFeatures(numClasses).density(Density.SPARSE_RANDOM).build();
        for (int i = 0; i < multiLabels.length; ++i) {
            MultiLabel multiLabel = multiLabels[i];
            for (int label : multiLabel.getMatchedLabels()) {
                dataSet.setFeatureValue(i, label, 1.0);
            }
        }
        BMTrainer trainer = BMSelector.selectTrainer(dataSet, numClusters, 10);
        Pair<BM, double[][]> pair = new Pair<BM, double[][]>();
        pair.setFirst(trainer.getBm());
        pair.setSecond(trainer.gammas);
        return pair;
    }
}

