/*
 * Decompiled with CFR 0.152.
 */
package smile.math.matrix;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import smile.math.MathEx;
import smile.math.matrix.Matrix;
import smile.math.matrix.Preconditioner;

public class BiconjugateGradient {
    private static final Logger logger = LoggerFactory.getLogger(BiconjugateGradient.class);
    private static BiconjugateGradient instance = new BiconjugateGradient();
    private double tol = 1.0E-10;
    private int itol = 1;
    private int maxIter = 0;
    private Preconditioner preconditioner;

    public BiconjugateGradient() {
    }

    public BiconjugateGradient(double tol, int itol, int maxIter) {
        this.setTolerance(tol);
        this.setConvergenceTest(itol);
        this.setMaxIter(maxIter);
    }

    public static BiconjugateGradient getInstance() {
        return instance;
    }

    public BiconjugateGradient setTolerance(double tol) {
        if (tol <= 0.0) {
            throw new IllegalArgumentException("Invalid tolerance: " + tol);
        }
        this.tol = tol;
        return this;
    }

    public BiconjugateGradient setConvergenceTest(int itol) {
        if (itol < 1 || itol > 4) {
            throw new IllegalArgumentException(String.format("Invalid itol: %d", itol));
        }
        this.itol = itol;
        return this;
    }

    public BiconjugateGradient setMaxIter(int maxIter) {
        this.maxIter = maxIter;
        return this;
    }

    public BiconjugateGradient setPreconditioner(Preconditioner preconditioner) {
        this.preconditioner = preconditioner;
        return this;
    }

    private Preconditioner diagonalPreconditioner(Matrix A) {
        return (b, x) -> {
            double[] diag = A.diag();
            int n = diag.length;
            for (int i = 0; i < n; ++i) {
                x[i] = diag[i] != 0.0 ? b[i] / diag[i] : b[i];
            }
        };
    }

    public double solve(Matrix A, double[] b, double[] x) {
        double bnrm;
        int j;
        if (this.maxIter <= 0) {
            this.maxIter = 2 * Math.max(A.nrows(), A.ncols());
        }
        if (this.preconditioner == null) {
            this.preconditioner = this.diagonalPreconditioner(A);
        }
        double err = 0.0;
        double bkden = 1.0;
        double znrm = 0.0;
        int n = b.length;
        double[] p = new double[n];
        double[] pp = new double[n];
        double[] r = new double[n];
        double[] rr = new double[n];
        double[] z = new double[n];
        double[] zz = new double[n];
        A.ax(x, r);
        for (j = 0; j < n; ++j) {
            r[j] = b[j] - r[j];
            rr[j] = r[j];
        }
        if (this.itol == 1) {
            bnrm = this.snorm(b);
            this.preconditioner.solve(r, z);
        } else if (this.itol == 2) {
            this.preconditioner.solve(b, z);
            bnrm = this.snorm(z);
            this.preconditioner.solve(r, z);
        } else if (this.itol == 3 || this.itol == 4) {
            this.preconditioner.solve(b, z);
            bnrm = this.snorm(z);
            this.preconditioner.solve(r, z);
            znrm = this.snorm(z);
        } else {
            throw new IllegalArgumentException(String.format("Illegal itol: %d", this.itol));
        }
        for (int iter = 1; iter <= this.maxIter; ++iter) {
            this.preconditioner.solve(rr, zz);
            double bknum = 0.0;
            for (j = 0; j < n; ++j) {
                bknum += z[j] * rr[j];
            }
            if (iter == 1) {
                for (j = 0; j < n; ++j) {
                    p[j] = z[j];
                    pp[j] = zz[j];
                }
            } else {
                double bk = bknum / bkden;
                for (j = 0; j < n; ++j) {
                    p[j] = bk * p[j] + z[j];
                    pp[j] = bk * pp[j] + zz[j];
                }
            }
            bkden = bknum;
            A.ax(p, z);
            double akden = 0.0;
            for (j = 0; j < n; ++j) {
                akden += z[j] * pp[j];
            }
            double ak = bknum / akden;
            A.atx(pp, zz);
            for (j = 0; j < n; ++j) {
                int n2 = j;
                x[n2] = x[n2] + ak * p[j];
                int n3 = j;
                r[n3] = r[n3] - ak * z[j];
                int n4 = j;
                rr[n4] = rr[n4] - ak * zz[j];
            }
            this.preconditioner.solve(r, z);
            if (this.itol == 1) {
                err = this.snorm(r) / bnrm;
            } else if (this.itol == 2) {
                err = this.snorm(z) / bnrm;
            } else if (this.itol == 3 || this.itol == 4) {
                double zm1nrm = znrm;
                if (!(Math.abs(zm1nrm - (znrm = this.snorm(z))) > MathEx.EPSILON * znrm)) {
                    err = znrm / bnrm;
                    continue;
                }
                double dxnrm = Math.abs(ak) * this.snorm(p);
                err = znrm / Math.abs(zm1nrm - znrm) * dxnrm;
                double xnrm = this.snorm(x);
                if (err <= 0.5 * xnrm) {
                    err /= xnrm;
                } else {
                    err = znrm / bnrm;
                    continue;
                }
            }
            if (iter % 10 == 0) {
                logger.info(String.format("BCG: the error after %3d iterations: %.5g", iter, err));
            }
            if (!(err <= this.tol)) continue;
            logger.info(String.format("BCG: the error after %3d iterations: %.5g", iter, err));
            break;
        }
        return err;
    }

    private double snorm(double[] x) {
        int n = x.length;
        if (this.itol <= 3) {
            double ans = 0.0;
            for (int i = 0; i < n; ++i) {
                ans += x[i] * x[i];
            }
            return Math.sqrt(ans);
        }
        int isamax = 0;
        for (int i = 0; i < n; ++i) {
            if (!(Math.abs(x[i]) > Math.abs(x[isamax]))) continue;
            isamax = i;
        }
        return Math.abs(x[isamax]);
    }
}

