/*
 * Decompiled with CFR 0.152.
 */
package smile.clustering;

import java.util.Arrays;
import java.util.function.ToDoubleBiFunction;
import java.util.stream.IntStream;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import smile.clustering.BBDTree;
import smile.clustering.CentroidClustering;
import smile.clustering.Clustering;
import smile.math.MathEx;
import smile.math.distance.EuclideanDistance;
import smile.util.AlgoStatus;
import smile.util.IterativeAlgorithmController;

public class KMeans {
    private static final Logger logger = LoggerFactory.getLogger(KMeans.class);

    private KMeans() {
    }

    public static CentroidClustering<double[], double[]> fit(double[][] data, int k, int maxIter) {
        return KMeans.fit(data, new Clustering.Options(k, maxIter));
    }

    public static CentroidClustering<double[], double[]> fit(double[][] data, Clustering.Options options) {
        return KMeans.fit(new BBDTree(data), data, options);
    }

    public static CentroidClustering<double[], double[]> fit(BBDTree bbd, double[][] data, Clustering.Options options) {
        int k = options.k();
        int maxIter = options.maxIter();
        double tol = options.tol();
        IterativeAlgorithmController<AlgoStatus> controller = options.controller();
        int n = data.length;
        int d = data[0].length;
        EuclideanDistance distance = new EuclideanDistance();
        CentroidClustering<double[], double[]> clustering = CentroidClustering.init("K-Means", data, k, distance);
        double distortion = clustering.distortion();
        logger.info("Initial distortion = {}", (Object)distortion);
        int[] size = clustering.size();
        int[] group = clustering.group();
        double[][] centroids = (double[][])clustering.centers();
        KMeans.updateCentroids(clustering, data);
        double[][] sum = new double[k][d];
        double diff = Double.MAX_VALUE;
        for (int iter = 1; iter <= maxIter && diff > tol; ++iter) {
            double wcss = bbd.clustering(k, centroids, sum, size, group);
            diff = distortion - wcss;
            distortion = wcss;
            logger.info("Iteration {}: distortion = {}", (Object)iter, (Object)distortion);
            if (controller == null) continue;
            controller.submit((Object)new AlgoStatus(iter, distortion));
            if (controller.isInterrupted()) break;
        }
        if (diff > tol) {
            KMeans.updateCentroids(clustering, data);
        }
        double[] proximity = clustering.proximity();
        IntStream.range(0, n).parallel().forEach(arg_0 -> KMeans.lambda$fit$0((ToDoubleBiFunction)distance, data, centroids, group, proximity, arg_0));
        return new CentroidClustering<double[], double[]>("X-Means", (T[])centroids, (ToDoubleBiFunction<double[], double[]>)distance, group, proximity);
    }

    public static CentroidClustering<double[], double[]> lloyd(double[][] data, int k, int maxIter) {
        return KMeans.lloyd(data, new Clustering.Options(k, maxIter));
    }

    public static CentroidClustering<double[], double[]> lloyd(double[][] data, Clustering.Options options) {
        int k = options.k();
        int maxIter = options.maxIter();
        double tol = options.tol();
        IterativeAlgorithmController<AlgoStatus> controller = options.controller();
        int n = data.length;
        int d = data[0].length;
        ToDoubleBiFunction<double[], double[]> distance = MathEx::distanceWithMissingValues;
        CentroidClustering<double[], double[]> clustering = CentroidClustering.init("K-Means", data, k, distance);
        double distortion = clustering.distortion();
        logger.info("Initial distortion = {}", (Object)distortion);
        int[][] notNaN = new int[k][d];
        int[] size = clustering.size();
        int[] group = clustering.group();
        double diff = Double.MAX_VALUE;
        for (int iter = 1; iter <= maxIter && diff > tol; ++iter) {
            KMeans.updateCentroidsWithMissingValues(clustering, data, notNaN);
            clustering = clustering.assign((U[])data);
            diff = distortion - clustering.distortion();
            distortion = clustering.distortion();
            logger.info("Iteration {}: distortion = {}", (Object)iter, (Object)distortion);
            if (controller == null) continue;
            controller.submit((Object)new AlgoStatus(iter, distortion));
            if (controller.isInterrupted()) break;
        }
        if (diff > tol) {
            KMeans.updateCentroidsWithMissingValues(clustering, data, notNaN);
        }
        return clustering;
    }

    static void updateCentroids(CentroidClustering<double[], double[]> clustering, double[][] data) {
        int n = data.length;
        int[] size = clustering.size();
        int[] group = clustering.group();
        double[][] centroids = clustering.centers();
        int k = centroids.length;
        int d = centroids[0].length;
        Arrays.fill(size, 0);
        IntStream.range(0, k).parallel().forEach(cluster -> {
            double[] centroid = new double[d];
            for (int i = 0; i < n; ++i) {
                if (group[i] != cluster) continue;
                int n2 = cluster;
                size[n2] = size[n2] + 1;
                for (int j = 0; j < d; ++j) {
                    int n3 = j;
                    centroid[n3] = centroid[n3] + data[i][j];
                }
            }
            int j = 0;
            while (j < d) {
                int n4 = j++;
                centroid[n4] = centroid[n4] / (double)size[cluster];
            }
            centroids[cluster] = centroid;
        });
    }

    static void updateCentroidsWithMissingValues(CentroidClustering<double[], double[]> clustering, double[][] data, int[][] notNaN) {
        int n = data.length;
        int[] size = clustering.size();
        int[] group = clustering.group();
        double[][] centroids = clustering.centers();
        int k = centroids.length;
        int d = centroids[0].length;
        IntStream.range(0, k).parallel().forEach(cluster -> {
            double[] centroid = new double[d];
            Arrays.fill(notNaN[cluster], 0);
            for (int i = 0; i < n; ++i) {
                if (group[i] != cluster) continue;
                int n2 = cluster;
                size[n2] = size[n2] + 1;
                for (int j = 0; j < d; ++j) {
                    if (Double.isNaN(data[i][j])) continue;
                    int n3 = j;
                    centroid[n3] = centroid[n3] + data[i][j];
                    int[] nArray = notNaN[cluster];
                    int n4 = j;
                    nArray[n4] = nArray[n4] + 1;
                }
            }
            for (int j = 0; j < d; ++j) {
                int n5 = j;
                centroid[n5] = centroid[n5] / (double)notNaN[cluster][j];
            }
            centroids[cluster] = centroid;
        });
    }

    private static /* synthetic */ void lambda$fit$0(ToDoubleBiFunction distance, double[][] data, double[][] centroids, int[] group, double[] proximity, int i) {
        double dist = distance.applyAsDouble(data[i], centroids[group[i]]);
        proximity[i] = dist * dist;
    }
}

