/*
 * Decompiled with CFR 0.152.
 */
package edu.neu.ccs.pyramid.optimization;

import edu.neu.ccs.pyramid.optimization.Optimizable;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.mahout.math.DenseVector;
import org.apache.mahout.math.RandomAccessSparseVector;
import org.apache.mahout.math.Vector;

public class BackTrackingLineSearcher {
    private static final Logger logger = LogManager.getLogger();
    private Optimizable.ByGradientValue function;
    private double initialStepLength = 1.0;
    private double shrinkage = 0.5;
    private double c = 1.0E-4;

    public BackTrackingLineSearcher(Optimizable.ByGradientValue function) {
        this.function = function;
    }

    public MoveInfo moveAlongDirection(Vector searchDirection) {
        double targetValue;
        Vector step;
        Vector localSearchDir;
        if (logger.isDebugEnabled()) {
            logger.debug("start line search");
            if (searchDirection.size() < 100) {
                logger.debug("direction=" + searchDirection);
            }
        }
        MoveInfo moveInfo = new MoveInfo();
        double stepLength = this.initialStepLength;
        double value = this.function.getValue();
        moveInfo.setOldValue(value);
        Vector gradient = this.function.getGradient();
        double product = gradient.dot(searchDirection);
        if (product < 0.0) {
            localSearchDir = searchDirection;
        } else {
            if (logger.isWarnEnabled()) {
                logger.warn("Bad search direction! Use negative gradient instead. Product of gradient and search direction = " + product);
            }
            localSearchDir = gradient.times(-1.0);
        }
        Object initialPosition = this.function.getParameters().isDense() ? new DenseVector(this.function.getParameters()) : new RandomAccessSparseVector(this.function.getParameters());
        while (true) {
            step = localSearchDir.times(stepLength);
            Vector target = initialPosition.plus(step);
            this.function.setParameters(target);
            targetValue = this.function.getValue();
            if (logger.isDebugEnabled()) {
                logger.debug("step length = " + stepLength + ", target value = " + targetValue);
            }
            if (targetValue <= value + this.c * stepLength * product && value < Double.POSITIVE_INFINITY || stepLength == 0.0) break;
            stepLength *= this.shrinkage;
        }
        moveInfo.setStep(step);
        moveInfo.setStepLength(stepLength);
        moveInfo.setNewValue(targetValue);
        if (logger.isDebugEnabled()) {
            logger.debug("line search done. " + moveInfo);
        }
        return moveInfo;
    }

    public void setInitialStepLength(double initialStepLength) {
        this.initialStepLength = initialStepLength;
    }

    public void setShrinkage(double shrinkage) {
        this.shrinkage = shrinkage;
    }

    public void setC(double c) {
        this.c = c;
    }

    public static class MoveInfo {
        private double oldValue;
        private double newValue;
        private Vector step;
        private double stepLength;

        public Vector getStep() {
            return this.step;
        }

        public void setStep(Vector step) {
            this.step = step;
        }

        public double getStepLength() {
            return this.stepLength;
        }

        public void setStepLength(double stepLength) {
            this.stepLength = stepLength;
        }

        public double getOldValue() {
            return this.oldValue;
        }

        public void setOldValue(double oldValue) {
            this.oldValue = oldValue;
        }

        public double getNewValue() {
            return this.newValue;
        }

        public void setNewValue(double newValue) {
            this.newValue = newValue;
        }

        public String toString() {
            StringBuilder sb = new StringBuilder();
            sb.append("oldValue=").append(this.oldValue);
            sb.append(", newValue=").append(this.newValue);
            sb.append(", stepLength=").append(this.stepLength);
            return sb.toString();
        }
    }
}

