/*
 * Decompiled with CFR 0.152.
 */
package smile.regression;

import java.util.Arrays;
import java.util.Properties;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import smile.data.DataFrame;
import smile.data.formula.Formula;
import smile.data.type.StructType;
import smile.linalg.Transpose;
import smile.math.MathEx;
import smile.regression.LinearModel;
import smile.tensor.BiconjugateGradient;
import smile.tensor.DenseMatrix;
import smile.tensor.Matrix;
import smile.tensor.Preconditioner;
import smile.tensor.ScalarType;
import smile.tensor.Vector;

public class LASSO {
    private static final Logger logger = LoggerFactory.getLogger(LASSO.class);

    private LASSO() {
    }

    public static LinearModel fit(Formula formula, DataFrame data) {
        return LASSO.fit(formula, data, new Options(1.0));
    }

    public static LinearModel fit(Formula formula, DataFrame data, Options options) {
        formula = formula.expand(data.schema());
        StructType schema = formula.bind(data.schema());
        DenseMatrix X = formula.matrix(data, false);
        double[] y = formula.y(data).toDoubleArray();
        int n = X.nrow();
        int p = X.ncol();
        Vector center = X.colMeans();
        Vector scale = X.colSds();
        for (int j = 0; j < p; ++j) {
            if (!MathEx.isZero((double)scale.get(j))) continue;
            throw new IllegalArgumentException(String.format("The column '%s' is constant", schema.names()[j]));
        }
        DenseMatrix scaledX = X.standardize(center, scale);
        double[] centeredY = new double[n];
        double ymu = MathEx.mean((double[])y);
        for (int i = 0; i < n; ++i) {
            centeredY[i] = y[i] - ymu;
        }
        Vector w = LASSO.train(scaledX, centeredY, options);
        for (int j = 0; j < p; ++j) {
            w.div(j, scale.get(j));
        }
        double b = ymu - w.dot(center);
        return new LinearModel(formula, schema, X, y, w, b);
    }

    static Vector train(DenseMatrix x, double[] y, Options options) {
        int iter;
        double tol = options.tol;
        double lambda = options.lambda;
        int maxIter = options.maxIter;
        int MU = 2;
        double alpha = options.alpha;
        double beta = options.beta;
        double eta = options.eta;
        int lsMaxIter = options.lsMaxIter;
        int pcgMaxIter = options.pcgMaxIter;
        int pitr = 0;
        int n = x.nrow();
        int p = x.ncol();
        double t = Math.min(Math.max(1.0, 1.0 / lambda), (double)(2 * p) / 0.001);
        double dobj = Double.NEGATIVE_INFINITY;
        double s = Double.POSITIVE_INFINITY;
        double[] w = new double[p];
        double[] u = new double[p];
        double[] z = new double[n];
        double[][] f = new double[2][p];
        Arrays.fill(u, 1.0);
        for (int i = 0; i < p; ++i) {
            f[0][i] = w[i] - u[i];
            f[1][i] = -w[i] - u[i];
        }
        double[] neww = new double[p];
        double[] newu = new double[p];
        double[] newz = new double[n];
        double[][] newf = new double[2][p];
        double[] dx = new double[p];
        double[] du = new double[p];
        double[] dxu = new double[2 * p];
        double[] grad = new double[2 * p];
        double[] diagxtx = new double[p];
        Arrays.fill(diagxtx, 2.0);
        double[] nu = new double[n];
        double[] xnu = new double[p];
        double[] q1 = new double[p];
        double[] q2 = new double[p];
        double[] d1 = new double[p];
        double[] d2 = new double[p];
        double[][] gradphi = new double[2][p];
        double[] prb = new double[p];
        double[] prs = new double[p];
        PCG pcg = new PCG(x, d1, d2, prb, prs);
        Vector w_ = Vector.column((double[])w);
        Vector z_ = Vector.column((double[])z);
        Vector nu_ = Vector.column((double[])nu);
        Vector xnu_ = Vector.column((double[])xnu);
        Vector neww_ = Vector.column((double[])neww);
        Vector newz_ = Vector.column((double[])newz);
        Vector grad_ = Vector.column((double[])grad);
        Vector dxu_ = Vector.column((double[])dxu);
        Vector gradphi0 = Vector.column((double[])gradphi[0]);
        for (iter = 1; iter <= maxIter; ++iter) {
            int lsiter;
            double error;
            int i;
            x.mv(w_, z_);
            for (int i2 = 0; i2 < n; ++i2) {
                int n2 = i2;
                z[n2] = z[n2] - y[i2];
                nu[i2] = 2.0 * z[i2];
            }
            x.tv(nu_, xnu_);
            double maxXnu = xnu_.normInf();
            if (maxXnu > lambda) {
                double lnu = lambda / maxXnu;
                int i3 = 0;
                while (i3 < n) {
                    int n3 = i3++;
                    nu[n3] = nu[n3] * lnu;
                }
            }
            double pobj = MathEx.dot((double[])z, (double[])z) + lambda * MathEx.norm1((double[])w);
            dobj = Math.max(-0.25 * MathEx.dot((double[])nu, (double[])nu) - MathEx.dot((double[])nu, (double[])y), dobj);
            double gap = pobj - dobj;
            if (iter % 10 == 0 || gap / dobj < tol) {
                logger.info("Iteration {}: primal objective = {}, dual objective = {}", new Object[]{iter, pobj, dobj});
            }
            if (gap / dobj < tol) break;
            if (s >= 0.5) {
                t = Math.max(Math.min((double)(2 * p * 2) / gap, 2.0 * t), t);
            }
            for (i = 0; i < p; ++i) {
                double q1i = 1.0 / (u[i] + w[i]);
                double q2i = 1.0 / (u[i] - w[i]);
                q1[i] = q1i;
                q2[i] = q2i;
                d1[i] = (q1i * q1i + q2i * q2i) / t;
                d2[i] = (q1i * q1i - q2i * q2i) / t;
            }
            x.tv(z_, gradphi0);
            for (i = 0; i < p; ++i) {
                gradphi[0][i] = 2.0 * gradphi[0][i] - (q1[i] - q2[i]) / t;
                gradphi[1][i] = lambda - (q1[i] + q2[i]) / t;
                grad[i] = -gradphi[0][i];
                grad[i + p] = -gradphi[1][i];
            }
            for (i = 0; i < p; ++i) {
                prb[i] = diagxtx[i] + d1[i];
                prs[i] = prb[i] * d1[i] - d2[i] * d2[i];
            }
            double normg = MathEx.norm((double[])grad);
            double pcgtol = Math.min(0.1, eta * gap / Math.min(1.0, normg));
            if (iter != 0 && pitr == 0) {
                pcgtol *= 0.1;
            }
            if ((error = BiconjugateGradient.solve((Matrix)pcg, (Vector)grad_, (Vector)dxu_, (Preconditioner)pcg, (double)pcgtol, (int)1, (int)pcgMaxIter)) > pcgtol) {
                pitr = pcgMaxIter;
            }
            for (int i4 = 0; i4 < p; ++i4) {
                dx[i4] = dxu[i4];
                du[i4] = dxu[i4 + p];
            }
            double phi = MathEx.dot((double[])z, (double[])z) + lambda * MathEx.sum((double[])u) - LASSO.sumlogneg(f) / t;
            s = 1.0;
            double gdx = MathEx.dot((double[])grad, (double[])dxu);
            for (lsiter = 0; lsiter < lsMaxIter; ++lsiter) {
                int i5;
                for (i5 = 0; i5 < p; ++i5) {
                    neww[i5] = w[i5] + s * dx[i5];
                    newu[i5] = u[i5] + s * du[i5];
                    newf[0][i5] = neww[i5] - newu[i5];
                    newf[1][i5] = -neww[i5] - newu[i5];
                }
                if (MathEx.max((double[][])newf) < 0.0) {
                    x.mv(neww_, newz_);
                    for (i5 = 0; i5 < n; ++i5) {
                        int n4 = i5;
                        newz[n4] = newz[n4] - y[i5];
                    }
                    double newphi = MathEx.dot((double[])newz, (double[])newz) + lambda * MathEx.sum((double[])newu) - LASSO.sumlogneg(newf) / t;
                    if (newphi - phi <= alpha * s * gdx) break;
                }
                s = beta * s;
            }
            if (lsiter == lsMaxIter) {
                logger.warn("Linear search reaches maximum number of iterations: {}", (Object)lsMaxIter);
                break;
            }
            System.arraycopy(neww, 0, w, 0, p);
            System.arraycopy(newu, 0, u, 0, p);
            System.arraycopy(newf[0], 0, f[0], 0, p);
            System.arraycopy(newf[1], 0, f[1], 0, p);
        }
        if (iter == maxIter) {
            logger.warn("IPM reaches maximum number of iterations: {}", (Object)maxIter);
        }
        return w_;
    }

    private static double sumlogneg(double[][] f) {
        double sum = 0.0;
        double[][] dArray = f;
        int n = dArray.length;
        for (int i = 0; i < n; ++i) {
            double[] row;
            for (double x : row = dArray[i]) {
                sum += Math.log(-x);
            }
        }
        return sum;
    }

    public record Options(double lambda, double tol, int maxIter, double alpha, double beta, double eta, int lsMaxIter, int pcgMaxIter) {
        public Options {
            if (lambda < 0.0) {
                throw new IllegalArgumentException("Invalid shrinkage/regularization parameter lambda = " + lambda);
            }
            if (tol <= 0.0) {
                throw new IllegalArgumentException("Invalid tolerance: " + tol);
            }
            if (maxIter <= 0) {
                throw new IllegalArgumentException("Invalid maximum number of iterations: " + maxIter);
            }
            if (alpha <= 0.0) {
                throw new IllegalArgumentException("Invalid alpha: " + alpha);
            }
            if (beta <= 0.0) {
                throw new IllegalArgumentException("Invalid beta: " + beta);
            }
            if (eta <= 0.0) {
                throw new IllegalArgumentException("Invalid eta: " + eta);
            }
            if (lsMaxIter <= 0) {
                throw new IllegalArgumentException("Invalid maximum number of line search iterations: " + lsMaxIter);
            }
            if (pcgMaxIter <= 0) {
                throw new IllegalArgumentException("Invalid maximum number of PCG iterations: " + pcgMaxIter);
            }
        }

        public Options(double lambda) {
            this(lambda, 1.0E-4, 1000);
        }

        public Options(double lambda, double tol, int maxIter) {
            this(lambda, tol, maxIter, 0.01, 0.5, 0.001, 100, 5000);
        }

        public Properties toProperties() {
            Properties props = new Properties();
            props.setProperty("smile.lasso.lambda", Double.toString(this.lambda));
            props.setProperty("smile.lasso.tolerance", Double.toString(this.tol));
            props.setProperty("smile.lasso.iterations", Integer.toString(this.maxIter));
            props.setProperty("smile.lasso.alpha", Double.toString(this.alpha));
            props.setProperty("smile.lasso.beta", Double.toString(this.beta));
            props.setProperty("smile.lasso.eta", Double.toString(this.eta));
            props.setProperty("smile.lasso.line_search_iterations", Integer.toString(this.lsMaxIter));
            props.setProperty("smile.lasso.pcg_iterations", Integer.toString(this.pcgMaxIter));
            return props;
        }

        public static Options of(Properties props) {
            double lambda = Double.parseDouble(props.getProperty("smile.lasso.lambda", "1"));
            double tol = Double.parseDouble(props.getProperty("smile.lasso.tolerance", "1E-4"));
            int maxIter = Integer.parseInt(props.getProperty("smile.lasso.iterations", "1000"));
            double alpha = Double.parseDouble(props.getProperty("smile.lasso.alpha", "0.01"));
            double beta = Double.parseDouble(props.getProperty("smile.lasso.beta", "0.5"));
            double eta = Double.parseDouble(props.getProperty("smile.lasso.eta", "1E-3"));
            int lsMaxIter = Integer.parseInt(props.getProperty("smile.lasso.line_search_iterations", "100"));
            int pcgMaxIter = Integer.parseInt(props.getProperty("smile.lasso.pcg_iterations", "5000"));
            return new Options(lambda, tol, maxIter, alpha, beta, eta, lsMaxIter, pcgMaxIter);
        }
    }

    static class PCG
    implements Matrix,
    Preconditioner {
        final DenseMatrix A;
        Matrix AtA;
        final int p;
        final double[] d1;
        final double[] d2;
        final double[] prb;
        final double[] prs;
        final Vector ax;
        final Vector atax;

        PCG(DenseMatrix A, double[] d1, double[] d2, double[] prb, double[] prs) {
            this.A = A;
            this.d1 = d1;
            this.d2 = d2;
            this.prb = prb;
            this.prs = prs;
            int n = A.nrow();
            this.p = A.ncol();
            this.ax = A.vector(n);
            this.atax = A.vector(this.p);
            if (A.ncol() < 10000) {
                this.AtA = A.ata();
            }
        }

        public int nrow() {
            return 2 * this.p;
        }

        public int ncol() {
            return 2 * this.p;
        }

        public long length() {
            return this.A.length();
        }

        public ScalarType scalarType() {
            return this.A.scalarType();
        }

        public void mv(Vector x, Vector y) {
            if (this.AtA != null) {
                this.AtA.mv(x, this.atax);
            } else {
                this.A.mv(x, this.ax);
                this.A.tv(this.ax, this.atax);
            }
            for (int i = 0; i < this.p; ++i) {
                y.set(i, 2.0 * this.atax.get(i) + this.d1[i] * x.get(i) + this.d2[i] * x.get(i + this.p));
                y.set(i + this.p, this.d2[i] * x.get(i) + this.d1[i] * x.get(i + this.p));
            }
        }

        public void tv(Vector x, Vector y) {
            this.mv(x, y);
        }

        public void solve(Vector b, Vector x) {
            for (int i = 0; i < this.p; ++i) {
                x.set(i, (this.d1[i] * b.get(i) - this.d2[i] * b.get(i + this.p)) / this.prs[i]);
                x.set(i + this.p, (-this.d2[i] * b.get(i) + this.prb[i] * b.get(i + this.p)) / this.prs[i]);
            }
        }

        public void mv(Transpose trans, double alpha, Vector x, double beta, Vector y) {
            throw new UnsupportedOperationException();
        }

        public void tv(Vector work, int inputOffset, int outputOffset) {
            throw new UnsupportedOperationException();
        }

        public double get(int i, int j) {
            throw new UnsupportedOperationException();
        }

        public void set(int i, int j, double x) {
            throw new UnsupportedOperationException();
        }

        public void add(int i, int j, double x) {
            throw new UnsupportedOperationException();
        }

        public void sub(int i, int j, double x) {
            throw new UnsupportedOperationException();
        }

        public void mul(int i, int j, double x) {
            throw new UnsupportedOperationException();
        }

        public void div(int i, int j, double x) {
            throw new UnsupportedOperationException();
        }

        public Matrix scale(double alpha) {
            throw new UnsupportedOperationException();
        }

        public Matrix copy() {
            throw new UnsupportedOperationException();
        }

        public Matrix transpose() {
            throw new UnsupportedOperationException();
        }
    }
}

