/*
 * Decompiled with CFR 0.152.
 */
package smile.clustering;

import java.util.Arrays;
import java.util.stream.IntStream;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import smile.clustering.CentroidClustering;
import smile.clustering.Clustering;
import smile.math.MathEx;
import smile.math.distance.HammingDistance;
import smile.util.AlgoStatus;
import smile.util.IntSet;
import smile.util.IterativeAlgorithmController;

public class KModes {
    private static final Logger logger = LoggerFactory.getLogger(KModes.class);

    private KModes() {
    }

    public static CentroidClustering<int[], int[]> fit(int[][] data, int k, int maxIter) {
        return KModes.fit(data, new Clustering.Options(k, maxIter));
    }

    public static CentroidClustering<int[], int[]> fit(int[][] data, Clustering.Options options) {
        int k = options.k();
        int maxIter = options.maxIter();
        double tol = options.tol();
        IterativeAlgorithmController<AlgoStatus> controller = options.controller();
        int n = data.length;
        int d = data[0].length;
        Codec[] codec = (Codec[])IntStream.range(0, d).parallel().mapToObj(j -> {
            int[] x = new int[n];
            for (int i = 0; i < n; ++i) {
                x[i] = data[i][j];
            }
            return new Codec(x);
        }).toArray(Codec[]::new);
        CentroidClustering<int[], int[]> clustering = CentroidClustering.init("K-Modes", data, k, new HammingDistance());
        double distortion = clustering.distortion();
        logger.info("Initial distortion = {}", (Object)distortion);
        double diff = 2.147483647E9;
        for (int iter = 1; iter <= maxIter && diff > tol; ++iter) {
            KModes.updateCentroids(clustering, data, codec);
            clustering = clustering.assign((U[])data);
            diff = distortion - clustering.distortion();
            distortion = clustering.distortion();
            logger.info("Iteration {}: distortion = {}", (Object)iter, (Object)clustering.distortion());
            if (controller == null) continue;
            controller.submit((Object)new AlgoStatus(iter, distortion));
            if (controller.isInterrupted()) break;
        }
        if (diff > 0.0) {
            KModes.updateCentroids(clustering, data, codec);
        }
        return clustering;
    }

    private static void updateCentroids(CentroidClustering<int[], int[]> clustering, int[][] data, Codec[] codec) {
        int n = data.length;
        int[] group = clustering.group();
        int[][] centroids = clustering.centers();
        int k = centroids.length;
        int d = centroids[0].length;
        IntStream.range(0, k).parallel().forEach(cluster -> {
            int[] centroid = new int[d];
            for (int j = 0; j < d; ++j) {
                if (codec[j].k <= 1) continue;
                int[] count = new int[codec[j].k];
                int[] x = codec[j].x;
                for (int i = 0; i < n; ++i) {
                    if (group[i] != cluster) continue;
                    int n2 = x[i];
                    count[n2] = count[n2] + 1;
                }
                centroid[j] = codec[j].valueOf(MathEx.whichMax((int[])count));
            }
            centroids[cluster] = centroid;
        });
    }

    private static class Codec {
        public final int k;
        public final int[] x;
        public final IntSet encoder;

        public Codec(int[] x) {
            int[] y = MathEx.unique((int[])x);
            Arrays.sort(y);
            this.x = x;
            this.k = y.length;
            this.encoder = new IntSet(y);
            if (y[0] != 0 || y[this.k - 1] != this.k - 1) {
                int n = x.length;
                for (int i = 0; i < n; ++i) {
                    x[i] = this.encoder.indexOf(x[i]);
                }
            }
        }

        public int valueOf(int i) {
            return this.encoder.valueOf(i);
        }
    }
}

