/*
 * Decompiled with CFR 0.152.
 */
package smile.clustering;

import java.util.Arrays;
import java.util.Properties;
import java.util.function.ToDoubleBiFunction;
import java.util.stream.IntStream;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import smile.clustering.CentroidClustering;
import smile.math.MathEx;
import smile.math.distance.EuclideanDistance;
import smile.tensor.DenseMatrix;
import smile.tensor.Eigen;
import smile.tensor.Matrix;
import smile.tensor.Vector;
import smile.util.AlgoStatus;
import smile.util.IterativeAlgorithmController;

public class DeterministicAnnealing {
    private static final Logger logger = LoggerFactory.getLogger(DeterministicAnnealing.class);

    private DeterministicAnnealing() {
    }

    public static CentroidClustering<double[], double[]> fit(double[][] data, int kmax, double alpha, int maxIter) {
        return DeterministicAnnealing.fit(data, new Options(kmax, alpha, maxIter));
    }

    public static CentroidClustering<double[], double[]> fit(double[][] data, Options options) {
        int kmax = options.kmax;
        double alpha = options.alpha;
        double splitTol = options.splitTol;
        IterativeAlgorithmController<AlgoStatus> controller = options.controller();
        int n = data.length;
        int d = data[0].length;
        double[][] centroids = new double[2 * kmax][d];
        double[][] posteriori = new double[n][2 * kmax];
        double[] priori = new double[2 * kmax];
        priori[1] = 0.5;
        priori[0] = 0.5;
        centroids[0] = MathEx.colMeans((double[][])data);
        for (int i2 = 0; i2 < d; ++i2) {
            centroids[1][i2] = centroids[0][i2] * 1.01;
        }
        DenseMatrix cov = DenseMatrix.of((double[][])MathEx.cov((double[][])data, (double[])centroids[0]));
        Vector ev = cov.vector(d);
        ev.fill(1.0);
        double lambda = Eigen.power((Matrix)cov, (Vector)ev, (double)0.0, (double)1.0E-4, (int)Math.max(20, 2 * cov.nrow()));
        double T = 2.0 * lambda + 0.01;
        int k = 2;
        boolean stop = false;
        boolean split = false;
        while (!stop) {
            double distortion = DeterministicAnnealing.update(data, T, k, centroids, posteriori, priori, options.maxIter, options.tol);
            if (k >= 2 * kmax && split) {
                stop = true;
            }
            if (controller != null) {
                controller.submit((Object)new AlgoStatus(k / 2, distortion, (Object)T));
                if (controller.isInterrupted()) {
                    stop = true;
                }
            }
            int currentK = k;
            for (int i3 = 0; i3 < currentK; i3 += 2) {
                int j;
                double norm = 0.0;
                for (j = 0; j < d; ++j) {
                    double diff = centroids[i3][j] - centroids[i3 + 1][j];
                    norm += diff * diff;
                }
                if (norm > splitTol) {
                    if (k < 2 * kmax) {
                        for (j = 0; j < d; ++j) {
                            centroids[k][j] = centroids[i3 + 1][j];
                            centroids[k + 1][j] = centroids[i3 + 1][j] * 1.01;
                        }
                        priori[k] = priori[i3 + 1] / 2.0;
                        priori[k + 1] = priori[i3 + 1] / 2.0;
                        priori[i3] = priori[i3] / 2.0;
                        priori[i3 + 1] = priori[i3] / 2.0;
                        k += 2;
                    }
                    if (currentK >= 2 * kmax) {
                        split = true;
                    }
                }
                for (j = 0; j < d; ++j) {
                    centroids[i3 + 1][j] = centroids[i3][j] * 1.01;
                }
            }
            if (split) {
                T /= alpha;
            } else if (k - currentK > 2) {
                T /= alpha;
                alpha += 5.0 * Math.pow(10.0, Math.log10(1.0 - alpha) - 1.0);
            } else {
                if (k > currentK && k == 2 * kmax - 2) {
                    alpha += 5.0 * Math.pow(10.0, Math.log10(1.0 - alpha) - 1.0);
                }
                T *= alpha;
            }
            if (!(alpha >= 1.0)) continue;
            break;
        }
        double[][] centers = new double[k /= 2][];
        for (int i4 = 0; i4 < k; ++i4) {
            centers[i4] = centroids[2 * i4];
        }
        int[] size = new int[k];
        int[] group = new int[n];
        int numClusters = k;
        IntStream.range(0, n).parallel().forEach(i -> {
            int cluster = -1;
            double nearest = Double.MAX_VALUE;
            for (int j = 0; j < numClusters; ++j) {
                double dist = MathEx.squaredDistance((double[])centers[j], (double[])data[i]);
                if (!(nearest > dist)) continue;
                nearest = dist;
                cluster = j;
            }
            group[i] = cluster;
            int n = cluster;
            size[n] = size[n] + 1;
        });
        IntStream.range(0, k).parallel().forEach(cluster -> {
            double[] center = centers[cluster];
            Arrays.fill(center, 0.0);
            for (int i = 0; i < n; ++i) {
                if (group[i] != cluster) continue;
                for (int j = 0; j < d; ++j) {
                    int n2 = j;
                    center[n2] = center[n2] + data[i][j];
                }
            }
            int j = 0;
            while (j < d) {
                int n3 = j++;
                center[n3] = center[n3] / (double)size[cluster];
            }
        });
        double[] proximity = new double[n];
        IntStream.range(0, n).parallel().forEach(i -> {
            proximity[i] = MathEx.squaredDistance((double[])centers[group[i]], (double[])data[i]);
        });
        EuclideanDistance distance = new EuclideanDistance();
        return new CentroidClustering<double[], double[]>("D.Annealing", (T[])centers, (ToDoubleBiFunction<double[], double[]>)distance, group, proximity);
    }

    private static double update(double[][] data, double T, int k, double[][] centroids, double[][] posteriori, double[] priori, int maxIter, double tol) {
        int n = data.length;
        int d = data[0].length;
        double distortion = Double.MAX_VALUE;
        double diff = Double.MAX_VALUE;
        for (int iter = 1; iter <= maxIter && diff > tol; ++iter) {
            int i2;
            double D = IntStream.range(0, n).parallel().mapToDouble(i -> {
                double Z = 0.0;
                double[] p = posteriori[i];
                double[] dist = new double[k];
                for (int j = 0; j < k; ++j) {
                    dist[j] = MathEx.squaredDistance((double[])data[i], (double[])centroids[j]);
                    p[j] = priori[j] * Math.exp(-dist[j] / T);
                    Z += p[j];
                }
                double sum = 0.0;
                for (int j = 0; j < k; ++j) {
                    int n = j;
                    p[n] = p[n] / Z;
                    sum += p[j] * dist[j];
                }
                return sum;
            }).sum();
            double H = IntStream.range(0, n).parallel().mapToDouble(i -> {
                double[] p = posteriori[i];
                double sum = 0.0;
                for (int j = 0; j < k; ++j) {
                    sum += -p[j] * Math.log(p[j]);
                }
                return sum;
            }).sum();
            Arrays.fill(priori, 0.0);
            for (i2 = 0; i2 < n; ++i2) {
                double[] p = posteriori[i2];
                for (int j = 0; j < k; ++j) {
                    int n2 = j;
                    priori[n2] = priori[n2] + p[j];
                }
            }
            i2 = 0;
            while (i2 < k) {
                int n3 = i2++;
                priori[n3] = priori[n3] / (double)n;
            }
            IntStream.range(0, k).parallel().forEach(i -> {
                Arrays.fill(centroids[i], 0.0);
                int j = 0;
                while (j < d) {
                    for (int m = 0; m < n; ++m) {
                        double[] dArray = centroids[i];
                        int n2 = j;
                        dArray[n2] = dArray[n2] + data[m][j] * posteriori[m][i];
                    }
                    double[] dArray = centroids[i];
                    int n3 = j++;
                    dArray[n3] = dArray[n3] / ((double)n * priori[i]);
                }
            });
            double DTH = D - T * H;
            diff = distortion - DTH;
            distortion = DTH;
            logger.info("Iterations {}: k = {}, temperature = {}, entropy = {}, soft distortion = {}", new Object[]{iter, k / 2, T, H, D});
        }
        return distortion;
    }

    public record Options(int kmax, double alpha, int maxIter, double tol, double splitTol, IterativeAlgorithmController<AlgoStatus> controller) {
        public Options {
            if (kmax < 2) {
                throw new IllegalArgumentException("Invalid number of clusters: " + kmax);
            }
            if (alpha <= 0.0 || alpha >= 1.0) {
                throw new IllegalArgumentException("Invalid alpha: " + alpha);
            }
            if (maxIter <= 0) {
                throw new IllegalArgumentException("Invalid maximum number of iterations: " + maxIter);
            }
            if (tol < 0.0) {
                throw new IllegalArgumentException("Invalid tolerance: " + tol);
            }
            if (splitTol < 0.0) {
                throw new IllegalArgumentException("Invalid split tolerance: " + splitTol);
            }
        }

        public Options(int kmax, double alpha, int maxIter) {
            this(kmax, alpha, maxIter, 1.0E-4, 0.01, null);
        }

        public Properties toProperties() {
            Properties props = new Properties();
            props.setProperty("smile.deterministic_annealing.k", Integer.toString(this.kmax));
            props.setProperty("smile.deterministic_annealing.alpha", Double.toString(this.alpha));
            props.setProperty("smile.deterministic_annealing.iterations", Integer.toString(this.maxIter));
            props.setProperty("smile.deterministic_annealing.tolerance", Double.toString(this.tol));
            props.setProperty("smile.deterministic_annealing.split_tolerance", Double.toString(this.splitTol));
            return props;
        }

        public static Options of(Properties props) {
            int kmax = Integer.parseInt(props.getProperty("smile.deterministic_annealing.k", "2"));
            double alpha = Double.parseDouble(props.getProperty("smile.deterministic_annealing.alpha", "0.9"));
            int maxIter = Integer.parseInt(props.getProperty("smile.deterministic_annealing.iterations", "100"));
            double tol = Double.parseDouble(props.getProperty("smile.deterministic_annealing.tolerance", "1E-4"));
            double splitTol = Double.parseDouble(props.getProperty("smile.deterministic_annealing.split_tolerance", "1E-2"));
            return new Options(kmax, alpha, maxIter, tol, splitTol, null);
        }
    }
}

