/*
 * Decompiled with CFR 0.152.
 */
package smile.tensor;

import java.io.Serializable;
import java.lang.foreign.MemorySegment;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import smile.linalg.lapack.clapack_h;
import smile.tensor.DenseMatrix;
import smile.tensor.Vector;

public record Cholesky(DenseMatrix lu) implements Serializable
{
    private static final Logger logger = LoggerFactory.getLogger(Cholesky.class);

    public Cholesky {
        if (lu.nrow() != lu.ncol()) {
            throw new UnsupportedOperationException("Cholesky constructor on a non-square matrix");
        }
    }

    public double det() {
        int n = this.lu.n;
        double d = 1.0;
        for (int i = 0; i < n; ++i) {
            d *= this.lu.get(i, i);
        }
        return d * d;
    }

    public double logdet() {
        int n = this.lu.n;
        double d = 0.0;
        for (int i = 0; i < n; ++i) {
            d += Math.log(this.lu.get(i, i));
        }
        return 2.0 * d;
    }

    public DenseMatrix inverse() {
        DenseMatrix inv = this.lu.eye(this.lu.n);
        this.solve(inv);
        return inv;
    }

    public Vector solve(double[] b) {
        Vector x = this.lu.vector(this.lu.n);
        for (int i = 0; i < this.lu.n; ++i) {
            x.set(i, b[i]);
        }
        this.solve(x);
        return x;
    }

    public Vector solve(float[] b) {
        Vector x = this.lu.vector(this.lu.n);
        for (int i = 0; i < this.lu.n; ++i) {
            x.set(i, b[i]);
        }
        this.solve(x);
        return x;
    }

    public void solve(DenseMatrix B) {
        if (this.lu.scalarType() != B.scalarType()) {
            throw new IllegalArgumentException("Incompatible ScalarType: " + String.valueOf((Object)B.scalarType()) + " != " + String.valueOf((Object)this.lu.scalarType()));
        }
        if (this.lu.m != B.m) {
            throw new IllegalArgumentException(String.format("Row dimensions do not agree: A is %d x %d, but B is %d x %d", this.lu.m, this.lu.n, B.m, B.n));
        }
        byte[] uplo = new byte[]{this.lu.uplo.lapack()};
        int[] n = new int[]{this.lu.n};
        int[] nrhs = new int[]{B.n};
        int[] lda = new int[]{this.lu.ld};
        int[] ldb = new int[]{B.ld};
        int[] info = new int[]{0};
        MemorySegment uplo_ = MemorySegment.ofArray(uplo);
        MemorySegment n_ = MemorySegment.ofArray(n);
        MemorySegment nrhs_ = MemorySegment.ofArray(nrhs);
        MemorySegment lda_ = MemorySegment.ofArray(lda);
        MemorySegment ldb_ = MemorySegment.ofArray(ldb);
        MemorySegment info_ = MemorySegment.ofArray(info);
        switch (this.lu.scalarType()) {
            case Float64: {
                clapack_h.dpotrs_(uplo_, n_, nrhs_, this.lu.memory, lda_, B.memory, ldb_, info_);
                break;
            }
            case Float32: {
                clapack_h.spotrs_(uplo_, n_, nrhs_, this.lu.memory, lda_, B.memory, ldb_, info_);
                break;
            }
            default: {
                throw new UnsupportedOperationException("Unsupported scalar type: " + String.valueOf((Object)this.lu.scalarType()));
            }
        }
        if (info[0] != 0) {
            logger.error("LAPACK POTRS error code: {}", (Object)info[0]);
            throw new ArithmeticException("LAPACK POTRS error code: " + info[0]);
        }
    }
}

