/*
 * Decompiled with CFR 0.152.
 */
package smile.glm;

import java.io.Serializable;
import java.util.Properties;
import java.util.stream.IntStream;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import smile.data.CategoricalEncoder;
import smile.data.DataFrame;
import smile.data.Tuple;
import smile.data.formula.Formula;
import smile.data.type.StructType;
import smile.glm.model.Model;
import smile.math.MathEx;
import smile.math.special.Erf;
import smile.stat.Hypothesis;
import smile.tensor.Cholesky;
import smile.tensor.DenseMatrix;
import smile.tensor.QR;
import smile.tensor.Vector;
import smile.validation.ModelSelection;

public class GLM
implements Serializable {
    private static final long serialVersionUID = 2L;
    private static final Logger logger = LoggerFactory.getLogger(GLM.class);
    protected final Formula formula;
    final String[] predictors;
    protected final Model model;
    protected final double[] beta;
    protected final double[][] ztest;
    protected final double[] mu;
    protected final double nullDeviance;
    protected final double deviance;
    protected final double[] devianceResiduals;
    protected final int df;
    protected final double logLikelihood;

    public GLM(Formula formula, String[] predictors, Model model, double[] beta, double logLikelihood, double deviance, double nullDeviance, double[] mu, double[] residuals, double[][] ztest) {
        this.formula = formula;
        this.model = model;
        this.predictors = predictors;
        this.beta = beta;
        this.logLikelihood = logLikelihood;
        this.deviance = deviance;
        this.nullDeviance = nullDeviance;
        this.mu = mu;
        this.devianceResiduals = residuals;
        this.ztest = ztest;
        this.df = mu.length - beta.length;
    }

    public double[] coefficients() {
        return this.beta;
    }

    public double[][] ztest() {
        return this.ztest;
    }

    public double[] devianceResiduals() {
        return this.devianceResiduals;
    }

    public double[] fittedValues() {
        return this.mu;
    }

    public double deviance() {
        return this.deviance;
    }

    public double logLikelihood() {
        return this.logLikelihood;
    }

    public double AIC() {
        return ModelSelection.AIC(this.logLikelihood, this.beta.length);
    }

    public double BIC() {
        return ModelSelection.BIC(this.logLikelihood, this.beta.length, this.mu.length);
    }

    public double predict(Tuple x) {
        double[] a = this.formula.x(x).toArray(true, CategoricalEncoder.DUMMY, new String[0]);
        int p = this.beta.length;
        double dot = 0.0;
        for (int i = 0; i < p; ++i) {
            dot += a[i] * this.beta[i];
        }
        return this.model.invlink(dot);
    }

    public double[] predict(DataFrame data) {
        DenseMatrix X = this.formula.matrix(data, true);
        double[] y = X.mv(this.beta).toArray(new double[0]);
        int n = y.length;
        for (int i = 0; i < n; ++i) {
            y[i] = this.model.invlink(y[i]);
        }
        return y;
    }

    public String toString() {
        StringBuilder builder = new StringBuilder();
        builder.append(String.format("Generalized Linear Model - %s:\n", this.model));
        double[] r = (double[])this.devianceResiduals.clone();
        builder.append("\nDeviance Residuals:\n");
        builder.append("       Min          1Q      Median          3Q         Max\n");
        builder.append(String.format("%10.4f  %10.4f  %10.4f  %10.4f  %10.4f%n", MathEx.min((double[])r), MathEx.q1((double[])r), MathEx.median((double[])r), MathEx.q3((double[])r), MathEx.max((double[])r)));
        int p = this.beta.length - 1;
        builder.append("\nCoefficients:\n");
        if (this.ztest != null) {
            builder.append("                  Estimate Std. Error    z value   Pr(>|z|)\n");
            for (int i = 0; i < p; ++i) {
                builder.append(String.format("%-15s %10.3e %10.3e %10.4f %10.5f %s%n", this.predictors[i], this.ztest[i][0], this.ztest[i][1], this.ztest[i][2], this.ztest[i][3], Hypothesis.significance((double)this.ztest[i][3])));
            }
            builder.append("---------------------------------------------------------------------\n");
            builder.append("Significance codes:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1\n");
        } else {
            builder.append(String.format("Intercept       %10.4f%n", this.beta[p]));
            for (int i = 0; i < p; ++i) {
                builder.append(String.format("%-15s %10.4f%n", this.predictors[i], this.beta[i]));
            }
        }
        builder.append(String.format("%n    Null deviance: %.1f on %d degrees of freedom", this.nullDeviance, this.df + p));
        builder.append(String.format("%nResidual deviance: %.1f on %d degrees of freedom", this.deviance, this.df));
        builder.append(String.format("%nAIC: %.4f     BIC: %.4f%n", this.AIC(), this.BIC()));
        return builder.toString();
    }

    public static GLM fit(Formula formula, DataFrame data, Model model) {
        return GLM.fit(formula, data, model, new Options());
    }

    public static GLM fit(Formula formula, DataFrame data, Model model, Options options) {
        int p;
        StructType schema = formula.bind(data.schema());
        DenseMatrix X = formula.matrix(data, true);
        DenseMatrix XW = X.zeros(X.nrow(), X.ncol());
        double[] y = formula.y(data).toDoubleArray();
        int n = X.nrow();
        if (n <= (p = X.ncol())) {
            throw new IllegalArgumentException(String.format("The input matrix is not over determined: %d rows, %d columns", n, p));
        }
        double[] eta = new double[n];
        double[] mu = new double[n];
        double[] w = new double[n];
        double[] z = new double[n];
        double[] residuals = new double[n];
        IntStream.range(0, n).parallel().forEach(i -> {
            mu[i] = model.mustart(y[i]);
            eta[i] = model.link(mu[i]);
            double g = model.dlink(mu[i]);
            z[i] = eta[i] + (y[i] - mu[i]) * g;
            double v = model.variance(mu[i]);
            w[i] = 1.0 / (g * Math.sqrt(v));
            int n = i;
            z[n] = z[n] * w[i];
        });
        for (int j = 0; j < p; ++j) {
            for (int i2 = 0; i2 < n; ++i2) {
                XW.set(i2, j, X.get(i2, j) * w[i2]);
            }
        }
        QR qr = XW.qr();
        Vector beta = qr.solve(z);
        Vector eta_ = Vector.column((double[])eta);
        double dev = Double.POSITIVE_INFINITY;
        for (int iter = 0; iter < options.maxIter; ++iter) {
            X.mv(beta, eta_);
            IntStream.range(0, n).parallel().forEach(i -> {
                mu[i] = model.invlink(eta[i]);
                double g = model.dlink(mu[i]);
                z[i] = eta[i] + (y[i] - mu[i]) * g;
                double v = model.variance(mu[i]);
                w[i] = 1.0 / (g * Math.sqrt(v));
                int n = i;
                z[n] = z[n] * w[i];
            });
            double newDev = model.deviance(y, mu, residuals);
            if (iter > 0) {
                logger.info("Deviance after {} iterations: {}", (Object)iter, (Object)dev);
            }
            if (dev - newDev < options.tol) break;
            dev = newDev;
            for (int j = 0; j < p; ++j) {
                for (int i3 = 0; i3 < n; ++i3) {
                    XW.set(i3, j, X.get(i3, j) * w[i3]);
                }
            }
            qr = XW.qr();
            beta = qr.solve(z);
        }
        Cholesky cholesky = qr.toCholesky();
        DenseMatrix inv = cholesky.inverse();
        double[][] ztest = new double[p][4];
        for (int i4 = 0; i4 < p; ++i4) {
            ztest[i4][0] = beta.get(i4);
            ztest[i4][1] = Math.sqrt(inv.get(i4, i4));
            ztest[i4][2] = ztest[i4][0] / ztest[i4][1];
            ztest[i4][3] = 2.0 - Erf.erfc((double)(-0.7071067811865476 * Math.abs(ztest[i4][2])));
        }
        return new GLM(formula, schema.names(), model, beta.toArray(new double[0]), model.logLikelihood(y, mu), dev, model.nullDeviance(y, MathEx.mean((double[])y)), mu, residuals, ztest);
    }

    public record Options(double tol, int maxIter) {
        public Options {
            if (tol <= 0.0) {
                throw new IllegalArgumentException("Invalid tolerance: " + tol);
            }
            if (maxIter <= 0) {
                throw new IllegalArgumentException("Invalid maximum number of iterations: " + maxIter);
            }
        }

        public Options() {
            this(1.0E-5, 50);
        }

        public Properties toProperties() {
            Properties props = new Properties();
            props.setProperty("smile.glm.tolerance", Double.toString(this.tol));
            props.setProperty("smile.glm.iterations", Integer.toString(this.maxIter));
            return props;
        }

        public static Options of(Properties props) {
            double tol = Double.parseDouble(props.getProperty("smile.glm.tolerance", "1E-5"));
            int maxIter = Integer.parseInt(props.getProperty("smile.glm.iterations", "50"));
            return new Options(tol, maxIter);
        }
    }
}

