/*
 * Decompiled with CFR 0.152.
 */
package smile.clustering;

import java.util.ArrayList;
import java.util.stream.IntStream;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import smile.clustering.CLARANS;
import smile.clustering.KMeans;
import smile.clustering.PartitionClustering;
import smile.math.MathEx;
import smile.math.distance.Distance;
import smile.math.distance.EuclideanDistance;
import smile.neighbor.LinearSearch;
import smile.neighbor.Neighbor;
import smile.neighbor.RNNSearch;

public class MEC<T>
extends PartitionClustering
implements Comparable<MEC<T>> {
    private static final long serialVersionUID = 2L;
    private static final Logger logger = LoggerFactory.getLogger(MEC.class);
    public final double entropy;
    public final double radius;
    private RNNSearch<T, T> nns;

    public MEC(double entropy, double radius, RNNSearch<T, T> nns, int k, int[] y) {
        super(k, y);
        this.entropy = entropy;
        this.radius = radius;
        this.nns = nns;
    }

    @Override
    public int compareTo(MEC<T> o) {
        return Double.compare(this.entropy, o.entropy);
    }

    public static <T> MEC<T> fit(T[] data, Distance<T> distance, int k, double radius) {
        int[] y;
        if (k < 2) {
            throw new IllegalArgumentException("Invalid k: " + k);
        }
        if (radius <= 0.0) {
            throw new IllegalArgumentException("Invalid radius: " + radius);
        }
        if (data instanceof double[][] && distance instanceof EuclideanDistance) {
            KMeans kmeans = KMeans.fit((double[][])data, k);
            y = kmeans.y;
        } else {
            CLARANS<T> clarans = CLARANS.fit(data, (arg_0, arg_1) -> distance.d(arg_0, arg_1), k);
            y = clarans.y;
        }
        return MEC.fit(data, new LinearSearch<T>(data, distance), k, radius, y, 1.0E-4);
    }

    public static <T> MEC<T> fit(T[] data, RNNSearch<T, T> nns, int k, double radius, int[] y, double tol) {
        int i2;
        if (k < 2) {
            throw new IllegalArgumentException("Invalid k: " + k);
        }
        if (radius <= 0.0) {
            throw new IllegalArgumentException("Invalid radius: " + radius);
        }
        int n = data.length;
        double[] px = new double[n];
        ArrayList<int[]> neighbors = new ArrayList<int[]>();
        logger.info(String.format("Estimating the probabilities ...", new Object[0]));
        IntStream stream = IntStream.range(0, n);
        if (!(nns instanceof LinearSearch)) {
            stream = stream.parallel();
        }
        stream.forEach(i -> {
            ArrayList list = new ArrayList();
            list.add(Neighbor.of(data[i], i, 0.0));
            nns.range(data[i], radius, list);
            int[] neighborhood = new int[list.size()];
            neighbors.add(neighborhood);
            for (int j = 0; j < list.size(); ++j) {
                neighborhood[j] = list.get((int)j).index;
            }
            px[i] = (double)list.size() / (double)n;
        });
        int[][] size = new int[n][k];
        int[] dominantCluster = new int[n];
        IntStream.range(0, n).parallel().forEach(i -> {
            for (int j : (int[])neighbors.get(i)) {
                int[] nArray = size[i];
                int n = y[j];
                nArray[n] = nArray[n] + 1;
            }
        });
        IntStream.range(0, n).parallel().forEach(i -> {
            int max = 0;
            for (int j = 0; j < k; ++j) {
                if (size[i][j] <= max) continue;
                dominantCluster[i] = j;
                max = size[i][j];
            }
        });
        double entropy = MEC.entropy(k, neighbors, size, px);
        logger.info(String.format("Entropy after initialization: %.4f", entropy));
        double diff = Double.MAX_VALUE;
        int iter = 1;
        while (diff > tol) {
            for (int i3 = 0; i3 < n; ++i3) {
                if (dominantCluster[i3] == y[i3]) continue;
                double oldMutual = 0.0;
                double newMutual = 0.0;
                for (int neighbor : neighbors.get(i3)) {
                    double nk = neighbors.get(neighbor).length;
                    double r1 = (double)size[neighbor][y[i3]] / nk;
                    double r2 = (double)size[neighbor][dominantCluster[i3]] / nk;
                    if (r1 > 0.0) {
                        oldMutual -= r1 * MathEx.log2((double)r1) * px[neighbor];
                    }
                    if (r2 > 0.0) {
                        oldMutual -= r2 * MathEx.log2((double)r2) * px[neighbor];
                    }
                    r1 = ((double)size[neighbor][y[i3]] - 1.0) / nk;
                    r2 = ((double)size[neighbor][dominantCluster[i3]] + 1.0) / nk;
                    if (r1 > 0.0) {
                        newMutual -= r1 * MathEx.log2((double)r1) * px[neighbor];
                    }
                    if (!(r2 > 0.0)) continue;
                    newMutual -= r2 * MathEx.log2((double)r2) * px[neighbor];
                }
                if (!(newMutual < oldMutual)) continue;
                for (int neighbor : neighbors.get(i3)) {
                    int[] nArray = size[neighbor];
                    int n2 = y[i3];
                    nArray[n2] = nArray[n2] - 1;
                    int[] nArray2 = size[neighbor];
                    int n3 = dominantCluster[i3];
                    nArray2[n3] = nArray2[n3] + 1;
                    int mi = dominantCluster[i3];
                    int mk = dominantCluster[neighbor];
                    if (size[neighbor][mi] <= size[neighbor][mk]) continue;
                    dominantCluster[neighbor] = dominantCluster[i3];
                }
                y[i3] = dominantCluster[i3];
            }
            double prevObj = entropy;
            entropy = MEC.entropy(k, neighbors, size, px);
            diff = prevObj - entropy;
            logger.info(String.format("Entropy after %3d iterations: %.4f", iter, entropy));
            ++iter;
        }
        int[] clusterSize = new int[k];
        for (int i4 = 0; i4 < n; ++i4) {
            int n4 = y[i4];
            clusterSize[n4] = clusterSize[n4] + 1;
        }
        int K = 0;
        int j = 0;
        for (i2 = 0; i2 < k; ++i2) {
            if (clusterSize[i2] <= 0) continue;
            ++K;
            clusterSize[i2] = j++;
        }
        for (i2 = 0; i2 < n; ++i2) {
            y[i2] = clusterSize[y[i2]];
        }
        return new MEC<T>(entropy, radius, nns, K, y);
    }

    public int predict(T x) {
        ArrayList neighbors = new ArrayList();
        this.nns.range(x, this.radius, neighbors);
        if (neighbors.isEmpty()) {
            return Integer.MAX_VALUE;
        }
        int[] label = new int[this.k];
        for (Neighbor neighbor : neighbors) {
            int yi;
            int n = yi = this.y[neighbor.index];
            label[n] = label[n] + 1;
        }
        return MathEx.whichMax((int[])label);
    }

    @Override
    public String toString() {
        return String.format("Cluster entropy: %.5f%n", this.entropy) + super.toString();
    }

    private static double entropy(int k, ArrayList<int[]> neighbors, int[][] size, double[] px) {
        int n = neighbors.size();
        return IntStream.range(0, n).parallel().mapToDouble(i -> {
            double conditionalEntropy = 0.0;
            int ni = ((int[])neighbors.get(i)).length;
            int[] ci = size[i];
            for (int j = 0; j < k; ++j) {
                if (ci[j] <= 0) continue;
                double r = (double)ci[j] / (double)ni;
                conditionalEntropy -= r * MathEx.log2((double)r);
            }
            return conditionalEntropy *= px[i];
        }).sum();
    }
}

