/*
 * Decompiled with CFR 0.152.
 */
package smile.ica;

import java.io.Serializable;
import java.lang.reflect.Constructor;
import java.lang.runtime.SwitchBootstraps;
import java.util.Objects;
import java.util.Properties;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import smile.ica.Exp;
import smile.ica.Kurtosis;
import smile.ica.LogCosh;
import smile.math.MathEx;
import smile.stat.distribution.GaussianDistribution;
import smile.tensor.DenseMatrix;
import smile.tensor.EVD;
import smile.tensor.Vector;
import smile.util.function.DifferentiableFunction;

public record ICA(double[][] components) implements Serializable
{
    private static final long serialVersionUID = 2L;
    private static final Logger logger = LoggerFactory.getLogger(ICA.class);

    public static ICA fit(double[][] data, int p) {
        return ICA.fit(data, p, new Options(new LogCosh(), 100));
    }

    public static ICA fit(double[][] data, int p, Options options) {
        int n = data[0].length;
        int m = data.length;
        if (p < 1 || p > m) {
            throw new IllegalArgumentException("Invalid dimension of feature space: " + p);
        }
        DifferentiableFunction contrast = options.contrast;
        int maxIter = options.maxIter;
        double tol = options.tol;
        GaussianDistribution g = new GaussianDistribution(0.0, 1.0);
        double[][] W = new double[p][n];
        for (int i = 0; i < p; ++i) {
            for (int j = 0; j < n; ++j) {
                W[i][j] = g.rand();
            }
            MathEx.unitize(W[i]);
        }
        DenseMatrix X = ICA.whiten(data);
        double[] wold = new double[n];
        double[] wdif = new double[n];
        double[] gwX = new double[m];
        Vector gwX_ = Vector.column(gwX);
        double[] g2w = new double[n];
        double[] wX = new double[m];
        Vector wX_ = Vector.column(wX);
        for (int i = 0; i < p; ++i) {
            double[] w = W[i];
            Vector w_ = Vector.column(w);
            double diff = Double.MAX_VALUE;
            for (int iter = 1; iter <= maxIter && diff > tol; ++iter) {
                int j;
                System.arraycopy(w, 0, wold, 0, n);
                X.tv(w_, wX_);
                double g2 = 0.0;
                for (j = 0; j < m; ++j) {
                    gwX[j] = contrast.g(wX[j]);
                    g2 += contrast.g2(wX[j]);
                }
                for (j = 0; j < n; ++j) {
                    g2w[j] = w[j] * g2;
                }
                X.mv(gwX_, w_);
                for (j = 0; j < n; ++j) {
                    w[j] = (w[j] - g2w[j]) / (double)m;
                }
                for (int k = 0; k < i; ++k) {
                    double[] wk = W[k];
                    double wkw = MathEx.dot(W[k], w);
                    for (int j2 = 0; j2 < n; ++j2) {
                        int n2 = j2;
                        w[n2] = w[n2] - wkw * wk[j2];
                    }
                }
                MathEx.unitize2(w);
                for (j = 0; j < n; ++j) {
                    wdif[j] = w[j] - wold[j];
                }
                double n1 = MathEx.norm(wdif);
                for (int j3 = 0; j3 < n; ++j3) {
                    wdif[j3] = w[j3] + wold[j3];
                }
                double n2 = MathEx.norm(wdif);
                diff = Math.min(n1, n2);
            }
            if (!(diff > tol)) continue;
            logger.warn("Component {} did not converge in {} iterations.", (Object)i, (Object)maxIter);
        }
        return new ICA(W);
    }

    private static DenseMatrix whiten(double[][] data) {
        double[] mean = MathEx.rowMeans(data);
        DenseMatrix X = DenseMatrix.of(data).transpose();
        int n = X.nrow();
        int m = X.ncol();
        for (int j = 0; j < m; ++j) {
            double mu = mean[j];
            for (int i = 0; i < n; ++i) {
                X.sub(i, j, mu);
            }
        }
        DenseMatrix XtX = X.ata();
        EVD eigen = XtX.eigen();
        DenseMatrix E = eigen.Vr();
        DenseMatrix Y = X.mm(E);
        double[] d = eigen.wr().toArray(new double[0]);
        for (int i = 0; i < d.length; ++i) {
            if (d[i] < 1.0E-8) {
                throw new IllegalArgumentException(String.format("Covariance matrix (column %d) is close to singular.", i));
            }
            d[i] = 1.0 / Math.sqrt(d[i]);
        }
        for (int j = 0; j < m; ++j) {
            for (int i = 0; i < n; ++i) {
                Y.mul(i, j, d[j]);
            }
        }
        return Y;
    }

    public record Options(DifferentiableFunction contrast, int maxIter, double tol) {
        public Options {
            if (tol <= 0.0) {
                throw new IllegalArgumentException("Invalid tolerance: " + tol);
            }
            if (maxIter <= 0) {
                throw new IllegalArgumentException("Invalid maximum number of iterations: " + maxIter);
            }
        }

        public Options(DifferentiableFunction contrast, int maxIter) {
            this(contrast, maxIter, 1.0E-4);
        }

        public Options(String contrast, int maxIter) {
            this(switch (contrast) {
                case "LogCosh" -> new LogCosh();
                case "Gaussian" -> new Exp();
                case "Kurtosis" -> new Kurtosis();
                default -> throw new IllegalArgumentException("Unsupported contrast function: " + contrast);
            }, maxIter);
        }

        public Properties toProperties() {
            Properties props = new Properties();
            props.setProperty("smile.ica.iterations", Integer.toString(this.maxIter));
            props.setProperty("smile.ica.tolerance", Double.toString(this.tol));
            DifferentiableFunction differentiableFunction = this.contrast;
            Objects.requireNonNull(differentiableFunction);
            DifferentiableFunction differentiableFunction2 = differentiableFunction;
            int n = 0;
            String name = switch (SwitchBootstraps.typeSwitch("typeSwitch", new Object[]{LogCosh.class, Exp.class, Kurtosis.class}, (DifferentiableFunction)differentiableFunction2, n)) {
                case 0 -> {
                    LogCosh cosh = (LogCosh)differentiableFunction2;
                    yield "LogCosh";
                }
                case 1 -> {
                    Exp exp = (Exp)differentiableFunction2;
                    yield "Gaussian";
                }
                case 2 -> {
                    Kurtosis kurtosis = (Kurtosis)differentiableFunction2;
                    yield "Kurtosis";
                }
                default -> this.getClass().getName();
            };
            props.setProperty("smile.ica.contrast", name);
            return props;
        }

        public static Options of(Properties props) throws ReflectiveOperationException {
            String name;
            DifferentiableFunction contrast = switch (name = props.getProperty("smile.ica.contrast", "LogCosh")) {
                case "LogCosh" -> new LogCosh();
                case "Gaussian" -> new Exp();
                case "Kurtosis" -> new Kurtosis();
                default -> {
                    Class<?> clazz = Class.forName(name);
                    Constructor<?> constructor = clazz.getDeclaredConstructor(new Class[0]);
                    yield (DifferentiableFunction)constructor.newInstance(new Object[0]);
                }
            };
            int maxIter = Integer.parseInt(props.getProperty("smile.ica.iterations", "100"));
            double tol = Double.parseDouble(props.getProperty("smile.ica.tolerance", "1E-4"));
            return new Options(contrast, maxIter, tol);
        }
    }
}

