package smile.regression;

import java.util.Arrays;
import java.util.Properties;
import smile.data.DataFrame;
import smile.data.formula.Formula;
import smile.math.MathEx;
import smile.math.matrix.DenseMatrix;

/* loaded from: input_file:smile/regression/RidgeRegression.class */
public class RidgeRegression {
    public static LinearModel fit(Formula formula, DataFrame dataFrame) {
        return fit(formula, dataFrame, new Properties());
    }

    public static LinearModel fit(Formula formula, DataFrame dataFrame, Properties properties) {
        return fit(formula, dataFrame, Double.valueOf(properties.getProperty("smile.ridge.lambda", "1")).doubleValue());
    }

    public static LinearModel fit(Formula formula, DataFrame dataFrame, double d) {
        if (d < 0.0d) {
            throw new IllegalArgumentException("Invalid shrinkage/regularization parameter lambda = " + d);
        }
        DenseMatrix matrix = formula.matrix(dataFrame, false);
        double[] doubleArray = formula.y(dataFrame).toDoubleArray();
        int nrows = matrix.nrows();
        int ncols = matrix.ncols();
        if (nrows <= ncols) {
            throw new IllegalArgumentException(String.format("The input matrix is not over determined: %d rows, %d columns", Integer.valueOf(nrows), Integer.valueOf(ncols)));
        }
        LinearModel linearModel = new LinearModel();
        linearModel.formula = formula;
        linearModel.schema = formula.xschema();
        linearModel.p = ncols;
        double[] colMeans = matrix.colMeans();
        double[] colSds = matrix.colSds();
        for (int i = 0; i < colSds.length; i++) {
            if (MathEx.isZero(colSds[i])) {
                throw new IllegalArgumentException(String.format("The column '%s' is constant", formula.schema().fieldName(i)));
            }
        }
        DenseMatrix scale = matrix.scale(colMeans, colSds);
        linearModel.w = new double[ncols];
        scale.atx(doubleArray, linearModel.w);
        DenseMatrix ata = scale.ata();
        for (int i2 = 0; i2 < ncols; i2++) {
            ata.add(i2, i2, d);
        }
        ata.cholesky().solve(linearModel.w);
        for (int i3 = 0; i3 < ncols; i3++) {
            double[] dArr = linearModel.w;
            int i4 = i3;
            dArr[i4] = dArr[i4] / colSds[i3];
        }
        double mean = MathEx.mean(doubleArray);
        linearModel.b = mean - MathEx.dot(linearModel.w, colMeans);
        double[] dArr2 = new double[nrows];
        Arrays.fill(dArr2, linearModel.b);
        matrix.axpy(linearModel.w, dArr2);
        linearModel.fitness(dArr2, doubleArray, mean);
        return linearModel;
    }
}
