/*
 * Decompiled with CFR 0.152.
 */
package edu.neu.ccs.pyramid.optimization;

import edu.neu.ccs.pyramid.classification.logistic_regression.RidgeBinaryLogisticLoss;
import edu.neu.ccs.pyramid.util.Pair;
import org.apache.mahout.math.DenseVector;
import org.apache.mahout.math.Vector;

public class TrustRegionNewtonOptimizer {
    private final RidgeBinaryLogisticLoss loss;
    private final double eps;
    private final int maxIter;
    private static final double ETA0 = 1.0E-4;
    private static final double ETA1 = 0.25;
    private static final double ETA2 = 0.75;
    private static final double SIGMA1 = 0.25;
    private static final double SIGMA2 = 0.5;
    private static final double SIGMA3 = 4.0;

    public TrustRegionNewtonOptimizer(RidgeBinaryLogisticLoss loss) {
        this(loss, 0.1);
    }

    public TrustRegionNewtonOptimizer(RidgeBinaryLogisticLoss loss, double eps) {
        this(loss, eps, 1000);
    }

    public TrustRegionNewtonOptimizer(RidgeBinaryLogisticLoss loss, double eps, int maxIter) {
        this.loss = loss;
        this.eps = eps;
        this.maxIter = maxIter;
    }

    void tron(Vector w) {
        double delta;
        double gnorm1;
        int numColumns = this.loss.getNumColumns();
        double one = 1.0;
        boolean search = true;
        int iter = 1;
        DenseVector w_new = new DenseVector(numColumns);
        DenseVector g = new DenseVector(numColumns);
        for (int i = 0; i < numColumns; ++i) {
            w.set(i, 0.0);
        }
        double f = this.loss.fun(w);
        this.loss.grad(w, (Vector)g);
        double gnorm = gnorm1 = (delta = g.norm(2.0));
        if (gnorm <= this.eps * gnorm1) {
            search = false;
        }
        iter = 1;
        while (iter <= this.maxIter && search) {
            int j;
            Pair<Vector, Vector> result = this.trcg(delta, (Vector)g);
            Vector s = result.getFirst();
            Vector r = result.getSecond();
            for (j = 0; j < w.size(); ++j) {
                w_new.set(j, w.get(j));
            }
            TrustRegionNewtonOptimizer.daxpy(one, s, (Vector)w_new);
            double gs = g.dot(s);
            double prered = -0.5 * (gs - s.dot(r));
            double fnew = this.loss.fun((Vector)w_new);
            double actred = f - fnew;
            double snorm = s.norm(2.0);
            if (iter == 1) {
                delta = Math.min(delta, snorm);
            }
            double alpha = fnew - f - gs <= 0.0 ? 4.0 : Math.max(0.25, -0.5 * (gs / (fnew - f - gs)));
            delta = actred < 1.0E-4 * prered ? Math.min(Math.max(alpha, 0.25) * snorm, 0.5 * delta) : (actred < 0.25 * prered ? Math.max(0.25 * delta, Math.min(alpha * snorm, 0.5 * delta)) : (actred < 0.75 * prered ? Math.max(0.25 * delta, Math.min(alpha * snorm, 4.0 * delta)) : Math.max(delta, Math.min(alpha * snorm, 4.0 * delta))));
            System.out.println("f = " + f);
            if (actred > 1.0E-4 * prered) {
                ++iter;
                for (j = 0; j < w.size(); ++j) {
                    w.set(j, w_new.get(j));
                }
                f = fnew;
                this.loss.grad(w, (Vector)g);
                gnorm = g.norm(2.0);
                if (gnorm <= this.eps * gnorm1) break;
            }
            if (f < -1.0E32) break;
            if (Math.abs(actred) <= 0.0 && prered <= 0.0) {
                System.out.println("WARNING: actred and prered <= 0%n");
                break;
            }
            if (!(Math.abs(actred) <= 1.0E-12 * Math.abs(f)) || !(Math.abs(prered) <= 1.0E-12 * Math.abs(f))) continue;
            System.out.println("WARNING: actred and prered too small%n");
            break;
        }
    }

    private Pair<Vector, Vector> trcg(double delta, Vector g) {
        int numColumns = this.loss.getNumColumns();
        double one = 1.0;
        DenseVector d = new DenseVector(numColumns);
        DenseVector Hd = new DenseVector(numColumns);
        DenseVector s = new DenseVector(numColumns);
        DenseVector r = new DenseVector(numColumns);
        Pair<Vector, Vector> result = new Pair<Vector, Vector>();
        for (int i = 0; i < numColumns; ++i) {
            s.set(i, 0.0);
            r.set(i, -g.get(i));
            d.set(i, r.get(i));
        }
        double cgtol = 0.1 * g.norm(2.0);
        double rTr = r.dot((Vector)r);
        while (!(r.norm(2.0) <= cgtol)) {
            this.loss.Hv((Vector)d, (Vector)Hd);
            double alpha = rTr / d.dot((Vector)Hd);
            TrustRegionNewtonOptimizer.daxpy(alpha, (Vector)d, (Vector)s);
            if (s.norm(2.0) > delta) {
                alpha = -alpha;
                TrustRegionNewtonOptimizer.daxpy(alpha, (Vector)d, (Vector)s);
                double std = s.dot((Vector)d);
                double sts = s.dot((Vector)s);
                double dtd = d.dot((Vector)d);
                double dsq = delta * delta;
                double rad = Math.sqrt(std * std + dtd * (dsq - sts));
                alpha = std >= 0.0 ? (dsq - sts) / (std + rad) : (rad - std) / dtd;
                TrustRegionNewtonOptimizer.daxpy(alpha, (Vector)d, (Vector)s);
                alpha = -alpha;
                TrustRegionNewtonOptimizer.daxpy(alpha, (Vector)Hd, (Vector)r);
                break;
            }
            alpha = -alpha;
            TrustRegionNewtonOptimizer.daxpy(alpha, (Vector)Hd, (Vector)r);
            double rnewTrnew = r.dot((Vector)r);
            double beta = rnewTrnew / rTr;
            TrustRegionNewtonOptimizer.scale(beta, (Vector)d);
            TrustRegionNewtonOptimizer.daxpy(one, (Vector)r, (Vector)d);
            rTr = rnewTrnew;
        }
        result.setFirst((Vector)s);
        result.setSecond((Vector)r);
        return result;
    }

    private static void daxpy(double constant, Vector vector1, Vector vector2) {
        if (constant == 0.0) {
            return;
        }
        assert (vector1.size() == vector2.size());
        for (int i = 0; i < vector1.size(); ++i) {
            vector2.set(i, vector2.get(i) + constant * vector1.get(i));
        }
    }

    private static void scale(double constant, Vector vector) {
        if (constant == 1.0) {
            return;
        }
        for (int i = 0; i < vector.size(); ++i) {
            vector.set(i, vector.get(i) * constant);
        }
    }
}

