/*
 * Decompiled with CFR 0.152.
 */
package smile.clustering;

import java.io.Serializable;
import java.util.Arrays;
import java.util.function.ToDoubleBiFunction;
import java.util.stream.IntStream;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import smile.clustering.CentroidClustering;
import smile.clustering.Clustering;
import smile.math.MathEx;
import smile.util.AlgoStatus;
import smile.util.IterativeAlgorithmController;
import smile.util.SparseArray;

public class SIB {
    private static final Logger logger = LoggerFactory.getLogger(SIB.class);

    private SIB() {
    }

    public static CentroidClustering<double[], SparseArray> fit(SparseArray[] data, int k, int maxIter) {
        return SIB.fit(data, new Clustering.Options(k, maxIter));
    }

    public static CentroidClustering<double[], SparseArray> fit(SparseArray[] data, Clustering.Options options) {
        int k = options.k();
        int maxIter = options.maxIter();
        double tol = options.tol();
        IterativeAlgorithmController<AlgoStatus> controller = options.controller();
        int n = data.length;
        int d = 1 + Arrays.stream(data).flatMapToInt(SparseArray::indexStream).max().orElse(0);
        ToDoubleBiFunction<SparseArray, SparseArray> distance = MathEx::JensenShannonDivergence;
        CentroidClustering<SparseArray, SparseArray> clustering = CentroidClustering.init("SIB", data, k, distance);
        logger.info("Initial distortion = {}", (Object)clustering.distortion());
        int[] size = clustering.size();
        int[] group = clustering.group();
        double[][] centroids = new double[k][d];
        IntStream.range(0, k).parallel().forEach(cluster -> {
            for (int i = 0; i < n; ++i) {
                if (group[i] != cluster) continue;
                int n2 = cluster;
                size[n2] = size[n2] + 1;
                for (SparseArray.Entry e : data[i]) {
                    double[] dArray = centroids[cluster];
                    int n3 = e.index();
                    dArray[n3] = dArray[n3] + e.value();
                }
            }
            int j = 0;
            while (j < d) {
                double[] dArray = centroids[cluster];
                int n4 = j++;
                dArray[n4] = dArray[n4] / (double)size[cluster];
            }
        });
        int reassignment = n;
        for (int iter = 1; iter <= maxIter && (double)reassignment > tol; ++iter) {
            reassignment = 0;
            for (int i2 = 0; i2 < n; ++i2) {
                int c = group[i2];
                double nearest = Double.MAX_VALUE;
                for (int j = 0; j < k; ++j) {
                    double divergence = MathEx.JensenShannonDivergence((SparseArray)data[i2], (double[])centroids[j]);
                    if (!(nearest > divergence)) continue;
                    nearest = divergence;
                    c = j;
                }
                if (c == group[i2]) continue;
                int o = group[i2];
                int j = 0;
                while (j < d) {
                    double[] dArray = centroids[c];
                    int n2 = j;
                    dArray[n2] = dArray[n2] * (double)size[c];
                    double[] dArray2 = centroids[o];
                    int n3 = j++;
                    dArray2[n3] = dArray2[n3] * (double)size[o];
                }
                for (SparseArray.Entry e : data[i2]) {
                    int j2 = e.index();
                    double p = e.value();
                    double[] dArray = centroids[c];
                    int n4 = j2;
                    dArray[n4] = dArray[n4] + p;
                    double[] dArray3 = centroids[o];
                    int n5 = j2;
                    dArray3[n5] = dArray3[n5] - p;
                    if (!(centroids[o][j2] < 0.0)) continue;
                    centroids[o][j2] = 0.0;
                }
                int n6 = o;
                size[n6] = size[n6] - 1;
                int n7 = c;
                size[n7] = size[n7] + 1;
                j = 0;
                while (j < d) {
                    double[] dArray = centroids[c];
                    int n8 = j++;
                    dArray[n8] = dArray[n8] / (double)size[c];
                }
                if (size[o] > 0) {
                    j = 0;
                    while (j < d) {
                        double[] dArray = centroids[o];
                        int n9 = j++;
                        dArray[n9] = dArray[n9] / (double)size[o];
                    }
                }
                group[i2] = c;
                ++reassignment;
            }
            logger.info("Iteration {}: assignments = {}", (Object)iter, (Object)reassignment);
            if (controller == null) continue;
            controller.submit((Object)new AlgoStatus(iter, (double)reassignment));
            if (controller.isInterrupted()) break;
        }
        double[] proximity = clustering.proximity();
        double distortion = IntStream.range(0, n).parallel().mapToDouble(i -> {
            double dist = MathEx.JensenShannonDivergence((SparseArray)data[i], (double[])centroids[group[i]]);
            dist *= dist;
            proximity[i] = dist;
            return dist;
        }).sum() / (double)n;
        logger.info("Final distortion: {}", (Object)distortion);
        return new CentroidClustering<double[], SparseArray>("SIB", (T[])centroids, new JSDistance(), group, proximity);
    }

    private static class JSDistance
    implements ToDoubleBiFunction<double[], SparseArray>,
    Serializable {
        private static final long serialVersionUID = 1L;

        private JSDistance() {
        }

        @Override
        public double applyAsDouble(double[] x, SparseArray y) {
            return MathEx.JensenShannonDivergence((double[])x, (SparseArray)y);
        }
    }
}

