/*
 * Decompiled with CFR 0.152.
 */
package edu.neu.ccs.pyramid.optimization;

import edu.neu.ccs.pyramid.optimization.BackTrackingLineSearcher;
import edu.neu.ccs.pyramid.optimization.GradientValueOptimizer;
import edu.neu.ccs.pyramid.optimization.Optimizable;
import edu.neu.ccs.pyramid.optimization.Optimizer;
import java.util.Iterator;
import java.util.LinkedList;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.mahout.math.DenseVector;
import org.apache.mahout.math.Vector;

public class LBFGS
extends GradientValueOptimizer
implements Optimizer {
    private static final Logger logger = LogManager.getLogger();
    private BackTrackingLineSearcher lineSearcher;
    private double m = 5.0;
    private LinkedList<Vector> sQueue;
    private LinkedList<Vector> yQueue;
    private LinkedList<Double> rhoQueue;

    public LBFGS(Optimizable.ByGradientValue function) {
        super(function);
        this.lineSearcher = new BackTrackingLineSearcher(function);
        this.lineSearcher.setInitialStepLength(1.0);
        this.sQueue = new LinkedList();
        this.yQueue = new LinkedList();
        this.rhoQueue = new LinkedList();
    }

    private void reset() {
        this.sQueue = new LinkedList();
        this.yQueue = new LinkedList();
        this.rhoQueue = new LinkedList();
    }

    @Override
    public void iterate() {
        if (logger.isDebugEnabled()) {
            logger.debug("start one iteration");
        }
        DenseVector oldGradient = new DenseVector(this.function.getGradient());
        Vector direction = this.findDirection();
        if (logger.isDebugEnabled()) {
            logger.debug("norm of direction = " + direction.norm(2.0));
        }
        BackTrackingLineSearcher.MoveInfo moveInfo = this.lineSearcher.moveAlongDirection(direction);
        Vector s = moveInfo.getStep();
        Vector newGradient = this.function.getGradient();
        Vector y = newGradient.minus((Vector)oldGradient);
        double denominator = y.dot(s);
        double rho = 0.0;
        if (denominator > 0.0) {
            rho = 1.0 / denominator;
        } else {
            this.terminator.forceTerminate();
            if (logger.isWarnEnabled()) {
                logger.warn("denominator <= 0, force to terminate");
            }
        }
        if (logger.isDebugEnabled()) {
            if (y.size() < 100) {
                logger.debug("y= " + y);
                logger.debug("s= " + s);
            }
            logger.debug("denominator = " + denominator);
            logger.debug("rho = " + rho);
        }
        this.sQueue.add(s);
        this.yQueue.add(y);
        this.rhoQueue.add(rho);
        if ((double)this.sQueue.size() > this.m) {
            this.sQueue.remove();
            this.yQueue.remove();
            this.rhoQueue.remove();
        }
        double value = this.function.getValue();
        this.terminator.add(value);
        if (logger.isDebugEnabled()) {
            logger.debug("finish one iteration. loss = " + value);
        }
    }

    Vector findDirection() {
        Vector g = this.function.getGradient();
        DenseVector q = new DenseVector(g.size());
        q.assign(g);
        Iterator<Double> rhoDesIterator = this.rhoQueue.descendingIterator();
        Iterator<Vector> sDesIterator = this.sQueue.descendingIterator();
        Iterator<Vector> yDesIterator = this.yQueue.descendingIterator();
        LinkedList<Double> alphaQueue = new LinkedList<Double>();
        while (rhoDesIterator.hasNext()) {
            double rho = rhoDesIterator.next();
            Vector s = sDesIterator.next();
            Vector y = yDesIterator.next();
            double alpha = s.dot((Vector)q) * rho;
            alphaQueue.addFirst(alpha);
            q = q.minus(y.times(alpha));
        }
        double gamma = this.gamma();
        Vector r = q.times(gamma);
        Iterator rhoIterator = this.rhoQueue.iterator();
        Iterator sIterator = this.sQueue.iterator();
        Iterator yIterator = this.yQueue.iterator();
        Iterator alphaIterator = alphaQueue.iterator();
        while (rhoIterator.hasNext()) {
            double rho = (Double)rhoIterator.next();
            Vector s = (Vector)sIterator.next();
            Vector y = (Vector)yIterator.next();
            double alpha = (Double)alphaIterator.next();
            double beta = y.dot(r) * rho;
            r = r.plus(s.times(alpha - beta));
        }
        return r.times(-1.0);
    }

    double gamma() {
        if (this.sQueue.isEmpty()) {
            return 1.0;
        }
        Vector s = this.sQueue.getLast();
        Vector y = this.yQueue.getLast();
        double denominator = y.dot(y);
        if (denominator <= 0.0) {
            return 1.0;
        }
        return s.dot(y) / y.dot(y);
    }

    public void setHistory(double m) {
        this.m = m;
    }
}

