/*
 * Decompiled with CFR 0.152.
 */
package smile.clustering;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import smile.clustering.BBDTree;
import smile.clustering.CentroidClustering;
import smile.math.MathEx;

public class KMeans
extends CentroidClustering<double[], double[]> {
    private static final long serialVersionUID = 2L;
    private static final Logger logger = LoggerFactory.getLogger(KMeans.class);

    public KMeans(double distortion, double[][] centroids, int[] y) {
        super(distortion, (T[])centroids, y);
    }

    @Override
    public double distance(double[] x, double[] y) {
        return MathEx.squaredDistance((double[])x, (double[])y);
    }

    public static KMeans fit(double[][] data, int k) {
        return KMeans.fit(data, k, 100, 1.0E-4);
    }

    public static KMeans fit(double[][] data, int k, int maxIter, double tol) {
        return KMeans.fit(new BBDTree(data), data, k, maxIter, tol);
    }

    public static KMeans fit(BBDTree bbd, double[][] data, int k, int maxIter, double tol) {
        if (k < 2) {
            throw new IllegalArgumentException("Invalid number of clusters: " + k);
        }
        if (maxIter <= 0) {
            throw new IllegalArgumentException("Invalid maximum number of iterations: " + maxIter);
        }
        int n = data.length;
        int d = data[0].length;
        int[] y = new int[n];
        double[][] medoids = new double[k][];
        double distortion = MathEx.sum((double[])KMeans.seed(data, medoids, y, MathEx::squaredDistance));
        logger.info(String.format("Distortion after initialization: %.4f", distortion));
        int[] size = new int[k];
        double[][] centroids = new double[k][d];
        KMeans.updateCentroids(centroids, data, y, size);
        double[][] sum = new double[k][d];
        double diff = Double.MAX_VALUE;
        for (int iter = 1; iter <= maxIter && diff > tol; ++iter) {
            double wcss = bbd.clustering(centroids, sum, size, y);
            logger.info(String.format("Distortion after %3d iterations: %.4f", iter, wcss));
            diff = distortion - wcss;
            distortion = wcss;
        }
        return new KMeans(distortion, centroids, y);
    }

    public static KMeans lloyd(double[][] data, int k) {
        return KMeans.lloyd(data, k, 100, 1.0E-4);
    }

    public static KMeans lloyd(double[][] data, int k, int maxIter, double tol) {
        if (k < 2) {
            throw new IllegalArgumentException("Invalid number of clusters: " + k);
        }
        if (maxIter <= 0) {
            throw new IllegalArgumentException("Invalid maximum number of iterations: " + maxIter);
        }
        int n = data.length;
        int d = data[0].length;
        int[] y = new int[n];
        double[][] medoids = new double[k][];
        double distortion = MathEx.sum((double[])KMeans.seed(data, medoids, y, MathEx::squaredDistanceWithMissingValues));
        logger.info(String.format("Distortion after initialization: %.4f", distortion));
        int[] size = new int[k];
        double[][] centroids = new double[k][d];
        int[][] notNaN = new int[k][d];
        double diff = Double.MAX_VALUE;
        for (int iter = 1; iter <= maxIter && diff > tol; ++iter) {
            KMeans.updateCentroidsWithMissingValues(centroids, data, y, size, notNaN);
            double wcss = KMeans.assign(y, data, centroids, MathEx::squaredDistanceWithMissingValues);
            logger.info(String.format("Distortion after %3d iterations: %.4f", iter, wcss));
            diff = distortion - wcss;
            distortion = wcss;
        }
        if (diff > tol) {
            KMeans.updateCentroidsWithMissingValues(centroids, data, y, size, notNaN);
        }
        return new KMeans(distortion, centroids, y){

            @Override
            public double distance(double[] x, double[] y) {
                return MathEx.squaredDistanceWithMissingValues((double[])x, (double[])y);
            }
        };
    }
}

