/*
 * Decompiled with CFR 0.152.
 */
package smile.manifold;

import java.util.Arrays;
import java.util.stream.IntStream;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import smile.math.MathEx;
import smile.stat.distribution.GaussianDistribution;

public class TSNE {
    private static final Logger logger = LoggerFactory.getLogger(TSNE.class);
    public final double[][] coordinates;
    private double eta = 500.0;
    private double momentum = 0.5;
    private double finalMomentum = 0.8;
    private int momentumSwitchIter = 250;
    private double minGain = 0.01;
    private int totalIter = 1;
    private double[][] D;
    private double[][] dY;
    private double[][] gains;
    private double[][] P;
    private double[][] Q;
    private double Qsum;

    public TSNE(double[][] X, int d) {
        this(X, d, 20.0, 200.0, 1000);
    }

    public TSNE(double[][] X, int d, double perplexity, double eta, int iterations) {
        this.eta = eta;
        int n = X.length;
        if (X.length == X[0].length) {
            this.D = X;
        } else {
            this.D = new double[n][n];
            MathEx.pdist((double[][])X, (double[][])this.D, (boolean)true, (boolean)false);
        }
        double[][] Y = this.coordinates = new double[n][d];
        this.dY = new double[n][d];
        this.gains = new double[n][d];
        GaussianDistribution gaussian = new GaussianDistribution(0.0, 1.0E-4);
        for (int i = 0; i < n; ++i) {
            Arrays.fill(this.gains[i], 1.0);
            double[] Yi = Y[i];
            for (int j = 0; j < d; ++j) {
                Yi[j] = gaussian.rand();
            }
        }
        this.P = this.expd(this.D, perplexity, 0.001);
        this.Q = new double[n][n];
        double Psum = 2 * n;
        for (int i = 0; i < n; ++i) {
            double[] Pi = this.P[i];
            for (int j = 0; j < i; ++j) {
                double p = 12.0 * (Pi[j] + this.P[j][i]) / Psum;
                if (Double.isNaN(p) || p < 1.0E-16) {
                    p = 1.0E-16;
                }
                Pi[j] = p;
                this.P[j][i] = p;
            }
        }
        this.update(iterations);
    }

    public void update(int iterations) {
        int j;
        double[][] Y = this.coordinates;
        int n = Y.length;
        int d = Y[0].length;
        int iter = 1;
        while (iter <= iterations) {
            int i2;
            MathEx.pdist((double[][])Y, (double[][])this.Q, (boolean)true, (boolean)false);
            this.Qsum = 0.0;
            for (i2 = 0; i2 < n; ++i2) {
                double[] Qi = this.Q[i2];
                for (j = 0; j < i2; ++j) {
                    double q;
                    Qi[j] = q = 1.0 / (1.0 + Qi[j]);
                    this.Q[j][i2] = q;
                    this.Qsum += q;
                }
            }
            this.Qsum *= 2.0;
            IntStream.range(0, n).parallel().forEach(i -> this.sne(i));
            if (this.totalIter == this.momentumSwitchIter) {
                this.momentum = this.finalMomentum;
                for (i2 = 0; i2 < n; ++i2) {
                    double[] Pi = this.P[i2];
                    j = 0;
                    while (j < n) {
                        int n2 = j++;
                        Pi[n2] = Pi[n2] / 12.0;
                    }
                }
            }
            if (iter % 50 == 0) {
                double C = 0.0;
                for (int i3 = 0; i3 < n; ++i3) {
                    double[] Pi = this.P[i3];
                    double[] Qi = this.Q[i3];
                    for (int j2 = 0; j2 < i3; ++j2) {
                        double p = Pi[j2];
                        double q = Qi[j2] / this.Qsum;
                        if (Double.isNaN(q) || q < 1.0E-16) {
                            q = 1.0E-16;
                        }
                        C += p * MathEx.log2((double)(p / q));
                    }
                }
                logger.info("Error after {} iterations: {}", (Object)this.totalIter, (Object)(2.0 * C));
            }
            ++iter;
            ++this.totalIter;
        }
        double[] colMeans = MathEx.colMeans((double[][])Y);
        for (int i4 = 0; i4 < n; ++i4) {
            double[] Yi = Y[i4];
            for (j = 0; j < d; ++j) {
                int n3 = j;
                Yi[n3] = Yi[n3] - colMeans[j];
            }
        }
    }

    private void sne(int i) {
        double[] dC = new double[this.coordinates[0].length];
        double[][] Y = this.coordinates;
        int n = Y.length;
        int d = Y[0].length;
        double[] Yi = Y[i];
        double[] Pi = this.P[i];
        double[] Qi = this.Q[i];
        double[] dYi = this.dY[i];
        double[] g = this.gains[i];
        for (int j = 0; j < n; ++j) {
            if (i == j) continue;
            double[] Yj = Y[j];
            double q = Qi[j];
            double z = (Pi[j] - q / this.Qsum) * q;
            for (int k = 0; k < d; ++k) {
                int n2 = k;
                dC[n2] = dC[n2] + 4.0 * (Yi[k] - Yj[k]) * z;
            }
        }
        for (int k = 0; k < d; ++k) {
            double d2 = g[k] = Math.signum(dC[k]) != Math.signum(dYi[k]) ? g[k] + 0.2 : g[k] * 0.8;
            if (g[k] < this.minGain) {
                g[k] = this.minGain;
            }
            int n3 = k;
            Yi[n3] = Yi[n3] + dYi[k];
            dYi[k] = this.momentum * dYi[k] - this.eta * g[k] * dC[k];
        }
    }

    private double[][] expd(double[][] D, double perplexity, double tol) {
        int n = D.length;
        double[][] P = new double[n][n];
        double[] DiSum = MathEx.rowSums((double[][])D);
        IntStream.range(0, n).parallel().forEach(i -> {
            double logU = MathEx.log2((double)perplexity);
            double[] Pi = P[i];
            double[] Di = D[i];
            double beta = Math.sqrt((double)(n - 1) / DiSum[i]);
            double betamin = 0.0;
            double betamax = Double.POSITIVE_INFINITY;
            logger.debug("initial beta[{}] = {}", (Object)i, (Object)beta);
            double Hdiff = Double.MAX_VALUE;
            for (int iter = 0; Math.abs(Hdiff) > tol && iter < 50; ++iter) {
                int j;
                double Pisum = 0.0;
                double H = 0.0;
                for (j = 0; j < n; ++j) {
                    double p;
                    double d = beta * Di[j];
                    Pi[j] = p = Math.exp(-d);
                    Pisum += p;
                    H += p * d;
                }
                Pi[i] = 0.0;
                if (Math.abs(Hdiff = (H = MathEx.log2((double)(Pisum -= 1.0)) + H / Pisum) - logU) > tol) {
                    if (Hdiff > 0.0) {
                        betamin = beta;
                        beta = Double.isInfinite(betamax) ? (beta *= 2.0) : (beta + betamax) / 2.0;
                    } else {
                        betamax = beta;
                        beta = (beta + betamin) / 2.0;
                    }
                } else {
                    j = 0;
                    while (j < n) {
                        int n2 = j++;
                        Pi[n2] = Pi[n2] / Pisum;
                    }
                }
                logger.debug("Hdiff = {}, beta[{}] = {}, H = {}, logU = {}", new Object[]{Hdiff, i, beta, H, logU});
            }
        });
        return P;
    }
}

