/*
 * Decompiled with CFR 0.152.
 */
package smile.tensor;

import java.io.Serializable;
import java.lang.foreign.MemorySegment;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import smile.linalg.Transpose;
import smile.linalg.lapack.clapack_h;
import smile.tensor.DenseMatrix;
import smile.tensor.Vector;

public record LU(DenseMatrix lu, int[] ipiv, int info) implements Serializable
{
    private static final Logger logger = LoggerFactory.getLogger(LU.class);

    public boolean isSingular() {
        return this.info > 0;
    }

    public double det() {
        int j;
        int m = this.lu.m;
        int n = this.lu.n;
        if (m != n) {
            throw new IllegalArgumentException(String.format("The matrix is not square: %d x %d", m, n));
        }
        double d = 1.0;
        for (j = 0; j < n; ++j) {
            d *= this.lu.get(j, j);
        }
        for (j = 0; j < n; ++j) {
            if (j + 1 == this.ipiv[j]) continue;
            d = -d;
        }
        return d;
    }

    public DenseMatrix inverse() {
        DenseMatrix inv = this.lu.eye(this.lu.n);
        this.solve(inv);
        return inv;
    }

    public Vector solve(double[] b) {
        Vector x = this.lu.vector(this.lu.n);
        for (int i = 0; i < this.lu.n; ++i) {
            x.set(i, b[i]);
        }
        this.solve(x);
        return x;
    }

    public Vector solve(float[] b) {
        Vector x = this.lu.vector(this.lu.n);
        for (int i = 0; i < this.lu.n; ++i) {
            x.set(i, b[i]);
        }
        this.solve(x);
        return x;
    }

    public void solve(DenseMatrix B) {
        if (this.info > 0) {
            throw new RuntimeException("The matrix is singular.");
        }
        if (this.lu.order() != B.order()) {
            throw new IllegalArgumentException("The matrix layout is inconsistent.");
        }
        if (this.lu.scalarType() != B.scalarType()) {
            throw new IllegalArgumentException("Incompatible ScalarType: " + String.valueOf((Object)B.scalarType()) + " != " + String.valueOf((Object)this.lu.scalarType()));
        }
        if (this.lu.m != this.lu.n) {
            throw new IllegalArgumentException(String.format("The matrix is not square: %d x %d", this.lu.m, this.lu.n));
        }
        if (this.lu.m != B.m) {
            throw new IllegalArgumentException(String.format("Row dimensions do not agree: A is %d x %d, but B is %d x %d", this.lu.m, this.lu.n, B.m, B.n));
        }
        byte[] trans = new byte[]{Transpose.NO_TRANSPOSE.lapack()};
        int[] n = new int[]{this.lu.n};
        int[] nrhs = new int[]{B.n};
        int[] lda = new int[]{this.lu.ld};
        int[] ldb = new int[]{B.ld};
        int[] info = new int[]{0};
        MemorySegment trans_ = MemorySegment.ofArray(trans);
        MemorySegment n_ = MemorySegment.ofArray(n);
        MemorySegment nrhs_ = MemorySegment.ofArray(nrhs);
        MemorySegment lda_ = MemorySegment.ofArray(lda);
        MemorySegment ldb_ = MemorySegment.ofArray(ldb);
        MemorySegment ipiv_ = MemorySegment.ofArray(this.ipiv);
        MemorySegment info_ = MemorySegment.ofArray(info);
        switch (this.lu.scalarType()) {
            case Float64: {
                clapack_h.dgetrs_(trans_, n_, nrhs_, this.lu.memory, lda_, ipiv_, B.memory, ldb_, info_);
                break;
            }
            case Float32: {
                clapack_h.sgetrs_(trans_, n_, nrhs_, this.lu.memory, lda_, ipiv_, B.memory, ldb_, info_);
                break;
            }
            default: {
                throw new UnsupportedOperationException("Unsupported scalar type: " + String.valueOf((Object)this.lu.scalarType()));
            }
        }
        if (info[0] != 0) {
            logger.error("LAPACK GETRS error code: {}", (Object)info[0]);
            throw new ArithmeticException("LAPACK GETRS error code: " + info[0]);
        }
    }
}

