/*
 * Decompiled with CFR 0.152.
 */
package smile.regression;

import java.util.Arrays;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import smile.data.DataFrame;
import smile.data.Tuple;
import smile.data.formula.Formula;
import smile.data.type.StructType;
import smile.math.MathEx;
import smile.math.matrix.DenseMatrix;
import smile.math.special.Beta;
import smile.regression.DataFrameRegression;
import smile.regression.OnlineRegression;

public class LinearModel
implements OnlineRegression<double[]>,
DataFrameRegression {
    private static final long serialVersionUID = 2L;
    private static final Logger logger = LoggerFactory.getLogger(LinearModel.class);
    Formula formula;
    StructType schema;
    int p;
    double b;
    double[] w;
    double[][] ttest;
    double[] fittedValues;
    double[] residuals;
    double RSS;
    double error;
    int df;
    double RSquared;
    double adjustedRSquared;
    double F;
    double pvalue;
    DenseMatrix V;

    LinearModel() {
    }

    @Override
    public Formula formula() {
        return this.formula;
    }

    @Override
    public StructType schema() {
        return this.schema;
    }

    public double[][] ttest() {
        return this.ttest;
    }

    public double[] coefficients() {
        return this.w;
    }

    public double intercept() {
        return this.b;
    }

    public double[] residuals() {
        return this.residuals;
    }

    public double[] fittedValues() {
        return this.fittedValues;
    }

    public double RSS() {
        return this.RSS;
    }

    public double error() {
        return this.error;
    }

    public int df() {
        return this.df;
    }

    public double RSquared() {
        return this.RSquared;
    }

    public double adjustedRSquared() {
        return this.adjustedRSquared;
    }

    public double ftest() {
        return this.F;
    }

    public double pvalue() {
        return this.pvalue;
    }

    void fitness(double[] fittedValues, double[] y, double ym) {
        int n = fittedValues.length;
        this.fittedValues = fittedValues;
        this.residuals = new double[n];
        this.RSS = 0.0;
        double TSS = 0.0;
        for (int i = 0; i < n; ++i) {
            this.residuals[i] = y[i] - fittedValues[i];
            this.RSS += MathEx.sqr((double)this.residuals[i]);
            TSS += MathEx.sqr((double)(y[i] - ym));
        }
        this.error = Math.sqrt(this.RSS / (double)(n - this.p - 1));
        this.df = n - this.p - 1;
        this.RSquared = 1.0 - this.RSS / TSS;
        this.adjustedRSquared = 1.0 - (1.0 - this.RSquared) * (double)(n - 1) / (double)(n - this.p - 1);
        this.F = (TSS - this.RSS) * (double)(n - this.p - 1) / (this.RSS * (double)this.p);
        int df1 = this.p;
        int df2 = n - this.p - 1;
        if (df2 > 0) {
            this.pvalue = Beta.regularizedIncompleteBetaFunction((double)(0.5 * (double)df2), (double)(0.5 * (double)df1), (double)((double)df2 / ((double)df2 + (double)df1 * this.F)));
        } else {
            logger.warn("Skip calculating p-value: the linear system is under-determined.");
            this.pvalue = Double.NaN;
        }
    }

    @Override
    public double predict(double[] x) {
        if (x.length != this.p) {
            throw new IllegalArgumentException(String.format("Invalid input vector size: %d, expected: %d", x.length, this.p));
        }
        return this.b + MathEx.dot((double[])x, (double[])this.w);
    }

    @Override
    public double predict(Tuple x) {
        return this.predict(this.formula.xarray(x));
    }

    @Override
    public double[] predict(DataFrame df) {
        DenseMatrix X = this.formula.matrix(df, false);
        double[] y = new double[X.nrows()];
        Arrays.fill(y, this.b);
        X.axpy(this.w, y);
        return y;
    }

    public void update(Tuple data) {
        this.update(this.formula.xarray(data), this.formula.y(data));
    }

    public void update(DataFrame data) {
        int n = data.size();
        for (int i = 0; i < n; ++i) {
            this.update((Tuple)data.get(i));
        }
    }

    @Override
    public void update(double[] x, double y) {
        this.update(x, y, 1.0);
    }

    public void update(double[] x, double y, double lambda) {
        if (this.V == null) {
            throw new UnsupportedOperationException("The model doesn't support online learning");
        }
        if (lambda <= 0.0 || lambda > 1.0) {
            throw new IllegalArgumentException("The forgetting factor must be in (0, 1]");
        }
        if (x.length != this.p) {
            throw new IllegalArgumentException(String.format("Invalid input vector size: %d, expected: %d", x.length, this.p));
        }
        double[] x1 = new double[this.p + 1];
        System.arraycopy(x, 0, x1, 0, this.p);
        x1[this.p] = 1.0;
        double v = 1.0 + this.V.xax(x1);
        if (Double.isNaN(1.0 / v)) {
            throw new IllegalStateException("The updated V matrix is no longer invertible.");
        }
        double[] Vx = new double[this.p + 1];
        this.V.ax(x1, Vx);
        for (int j = 0; j <= this.p; ++j) {
            for (int i = 0; i <= this.p; ++i) {
                double tmp = this.V.get(i, j) - Vx[i] * Vx[j] / v;
                this.V.set(i, j, tmp / lambda);
            }
        }
        this.V.ax(x1, Vx);
        double err = y - this.predict(x);
        for (int i = 0; i < this.p; ++i) {
            int n = i;
            this.w[n] = this.w[n] + Vx[i] * err;
        }
        this.b += Vx[this.p] * err;
    }

    private String significance(double pvalue) {
        if (pvalue < 0.001) {
            return "***";
        }
        if (pvalue < 0.01) {
            return "**";
        }
        if (pvalue < 0.05) {
            return "*";
        }
        if (pvalue < 0.1) {
            return ".";
        }
        return "";
    }

    public String toString() {
        StringBuilder builder = new StringBuilder();
        builder.append("Linear Model:\n");
        double[] r = (double[])this.residuals.clone();
        builder.append("\nResiduals:\n");
        builder.append("\t       Min\t        1Q\t    Median\t        3Q\t       Max\n");
        builder.append(String.format("\t%10.4f\t%10.4f\t%10.4f\t%10.4f\t%10.4f%n", MathEx.min((double[])r), MathEx.q1((double[])r), MathEx.median((double[])r), MathEx.q3((double[])r), MathEx.max((double[])r)));
        builder.append("\nCoefficients:\n");
        if (this.ttest != null) {
            builder.append("                  Estimate Std. Error    t value   Pr(>|t|)\n");
            if (this.ttest.length > this.p) {
                builder.append(String.format("Intercept       %10.4f %10.4f %10.4f %10.4f %s%n", this.ttest[this.p][0], this.ttest[this.p][1], this.ttest[this.p][2], this.ttest[this.p][3], this.significance(this.ttest[this.p][3])));
            } else {
                builder.append(String.format("Intercept       %10.4f%n", this.b));
            }
            for (int i = 0; i < this.p; ++i) {
                builder.append(String.format("%-15s %10.4f %10.4f %10.4f %10.4f %s%n", this.schema.fieldName(i), this.ttest[i][0], this.ttest[i][1], this.ttest[i][2], this.ttest[i][3], this.significance(this.ttest[i][3])));
            }
            builder.append("---------------------------------------------------------------------\n");
            builder.append("Significance codes:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1\n");
        } else {
            builder.append(String.format("Intercept       %10.4f%n", this.b));
            for (int i = 0; i < this.p; ++i) {
                builder.append(String.format("%-15s %10.4f%n", this.schema.fieldName(i), this.w[i]));
            }
        }
        builder.append(String.format("\nResidual standard error: %.4f on %d degrees of freedom%n", this.error, this.df));
        builder.append(String.format("Multiple R-squared: %.4f,    Adjusted R-squared: %.4f%n", this.RSquared, this.adjustedRSquared));
        builder.append(String.format("F-statistic: %.4f on %d and %d DF,  p-value: %.4g%n", this.F, this.p, this.df, this.pvalue));
        return builder.toString();
    }
}

