/*
 * Decompiled with CFR 0.152.
 */
package smile.stat.distribution;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import smile.math.MathEx;
import smile.stat.distribution.MultivariateExponentialFamily;
import smile.stat.distribution.MultivariateMixture;

public class MultivariateExponentialFamilyMixture
extends MultivariateMixture {
    private static final long serialVersionUID = 2L;
    private static final Logger logger = LoggerFactory.getLogger(MultivariateExponentialFamilyMixture.class);
    public final double L;
    public final double bic;

    public MultivariateExponentialFamilyMixture(MultivariateMixture.Component ... components) {
        this(0.0, 1, components);
    }

    MultivariateExponentialFamilyMixture(double L, int n, MultivariateMixture.Component ... components) {
        super(components);
        for (MultivariateMixture.Component component : components) {
            if (component.distribution instanceof MultivariateExponentialFamily) continue;
            throw new IllegalArgumentException("Component " + component + " is not of multivariate exponential family.");
        }
        this.L = L;
        this.bic = L - 0.5 * (double)this.length() * Math.log(n);
    }

    public static MultivariateExponentialFamilyMixture fit(double[][] x, MultivariateMixture.Component ... components) {
        return MultivariateExponentialFamilyMixture.fit(x, components, 0.2, 500, 1.0E-4);
    }

    public static MultivariateExponentialFamilyMixture fit(double[][] x, MultivariateMixture.Component[] components, double gamma, int maxIter, double tol) {
        if (x.length < components.length / 2) {
            throw new IllegalArgumentException("Too many components");
        }
        if (gamma < 0.0 || gamma > 0.2) {
            throw new IllegalArgumentException("Invalid regularization factor gamma.");
        }
        int n = x.length;
        int k = components.length;
        double[][] posteriori = new double[k][n];
        double L = 0.0;
        double diff = Double.MAX_VALUE;
        for (int iter = 1; iter <= maxIter && diff > tol; ++iter) {
            int i;
            for (int i2 = 0; i2 < k; ++i2) {
                MultivariateMixture.Component c = components[i2];
                for (int j = 0; j < n; ++j) {
                    posteriori[i2][j] = c.priori * c.distribution.p(x[j]);
                }
            }
            for (int j = 0; j < n; ++j) {
                int i3;
                double p = 0.0;
                for (i3 = 0; i3 < k; ++i3) {
                    p += posteriori[i3][j];
                }
                for (i3 = 0; i3 < k; ++i3) {
                    double[] dArray = posteriori[i3];
                    int n2 = j;
                    dArray[n2] = dArray[n2] / p;
                }
                if (!(gamma > 0.0)) continue;
                for (i3 = 0; i3 < k; ++i3) {
                    double[] dArray = posteriori[i3];
                    int n3 = j;
                    dArray[n3] = dArray[n3] * (1.0 + gamma * MathEx.log2(posteriori[i3][j]));
                    if (!Double.isNaN(posteriori[i3][j]) && !(posteriori[i3][j] < 0.0)) continue;
                    posteriori[i3][j] = 0.0;
                }
            }
            double Z = 0.0;
            for (i = 0; i < k; ++i) {
                components[i] = ((MultivariateExponentialFamily)((Object)components[i].distribution)).M(x, posteriori[i]);
                Z += components[i].priori;
            }
            for (i = 0; i < k; ++i) {
                components[i] = new MultivariateMixture.Component(components[i].priori / Z, components[i].distribution);
            }
            double loglikelihood = 0.0;
            for (double[] xi : x) {
                double p = 0.0;
                for (MultivariateMixture.Component c : components) {
                    p += c.priori * c.distribution.p(xi);
                }
                if (!(p > 0.0)) continue;
                loglikelihood += Math.log(p);
            }
            diff = loglikelihood - L;
            L = loglikelihood;
            if (iter % 10 != 0) continue;
            logger.info(String.format("The log-likelihood after %d iterations: %.4f", iter, L));
        }
        return new MultivariateExponentialFamilyMixture(L, x.length, components);
    }
}

