/*
 * Decompiled with CFR 0.152.
 */
package smile.projection;

import java.io.Serializable;
import java.util.Properties;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import smile.math.DifferentiableFunction;
import smile.math.MathEx;
import smile.math.matrix.DenseMatrix;
import smile.math.matrix.EVD;
import smile.math.matrix.Matrix;
import smile.projection.ica.Gaussian;
import smile.projection.ica.LogCosh;
import smile.stat.distribution.GaussianDistribution;

public class ICA
implements Serializable {
    private static final long serialVersionUID = 2L;
    private static final Logger logger = LoggerFactory.getLogger(ICA.class);
    public final double[][] components;

    public ICA(double[][] components) {
        this.components = components;
    }

    public static ICA fit(double[][] data, int p) {
        return ICA.fit(data, p, new Properties());
    }

    public static ICA fit(double[][] data, int p, Properties prop) {
        Object f;
        String contrast;
        switch (contrast = prop.getProperty("smile.ica.contrast", "LogCosh")) {
            case "LogCosh": {
                f = new LogCosh();
                break;
            }
            case "Gaussian": {
                f = new Gaussian();
                break;
            }
            default: {
                throw new IllegalArgumentException("Unsupported contrast function: " + contrast);
            }
        }
        double tol = Double.valueOf(prop.getProperty("smile.ica.tolerance", "1E-4"));
        int maxIter = Integer.valueOf(prop.getProperty("smile.ica.max.iterations", "100"));
        return ICA.fit(data, p, f, tol, maxIter);
    }

    public static ICA fit(double[][] data, int p, DifferentiableFunction contrast, double tol, int maxIter) {
        if (tol <= 0.0) {
            throw new IllegalArgumentException("Invalid tolerance: " + tol);
        }
        if (maxIter <= 0) {
            throw new IllegalArgumentException("Invalid maximum number of iterations: " + maxIter);
        }
        int n = data.length;
        int m = data[0].length;
        if (p < 1 || p > m) {
            throw new IllegalArgumentException("Invalid dimension of feature space: " + p);
        }
        DenseMatrix projection = Matrix.zeros((int)p, (int)m);
        GaussianDistribution g = new GaussianDistribution(0.0, 1.0);
        double[][] W = new double[p][n];
        for (int i = 0; i < p; ++i) {
            for (int j = 0; j < n; ++j) {
                W[i][j] = g.rand();
            }
            MathEx.unitize((double[])W[i]);
        }
        DenseMatrix X = ICA.whiten(data);
        double[] wold = new double[n];
        double[] wdif = new double[n];
        double[] gwX = new double[m];
        double[] g2w = new double[n];
        for (int i = 0; i < p; ++i) {
            double[] w = W[i];
            double diff = Double.MAX_VALUE;
            for (int iter = 0; iter < maxIter && diff > tol; ++iter) {
                int j;
                System.arraycopy(w, 0, wold, 0, n);
                double[] wX = new double[m];
                X.atx(w, wX);
                double g2 = 0.0;
                for (j = 0; j < m; ++j) {
                    gwX[j] = contrast.g(wX[j]);
                    g2 += contrast.g2(wX[j]);
                }
                for (j = 0; j < n; ++j) {
                    g2w[j] = w[j] * g2;
                }
                X.ax(gwX, w);
                for (j = 0; j < n; ++j) {
                    w[j] = (w[j] - g2w[j]) / (double)m;
                }
                for (int k = 0; k < i; ++k) {
                    double[] wk = W[k];
                    double wkw = MathEx.dot((double[])W[k], (double[])w);
                    for (int j2 = 0; j2 < n; ++j2) {
                        int n2 = j2;
                        w[n2] = w[n2] - wkw * wk[j2];
                    }
                }
                MathEx.unitize2((double[])w);
                for (j = 0; j < n; ++j) {
                    wdif[j] = w[j] - wold[j];
                }
                double n1 = MathEx.norm((double[])wdif);
                for (int j3 = 0; j3 < n; ++j3) {
                    wdif[j3] = w[j3] + wold[j3];
                }
                double n2 = MathEx.norm((double[])wdif);
                diff = Math.min(n1, n2);
            }
            if (!(diff > tol)) continue;
            logger.warn(String.format("Component %d did not converge in %d iterations.", i, maxIter));
        }
        return new ICA(W);
    }

    private static DenseMatrix whiten(double[][] data) {
        double[] mean = MathEx.rowMeans((double[][])data);
        DenseMatrix X = Matrix.of((double[][])data);
        int n = X.nrows();
        int m = X.ncols();
        for (int j = 0; j < m; ++j) {
            for (int i = 0; i < n; ++i) {
                X.sub(i, j, mean[i]);
            }
        }
        DenseMatrix XXt = X.aat();
        XXt.setSymmetric(true);
        EVD eigen = XXt.eigen();
        DenseMatrix E = eigen.getEigenVectors();
        DenseMatrix Y = (DenseMatrix)E.atbmm((Object)X);
        double[] d = eigen.getEigenValues();
        for (int i = 0; i < d.length; ++i) {
            if (d[i] < 1.0E-8) {
                throw new IllegalArgumentException(String.format("Covariance matrix (column %d) is close to singular.", i));
            }
            d[i] = 1.0 / Math.sqrt(d[i]);
        }
        for (int j = 0; j < m; ++j) {
            for (int i = 0; i < n; ++i) {
                Y.mul(i, j, d[i]);
            }
        }
        return Y;
    }
}

