/*
 * Decompiled with CFR 0.152.
 */
package edu.neu.ccs.pyramid.classification.logistic_regression;

import edu.neu.ccs.pyramid.dataset.ClfDataSet;
import edu.neu.ccs.pyramid.dataset.DataSet;
import java.util.stream.IntStream;
import org.apache.mahout.math.DenseVector;
import org.apache.mahout.math.Vector;

public class RidgeBinaryLogisticLoss {
    private final Vector regularization;
    private final Vector scores;
    private final Vector diagonals;
    private final DataSet dataSet;
    private int[] labels;
    private int numRows;
    private int numColumns;

    public int getNumColumns() {
        return this.numColumns;
    }

    public RidgeBinaryLogisticLoss(ClfDataSet clfDataSet, Vector regularization) {
        this.dataSet = clfDataSet;
        this.numRows = this.dataSet.getNumDataPoints();
        this.numColumns = this.dataSet.getNumFeatures() + 1;
        this.scores = new DenseVector(this.numRows);
        this.diagonals = new DenseVector(this.numRows);
        this.regularization = regularization;
        this.labels = RidgeBinaryLogisticLoss.changeLabels(clfDataSet);
    }

    private double rowDot(int rowIndex, Vector vector) {
        double product = 0.0;
        product += vector.get(0);
        Vector part = vector.viewPart(1, vector.size() - 1);
        return product += this.dataSet.getRow(rowIndex).dot(part);
    }

    private void Xv(Vector v, Vector Xv) {
        if (Xv.isDense()) {
            IntStream.range(0, this.numRows).parallel().forEach(i -> Xv.set(i, this.rowDot(i, v)));
        } else {
            for (int i2 = 0; i2 < this.numRows; ++i2) {
                Xv.set(i2, this.rowDot(i2, v));
            }
        }
    }

    private double columnDot(int columnIndex, Vector vector) {
        if (columnIndex == 0) {
            return vector.zSum();
        }
        return this.dataSet.getColumn(columnIndex - 1).dot(vector);
    }

    private void XTv(Vector v, Vector XTv) {
        if (XTv.isDense()) {
            IntStream.range(0, this.numColumns).parallel().forEach(i -> XTv.set(i, this.columnDot(i, v)));
        } else {
            for (int i2 = 0; i2 < this.numColumns; ++i2) {
                XTv.set(i2, this.columnDot(i2, v));
            }
        }
    }

    public double fun(Vector w) {
        double f = 0.0;
        this.Xv(w, this.scores);
        f += w.dot(w);
        f /= 2.0;
        for (int i = 0; i < this.numRows; ++i) {
            double yz = (double)this.labels[i] * this.scores.get(i);
            if (yz >= 0.0) {
                f += this.regularization.get(i) * Math.log(1.0 + Math.exp(-yz));
                continue;
            }
            f += this.regularization.get(i) * (-yz + Math.log(1.0 + Math.exp(yz)));
        }
        return f;
    }

    public void grad(Vector w, Vector g) {
        int i;
        int[] y = this.labels;
        for (i = 0; i < this.numRows; ++i) {
            this.scores.set(i, 1.0 / (1.0 + Math.exp((double)(-y[i]) * this.scores.get(i))));
            this.diagonals.set(i, this.scores.get(i) * (1.0 - this.scores.get(i)));
            this.scores.set(i, this.regularization.get(i) * (this.scores.get(i) - 1.0) * (double)y[i]);
        }
        this.XTv(this.scores, g);
        for (i = 0; i < g.size(); ++i) {
            g.set(i, w.get(i) + g.get(i));
        }
    }

    public void Hv(Vector s, Vector Hs) {
        int i;
        DenseVector wa = new DenseVector(this.numRows);
        this.Xv(s, (Vector)wa);
        for (i = 0; i < this.numRows; ++i) {
            wa.set(i, this.regularization.get(i) * this.diagonals.get(i) * wa.get(i));
        }
        this.XTv((Vector)wa, Hs);
        for (i = 0; i < this.numColumns; ++i) {
            Hs.set(i, s.get(i) + Hs.get(i));
        }
    }

    public static int[] changeLabels(ClfDataSet clfDataSet) {
        if (clfDataSet.getNumClasses() != 2) {
            throw new RuntimeException("clfDataSet.getNumClasses()!=2");
        }
        int[] labels = clfDataSet.getLabels();
        int[] changed = new int[labels.length];
        for (int i = 0; i < labels.length; ++i) {
            changed[i] = labels[i] == 0 ? -1 : 1;
        }
        return changed;
    }
}

