/*
 * Decompiled with CFR 0.152.
 */
package smile.projection;

import java.io.Serializable;
import smile.math.MathEx;
import smile.math.matrix.Cholesky;
import smile.math.matrix.DenseMatrix;
import smile.math.matrix.EVD;
import smile.math.matrix.Matrix;
import smile.projection.LinearProjection;

public class PPCA
implements LinearProjection,
Serializable {
    private static final long serialVersionUID = 2L;
    private double[] mu;
    private double[] pmu;
    private double noise;
    private DenseMatrix loading;
    private DenseMatrix projection;

    public PPCA(double noise, double[] mu, DenseMatrix loading, DenseMatrix projection) {
        this.noise = noise;
        this.mu = mu;
        this.loading = loading;
        this.projection = projection;
        this.pmu = new double[projection.nrows()];
        projection.ax(mu, this.pmu);
    }

    public DenseMatrix getLoadings() {
        return this.loading;
    }

    public double[] getCenter() {
        return this.mu;
    }

    public double getNoiseVariance() {
        return this.noise;
    }

    @Override
    public DenseMatrix getProjection() {
        return this.projection;
    }

    @Override
    public double[] project(double[] x) {
        if (x.length != this.mu.length) {
            throw new IllegalArgumentException(String.format("Invalid input vector size: %d, expected: %d", x.length, this.mu.length));
        }
        double[] y = new double[this.projection.nrows()];
        this.projection.ax(x, y);
        MathEx.sub((double[])y, (double[])this.pmu);
        return y;
    }

    @Override
    public double[][] project(double[][] x) {
        if (x[0].length != this.mu.length) {
            throw new IllegalArgumentException(String.format("Invalid input vector size: %d, expected: %d", x[0].length, this.mu.length));
        }
        double[][] y = new double[x.length][this.projection.nrows()];
        for (int i = 0; i < x.length; ++i) {
            this.projection.ax(x[i], y[i]);
            MathEx.sub((double[])y[i], (double[])this.pmu);
        }
        return y;
    }

    public static PPCA fit(double[][] data, int k) {
        int m = data.length;
        int n = data[0].length;
        double[] mu = MathEx.colMeans((double[][])data);
        DenseMatrix cov = Matrix.zeros((int)n, (int)n);
        for (int l = 0; l < m; ++l) {
            for (int i = 0; i < n; ++i) {
                for (int j = 0; j <= i; ++j) {
                    cov.add(i, j, (data[l][i] - mu[i]) * (data[l][j] - mu[j]));
                }
            }
        }
        for (int i = 0; i < n; ++i) {
            for (int j = 0; j <= i; ++j) {
                cov.div(i, j, (double)m);
                cov.set(j, i, cov.get(i, j));
            }
        }
        cov.setSymmetric(true);
        EVD eigen = cov.eigen();
        double[] evalues = eigen.getEigenValues();
        DenseMatrix evectors = eigen.getEigenVectors();
        double noise = 0.0;
        for (int i = k; i < n; ++i) {
            noise += evalues[i];
        }
        noise /= (double)(n - k);
        DenseMatrix loading = Matrix.zeros((int)n, (int)k);
        for (int i = 0; i < n; ++i) {
            for (int j = 0; j < k; ++j) {
                loading.set(i, j, evectors.get(i, j) * Math.sqrt(evalues[j] - noise));
            }
        }
        DenseMatrix M2 = loading.ata();
        for (int i = 0; i < k; ++i) {
            M2.add(i, i, noise);
        }
        Cholesky chol = M2.cholesky();
        DenseMatrix Mi = chol.inverse();
        DenseMatrix projection = (DenseMatrix)Mi.abtmm((Object)loading);
        return new PPCA(noise, mu, loading, projection);
    }
}

