/*
 * Decompiled with CFR 0.152.
 */
package smile.clustering;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.function.ToDoubleBiFunction;
import java.util.stream.IntStream;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import smile.clustering.BBDTree;
import smile.clustering.CentroidClustering;
import smile.clustering.Clustering;
import smile.clustering.KMeans;
import smile.math.MathEx;
import smile.math.distance.EuclideanDistance;
import smile.sort.QuickSort;
import smile.stat.distribution.GaussianDistribution;
import smile.util.AlgoStatus;
import smile.util.IterativeAlgorithmController;

public class GMeans {
    private static final Logger logger = LoggerFactory.getLogger(GMeans.class);
    private static final double CRITICAL_VALUE = 1.8692;

    private GMeans() {
    }

    public static CentroidClustering<double[], double[]> fit(double[][] data, int kmax, int maxIter) {
        return GMeans.fit(data, new Clustering.Options(kmax, maxIter));
    }

    public static CentroidClustering<double[], double[]> fit(double[][] data, Clustering.Options options) {
        int kmax = options.k();
        int maxIter = options.maxIter();
        double tol = options.tol();
        IterativeAlgorithmController<AlgoStatus> controller = options.controller();
        int n = data.length;
        int d = data[0].length;
        int[] group = new int[n];
        double[][] sum = new double[kmax][d];
        double[][] centroids = new double[kmax][];
        double[] mean = MathEx.colMeans((double[][])data);
        int[] size = new int[kmax];
        centroids[0] = mean;
        size[0] = n;
        BBDTree bbd = new BBDTree(data);
        ArrayList<CentroidClustering<double[], double[]>> kmeans = new ArrayList<CentroidClustering<double[], double[]>>(kmax);
        ArrayList<double[]> centers = new ArrayList<double[]>();
        int k = 1;
        while (k < kmax) {
            kmeans.clear();
            centers.clear();
            double[] score = new double[k];
            for (int i = 0; i < k; ++i) {
                int ni = size[i];
                if (ni < 25) {
                    logger.info("Cluster {} too small to split: {} observations", (Object)i, (Object)ni);
                    score[i] = 0.0;
                    kmeans.add(null);
                    continue;
                }
                double[][] subset = new double[ni][];
                int l = 0;
                for (int j = 0; j < n; ++j) {
                    if (group[j] != i) continue;
                    subset[l++] = data[j];
                }
                CentroidClustering<double[], double[]> clustering = KMeans.fit(subset, new Clustering.Options(2, maxIter, tol, null));
                kmeans.add(clustering);
                double[] v = new double[d];
                for (int j = 0; j < d; ++j) {
                    v[j] = clustering.center(0)[j] - clustering.center(1)[j];
                }
                double vp = MathEx.dot((double[])v, (double[])v);
                double[] x = new double[ni];
                for (int j = 0; j < ni; ++j) {
                    x[j] = MathEx.dot((double[])subset[j], (double[])v) / vp;
                }
                MathEx.standardize((double[])x);
                score[i] = GMeans.AndersonDarling(x);
                logger.info("Cluster {} Anderson-Darling adjusted test statistic: {}", (Object)i, (Object)score[i]);
            }
            int[] index = QuickSort.sort((double[])score);
            for (int i = 0; i < k; ++i) {
                if (!(score[i] <= 1.8692)) continue;
                centers.add(centroids[index[i]]);
            }
            int m = centers.size();
            int i = k;
            while (--i >= 0) {
                if (!(score[i] > 1.8692)) continue;
                if (centers.size() + i - m + 1 < kmax) {
                    logger.info("Split cluster {}", (Object)index[i]);
                    centers.add((double[])((CentroidClustering)kmeans.get(index[i])).center(0));
                    centers.add((double[])((CentroidClustering)kmeans.get(index[i])).center(1));
                    continue;
                }
                centers.add(centroids[index[i]]);
            }
            if (centers.size() == k) {
                logger.info("No more split. Finish with {} clusters", (Object)k);
                break;
            }
            k = centers.size();
            centers.toArray((T[])centroids);
            double diff = Double.MAX_VALUE;
            double distortion = Double.MAX_VALUE;
            for (int iter = 1; iter <= maxIter && diff > tol; ++iter) {
                double wcss = bbd.clustering(k, centroids, sum, size, group);
                diff = distortion - wcss;
                distortion = wcss;
                logger.info("Iteration {}: {}-cluster distortion = {}", new Object[]{iter, k, distortion});
            }
            if (controller == null) continue;
            controller.submit((Object)new AlgoStatus(k, distortion));
            if (!controller.isInterrupted()) continue;
            break;
        }
        double[] proximity = new double[n];
        IntStream.range(0, k).parallel().forEach(cluster -> {
            double[] centroid = centroids[cluster];
            for (int i = 0; i < n; ++i) {
                double dist;
                if (group[i] != cluster) continue;
                proximity[i] = dist = MathEx.squaredDistance((double[])data[i], (double[])centroid);
            }
        });
        EuclideanDistance distance = new EuclideanDistance();
        return new CentroidClustering<double[], double[]>("G-Means", (T[])((double[][])Arrays.copyOf(centroids, k)), (ToDoubleBiFunction<double[], double[]>)distance, group, proximity);
    }

    private static double AndersonDarling(double[] x) {
        int n = x.length;
        GaussianDistribution gaussian = GaussianDistribution.getInstance();
        Arrays.sort(x);
        for (int i = 0; i < n; ++i) {
            x[i] = gaussian.cdf(x[i]);
            if (x[i] == 0.0) {
                x[i] = 1.0E-7;
            }
            if (x[i] != 1.0) continue;
            x[i] = 0.9999999;
        }
        double A = 0.0;
        for (int i = 0; i < n; ++i) {
            A -= (double)(2 * i + 1) * (Math.log(x[i]) + Math.log(1.0 - x[n - i - 1]));
        }
        A = A / (double)n - (double)n;
        return A *= 1.0 + 4.0 / (double)n - 25.0 / (double)(n * n);
    }
}

