/*
 * Decompiled with CFR 0.152.
 */
package smile.tensor;

import java.io.Serializable;
import java.lang.foreign.MemorySegment;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import smile.linalg.Diag;
import smile.linalg.EVDJob;
import smile.linalg.Order;
import smile.linalg.SVDJob;
import smile.linalg.Side;
import smile.linalg.Transpose;
import smile.linalg.UPLO;
import smile.linalg.blas.cblas_h;
import smile.linalg.lapack.clapack_h;
import smile.math.MathEx;
import smile.stat.distribution.Distribution;
import smile.stat.distribution.GaussianDistribution;
import smile.tensor.Cholesky;
import smile.tensor.DenseMatrix32;
import smile.tensor.DenseMatrix64;
import smile.tensor.EVD;
import smile.tensor.LU;
import smile.tensor.Matrix;
import smile.tensor.QR;
import smile.tensor.SVD;
import smile.tensor.ScalarType;
import smile.tensor.Vector;

public abstract class DenseMatrix
implements Matrix,
Serializable {
    private static final Logger logger = LoggerFactory.getLogger(DenseMatrix.class);
    transient MemorySegment memory;
    final int ld;
    final int m;
    final int n;
    UPLO uplo;
    Diag diag;
    String[] rowNames;
    String[] colNames;

    DenseMatrix() {
        this.memory = null;
        this.m = 0;
        this.n = 0;
        this.ld = 0;
        this.uplo = null;
        this.diag = null;
    }

    DenseMatrix(MemorySegment memory, int m, int n, int ld, UPLO uplo, Diag diag) {
        if (m <= 0 || n <= 0) {
            throw new IllegalArgumentException(String.format("Invalid matrix size: %d x %d", m, n));
        }
        if (this.order() == Order.COL_MAJOR && ld < m) {
            throw new IllegalArgumentException(String.format("Invalid leading dimension for COL_MAJOR: %d < %d", ld, m));
        }
        if (this.order() == Order.ROW_MAJOR && ld < n) {
            throw new IllegalArgumentException(String.format("Invalid leading dimension for ROW_MAJOR: %d < %d", ld, n));
        }
        this.memory = memory;
        this.m = m;
        this.n = n;
        this.ld = ld;
        this.uplo = uplo;
        this.diag = diag;
    }

    public String toString() {
        return this.toString(false);
    }

    public boolean equals(Object other) {
        double tol = 10.0f * MathEx.FLOAT_EPSILON;
        if (other instanceof DenseMatrix) {
            DenseMatrix b = (DenseMatrix)other;
            if (this.nrow() == b.nrow() && this.ncol() == b.ncol()) {
                for (int j = 0; j < this.n; ++j) {
                    for (int i = 0; i < this.m; ++i) {
                        if (!(Math.abs(this.get(i, j) - b.get(i, j)) > tol)) continue;
                        return false;
                    }
                }
                return true;
            }
        }
        return false;
    }

    static int ld(int n) {
        int elementSize = 4;
        if (n <= 256 / elementSize) {
            return n;
        }
        return ((n * elementSize + 511) / 512 * 512 + 64) / elementSize;
    }

    int offset(int i, int j) {
        return j * this.ld + i;
    }

    int capacity() {
        return this.ld * this.n;
    }

    @Override
    public abstract DenseMatrix copy();

    @Override
    public DenseMatrix transpose() {
        DenseMatrix trans = this.zeros(this.n, this.m);
        for (int i = 0; i < this.n; ++i) {
            for (int j = 0; j < this.m; ++j) {
                trans.set(i, j, this.get(j, i));
            }
        }
        return trans;
    }

    @Override
    public DenseMatrix scale(double alpha) {
        int length = this.capacity();
        switch (this.scalarType()) {
            case Float64: {
                cblas_h.cblas_dscal(length, alpha, this.memory, 1);
                break;
            }
            case Float32: {
                cblas_h.cblas_sscal(length, (float)alpha, this.memory, 1);
                break;
            }
            default: {
                throw new UnsupportedOperationException("Unsupported scalar type: " + String.valueOf((Object)this.scalarType()));
            }
        }
        return this;
    }

    @Override
    public int nrow() {
        return this.m;
    }

    @Override
    public int ncol() {
        return this.n;
    }

    public Order order() {
        return Order.COL_MAJOR;
    }

    public MemorySegment memory() {
        return this.memory;
    }

    public int ld() {
        return this.ld;
    }

    public boolean isSymmetric() {
        return this.uplo != null && this.diag == null;
    }

    public DenseMatrix withUplo(UPLO uplo) {
        if (this.m != this.n) {
            throw new IllegalArgumentException(String.format("The matrix is not square: %d x %d", this.m, this.n));
        }
        this.uplo = uplo;
        return this;
    }

    public UPLO uplo() {
        return this.uplo;
    }

    public DenseMatrix withDiag(Diag diag) {
        if (this.m != this.n) {
            throw new IllegalArgumentException(String.format("The matrix is not square: %d x %d", this.m, this.n));
        }
        this.diag = diag;
        return this;
    }

    public Diag diag() {
        return this.diag;
    }

    public String[] rowNames() {
        return this.rowNames;
    }

    public DenseMatrix withRowNames(String[] names) {
        if (names != null && names.length != this.nrow()) {
            throw new IllegalArgumentException(String.format("Invalid row names length: %d != %d", names.length, this.nrow()));
        }
        this.rowNames = names;
        return this;
    }

    public String[] colNames() {
        return this.colNames;
    }

    public DenseMatrix withColNames(String[] names) {
        if (names != null && names.length != this.ncol()) {
            throw new IllegalArgumentException(String.format("Invalid column names length: %d != %d", names.length, this.ncol()));
        }
        this.colNames = names;
        return this;
    }

    public abstract void fill(double var1);

    public Vector row(int i) {
        Vector x = this.vector(this.n);
        if (i < 0) {
            i = this.m + i;
        }
        for (int j = 0; j < this.n; ++j) {
            x.set(j, this.get(i, j));
        }
        return x;
    }

    public abstract Vector column(int var1);

    public DenseMatrix rows(int ... rows) {
        DenseMatrix x = this.zeros(rows.length, this.n);
        for (int i = 0; i < rows.length; ++i) {
            int row = rows[i];
            if (row < 0) {
                row = this.m + row;
            }
            for (int j = 0; j < this.n; ++j) {
                x.set(i, j, this.get(row, j));
            }
        }
        return x;
    }

    public DenseMatrix columns(int ... cols) {
        DenseMatrix x = this.zeros(this.m, cols.length);
        for (int j = 0; j < cols.length; ++j) {
            int col = cols[j];
            if (col < 0) {
                col = this.n + col;
            }
            for (int i = 0; i < this.m; ++i) {
                x.set(i, j, this.get(i, col));
            }
        }
        return x;
    }

    public DenseMatrix rows(int from, int to) {
        if (to <= from) {
            throw new IllegalArgumentException("Invalid row range [" + from + ", " + to + ")");
        }
        int k = to - from;
        DenseMatrix x = this.zeros(k, this.n);
        for (int i = 0; i < k; ++i) {
            for (int j = 0; j < this.n; ++j) {
                x.set(i, j, this.get(from + i, j));
            }
        }
        return x;
    }

    public DenseMatrix columns(int from, int to) {
        if (to <= from) {
            throw new IllegalArgumentException("Invalid row range [" + from + ", " + to + ")");
        }
        int k = to - from;
        DenseMatrix x = this.zeros(this.m, k);
        for (int j = 0; j < k; ++j) {
            for (int i = 0; i < this.m; ++i) {
                x.set(i, j, this.get(i, from + j));
            }
        }
        return x;
    }

    public DenseMatrix submatrix(int i, int j, int k, int l) {
        throw new UnsupportedOperationException();
    }

    public Vector colSums() {
        Vector sum = this.vector(this.n);
        for (int j = 0; j < this.n; ++j) {
            for (int i = 0; i < this.m; ++i) {
                sum.add(j, this.get(i, j));
            }
        }
        return sum;
    }

    public Vector rowSums() {
        Vector sum = this.vector(this.m);
        for (int j = 0; j < this.n; ++j) {
            for (int i = 0; i < this.m; ++i) {
                sum.add(i, this.get(i, j));
            }
        }
        return sum;
    }

    public Vector colMeans() {
        Vector mean = this.colSums();
        for (int j = 0; j < this.n; ++j) {
            mean.div(j, this.m);
        }
        return mean;
    }

    public Vector rowMeans() {
        Vector mean = this.rowSums();
        for (int i = 0; i < this.m; ++i) {
            mean.div(i, this.n);
        }
        return mean;
    }

    public Vector colSds() {
        Vector sd = this.vector(this.n);
        for (int j = 0; j < this.n; ++j) {
            double mu = 0.0;
            double sumsq = 0.0;
            for (int i = 0; i < this.m; ++i) {
                double a = this.get(i, j);
                mu += a;
                sumsq += a * a;
            }
            double variance = Math.max(sumsq / (double)this.m - (mu /= (double)this.m) * mu, 0.0);
            sd.set(j, Math.sqrt(variance));
        }
        return sd;
    }

    public DenseMatrix standardize() {
        Vector center = this.colMeans();
        Vector scale = this.colSds();
        return this.standardize(center, scale);
    }

    public DenseMatrix standardize(Vector center, Vector scale) {
        if (center == null && scale == null) {
            throw new IllegalArgumentException("Both center and scale are null");
        }
        DenseMatrix matrix = this.copy();
        if (center != null && scale != null) {
            for (int j = 0; j < this.n; ++j) {
                for (int i = 0; i < this.m; ++i) {
                    matrix.set(i, j, (this.get(i, j) - center.get(j)) / scale.get(j));
                }
            }
        } else if (center != null) {
            for (int j = 0; j < this.n; ++j) {
                for (int i = 0; i < this.m; ++i) {
                    matrix.sub(i, j, center.get(j));
                }
            }
        } else {
            for (int j = 0; j < this.n; ++j) {
                for (int i = 0; i < this.m; ++i) {
                    matrix.div(i, j, scale.get(j));
                }
            }
        }
        return matrix;
    }

    public void add(double alpha, DenseMatrix A, double beta, DenseMatrix B) {
        if (this.nrow() != A.nrow() || this.ncol() != A.ncol()) {
            throw new IllegalArgumentException(String.format("Adds matrix: %d x %d vs %d x %d", A.nrow(), A.ncol(), this.m, this.n));
        }
        if (this.nrow() != B.nrow() || this.ncol() != B.ncol()) {
            throw new IllegalArgumentException(String.format("Adds matrix: %d x %d vs %d x %d", B.nrow(), B.ncol(), this.m, this.n));
        }
        for (int j = 0; j < this.n; ++j) {
            for (int i = 0; i < this.m; ++i) {
                this.set(i, j, alpha * A.get(i, j) + beta * B.get(i, j));
            }
        }
    }

    public DenseMatrix axpy(double alpha, DenseMatrix x) {
        if (this.nrow() != x.nrow() || this.ncol() != x.ncol()) {
            throw new IllegalArgumentException(String.format("Adds matrix: %d x %d vs %d x %d", this.m, this.n, x.nrow(), x.ncol()));
        }
        if (this.scalarType() == x.scalarType() && this.ld == x.ld) {
            int length = this.capacity();
            switch (this.scalarType()) {
                case Float64: {
                    cblas_h.cblas_daxpy(length, alpha, x.memory, 1, this.memory, 1);
                    break;
                }
                case Float32: {
                    cblas_h.cblas_saxpy(length, (float)alpha, x.memory, 1, this.memory, 1);
                    break;
                }
                default: {
                    throw new UnsupportedOperationException("Unsupported scalar type: " + String.valueOf((Object)this.scalarType()));
                }
            }
        } else {
            for (int j = 0; j < this.n; ++j) {
                for (int i = 0; i < this.m; ++i) {
                    this.set(i, j, alpha * this.get(i, j) + alpha * x.get(i, j));
                }
            }
        }
        return this;
    }

    public DenseMatrix add(DenseMatrix B) {
        return this.axpy(1.0, B);
    }

    public DenseMatrix sub(DenseMatrix B) {
        return this.axpy(-1.0, B);
    }

    @Override
    public void mv(Transpose trans, double alpha, Vector x, double beta, Vector y) {
        if (this.scalarType() != x.scalarType()) {
            throw new IllegalArgumentException("Incompatible ScalarType: " + String.valueOf((Object)this.scalarType()) + " != " + String.valueOf((Object)x.scalarType()));
        }
        if (this.scalarType() != y.scalarType()) {
            throw new IllegalArgumentException("Incompatible ScalarType: " + String.valueOf((Object)this.scalarType()) + " != " + String.valueOf((Object)y.scalarType()));
        }
        switch (trans) {
            case NO_TRANSPOSE: {
                if (this.ncol() > x.size()) {
                    throw new IllegalArgumentException("Incompatible x vector size: " + this.ncol() + " != " + x.size());
                }
                if (this.nrow() <= y.size()) break;
                throw new IllegalArgumentException("Incompatible y vector size: " + this.nrow() + " != " + y.size());
            }
            case TRANSPOSE: {
                if (this.nrow() > x.size()) {
                    throw new IllegalArgumentException("Incompatible x vector size: " + this.nrow() + " != " + x.size());
                }
                if (this.ncol() <= y.size()) break;
                throw new IllegalArgumentException("Incompatible y vector size: " + this.ncol() + " != " + y.size());
            }
        }
        switch (this.scalarType()) {
            case Float64: {
                if (this.uplo != null) {
                    if (this.diag != null) {
                        if (alpha == 1.0 && beta == 0.0 && x == y) {
                            cblas_h.cblas_dtrmv(this.order().blas(), this.uplo.blas(), trans.blas(), this.diag.blas(), this.m, this.memory, this.ld, y.memory(), 1);
                            break;
                        }
                        cblas_h.cblas_dgemv(this.order().blas(), trans.blas(), this.m, this.n, alpha, this.memory, this.ld, x.memory(), 1, beta, y.memory(), 1);
                        break;
                    }
                    cblas_h.cblas_dsymv(this.order().blas(), this.uplo.blas(), this.m, alpha, this.memory, this.ld, x.memory(), 1, beta, y.memory(), 1);
                    break;
                }
                cblas_h.cblas_dgemv(this.order().blas(), trans.blas(), this.m, this.n, alpha, this.memory, this.ld, x.memory(), 1, beta, y.memory(), 1);
                break;
            }
            case Float32: {
                if (this.uplo != null) {
                    if (this.diag != null) {
                        if (alpha == 1.0 && beta == 0.0 && x == y) {
                            cblas_h.cblas_strmv(this.order().blas(), this.uplo.blas(), trans.blas(), this.diag.blas(), this.m, this.memory, this.ld, y.memory(), 1);
                            break;
                        }
                        cblas_h.cblas_sgemv(this.order().blas(), trans.blas(), this.m, this.n, (float)alpha, this.memory, this.ld, x.memory(), 1, (float)beta, y.memory(), 1);
                        break;
                    }
                    cblas_h.cblas_ssymv(this.order().blas(), this.uplo.blas(), this.m, (float)alpha, this.memory, this.ld, x.memory(), 1, (float)beta, y.memory(), 1);
                    break;
                }
                cblas_h.cblas_sgemv(this.order().blas(), trans.blas(), this.m, this.n, (float)alpha, this.memory, this.ld, x.memory(), 1, (float)beta, y.memory(), 1);
                break;
            }
            default: {
                throw new UnsupportedOperationException("Unsupported ScalarType: " + String.valueOf((Object)this.scalarType()));
            }
        }
    }

    public static void mm(double alpha, Transpose transA, DenseMatrix A, Transpose transB, DenseMatrix B, double beta, DenseMatrix C) {
        block17: {
            int n;
            int m;
            block18: {
                block16: {
                    if (C.scalarType() != A.scalarType()) {
                        throw new IllegalArgumentException("Incompatible ScalarType: " + String.valueOf((Object)C.scalarType()) + " != " + String.valueOf((Object)A.scalarType()));
                    }
                    if (C.scalarType() != B.scalarType()) {
                        throw new IllegalArgumentException("Incompatible ScalarType: " + String.valueOf((Object)C.scalarType()) + " != " + String.valueOf((Object)A.scalarType()));
                    }
                    m = C.nrow();
                    n = C.ncol();
                    if (!A.isSymmetric() || transB != Transpose.NO_TRANSPOSE || B.order() != C.order()) break block16;
                    switch (C.scalarType()) {
                        case Float64: {
                            cblas_h.cblas_dsymm(C.order().blas(), Side.LEFT.blas(), A.uplo().blas(), m, n, alpha, A.memory(), A.ld(), B.memory(), B.ld(), beta, C.memory(), C.ld());
                            break block17;
                        }
                        case Float32: {
                            cblas_h.cblas_ssymm(C.order().blas(), Side.LEFT.blas(), A.uplo().blas(), m, n, (float)alpha, A.memory(), A.ld(), B.memory(), B.ld(), (float)beta, C.memory(), C.ld());
                            break block17;
                        }
                        default: {
                            throw new UnsupportedOperationException("Unsupported ScalarType: " + String.valueOf((Object)A.scalarType()));
                        }
                    }
                }
                if (!B.isSymmetric() || transA != Transpose.NO_TRANSPOSE || A.order() != C.order()) break block18;
                switch (C.scalarType()) {
                    case Float64: {
                        cblas_h.cblas_dsymm(C.order().blas(), Side.RIGHT.blas(), B.uplo().blas(), m, n, alpha, B.memory(), B.ld(), A.memory(), A.ld(), beta, C.memory(), C.ld());
                        break block17;
                    }
                    case Float32: {
                        cblas_h.cblas_ssymm(C.order().blas(), Side.RIGHT.blas(), B.uplo().blas(), m, n, (float)alpha, B.memory(), B.ld(), A.memory(), A.ld(), (float)beta, C.memory(), C.ld());
                        break block17;
                    }
                    default: {
                        throw new UnsupportedOperationException("Unsupported ScalarType: " + String.valueOf((Object)A.scalarType()));
                    }
                }
            }
            if (C.order() != A.order()) {
                transA = Transpose.flip(transA);
                A = A.transpose();
            }
            if (C.order() != B.order()) {
                transB = Transpose.flip(transB);
                B = B.transpose();
            }
            int k = transA == Transpose.NO_TRANSPOSE ? A.ncol() : A.nrow();
            switch (C.scalarType()) {
                case Float64: {
                    cblas_h.cblas_dgemm(C.order().blas(), transA.blas(), transB.blas(), m, n, k, alpha, A.memory(), A.ld(), B.memory(), B.ld(), beta, C.memory(), C.ld());
                    break;
                }
                case Float32: {
                    cblas_h.cblas_sgemm(C.order().blas(), transA.blas(), transB.blas(), m, n, k, (float)alpha, A.memory(), A.ld(), B.memory(), B.ld(), (float)beta, C.memory(), C.ld());
                    break;
                }
                default: {
                    throw new UnsupportedOperationException("Unsupported ScalarType: " + String.valueOf((Object)A.scalarType()));
                }
            }
        }
    }

    public DenseMatrix mm(DenseMatrix B) {
        if (this.ncol() != B.nrow()) {
            throw new IllegalArgumentException(String.format("Matrix multiplication A * B: %d x %d vs %d x %d", this.m, this.n, B.nrow(), B.ncol()));
        }
        DenseMatrix C = this.zeros(this.nrow(), B.ncol());
        DenseMatrix.mm(1.0, Transpose.NO_TRANSPOSE, this, Transpose.NO_TRANSPOSE, B, 0.0, C);
        return C;
    }

    public DenseMatrix tm(DenseMatrix B) {
        if (this.nrow() != B.nrow()) {
            throw new IllegalArgumentException(String.format("Matrix multiplication A' * B: %d x %d vs %d x %d", this.m, this.n, B.nrow(), B.ncol()));
        }
        DenseMatrix C = this.zeros(this.ncol(), B.ncol());
        DenseMatrix.mm(1.0, Transpose.TRANSPOSE, this, Transpose.NO_TRANSPOSE, B, 0.0, C);
        return C;
    }

    public DenseMatrix mt(DenseMatrix B) {
        if (this.ncol() != B.ncol()) {
            throw new IllegalArgumentException(String.format("Matrix multiplication A * B': %d x %d vs %d x %d", this.m, this.n, B.nrow(), B.ncol()));
        }
        DenseMatrix C = this.zeros(this.nrow(), B.nrow());
        DenseMatrix.mm(1.0, Transpose.NO_TRANSPOSE, this, Transpose.TRANSPOSE, B, 0.0, C);
        return C;
    }

    public DenseMatrix ata() {
        DenseMatrix C = this.zeros(this.ncol(), this.ncol()).withUplo(UPLO.LOWER);
        DenseMatrix.mm(1.0, Transpose.TRANSPOSE, this, Transpose.NO_TRANSPOSE, this, 0.0, C);
        return C;
    }

    public DenseMatrix aat() {
        DenseMatrix C = this.zeros(this.nrow(), this.nrow()).withUplo(UPLO.LOWER);
        DenseMatrix.mm(1.0, Transpose.NO_TRANSPOSE, this, Transpose.TRANSPOSE, this, 0.0, C);
        return C;
    }

    public void ger(double alpha, Vector x, Vector y) {
        if (this.m != x.size()) {
            throw new IllegalArgumentException(String.format("Dimensions do not match for rank-1 update: %d x %d vs %d x 1", this.nrow(), this.ncol(), x.size()));
        }
        if (this.n != y.size()) {
            throw new IllegalArgumentException(String.format("Dimensions do not match for rank-1 update: %d x %d vs 1 x %d", this.nrow(), this.ncol(), y.size()));
        }
        switch (this.scalarType()) {
            case Float64: {
                cblas_h.cblas_dger(this.order().blas(), this.m, this.n, alpha, x.memory, 1, y.memory, 1, this.memory, this.ld);
                break;
            }
            case Float32: {
                cblas_h.cblas_sger(this.order().blas(), this.m, this.n, (float)alpha, x.memory, 1, y.memory, 1, this.memory, this.ld);
                break;
            }
            default: {
                throw new UnsupportedOperationException("Unsupported scalar type: " + String.valueOf((Object)this.scalarType()));
            }
        }
    }

    public DenseMatrix inverse() {
        if (this.m != this.n) {
            throw new IllegalArgumentException(String.format("The matrix is not square: %d x %d", this.m, this.n));
        }
        DenseMatrix lu = this.copy();
        DenseMatrix inv = this.eye(lu.n);
        int[] n = new int[]{lu.n};
        int[] lda = new int[]{lu.ld};
        int[] ldb = new int[]{inv.ld};
        int[] ipiv = new int[lu.n];
        int[] info = new int[]{0};
        MemorySegment n_ = MemorySegment.ofArray(n);
        MemorySegment lda_ = MemorySegment.ofArray(lda);
        MemorySegment ldb_ = MemorySegment.ofArray(ldb);
        MemorySegment ipiv_ = MemorySegment.ofArray(ipiv);
        MemorySegment info_ = MemorySegment.ofArray(info);
        if (this.isSymmetric()) {
            Vector work = lu.vector(1);
            int[] lwork = new int[]{-1};
            byte[] uplo = new byte[]{lu.uplo.lapack()};
            MemorySegment lwork_ = MemorySegment.ofArray(lwork);
            MemorySegment uplo_ = MemorySegment.ofArray(uplo);
            switch (this.scalarType()) {
                case Float64: {
                    clapack_h.dsysv_(uplo_, n_, n_, lu.memory, lda_, ipiv_, inv.memory, ldb_, work.memory, lwork_, info_);
                    break;
                }
                case Float32: {
                    clapack_h.ssysv_(uplo_, n_, n_, lu.memory, lda_, ipiv_, inv.memory, ldb_, work.memory, lwork_, info_);
                    break;
                }
                default: {
                    throw new UnsupportedOperationException("Unsupported scalar type: " + String.valueOf((Object)this.scalarType()));
                }
            }
            if (info[0] != 0) {
                throw new ArithmeticException("SYSV fails: " + info[0]);
            }
            work = lu.vector((int)work.get(0));
            lwork[0] = work.size();
            switch (this.scalarType()) {
                case Float64: {
                    clapack_h.dsysv_(uplo_, n_, n_, lu.memory, lda_, ipiv_, inv.memory, ldb_, work.memory, lwork_, info_);
                    break;
                }
                case Float32: {
                    clapack_h.ssysv_(uplo_, n_, n_, lu.memory, lda_, ipiv_, inv.memory, ldb_, work.memory, lwork_, info_);
                    break;
                }
                default: {
                    throw new UnsupportedOperationException("Unsupported scalar type: " + String.valueOf((Object)this.scalarType()));
                }
            }
            if (info[0] != 0) {
                throw new ArithmeticException("SYSV fails: " + info[0]);
            }
        } else {
            switch (this.scalarType()) {
                case Float64: {
                    clapack_h.dgesv_(n_, n_, lu.memory, lda_, ipiv_, inv.memory, ldb_, info_);
                    break;
                }
                case Float32: {
                    clapack_h.sgesv_(n_, n_, lu.memory, lda_, ipiv_, inv.memory, ldb_, info_);
                    break;
                }
                default: {
                    throw new UnsupportedOperationException("Unsupported scalar type: " + String.valueOf((Object)this.scalarType()));
                }
            }
            if (info[0] != 0) {
                throw new ArithmeticException("GESV fails: " + info[0]);
            }
        }
        return inv;
    }

    public LU lu() {
        DenseMatrix lu = this;
        int[] m = new int[]{lu.m};
        int[] n = new int[]{lu.n};
        int[] lda = new int[]{lu.ld};
        int[] ipiv = new int[Math.min(lu.m, lu.n)];
        int[] info = new int[]{0};
        MemorySegment m_ = MemorySegment.ofArray(m);
        MemorySegment n_ = MemorySegment.ofArray(n);
        MemorySegment lda_ = MemorySegment.ofArray(lda);
        MemorySegment ipiv_ = MemorySegment.ofArray(ipiv);
        MemorySegment info_ = MemorySegment.ofArray(info);
        switch (this.scalarType()) {
            case Float64: {
                clapack_h.dgetrf_(m_, n_, lu.memory, lda_, ipiv_, info_);
                break;
            }
            case Float32: {
                clapack_h.sgetrf_(m_, n_, lu.memory, lda_, ipiv_, info_);
                break;
            }
            default: {
                throw new UnsupportedOperationException("Unsupported scalar type: " + String.valueOf((Object)this.scalarType()));
            }
        }
        if (info[0] < 0) {
            logger.error("LAPACK GETRF error code: {}", (Object)info);
            throw new ArithmeticException("LAPACK GETRF error code: " + info[0]);
        }
        lu.uplo = null;
        return new LU(lu, ipiv, info[0]);
    }

    public Cholesky cholesky() {
        if (this.uplo == null) {
            throw new IllegalArgumentException("The matrix is not symmetric");
        }
        DenseMatrix lu = this;
        byte[] uplo = new byte[]{lu.uplo.lapack()};
        int[] n = new int[]{lu.n};
        int[] lda = new int[]{lu.ld};
        int[] info = new int[]{0};
        MemorySegment uplo_ = MemorySegment.ofArray(uplo);
        MemorySegment n_ = MemorySegment.ofArray(n);
        MemorySegment lda_ = MemorySegment.ofArray(lda);
        MemorySegment info_ = MemorySegment.ofArray(info);
        switch (this.scalarType()) {
            case Float64: {
                clapack_h.dpotrf_(uplo_, n_, lu.memory, lda_, info_);
                break;
            }
            case Float32: {
                clapack_h.spotrf_(uplo_, n_, lu.memory, lda_, info_);
                break;
            }
            default: {
                throw new UnsupportedOperationException("Unsupported scalar type: " + String.valueOf((Object)this.scalarType()));
            }
        }
        if (info[0] != 0) {
            logger.error("LAPACK POTRF error code: {}", (Object)info[0]);
            throw new ArithmeticException("LAPACK POTRF error code: " + info[0]);
        }
        return new Cholesky(lu);
    }

    public QR qr() {
        DenseMatrix qr = this;
        Vector tau = qr.vector(Math.min(this.m, this.n));
        Vector work = this.vector(1);
        int[] m = new int[]{qr.m};
        int[] n = new int[]{qr.n};
        int[] lda = new int[]{qr.ld};
        int[] lwork = new int[]{-1};
        int[] info = new int[]{0};
        MemorySegment m_ = MemorySegment.ofArray(m);
        MemorySegment n_ = MemorySegment.ofArray(n);
        MemorySegment lda_ = MemorySegment.ofArray(lda);
        MemorySegment lwork_ = MemorySegment.ofArray(lwork);
        MemorySegment info_ = MemorySegment.ofArray(info);
        switch (this.scalarType()) {
            case Float64: {
                clapack_h.dgeqrf_(m_, n_, qr.memory, lda_, tau.memory, work.memory, lwork_, info_);
                break;
            }
            case Float32: {
                clapack_h.sgeqrf_(m_, n_, qr.memory, lda_, tau.memory, work.memory, lwork_, info_);
                break;
            }
            default: {
                throw new UnsupportedOperationException("Unsupported scalar type: " + String.valueOf((Object)this.scalarType()));
            }
        }
        if (info[0] != 0) {
            logger.error("LAPACK GEQRF error code: {}", (Object)info[0]);
            throw new IllegalArgumentException("LAPACK GEQRF error code: " + info[0]);
        }
        work = qr.vector((int)work.get(0));
        lwork[0] = work.size();
        switch (this.scalarType()) {
            case Float64: {
                clapack_h.dgeqrf_(m_, n_, qr.memory, lda_, tau.memory, work.memory, lwork_, info_);
                break;
            }
            case Float32: {
                clapack_h.sgeqrf_(m_, n_, qr.memory, lda_, tau.memory, work.memory, lwork_, info_);
                break;
            }
            default: {
                throw new UnsupportedOperationException("Unsupported scalar type: " + String.valueOf((Object)this.scalarType()));
            }
        }
        if (info[0] != 0) {
            logger.error("LAPACK GEQRF error code: {}", (Object)info[0]);
            throw new ArithmeticException("LAPACK GEQRF error code: " + info[0]);
        }
        qr.uplo = null;
        return new QR(qr, tau);
    }

    public SVD svd() {
        return this.svd(true);
    }

    public SVD svd(boolean vectors) {
        int k = Math.min(this.m, this.n);
        Vector s = this.vector(k);
        DenseMatrix A = this;
        DenseMatrix U = vectors ? this.zeros(A.m, k) : this.zeros(1, 1);
        DenseMatrix Vt = vectors ? this.zeros(k, A.n) : this.zeros(1, 1);
        Vector work = this.vector(1);
        byte[] jobz = new byte[]{vectors ? SVDJob.COMPACT.lapack() : SVDJob.NO_VECTORS.lapack()};
        int[] m = new int[]{A.m};
        int[] n = new int[]{A.n};
        int[] lda = new int[]{A.ld};
        int[] lwork = new int[]{-1};
        int[] iwork = new int[8 * Math.min(A.m, A.n)];
        int[] ldu = new int[]{U.ld};
        int[] ldvt = new int[]{Vt.ld};
        int[] info = new int[]{0};
        MemorySegment m_ = MemorySegment.ofArray(m);
        MemorySegment n_ = MemorySegment.ofArray(n);
        MemorySegment lda_ = MemorySegment.ofArray(lda);
        MemorySegment jobz_ = MemorySegment.ofArray(jobz);
        MemorySegment info_ = MemorySegment.ofArray(info);
        MemorySegment ldu_ = MemorySegment.ofArray(ldu);
        MemorySegment ldvt_ = MemorySegment.ofArray(ldvt);
        MemorySegment lwork_ = MemorySegment.ofArray(lwork);
        MemorySegment iwork_ = MemorySegment.ofArray(iwork);
        switch (this.scalarType()) {
            case Float64: {
                clapack_h.dgesdd_(jobz_, m_, n_, A.memory, lda_, s.memory, U.memory, ldu_, Vt.memory, ldvt_, work.memory, lwork_, iwork_, info_);
                break;
            }
            case Float32: {
                clapack_h.sgesdd_(jobz_, m_, n_, A.memory, lda_, s.memory, U.memory, ldu_, Vt.memory, ldvt_, work.memory, lwork_, iwork_, info_);
                break;
            }
            default: {
                throw new UnsupportedOperationException("Unsupported scalar type: " + String.valueOf((Object)this.scalarType()));
            }
        }
        if (info[0] != 0) {
            logger.error("LAPACK GESDD with error code: {}", (Object)info[0]);
            throw new ArithmeticException("LAPACK GESDD with COMPACT error code: " + info[0]);
        }
        work = this.vector((int)work.get(0));
        lwork[0] = work.size();
        switch (this.scalarType()) {
            case Float64: {
                clapack_h.dgesdd_(jobz_, m_, n_, A.memory, lda_, s.memory, U.memory, ldu_, Vt.memory, ldvt_, work.memory, lwork_, iwork_, info_);
                break;
            }
            case Float32: {
                clapack_h.sgesdd_(jobz_, m_, n_, A.memory, lda_, s.memory, U.memory, ldu_, Vt.memory, ldvt_, work.memory, lwork_, iwork_, info_);
                break;
            }
            default: {
                throw new UnsupportedOperationException("Unsupported scalar type: " + String.valueOf((Object)this.scalarType()));
            }
        }
        if (info[0] != 0) {
            logger.error("LAPACK GESDD with error code: {}", (Object)info[0]);
            throw new ArithmeticException("LAPACK GESDD with COMPACT error code: " + info[0]);
        }
        return vectors ? new SVD(s, U, Vt) : new SVD(A.m, A.n, s);
    }

    public EVD eigen() {
        return this.eigen(false, true);
    }

    public EVD eigen(boolean vl, boolean vr) {
        if (this.m != this.n) {
            throw new IllegalArgumentException(String.format("The matrix is not square: %d x %d", this.m, this.n));
        }
        DenseMatrix eig = this;
        byte[] vectors = new byte[]{EVDJob.VECTORS.lapack()};
        byte[] no_vectors = new byte[]{EVDJob.NO_VECTORS.lapack()};
        int[] n = new int[]{eig.n};
        int[] lda = new int[]{eig.ld};
        int[] info = new int[]{0};
        MemorySegment vectors_ = MemorySegment.ofArray(vectors);
        MemorySegment no_vectors_ = MemorySegment.ofArray(no_vectors);
        MemorySegment n_ = MemorySegment.ofArray(n);
        MemorySegment lda_ = MemorySegment.ofArray(lda);
        MemorySegment info_ = MemorySegment.ofArray(info);
        if (this.isSymmetric()) {
            Vector w = this.vector(eig.n);
            Vector work = this.vector(1);
            int[] lwork = new int[]{-1};
            int[] iwork = new int[1];
            int[] liwork = new int[]{-1};
            byte[] uplo = new byte[]{eig.uplo.lapack()};
            MemorySegment iwork_ = MemorySegment.ofArray(iwork);
            MemorySegment lwork_ = MemorySegment.ofArray(lwork);
            MemorySegment liwork_ = MemorySegment.ofArray(liwork);
            MemorySegment uplo_ = MemorySegment.ofArray(uplo);
            switch (this.scalarType()) {
                case Float64: {
                    clapack_h.dsyevd_(vr ? vectors_ : no_vectors_, uplo_, n_, eig.memory, lda_, w.memory, work.memory, lwork_, iwork_, liwork_, info_);
                    break;
                }
                case Float32: {
                    clapack_h.ssyevd_(vr ? vectors_ : no_vectors_, uplo_, n_, eig.memory, lda_, w.memory, work.memory, lwork_, iwork_, liwork_, info_);
                    break;
                }
                default: {
                    throw new UnsupportedOperationException("Unsupported scalar type: " + String.valueOf((Object)this.scalarType()));
                }
            }
            if (info[0] != 0) {
                logger.error("LAPACK SYEV error code: {}", (Object)info[0]);
                throw new ArithmeticException("LAPACK SYEV error code: " + info[0]);
            }
            work = this.vector((int)work.get(0));
            iwork = new int[iwork[0]];
            lwork[0] = work.size();
            liwork[0] = iwork.length;
            iwork_ = MemorySegment.ofArray(iwork);
            switch (this.scalarType()) {
                case Float64: {
                    clapack_h.dsyevd_(vr ? vectors_ : no_vectors_, uplo_, n_, eig.memory, lda_, w.memory, work.memory, lwork_, iwork_, liwork_, info_);
                    break;
                }
                case Float32: {
                    clapack_h.ssyevd_(vr ? vectors_ : no_vectors_, uplo_, n_, eig.memory, lda_, w.memory, work.memory, lwork_, iwork_, liwork_, info_);
                    break;
                }
                default: {
                    throw new UnsupportedOperationException("Unsupported scalar type: " + String.valueOf((Object)this.scalarType()));
                }
            }
            if (info[0] != 0) {
                logger.error("LAPACK SYEV error code: {}", (Object)info[0]);
                throw new ArithmeticException("LAPACK SYEV error code: " + info[0]);
            }
            eig.uplo = null;
            return new EVD(w, vr ? eig : null);
        }
        Vector wr = this.vector(eig.n);
        Vector wi = this.vector(eig.n);
        DenseMatrix Vl = vl ? this.zeros(eig.n, eig.n) : this.zeros(1, 1);
        DenseMatrix Vr = vr ? this.zeros(eig.n, eig.n) : this.zeros(1, 1);
        Vector work = this.vector(1);
        int[] ldvl = new int[]{Vl.ld};
        int[] ldvr = new int[]{Vr.ld};
        int[] lwork = new int[]{-1};
        MemorySegment ldvl_ = MemorySegment.ofArray(ldvl);
        MemorySegment ldvr_ = MemorySegment.ofArray(ldvr);
        MemorySegment lwork_ = MemorySegment.ofArray(lwork);
        switch (this.scalarType()) {
            case Float64: {
                clapack_h.dgeev_(vl ? vectors_ : no_vectors_, vr ? vectors_ : no_vectors_, n_, eig.memory, lda_, wr.memory, wi.memory, Vl.memory, ldvl_, Vr.memory, ldvr_, work.memory, lwork_, info_);
                break;
            }
            case Float32: {
                clapack_h.sgeev_(vl ? vectors_ : no_vectors_, vr ? vectors_ : no_vectors_, n_, eig.memory, lda_, wr.memory, wi.memory, Vl.memory, ldvl_, Vr.memory, ldvr_, work.memory, lwork_, info_);
                break;
            }
            default: {
                throw new UnsupportedOperationException("Unsupported scalar type: " + String.valueOf((Object)this.scalarType()));
            }
        }
        if (info[0] != 0) {
            logger.error("LAPACK GEEV error code: {}", (Object)info[0]);
            throw new ArithmeticException("LAPACK GEEV error code: " + info[0]);
        }
        work = this.vector((int)work.get(0));
        lwork[0] = work.size();
        switch (this.scalarType()) {
            case Float64: {
                clapack_h.dgeev_(vl ? vectors_ : no_vectors_, vr ? vectors_ : no_vectors_, n_, eig.memory, lda_, wr.memory, wi.memory, Vl.memory, ldvl_, Vr.memory, ldvr_, work.memory, lwork_, info_);
                break;
            }
            case Float32: {
                clapack_h.sgeev_(vl ? vectors_ : no_vectors_, vr ? vectors_ : no_vectors_, n_, eig.memory, lda_, wr.memory, wi.memory, Vl.memory, ldvl_, Vr.memory, ldvr_, work.memory, lwork_, info_);
                break;
            }
            default: {
                throw new UnsupportedOperationException("Unsupported scalar type: " + String.valueOf((Object)this.scalarType()));
            }
        }
        if (info[0] != 0) {
            logger.error("LAPACK GEEV error code: {}", (Object)info[0]);
            throw new ArithmeticException("LAPACK GEEV error code: " + info[0]);
        }
        return new EVD(wr, wi, vl ? Vl : null, vr ? Vr : null);
    }

    public static DenseMatrix of(double[][] A) {
        int m = A.length;
        int n = A[0].length;
        DenseMatrix matrix = DenseMatrix.zeros(ScalarType.Float64, m, n);
        for (int i = 0; i < m; ++i) {
            for (int j = 0; j < n; ++j) {
                matrix.set(i, j, A[i][j]);
            }
        }
        return matrix;
    }

    public static DenseMatrix of(float[][] A) {
        int m = A.length;
        int n = A[0].length;
        DenseMatrix matrix = DenseMatrix.zeros(ScalarType.Float32, m, n);
        for (int i = 0; i < m; ++i) {
            for (int j = 0; j < n; ++j) {
                matrix.set(i, j, A[i][j]);
            }
        }
        return matrix;
    }

    public static DenseMatrix zeros(ScalarType scalarType, int m, int n) {
        int ld = DenseMatrix.ld(m);
        return switch (scalarType) {
            case ScalarType.Float64 -> {
                double[] array = new double[ld * n];
                yield new DenseMatrix64(array, m, n, ld, null, null);
            }
            case ScalarType.Float32 -> {
                float[] array = new float[ld * n];
                yield new DenseMatrix32(array, m, n, ld, null, null);
            }
            default -> throw new UnsupportedOperationException("Unsupported ScalarType: " + String.valueOf((Object)scalarType));
        };
    }

    public DenseMatrix zeros(int m, int n) {
        return DenseMatrix.zeros(this.scalarType(), m, n);
    }

    public static DenseMatrix eye(ScalarType scalarType, int n) {
        return DenseMatrix.eye(scalarType, n, n);
    }

    public static DenseMatrix eye(ScalarType scalarType, int m, int n) {
        int ld = DenseMatrix.ld(m);
        DenseMatrix matrix = DenseMatrix.zeros(scalarType, m, n);
        int k = Math.min(m, n);
        for (int i = 0; i < k; ++i) {
            matrix.set(i, i, 1.0);
        }
        return matrix;
    }

    public DenseMatrix eye(int n) {
        return this.eye(n, n);
    }

    public DenseMatrix eye(int m, int n) {
        return DenseMatrix.eye(this.scalarType(), m, n);
    }

    public static DenseMatrix diagflat(double[] diag) {
        int n = diag.length;
        DenseMatrix matrix = DenseMatrix.zeros(ScalarType.Float64, n, n);
        for (int i = 0; i < n; ++i) {
            matrix.set(i, i, diag[i]);
        }
        return matrix;
    }

    public static DenseMatrix diagflat(float[] diag) {
        int n = diag.length;
        DenseMatrix matrix = DenseMatrix.zeros(ScalarType.Float32, n, n);
        for (int i = 0; i < n; ++i) {
            matrix.set(i, i, diag[i]);
        }
        return matrix;
    }

    public static DenseMatrix randn(ScalarType scalarType, int m, int n) {
        return DenseMatrix.rand(scalarType, m, n, GaussianDistribution.getInstance());
    }

    public static DenseMatrix rand(ScalarType scalarType, int m, int n, Distribution distribution) {
        DenseMatrix matrix = DenseMatrix.zeros(scalarType, m, n);
        for (int j = 0; j < n; ++j) {
            for (int i = 0; i < m; ++i) {
                matrix.set(i, j, distribution.rand());
            }
        }
        return matrix;
    }

    public static DenseMatrix rand(ScalarType scalarType, int m, int n) {
        DenseMatrix matrix = DenseMatrix.zeros(scalarType, m, n);
        for (int j = 0; j < n; ++j) {
            for (int i = 0; i < m; ++i) {
                matrix.set(i, j, MathEx.random());
            }
        }
        return matrix;
    }

    public static DenseMatrix rand(ScalarType scalarType, int m, int n, double lo, double hi) {
        DenseMatrix matrix = DenseMatrix.zeros(scalarType, m, n);
        for (int j = 0; j < n; ++j) {
            for (int i = 0; i < m; ++i) {
                matrix.set(i, j, MathEx.random(lo, hi));
            }
        }
        return matrix;
    }

    public static DenseMatrix toeplitz(double[] a) {
        int n = a.length;
        DenseMatrix toeplitz = DenseMatrix.zeros(ScalarType.Float64, n, n);
        toeplitz.withUplo(UPLO.LOWER);
        for (int i = 0; i < n; ++i) {
            int j;
            for (j = 0; j < i; ++j) {
                toeplitz.set(i, j, a[i - j]);
            }
            for (j = i; j < n; ++j) {
                toeplitz.set(i, j, a[j - i]);
            }
        }
        return toeplitz;
    }

    public static DenseMatrix toeplitz(float[] a) {
        int n = a.length;
        DenseMatrix toeplitz = DenseMatrix.zeros(ScalarType.Float32, n, n);
        toeplitz.withUplo(UPLO.LOWER);
        for (int i = 0; i < n; ++i) {
            int j;
            for (j = 0; j < i; ++j) {
                toeplitz.set(i, j, a[i - j]);
            }
            for (j = i; j < n; ++j) {
                toeplitz.set(i, j, a[j - i]);
            }
        }
        return toeplitz;
    }

    public static DenseMatrix toeplitz(double[] kl, double[] ku) {
        if (kl.length != ku.length - 1) {
            throw new IllegalArgumentException(String.format("Invalid sub-diagonals and super-diagonals size: %d != %d - 1", kl.length, ku.length));
        }
        int n = kl.length;
        DenseMatrix toeplitz = DenseMatrix.zeros(ScalarType.Float64, n, n);
        for (int i = 0; i < n; ++i) {
            int j;
            for (j = 0; j < i; ++j) {
                toeplitz.set(i, j, kl[i - j]);
            }
            for (j = i; j < n; ++j) {
                toeplitz.set(i, j, ku[j - i]);
            }
        }
        return toeplitz;
    }

    public static DenseMatrix toeplitz(float[] kl, float[] ku) {
        if (kl.length != ku.length - 1) {
            throw new IllegalArgumentException(String.format("Invalid sub-diagonals and super-diagonals size: %d != %d - 1", kl.length, ku.length));
        }
        int n = kl.length;
        DenseMatrix toeplitz = DenseMatrix.zeros(ScalarType.Float32, n, n);
        for (int i = 0; i < n; ++i) {
            int j;
            for (j = 0; j < i; ++j) {
                toeplitz.set(i, j, kl[i - j]);
            }
            for (j = i; j < n; ++j) {
                toeplitz.set(i, j, ku[j - i]);
            }
        }
        return toeplitz;
    }
}

