/*
 * Decompiled with CFR 0.152.
 */
package smile.manifold;

import java.io.Serializable;
import java.util.Collection;
import java.util.stream.IntStream;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import smile.graph.AdjacencyList;
import smile.graph.Graph;
import smile.manifold.NearestNeighborGraph;
import smile.math.DifferentiableMultivariateFunction;
import smile.math.LevenbergMarquardt;
import smile.math.MathEx;
import smile.math.distance.Distance;
import smile.math.distance.EuclideanDistance;
import smile.math.matrix.ARPACK;
import smile.math.matrix.IMatrix;
import smile.math.matrix.Matrix;
import smile.math.matrix.SparseMatrix;
import smile.stat.distribution.GaussianDistribution;

public class UMAP
implements Serializable {
    private static final long serialVersionUID = 2L;
    private static final Logger logger = LoggerFactory.getLogger(UMAP.class);
    public final double[][] coordinates;
    public final int[] index;
    public final AdjacencyList graph;
    private static final DifferentiableMultivariateFunction func = new DifferentiableMultivariateFunction(){

        public double f(double[] x) {
            return 1.0 / (1.0 + x[0] * Math.pow(x[2], x[1]));
        }

        public double g(double[] x, double[] g) {
            double pow = Math.pow(x[2], x[1]);
            double de = 1.0 + x[0] * pow;
            g[0] = -pow / (de * de);
            g[1] = -(x[0] * x[1] * Math.log(x[2]) * pow) / (de * de);
            return 1.0 / de;
        }
    };

    public UMAP(int[] index, double[][] coordinates, AdjacencyList graph) {
        this.index = index;
        this.coordinates = coordinates;
        this.graph = graph;
    }

    public static UMAP of(double[][] data) {
        return UMAP.of(data, 15);
    }

    public static <T> UMAP of(T[] data, Distance<T> distance) {
        return UMAP.of(data, distance, 15);
    }

    public static UMAP of(double[][] data, int k) {
        return UMAP.of(data, new EuclideanDistance(), k);
    }

    public static <T> UMAP of(T[] data, Distance<T> distance, int k) {
        return UMAP.of(data, distance, k, 2, data.length > 10000 ? 200 : 500, 1.0, 0.1, 1.0, 5, 1.0);
    }

    public static UMAP of(double[][] data, int k, int d, int iterations, double learningRate, double minDist, double spread, int negativeSamples, double repulsionStrength) {
        return UMAP.of(data, new EuclideanDistance(), k, d, iterations, learningRate, minDist, spread, negativeSamples, repulsionStrength);
    }

    public static <T> UMAP of(T[] data, Distance<T> distance, int k, int d, int iterations, double learningRate, double minDist, double spread, int negativeSamples, double repulsionStrength) {
        if (d < 2) {
            throw new IllegalArgumentException("d must be greater than 1: " + d);
        }
        if (k < 2) {
            throw new IllegalArgumentException("k must be greater than 1: " + k);
        }
        if (minDist <= 0.0) {
            throw new IllegalArgumentException("minDist must greater than 0: " + minDist);
        }
        if (minDist > spread) {
            throw new IllegalArgumentException("minDist must be less than or equal to spread: " + minDist + ",spread=" + spread);
        }
        if (iterations < 10) {
            throw new IllegalArgumentException("epochs must be a positive integer of at least 10: " + iterations);
        }
        if (learningRate <= 0.0) {
            throw new IllegalArgumentException("learningRate must greater than 0: " + learningRate);
        }
        if (negativeSamples <= 0) {
            throw new IllegalArgumentException("negativeSamples must greater than 0: " + negativeSamples);
        }
        AdjacencyList graph = NearestNeighborGraph.of(data, distance, k, true, null);
        NearestNeighborGraph nng = NearestNeighborGraph.largest(graph);
        graph = UMAP.computeFuzzySimplicialSet(nng.graph, k, 64);
        SparseMatrix conorm = graph.toMatrix();
        double[][] coordinates = UMAP.spectralLayout(graph, d);
        logger.info("Finish initialization with spectral layout");
        double[] curve = UMAP.fitCurve(spread, minDist);
        logger.info("Finish fitting the curve parameters");
        SparseMatrix epochs = UMAP.computeEpochPerSample(conorm, iterations);
        logger.info("Start optimizing the layout");
        UMAP.optimizeLayout(coordinates, curve, epochs, iterations, learningRate, negativeSamples, repulsionStrength);
        return new UMAP(nng.index, coordinates, graph);
    }

    private static double[] fitCurve(double spread, double minDist) {
        int size = 300;
        double[] x = new double[size];
        double[] y = new double[size];
        double end = 3.0 * spread;
        double interval = end / (double)size;
        for (int i = 0; i < x.length; ++i) {
            x[i] = (double)(i + 1) * interval;
            y[i] = x[i] < minDist ? 1.0 : Math.exp(-(x[i] - minDist) / spread);
        }
        double[] p = new double[]{0.5, 0.0};
        LevenbergMarquardt curveFit = LevenbergMarquardt.fit((DifferentiableMultivariateFunction)func, (double[])x, (double[])y, (double[])p);
        return curveFit.parameters;
    }

    private static AdjacencyList computeFuzzySimplicialSet(AdjacencyList nng, int k, int iterations) {
        int i;
        double LogK = MathEx.log2((double)k);
        double EPSILON = 1.0E-8;
        double TOLERANCE = 1.0E-5;
        double MIN_SCALE = 0.001;
        int n = nng.getNumVertices();
        double[] sigma = new double[n];
        double[] rho = new double[n];
        double avg = IntStream.range(0, n).mapToObj(arg_0 -> ((AdjacencyList)nng).getEdges(arg_0)).flatMapToDouble(edges -> edges.stream().mapToDouble(edge -> edge.weight)).filter(w -> !MathEx.isZero((double)w, (double)1.0E-8)).average().orElse(0.0);
        for (i = 0; i < n; ++i) {
            double lo = 0.0;
            double hi = Double.POSITIVE_INFINITY;
            double mid = 1.0;
            Collection knn = nng.getEdges(i);
            rho[i] = knn.stream().mapToDouble(edge -> edge.weight).filter(w -> !MathEx.isZero((double)w, (double)1.0E-8)).min().orElse(0.0);
            for (int iter = 0; iter < iterations; ++iter) {
                double psum = 0.0;
                for (Graph.Edge edge2 : knn) {
                    if (MathEx.isZero((double)edge2.weight, (double)1.0E-8)) continue;
                    double d = edge2.weight - rho[i];
                    psum += d > 0.0 ? Math.exp(-d / mid) : 1.0;
                }
                if (Math.abs(psum - LogK) < 1.0E-5) break;
                if (psum > LogK) {
                    hi = mid;
                    mid = (lo + hi) / 2.0;
                    continue;
                }
                lo = mid;
                if (Double.isInfinite(hi)) {
                    mid *= 2.0;
                    continue;
                }
                mid = (lo + hi) / 2.0;
            }
            sigma[i] = mid;
            if (rho[i] > 0.0) {
                double avgi = knn.stream().mapToDouble(edge -> edge.weight).filter(w -> !MathEx.isZero((double)w, (double)1.0E-8)).average().orElse(0.0);
                sigma[i] = Math.max(sigma[i], 0.001 * avgi);
                continue;
            }
            sigma[i] = Math.max(sigma[i], 0.001 * avg);
        }
        for (i = 0; i < n; ++i) {
            for (Graph.Edge edge3 : nng.getEdges(i)) {
                edge3.weight = Math.exp(-Math.max(0.0, edge3.weight - rho[i]) / sigma[i]);
            }
        }
        AdjacencyList G = new AdjacencyList(n, false);
        for (int i2 = 0; i2 < n; ++i2) {
            for (Graph.Edge edge4 : nng.getEdges(i2)) {
                double w2 = edge4.weight;
                double w22 = nng.getWeight(edge4.v2, edge4.v1);
                G.setWeight(edge4.v1, edge4.v2, w2 + w22 - w2 * w22);
            }
        }
        return G;
    }

    private static double[][] spectralLayout(AdjacencyList nng, int d) {
        int n = nng.getNumVertices();
        double[] D = new double[n];
        for (int i = 0; i < n; ++i) {
            for (Object edge : nng.getEdges(i)) {
                int n2 = i;
                D[n2] = D[n2] + ((Graph.Edge)edge).weight;
            }
            D[i] = 1.0 / Math.sqrt(D[i]);
        }
        AdjacencyList laplacian = new AdjacencyList(n, false);
        for (int i = 0; i < n; ++i) {
            laplacian.setWeight(i, i, 1.0);
            for (Graph.Edge edge : nng.getEdges(i)) {
                double w = -D[edge.v1] * edge.weight * D[edge.v2];
                laplacian.setWeight(edge.v1, edge.v2, w);
            }
        }
        SparseMatrix L = laplacian.toMatrix();
        Matrix.EVD eigen = ARPACK.syev((IMatrix)L, (ARPACK.SymmOption)ARPACK.SymmOption.SM, (int)Math.min(10 * (d + 1), n - 1));
        double absMax = 0.0;
        Matrix V = eigen.Vr;
        double[][] coordinates = new double[n][d];
        int j = d;
        while (--j >= 0) {
            int c = V.ncol() - j - 2;
            for (int i = 0; i < n; ++i) {
                double x;
                coordinates[i][j] = x = V.get(i, c);
                double abs = Math.abs(x);
                if (!(abs > absMax)) continue;
                absMax = abs;
            }
        }
        double expansion = 10.0 / absMax;
        GaussianDistribution gaussian = new GaussianDistribution(0.0, 1.0E-4);
        for (int i = 0; i < n; ++i) {
            for (int j2 = 0; j2 < d; ++j2) {
                coordinates[i][j2] = coordinates[i][j2] * expansion + gaussian.rand();
            }
        }
        double[] colMax = MathEx.colMax((double[][])coordinates);
        double[] colMin = MathEx.colMin((double[][])coordinates);
        double[] de = new double[d];
        for (int j3 = 0; j3 < d; ++j3) {
            de[j3] = colMax[j3] - colMin[j3];
        }
        for (int i = 0; i < n; ++i) {
            for (int j4 = 0; j4 < d; ++j4) {
                coordinates[i][j4] = 10.0 * (coordinates[i][j4] - colMin[j4]) / de[j4];
            }
        }
        return coordinates;
    }

    private static void optimizeLayout(double[][] embedding, double[] curve, SparseMatrix epochsPerSample, int iterations, double initialAlpha, int negativeSamples, double gamma) {
        int n = embedding.length;
        int d = embedding[0].length;
        double a = curve[0];
        double b = curve[1];
        double alpha = initialAlpha;
        SparseMatrix epochsPerNegativeSample = epochsPerSample.clone();
        epochsPerNegativeSample.nonzeros().forEach(w -> w.update(w.x / (double)negativeSamples));
        SparseMatrix epochNextNegativeSample = epochsPerNegativeSample.clone();
        SparseMatrix epochNextSample = epochsPerSample.clone();
        for (int iter = 1; iter <= iterations; ++iter) {
            for (SparseMatrix.Entry edge : epochNextSample) {
                if (!(edge.x > 0.0) || !(edge.x <= (double)iter)) continue;
                int j = edge.i;
                int k = edge.j;
                int index = edge.index;
                double[] current = embedding[j];
                double[] other = embedding[k];
                double distSquared = MathEx.squaredDistance((double[])current, (double[])other);
                if (distSquared > 0.0) {
                    double gradCoeff = -2.0 * a * b * Math.pow(distSquared, b - 1.0);
                    gradCoeff /= a * Math.pow(distSquared, b) + 1.0;
                    int i = 0;
                    while (i < d) {
                        double gradD = UMAP.clamp(gradCoeff * (current[i] - other[i]));
                        int n2 = i;
                        current[n2] = current[n2] + gradD * alpha;
                        int n3 = i++;
                        other[n3] = other[n3] - gradD * alpha;
                    }
                }
                edge.update(edge.x + epochsPerSample.get(index));
                int negSamples = (int)(((double)iter - epochNextNegativeSample.get(index)) / epochsPerNegativeSample.get(index));
                for (int p = 0; p < negSamples; ++p) {
                    k = MathEx.randomInt((int)n);
                    if (j == k) continue;
                    other = embedding[k];
                    distSquared = MathEx.squaredDistance((double[])current, (double[])other);
                    double gradCoeff = 0.0;
                    if (distSquared > 0.0) {
                        gradCoeff = 2.0 * gamma * b;
                        gradCoeff /= (0.001 + distSquared) * (a * Math.pow(distSquared, b) + 1.0);
                    }
                    int i = 0;
                    while (i < d) {
                        double gradD = 4.0;
                        if (gradCoeff > 0.0) {
                            gradD = UMAP.clamp(gradCoeff * (current[i] - other[i]));
                        }
                        int n4 = i++;
                        current[n4] = current[n4] + gradD * alpha;
                    }
                }
                epochNextNegativeSample.set(index, epochNextNegativeSample.get(index) + epochsPerNegativeSample.get(index) * (double)negSamples);
            }
            logger.info(String.format("The learning rate at %3d iterations: %.5f", iter, alpha));
            alpha = initialAlpha * (1.0 - (double)iter / (double)iterations);
        }
    }

    private static SparseMatrix computeEpochPerSample(SparseMatrix strength, int iterations) {
        double max = strength.nonzeros().mapToDouble(w -> w.x).max().orElse(0.0);
        double min = max / (double)iterations;
        strength.nonzeros().forEach(w -> {
            if (w.x < min) {
                w.update(0.0);
            } else {
                w.update(max / w.x);
            }
        });
        return strength;
    }

    private static double clamp(double val) {
        return Math.min(4.0, Math.max(val, -4.0));
    }
}

