/*
 * Decompiled with CFR 0.152.
 */
package smile.regression;

import java.util.Properties;
import smile.data.DataFrame;
import smile.data.formula.Formula;
import smile.data.type.StructType;
import smile.math.MathEx;
import smile.regression.LASSO;
import smile.regression.LinearModel;
import smile.tensor.DenseMatrix;
import smile.tensor.Vector;

public class ElasticNet {
    private ElasticNet() {
    }

    public static LinearModel fit(Formula formula, DataFrame data, double lambda1, double lambda2) {
        return ElasticNet.fit(formula, data, new Options(lambda1, lambda2));
    }

    public static LinearModel fit(Formula formula, DataFrame data, Options options) {
        double c = 1.0 / Math.sqrt(1.0 + options.lambda2);
        formula = formula.expand(data.schema());
        StructType schema = formula.bind(data.schema());
        DenseMatrix X = formula.matrix(data, false);
        double[] y = formula.y(data).toDoubleArray();
        int n = X.nrow();
        int p = X.ncol();
        Vector center = X.colMeans();
        Vector scale = X.colSds();
        double[] centeredY = new double[n + p];
        double ymu = MathEx.mean((double[])y);
        for (int i = 0; i < n; ++i) {
            centeredY[i] = y[i] - ymu;
        }
        DenseMatrix scaledX = X.zeros(X.nrow() + p, p);
        double padding = c * Math.sqrt(options.lambda2);
        for (int j = 0; j < p; ++j) {
            for (int i = 0; i < n; ++i) {
                scaledX.set(i, j, c * (X.get(i, j) - center.get(j)) / scale.get(j));
            }
            scaledX.set(j + n, j, padding);
        }
        LASSO.Options lasso = new LASSO.Options(options.lambda1 * c, options.tol, options.maxIter, options.alpha, options.beta, options.eta, options.lsMaxIter, options.pcgMaxIter);
        Vector w = LASSO.train(scaledX, centeredY, lasso);
        for (int i = 0; i < p; ++i) {
            w.set(i, c * w.get(i) / scale.get(i));
        }
        double b = ymu - w.dot(center);
        return new LinearModel(formula, schema, X, y, w, b);
    }

    public record Options(double lambda1, double lambda2, double tol, int maxIter, double alpha, double beta, double eta, int lsMaxIter, int pcgMaxIter) {
        public Options {
            if (lambda1 <= 0.0) {
                throw new IllegalArgumentException("Please use Ridge instead, wrong L1 portion setting: " + lambda1);
            }
            if (lambda2 <= 0.0) {
                throw new IllegalArgumentException("Please use LASSO instead, wrong L2 portion setting: " + lambda2);
            }
            if (tol <= 0.0) {
                throw new IllegalArgumentException("Invalid tolerance: " + tol);
            }
            if (maxIter <= 0) {
                throw new IllegalArgumentException("Invalid maximum number of iterations: " + maxIter);
            }
            if (alpha <= 0.0) {
                throw new IllegalArgumentException("Invalid alpha: " + alpha);
            }
            if (beta <= 0.0) {
                throw new IllegalArgumentException("Invalid beta: " + beta);
            }
            if (eta <= 0.0) {
                throw new IllegalArgumentException("Invalid eta: " + eta);
            }
            if (lsMaxIter <= 0) {
                throw new IllegalArgumentException("Invalid maximum number of line search iterations: " + lsMaxIter);
            }
            if (pcgMaxIter <= 0) {
                throw new IllegalArgumentException("Invalid maximum number of PCG iterations: " + pcgMaxIter);
            }
        }

        public Options(double lambda1, double lambda2) {
            this(lambda1, lambda2, 1.0E-4, 1000);
        }

        public Options(double lambda1, double lambda2, double tol, int maxIter) {
            this(lambda1, lambda2, tol, maxIter, 0.01, 0.5, 0.001, 100, 5000);
        }

        public Properties toProperties() {
            Properties props = new Properties();
            props.setProperty("smile.elastic_net.lambda1", Double.toString(this.lambda1));
            props.setProperty("smile.elastic_net.lambda2", Double.toString(this.lambda2));
            props.setProperty("smile.elastic_net.tolerance", Double.toString(this.tol));
            props.setProperty("smile.elastic_net.iterations", Integer.toString(this.maxIter));
            props.setProperty("smile.elastic_net.alpha", Double.toString(this.alpha));
            props.setProperty("smile.elastic_net.beta", Double.toString(this.beta));
            props.setProperty("smile.elastic_net.eta", Double.toString(this.eta));
            props.setProperty("smile.elastic_net.line_search_iterations", Integer.toString(this.lsMaxIter));
            props.setProperty("smile.elastic_net.pcg_iterations", Integer.toString(this.pcgMaxIter));
            return props;
        }

        public static Options of(Properties props) {
            double lambda1 = Double.parseDouble(props.getProperty("smile.elastic_net.lambda1"));
            double lambda2 = Double.parseDouble(props.getProperty("smile.elastic_net.lambda2"));
            double tol = Double.parseDouble(props.getProperty("smile.elastic_net.tolerance", "1E-4"));
            int maxIter = Integer.parseInt(props.getProperty("smile.elastic_net.iterations", "1000"));
            double alpha = Double.parseDouble(props.getProperty("smile.elastic_net.alpha", "0.01"));
            double beta = Double.parseDouble(props.getProperty("smile.elastic_net.beta", "0.5"));
            double eta = Double.parseDouble(props.getProperty("smile.elastic_net.eta", "1E-3"));
            int lsMaxIter = Integer.parseInt(props.getProperty("smile.elastic_net.line_search_iterations", "100"));
            int pcgMaxIter = Integer.parseInt(props.getProperty("smile.elastic_net.pcg_iterations", "5000"));
            return new Options(lambda1, lambda2, tol, maxIter, alpha, beta, eta, lsMaxIter, pcgMaxIter);
        }
    }
}

