/*
 * Decompiled with CFR 0.152.
 */
package smile.clustering;

import java.util.Properties;
import java.util.stream.IntStream;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import smile.clustering.CentroidClustering;
import smile.clustering.Clustering;
import smile.clustering.KMeans;
import smile.data.SparseDataset;
import smile.linalg.Transpose;
import smile.linalg.UPLO;
import smile.math.MathEx;
import smile.tensor.ARPACK;
import smile.tensor.DenseMatrix;
import smile.tensor.EVD;
import smile.tensor.Matrix;
import smile.tensor.ScalarType;
import smile.tensor.Vector;
import smile.util.AlgoStatus;
import smile.util.IterativeAlgorithmController;
import smile.util.SparseArray;
import smile.util.SparseIntArray;

public class SpectralClustering {
    private static final Logger logger = LoggerFactory.getLogger(SpectralClustering.class);

    private SpectralClustering() {
    }

    public static CentroidClustering<double[], double[]> fit(SparseIntArray[] data, int p, Clustering.Options options) {
        double[][] Y = SpectralClustering.embed(data, p, options.k());
        return KMeans.fit(Y, options);
    }

    public static CentroidClustering<double[], double[]> fit(double[][] data, Options options) {
        if (options.l >= options.k) {
            return SpectralClustering.nystrom(data, options);
        }
        double[][] Y = SpectralClustering.embed(data, options.k, options.sigma);
        return KMeans.fit(Y, new Clustering.Options(options.k, options.maxIter, options.tol, options.controller));
    }

    public static CentroidClustering<double[], double[]> nystrom(double[][] data, Options options) {
        int i2;
        int n = data.length;
        int k = options.k;
        int l = options.l;
        double sigma = options.sigma;
        double gamma = -0.5 / (sigma * sigma);
        if (l < k || l >= n) {
            throw new IllegalArgumentException("Invalid number of random samples: " + l);
        }
        int[] index = MathEx.permutate((int)n);
        double[][] x = new double[n][];
        for (int i3 = 0; i3 < n; ++i3) {
            x[i3] = data[index[i3]];
        }
        DenseMatrix C = DenseMatrix.zeros((ScalarType)ScalarType.Float64, (int)n, (int)l);
        double[] D = new double[n];
        IntStream.range(0, n).parallel().forEach(i -> {
            for (int j = 0; j < n; ++j) {
                if (i == j) continue;
                double w = Math.exp(gamma * MathEx.squaredDistance((double[])x[i], (double[])x[j]));
                int n2 = i;
                D[n2] = D[n2] + w;
                if (j >= l) continue;
                C.set(i, j, w);
            }
        });
        for (i2 = 0; i2 < n; ++i2) {
            if (D[i2] < 1.0E-4) {
                logger.error("Small D[{}] = {}. The data may contain outliers.", (Object)i2, (Object)D[i2]);
            }
            D[i2] = 1.0 / Math.sqrt(D[i2]);
        }
        for (i2 = 0; i2 < n; ++i2) {
            for (int j = 0; j < l; ++j) {
                C.set(i2, j, D[i2] * C.get(i2, j) * D[j]);
            }
        }
        DenseMatrix W = C.submatrix(0, 0, l, l);
        W.withUplo(UPLO.LOWER);
        EVD eigen = ARPACK.syev((Matrix)W, (ARPACK.SymmOption)ARPACK.SymmOption.LA, (int)k);
        double[] e = eigen.wr().toArray(new double[0]);
        double scale = Math.sqrt((double)l / (double)n);
        for (int i4 = 0; i4 < k; ++i4) {
            if (e[i4] <= 1.0E-8) {
                throw new IllegalStateException("Non-positive eigen value: " + e[i4]);
            }
            e[i4] = scale / e[i4];
        }
        DenseMatrix U = eigen.Vr();
        for (int i5 = 0; i5 < l; ++i5) {
            for (int j = 0; j < k; ++j) {
                U.mul(i5, j, e[j]);
            }
        }
        double[][] Y = C.mm(U).toArray((double[][])new double[0][]);
        for (int i6 = 0; i6 < n; ++i6) {
            MathEx.unitize2((double[])Y[i6]);
        }
        double[][] features = new double[n][];
        for (int i7 = 0; i7 < n; ++i7) {
            features[index[i7]] = Y[i7];
        }
        return KMeans.fit(features, new Clustering.Options(k, options.maxIter, options.tol, options.controller));
    }

    public static double[][] embed(DenseMatrix W, int d) {
        int i;
        int n = W.nrow();
        double[] D = W.colSums().toArray(new double[0]);
        for (i = 0; i < n; ++i) {
            if (D[i] == 0.0) {
                throw new IllegalArgumentException("Isolated vertex: " + i);
            }
            D[i] = 1.0 / Math.sqrt(D[i]);
        }
        for (i = 0; i < n; ++i) {
            for (int j = 0; j < i; ++j) {
                double w = D[i] * W.get(i, j) * D[j];
                W.set(i, j, w);
                W.set(j, i, w);
            }
        }
        W.withUplo(UPLO.LOWER);
        EVD eigen = ARPACK.syev((Matrix)W, (ARPACK.SymmOption)ARPACK.SymmOption.LA, (int)d);
        double[][] Y = eigen.Vr().toArray((double[][])new double[0][]);
        for (int i2 = 0; i2 < n; ++i2) {
            MathEx.unitize2((double[])Y[i2]);
        }
        return Y;
    }

    public static double[][] embed(double[][] data, int d, double sigma) {
        int n = data.length;
        double gamma = -0.5 / (sigma * sigma);
        DenseMatrix W = DenseMatrix.zeros((ScalarType)ScalarType.Float64, (int)n, (int)n);
        for (int i = 0; i < n; ++i) {
            for (int j = 0; j < i; ++j) {
                double w = Math.exp(gamma * MathEx.squaredDistance((double[])data[i], (double[])data[j]));
                W.set(i, j, w);
                W.set(j, i, w);
            }
        }
        return SpectralClustering.embed(W, d);
    }

    public static double[][] embed(SparseIntArray[] data, int p, int d) {
        int n = data.length;
        double[] idf = new double[p];
        for (int i2 = 0; i2 < n; ++i2) {
            data[i2].forEach((j, count) -> {
                int n = j;
                idf[n] = idf[n] + (count > 0 ? 1.0 : 0.0);
            });
        }
        for (int j2 = 0; j2 < p; ++j2) {
            idf[j2] = Math.log((double)n / (1.0 + idf[j2]));
        }
        SparseArray[] X = new SparseArray[n];
        IntStream.range(0, n).parallel().forEach(i -> {
            double[] x = new double[p];
            data[i].forEach((j, count) -> {
                x[j] = (double)count / idf[j];
            });
            MathEx.normalize((double[])x);
            SparseArray Xi = new SparseArray(data[i].size());
            for (int j2 = 0; j2 < p; ++j2) {
                if (!(x[j2] > 0.0)) continue;
                Xi.set(j2, x[j2]);
            }
            X[i] = Xi;
        });
        double[] D = new double[n];
        IntStream.range(0, n).parallel().forEach(i -> {
            double Di = -1.0;
            for (int j2 = 0; j2 < n; ++j2) {
                Di += MathEx.dot((SparseArray)X[i], (SparseArray)X[j2]);
            }
            D[i] = Di;
            double di = Math.sqrt(Di);
            X[i].update((j, xj) -> xj / di);
        });
        CountMatrix W = new CountMatrix((Matrix)SparseDataset.of((SparseArray[])X, (int)p).toMatrix(), D);
        EVD eigen = ARPACK.syev((Matrix)W, (ARPACK.SymmOption)ARPACK.SymmOption.LA, (int)d);
        double[][] Y = eigen.Vr().toArray((double[][])new double[0][]);
        for (int i3 = 0; i3 < n; ++i3) {
            MathEx.unitize2((double[])Y[i3]);
        }
        return Y;
    }

    public record Options(int k, int l, double sigma, int maxIter, double tol, IterativeAlgorithmController<AlgoStatus> controller) {
        public Options {
            if (k < 2) {
                throw new IllegalArgumentException("Invalid number of clusters: " + k);
            }
            if (l < k && l > 0) {
                throw new IllegalArgumentException("Invalid number of random samples: " + l);
            }
            if (sigma <= 0.0) {
                throw new IllegalArgumentException("Invalid standard deviation of Gaussian kernel: " + sigma);
            }
            if (maxIter <= 0) {
                throw new IllegalArgumentException("Invalid maximum number of iterations: " + maxIter);
            }
            if (tol < 0.0) {
                throw new IllegalArgumentException("Invalid tolerance: " + tol);
            }
        }

        public Options(int k, double sigma, int maxIter) {
            this(k, 0, sigma, maxIter);
        }

        public Options(int k, int l, double sigma, int maxIter) {
            this(k, l, sigma, maxIter, 1.0E-4, null);
        }

        public Properties toProperties() {
            Properties props = new Properties();
            props.setProperty("smile.spectral_clustering.k", Integer.toString(this.k));
            props.setProperty("smile.spectral_clustering.l", Integer.toString(this.l));
            props.setProperty("smile.spectral_clustering.sigma", Double.toString(this.sigma));
            props.setProperty("smile.spectral_clustering.iterations", Integer.toString(this.maxIter));
            props.setProperty("smile.spectral_clustering.tolerance", Double.toString(this.tol));
            return props;
        }

        public static Options of(Properties props) {
            int k = Integer.parseInt(props.getProperty("smile.spectral_clustering.k", "2"));
            int l = Integer.parseInt(props.getProperty("smile.spectral_clustering.l", "0"));
            double sigma = Double.parseDouble(props.getProperty("smile.spectral_clustering.sigma", "1.0"));
            int maxIter = Integer.parseInt(props.getProperty("smile.spectral_clustering.iterations", "100"));
            double tol = Double.parseDouble(props.getProperty("smile.spectral_clustering.tolerance", "1E-4"));
            return new Options(k, l, sigma, maxIter, tol, null);
        }
    }

    static class CountMatrix
    implements Matrix {
        final Matrix X;
        final double[] D;
        final Vector x;
        final Vector ax;
        final Vector y;

        CountMatrix(Matrix X, double[] D) {
            this.X = X;
            this.D = D;
            int n = X.nrow();
            int p = X.ncol();
            this.x = X.vector(n);
            this.y = X.vector(n);
            this.ax = X.vector(p);
        }

        public int nrow() {
            return this.X.nrow();
        }

        public int ncol() {
            return this.X.nrow();
        }

        public long length() {
            return this.X.length();
        }

        public ScalarType scalarType() {
            return this.X.scalarType();
        }

        public void mv(Vector x, Vector y) {
            this.X.tv(x, this.ax);
            this.X.mv(this.ax, y);
            for (int i = 0; i < y.size(); ++i) {
                y.sub(i, x.get(i) / this.D[i]);
            }
        }

        public void tv(Vector x, Vector y) {
            this.mv(x, y);
        }

        public void mv(Transpose trans, double alpha, Vector x, double beta, Vector y) {
            throw new UnsupportedOperationException();
        }

        public void mv(Vector work, int inputOffset, int outputOffset) {
            Vector.copy((Vector)work, (int)inputOffset, (Vector)this.x, (int)0, (int)this.x.size());
            this.X.tv(work, this.ax);
            this.X.mv(this.ax, this.y);
            for (int i = 0; i < this.y.size(); ++i) {
                this.y.sub(i, this.x.get(i) / this.D[i]);
            }
            Vector.copy((Vector)this.y, (int)0, (Vector)work, (int)outputOffset, (int)this.y.size());
        }

        public void tv(Vector work, int inputOffset, int outputOffset) {
            throw new UnsupportedOperationException();
        }

        public double get(int i, int j) {
            throw new UnsupportedOperationException();
        }

        public void set(int i, int j, double x) {
            throw new UnsupportedOperationException();
        }

        public void add(int i, int j, double x) {
            throw new UnsupportedOperationException();
        }

        public void sub(int i, int j, double x) {
            throw new UnsupportedOperationException();
        }

        public void mul(int i, int j, double x) {
            throw new UnsupportedOperationException();
        }

        public void div(int i, int j, double x) {
            throw new UnsupportedOperationException();
        }

        public Matrix scale(double alpha) {
            throw new UnsupportedOperationException();
        }

        public Matrix copy() {
            throw new UnsupportedOperationException();
        }

        public Matrix transpose() {
            throw new UnsupportedOperationException();
        }
    }
}

