/*
 * Decompiled with CFR 0.152.
 */
package smile.stat.distribution;

import java.util.Arrays;
import smile.linalg.UPLO;
import smile.math.MathEx;
import smile.stat.distribution.GaussianDistribution;
import smile.stat.distribution.MultivariateDistribution;
import smile.stat.distribution.MultivariateExponentialFamily;
import smile.stat.distribution.MultivariateMixture;
import smile.tensor.Cholesky;
import smile.tensor.DenseMatrix;
import smile.tensor.Matrix;
import smile.tensor.ScalarType;
import smile.tensor.Vector;

public class MultivariateGaussianDistribution
implements MultivariateDistribution,
MultivariateExponentialFamily {
    private static final long serialVersionUID = 2L;
    private static final double LOG2PIE = Math.log(17.079468445347132);
    public final double[] mu;
    public final DenseMatrix sigma;
    public final boolean diagonal;
    private int dim;
    private Matrix sigmaInv;
    private DenseMatrix sigmaL;
    private double sigmaDet;
    private double pdfConstant;
    private final int length;

    public MultivariateGaussianDistribution(double[] mean, double variance) {
        if (variance <= 0.0) {
            throw new IllegalArgumentException("Variance is not positive: " + variance);
        }
        this.mu = mean;
        double[] v = new double[mean.length];
        Arrays.fill(v, variance);
        this.sigma = DenseMatrix.diagflat(v);
        this.diagonal = true;
        this.length = this.mu.length + 1;
        this.init();
    }

    public MultivariateGaussianDistribution(double[] mean, double[] variance) {
        if (mean.length != variance.length) {
            throw new IllegalArgumentException("Mean vector and covariance matrix have different dimension");
        }
        for (double v : variance) {
            if (!(v <= 0.0)) continue;
            throw new IllegalArgumentException("Variance is not positive: " + v);
        }
        this.mu = mean;
        this.sigma = DenseMatrix.diagflat(variance);
        this.diagonal = true;
        this.length = 2 * this.mu.length;
        this.init();
    }

    public MultivariateGaussianDistribution(double[] mean, DenseMatrix cov) {
        if (mean.length != cov.nrow()) {
            throw new IllegalArgumentException("Mean vector and covariance matrix have different dimension");
        }
        this.mu = mean;
        this.sigma = cov;
        this.diagonal = false;
        this.length = this.mu.length + this.mu.length * (this.mu.length + 1) / 2;
        this.init();
    }

    public static MultivariateGaussianDistribution fit(double[][] data) {
        return MultivariateGaussianDistribution.fit(data, false);
    }

    public static MultivariateGaussianDistribution fit(double[][] data, boolean diagonal) {
        double[] mu = MathEx.colMeans(data);
        int n = data.length;
        int d = mu.length;
        if (diagonal) {
            double[] variance = new double[d];
            for (double[] x : data) {
                for (int j = 0; j < d; ++j) {
                    int n2 = j;
                    variance[n2] = variance[n2] + (x[j] - mu[j]) * (x[j] - mu[j]);
                }
            }
            int n1 = n - 1;
            int j = 0;
            while (j < d) {
                int n3 = j++;
                variance[n3] = variance[n3] / (double)n1;
            }
            return new MultivariateGaussianDistribution(mu, variance);
        }
        return new MultivariateGaussianDistribution(mu, DenseMatrix.of(MathEx.cov(data, mu)));
    }

    private void init() {
        this.dim = this.mu.length;
        this.sigma.withUplo(UPLO.LOWER);
        Cholesky cholesky = this.sigma.copy().cholesky();
        this.sigmaInv = cholesky.inverse();
        this.sigmaDet = cholesky.det();
        this.sigmaL = cholesky.lu();
        this.pdfConstant = ((double)this.dim * Math.log(Math.PI * 2) + Math.log(this.sigmaDet)) / 2.0;
    }

    @Override
    public int length() {
        return this.length;
    }

    @Override
    public double entropy() {
        return ((double)this.dim * LOG2PIE + Math.log(this.sigmaDet)) / 2.0;
    }

    @Override
    public double[] mean() {
        return this.mu;
    }

    @Override
    public DenseMatrix cov() {
        return this.sigma;
    }

    public double scatter() {
        return this.sigmaDet;
    }

    @Override
    public double logp(double[] x) {
        if (x.length != this.dim) {
            throw new IllegalArgumentException("Sample has different dimension.");
        }
        double[] v = (double[])x.clone();
        MathEx.sub(v, this.mu);
        double result = this.sigmaInv.xAx(Vector.column(v)) / -2.0;
        return result - this.pdfConstant;
    }

    @Override
    public double p(double[] x) {
        return Math.exp(this.logp(x));
    }

    @Override
    public double cdf(double[] x) {
        if (x.length != this.dim) {
            throw new IllegalArgumentException("Sample has different dimension.");
        }
        int Nmax = 10000;
        double alph = GaussianDistribution.getInstance().quantile(0.999);
        double errMax = 0.001;
        double[] v = (double[])x.clone();
        MathEx.sub(v, this.mu);
        double p = 0.0;
        double varSum = 0.0;
        double[] e = new double[this.dim];
        double[] f = new double[this.dim];
        e[0] = GaussianDistribution.getInstance().cdf(v[0] / this.sigmaL.get(0, 0));
        f[0] = e[0];
        double[] y = new double[this.dim];
        double err = 2.0 * errMax;
        for (int N = 1; err > errMax && N <= Nmax; ++N) {
            double[] w = MathEx.random(this.dim - 1);
            for (int i = 1; i < this.dim; ++i) {
                y[i - 1] = GaussianDistribution.getInstance().quantile(w[i - 1] * e[i - 1]);
                double q = 0.0;
                for (int j = 0; j < i; ++j) {
                    q += this.sigmaL.get(i, j) * y[j];
                }
                e[i] = GaussianDistribution.getInstance().cdf((v[i] - q) / this.sigmaL.get(i, i));
                f[i] = e[i] * f[i - 1];
            }
            double del = (f[this.dim - 1] - p) / (double)N;
            p += del;
            varSum = (double)(N - 2) * varSum / (double)N + del * del;
            err = alph * Math.sqrt(varSum);
        }
        return p;
    }

    public double[] rand() {
        double[] spt = new double[this.mu.length];
        for (int i = 0; i < this.mu.length; ++i) {
            double v;
            double y;
            double u;
            double x;
            double q;
            while ((q = (x = (u = MathEx.random()) - 0.449871) * x + (y = Math.abs(v = 1.7156 * (MathEx.random() - 0.5)) + 0.386595) * (0.196 * y - 0.25472 * x)) > 0.27597 && (q > 0.27846 || v * v > -4.0 * Math.log(u) * u * u)) {
            }
            spt[i] = v / u;
        }
        double[] pt = new double[this.sigmaL.nrow()];
        for (int i = 0; i < pt.length; ++i) {
            for (int j = 0; j <= i; ++j) {
                int n = i;
                pt[n] = pt[n] + this.sigmaL.get(i, j) * spt[j];
            }
        }
        MathEx.add(pt, this.mu);
        return pt;
    }

    public double[][] rand(int n) {
        double[][] data = new double[n][];
        for (int i = 0; i < n; ++i) {
            data[i] = this.rand();
        }
        return data;
    }

    @Override
    public MultivariateMixture.Component M(double[][] data, double[] posteriori) {
        MultivariateGaussianDistribution gaussian;
        int i;
        int n = data.length;
        int d = data[0].length;
        double alpha = 0.0;
        double[] mean = new double[d];
        for (int k = 0; k < n; ++k) {
            alpha += posteriori[k];
            double[] x = data[k];
            for (i = 0; i < d; ++i) {
                int n2 = i;
                mean[n2] = mean[n2] + x[i] * posteriori[k];
            }
        }
        int i2 = 0;
        while (i2 < d) {
            int n3 = i2++;
            mean[n3] = mean[n3] / alpha;
        }
        if (this.diagonal) {
            double[] variance = new double[d];
            for (k = 0; k < n; ++k) {
                double[] x = data[k];
                for (int i3 = 0; i3 < d; ++i3) {
                    int n4 = i3;
                    variance[n4] = variance[n4] + (x[i3] - mean[i3]) * (x[i3] - mean[i3]) * posteriori[k];
                }
            }
            i = 0;
            while (i < d) {
                int n5 = i++;
                variance[n5] = variance[n5] / alpha;
            }
            gaussian = new MultivariateGaussianDistribution(mean, variance);
        } else {
            DenseMatrix cov = DenseMatrix.zeros(ScalarType.Float64, d, d);
            for (k = 0; k < n; ++k) {
                double[] x = data[k];
                for (int i4 = 0; i4 < d; ++i4) {
                    for (int j = 0; j < d; ++j) {
                        cov.add(i4, j, (x[i4] - mean[i4]) * (x[j] - mean[j]) * posteriori[k]);
                    }
                }
            }
            for (i = 0; i < d; ++i) {
                for (int j = 0; j < d; ++j) {
                    cov.div(i, j, alpha);
                }
                cov.mul(i, i, 1.00001);
            }
            gaussian = new MultivariateGaussianDistribution(mean, cov);
        }
        return new MultivariateMixture.Component(alpha, gaussian);
    }

    public String toString() {
        return String.format("MultivariateGaussian(mu = %s, sigma = %s)", Arrays.toString(this.mu), this.sigma);
    }
}

