/*
 * Decompiled with CFR 0.152.
 */
package smile.math;

import java.util.Arrays;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import smile.math.DifferentiableMultivariateFunction;
import smile.math.MathEx;
import smile.math.MultivariateFunction;

public class BFGS {
    private static final Logger logger = LoggerFactory.getLogger(BFGS.class);
    private double gtol = 1.0E-5;
    private int maxIter = 500;

    public BFGS() {
        this(1.0E-5, 500);
    }

    public BFGS(double gtol, int maxIter) {
        if (gtol <= 0.0) {
            throw new IllegalArgumentException("Invalid gradient tolerance: " + gtol);
        }
        if (maxIter <= 0) {
            throw new IllegalArgumentException("Invalid maximum number of iterations: " + maxIter);
        }
        this.gtol = gtol;
        this.maxIter = maxIter;
    }

    public double minimize(DifferentiableMultivariateFunction func, int m, double[] x) {
        if (m <= 0) {
            throw new IllegalArgumentException("Invalid m: " + m);
        }
        double TOLX = 4.0 * MathEx.EPSILON;
        double STPMX = 100.0;
        int n = x.length;
        double[] xnew = new double[n];
        double[] gnew = new double[n];
        double[] xi = new double[n];
        double[][] s = new double[m][n];
        double[][] y = new double[m][n];
        double[] rho = new double[m];
        double[] a = new double[m];
        double diag = 1.0;
        double[] g = new double[n];
        double f = func.g(x, g);
        logger.info(String.format("L-BFGS: initial function value: %.5f", f));
        double sum = 0.0;
        for (int i = 0; i < n; ++i) {
            xi[i] = -g[i];
            sum += x[i] * x[i];
        }
        double stpmax = 100.0 * Math.max(Math.sqrt(sum), (double)n);
        int k = 0;
        for (int iter = 1; iter <= this.maxIter; ++iter) {
            int i;
            this.linesearch(func, x, f, g, xi, xnew, stpmax);
            f = func.g(xnew, gnew);
            for (int i2 = 0; i2 < n; ++i2) {
                s[k][i2] = xnew[i2] - x[i2];
                y[k][i2] = gnew[i2] - g[i2];
                x[i2] = xnew[i2];
                g[i2] = gnew[i2];
            }
            double test = 0.0;
            for (int i3 = 0; i3 < n; ++i3) {
                double temp = Math.abs(s[k][i3]) / Math.max(Math.abs(x[i3]), 1.0);
                if (!(temp > test)) continue;
                test = temp;
            }
            if (test < TOLX) {
                logger.info(String.format("L-BFGS converges on x after %d iterations: %.5f", iter, f));
                return f;
            }
            test = 0.0;
            double den = Math.max(f, 1.0);
            for (int i4 = 0; i4 < n; ++i4) {
                double temp = Math.abs(g[i4]) * Math.max(Math.abs(x[i4]), 1.0) / den;
                if (!(temp > test)) continue;
                test = temp;
            }
            if (test < this.gtol) {
                logger.info(String.format("L-BFGS converges on gradient after %d iterations: %.5f", iter, f));
                return f;
            }
            if (iter % 100 == 0) {
                logger.info(String.format("L-BFGS: the function value after %3d iterations: %.5f", iter, f));
            }
            double ys = MathEx.dot(y[k], s[k]);
            double yy = MathEx.dot(y[k], y[k]);
            diag = ys / yy;
            rho[k] = 1.0 / ys;
            for (int i5 = 0; i5 < n; ++i5) {
                xi[i5] = -g[i5];
            }
            int cp = k;
            int bound = iter > m ? m : iter;
            for (i = 0; i < bound; ++i) {
                a[cp] = rho[cp] * MathEx.dot(s[cp], xi);
                MathEx.axpy(-a[cp], y[cp], xi);
                if (--cp != -1) continue;
                cp = m - 1;
            }
            i = 0;
            while (i < n) {
                int n2 = i++;
                xi[n2] = xi[n2] * diag;
            }
            for (i = 0; i < bound; ++i) {
                if (++cp == m) {
                    cp = 0;
                }
                double b = rho[cp] * MathEx.dot(y[cp], xi);
                MathEx.axpy(a[cp] - b, s[cp], xi);
            }
            if (++k != m) continue;
            k = 0;
        }
        logger.warn("L-BFGS reaches the maximum number of iterations: " + this.maxIter);
        return f;
    }

    public double minimize(DifferentiableMultivariateFunction func, double[] x) {
        double TOLX = 4.0 * MathEx.EPSILON;
        double STPMX = 100.0;
        int n = x.length;
        double[] dg = new double[n];
        double[] g = new double[n];
        double[] hdg = new double[n];
        double[] xnew = new double[n];
        double[] xi = new double[n];
        double[][] hessin = new double[n][n];
        double f = func.g(x, g);
        logger.info(String.format("BFGS: initial function value: %.5f", f));
        double sum = 0.0;
        for (int i = 0; i < n; ++i) {
            hessin[i][i] = 1.0;
            xi[i] = -g[i];
            sum += x[i] * x[i];
        }
        double stpmax = 100.0 * Math.max(Math.sqrt(sum), (double)n);
        for (int iter = 1; iter <= this.maxIter; ++iter) {
            int j;
            double temp;
            int i;
            f = this.linesearch(func, x, f, g, xi, xnew, stpmax);
            if (iter % 100 == 0) {
                logger.info(String.format("BFGS: the function value after %3d iterations: %.5f", iter, f));
            }
            for (i = 0; i < n; ++i) {
                xi[i] = xnew[i] - x[i];
                x[i] = xnew[i];
            }
            double test = 0.0;
            for (i = 0; i < n; ++i) {
                temp = Math.abs(xi[i]) / Math.max(Math.abs(x[i]), 1.0);
                if (!(temp > test)) continue;
                test = temp;
            }
            if (test < TOLX) {
                logger.info(String.format("BFGS converges on x after %d iterations: %.5f", iter, f));
                return f;
            }
            System.arraycopy(g, 0, dg, 0, n);
            func.g(x, g);
            double den = Math.max(f, 1.0);
            test = 0.0;
            for (i = 0; i < n; ++i) {
                temp = Math.abs(g[i]) * Math.max(Math.abs(x[i]), 1.0) / den;
                if (!(temp > test)) continue;
                test = temp;
            }
            if (test < this.gtol) {
                logger.info(String.format("BFGS converges on gradient after %d iterations: %.5f", iter, f));
                return f;
            }
            for (i = 0; i < n; ++i) {
                dg[i] = g[i] - dg[i];
            }
            for (i = 0; i < n; ++i) {
                hdg[i] = 0.0;
                for (j = 0; j < n; ++j) {
                    int n2 = i;
                    hdg[n2] = hdg[n2] + hessin[i][j] * dg[j];
                }
            }
            double sumxi = 0.0;
            double sumdg = 0.0;
            double fae = 0.0;
            double fac = 0.0;
            for (i = 0; i < n; ++i) {
                fac += dg[i] * xi[i];
                fae += dg[i] * hdg[i];
                sumdg += dg[i] * dg[i];
                sumxi += xi[i] * xi[i];
            }
            if (fac > Math.sqrt(MathEx.EPSILON * sumdg * sumxi)) {
                fac = 1.0 / fac;
                double fad = 1.0 / fae;
                for (i = 0; i < n; ++i) {
                    dg[i] = fac * xi[i] - fad * hdg[i];
                }
                for (i = 0; i < n; ++i) {
                    for (j = i; j < n; ++j) {
                        double[] dArray = hessin[i];
                        int n3 = j;
                        dArray[n3] = dArray[n3] + (fac * xi[i] * xi[j] - fad * hdg[i] * hdg[j] + fae * dg[i] * dg[j]);
                        hessin[j][i] = hessin[i][j];
                    }
                }
            }
            Arrays.fill(xi, 0.0);
            for (i = 0; i < n; ++i) {
                for (j = 0; j < n; ++j) {
                    int n4 = i;
                    xi[n4] = xi[n4] - hessin[i][j] * g[j];
                }
            }
        }
        logger.warn("BFGS reaches the maximum number of iterations: " + this.maxIter);
        return f;
    }

    private double linesearch(MultivariateFunction func, double[] xold, double fold, double[] g, double[] p, double[] x, double stpmax) {
        int i;
        if (stpmax <= 0.0) {
            throw new IllegalArgumentException("Invalid upper bound of linear search step: " + stpmax);
        }
        double xtol = MathEx.EPSILON;
        double ftol = 1.0E-4;
        int n = xold.length;
        double pnorm = MathEx.norm(p);
        if (pnorm > stpmax) {
            double r = stpmax / pnorm;
            i = 0;
            while (i < n) {
                int n2 = i++;
                p[n2] = p[n2] * r;
            }
        }
        double slope = 0.0;
        for (i = 0; i < n; ++i) {
            slope += g[i] * p[i];
        }
        if (slope >= 0.0) {
            throw new IllegalArgumentException("Line Search: the search direction is not a descent direction, which may be caused by roundoff problem.");
        }
        double test = 0.0;
        for (int i2 = 0; i2 < n; ++i2) {
            double temp = Math.abs(p[i2]) / Math.max(xold[i2], 1.0);
            if (!(temp > test)) continue;
            test = temp;
        }
        double alammin = xtol / test;
        double alam = 1.0;
        double alam2 = 0.0;
        double f2 = 0.0;
        while (true) {
            double tmpalam;
            for (int i3 = 0; i3 < n; ++i3) {
                x[i3] = xold[i3] + alam * p[i3];
            }
            double f = func.apply(x);
            if (alam < alammin) {
                System.arraycopy(xold, 0, x, 0, n);
                return f;
            }
            if (f <= fold + 1.0E-4 * alam * slope) {
                return f;
            }
            if (alam == 1.0) {
                tmpalam = -slope / (2.0 * (f - fold - slope));
            } else {
                double disc;
                double rhs1 = f - fold - alam * slope;
                double rhs2 = f2 - fold - alam2 * slope;
                double a = (rhs1 / (alam * alam) - rhs2 / (alam2 * alam2)) / (alam - alam2);
                double b = (-alam2 * rhs1 / (alam * alam) + alam * rhs2 / (alam2 * alam2)) / (alam - alam2);
                tmpalam = a == 0.0 ? -slope / (2.0 * b) : ((disc = b * b - 3.0 * a * slope) < 0.0 ? 0.5 * alam : (b <= 0.0 ? (-b + Math.sqrt(disc)) / (3.0 * a) : -slope / (b + Math.sqrt(disc))));
                if (tmpalam > 0.5 * alam) {
                    tmpalam = 0.5 * alam;
                }
            }
            alam2 = alam;
            f2 = f;
            alam = Math.max(tmpalam, 0.1 * alam);
        }
    }
}

