/*
 * Decompiled with CFR 0.152.
 */
package smile.tensor;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import smile.math.MathEx;
import smile.tensor.Matrix;
import smile.tensor.Preconditioner;
import smile.tensor.ScalarType;
import smile.tensor.Vector;

public interface BiconjugateGradient {
    public static final Logger logger = LoggerFactory.getLogger(BiconjugateGradient.class);

    public static double solve(Matrix A, Vector b, Vector x) {
        return BiconjugateGradient.solve(A, b, x, Preconditioner.Jacobi(A), 1.0E-6, 1, 2 * A.nrow());
    }

    public static double solve(Matrix A, Vector b, Vector x, Preconditioner P, double tol, int itol, int maxIter) {
        double bnrm;
        int j;
        if (tol <= 0.0) {
            throw new IllegalArgumentException("Invalid tolerance: " + tol);
        }
        if (itol < 1 || itol > 4) {
            throw new IllegalArgumentException("Invalid itol: " + itol);
        }
        if (maxIter <= 0) {
            throw new IllegalArgumentException("Invalid maximum iterations: " + maxIter);
        }
        double eps = A.scalarType() == ScalarType.Float64 ? MathEx.EPSILON : (double)MathEx.FLOAT_EPSILON;
        double err = 0.0;
        double bkden = 1.0;
        double znrm = 0.0;
        int n = b.size();
        Vector p = A.vector(n);
        Vector pp = A.vector(n);
        Vector r = A.vector(n);
        Vector rr = A.vector(n);
        Vector z = A.vector(n);
        Vector zz = A.vector(n);
        A.mv(x, r);
        for (j = 0; j < n; ++j) {
            double br = b.get(j) - r.get(j);
            r.set(j, br);
            rr.set(j, br);
        }
        if (itol == 1) {
            bnrm = BiconjugateGradient.norm(b, itol);
            P.solve(r, z);
        } else if (itol == 2) {
            P.solve(b, z);
            bnrm = BiconjugateGradient.norm(z, itol);
            P.solve(r, z);
        } else {
            P.solve(b, z);
            bnrm = BiconjugateGradient.norm(z, itol);
            P.solve(r, z);
            znrm = BiconjugateGradient.norm(z, itol);
        }
        for (int iter = 1; iter <= maxIter; ++iter) {
            P.solve(rr, zz);
            double bknum = 0.0;
            for (j = 0; j < n; ++j) {
                bknum += z.get(j) * rr.get(j);
            }
            if (iter == 1) {
                for (j = 0; j < n; ++j) {
                    p.set(j, z.get(j));
                    pp.set(j, zz.get(j));
                }
            } else {
                double bk = bknum / bkden;
                for (j = 0; j < n; ++j) {
                    p.set(j, bk * p.get(j) + z.get(j));
                    pp.set(j, bk * pp.get(j) + zz.get(j));
                }
            }
            bkden = bknum;
            A.mv(p, z);
            double akden = 0.0;
            for (j = 0; j < n; ++j) {
                akden += z.get(j) * pp.get(j);
            }
            double ak = bknum / akden;
            A.tv(pp, zz);
            for (j = 0; j < n; ++j) {
                x.add(j, ak * p.get(j));
                r.sub(j, ak * z.get(j));
                rr.sub(j, ak * zz.get(j));
            }
            P.solve(r, z);
            if (itol == 1) {
                err = BiconjugateGradient.norm(r, itol) / bnrm;
            } else if (itol == 2) {
                err = BiconjugateGradient.norm(z, itol) / bnrm;
            } else {
                double zm1nrm = znrm;
                if (!(Math.abs(zm1nrm - (znrm = BiconjugateGradient.norm(z, itol))) > eps * znrm)) {
                    err = znrm / bnrm;
                    continue;
                }
                double dxnrm = Math.abs(ak) * BiconjugateGradient.norm(p, itol);
                err = znrm / Math.abs(zm1nrm - znrm) * dxnrm;
                double xnrm = BiconjugateGradient.norm(x, itol);
                if (err <= 0.5 * xnrm) {
                    err /= xnrm;
                } else {
                    err = znrm / bnrm;
                    continue;
                }
            }
            if (iter % 10 == 0 || err <= tol) {
                logger.info("BCG: the error after {} iterations: {}", (Object)iter, (Object)err);
            }
            if (err <= tol) break;
        }
        return err;
    }

    private static double norm(Vector x, int itol) {
        if (itol <= 3) {
            return x.norm2();
        }
        return x.normInf();
    }
}

