package smile.projection;

import java.io.Serializable;
import smile.math.MathEx;
import smile.math.blas.UPLO;
import smile.math.matrix.Matrix;

/* loaded from: input_file:smile/projection/PCA.class */
public class PCA implements LinearProjection, Serializable {
    private static final long serialVersionUID = 2;
    private int p;
    private int n;
    private double[] mu;
    private double[] pmu;
    private Matrix eigvectors;
    private double[] eigvalues;
    private double[] proportion;
    private double[] cumulativeProportion;
    private Matrix projection;

    public PCA(double[] dArr, double[] dArr2, Matrix matrix) {
        this.mu = dArr;
        this.eigvalues = dArr2;
        this.eigvectors = matrix;
        this.n = dArr.length;
        this.proportion = (double[]) dArr2.clone();
        MathEx.unitize1(this.proportion);
        this.cumulativeProportion = new double[dArr2.length];
        this.cumulativeProportion[0] = this.proportion[0];
        for (int i = 1; i < dArr2.length; i++) {
            this.cumulativeProportion[i] = this.cumulativeProportion[i - 1] + this.proportion[i];
        }
        setProjection(0.95d);
    }

    public static PCA fit(double[][] dArr) {
        double[] dArr2;
        Matrix matrix;
        int length = dArr.length;
        int length2 = dArr[0].length;
        double[] colMeans = MathEx.colMeans(dArr);
        Matrix matrix2 = new Matrix(dArr);
        for (int i = 0; i < length2; i++) {
            for (int i2 = 0; i2 < length; i2++) {
                matrix2.sub(i2, i, colMeans[i]);
            }
        }
        if (length > length2) {
            Matrix.SVD svd = matrix2.svd(true, true);
            dArr2 = svd.s;
            for (int i3 = 0; i3 < dArr2.length; i3++) {
                int i4 = i3;
                dArr2[i4] = dArr2[i4] * dArr2[i3];
            }
            matrix = svd.V;
        } else {
            Matrix matrix3 = new Matrix(length2, length2);
            for (int i5 = 0; i5 < length; i5++) {
                for (int i6 = 0; i6 < length2; i6++) {
                    for (int i7 = 0; i7 <= i6; i7++) {
                        matrix3.add(i6, i7, matrix2.get(i5, i6) * matrix2.get(i5, i7));
                    }
                }
            }
            for (int i8 = 0; i8 < length2; i8++) {
                for (int i9 = 0; i9 <= i8; i9++) {
                    matrix3.div(i8, i9, length);
                    matrix3.set(i9, i8, matrix3.get(i8, i9));
                }
            }
            matrix3.uplo(UPLO.LOWER);
            Matrix.EVD sort = matrix3.eigen(false, true, true).sort();
            dArr2 = sort.wr;
            matrix = sort.Vr;
        }
        return new PCA(colMeans, dArr2, matrix);
    }

    public static PCA cor(double[][] dArr) {
        int length = dArr.length;
        int length2 = dArr[0].length;
        double[] colMeans = MathEx.colMeans(dArr);
        Matrix matrix = new Matrix(dArr);
        for (int i = 0; i < length2; i++) {
            for (int i2 = 0; i2 < length; i2++) {
                matrix.sub(i2, i, colMeans[i]);
            }
        }
        Matrix matrix2 = new Matrix(length2, length2);
        for (int i3 = 0; i3 < length; i3++) {
            for (int i4 = 0; i4 < length2; i4++) {
                for (int i5 = 0; i5 <= i4; i5++) {
                    matrix2.add(i4, i5, matrix.get(i3, i4) * matrix.get(i3, i5));
                }
            }
        }
        for (int i6 = 0; i6 < length2; i6++) {
            for (int i7 = 0; i7 <= i6; i7++) {
                matrix2.div(i6, i7, length);
                matrix2.set(i7, i6, matrix2.get(i6, i7));
            }
        }
        double[] dArr2 = new double[length2];
        for (int i8 = 0; i8 < length2; i8++) {
            dArr2[i8] = Math.sqrt(matrix2.get(i8, i8));
        }
        for (int i9 = 0; i9 < length2; i9++) {
            for (int i10 = 0; i10 <= i9; i10++) {
                matrix2.div(i9, i10, dArr2[i9] * dArr2[i10]);
                matrix2.set(i10, i9, matrix2.get(i9, i10));
            }
        }
        matrix2.uplo(UPLO.LOWER);
        Matrix.EVD sort = matrix2.eigen(false, true, true).sort();
        Matrix matrix3 = sort.Vr;
        for (int i11 = 0; i11 < length2; i11++) {
            for (int i12 = 0; i12 < length2; i12++) {
                matrix3.div(i11, i12, dArr2[i11]);
            }
        }
        return new PCA(colMeans, sort.wr, matrix3);
    }

    public double[] getCenter() {
        return this.mu;
    }

    public Matrix getLoadings() {
        return this.eigvectors;
    }

    public double[] getVariance() {
        return this.eigvalues;
    }

    public double[] getVarianceProportion() {
        return this.proportion;
    }

    public double[] getCumulativeVarianceProportion() {
        return this.cumulativeProportion;
    }

    @Override // smile.projection.LinearProjection
    public Matrix getProjection() {
        return this.projection;
    }

    public PCA setProjection(int i) {
        if (i < 1 || i > this.n) {
            throw new IllegalArgumentException("Invalid dimension of feature space: " + i);
        }
        this.p = i;
        this.projection = new Matrix(i, this.n);
        for (int i2 = 0; i2 < this.n; i2++) {
            for (int i3 = 0; i3 < i; i3++) {
                this.projection.set(i3, i2, this.eigvectors.get(i2, i3));
            }
        }
        this.pmu = this.projection.mv(this.mu);
        return this;
    }

    public PCA setProjection(double d) {
        if (d <= 0.0d || d > 1.0d) {
            throw new IllegalArgumentException("Invalid percentage of variance: " + d);
        }
        int i = 0;
        while (true) {
            if (i >= this.n) {
                break;
            }
            if (this.cumulativeProportion[i] >= d) {
                setProjection(i + 1);
                break;
            }
            i++;
        }
        return this;
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // smile.projection.LinearProjection, smile.projection.Projection
    public double[] project(double[] dArr) {
        if (dArr.length != this.n) {
            throw new IllegalArgumentException(String.format("Invalid input vector size: %d, expected: %d", Integer.valueOf(dArr.length), Integer.valueOf(this.n)));
        }
        double[] mv = this.projection.mv(dArr);
        MathEx.sub(mv, this.pmu);
        return mv;
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // smile.projection.LinearProjection, smile.projection.Projection
    public double[][] project(double[][] dArr) {
        if (dArr[0].length != this.mu.length) {
            throw new IllegalArgumentException(String.format("Invalid input vector size: %d, expected: %d", Integer.valueOf(dArr[0].length), Integer.valueOf(this.n)));
        }
        double[][] dArr2 = new double[dArr.length][this.p];
        for (int i = 0; i < dArr.length; i++) {
            this.projection.mv(dArr[i], dArr2[i]);
            MathEx.sub(dArr2[i], this.pmu);
        }
        return dArr2;
    }
}
