/*
 * Decompiled with CFR 0.152.
 */
package edu.neu.ccs.pyramid.multilabel_classification.cbm;

import edu.neu.ccs.pyramid.dataset.SerializableVector;
import edu.neu.ccs.pyramid.util.MathUtil;
import edu.neu.ccs.pyramid.util.Vectors;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.Serializable;
import org.apache.mahout.math.DenseVector;
import org.apache.mahout.math.Vector;

public class AugmentedLR
implements Serializable {
    private static final long serialVersionUID = 1L;
    private int numFeatures;
    private int numComponents;
    private transient Vector weights;

    public AugmentedLR(int numFeatures, int numComponents) {
        this.numFeatures = numFeatures;
        this.numComponents = numComponents;
        this.weights = new DenseVector(numFeatures + numComponents + 1);
    }

    public int getNumComponents() {
        return this.numComponents;
    }

    Vector getAllWeights() {
        return this.weights;
    }

    void setWeights(Vector weights) {
        this.weights = weights;
    }

    private double getWeightForComponent(int k) {
        return this.weights.get(this.numFeatures + k);
    }

    private double getBias() {
        return this.weights.get(this.weights.size() - 1);
    }

    Vector featureWeights() {
        return this.weights.viewPart(0, this.numFeatures);
    }

    Vector componentWeights() {
        return this.weights.viewPart(this.numFeatures, this.numComponents);
    }

    Vector getWeightsWithoutBias() {
        return this.weights.viewPart(0, this.numFeatures + this.numComponents);
    }

    private double featureScore(Vector featureVector) {
        return Vectors.dot(this.featureWeights(), featureVector) + this.getBias();
    }

    private double[] augmentedScores(Vector featureVector) {
        double[] scores = new double[this.numComponents];
        double featureScore = this.featureScore(featureVector);
        for (int k = 0; k < this.numComponents; ++k) {
            scores[k] = featureScore + this.getWeightForComponent(k);
        }
        return scores;
    }

    private double[][] logAugmentedProbs(double[] augmentedScores) {
        double[][] logProbs = new double[this.numComponents][2];
        for (int k = 0; k < this.numComponents; ++k) {
            double[] s = new double[]{0.0, augmentedScores[k]};
            logProbs[k] = MathUtil.logSoftmax(s);
        }
        return logProbs;
    }

    double[][] logAugmentedProbs(Vector featureVector) {
        double[] scores = this.augmentedScores(featureVector);
        return this.logAugmentedProbs(scores);
    }

    private void writeObject(ObjectOutputStream out) throws IOException {
        out.writeInt(this.numFeatures);
        out.writeInt(this.numComponents);
        SerializableVector serializableVector = new SerializableVector(this.weights);
        out.writeObject(serializableVector);
    }

    private void readObject(ObjectInputStream in) throws IOException, ClassNotFoundException {
        this.numFeatures = in.readInt();
        this.numComponents = in.readInt();
        this.weights = ((SerializableVector)in.readObject()).getVector();
    }
}

