package smile.regression;

import java.util.Arrays;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import smile.data.CategoricalEncoder;
import smile.data.DataFrame;
import smile.data.Tuple;
import smile.data.formula.Formula;
import smile.data.type.StructType;
import smile.math.MathEx;
import smile.math.matrix.Matrix;
import smile.math.special.Beta;
import smile.stat.Hypothesis;

/* loaded from: input_file:smile/regression/LinearModel.class */
public class LinearModel implements DataFrameRegression {
    private static final long serialVersionUID = 2;
    private static final Logger logger = LoggerFactory.getLogger(LinearModel.class);
    Formula formula;
    StructType schema;
    String[] predictors;
    int p;
    double b;
    double[] w;
    boolean bias;
    double[][] ttest;
    double[] fittedValues;
    double[] residuals;
    double RSS;
    double error;
    int df;
    double RSquared;
    double adjustedRSquared;
    double F;
    double pvalue;
    Matrix V;

    public LinearModel(Formula formula, StructType structType, Matrix matrix, double[] dArr, double[] dArr2, double d) {
        this.formula = formula;
        this.schema = structType;
        this.predictors = matrix.colNames();
        this.p = matrix.ncol();
        this.w = dArr2;
        this.b = d;
        this.bias = this.predictors[0].equals("Intercept");
        int nrow = matrix.nrow();
        this.fittedValues = new double[nrow];
        Arrays.fill(this.fittedValues, d);
        matrix.mv(1.0d, dArr2, 1.0d, this.fittedValues);
        this.residuals = new double[nrow];
        this.RSS = 0.0d;
        double d2 = 0.0d;
        double mean = MathEx.mean(dArr);
        for (int i = 0; i < nrow; i++) {
            this.residuals[i] = dArr[i] - this.fittedValues[i];
            this.RSS += MathEx.pow2(this.residuals[i]);
            d2 += MathEx.pow2(dArr[i] - mean);
        }
        this.error = Math.sqrt(this.RSS / (nrow - this.p));
        this.df = nrow - this.p;
        this.RSquared = 1.0d - (this.RSS / d2);
        this.adjustedRSquared = 1.0d - (((1.0d - this.RSquared) * (nrow - 1)) / (nrow - this.p));
        this.F = ((d2 - this.RSS) * (nrow - this.p)) / (this.RSS * (this.p - 1));
        int i2 = this.p - 1;
        int i3 = nrow - this.p;
        if (i3 > 0 && this.F > 0.0d) {
            this.pvalue = Beta.regularizedIncompleteBetaFunction(0.5d * i3, 0.5d * i2, i3 / (i3 + (i2 * this.F)));
        } else {
            logger.warn("Skip calculating p-value: {}.", this.F <= 0.0d ? "R2 is not positive" : "the linear system is under-determined");
            this.pvalue = Double.NaN;
        }
    }

    @Override // smile.regression.DataFrameRegression
    public Formula formula() {
        return this.formula;
    }

    @Override // smile.regression.DataFrameRegression
    public StructType schema() {
        return this.schema;
    }

    public double[][] ttest() {
        return this.ttest;
    }

    public double[] coefficients() {
        return this.bias ? Arrays.copyOfRange(this.w, 1, this.w.length) : this.w;
    }

    public double intercept() {
        return this.bias ? this.w[0] : this.b;
    }

    public double[] residuals() {
        return this.residuals;
    }

    public double[] fittedValues() {
        return this.fittedValues;
    }

    public double RSS() {
        return this.RSS;
    }

    public double error() {
        return this.error;
    }

    public int df() {
        return this.df;
    }

    public double RSquared() {
        return this.RSquared;
    }

    public double adjustedRSquared() {
        return this.adjustedRSquared;
    }

    public double ftest() {
        return this.F;
    }

    public double pvalue() {
        return this.pvalue;
    }

    public double predict(double[] dArr) {
        double d = this.b;
        if (dArr.length == this.w.length) {
            for (int i = 0; i < dArr.length; i++) {
                d += dArr[i] * this.w[i];
            }
        } else {
            if (!this.bias || dArr.length != this.w.length - 1) {
                throw new IllegalArgumentException("Invalid vector size: " + dArr.length);
            }
            d = this.w[0];
            for (int i2 = 0; i2 < dArr.length; i2++) {
                d += dArr[i2] * this.w[i2 + 1];
            }
        }
        return d;
    }

    @Override // smile.regression.Regression
    public double predict(Tuple tuple) {
        return predict(this.formula.x(tuple).toArray(false, CategoricalEncoder.DUMMY, new String[0]));
    }

    @Override // smile.regression.DataFrameRegression
    public double[] predict(DataFrame dataFrame) {
        if (this.bias) {
            return this.formula.matrix(dataFrame, true).mv(this.w);
        }
        Matrix matrix = this.formula.matrix(dataFrame, false);
        double[] dArr = new double[matrix.nrow()];
        Arrays.fill(dArr, this.b);
        matrix.mv(1.0d, this.w, 1.0d, dArr);
        return dArr;
    }

    public void update(Tuple tuple) {
        update(this.formula.x(tuple).toArray(this.bias, CategoricalEncoder.DUMMY, new String[0]), this.formula.y(tuple));
    }

    public void update(DataFrame dataFrame) {
        int size = dataFrame.size();
        for (int i = 0; i < size; i++) {
            update((Tuple) dataFrame.get(i));
        }
    }

    @Override // smile.regression.Regression
    public boolean online() {
        return this.V != null;
    }

    public void update(double[] dArr, double d) {
        update(dArr, d, 1.0d);
    }

    public void update(double[] dArr, double d, double d2) {
        if (this.V == null) {
            throw new UnsupportedOperationException("The model doesn't support online learning");
        }
        if (d2 <= 0.0d || d2 > 1.0d) {
            throw new IllegalArgumentException("The forgetting factor must be in (0, 1]");
        }
        if (dArr.length != this.p) {
            throw new IllegalArgumentException(String.format("Invalid input vector size: %d, expected: %d", Integer.valueOf(dArr.length), Integer.valueOf(this.p)));
        }
        double xAx = 1.0d + this.V.xAx(dArr);
        if (Double.isNaN(1.0d / xAx)) {
            throw new IllegalStateException("The updated V matrix is no longer invertible.");
        }
        double[] mv = this.V.mv(dArr);
        for (int i = 0; i < this.p; i++) {
            for (int i2 = 0; i2 < this.p; i2++) {
                this.V.set(i2, i, (this.V.get(i2, i) - ((mv[i2] * mv[i]) / xAx)) / d2);
            }
        }
        this.V.mv(dArr, mv);
        double predict = d - predict(dArr);
        for (int i3 = 0; i3 < this.p; i3++) {
            double[] dArr2 = this.w;
            int i4 = i3;
            dArr2[i4] = dArr2[i4] + (mv[i3] * predict);
        }
    }

    public String toString() {
        StringBuilder sb = new StringBuilder();
        sb.append("Linear Model:\n");
        double[] dArr = (double[]) this.residuals.clone();
        sb.append("\nResiduals:\n");
        sb.append("       Min          1Q      Median          3Q         Max\n");
        sb.append(String.format("%10.4f  %10.4f  %10.4f  %10.4f  %10.4f%n", Double.valueOf(MathEx.min(dArr)), Double.valueOf(MathEx.q1(dArr)), Double.valueOf(MathEx.median(dArr)), Double.valueOf(MathEx.q3(dArr)), Double.valueOf(MathEx.max(dArr))));
        sb.append("\nCoefficients:\n");
        if (this.ttest != null) {
            sb.append("                  Estimate Std. Error    t value   Pr(>|t|)\n");
            if (!this.bias) {
                sb.append(String.format("Intercept       %10.4f%n", Double.valueOf(this.b)));
            }
            for (int i = 0; i < this.p; i++) {
                sb.append(String.format("%-15s %10.4f %10.4f %10.4f %10.4f %s%n", this.predictors[i], Double.valueOf(this.ttest[i][0]), Double.valueOf(this.ttest[i][1]), Double.valueOf(this.ttest[i][2]), Double.valueOf(this.ttest[i][3]), Hypothesis.significance(this.ttest[i][3])));
            }
            sb.append("---------------------------------------------------------------------\n");
            sb.append("Significance codes:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1\n");
        } else {
            if (!this.bias) {
                sb.append(String.format("Intercept       %10.4f%n", Double.valueOf(this.b)));
            }
            for (int i2 = 0; i2 < this.p; i2++) {
                sb.append(String.format("%-15s %10.4f%n", this.predictors[i2], Double.valueOf(this.w[i2])));
            }
        }
        sb.append(String.format("%nResidual standard error: %.4f on %d degrees of freedom%n", Double.valueOf(this.error), Integer.valueOf(this.df)));
        sb.append(String.format("Multiple R-squared: %.4f,    Adjusted R-squared: %.4f%n", Double.valueOf(this.RSquared), Double.valueOf(this.adjustedRSquared)));
        sb.append(String.format("F-statistic: %.4f on %d and %d DF,  p-value: %.4g%n", Double.valueOf(this.F), Integer.valueOf(this.p), Integer.valueOf(this.df), Double.valueOf(this.pvalue)));
        return sb.toString();
    }
}
