/*
 * Decompiled with CFR 0.152.
 */
package smile.stat.distribution;

import java.util.ArrayList;
import java.util.List;
import smile.math.Math;
import smile.stat.distribution.MultivariateExponentialFamilyMixture;
import smile.stat.distribution.MultivariateGaussianDistribution;
import smile.stat.distribution.MultivariateMixture;

public class MultivariateGaussianMixture
extends MultivariateExponentialFamilyMixture {
    public MultivariateGaussianMixture(List<MultivariateMixture.Component> mixture) {
        super(mixture);
    }

    public MultivariateGaussianMixture(double[][] data, int k) {
        this(data, k, false);
    }

    public MultivariateGaussianMixture(double[][] data, int k, boolean diagonal) {
        int i;
        int j;
        int i2;
        if (k < 2) {
            throw new IllegalArgumentException("Invalid number of components in the mixture.");
        }
        int n = data.length;
        int d = data[0].length;
        double[] mu = new double[d];
        double[][] sigma = new double[d][d];
        for (i2 = 0; i2 < n; ++i2) {
            for (j = 0; j < d; ++j) {
                int n2 = j;
                mu[n2] = mu[n2] + data[i2][j];
            }
        }
        int j2 = 0;
        while (j2 < d) {
            int n3 = j2++;
            mu[n3] = mu[n3] / (double)n;
        }
        if (diagonal) {
            for (i2 = 0; i2 < n; ++i2) {
                for (j = 0; j < d; ++j) {
                    double[] dArray = sigma[j];
                    int n4 = j;
                    dArray[n4] = dArray[n4] + (data[i2][j] - mu[j]) * (data[i2][j] - mu[j]);
                }
            }
            j2 = 0;
            while (j2 < d) {
                double[] dArray = sigma[j2];
                int n5 = j2++;
                dArray[n5] = dArray[n5] / (double)(n - 1);
            }
        } else {
            for (i2 = 0; i2 < n; ++i2) {
                for (j = 0; j < d; ++j) {
                    for (int l = 0; l <= j; ++l) {
                        double[] dArray = sigma[j];
                        int n6 = l;
                        dArray[n6] = dArray[n6] + (data[i2][j] - mu[j]) * (data[i2][l] - mu[l]);
                    }
                }
            }
            for (j2 = 0; j2 < d; ++j2) {
                for (int l = 0; l <= j2; ++l) {
                    double[] dArray = sigma[j2];
                    int n7 = l;
                    dArray[n7] = dArray[n7] / (double)(n - 1);
                    sigma[l][j2] = sigma[j2][l];
                }
            }
        }
        double[] centroid = data[Math.randomInt(n)];
        MultivariateMixture.Component c = new MultivariateMixture.Component();
        c.priori = 1.0 / (double)k;
        MultivariateGaussianDistribution gaussian = new MultivariateGaussianDistribution(centroid, sigma);
        gaussian.diagonal = diagonal;
        c.distribution = gaussian;
        this.components.add(c);
        double[] D = new double[n];
        for (i = 0; i < n; ++i) {
            D[i] = Double.MAX_VALUE;
        }
        for (i = 1; i < k; ++i) {
            int index;
            for (int j3 = 0; j3 < n; ++j3) {
                double dist = Math.squaredDistance(data[j3], centroid);
                if (!(dist < D[j3])) continue;
                D[j3] = dist;
            }
            double cutoff = Math.random() * Math.sum(D);
            double cost = 0.0;
            for (index = 0; index < n && !((cost += D[index]) >= cutoff); ++index) {
            }
            centroid = data[index];
            c = new MultivariateMixture.Component();
            c.priori = 1.0 / (double)k;
            gaussian = new MultivariateGaussianDistribution(centroid, sigma);
            gaussian.diagonal = diagonal;
            c.distribution = gaussian;
            this.components.add(c);
        }
        this.EM(this.components, data);
    }

    public MultivariateGaussianMixture(double[][] data) {
        this(data, false);
    }

    public MultivariateGaussianMixture(double[][] data, boolean diagonal) {
        if (data.length < 20) {
            throw new IllegalArgumentException("Too few samples.");
        }
        ArrayList<MultivariateMixture.Component> mixture = new ArrayList<MultivariateMixture.Component>();
        MultivariateMixture.Component c = new MultivariateMixture.Component();
        c.priori = 1.0;
        c.distribution = new MultivariateGaussianDistribution(data, diagonal);
        mixture.add(c);
        int freedom = 0;
        for (int i = 0; i < mixture.size(); ++i) {
            freedom += ((MultivariateMixture.Component)mixture.get((int)i)).distribution.npara();
        }
        double bic = 0.0;
        for (double[] x : data) {
            double p = c.distribution.p(x);
            if (!(p > 0.0)) continue;
            bic += Math.log(p);
        }
        bic -= 0.5 * (double)freedom * Math.log(data.length);
        double b = Double.NEGATIVE_INFINITY;
        while (bic > b) {
            b = bic;
            this.components = (ArrayList)mixture.clone();
            this.split(mixture);
            bic = this.EM(mixture, data);
            freedom = 0;
            for (int i = 0; i < mixture.size(); ++i) {
                freedom += mixture.get((int)i).distribution.npara();
            }
            bic -= 0.5 * (double)freedom * Math.log(data.length);
        }
    }

    private void split(List<MultivariateMixture.Component> mixture) {
        MultivariateMixture.Component componentToSplit = null;
        double maxSigma = 0.0;
        for (MultivariateMixture.Component c : mixture) {
            double sigma = ((MultivariateGaussianDistribution)c.distribution).scatter();
            if (!(sigma > maxSigma)) continue;
            maxSigma = sigma;
            componentToSplit = c;
        }
        double[][] delta = ((MultivariateGaussianDistribution)componentToSplit.distribution).cov();
        double[] mu = ((MultivariateGaussianDistribution)componentToSplit.distribution).mean();
        MultivariateMixture.Component c = new MultivariateMixture.Component();
        c.priori = componentToSplit.priori / 2.0;
        double[] mu1 = new double[mu.length];
        for (int i = 0; i < mu.length; ++i) {
            mu1[i] = mu[i] + Math.sqrt(delta[i][i]) / 2.0;
        }
        c.distribution = new MultivariateGaussianDistribution(mu1, delta);
        mixture.add(c);
        c = new MultivariateMixture.Component();
        c.priori = componentToSplit.priori / 2.0;
        double[] mu2 = new double[mu.length];
        for (int i = 0; i < mu.length; ++i) {
            mu2[i] = mu[i] - Math.sqrt(delta[i][i]) / 2.0;
        }
        c.distribution = new MultivariateGaussianDistribution(mu2, delta);
        mixture.add(c);
        mixture.remove(componentToSplit);
    }
}

