/*
 * Decompiled with CFR 0.152.
 */
package smile.clustering;

import java.util.ArrayList;
import java.util.Properties;
import java.util.stream.IntStream;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import smile.clustering.CentroidClustering;
import smile.clustering.KMeans;
import smile.clustering.KMedoids;
import smile.clustering.Partitioning;
import smile.math.MathEx;
import smile.math.distance.Distance;
import smile.math.distance.EuclideanDistance;
import smile.neighbor.LinearSearch;
import smile.neighbor.Neighbor;
import smile.neighbor.RNNSearch;
import smile.util.AlgoStatus;
import smile.util.IterativeAlgorithmController;

public class MEC<T>
extends Partitioning
implements Comparable<MEC<T>> {
    private static final long serialVersionUID = 2L;
    private static final Logger logger = LoggerFactory.getLogger(MEC.class);
    private final double entropy;
    private final double radius;
    private final RNNSearch<T, T> nns;

    public MEC(int k, int[] group, double entropy, double radius, RNNSearch<T, T> nns) {
        super(k, group);
        this.entropy = entropy;
        this.radius = radius;
        this.nns = nns;
    }

    public double entropy() {
        return this.entropy;
    }

    public double radius() {
        return this.radius;
    }

    @Override
    public int compareTo(MEC<T> o) {
        return Double.compare(this.entropy, o.entropy);
    }

    /*
     * Enabled aggressive block sorting
     */
    public static <T> MEC<T> fit(T[] data, Distance<T> distance, int k, double radius) {
        int[] group;
        if (k < 2) {
            throw new IllegalArgumentException("Invalid k: " + k);
        }
        if (radius <= 0.0) {
            throw new IllegalArgumentException("Invalid radius: " + radius);
        }
        if (data instanceof double[][]) {
            double[][] matrix = (double[][])data;
            if (distance instanceof EuclideanDistance) {
                CentroidClustering<double[], double[]> kmeans = KMeans.fit(matrix, k, 10);
                group = kmeans.group();
                return MEC.fit(data, LinearSearch.of((Object[])data, distance), group, new Options(k, radius));
            }
        }
        CentroidClustering<T, T> clarans = KMedoids.fit(data, distance, k);
        group = clarans.group();
        return MEC.fit(data, LinearSearch.of((Object[])data, distance), group, new Options(k, radius));
    }

    public static <T> MEC<T> fit(T[] data, RNNSearch<T, T> nns, int[] group, Options options) {
        int i2;
        int k = options.k;
        int maxIter = options.maxIter;
        double radius = options.radius;
        double tol = options.tol;
        IterativeAlgorithmController<AlgoStatus> controller = options.controller;
        int n = data.length;
        double[] px = new double[n];
        int[][] neighbors = new int[n][];
        logger.info("Estimating the probabilities ...");
        IntStream stream = IntStream.range(0, n);
        if (!(nns instanceof LinearSearch)) {
            stream = stream.parallel();
        }
        stream.forEach(i -> {
            ArrayList<Neighbor> list = new ArrayList<Neighbor>();
            list.add(Neighbor.of((Object)data[i], (int)i, (double)0.0));
            nns.search(data[i], radius, list);
            int[] neighborhood = new int[list.size()];
            neighbors[i] = neighborhood;
            for (int j = 0; j < list.size(); ++j) {
                neighborhood[j] = ((Neighbor)list.get(j)).index();
            }
            px[i] = (double)list.size() / (double)n;
        });
        int[][] size = new int[n][k];
        int[] dominantCluster = new int[n];
        IntStream.range(0, n).parallel().forEach(i -> {
            for (int j : neighbors[i]) {
                int[] nArray = size[i];
                int n = group[j];
                nArray[n] = nArray[n] + 1;
            }
        });
        IntStream.range(0, n).parallel().forEach(i -> {
            int max = 0;
            for (int j = 0; j < k; ++j) {
                if (size[i][j] <= max) continue;
                dominantCluster[i] = j;
                max = size[i][j];
            }
        });
        double entropy = MEC.entropy(k, neighbors, size, px);
        logger.info("Initial entropy = {}", (Object)entropy);
        double diff = Double.MAX_VALUE;
        for (int iter = 1; iter <= maxIter && diff > tol; ++iter) {
            for (int i3 = 0; i3 < n; ++i3) {
                if (dominantCluster[i3] == group[i3]) continue;
                double oldMutual = 0.0;
                double newMutual = 0.0;
                for (int neighbor : neighbors[i3]) {
                    double nk = neighbors[neighbor].length;
                    double r1 = (double)size[neighbor][group[i3]] / nk;
                    double r2 = (double)size[neighbor][dominantCluster[i3]] / nk;
                    if (r1 > 0.0) {
                        oldMutual -= r1 * MathEx.log2((double)r1) * px[neighbor];
                    }
                    if (r2 > 0.0) {
                        oldMutual -= r2 * MathEx.log2((double)r2) * px[neighbor];
                    }
                    r1 = ((double)size[neighbor][group[i3]] - 1.0) / nk;
                    r2 = ((double)size[neighbor][dominantCluster[i3]] + 1.0) / nk;
                    if (r1 > 0.0) {
                        newMutual -= r1 * MathEx.log2((double)r1) * px[neighbor];
                    }
                    if (!(r2 > 0.0)) continue;
                    newMutual -= r2 * MathEx.log2((double)r2) * px[neighbor];
                }
                if (!(newMutual < oldMutual)) continue;
                for (int neighbor : neighbors[i3]) {
                    int[] nArray = size[neighbor];
                    int n2 = group[i3];
                    nArray[n2] = nArray[n2] - 1;
                    int[] nArray2 = size[neighbor];
                    int n3 = dominantCluster[i3];
                    nArray2[n3] = nArray2[n3] + 1;
                    int mi = dominantCluster[i3];
                    int mk = dominantCluster[neighbor];
                    if (size[neighbor][mi] <= size[neighbor][mk]) continue;
                    dominantCluster[neighbor] = dominantCluster[i3];
                }
                group[i3] = dominantCluster[i3];
            }
            double prevObj = entropy;
            entropy = MEC.entropy(k, neighbors, size, px);
            diff = prevObj - entropy;
            logger.info("Iteration {}: entropy = {}", (Object)iter, (Object)entropy);
            if (controller == null) continue;
            controller.submit((Object)new AlgoStatus(iter, entropy));
            if (controller.isInterrupted()) break;
        }
        int[] clusterSize = new int[k];
        for (int i4 = 0; i4 < n; ++i4) {
            int n4 = group[i4];
            clusterSize[n4] = clusterSize[n4] + 1;
        }
        int numClusters = 0;
        int j = 0;
        for (i2 = 0; i2 < k; ++i2) {
            if (clusterSize[i2] <= 0) continue;
            ++numClusters;
            clusterSize[i2] = j++;
        }
        for (i2 = 0; i2 < n; ++i2) {
            group[i2] = clusterSize[group[i2]];
        }
        return new MEC<T>(numClusters, group, entropy, radius, nns);
    }

    public int predict(T x) {
        ArrayList neighbors = new ArrayList();
        this.nns.search(x, this.radius, neighbors);
        if (neighbors.isEmpty()) {
            return Integer.MAX_VALUE;
        }
        int[] label = new int[this.k];
        for (Neighbor neighbor : neighbors) {
            int y;
            int n = y = this.group[neighbor.index()];
            label[n] = label[n] + 1;
        }
        return MathEx.whichMax((int[])label);
    }

    @Override
    public String toString() {
        return String.format("%sEntropy %11.5f%n", super.toString(), this.entropy);
    }

    private static double entropy(int k, int[][] neighbors, int[][] size, double[] px) {
        return IntStream.range(0, neighbors.length).parallel().mapToDouble(i -> {
            double conditionalEntropy = 0.0;
            int ni = neighbors[i].length;
            int[] ci = size[i];
            for (int j = 0; j < k; ++j) {
                if (ci[j] <= 0) continue;
                double r = (double)ci[j] / (double)ni;
                conditionalEntropy -= r * MathEx.log2((double)r);
            }
            return conditionalEntropy *= px[i];
        }).sum();
    }

    public record Options(int k, double radius, int maxIter, double tol, IterativeAlgorithmController<AlgoStatus> controller) {
        public Options {
            if (k < 2) {
                throw new IllegalArgumentException("Invalid k: " + k);
            }
            if (radius <= 0.0) {
                throw new IllegalArgumentException("Invalid radius: " + radius);
            }
            if (maxIter <= 0) {
                throw new IllegalArgumentException("Invalid maximum number of iterations: " + maxIter);
            }
            if (tol < 0.0) {
                throw new IllegalArgumentException("Invalid tolerance: " + tol);
            }
        }

        public Options(int k, double radius) {
            this(k, radius, 500, 1.0E-4, null);
        }

        public Properties toProperties() {
            Properties props = new Properties();
            props.setProperty("smile.mec.k", Integer.toString(this.k));
            props.setProperty("smile.mec.radius", Double.toString(this.radius));
            props.setProperty("smile.mec.iterations", Integer.toString(this.maxIter));
            props.setProperty("smile.mec.tolerance", Double.toString(this.tol));
            return props;
        }

        public static Options of(Properties props) {
            int k = Integer.parseInt(props.getProperty("smile.mec.k", "2"));
            double radius = Double.parseDouble(props.getProperty("smile.mec.radius", "1.0"));
            int maxIter = Integer.parseInt(props.getProperty("smile.mec.iterations", "500"));
            double tol = Double.parseDouble(props.getProperty("smile.mec.tolerance", "1E-4"));
            return new Options(k, radius, maxIter, tol, null);
        }
    }
}

