/*
 * Decompiled with CFR 0.152.
 */
package smile.regression;

import java.util.Properties;
import smile.data.DataFrame;
import smile.data.formula.Formula;
import smile.data.type.StructType;
import smile.math.MathEx;
import smile.math.matrix.Matrix;
import smile.regression.LASSO;
import smile.regression.LinearModel;

public class ElasticNet {
    public static LinearModel fit(Formula formula, DataFrame data, Properties params) {
        double lambda1 = Double.parseDouble(params.getProperty("smile.elastic_net.lambda1"));
        double lambda2 = Double.parseDouble(params.getProperty("smile.elastic_net.lambda2"));
        double tol = Double.parseDouble(params.getProperty("smile.elastic_net.tolerance", "1E-4"));
        int maxIter = Integer.parseInt(params.getProperty("smile.elastic_net.iterations", "1000"));
        return ElasticNet.fit(formula, data, lambda1, lambda2, tol, maxIter);
    }

    public static LinearModel fit(Formula formula, DataFrame data, double lambda1, double lambda2) {
        return ElasticNet.fit(formula, data, lambda1, lambda2, 1.0E-4, 1000);
    }

    public static LinearModel fit(Formula formula, DataFrame data, double lambda1, double lambda2, double tol, int maxIter) {
        int i;
        if (lambda1 <= 0.0) {
            throw new IllegalArgumentException("Please use Ridge instead, wrong L1 portion setting: " + lambda1);
        }
        if (lambda2 <= 0.0) {
            throw new IllegalArgumentException("Please use LASSO instead, wrong L2 portion setting: " + lambda2);
        }
        double c = 1.0 / Math.sqrt(1.0 + lambda2);
        formula = formula.expand(data.schema());
        StructType schema = formula.bind(data.schema());
        Matrix X = formula.matrix(data, false);
        double[] y = formula.y(data).toDoubleArray();
        int n = X.nrow();
        int p = X.ncol();
        double[] center = X.colMeans();
        double[] scale = X.colSds();
        double[] y2 = new double[n + p];
        double ym = MathEx.mean((double[])y);
        for (int i2 = 0; i2 < n; ++i2) {
            y2[i2] = y[i2] - ym;
        }
        Matrix X2 = new Matrix(X.nrow() + p, p);
        double padding = c * Math.sqrt(lambda2);
        for (int j = 0; j < p; ++j) {
            for (i = 0; i < n; ++i) {
                X2.set(i, j, c * (X.get(i, j) - center[j]) / scale[j]);
            }
            X2.set(j + n, j, padding);
        }
        double[] w = LASSO.train(X2, y2, lambda1 * c, tol, maxIter);
        for (i = 0; i < p; ++i) {
            w[i] = c * w[i] / scale[i];
        }
        double b = ym - MathEx.dot((double[])w, (double[])center);
        return new LinearModel(formula, schema, X, y, w, b);
    }
}

