package smile.math;

import java.util.Arrays;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:smile/math/BFGS.class */
public class BFGS {
    private static final Logger logger = LoggerFactory.getLogger(BFGS.class);
    private double gtol;
    private int maxIter;

    public BFGS() {
        this(1.0E-5d, 500);
    }

    public BFGS(double d, int i) {
        this.gtol = 1.0E-5d;
        this.maxIter = 500;
        if (d <= 0.0d) {
            throw new IllegalArgumentException("Invalid gradient tolerance: " + d);
        }
        if (i <= 0) {
            throw new IllegalArgumentException("Invalid maximum number of iterations: " + i);
        }
        this.gtol = d;
        this.maxIter = i;
    }

    public double minimize(DifferentiableMultivariateFunction differentiableMultivariateFunction, int i, double[] dArr) {
        if (i <= 0) {
            throw new IllegalArgumentException("Invalid m: " + i);
        }
        double d = 4.0d * MathEx.EPSILON;
        int length = dArr.length;
        double[] dArr2 = new double[length];
        double[] dArr3 = new double[length];
        double[] dArr4 = new double[length];
        double[][] dArr5 = new double[i][length];
        double[][] dArr6 = new double[i][length];
        double[] dArr7 = new double[i];
        double[] dArr8 = new double[i];
        double[] dArr9 = new double[length];
        double g = differentiableMultivariateFunction.g(dArr, dArr9);
        logger.info(String.format("L-BFGS: initial function value: %.5f", Double.valueOf(g)));
        double d2 = 0.0d;
        for (int i2 = 0; i2 < length; i2++) {
            dArr4[i2] = -dArr9[i2];
            d2 += dArr[i2] * dArr[i2];
        }
        double max = 100.0d * Math.max(Math.sqrt(d2), length);
        int i3 = 1;
        int i4 = 0;
        while (i3 <= this.maxIter) {
            linesearch(differentiableMultivariateFunction, dArr, g, dArr9, dArr4, dArr2, max);
            g = differentiableMultivariateFunction.g(dArr2, dArr3);
            for (int i5 = 0; i5 < length; i5++) {
                dArr5[i4][i5] = dArr2[i5] - dArr[i5];
                dArr6[i4][i5] = dArr3[i5] - dArr9[i5];
                dArr[i5] = dArr2[i5];
                dArr9[i5] = dArr3[i5];
            }
            double d3 = 0.0d;
            for (int i6 = 0; i6 < length; i6++) {
                double abs = Math.abs(dArr5[i4][i6]) / Math.max(Math.abs(dArr[i6]), 1.0d);
                if (abs > d3) {
                    d3 = abs;
                }
            }
            if (d3 < d) {
                logger.info(String.format("L-BFGS converges on x after %d iterations: %.5f", Integer.valueOf(i3), Double.valueOf(g)));
                return g;
            }
            double d4 = 0.0d;
            double max2 = Math.max(g, 1.0d);
            for (int i7 = 0; i7 < length; i7++) {
                double abs2 = (Math.abs(dArr9[i7]) * Math.max(Math.abs(dArr[i7]), 1.0d)) / max2;
                if (abs2 > d4) {
                    d4 = abs2;
                }
            }
            if (d4 < this.gtol) {
                logger.info(String.format("L-BFGS converges on gradient after %d iterations: %.5f", Integer.valueOf(i3), Double.valueOf(g)));
                return g;
            }
            if (i3 % 100 == 0) {
                logger.info(String.format("L-BFGS: the function value after %3d iterations: %.5f", Integer.valueOf(i3), Double.valueOf(g)));
            }
            double dot = MathEx.dot(dArr6[i4], dArr5[i4]);
            double dot2 = dot / MathEx.dot(dArr6[i4], dArr6[i4]);
            dArr7[i4] = 1.0d / dot;
            for (int i8 = 0; i8 < length; i8++) {
                dArr4[i8] = -dArr9[i8];
            }
            int i9 = i4;
            int i10 = i3 > i ? i : i3;
            for (int i11 = 0; i11 < i10; i11++) {
                dArr8[i9] = dArr7[i9] * MathEx.dot(dArr5[i9], dArr4);
                MathEx.axpy(-dArr8[i9], dArr6[i9], dArr4);
                i9--;
                if (i9 == -1) {
                    i9 = i - 1;
                }
            }
            for (int i12 = 0; i12 < length; i12++) {
                int i13 = i12;
                dArr4[i13] = dArr4[i13] * dot2;
            }
            for (int i14 = 0; i14 < i10; i14++) {
                i9++;
                if (i9 == i) {
                    i9 = 0;
                }
                MathEx.axpy(dArr8[i9] - (dArr7[i9] * MathEx.dot(dArr6[i9], dArr4)), dArr5[i9], dArr4);
            }
            i4++;
            if (i4 == i) {
                i4 = 0;
            }
            i3++;
        }
        logger.warn("L-BFGS reaches the maximum number of iterations: " + this.maxIter);
        return g;
    }

    public double minimize(DifferentiableMultivariateFunction differentiableMultivariateFunction, double[] dArr) {
        double d = 4.0d * MathEx.EPSILON;
        int length = dArr.length;
        double[] dArr2 = new double[length];
        double[] dArr3 = new double[length];
        double[] dArr4 = new double[length];
        double[] dArr5 = new double[length];
        double[] dArr6 = new double[length];
        double[][] dArr7 = new double[length][length];
        double g = differentiableMultivariateFunction.g(dArr, dArr3);
        logger.info(String.format("BFGS: initial function value: %.5f", Double.valueOf(g)));
        double d2 = 0.0d;
        for (int i = 0; i < length; i++) {
            dArr7[i][i] = 1.0d;
            dArr6[i] = -dArr3[i];
            d2 += dArr[i] * dArr[i];
        }
        double max = 100.0d * Math.max(Math.sqrt(d2), length);
        for (int i2 = 1; i2 <= this.maxIter; i2++) {
            g = linesearch(differentiableMultivariateFunction, dArr, g, dArr3, dArr6, dArr5, max);
            if (i2 % 100 == 0) {
                logger.info(String.format("BFGS: the function value after %3d iterations: %.5f", Integer.valueOf(i2), Double.valueOf(g)));
            }
            for (int i3 = 0; i3 < length; i3++) {
                dArr6[i3] = dArr5[i3] - dArr[i3];
                dArr[i3] = dArr5[i3];
            }
            double d3 = 0.0d;
            for (int i4 = 0; i4 < length; i4++) {
                double abs = Math.abs(dArr6[i4]) / Math.max(Math.abs(dArr[i4]), 1.0d);
                if (abs > d3) {
                    d3 = abs;
                }
            }
            if (d3 < d) {
                logger.info(String.format("BFGS converges on x after %d iterations: %.5f", Integer.valueOf(i2), Double.valueOf(g)));
                return g;
            }
            System.arraycopy(dArr3, 0, dArr2, 0, length);
            differentiableMultivariateFunction.g(dArr, dArr3);
            double max2 = Math.max(g, 1.0d);
            double d4 = 0.0d;
            for (int i5 = 0; i5 < length; i5++) {
                double abs2 = (Math.abs(dArr3[i5]) * Math.max(Math.abs(dArr[i5]), 1.0d)) / max2;
                if (abs2 > d4) {
                    d4 = abs2;
                }
            }
            if (d4 < this.gtol) {
                logger.info(String.format("BFGS converges on gradient after %d iterations: %.5f", Integer.valueOf(i2), Double.valueOf(g)));
                return g;
            }
            for (int i6 = 0; i6 < length; i6++) {
                dArr2[i6] = dArr3[i6] - dArr2[i6];
            }
            for (int i7 = 0; i7 < length; i7++) {
                dArr4[i7] = 0.0d;
                for (int i8 = 0; i8 < length; i8++) {
                    int i9 = i7;
                    dArr4[i9] = dArr4[i9] + (dArr7[i7][i8] * dArr2[i8]);
                }
            }
            double d5 = 0.0d;
            double d6 = 0.0d;
            double d7 = 0.0d;
            double d8 = 0.0d;
            for (int i10 = 0; i10 < length; i10++) {
                d8 += dArr2[i10] * dArr6[i10];
                d7 += dArr2[i10] * dArr4[i10];
                d6 += dArr2[i10] * dArr2[i10];
                d5 += dArr6[i10] * dArr6[i10];
            }
            if (d8 > Math.sqrt(MathEx.EPSILON * d6 * d5)) {
                double d9 = 1.0d / d8;
                double d10 = 1.0d / d7;
                for (int i11 = 0; i11 < length; i11++) {
                    dArr2[i11] = (d9 * dArr6[i11]) - (d10 * dArr4[i11]);
                }
                for (int i12 = 0; i12 < length; i12++) {
                    for (int i13 = i12; i13 < length; i13++) {
                        double[] dArr8 = dArr7[i12];
                        int i14 = i13;
                        dArr8[i14] = dArr8[i14] + (((d9 * dArr6[i12]) * dArr6[i13]) - ((d10 * dArr4[i12]) * dArr4[i13])) + (d7 * dArr2[i12] * dArr2[i13]);
                        dArr7[i13][i12] = dArr7[i12][i13];
                    }
                }
            }
            Arrays.fill(dArr6, 0.0d);
            for (int i15 = 0; i15 < length; i15++) {
                for (int i16 = 0; i16 < length; i16++) {
                    int i17 = i15;
                    dArr6[i17] = dArr6[i17] - (dArr7[i15][i16] * dArr3[i16]);
                }
            }
        }
        logger.warn("BFGS reaches the maximum number of iterations: " + this.maxIter);
        return g;
    }

    private double linesearch(MultivariateFunction multivariateFunction, double[] dArr, double d, double[] dArr2, double[] dArr3, double[] dArr4, double d2) {
        double sqrt;
        if (d2 <= 0.0d) {
            throw new IllegalArgumentException("Invalid upper bound of linear search step: " + d2);
        }
        double d3 = MathEx.EPSILON;
        int length = dArr.length;
        double norm = MathEx.norm(dArr3);
        if (norm > d2) {
            double d4 = d2 / norm;
            for (int i = 0; i < length; i++) {
                int i2 = i;
                dArr3[i2] = dArr3[i2] * d4;
            }
        }
        double d5 = 0.0d;
        for (int i3 = 0; i3 < length; i3++) {
            d5 += dArr2[i3] * dArr3[i3];
        }
        if (d5 >= 0.0d) {
            throw new IllegalArgumentException("Line Search: the search direction is not a descent direction, which may be caused by roundoff problem.");
        }
        double d6 = 0.0d;
        for (int i4 = 0; i4 < length; i4++) {
            double abs = Math.abs(dArr3[i4]) / Math.max(dArr[i4], 1.0d);
            if (abs > d6) {
                d6 = abs;
            }
        }
        double d7 = d3 / d6;
        double d8 = 1.0d;
        double d9 = 0.0d;
        double d10 = 0.0d;
        while (true) {
            for (int i5 = 0; i5 < length; i5++) {
                dArr4[i5] = dArr[i5] + (d8 * dArr3[i5]);
            }
            double apply = multivariateFunction.apply(dArr4);
            if (d8 < d7) {
                System.arraycopy(dArr, 0, dArr4, 0, length);
                return apply;
            }
            if (apply <= d + (1.0E-4d * d8 * d5)) {
                return apply;
            }
            if (d8 == 1.0d) {
                sqrt = (-d5) / (2.0d * ((apply - d) - d5));
            } else {
                double d11 = (apply - d) - (d8 * d5);
                double d12 = (d10 - d) - (d9 * d5);
                double d13 = ((d11 / (d8 * d8)) - (d12 / (d9 * d9))) / (d8 - d9);
                double d14 = ((((-d9) * d11) / (d8 * d8)) + ((d8 * d12) / (d9 * d9))) / (d8 - d9);
                if (d13 == 0.0d) {
                    sqrt = (-d5) / (2.0d * d14);
                } else {
                    double d15 = (d14 * d14) - ((3.0d * d13) * d5);
                    sqrt = d15 < 0.0d ? 0.5d * d8 : d14 <= 0.0d ? ((-d14) + Math.sqrt(d15)) / (3.0d * d13) : (-d5) / (d14 + Math.sqrt(d15));
                }
                if (sqrt > 0.5d * d8) {
                    sqrt = 0.5d * d8;
                }
            }
            d9 = d8;
            d10 = apply;
            d8 = Math.max(sqrt, 0.1d * d8);
        }
    }
}
