/*
 * Decompiled with CFR 0.152.
 */
package smile.regression;

import java.util.Arrays;
import java.util.Properties;
import smile.data.DataFrame;
import smile.data.formula.Formula;
import smile.data.type.StructType;
import smile.linalg.UPLO;
import smile.math.MathEx;
import smile.regression.LinearModel;
import smile.tensor.Cholesky;
import smile.tensor.DenseMatrix;
import smile.tensor.Vector;
import smile.util.Strings;

public class RidgeRegression {
    private RidgeRegression() {
    }

    public static LinearModel fit(Formula formula, DataFrame data, double lambda) {
        int n = data.size();
        double[] weights = new double[n];
        Arrays.fill(weights, 1.0);
        return RidgeRegression.fit(formula, data, weights, new Options(lambda));
    }

    public static LinearModel fit(Formula formula, DataFrame data, double[] weights, Options options) {
        formula = formula.expand(data.schema());
        StructType schema = formula.bind(data.schema());
        DenseMatrix X = formula.matrix(data, false);
        double[] y = formula.y(data).toDoubleArray();
        int n = X.nrow();
        int p = X.ncol();
        if (weights.length != n) {
            throw new IllegalArgumentException(String.format("Invalid weights vector size: %d != %d", weights.length, n));
        }
        for (int i = 0; i < n; ++i) {
            if (!(weights[i] <= 0.0)) continue;
            throw new IllegalArgumentException(String.format("Invalid weights[%d] = %f", i, weights[i]));
        }
        double[] lambda = options.lambda;
        if (lambda.length == 1) {
            double shrinkage = lambda[0];
            lambda = new double[p];
            Arrays.fill(lambda, shrinkage);
        } else if (lambda.length != p) {
            throw new IllegalArgumentException(String.format("Invalid lambda vector size: %d != %d", lambda.length, p));
        }
        for (int i = 0; i < p; ++i) {
            if (!(lambda[i] < 0.0)) continue;
            throw new IllegalArgumentException(String.format("Invalid lambda[%d] = %f", i, lambda[i]));
        }
        double[] beta0 = options.beta0;
        if (beta0.length == 1) {
            double beta = beta0[0];
            beta0 = new double[p];
            Arrays.fill(beta0, beta);
        } else if (beta0.length != p) {
            throw new IllegalArgumentException(String.format("Invalid beta0 vector size: %d != %d", beta0.length, p));
        }
        Vector center = X.colMeans();
        Vector scale = X.colSds();
        for (int j = 0; j < scale.size(); ++j) {
            if (!MathEx.isZero((double)scale.get(j))) continue;
            throw new IllegalArgumentException(String.format("The column '%s' is constant", schema.names()[j]));
        }
        DenseMatrix scaledX = X.standardize(center, scale);
        DenseMatrix XtW = X.zeros(p, n);
        for (int i = 0; i < p; ++i) {
            for (int j = 0; j < n; ++j) {
                XtW.set(i, j, weights[j] * scaledX.get(j, i));
            }
        }
        Vector w = XtW.mv(y);
        for (int i = 0; i < p; ++i) {
            w.add(i, lambda[i] * beta0[i]);
        }
        DenseMatrix XtX = XtW.mm(scaledX);
        XtX.withUplo(UPLO.LOWER);
        for (int i = 0; i < XtX.nrow(); ++i) {
            XtX.add(i, i, lambda[i]);
        }
        Cholesky cholesky = XtX.cholesky();
        cholesky.solve((DenseMatrix)w);
        for (int j = 0; j < p; ++j) {
            w.div(j, scale.get(j));
        }
        double b = MathEx.mean((double[])y) - w.dot(center);
        return new LinearModel(formula, schema, X, y, w, b);
    }

    public record Options(double[] lambda, double[] beta0) {
        public Options {
            for (double value : lambda) {
                if (!(value < 0.0)) continue;
                throw new IllegalArgumentException("Invalid shrinkage/regularization parameter lambda = " + value);
            }
            for (double value : beta0) {
                if (!(value < 0.0)) continue;
                throw new IllegalArgumentException("Invalid generalized ridge penalty target beta0 = " + value);
            }
        }

        public Options(double lambda) {
            this(lambda, 0.0);
        }

        public Options(double lambda, double beta0) {
            this(new double[]{lambda}, new double[]{beta0});
        }

        public Properties toProperties() {
            Properties props = new Properties();
            props.setProperty("smile.ridge.lambda", Arrays.toString(this.lambda));
            props.setProperty("smile.ridge.beta0", Arrays.toString(this.beta0));
            return props;
        }

        public static Options of(Properties props) {
            String lambda = props.getProperty("smile.ridge.lambda", "1");
            String beta0 = props.getProperty("smile.ridge.beta0", "0");
            try {
                return new Options(Double.parseDouble(lambda), Double.parseDouble(beta0));
            }
            catch (Exception e) {
                return new Options(Strings.parseDoubleArray((String)lambda), Strings.parseDoubleArray((String)beta0));
            }
        }
    }
}

