/*
 * Decompiled with CFR 0.152.
 */
package smile.stat.distribution;

import smile.math.Math;
import smile.math.matrix.CholeskyDecomposition;
import smile.stat.distribution.AbstractMultivariateDistribution;
import smile.stat.distribution.GaussianDistribution;
import smile.stat.distribution.MultivariateExponentialFamily;
import smile.stat.distribution.MultivariateMixture;

public class MultivariateGaussianDistribution
extends AbstractMultivariateDistribution
implements MultivariateExponentialFamily {
    private static final double LOG2PIE = Math.log(17.079468445347132);
    double[] mu;
    double[][] sigma;
    boolean diagonal;
    private int dim;
    private double[][] sigmaInv;
    private double[][] sigmaL;
    private double sigmaDet;
    private double pdfConstant;
    private int numParameters;

    public MultivariateGaussianDistribution(double[] mean, double var) {
        if (var <= 0.0) {
            throw new IllegalArgumentException("Variance is not positive: " + var);
        }
        this.mu = new double[mean.length];
        this.sigma = new double[this.mu.length][this.mu.length];
        for (int i = 0; i < this.mu.length; ++i) {
            this.mu[i] = mean[i];
            this.sigma[i][i] = var;
        }
        this.diagonal = true;
        this.numParameters = this.mu.length + 1;
        this.init();
    }

    public MultivariateGaussianDistribution(double[] mean, double[] var) {
        if (mean.length != var.length) {
            throw new IllegalArgumentException("Mean vector and covariance matrix have different dimension");
        }
        this.mu = new double[mean.length];
        this.sigma = new double[this.mu.length][this.mu.length];
        for (int i = 0; i < this.mu.length; ++i) {
            if (var[i] <= 0.0) {
                throw new IllegalArgumentException("Variance is not positive: " + var[i]);
            }
            this.mu[i] = mean[i];
            this.sigma[i][i] = var[i];
        }
        this.diagonal = true;
        this.numParameters = 2 * this.mu.length;
        this.init();
    }

    public MultivariateGaussianDistribution(double[] mean, double[][] cov) {
        if (mean.length != cov.length) {
            throw new IllegalArgumentException("Mean vector and covariance matrix have different dimension");
        }
        this.mu = new double[mean.length];
        this.sigma = new double[mean.length][mean.length];
        for (int i = 0; i < this.mu.length; ++i) {
            this.mu[i] = mean[i];
            System.arraycopy(cov[i], 0, this.sigma[i], 0, this.mu.length);
        }
        this.diagonal = false;
        this.numParameters = this.mu.length + this.mu.length * (this.mu.length + 1) / 2;
        this.init();
    }

    public MultivariateGaussianDistribution(double[][] data) {
        this(data, false);
    }

    public MultivariateGaussianDistribution(double[][] data, boolean diagonal) {
        this.diagonal = diagonal;
        this.mu = Math.colMean(data);
        if (diagonal) {
            this.sigma = new double[data[0].length][data[0].length];
            for (int i = 0; i < data.length; ++i) {
                for (int j = 0; j < this.mu.length; ++j) {
                    double[] dArray = this.sigma[j];
                    int n = j;
                    dArray[n] = dArray[n] + (data[i][j] - this.mu[j]) * (data[i][j] - this.mu[j]);
                }
            }
            int j = 0;
            while (j < this.mu.length) {
                double[] dArray = this.sigma[j];
                int n = j++;
                dArray[n] = dArray[n] / (double)(data.length - 1);
            }
        } else {
            this.sigma = Math.cov(data, this.mu);
        }
        this.numParameters = this.mu.length + this.mu.length * (this.mu.length + 1) / 2;
        this.init();
    }

    private void init() {
        this.dim = this.mu.length;
        CholeskyDecomposition cholesky = new CholeskyDecomposition(this.sigma);
        this.sigmaInv = cholesky.inverse().array();
        this.sigmaDet = cholesky.det();
        this.sigmaL = cholesky.getL();
        this.pdfConstant = ((double)this.dim * Math.log(java.lang.Math.PI * 2) + Math.log(this.sigmaDet)) / 2.0;
    }

    public boolean isDiagonal() {
        return this.diagonal;
    }

    @Override
    public int npara() {
        return this.numParameters;
    }

    @Override
    public double entropy() {
        return ((double)this.dim * LOG2PIE + Math.log(this.sigmaDet)) / 2.0;
    }

    @Override
    public double[] mean() {
        return this.mu;
    }

    @Override
    public double[][] cov() {
        return this.sigma;
    }

    public double scatter() {
        return this.sigmaDet;
    }

    @Override
    public double logp(double[] x) {
        if (x.length != this.dim) {
            throw new IllegalArgumentException("Sample has different dimension.");
        }
        double[] v = (double[])x.clone();
        Math.minus(v, this.mu);
        double result = Math.xax(this.sigmaInv, v) / -2.0;
        return result - this.pdfConstant;
    }

    @Override
    public double p(double[] x) {
        return Math.exp(this.logp(x));
    }

    @Override
    public double cdf(double[] x) {
        if (x.length != this.dim) {
            throw new IllegalArgumentException("Sample has different dimension.");
        }
        int Nmax = 10000;
        double alph = GaussianDistribution.getInstance().quantile(0.999);
        double errMax = 0.001;
        double[] v = (double[])x.clone();
        Math.minus(v, this.mu);
        double p = 0.0;
        double varSum = 0.0;
        double[] f = new double[this.dim];
        f[0] = GaussianDistribution.getInstance().cdf(v[0] / this.sigmaL[0][0]);
        double[] y = new double[this.dim];
        double err = 2.0 * errMax;
        for (int N = 1; err > errMax && N <= Nmax; ++N) {
            double[] w = Math.random(this.dim - 1);
            for (int i = 1; i < this.dim; ++i) {
                y[i - 1] = GaussianDistribution.getInstance().quantile(w[i - 1] * f[i - 1]);
                double q = 0.0;
                for (int j = 0; j < i; ++j) {
                    q += this.sigmaL[i][j] * y[j];
                }
                f[i] = GaussianDistribution.getInstance().cdf((v[i] - q) / this.sigmaL[i][i]) * f[i - 1];
            }
            double del = (f[this.dim - 1] - p) / (double)N;
            p += del;
            varSum = (double)(N - 2) * varSum / (double)N + del * del;
            err = alph * Math.sqrt(varSum);
        }
        return p;
    }

    public double[] rand() {
        double[] spt = new double[this.mu.length];
        for (int i = 0; i < this.mu.length; ++i) {
            double v;
            double y;
            double u;
            double x;
            double q;
            while ((q = (x = (u = Math.random()) - 0.449871) * x + (y = Math.abs(v = 1.7156 * (Math.random() - 0.5)) + 0.386595) * (0.196 * y - 0.25472 * x)) > 0.27597 && (q > 0.27846 || v * v > -4.0 * Math.log(u) * u * u)) {
            }
            spt[i] = v / u;
        }
        double[] pt = new double[this.sigmaL.length];
        for (int i = 0; i < pt.length; ++i) {
            for (int j = 0; j <= i; ++j) {
                int n = i;
                pt[n] = pt[n] + this.sigmaL[i][j] * spt[j];
            }
        }
        Math.plus(pt, this.mu);
        return pt;
    }

    @Override
    public MultivariateMixture.Component M(double[][] x, double[] posteriori) {
        int i;
        int k;
        int n = x[0].length;
        double alpha = 0.0;
        double[] mean = new double[n];
        double[][] cov = new double[n][n];
        for (k = 0; k < x.length; ++k) {
            alpha += posteriori[k];
            for (i = 0; i < n; ++i) {
                int n2 = i;
                mean[n2] = mean[n2] + x[k][i] * posteriori[k];
            }
        }
        int i2 = 0;
        while (i2 < mean.length) {
            int n3 = i2++;
            mean[n3] = mean[n3] / alpha;
        }
        if (this.diagonal) {
            for (k = 0; k < x.length; ++k) {
                for (i = 0; i < n; ++i) {
                    double[] dArray = cov[i];
                    int n4 = i;
                    dArray[n4] = dArray[n4] + (x[k][i] - mean[i]) * (x[k][i] - mean[i]) * posteriori[k];
                }
            }
            i2 = 0;
            while (i2 < cov.length) {
                double[] dArray = cov[i2];
                int n5 = i2++;
                dArray[n5] = dArray[n5] / alpha;
            }
        } else {
            for (k = 0; k < x.length; ++k) {
                for (i = 0; i < n; ++i) {
                    for (int j = 0; j < n; ++j) {
                        double[] dArray = cov[i];
                        int n6 = j;
                        dArray[n6] = dArray[n6] + (x[k][i] - mean[i]) * (x[k][j] - mean[j]) * posteriori[k];
                    }
                }
            }
            i2 = 0;
            while (i2 < cov.length) {
                int j = 0;
                while (j < cov[i2].length) {
                    double[] dArray = cov[i2];
                    int n7 = j++;
                    dArray[n7] = dArray[n7] / alpha;
                }
                double[] dArray = cov[i2];
                int n8 = i2++;
                dArray[n8] = dArray[n8] * 1.00001;
            }
        }
        MultivariateMixture.Component c = new MultivariateMixture.Component();
        c.priori = alpha;
        MultivariateGaussianDistribution g = new MultivariateGaussianDistribution(mean, cov);
        g.diagonal = this.diagonal;
        c.distribution = g;
        return c;
    }

    public String toString() {
        int i;
        StringBuilder builder = new StringBuilder("Multivariate Gaussian Distribution:\nmu = [");
        for (i = 0; i < this.mu.length; ++i) {
            builder.append(this.mu[i]).append(" ");
        }
        builder.setCharAt(builder.length() - 1, ']');
        builder.append("\nSigma = [\n");
        for (i = 0; i < this.sigma.length; ++i) {
            builder.append('\t');
            for (int j = 0; j < this.sigma[i].length; ++j) {
                builder.append(this.sigma[i][j]).append(" ");
            }
            builder.append('\n');
        }
        builder.append("\t]");
        return builder.toString();
    }
}

