package smile.regression;

import java.util.Properties;
import smile.data.DataFrame;
import smile.data.formula.Formula;
import smile.data.type.StructType;
import smile.math.MathEx;
import smile.math.matrix.Matrix;

/* loaded from: input_file:smile/regression/ElasticNet.class */
public class ElasticNet {
    public static LinearModel fit(Formula formula, DataFrame dataFrame, Properties properties) {
        return fit(formula, dataFrame, Double.valueOf(properties.getProperty("smile.elastic.net.lambda1")).doubleValue(), Double.valueOf(properties.getProperty("smile.elastic.net.lambda2")).doubleValue(), Double.valueOf(properties.getProperty("smile.elastic.net.tolerance", "1E-4")).doubleValue(), Integer.valueOf(properties.getProperty("smile.elastic.net.max.iterations", "1000")).intValue());
    }

    public static LinearModel fit(Formula formula, DataFrame dataFrame, double d, double d2) {
        return fit(formula, dataFrame, d, d2, 1.0E-4d, 1000);
    }

    public static LinearModel fit(Formula formula, DataFrame dataFrame, double d, double d2, double d3, int i) {
        if (d <= 0.0d) {
            throw new IllegalArgumentException("Please use Ridge instead, wrong L1 portion setting: " + d);
        }
        if (d2 <= 0.0d) {
            throw new IllegalArgumentException("Please use LASSO instead, wrong L2 portion setting: " + d2);
        }
        double sqrt = 1.0d / Math.sqrt(1.0d + d2);
        Formula expand = formula.expand(dataFrame.schema());
        StructType bind = expand.bind(dataFrame.schema());
        Matrix matrix = expand.matrix(dataFrame, false);
        double[] doubleArray = expand.y(dataFrame).toDoubleArray();
        int nrows = matrix.nrows();
        int ncols = matrix.ncols();
        double[] colMeans = matrix.colMeans();
        double[] colSds = matrix.colSds();
        double[] dArr = new double[doubleArray.length + ncols];
        System.arraycopy(doubleArray, 0, dArr, 0, doubleArray.length);
        Matrix matrix2 = new Matrix(matrix.nrows() + ncols, ncols);
        double sqrt2 = sqrt * Math.sqrt(d2);
        for (int i2 = 0; i2 < ncols; i2++) {
            for (int i3 = 0; i3 < nrows; i3++) {
                matrix2.set(i3, i2, (sqrt * (matrix.get(i3, i2) - colMeans[i2])) / colSds[i2]);
            }
            matrix2.set(i2 + nrows, i2, sqrt2);
        }
        double[] train = LASSO.train(matrix2, dArr, d * sqrt, d3, i);
        for (int i4 = 0; i4 < ncols; i4++) {
            train[i4] = (sqrt * train[i4]) / colSds[i4];
        }
        return new LinearModel(expand, bind, matrix, doubleArray, train, MathEx.mean(doubleArray) - MathEx.dot(train, colMeans));
    }
}
