/*
 * Decompiled with CFR 0.152.
 */
package umontreal.ssj.probdistmulti;

import umontreal.ssj.functions.MathFunction;
import umontreal.ssj.probdistmulti.DiscreteDistributionIntMulti;
import umontreal.ssj.util.Num;
import umontreal.ssj.util.RootFinder;

public class NegativeMultinomialDist
extends DiscreteDistributionIntMulti {
    protected double n;
    protected double[] p;

    public NegativeMultinomialDist(double n, double[] p) {
        this.setParams(n, p);
    }

    @Override
    public double prob(int[] x) {
        return NegativeMultinomialDist.prob_(this.n, this.p, x);
    }

    @Override
    public double[] getMean() {
        return NegativeMultinomialDist.getMean_(this.n, this.p);
    }

    @Override
    public double[][] getCovariance() {
        return NegativeMultinomialDist.getCovariance_(this.n, this.p);
    }

    @Override
    public double[][] getCorrelation() {
        return NegativeMultinomialDist.getCorrelation_(this.n, this.p);
    }

    private static void verifParam(double n, double[] p) {
        double sumPi = 0.0;
        if (n <= 0.0) {
            throw new IllegalArgumentException("n <= 0");
        }
        for (int i = 0; i < p.length; ++i) {
            if (p[i] < 0.0 || p[i] >= 1.0) {
                throw new IllegalArgumentException("p is not a probability vector");
            }
            sumPi += p[i];
        }
        if (sumPi >= 1.0) {
            throw new IllegalArgumentException("p is not a probability vector");
        }
    }

    private static double prob_(double n, double[] p, int[] x) {
        double p0 = 0.0;
        double sumPi = 0.0;
        double sumXi = 0.0;
        double sumLnXiFact = 0.0;
        double sumXiLnPi = 0.0;
        if (x.length != p.length) {
            throw new IllegalArgumentException("x and p must have the same size");
        }
        for (int i = 0; i < p.length; ++i) {
            sumPi += p[i];
            sumXi += (double)x[i];
            sumLnXiFact += Num.lnFactorial(x[i]);
            sumXiLnPi += (double)x[i] * Math.log(p[i]);
        }
        p0 = 1.0 - sumPi;
        return Math.exp(Num.lnGamma(n + sumXi) - (Num.lnGamma(n) + sumLnXiFact) + n * Math.log(p0) + sumXiLnPi);
    }

    public static double prob(double n, double[] p, int[] x) {
        NegativeMultinomialDist.verifParam(n, p);
        return NegativeMultinomialDist.prob_(n, p, x);
    }

    private static double cdf_(double n, double[] p, int[] x) {
        throw new UnsupportedOperationException("cdf not implemented");
    }

    public static double cdf(double n, double[] p, int[] x) {
        NegativeMultinomialDist.verifParam(n, p);
        return NegativeMultinomialDist.cdf_(n, p, x);
    }

    private static double[] getMean_(double n, double[] p) {
        int i;
        double p0 = 0.0;
        double sumPi = 0.0;
        double[] mean = new double[p.length];
        for (i = 0; i < p.length; ++i) {
            sumPi += p[i];
        }
        p0 = 1.0 - sumPi;
        for (i = 0; i < p.length; ++i) {
            mean[i] = n * p[i] / p0;
        }
        return mean;
    }

    public static double[] getMean(double n, double[] p) {
        NegativeMultinomialDist.verifParam(n, p);
        return NegativeMultinomialDist.getMean_(n, p);
    }

    private static double[][] getCovariance_(double n, double[] p) {
        int i;
        double p0 = 0.0;
        double sumPi = 0.0;
        double[][] cov = new double[p.length][p.length];
        for (i = 0; i < p.length; ++i) {
            sumPi += p[i];
        }
        p0 = 1.0 - sumPi;
        for (i = 0; i < p.length; ++i) {
            for (int j = 0; j < p.length; ++j) {
                cov[i][j] = n * p[i] * p[j] / (p0 * p0);
            }
            cov[i][i] = n * p[i] * (p[i] + p0) / (p0 * p0);
        }
        return cov;
    }

    public static double[][] getCovariance(double n, double[] p) {
        NegativeMultinomialDist.verifParam(n, p);
        return NegativeMultinomialDist.getCovariance_(n, p);
    }

    private static double[][] getCorrelation_(double n, double[] p) {
        int i;
        double[][] corr = new double[p.length][p.length];
        double sumPi = 0.0;
        for (i = 0; i < p.length; ++i) {
            sumPi += p[i];
        }
        double p0 = 1.0 - sumPi;
        for (i = 0; i < p.length; ++i) {
            for (int j = 0; j < p.length; ++j) {
                corr[i][j] = Math.sqrt(p[i] * p[j] / ((p0 + p[i]) * (p0 + p[j])));
            }
            corr[i][i] = 1.0;
        }
        return corr;
    }

    public static double[][] getCorrelation(double n, double[] p) {
        NegativeMultinomialDist.verifParam(n, p);
        return NegativeMultinomialDist.getCorrelation_(n, p);
    }

    public static double[] getMLE(int[][] x, int m, int d) {
        int j;
        int i;
        int[] ups = new int[m];
        double[] mean = new double[d];
        for (i = 0; i < d; ++i) {
            mean[i] = 0.0;
        }
        for (j = 0; j < m; ++j) {
            ups[j] = 0;
            for (i = 0; i < d; ++i) {
                int n = j;
                ups[n] = ups[n] + x[j][i];
                int n2 = i;
                mean[n2] = mean[n2] + (double)x[j][i];
            }
        }
        i = 0;
        while (i < d) {
            int n = i++;
            mean[n] = mean[n] / (double)m;
        }
        int M = ups[0];
        for (j = 1; j < m; ++j) {
            if (ups[j] <= M) continue;
            M = ups[j];
        }
        if (M >= Integer.MAX_VALUE) {
            throw new IllegalArgumentException("n/p_i too large");
        }
        double[] Fl = new double[M];
        for (int l = 0; l < M; ++l) {
            int prop = 0;
            for (j = 0; j < m; ++j) {
                if (ups[j] <= l) continue;
                ++prop;
            }
            Fl[l] = (double)prop / (double)m;
        }
        double[] parameters = new double[d + 1];
        Function f = new Function(m, M, ups, Fl);
        parameters[0] = RootFinder.brentDekker(1.0E-9, 1.0E9, f, 1.0E-5);
        double[] lambda = new double[d];
        double sumLambda = 0.0;
        for (i = 0; i < d; ++i) {
            lambda[i] = mean[i] / parameters[0];
            sumLambda += lambda[i];
        }
        for (i = 0; i < d; ++i) {
            parameters[i + 1] = lambda[i] / (1.0 + sumLambda);
            if (!(parameters[i + 1] > 1.0)) continue;
            throw new IllegalArgumentException("p_i > 1");
        }
        return parameters;
    }

    public static double getMLEninv(int[][] x, int m, int d) {
        int j;
        int i;
        int[] ups = new int[m];
        double[] mean = new double[d];
        for (i = 0; i < d; ++i) {
            mean[i] = 0.0;
        }
        for (j = 0; j < m; ++j) {
            ups[j] = 0;
            for (i = 0; i < d; ++i) {
                int n = j;
                ups[n] = ups[n] + x[j][i];
                int n2 = i;
                mean[n2] = mean[n2] + (double)x[j][i];
            }
        }
        i = 0;
        while (i < d) {
            int n = i++;
            mean[n] = mean[n] / (double)m;
        }
        int M = ups[0];
        for (j = 1; j < m; ++j) {
            if (ups[j] <= M) continue;
            M = ups[j];
        }
        if (M >= Integer.MAX_VALUE) {
            throw new IllegalArgumentException("n/p_i too large");
        }
        double[] Fl = new double[M];
        for (int l = 0; l < M; ++l) {
            int prop = 0;
            for (j = 0; j < m; ++j) {
                if (ups[j] <= l) continue;
                ++prop;
            }
            Fl[l] = (double)prop / (double)m;
        }
        FuncInv f = new FuncInv(m, M, ups, Fl);
        double nu = RootFinder.brentDekker(1.0E-8, 1.0E8, f, 1.0E-5);
        return nu;
    }

    public double getGamma() {
        return this.n;
    }

    public double[] getP() {
        return this.p;
    }

    public void setParams(double n, double[] p) {
        if (n <= 0.0) {
            throw new IllegalArgumentException("n <= 0");
        }
        this.n = n;
        this.dimension = p.length;
        this.p = new double[this.dimension];
        double sumPi = 0.0;
        for (int i = 0; i < this.dimension; ++i) {
            if (p[i] < 0.0 || p[i] >= 1.0) {
                throw new IllegalArgumentException("p is not a probability vector");
            }
            sumPi += p[i];
            this.p[i] = p[i];
        }
        if (sumPi >= 1.0) {
            throw new IllegalArgumentException("p is not a probability vector");
        }
    }

    private static class FuncInv
    extends Function
    implements MathFunction {
        public FuncInv(int k, int m, int[] ups, double[] Fl) {
            super(k, m, ups, Fl);
        }

        @Override
        public double evaluate(double nu) {
            double sum = 0.0;
            for (int l = 0; l < this.M; ++l) {
                sum += this.Fl[l] / (1.0 + nu * (double)l);
            }
            return sum * nu - Math.log1p((double)this.sumUps * nu / (double)this.k);
        }
    }

    private static class Function
    implements MathFunction {
        protected double[] Fl;
        protected int[] ups;
        protected int k;
        protected int M;
        protected int sumUps;

        public Function(int k, int m, int[] ups, double[] Fl) {
            this.k = k;
            this.M = m;
            this.Fl = new double[Fl.length];
            System.arraycopy(Fl, 0, this.Fl, 0, Fl.length);
            this.ups = new int[ups.length];
            System.arraycopy(ups, 0, this.ups, 0, ups.length);
            this.sumUps = 0;
            for (int i = 0; i < ups.length; ++i) {
                this.sumUps += ups[i];
            }
        }

        @Override
        public double evaluate(double gamma) {
            double sum = 0.0;
            for (int l = 0; l < this.M; ++l) {
                sum += this.Fl[l] / (gamma + (double)l);
            }
            return sum - Math.log1p((double)this.sumUps / ((double)this.k * gamma));
        }
    }
}

