/*
 * Decompiled with CFR 0.152.
 */
package smile.clustering;

import java.util.stream.IntStream;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import smile.clustering.CentroidClustering;
import smile.clustering.Clustering;
import smile.math.MathEx;
import smile.math.distance.Distance;
import smile.util.AlgoStatus;
import smile.util.IterativeAlgorithmController;

public class KMedoids<T> {
    private static final Logger logger = LoggerFactory.getLogger(KMedoids.class);

    private KMedoids() {
    }

    public static <T> CentroidClustering<T, T> fit(T[] data, Distance<T> distance, int k) {
        return KMedoids.fit(data, distance, new Clustering.Options(k, 2, 0.0125, null));
    }

    public static <T> CentroidClustering<T, T> fit(T[] data, Distance<T> distance, Clustering.Options options) {
        int n = data.length;
        int k = options.k();
        if (k >= n) {
            throw new IllegalArgumentException("Too large k: " + k);
        }
        IterativeAlgorithmController<AlgoStatus> controller = options.controller();
        int numLocal = Math.min(3, options.maxIter());
        int maxNeighbor = (int)Math.round(options.tol() * (double)k * (double)(n - k));
        int minmax = Math.min(100, k * (n - k));
        maxNeighbor = Math.max(minmax, maxNeighbor);
        if (maxNeighbor > n) {
            throw new IllegalArgumentException("Too large maxNeighbor: " + maxNeighbor);
        }
        double best = Double.MAX_VALUE;
        CentroidClustering result = null;
        for (int iter = 1; iter <= numLocal; ++iter) {
            CentroidClustering<T, T> clustering = CentroidClustering.init("K-Medoids", data, k, distance);
            T[] medoids = clustering.centers();
            double distortion = clustering.distortion();
            int[] group = clustering.group();
            double[] proximity = clustering.proximity();
            Object[] centers = (Object[])medoids.clone();
            int[] y = new int[n];
            double[] d = new double[n];
            for (int neighborCount = 1; neighborCount <= maxNeighbor; ++neighborCount) {
                System.arraycopy(medoids, 0, centers, 0, k);
                System.arraycopy(group, 0, y, 0, n);
                System.arraycopy(proximity, 0, d, 0, n);
                double loss = KMedoids.randomSearch(data, centers, y, d, distance);
                if (!(loss < distortion)) continue;
                System.arraycopy(centers, 0, medoids, 0, k);
                System.arraycopy(y, 0, group, 0, n);
                System.arraycopy(d, 0, proximity, 0, n);
                distortion = loss;
                logger.info("Iteration {}: random search = {}, distortion = {} ", new Object[]{iter, neighborCount, distortion});
                neighborCount = 0;
            }
            if (distortion < best) {
                best = distortion;
                result = new CentroidClustering("K-Medoids", medoids, distance, group, proximity);
            }
            if (controller == null) continue;
            controller.submit((Object)new AlgoStatus(iter, distortion));
            if (controller.isInterrupted()) break;
        }
        return result;
    }

    private static <T> double randomSearch(T[] data, T[] medoids, int[] y, double[] d, Distance<T> distance) {
        int n = data.length;
        int k = medoids.length;
        int cluster = MathEx.randomInt((int)k);
        Object medoid = KMedoids.getRandomMedoid(data, medoids);
        medoids[cluster] = medoid;
        IntStream.range(0, n).parallel().forEach(i -> {
            double dist = distance.applyAsDouble(data[i], medoid);
            if (d[i] > (dist *= dist)) {
                y[i] = cluster;
                d[i] = dist;
            } else if (y[i] == cluster) {
                d[i] = dist;
                for (int j = 0; j < k; ++j) {
                    if (j == cluster) continue;
                    dist = distance.applyAsDouble(data[i], medoids[j]);
                    if (!(d[i] > (dist *= dist))) continue;
                    d[i] = dist;
                    y[i] = j;
                }
            }
        });
        return MathEx.mean((double[])d);
    }

    private static <T> T getRandomMedoid(T[] data, T[] medoids) {
        int n = data.length;
        T medoid = data[MathEx.randomInt((int)n)];
        while (CentroidClustering.contains(medoid, medoids)) {
            medoid = data[MathEx.randomInt((int)n)];
        }
        return medoid;
    }
}

