/*
 * Decompiled with CFR 0.152.
 */
package smile.regression;

import java.util.Arrays;
import java.util.Properties;
import smile.data.DataFrame;
import smile.data.formula.Formula;
import smile.math.MathEx;
import smile.math.matrix.DenseMatrix;
import smile.math.matrix.Matrix;
import smile.regression.LASSO;
import smile.regression.LinearModel;

public class ElasticNet {
    public static LinearModel fit(Formula formula, DataFrame data, Properties prop) {
        double lambda1 = Double.valueOf(prop.getProperty("smile.elastic.net.lambda1"));
        double lambda2 = Double.valueOf(prop.getProperty("smile.elastic.net.lambda2"));
        double tol = Double.valueOf(prop.getProperty("smile.elastic.net.tolerance", "1E-4"));
        int maxIter = Integer.valueOf(prop.getProperty("smile.elastic.net.max.iterations", "1000"));
        return ElasticNet.fit(formula, data, lambda1, lambda2, tol, maxIter);
    }

    public static LinearModel fit(Formula formula, DataFrame data, double lambda1, double lambda2) {
        return ElasticNet.fit(formula, data, lambda1, lambda2, 1.0E-4, 1000);
    }

    public static LinearModel fit(Formula formula, DataFrame data, double lambda1, double lambda2, double tol, int maxIter) {
        if (lambda1 <= 0.0) {
            throw new IllegalArgumentException("Please use Ridge instead, wrong L1 portion setting: " + lambda1);
        }
        if (lambda2 <= 0.0) {
            throw new IllegalArgumentException("Please use LASSO instead, wrong L2 portion setting: " + lambda2);
        }
        double c = 1.0 / Math.sqrt(1.0 + lambda2);
        DenseMatrix X = formula.matrix(data, false);
        double[] y = formula.y(data).toDoubleArray();
        int n = X.nrows();
        int p = X.ncols();
        double[] center = X.colMeans();
        double[] scale = X.colSds();
        double[] y2 = new double[y.length + p];
        System.arraycopy(y, 0, y2, 0, y.length);
        DenseMatrix X2 = Matrix.zeros((int)(X.nrows() + p), (int)p);
        double padding = c * Math.sqrt(lambda2);
        for (int j = 0; j < p; ++j) {
            for (int i = 0; i < n; ++i) {
                X2.set(i, j, c * (X.get(i, j) - center[j]) / scale[j]);
            }
            X2.set(j + n, j, padding);
        }
        LinearModel model = LASSO.train((Matrix)X2, y2, lambda1 * c, tol, maxIter);
        model.formula = formula;
        model.schema = formula.xschema();
        double[] w = new double[p];
        for (int i = 0; i < p; ++i) {
            w[i] = c * model.w[i] / scale[i];
        }
        model.w = w;
        double ym = MathEx.mean((double[])y);
        model.b = ym - MathEx.dot((double[])model.w, (double[])center);
        double[] fittedValues = new double[y.length];
        Arrays.fill(fittedValues, model.b);
        X.axpy(model.w, fittedValues);
        model.fitness(fittedValues, y, ym);
        return model;
    }
}

