/*
 * Decompiled with CFR 0.152.
 */
package edu.neu.ccs.pyramid.dataset;

import edu.neu.ccs.pyramid.util.MathUtil;
import java.util.Arrays;
import java.util.stream.IntStream;
import java.util.stream.Stream;

public class WeightMatrix {
    private int numDataPoints;
    private int numClasses;
    private double[][] dataClass;
    private double[][] classData;

    public WeightMatrix(int numDataPoints, int numClasses) {
        this.numDataPoints = numDataPoints;
        this.numClasses = numClasses;
        this.dataClass = new double[numDataPoints][numClasses];
        this.classData = new double[numClasses][numDataPoints];
    }

    public void setProbability(int dataPointIndex, int classIndex, double prob) {
        this.dataClass[dataPointIndex][classIndex] = prob;
        this.classData[classIndex][dataPointIndex] = prob;
    }

    public double[] getProbsForData(int dataPointIndex) {
        return this.dataClass[dataPointIndex];
    }

    public double[] getProbsForClass(int classIndex) {
        return this.classData[classIndex];
    }

    public void normalize() {
        double sum = ((Stream)Arrays.stream(this.dataClass).parallel()).mapToDouble(MathUtil::arraySum).sum();
        IntStream.range(0, this.numDataPoints).parallel().forEach(i -> IntStream.range(0, this.numClasses).forEach(k -> {
            double prob = this.dataClass[i][k] / sum;
            this.setProbability(i, k, prob);
        }));
    }

    public String toString() {
        StringBuilder sb = new StringBuilder("DistributionMatrix{");
        sb.append("dataClass=").append(Arrays.deepToString((Object[])this.dataClass));
        sb.append(", classData=").append(Arrays.deepToString((Object[])this.classData));
        sb.append('}');
        return sb.toString();
    }
}

