/*
 * Decompiled with CFR 0.152.
 */
package smile.clustering;

import java.lang.reflect.Array;
import java.util.function.ToDoubleBiFunction;
import java.util.stream.IntStream;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import smile.clustering.CentroidClustering;
import smile.math.MathEx;
import smile.math.distance.Distance;

public class CLARANS<T>
extends CentroidClustering<T, T> {
    private static final long serialVersionUID = 2L;
    private static final Logger logger = LoggerFactory.getLogger(CLARANS.class);
    private final Distance<T> distance;

    public CLARANS(double distortion, T[] medoids, int[] y, Distance<T> distance) {
        super(distortion, medoids, y);
        this.distance = distance;
    }

    @Override
    public double distance(T x, T y) {
        return this.distance.d(x, y);
    }

    public static <T> CLARANS<T> fit(T[] data, int k, Distance<T> distance) {
        return CLARANS.fit(data, k, (int)Math.round(0.0125 * (double)k * (double)(data.length - k)), distance);
    }

    public static <T> CLARANS<T> fit(T[] data, int k, int maxNeighbor, Distance<T> distance) {
        if (maxNeighbor <= 0) {
            throw new IllegalArgumentException("Invalid maxNeighbors: " + maxNeighbor);
        }
        int n = data.length;
        if (k >= n) {
            throw new IllegalArgumentException("Too large k: " + k);
        }
        if (maxNeighbor > n) {
            throw new IllegalArgumentException("Too large maxNeighbor: " + maxNeighbor);
        }
        int minmax = 100;
        if (k * (n - k) < minmax) {
            minmax = k * (n - k);
        }
        if (maxNeighbor < minmax) {
            maxNeighbor = minmax;
        }
        Object[] medoids = (Object[])Array.newInstance(data.getClass().getComponentType(), k);
        Object[] newMedoids = (Object[])medoids.clone();
        int[] y = new int[n];
        int[] newY = new int[n];
        double[] newD = new double[n];
        double[] d = CLARANS.seed(data, medoids, y, distance);
        double distortion = MathEx.sum((double[])d);
        System.arraycopy(medoids, 0, newMedoids, 0, k);
        System.arraycopy(y, 0, newY, 0, n);
        System.arraycopy(d, 0, newD, 0, n);
        for (int neighborCount = 1; neighborCount <= maxNeighbor; ++neighborCount) {
            double randomNeighborDistortion = CLARANS.getRandomNeighbor(data, newMedoids, newY, newD, distance);
            if (randomNeighborDistortion < distortion) {
                logger.info(String.format("Distortion reduces to %.4f after %3d random neighbors", distortion, neighborCount));
                neighborCount = 0;
                distortion = randomNeighborDistortion;
                System.arraycopy(newMedoids, 0, medoids, 0, k);
                System.arraycopy(newY, 0, y, 0, n);
                System.arraycopy(newD, 0, d, 0, n);
                continue;
            }
            System.arraycopy(medoids, 0, newMedoids, 0, k);
            System.arraycopy(y, 0, newY, 0, n);
            System.arraycopy(d, 0, newD, 0, n);
        }
        logger.info(String.format("Final distortion: %.4f", distortion));
        return new CLARANS<Object>(distortion, medoids, y, distance);
    }

    private static <T> double getRandomNeighbor(T[] data, T[] medoids, int[] y, double[] d, ToDoubleBiFunction<T, T> distance) {
        int n = data.length;
        int k = medoids.length;
        int cluster = MathEx.randomInt((int)k);
        Object medoid = CLARANS.getRandomMedoid(data, medoids);
        medoids[cluster] = medoid;
        IntStream.range(0, n).parallel().forEach(i -> {
            double dist = distance.applyAsDouble(data[i], medoid);
            if (d[i] > dist) {
                y[i] = cluster;
                d[i] = dist;
            } else if (y[i] == cluster) {
                d[i] = dist;
                for (int j = 0; j < k; ++j) {
                    if (j == cluster || !(d[i] > (dist = distance.applyAsDouble(data[i], medoids[j])))) continue;
                    d[i] = dist;
                    y[i] = j;
                }
            }
        });
        return MathEx.sum((double[])d);
    }

    private static <T> T getRandomMedoid(T[] data, T[] medoids) {
        int n = data.length;
        T medoid = data[MathEx.randomInt((int)n)];
        while (CLARANS.contains(medoids, medoid)) {
            medoid = data[MathEx.randomInt((int)n)];
        }
        return medoid;
    }

    private static <T> boolean contains(T[] medoids, T medoid) {
        for (T m : medoids) {
            if (m != medoid) continue;
            return true;
        }
        return false;
    }
}

