/*
 * Decompiled with CFR 0.152.
 */
package smile.clustering;

import java.io.Serializable;
import java.util.Arrays;
import java.util.function.ToDoubleBiFunction;
import java.util.stream.IntStream;
import smile.math.MathEx;

public record CentroidClustering<T, U>(String name, T[] centers, ToDoubleBiFunction<T, U> distance, int[] group, double[] proximity, int[] size, double[] distortions) implements Comparable<CentroidClustering<T, U>>,
Serializable
{
    private static final long serialVersionUID = 1L;

    public CentroidClustering(String name, T[] centers, ToDoubleBiFunction<T, U> distance, int[] group, double[] proximity) {
        this(name, centers, distance, group, proximity, new int[centers.length + 1], new double[centers.length + 1]);
        int i;
        int k = centers.length;
        this.distortions[k] = 0.0;
        for (i = 0; i < group.length; ++i) {
            int y;
            int n = y = group[i];
            this.size[n] = this.size[n] + 1;
            int n2 = y;
            this.distortions[n2] = this.distortions[n2] + proximity[i];
            int n3 = k;
            this.distortions[n3] = this.distortions[n3] + proximity[i];
        }
        this.size[k] = group.length;
        for (i = 0; i <= k; ++i) {
            int n = i;
            this.distortions[n] = this.distortions[n] / (double)this.size[i];
        }
    }

    public int k() {
        return this.centers.length;
    }

    public double distortion() {
        return this.distortions[this.centers.length];
    }

    @Override
    public int compareTo(CentroidClustering<T, U> o) {
        return Double.compare(this.distortion(), o.distortion());
    }

    @Override
    public String toString() {
        int k = this.centers.length;
        StringBuilder sb = new StringBuilder();
        sb.append(String.format("%-11s %15s %12s%n", this.name, "Size (%)", "Distortion"));
        for (int i = 0; i < k; ++i) {
            double percent = 100.0 * (double)this.size[i] / (double)this.group.length;
            sb.append(String.format("Cluster %-3d %7d (%4.1f%%) %12.4f%n", i + 1, this.size[i], percent, this.distortions[i]));
        }
        sb.append(String.format("%-11s %7d (100.%%) %12.4f%n", "Total", this.group.length, this.distortions[k]));
        return sb.toString();
    }

    public T center(int i) {
        return this.centers[i];
    }

    public int group(int i) {
        return this.group[i];
    }

    public double proximity(int i) {
        return this.proximity[i];
    }

    public int size(int i) {
        return this.size[i];
    }

    public double radius(int i) {
        return this.size[i];
    }

    public int predict(U x) {
        int label = 0;
        double nearest = Double.MAX_VALUE;
        for (int i = 0; i < this.centers.length; ++i) {
            double dist = this.distance.applyAsDouble(this.centers[i], x);
            if (!(dist < nearest)) continue;
            nearest = dist;
            label = i;
        }
        return label;
    }

    CentroidClustering<T, U> assign(U[] data) {
        int n = data.length;
        int k = this.centers.length;
        Arrays.fill(this.size, 0);
        Arrays.fill(this.distortions, 0.0);
        double distortion = IntStream.range(0, n).parallel().mapToDouble(i -> {
            double dist;
            int cluster = -1;
            double nearest = Double.MAX_VALUE;
            for (int j = 0; j < k; ++j) {
                double dist2 = this.distance.applyAsDouble(this.centers[j], data[i]);
                if (!(nearest > dist2)) continue;
                nearest = dist2;
                cluster = j;
            }
            this.proximity[i] = dist = nearest * nearest;
            this.group[i] = cluster;
            int n = cluster;
            this.size[n] = this.size[n] + 1;
            int n2 = cluster;
            this.distortions[n2] = this.distortions[n2] + dist;
            return dist;
        }).sum();
        for (int i2 = 0; i2 < k; ++i2) {
            int n2 = i2;
            this.distortions[n2] = this.distortions[n2] / (double)this.size[i2];
        }
        this.distortions[k] = MathEx.mean((double[])this.proximity);
        return new CentroidClustering<T, U>(this.name, this.centers, this.distance, this.group, this.proximity, this.size, this.distortions);
    }

    public static <T> CentroidClustering<T, T> init(String name, T[] data, int k, ToDoubleBiFunction<T, T> distance) {
        int n = data.length;
        int[] group = new int[n];
        double[] proximity = new double[n];
        double[] probability = new double[n];
        Arrays.fill(proximity, Double.MAX_VALUE);
        T[] medoids = Arrays.copyOf(data, k);
        medoids[0] = data[MathEx.randomInt((int)n)];
        for (int j = 1; j <= k; ++j) {
            int prev = j - 1;
            Object medoid = medoids[prev];
            IntStream.range(0, n).parallel().forEach(i -> {
                double dist = distance.applyAsDouble(data[i], medoid);
                if ((dist *= dist) < proximity[i]) {
                    proximity[i] = dist;
                    group[i] = prev;
                }
            });
            if (j >= k) continue;
            System.arraycopy(proximity, 0, probability, 0, n);
            MathEx.unitize1((double[])probability);
            T center = data[MathEx.random((double[])probability)];
            while (CentroidClustering.contains(center, medoids, j)) {
                center = data[MathEx.random((double[])probability)];
            }
            medoids[j] = center;
        }
        return new CentroidClustering<T, T>(name, medoids, distance, group, proximity);
    }

    public static double[][] seeds(double[][] data, int k) {
        CentroidClustering<double[], double[]> clustering = CentroidClustering.init("K-Means++", data, k, MathEx::distance);
        double[][] medoids = clustering.centers();
        double[][] neurons = new double[k][];
        for (int i = 0; i < k; ++i) {
            neurons[i] = (double[])medoids[i].clone();
        }
        return neurons;
    }

    static <T> boolean contains(T medoid, T[] medoids) {
        return CentroidClustering.contains(medoid, medoids, medoids.length);
    }

    static <T> boolean contains(T medoid, T[] medoids, int length) {
        for (int i = 0; i < length; ++i) {
            if (medoids[i] != medoid) continue;
            return true;
        }
        return false;
    }
}

