/*
 * Decompiled with CFR 0.152.
 */
package smile.base.rbf;

import java.io.Serializable;
import java.util.Arrays;
import smile.clustering.CLARANS;
import smile.clustering.KMeans;
import smile.math.MathEx;
import smile.math.distance.EuclideanDistance;
import smile.math.distance.Metric;
import smile.math.rbf.GaussianRadialBasis;
import smile.math.rbf.RadialBasisFunction;

public class RBF<T>
implements Serializable {
    private static final long serialVersionUID = 2L;
    private final T center;
    private final RadialBasisFunction rbf;
    private final Metric<T> distance;

    public RBF(T center, RadialBasisFunction rbf, Metric<T> distance) {
        this.center = center;
        this.rbf = rbf;
        this.distance = distance;
    }

    public double f(T x) {
        return this.rbf.f(this.distance.d(x, this.center));
    }

    public static <T> RBF<T>[] of(T[] centers, RadialBasisFunction basis, Metric<T> distance) {
        int k = centers.length;
        RBF[] rbf = new RBF[k];
        for (int i = 0; i < k; ++i) {
            rbf[i] = new RBF<T>(centers[i], basis, distance);
        }
        return rbf;
    }

    public static <T> RBF<T>[] of(T[] centers, RadialBasisFunction[] basis, Metric<T> distance) {
        int k = centers.length;
        RBF[] rbf = new RBF[k];
        for (int i = 0; i < k; ++i) {
            rbf[i] = new RBF<T>(centers[i], basis[i], distance);
        }
        return rbf;
    }

    private static <T> double estimateWidth(T[] centers, Metric<T> distance) {
        int k = centers.length;
        double r0 = 0.0;
        for (int i = 0; i < k; ++i) {
            for (int j = 0; j < i; ++j) {
                double d = distance.d(centers[i], centers[j]);
                if (!(r0 < d)) continue;
                r0 = d;
            }
        }
        return r0 /= Math.sqrt(2 * k);
    }

    private static <T> double[] estimateWidth(T[] centers, Metric<T> distance, int p) {
        int k = centers.length;
        double[] d = new double[k];
        double[] r = new double[k];
        for (int i = 0; i < k; ++i) {
            for (int j = 0; j < k; ++j) {
                d[j] = distance.d(centers[i], centers[j]);
            }
            Arrays.sort(d);
            double r0 = 0.0;
            for (int j = 1; j <= p; ++j) {
                r0 += d[j];
            }
            r[i] = r0 / (double)p;
        }
        return r;
    }

    private static <T> double[] estimateWidth(T[] x, int[] y, T[] centers, int[] clusterSize, Metric<T> distance, double r) {
        int i;
        int k = centers.length;
        double[] sigma = new double[k];
        for (i = 0; i < x.length; ++i) {
            int n = y[i];
            sigma[n] = sigma[n] + MathEx.pow2((double)distance.d(x[i], centers[y[i]]));
        }
        i = 0;
        while (i < k) {
            if (clusterSize[i] >= 5 || sigma[i] != 0.0) {
                sigma[i] = Math.sqrt(sigma[i] / (double)clusterSize[i]);
            } else {
                sigma[i] = Double.POSITIVE_INFINITY;
                for (int j = 0; j < k; ++j) {
                    double d;
                    if (i == j || !((d = distance.d(centers[i], centers[j])) < sigma[i])) continue;
                    sigma[i] = d;
                }
                int n = i;
                sigma[n] = sigma[n] / 2.0;
            }
            int n = i++;
            sigma[n] = sigma[n] * r;
        }
        return sigma;
    }

    private static GaussianRadialBasis[] gaussian(double[] width) {
        int k = width.length;
        GaussianRadialBasis[] basis = new GaussianRadialBasis[k];
        for (int i = 0; i < k; ++i) {
            basis[i] = new GaussianRadialBasis(width[i]);
        }
        return basis;
    }

    public static RBF<double[]>[] fit(double[][] x, int k) {
        KMeans kmeans = KMeans.fit(x, k, 10, 1.0E-4);
        double[][] centers = (double[][])kmeans.centroids;
        EuclideanDistance distance = new EuclideanDistance();
        GaussianRadialBasis basis = new GaussianRadialBasis(RBF.estimateWidth(centers, distance));
        return RBF.of(centers, (RadialBasisFunction)basis, distance);
    }

    public static RBF<double[]>[] fit(double[][] x, int k, int p) {
        if (p < 1 || p >= k) {
            throw new IllegalArgumentException("Invalid number of nearest neighbors: " + p);
        }
        KMeans kmeans = KMeans.fit(x, k, 10, 1.0E-4);
        double[][] centers = (double[][])kmeans.centroids;
        EuclideanDistance distance = new EuclideanDistance();
        double[] width = RBF.estimateWidth(centers, distance, p);
        GaussianRadialBasis[] basis = RBF.gaussian(width);
        return RBF.of(centers, (RadialBasisFunction[])basis, distance);
    }

    public static RBF<double[]>[] fit(double[][] x, int k, double r) {
        if (r <= 0.0) {
            throw new IllegalArgumentException("Invalid scaling parameter: " + r);
        }
        KMeans kmeans = KMeans.fit(x, k, 10, 1.0E-4);
        double[][] centers = (double[][])kmeans.centroids;
        EuclideanDistance distance = new EuclideanDistance();
        double[] width = RBF.estimateWidth(x, kmeans.y, centers, kmeans.size, distance, r);
        GaussianRadialBasis[] basis = RBF.gaussian(width);
        return RBF.of(centers, (RadialBasisFunction[])basis, distance);
    }

    public static <T> RBF<T>[] fit(T[] x, Metric<T> distance, int k) {
        CLARANS<T> clarans = CLARANS.fit(x, distance, k);
        Object[] centers = clarans.centroids;
        GaussianRadialBasis basis = new GaussianRadialBasis(RBF.estimateWidth(centers, distance));
        return RBF.of(centers, (RadialBasisFunction)basis, distance);
    }

    public static <T> RBF<T>[] fit(T[] x, Metric<T> distance, int k, int p) {
        if (p < 1 || p >= k) {
            throw new IllegalArgumentException("Invalid number of nearest neighbors: " + p);
        }
        CLARANS<T> clarans = CLARANS.fit(x, distance, k);
        Object[] centers = clarans.centroids;
        double[] width = RBF.estimateWidth(centers, distance, p);
        GaussianRadialBasis[] basis = RBF.gaussian(width);
        return RBF.of(centers, (RadialBasisFunction[])basis, distance);
    }

    public static <T> RBF<T>[] fit(T[] x, Metric<T> distance, int k, double r) {
        if (r <= 0.0) {
            throw new IllegalArgumentException("Invalid scaling parameter: " + r);
        }
        CLARANS<T> clarans = CLARANS.fit(x, distance, k);
        Object[] centers = clarans.centroids;
        double[] width = RBF.estimateWidth(x, clarans.y, centers, clarans.size, distance, r);
        GaussianRadialBasis[] basis = RBF.gaussian(width);
        return RBF.of(centers, (RadialBasisFunction[])basis, distance);
    }
}

