/*
 * Decompiled with CFR 0.152.
 */
package smile.tensor;

import java.io.Serializable;
import java.lang.foreign.MemorySegment;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import smile.linalg.Order;
import smile.linalg.Transpose;
import smile.linalg.UPLO;
import smile.linalg.blas.cblas_h;
import smile.linalg.lapack.clapack_h;
import smile.math.MathEx;
import smile.tensor.BandMatrix32;
import smile.tensor.BandMatrix64;
import smile.tensor.DenseMatrix;
import smile.tensor.Matrix;
import smile.tensor.ScalarType;
import smile.tensor.Vector;

public abstract class BandMatrix
implements Matrix,
Serializable {
    private static final Logger logger = LoggerFactory.getLogger(BandMatrix.class);
    transient MemorySegment memory;
    final int m;
    final int n;
    final int kl;
    final int ku;
    final int ld;
    UPLO uplo = null;

    BandMatrix() {
        this.memory = null;
        this.m = 0;
        this.n = 0;
        this.kl = 0;
        this.ku = 0;
        this.ld = 0;
        this.uplo = null;
    }

    BandMatrix(MemorySegment memory, int m, int n, int kl, int ku) {
        if (m <= 0 || n <= 0) {
            throw new IllegalArgumentException(String.format("Invalid matrix size: %d x %d", m, n));
        }
        if (kl < 0 || ku < 0) {
            throw new IllegalArgumentException(String.format("Invalid subdiagonals or superdiagonals: kl = %d, ku = %d", kl, ku));
        }
        if (kl >= m) {
            throw new IllegalArgumentException(String.format("Invalid subdiagonals %d >= %d", kl, m));
        }
        if (ku >= n) {
            throw new IllegalArgumentException(String.format("Invalid superdiagonals %d >= %d", ku, n));
        }
        this.memory = memory;
        this.m = m;
        this.n = n;
        this.kl = kl;
        this.ku = ku;
        this.ld = kl + ku + 1;
    }

    public static BandMatrix zeros(ScalarType scalarType, int m, int n, int kl, int ku) {
        int ld = kl + ku + 1;
        return switch (scalarType) {
            case ScalarType.Float64 -> {
                double[] AB = new double[ld * n];
                yield new BandMatrix64(m, n, kl, ku, AB);
            }
            case ScalarType.Float32 -> {
                float[] AB = new float[ld * n];
                yield new BandMatrix32(m, n, kl, ku, AB);
            }
            default -> throw new UnsupportedOperationException("Unsupported ScalarType: " + String.valueOf((Object)scalarType));
        };
    }

    public static BandMatrix of(int m, int n, int kl, int ku, double[][] ab) {
        int ld = kl + ku + 1;
        double[] AB = new double[ld * n];
        BandMatrix64 matrix = new BandMatrix64(m, n, kl, ku, AB);
        for (int j = 0; j < n; ++j) {
            for (int i = 0; i < ld; ++i) {
                AB[j * ld + i] = ab[i][j];
            }
        }
        return matrix;
    }

    public static BandMatrix of(int m, int n, int kl, int ku, float[][] ab) {
        int ld = kl + ku + 1;
        float[] AB = new float[ld * n];
        BandMatrix32 matrix = new BandMatrix32(m, n, kl, ku, AB);
        for (int j = 0; j < n; ++j) {
            for (int i = 0; i < ld; ++i) {
                AB[j * ld + i] = ab[i][j];
            }
        }
        return matrix;
    }

    @Override
    public int nrow() {
        return this.m;
    }

    @Override
    public int ncol() {
        return this.n;
    }

    public int kl() {
        return this.kl;
    }

    public int ku() {
        return this.ku;
    }

    public Order layout() {
        return Order.COL_MAJOR;
    }

    public int ld() {
        return this.ld;
    }

    public boolean isSymmetric() {
        return this.m == this.n && this.kl == this.ku && this.uplo != null;
    }

    @Override
    public BandMatrix scale(double alpha) {
        switch (this.scalarType()) {
            case Float64: {
                cblas_h.cblas_dscal((int)this.length(), alpha, this.memory, 1);
                break;
            }
            case Float32: {
                cblas_h.cblas_sscal((int)this.length(), (float)alpha, this.memory, 1);
                break;
            }
            default: {
                throw new UnsupportedOperationException("Unsupported scalar type: " + String.valueOf((Object)this.scalarType()));
            }
        }
        return this;
    }

    @Override
    public abstract BandMatrix copy();

    @Override
    public BandMatrix transpose() {
        BandMatrix trans = BandMatrix.zeros(this.scalarType(), this.n, this.m, this.ku, this.kl);
        for (int j = 0; j < this.n; ++j) {
            int i;
            int k;
            for (k = 0; k <= this.kl; ++k) {
                i = j + k;
                if (i >= this.m) continue;
                trans.set(j, i, this.get(i, j));
            }
            for (k = 1; k <= this.ku; ++k) {
                i = j - k;
                if (i < 0) continue;
                trans.set(j, i, this.get(i, j));
            }
        }
        return trans;
    }

    public BandMatrix withUplo(UPLO uplo) {
        if (this.m != this.n) {
            throw new IllegalArgumentException(String.format("The matrix is not square: %d x %d", this.m, this.n));
        }
        if (this.kl != this.ku) {
            throw new IllegalArgumentException(String.format("kl != ku: %d != %d", this.kl, this.ku));
        }
        this.uplo = uplo;
        return this;
    }

    public UPLO uplo() {
        return this.uplo;
    }

    public boolean equals(Object o) {
        double tol = 10.0f * MathEx.FLOAT_EPSILON;
        if (o instanceof BandMatrix) {
            BandMatrix b = (BandMatrix)o;
            if (this.nrow() == b.nrow() && this.ncol() == b.ncol() && this.kl == b.kl && this.ku == b.ku) {
                for (int j = 0; j < this.n; ++j) {
                    int i;
                    int k;
                    for (k = 0; k <= this.kl; ++k) {
                        i = j + k;
                        if (i >= this.m || !(Math.abs(this.get(i, j) - b.get(i, j)) > tol)) continue;
                        return false;
                    }
                    for (k = 1; k <= this.ku; ++k) {
                        i = j - k;
                        if (i < 0 || !(Math.abs(this.get(i, j) - b.get(i, j)) > tol)) continue;
                        return false;
                    }
                }
                return true;
            }
        }
        return false;
    }

    @Override
    public void mv(Transpose trans, double alpha, Vector x, double beta, Vector y) {
        block11: {
            block10: {
                if (this.scalarType() != x.scalarType()) {
                    throw new IllegalArgumentException("Incompatible ScalarType: " + String.valueOf((Object)this.scalarType()) + " != " + String.valueOf((Object)x.scalarType()));
                }
                if (this.scalarType() != y.scalarType()) {
                    throw new IllegalArgumentException("Incompatible ScalarType: " + String.valueOf((Object)this.scalarType()) + " != " + String.valueOf((Object)y.scalarType()));
                }
                if (this.uplo == null) break block10;
                switch (this.scalarType()) {
                    case Float64: {
                        cblas_h.cblas_dsbmv(this.layout().blas(), this.uplo.blas(), this.n, this.kl, alpha, this.memory, this.ld, x.memory, 1, beta, y.memory, 1);
                        break block11;
                    }
                    case Float32: {
                        cblas_h.cblas_ssbmv(this.layout().blas(), this.uplo.blas(), this.n, this.kl, (float)alpha, this.memory, this.ld, x.memory, 1, (float)beta, y.memory, 1);
                        break block11;
                    }
                    default: {
                        throw new UnsupportedOperationException("Unsupported scalar type: " + String.valueOf((Object)this.scalarType()));
                    }
                }
            }
            switch (this.scalarType()) {
                case Float64: {
                    cblas_h.cblas_dgbmv(this.layout().blas(), trans.blas(), this.m, this.n, this.kl, this.ku, alpha, this.memory, this.ld, x.memory, 1, beta, y.memory, 1);
                    break;
                }
                case Float32: {
                    cblas_h.cblas_sgbmv(this.layout().blas(), trans.blas(), this.m, this.n, this.kl, this.ku, (float)alpha, this.memory, this.ld, x.memory, 1, (float)beta, y.memory, 1);
                    break;
                }
                default: {
                    throw new UnsupportedOperationException("Unsupported scalar type: " + String.valueOf((Object)this.scalarType()));
                }
            }
        }
    }

    public Vector solve(double[] b) {
        Vector x = this.vector(this.n);
        for (int i = 0; i < this.n; ++i) {
            x.set(i, b[i]);
        }
        this.solve(x);
        return x;
    }

    public Vector solve(float[] b) {
        Vector x = this.vector(this.n);
        for (int i = 0; i < this.n; ++i) {
            x.set(i, b[i]);
        }
        this.solve(x);
        return x;
    }

    public void solve(DenseMatrix B) {
        if (this.m != this.n) {
            throw new IllegalArgumentException(String.format("The matrix is not square: %d x %d", this.m, this.n));
        }
        if (B.m != this.m) {
            throw new IllegalArgumentException(String.format("Row dimensions do not agree: A is %d x %d, but B is %d x %d", this.m, this.n, B.m, B.n));
        }
        int[] m = new int[]{this.m};
        int[] n = new int[]{this.n};
        int[] kl = new int[]{this.kl};
        int[] ku = new int[]{this.ku};
        int[] lda = new int[]{this.lda()};
        int[] ipiv = new int[this.n];
        int[] info = new int[]{0};
        MemorySegment lu = this.lua();
        MemorySegment m_ = MemorySegment.ofArray(m);
        MemorySegment n_ = MemorySegment.ofArray(n);
        MemorySegment kl_ = MemorySegment.ofArray(kl);
        MemorySegment ku_ = MemorySegment.ofArray(ku);
        MemorySegment lda_ = MemorySegment.ofArray(lda);
        MemorySegment ipiv_ = MemorySegment.ofArray(ipiv);
        MemorySegment info_ = MemorySegment.ofArray(info);
        switch (this.scalarType()) {
            case Float64: {
                clapack_h.dgbtrf_(m_, n_, kl_, ku_, lu, lda_, ipiv_, info_);
                break;
            }
            case Float32: {
                clapack_h.sgbtrf_(m_, n_, kl_, ku_, lu, lda_, ipiv_, info_);
                break;
            }
            default: {
                throw new UnsupportedOperationException("Unsupported scalar type: " + String.valueOf((Object)this.scalarType()));
            }
        }
        if (info[0] < 0) {
            logger.error("LAPACK GBTRF error code: {}", (Object)info[0]);
            throw new ArithmeticException("LAPACK GBTRF error code: " + info[0]);
        }
        if (info[0] > 0) {
            throw new RuntimeException("The matrix is singular.");
        }
        byte[] trans = new byte[]{Transpose.NO_TRANSPOSE.lapack()};
        int[] nrhs = new int[]{B.n};
        int[] ldb = new int[]{B.ld};
        MemorySegment trans_ = MemorySegment.ofArray(trans);
        MemorySegment nrhs_ = MemorySegment.ofArray(nrhs);
        MemorySegment ldb_ = MemorySegment.ofArray(ldb);
        switch (this.scalarType()) {
            case Float64: {
                clapack_h.dgbtrs_(trans_, n_, kl_, ku_, nrhs_, lu, lda_, ipiv_, B.memory, ldb_, info_);
                break;
            }
            case Float32: {
                clapack_h.sgbtrs_(trans_, n_, kl_, ku_, nrhs_, lu, lda_, ipiv_, B.memory, ldb_, info_);
                break;
            }
            default: {
                throw new UnsupportedOperationException("Unsupported scalar type: " + String.valueOf((Object)this.scalarType()));
            }
        }
        if (info[0] != 0) {
            logger.error("LAPACK GBTRS error code: {}", (Object)info[0]);
            throw new ArithmeticException("LAPACK GBTRS error code: " + info[0]);
        }
    }

    int lda() {
        return 2 * this.kl + this.ku + 1;
    }

    abstract MemorySegment lua();
}

