/*
 * Decompiled with CFR 0.152.
 */
package edu.neu.ccs.pyramid.multilabel_classification.cbm;

import edu.neu.ccs.pyramid.clustering.bm.BMSelector;
import edu.neu.ccs.pyramid.dataset.MultiLabelClfDataSet;
import edu.neu.ccs.pyramid.multilabel_classification.cbm.CBM;
import edu.neu.ccs.pyramid.multilabel_classification.cbm.CBMOptimizer;
import edu.neu.ccs.pyramid.multilabel_classification.cbm.SparkCBMOptimizer;
import edu.neu.ccs.pyramid.util.MathUtil;

public class CBMInitializer {
    public static void initialize(CBM cbm, MultiLabelClfDataSet dataSet, CBMOptimizer optimizer) {
        double[][] gamms = BMSelector.selectGammas(dataSet.getNumClasses(), dataSet.getMultiLabels(), cbm.getNumComponents());
        for (int i = 0; i < dataSet.getNumDataPoints(); ++i) {
            for (int k = 0; k < cbm.getNumComponents(); ++k) {
                optimizer.gammas[i][k] = gamms[i][k];
                optimizer.gammasT[k][i] = gamms[i][k];
            }
        }
        System.out.println("performing M step");
        optimizer.mStep();
    }

    public static void randInitialize(CBM CBM2, MultiLabelClfDataSet dataSet, CBMOptimizer optimizer) {
        int K = CBM2.getNumComponents();
        for (int i = 0; i < dataSet.getNumDataPoints(); ++i) {
            double[] dist = new double[K];
            for (int k = 0; k < K; ++k) {
                dist[k] = Math.random();
            }
            double sum = MathUtil.arraySum(dist);
            for (int k = 0; k < K; ++k) {
                double value;
                optimizer.gammas[i][k] = value = dist[k] / sum;
                optimizer.gammasT[k][i] = value;
            }
        }
        optimizer.mStep();
    }

    public static void initialize(CBM cbm, MultiLabelClfDataSet dataSet, SparkCBMOptimizer optimizer) {
        double[][] gamms = BMSelector.selectGammas(dataSet.getNumClasses(), dataSet.getMultiLabels(), cbm.getNumComponents());
        for (int i = 0; i < dataSet.getNumDataPoints(); ++i) {
            for (int k = 0; k < cbm.getNumComponents(); ++k) {
                optimizer.gammas[i][k] = gamms[i][k];
                optimizer.gammasT[k][i] = gamms[i][k];
            }
        }
        System.out.println("performing M step");
        optimizer.mStep();
    }

    public static void randInitialize(CBM CBM2, MultiLabelClfDataSet dataSet, SparkCBMOptimizer optimizer) {
        int K = CBM2.getNumComponents();
        for (int i = 0; i < dataSet.getNumDataPoints(); ++i) {
            double[] dist = new double[K];
            for (int k = 0; k < K; ++k) {
                dist[k] = Math.random();
            }
            double sum = MathUtil.arraySum(dist);
            for (int k = 0; k < K; ++k) {
                double value;
                optimizer.gammas[i][k] = value = dist[k] / sum;
                optimizer.gammasT[k][i] = value;
            }
        }
        optimizer.mStep();
    }

    public static void avgInitialize(CBM CBM2, MultiLabelClfDataSet dataSet, CBMOptimizer optimizer) {
        int K = CBM2.getNumComponents();
        double avgValue = 1.0 / (double)K;
        for (int i = 0; i < dataSet.getNumDataPoints(); ++i) {
            for (int k = 0; k < K; ++k) {
                optimizer.gammas[i][k] = avgValue;
                optimizer.gammasT[k][i] = avgValue;
            }
        }
        optimizer.mStep();
    }
}

