/*
 * Decompiled with CFR 0.152.
 */
package smile.regression;

import java.io.Serializable;
import java.util.Arrays;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import smile.math.Math;
import smile.math.matrix.CholeskyDecomposition;
import smile.math.matrix.ColumnMajorMatrix;
import smile.math.matrix.DenseMatrix;
import smile.math.matrix.QRDecomposition;
import smile.math.matrix.SingularValueDecomposition;
import smile.math.special.Beta;
import smile.regression.Regression;
import smile.regression.RegressionTrainer;

public class OLS
implements Regression<double[]>,
Serializable {
    private static final long serialVersionUID = 1L;
    private static final Logger logger = LoggerFactory.getLogger(OLS.class);
    private int p;
    private double b;
    private double[] w;
    private double[][] coefficients;
    private double[] residuals;
    private double RSS;
    private double error;
    private int df;
    private double RSquared;
    private double adjustedRSquared;
    private double F;
    private double pvalue;

    public OLS(double[][] x, double[] y) {
        this(x, y, false);
    }

    public OLS(double[][] x, double[] y, boolean SVD) {
        if (x.length != y.length) {
            throw new IllegalArgumentException(String.format("The sizes of X and Y don't match: %d != %d", x.length, y.length));
        }
        int n = x.length;
        this.p = x[0].length;
        if (n <= this.p) {
            throw new IllegalArgumentException(String.format("The input matrix is not over determined: %d rows, %d columns", n, this.p));
        }
        double[] w1 = new double[this.p + 1];
        ColumnMajorMatrix X = new ColumnMajorMatrix(n, this.p + 1);
        for (int i = 0; i < n; ++i) {
            for (int j = 0; j < this.p; ++j) {
                X.set(i, j, x[i][j]);
            }
            X.set(i, this.p, 1.0);
        }
        QRDecomposition qr = null;
        SingularValueDecomposition svd = null;
        if (SVD) {
            svd = new SingularValueDecomposition((DenseMatrix)X);
            svd.solve(y, w1);
        } else {
            try {
                qr = new QRDecomposition((DenseMatrix)X);
                qr.solve(y, w1);
            }
            catch (RuntimeException e) {
                logger.warn("Matrix is not of full rank, try SVD instead");
                SVD = true;
                svd = new SingularValueDecomposition((DenseMatrix)X);
                Arrays.fill(w1, 0.0);
                svd.solve(y, w1);
            }
        }
        this.b = w1[this.p];
        this.w = new double[this.p];
        System.arraycopy(w1, 0, this.w, 0, this.p);
        double[] yhat = new double[n];
        Math.ax((double[][])x, (double[])this.w, (double[])yhat);
        double TSS = 0.0;
        this.RSS = 0.0;
        double ybar = Math.mean((double[])y);
        this.residuals = new double[n];
        for (int i = 0; i < n; ++i) {
            double r;
            this.residuals[i] = r = y[i] - yhat[i] - this.b;
            this.RSS += Math.sqr((double)r);
            TSS += Math.sqr((double)(y[i] - ybar));
        }
        this.error = Math.sqrt((double)(this.RSS / (double)(n - this.p - 1)));
        this.df = n - this.p - 1;
        this.RSquared = 1.0 - this.RSS / TSS;
        this.adjustedRSquared = 1.0 - (1.0 - this.RSquared) * (double)(n - 1) / (double)(n - this.p - 1);
        this.F = (TSS - this.RSS) * (double)(n - this.p - 1) / (this.RSS * (double)this.p);
        int df1 = this.p;
        int df2 = n - this.p - 1;
        this.pvalue = Beta.regularizedIncompleteBetaFunction((double)(0.5 * (double)df2), (double)(0.5 * (double)df1), (double)((double)df2 / ((double)df2 + (double)df1 * this.F)));
        this.coefficients = new double[this.p + 1][4];
        if (SVD) {
            for (int i = 0; i <= this.p; ++i) {
                this.coefficients[i][0] = w1[i];
                double s = svd.getSingularValues()[i];
                if (!Math.isZero((double)s, (double)1.0E-10)) {
                    double t;
                    double se;
                    this.coefficients[i][1] = se = this.error / svd.getSingularValues()[i];
                    this.coefficients[i][2] = t = w1[i] / se;
                    this.coefficients[i][3] = Beta.regularizedIncompleteBetaFunction((double)(0.5 * (double)this.df), (double)0.5, (double)((double)this.df / ((double)this.df + t * t)));
                    continue;
                }
                this.coefficients[i][1] = Double.NaN;
                this.coefficients[i][2] = 0.0;
                this.coefficients[i][3] = 1.0;
            }
        } else {
            CholeskyDecomposition cholesky = qr.toCholesky();
            DenseMatrix inv = cholesky.inverse();
            for (int i = 0; i <= this.p; ++i) {
                double t;
                double se;
                this.coefficients[i][0] = w1[i];
                this.coefficients[i][1] = se = this.error * Math.sqrt((double)inv.get(i, i));
                this.coefficients[i][2] = t = w1[i] / se;
                this.coefficients[i][3] = Beta.regularizedIncompleteBetaFunction((double)(0.5 * (double)this.df), (double)0.5, (double)((double)this.df / ((double)this.df + t * t)));
            }
        }
    }

    public double[][] ttest() {
        return this.coefficients;
    }

    public double[] coefficients() {
        return this.w;
    }

    public double intercept() {
        return this.b;
    }

    public double[] residuals() {
        return this.residuals;
    }

    public double RSS() {
        return this.RSS;
    }

    public double error() {
        return this.error;
    }

    public int df() {
        return this.df;
    }

    public double RSquared() {
        return this.RSquared;
    }

    public double adjustedRSquared() {
        return this.adjustedRSquared;
    }

    public double ftest() {
        return this.F;
    }

    public double pvalue() {
        return this.pvalue;
    }

    @Override
    public double predict(double[] x) {
        if (x.length != this.p) {
            throw new IllegalArgumentException(String.format("Invalid input vector size: %d, expected: %d", x.length, this.p));
        }
        return this.b + Math.dot((double[])x, (double[])this.w);
    }

    private String significance(double pvalue) {
        if (pvalue < 0.001) {
            return "***";
        }
        if (pvalue < 0.01) {
            return "**";
        }
        if (pvalue < 0.05) {
            return "*";
        }
        if (pvalue < 0.1) {
            return ".";
        }
        return "";
    }

    public String toString() {
        StringBuilder builder = new StringBuilder();
        builder.append("Linear Model:\n");
        double[] r = (double[])this.residuals.clone();
        builder.append("\nResiduals:\n");
        builder.append("\t       Min\t        1Q\t    Median\t        3Q\t       Max\n");
        builder.append(String.format("\t%10.4f\t%10.4f\t%10.4f\t%10.4f\t%10.4f%n", Math.min((double[])r), Math.q1((double[])r), Math.median((double[])r), Math.q3((double[])r), Math.max((double[])r)));
        builder.append("\nCoefficients:\n");
        builder.append("            Estimate        Std. Error        t value        Pr(>|t|)\n");
        builder.append(String.format("Intercept%11.4f%18.4f%15.4f%16.4f %s%n", this.coefficients[this.p][0], this.coefficients[this.p][1], this.coefficients[this.p][2], this.coefficients[this.p][3], this.significance(this.coefficients[this.p][3])));
        for (int i = 0; i < this.p; ++i) {
            builder.append(String.format("Var %d\t %11.4f%18.4f%15.4f%16.4f %s%n", i + 1, this.coefficients[i][0], this.coefficients[i][1], this.coefficients[i][2], this.coefficients[i][3], this.significance(this.coefficients[i][3])));
        }
        builder.append("---------------------------------------------------------------------\n");
        builder.append("Significance codes:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1\n");
        builder.append(String.format("\nResidual standard error: %.4f on %d degrees of freedom%n", this.error, this.df));
        builder.append(String.format("Multiple R-squared: %.4f,    Adjusted R-squared: %.4f%n", this.RSquared, this.adjustedRSquared));
        builder.append(String.format("F-statistic: %.4f on %d and %d DF,  p-value: %.4g%n", this.F, this.p, this.df, this.pvalue));
        return builder.toString();
    }

    public static class Trainer
    extends RegressionTrainer<double[]> {
        public OLS train(double[][] x, double[] y) {
            return new OLS(x, y);
        }
    }
}

