/*
 * Decompiled with CFR 0.152.
 */
package smile.manifold;

import java.io.Serializable;
import java.util.Arrays;
import java.util.Properties;
import java.util.stream.IntStream;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import smile.math.MathEx;
import smile.stat.distribution.GaussianDistribution;
import smile.util.AlgoStatus;
import smile.util.IterativeAlgorithmController;

public record TSNE(double cost, double[][] coordinates) implements Serializable
{
    private static final long serialVersionUID = 3L;
    private static final Logger logger = LoggerFactory.getLogger(TSNE.class);

    public static TSNE fit(double[][] X) {
        return TSNE.fit(X, new Options(2, 20.0, 200.0, 12.0, 1000));
    }

    public static TSNE fit(double[][] X, Options options) {
        double cost;
        double[][] D;
        double eta = options.eta;
        int n = X.length;
        int d = options.d;
        if (X.length == X[0].length) {
            D = X;
        } else {
            D = new double[n][n];
            MathEx.pdist((Object[])X, (double[][])D, MathEx::squaredDistance);
        }
        double[][] coordinates = new double[n][d];
        double[][] gains = new double[n][d];
        GaussianDistribution gaussian = new GaussianDistribution(0.0, 1.0E-4);
        for (int i2 = 0; i2 < n; ++i2) {
            Arrays.fill(gains[i2], 1.0);
            double[] Yi = coordinates[i2];
            for (int j = 0; j < d; ++j) {
                Yi[j] = gaussian.rand();
            }
        }
        double[][] P = TSNE.expd(D, options.perplexity, 0.001);
        double[][] Q = new double[n][n];
        double[][] dY = new double[n][d];
        double[][] dC = new double[n][d];
        double Psum = 2 * n;
        for (int i3 = 0; i3 < n; ++i3) {
            double[] Pi = P[i3];
            for (int j = 0; j < i3; ++j) {
                double p = 12.0 * (Pi[j] + P[j][i3]) / Psum;
                if (Double.isNaN(p) || p < 1.0E-16) {
                    p = 1.0E-16;
                }
                Pi[j] = p;
                P[j][i3] = p;
            }
        }
        double bestCost = cost = Double.MAX_VALUE;
        int bestIter = 0;
        double momentum = options.momentum;
        for (int iter = 1; iter <= options.maxIter; ++iter) {
            double Qsum = TSNE.computeQ(coordinates, Q);
            IntStream.range(0, n).parallel().forEach(i -> TSNE.sne(i, coordinates, P, Q, gains, dY[i], dC[i], Qsum, options.minGain));
            double mu = momentum;
            double gradNorm = IntStream.range(0, n).parallel().mapToDouble(i -> {
                double[] Yi = coordinates[i];
                double[] dYi = dY[i];
                double[] dCi = dC[i];
                double[] g = gains[i];
                double norm = 0.0;
                for (int k = 0; k < d; ++k) {
                    dYi[k] = mu * dYi[k] - eta * g[k] * dCi[k];
                    int n = k;
                    Yi[n] = Yi[n] + dYi[k];
                    norm = Math.max(norm, Math.abs(dYi[k] * g[k]));
                }
                return norm;
            }).max().orElse(0.0);
            if (iter == options.momentumSwitchIter) {
                momentum = options.finalMomentum;
                for (int i4 = 0; i4 < n; ++i4) {
                    double[] Pi = P[i4];
                    int j = 0;
                    while (j < n) {
                        int n2 = j++;
                        Pi[n2] = Pi[n2] / options.earlyExaggeration;
                    }
                }
            }
            if (iter % 10 != 0 && iter != options.maxIter) continue;
            cost = TSNE.computeCost(P, Q, Qsum);
            logger.info("Iteration {}: error = {}", (Object)iter, (Object)cost);
            if (cost < bestCost) {
                bestCost = cost;
                bestIter = iter;
            }
            if (iter > options.momentumSwitchIter) {
                if (iter - bestIter > options.maxIterWithoutProgress) {
                    logger.info("Iteration {}: did not make any progress in last {} episodes. Finished", (Object)iter, (Object)options.maxIterWithoutProgress);
                    break;
                }
                if (gradNorm < options.tol) {
                    logger.info("Iteration {}: gradient norm = {}. Finished", (Object)iter, (Object)gradNorm);
                    break;
                }
            }
            if (options.controller == null) continue;
            options.controller.submit((Object)new AlgoStatus(iter, cost));
            if (options.controller.isInterrupted()) break;
        }
        double[] colMeans = MathEx.colMeans((double[][])coordinates);
        IntStream.range(0, n).parallel().forEach(i -> {
            double[] Yi = coordinates[i];
            for (int j = 0; j < d; ++j) {
                int n = j;
                Yi[n] = Yi[n] - colMeans[j];
            }
        });
        return new TSNE(cost, coordinates);
    }

    private static void sne(int i, double[][] Y, double[][] P, double[][] Q, double[][] gains, double[] dY, double[] dC, double Qsum, double minGain) {
        int n = Y.length;
        int d = Y[0].length;
        double[] Yi = Y[i];
        double[] Pi = P[i];
        double[] Qi = Q[i];
        double[] g = gains[i];
        Arrays.fill(dC, 0.0);
        for (int j = 0; j < n; ++j) {
            if (i == j) continue;
            double[] Yj = Y[j];
            double q = Qi[j];
            double z = (Pi[j] - q / Qsum) * q;
            for (int k = 0; k < d; ++k) {
                int n2 = k;
                dC[n2] = dC[n2] + 4.0 * (Yi[k] - Yj[k]) * z;
            }
        }
        for (int k = 0; k < d; ++k) {
            double d2 = g[k] = Math.signum(dC[k]) != Math.signum(dY[k]) ? g[k] + 0.2 : g[k] * 0.8;
            if (!(g[k] < minGain)) continue;
            g[k] = minGain;
        }
    }

    private static double[][] expd(double[][] D, double perplexity, double tol) {
        int n = D.length;
        double[][] P = new double[n][n];
        double[] DiSum = MathEx.rowSums((double[][])D);
        IntStream.range(0, n).parallel().forEach(i -> {
            double logU = MathEx.log2((double)perplexity);
            double[] Pi = P[i];
            double[] Di = D[i];
            double beta = Math.sqrt((double)(n - 1) / DiSum[i]);
            double betamin = 0.0;
            double betamax = Double.POSITIVE_INFINITY;
            logger.debug("initial beta[{}] = {}", (Object)i, (Object)beta);
            double Hdiff = Double.MAX_VALUE;
            for (int iter = 0; Math.abs(Hdiff) > tol && iter < 50; ++iter) {
                int j;
                double Pisum = 0.0;
                double H = 0.0;
                for (j = 0; j < n; ++j) {
                    double p;
                    double d = beta * Di[j];
                    Pi[j] = p = Math.exp(-d);
                    Pisum += p;
                    H += p * d;
                }
                Pi[i] = 0.0;
                if (Math.abs(Hdiff = (H = MathEx.log2((double)(Pisum -= 1.0)) + H / Pisum) - logU) > tol) {
                    if (Hdiff > 0.0) {
                        betamin = beta;
                        beta = Double.isInfinite(betamax) ? (beta *= 2.0) : (beta + betamax) / 2.0;
                    } else {
                        betamax = beta;
                        beta = (beta + betamin) / 2.0;
                    }
                } else {
                    j = 0;
                    while (j < n) {
                        int n2 = j++;
                        Pi[n2] = Pi[n2] / Pisum;
                    }
                }
                logger.debug("Hdiff = {}, beta[{}] = {}, H = {}, logU = {}", new Object[]{Hdiff, i, beta, H, logU});
            }
        });
        return P;
    }

    private static double computeQ(double[][] Y, double[][] Q) {
        int n = Y.length;
        double[] rowSum = IntStream.range(0, n).parallel().mapToDouble(i -> {
            double[] Yi = Y[i];
            double[] Qi = Q[i];
            double sum = 0.0;
            for (int j = 0; j < n; ++j) {
                double q;
                Qi[j] = q = 1.0 / (1.0 + MathEx.squaredDistance((double[])Yi, (double[])Y[j]));
                sum += q;
            }
            return sum;
        }).toArray();
        return MathEx.sum((double[])rowSum);
    }

    private static double computeCost(double[][] P, double[][] Q, double Qsum) {
        return 2.0 * IntStream.range(0, Q.length).parallel().mapToDouble(i -> {
            double[] Pi = P[i];
            double[] Qi = Q[i];
            double C = 0.0;
            for (int j = 0; j < i; ++j) {
                double p = Pi[j];
                double q = Qi[j] / Qsum;
                if (Double.isNaN(q) || q < 1.0E-16) {
                    q = 1.0E-16;
                }
                C += p * MathEx.log2((double)(p / q));
            }
            return C;
        }).sum();
    }

    public record Options(int d, double perplexity, double eta, double earlyExaggeration, int maxIter, int maxIterWithoutProgress, double tol, double momentum, double finalMomentum, int momentumSwitchIter, double minGain, IterativeAlgorithmController<AlgoStatus> controller) {
        public Options {
            if (d < 2) {
                throw new IllegalArgumentException("Invalid dimension of feature space: " + d);
            }
            if (perplexity < 2.0) {
                throw new IllegalArgumentException("Invalid perplexity: " + perplexity);
            }
            if (eta <= 0.0) {
                throw new IllegalArgumentException("Invalid learning rate: " + eta);
            }
            if (earlyExaggeration <= 0.0) {
                throw new IllegalArgumentException("Invalid early exaggeration: " + earlyExaggeration);
            }
            if (maxIter < 250) {
                throw new IllegalArgumentException("Invalid maximum number of iterations: " + maxIter);
            }
            if (maxIterWithoutProgress < 50 || maxIterWithoutProgress > maxIter) {
                throw new IllegalArgumentException("Invalid maximum number of iterations without progress: " + maxIterWithoutProgress);
            }
            if (tol <= 0.0) {
                throw new IllegalArgumentException("Invalid tolerance: " + tol);
            }
            if (momentum <= 0.0) {
                throw new IllegalArgumentException("Invalid momentum: " + momentum);
            }
            if (finalMomentum <= 0.0) {
                throw new IllegalArgumentException("Invalid final momentum: " + finalMomentum);
            }
            if (momentumSwitchIter <= 0 || momentumSwitchIter >= maxIter) {
                throw new IllegalArgumentException("Invalid learning rate: " + momentumSwitchIter);
            }
            if (minGain <= 0.0) {
                throw new IllegalArgumentException("Invalid minimum gain: " + minGain);
            }
        }

        public Options(int d, double perplexity, double eta, double earlyExaggeration, int maxIter) {
            this(d, perplexity, eta, earlyExaggeration, maxIter, 50, 1.0E-7, 0.5, 0.8, 250, 0.01, null);
        }

        public Properties toProperties() {
            Properties props = new Properties();
            props.setProperty("smile.t_sne.d", Integer.toString(this.d));
            props.setProperty("smile.t_sne.perplexity", Double.toString(this.perplexity));
            props.setProperty("smile.t_sne.eta", Double.toString(this.eta));
            props.setProperty("smile.t_sne.early_exaggeration", Double.toString(this.earlyExaggeration));
            props.setProperty("smile.t_sne.iterations", Integer.toString(this.maxIter));
            props.setProperty("smile.t_sne.max_iterations_without_progress", Integer.toString(this.maxIterWithoutProgress));
            props.setProperty("smile.t_sne.tolerance", Double.toString(this.tol));
            props.setProperty("smile.t_sne.momentum", Double.toString(this.momentum));
            props.setProperty("smile.t_sne.final_momentum", Double.toString(this.finalMomentum));
            props.setProperty("smile.t_sne.momentum_switch", Integer.toString(this.momentumSwitchIter));
            props.setProperty("smile.t_sne.min_gain", Double.toString(this.minGain));
            return props;
        }

        public static Options of(Properties props) {
            int d = Integer.parseInt(props.getProperty("smile.t_sne.d", "2"));
            double perplexity = Double.parseDouble(props.getProperty("smile.t_sne.perplexity", "20"));
            double eta = Double.parseDouble(props.getProperty("smile.t_sne.eta", "200"));
            double earlyExaggeration = Double.parseDouble(props.getProperty("smile.t_sne.early_exaggeration"));
            int maxIter = Integer.parseInt(props.getProperty("smile.t_sne.iterations", "1000"));
            int maxIterWithoutProgress = Integer.parseInt(props.getProperty("smile.t_sne.max_iterations_without_progress", "50"));
            double tol = Double.parseDouble(props.getProperty("smile.t_sne.tolerance", "1E-7"));
            double momentum = Double.parseDouble(props.getProperty("smile.t_sne.momentum"));
            double finalMomentum = Double.parseDouble(props.getProperty("smile.t_sne.final_momentum"));
            int momentumSwitchIter = Integer.parseInt(props.getProperty("smile.t_sne.momentum_switch"));
            double minGain = Double.parseDouble(props.getProperty("smile.t_sne.momentum_switch"));
            return new Options(d, perplexity, eta, earlyExaggeration, maxIter, maxIterWithoutProgress, tol, momentum, finalMomentum, momentumSwitchIter, minGain, null);
        }
    }
}

