/*
 * Decompiled with CFR 0.152.
 */
package smile.manifold;

import java.util.Arrays;
import java.util.Properties;
import java.util.stream.IntStream;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import smile.feature.extraction.PCA;
import smile.graph.AdjacencyList;
import smile.graph.NearestNeighborGraph;
import smile.math.LevenbergMarquardt;
import smile.math.MathEx;
import smile.math.distance.Metric;
import smile.stat.distribution.GaussianDistribution;
import smile.tensor.ARPACK;
import smile.tensor.DenseMatrix;
import smile.tensor.EVD;
import smile.tensor.Matrix;
import smile.tensor.SparseMatrix;
import smile.util.function.DifferentiableMultivariateFunction;

public class UMAP {
    private static final Logger logger = LoggerFactory.getLogger(UMAP.class);
    private static final int LARGE_DATA_SIZE = 10000;

    private UMAP() {
    }

    public static double[][] fit(double[][] data, Options options) {
        NearestNeighborGraph nng = data.length <= 10000 ? NearestNeighborGraph.of((double[][])data, (int)options.k) : NearestNeighborGraph.descent((double[][])data, (int)options.k);
        return UMAP.fit(data, nng, options);
    }

    public static <T> double[][] fit(T[] data, Metric<T> distance, Options options) {
        NearestNeighborGraph nng = data.length <= 10000 ? NearestNeighborGraph.of((Object[])data, distance, (int)options.k) : NearestNeighborGraph.descent((Object[])data, distance, (int)options.k);
        return UMAP.fit(data, nng, options);
    }

    public static <T> double[][] fit(T[] data, NearestNeighborGraph nng, Options options) {
        double[][] coordinates;
        int d = options.d;
        int epochs = options.epochs;
        if (epochs < 10) {
            epochs = data.length > 10000 ? 200 : 500;
            logger.info("Set epochs = {}", (Object)epochs);
        }
        SparseMatrix conorm = UMAP.computeFuzzySimplicialSet(nng, options.localConnectivity);
        int n = nng.size();
        boolean connected = false;
        if (n <= 10000) {
            int[][] cc = nng.graph(false).bfcc();
            logger.info("The nearest neighbor graph has {} connected component(s).", (Object)cc.length);
            boolean bl = connected = cc.length == 1;
        }
        if (connected) {
            logger.info("Spectral initialization will be attempted.");
            coordinates = UMAP.spectralLayout(nng, d);
            UMAP.noisyScale(coordinates, 10.0, 1.0E-4);
        } else if (data instanceof double[][]) {
            logger.info("PCA-based initialization will be attempted.");
            coordinates = UMAP.pcaLayout((double[][])data, d);
            UMAP.noisyScale(coordinates, 10.0, 1.0E-4);
        } else {
            logger.info("Random initialization will be attempted.");
            coordinates = UMAP.randomLayout(n, d);
        }
        UMAP.normalize(coordinates, 10.0);
        logger.info("Finish embedding initialization");
        double[] curve = UMAP.fitCurve(options.spread, options.minDist);
        logger.info("Finish fitting the curve parameters: {}", (Object)Arrays.toString(curve));
        SparseMatrix epochsPerSample = UMAP.computeEpochPerSample(conorm, epochs);
        logger.info("Start optimizing the layout");
        UMAP.optimizeLayout(coordinates, curve, epochsPerSample, epochs, options.learningRate, options.negativeSamples, options.repulsionStrength);
        return coordinates;
    }

    private static double[] fitCurve(double spread, double minDist) {
        int size = 300;
        double[] x = new double[size];
        double[] y = new double[size];
        double end = 3.0 * spread;
        double interval = end / (double)size;
        for (int i = 0; i < size; ++i) {
            x[i] = (double)(i + 1) * interval;
            y[i] = x[i] < minDist ? 1.0 : Math.exp(-(x[i] - minDist) / spread);
        }
        double[] p = new double[]{0.5, 0.0};
        LevenbergMarquardt curveFit = LevenbergMarquardt.fit((DifferentiableMultivariateFunction)new Curve(), (double[])x, (double[])y, (double[])p);
        double[] result = curveFit.parameters();
        result[1] = result[1] / 2.0;
        return result;
    }

    private static SparseMatrix computeFuzzySimplicialSet(NearestNeighborGraph nng, double localConnectivity) {
        double[][] result = UMAP.smoothKnnDist(nng.distances(), nng.k(), 64, localConnectivity, 1.0);
        double[] sigma = result[0];
        double[] rho = result[1];
        int n = nng.size();
        AdjacencyList strength = UMAP.computeMembershipStrengths(nng, sigma, rho);
        AdjacencyList conorm = new AdjacencyList(n, false);
        int i = 0;
        while (i < n) {
            int u = i++;
            strength.forEachEdge(u, (v, a) -> {
                double b = strength.getWeight(v, u);
                double w = a + b - a * b;
                conorm.setWeight(u, v, w);
            });
        }
        return conorm.toMatrix();
    }

    private static double[][] smoothKnnDist(double[][] distances, double k, int maxIter, double localConnectivity, double bandwidth) {
        double SMOOTH_K_TOLERANCE = 1.0E-5;
        double MIN_K_DIST_SCALE = 0.001;
        int n = distances.length;
        double target = MathEx.log2((double)k) * bandwidth;
        double[] rho = new double[n];
        double[] knn = new double[n];
        int length = 0;
        double mean = 0.0;
        for (double[] row : distances) {
            mean += MathEx.sum((double[])row);
            length += row.length;
        }
        double mu = mean /= (double)length;
        IntStream.range(0, n).parallel().forEach(i -> {
            double lo = 0.0;
            double hi = Double.POSITIVE_INFINITY;
            double mid = 1.0;
            double[] nonZeroDists = Arrays.stream(distances[i]).filter(x -> x > 0.0).toArray();
            if ((double)nonZeroDists.length >= localConnectivity) {
                int index = (int)Math.floor(localConnectivity);
                double interpolation = localConnectivity - (double)index;
                if (index > 0) {
                    rho[i] = nonZeroDists[index - 1];
                    if (interpolation > 1.0E-5) {
                        int n = i;
                        rho[n] = rho[n] + interpolation * (nonZeroDists[index] - nonZeroDists[index - 1]);
                    }
                } else {
                    rho[i] = interpolation * nonZeroDists[0];
                }
            } else if (nonZeroDists.length > 0) {
                rho[i] = MathEx.max((double[])nonZeroDists);
            }
            for (int iter = 0; iter < maxIter; ++iter) {
                double psum = 0.0;
                for (int j = 1; j < distances[j].length; ++j) {
                    double d = distances[i][j] - rho[i];
                    psum += d > 0.0 ? Math.exp(-d / mid) : 1.0;
                }
                if (Math.abs(psum - target) < 1.0E-5) break;
                if (psum > target) {
                    hi = mid;
                    mid = (lo + hi) / 2.0;
                    continue;
                }
                lo = mid;
                if (Double.isInfinite(hi)) {
                    mid *= 2.0;
                    continue;
                }
                mid = (lo + hi) / 2.0;
            }
            knn[i] = mid;
            if (rho[i] > 0.0) {
                double mui = MathEx.mean((double[])distances[i]);
                if (knn[i] < 0.001 * mui) {
                    knn[i] = 0.001 * mui;
                }
            } else if (knn[i] < 0.001 * mu) {
                knn[i] = 0.001 * mu;
            }
        });
        return new double[][]{knn, rho};
    }

    private static AdjacencyList computeMembershipStrengths(NearestNeighborGraph nng, double[] sigma, double[] rho) {
        int n = nng.size();
        int[][] neighbors = nng.neighbors();
        double[][] distances = nng.distances();
        AdjacencyList G = new AdjacencyList(n, true);
        for (int i = 0; i < n; ++i) {
            for (int j = 0; j < neighbors[i].length; ++j) {
                double d = distances[i][j] - rho[i];
                double w = d <= 0.0 ? 1.0 : Math.exp(-d / sigma[i]);
                G.setWeight(i, neighbors[i][j], w);
            }
        }
        return G;
    }

    private static double[][] randomLayout(int n, int d) {
        double[][] embedding = new double[n][d];
        for (int i = 0; i < n; ++i) {
            for (int j = 0; j < d; ++j) {
                embedding[i][j] = MathEx.random((double)-10.0, (double)10.0);
            }
        }
        return embedding;
    }

    private static double[][] pcaLayout(double[][] data, int d) {
        return PCA.fit(data, new String[0]).getProjection(d).apply(data);
    }

    private static double[][] spectralLayout(NearestNeighborGraph nng, int d) {
        int[][] neighbors = nng.neighbors();
        double[][] distances = nng.distances();
        int n = nng.size();
        double[] D = new double[n];
        IntStream.range(0, n).parallel().forEach(i -> {
            D[i] = 1.0 / Math.sqrt(MathEx.sum((double[])distances[i]));
        });
        logger.info("Spectral layout computes Laplacian...");
        AdjacencyList laplacian = new AdjacencyList(n, false);
        for (int i2 = 0; i2 < n; ++i2) {
            laplacian.setWeight(i2, i2, 1.0);
            int[] v = neighbors[i2];
            double[] dist = distances[i2];
            for (int j = 0; j < v.length; ++j) {
                double w = -D[i2] * dist[j] * D[v[j]];
                laplacian.setWeight(i2, v[j], w);
            }
        }
        int k = d + 1;
        int numEigen = Math.min(2 * k + 1, (int)Math.sqrt(n));
        numEigen = Math.max(numEigen, k);
        numEigen = Math.min(numEigen, n);
        SparseMatrix L = laplacian.toMatrix();
        logger.info("Spectral layout computes {} eigen vectors", (Object)numEigen);
        EVD eigen = ARPACK.syev((Matrix)L, (ARPACK.SymmOption)ARPACK.SymmOption.SM, (int)numEigen);
        DenseMatrix V = eigen.Vr();
        double[][] coordinates = new double[n][d];
        int j = d;
        while (--j >= 0) {
            int c = V.ncol() - j - 2;
            for (int i3 = 0; i3 < n; ++i3) {
                coordinates[i3][j] = V.get(i3, c);
            }
        }
        return coordinates;
    }

    private static void noisyScale(double[][] coordinates, double scale, double noise) {
        int d = coordinates[0].length;
        double max = Double.NEGATIVE_INFINITY;
        for (double[] coordinate : coordinates) {
            for (int j = 0; j < d; ++j) {
                max = Math.max(max, Math.abs(coordinate[j]));
            }
        }
        double expansion = scale / max;
        GaussianDistribution gaussian = new GaussianDistribution(0.0, noise);
        for (double[] coordinate : coordinates) {
            for (int j = 0; j < d; ++j) {
                coordinate[j] = expansion * coordinate[j] + gaussian.rand();
            }
        }
    }

    private static void normalize(double[][] coordinates, double scale) {
        int d = coordinates[0].length;
        double[] colMax = MathEx.colMax((double[][])coordinates);
        double[] colMin = MathEx.colMin((double[][])coordinates);
        double[] length = new double[d];
        for (int j = 0; j < d; ++j) {
            length[j] = colMax[j] - colMin[j];
        }
        for (double[] coordinate : coordinates) {
            for (int j = 0; j < d; ++j) {
                coordinate[j] = scale * (coordinate[j] - colMin[j]) / length[j];
            }
        }
    }

    private static void optimizeLayout(double[][] embedding, double[] curve, SparseMatrix epochsPerSample, int epochs, double initialAlpha, int negativeSamples, double gamma) {
        int n = embedding.length;
        int d = embedding[0].length;
        double a = curve[0];
        double b = curve[1];
        double alpha = initialAlpha;
        SparseMatrix epochsPerNegativeSample = epochsPerSample.copy();
        epochsPerNegativeSample.nonzeros().forEach(w -> w.update(w.x / (double)negativeSamples));
        SparseMatrix epochNextNegativeSample = epochsPerNegativeSample.copy();
        SparseMatrix epochNextSample = epochsPerSample.copy();
        for (int iter = 1; iter <= epochs; ++iter) {
            for (SparseMatrix.Entry edge : epochNextSample) {
                if (!(edge.x > 0.0) || !(edge.x <= (double)iter)) continue;
                int j = edge.i;
                int k = edge.j;
                int index = edge.index;
                double[] current = embedding[j];
                double[] other = embedding[k];
                double distSquared = MathEx.squaredDistance((double[])current, (double[])other);
                if (distSquared > 0.0) {
                    double gradCoeff = -2.0 * a * b * Math.pow(distSquared, b - 1.0);
                    gradCoeff /= a * Math.pow(distSquared, b) + 1.0;
                    int i = 0;
                    while (i < d) {
                        double gradD = UMAP.clamp(gradCoeff * (current[i] - other[i]));
                        int n2 = i;
                        current[n2] = current[n2] + gradD * alpha;
                        int n3 = i++;
                        other[n3] = other[n3] - gradD * alpha;
                    }
                }
                edge.update(edge.x + epochsPerSample.get(index));
                int negSamples = (int)(((double)iter - epochNextNegativeSample.get(index)) / epochsPerNegativeSample.get(index));
                for (int p = 0; p < negSamples; ++p) {
                    k = MathEx.randomInt((int)n);
                    if (j == k) continue;
                    other = embedding[k];
                    distSquared = MathEx.squaredDistance((double[])current, (double[])other);
                    double gradCoeff = 0.0;
                    if (distSquared > 0.0) {
                        gradCoeff = 2.0 * gamma * b;
                        gradCoeff /= (0.001 + distSquared) * (a * Math.pow(distSquared, b) + 1.0);
                    }
                    int i = 0;
                    while (i < d) {
                        double gradD = 4.0;
                        if (gradCoeff > 0.0) {
                            gradD = UMAP.clamp(gradCoeff * (current[i] - other[i]));
                        }
                        int n4 = i++;
                        current[n4] = current[n4] + gradD * alpha;
                    }
                }
                epochNextNegativeSample.set(index, epochNextNegativeSample.get(index) + epochsPerNegativeSample.get(index) * (double)negSamples);
            }
            logger.info("The learning rate at {} iterations: {}", (Object)iter, (Object)alpha);
            alpha = initialAlpha * (1.0 - (double)iter / (double)epochs);
        }
    }

    private static SparseMatrix computeEpochPerSample(SparseMatrix strength, int epochs) {
        double max = strength.nonzeros().mapToDouble(w -> w.x).max().orElse(0.0);
        double min = max / (double)epochs;
        strength.nonzeros().forEach(w -> {
            if (w.x < min) {
                w.update(0.0);
            } else {
                w.update(max / w.x);
            }
        });
        return strength;
    }

    private static double clamp(double val) {
        return Math.min(4.0, Math.max(val, -4.0));
    }

    public record Options(int k, int d, int epochs, double learningRate, double minDist, double spread, int negativeSamples, double repulsionStrength, double localConnectivity) {
        public Options {
            if (k < 2) {
                throw new IllegalArgumentException("Invalid number of nearest neighbors: " + k);
            }
            if (d < 2) {
                throw new IllegalArgumentException("Invalid dimension of feature space: " + d);
            }
            if (minDist <= 0.0) {
                throw new IllegalArgumentException("minDist must greater than 0: " + minDist);
            }
            if (minDist > spread) {
                throw new IllegalArgumentException("minDist must be less than or equal to spread: " + minDist + ", spread=" + spread);
            }
            if (learningRate <= 0.0) {
                throw new IllegalArgumentException("learningRate must greater than 0: " + learningRate);
            }
            if (negativeSamples <= 0) {
                throw new IllegalArgumentException("negativeSamples must greater than 0: " + negativeSamples);
            }
            if (localConnectivity < 1.0) {
                throw new IllegalArgumentException("localConnectivity must be at least 1.0: " + localConnectivity);
            }
        }

        public Options(int k) {
            this(k, 2, 0, 1.0, 0.1, 1.0, 5, 1.0, 1.0);
        }

        public Properties toProperties() {
            Properties props = new Properties();
            props.setProperty("smile.umap.k", Integer.toString(this.k));
            props.setProperty("smile.umap.d", Integer.toString(this.d));
            props.setProperty("smile.umap.epochs", Integer.toString(this.epochs));
            props.setProperty("smile.umap.learning_rate", Double.toString(this.learningRate));
            props.setProperty("smile.umap.min_dist", Double.toString(this.minDist));
            props.setProperty("smile.umap.spread", Double.toString(this.spread));
            props.setProperty("smile.umap.negative_samples", Integer.toString(this.negativeSamples));
            props.setProperty("smile.umap.repulsion_strength", Double.toString(this.repulsionStrength));
            props.setProperty("smile.umap.local_connectivity", Double.toString(this.localConnectivity));
            return props;
        }

        public static Options of(Properties props) {
            int k = Integer.parseInt(props.getProperty("smile.umap.k", "15"));
            int d = Integer.parseInt(props.getProperty("smile.umap.d", "2"));
            int epochs = Integer.parseInt(props.getProperty("smile.umap.epochs", "0"));
            double learningRate = Double.parseDouble(props.getProperty("smile.umap.learning_rate", "1.0"));
            double minDist = Double.parseDouble(props.getProperty("smile.umap.min_dist", "0.1"));
            double spread = Double.parseDouble(props.getProperty("smile.umap.spread", "1.0"));
            int negativeSamples = Integer.parseInt(props.getProperty("smile.umap.negative_samples", "5"));
            double repulsionStrength = Double.parseDouble(props.getProperty("smile.umap.repulsion_strength", "1.0"));
            double localConnectivity = Double.parseDouble(props.getProperty("smile.umap.local_connectivity", "1.0"));
            return new Options(k, d, epochs, learningRate, minDist, spread, negativeSamples, repulsionStrength, localConnectivity);
        }
    }

    private static class Curve
    implements DifferentiableMultivariateFunction {
        private Curve() {
        }

        public double f(double[] x) {
            return 1.0 / (1.0 + x[0] * Math.pow(x[2], x[1]));
        }

        public double g(double[] x, double[] g) {
            double pow = Math.pow(x[2], x[1]);
            double de = 1.0 + x[0] * pow;
            g[0] = -pow / (de * de);
            g[1] = -(x[0] * x[1] * Math.log(x[2]) * pow) / (de * de);
            return 1.0 / de;
        }
    }
}

