/*
 * Decompiled with CFR 0.152.
 */
package smile.regression;

import java.util.Arrays;
import java.util.Objects;
import java.util.Properties;
import smile.math.BFGS;
import smile.math.MathEx;
import smile.math.kernel.MercerKernel;
import smile.regression.Regression;
import smile.stat.distribution.MultivariateGaussianDistribution;
import smile.tensor.Cholesky;
import smile.tensor.DenseMatrix;
import smile.tensor.EVD;
import smile.tensor.LU;
import smile.tensor.Vector;
import smile.util.function.DifferentiableMultivariateFunction;

public class GaussianProcessRegression<T>
implements Regression<T> {
    private static final long serialVersionUID = 2L;
    public final MercerKernel<T> kernel;
    public final T[] regressors;
    public final Vector w;
    public final double mean;
    public final double sd;
    public final double noise;
    public final double L;
    private final Cholesky cholesky;

    public GaussianProcessRegression(MercerKernel<T> kernel, T[] regressors, Vector weight, double noise) {
        this(kernel, regressors, weight, noise, 0.0, 1.0);
    }

    public GaussianProcessRegression(MercerKernel<T> kernel, T[] regressors, Vector weight, double noise, double mean, double sd) {
        this(kernel, regressors, weight, noise, mean, sd, null, Double.NaN);
    }

    public GaussianProcessRegression(MercerKernel<T> kernel, T[] regressors, Vector weight, double noise, double mean, double sd, Cholesky cholesky, double L) {
        if (noise < 0.0) {
            throw new IllegalArgumentException("Invalid noise variance: " + noise);
        }
        this.kernel = kernel;
        this.regressors = regressors;
        this.w = weight;
        this.noise = noise;
        this.mean = mean;
        this.sd = sd;
        this.cholesky = cholesky;
        this.L = L;
    }

    @Override
    public double predict(T x) {
        int n = this.regressors.length;
        double mu = 0.0;
        for (int i = 0; i < n; ++i) {
            mu += this.w.get(i) * this.kernel.k(x, this.regressors[i]);
        }
        return mu * this.sd + this.mean;
    }

    public double predict(T x, double[] estimation) {
        if (this.cholesky == null) {
            throw new UnsupportedOperationException("The Cholesky decomposition of kernel matrix is not available.");
        }
        int n = this.regressors.length;
        double[] k_ = new double[n];
        for (int i = 0; i < n; ++i) {
            k_[i] = this.kernel.k(x, this.regressors[i]);
        }
        Vector k = Vector.column((double[])k_);
        Vector Kx = this.cholesky.solve(k_);
        double mu = this.w.dot(k);
        double sd = Math.sqrt(this.kernel.k(x, x) - Kx.dot(k));
        mu = mu * this.sd + this.mean;
        estimation[0] = mu;
        estimation[1] = sd *= this.sd;
        return mu;
    }

    public JointPrediction query(T[] samples) {
        if (this.cholesky == null) {
            throw new UnsupportedOperationException("The Cholesky decomposition of kernel matrix is not available.");
        }
        DenseMatrix Kt = this.kernel.K((Object[])samples, (Object[])this.regressors);
        DenseMatrix Kv = Kt.transpose();
        this.cholesky.solve(Kv);
        DenseMatrix cov = this.kernel.K((Object[])samples);
        cov.sub(Kt.mm(Kv));
        cov.scale(this.sd * this.sd);
        Vector mu = Kt.mv(this.w);
        Vector std = cov.diagonal();
        int m = samples.length;
        for (int i = 0; i < m; ++i) {
            mu.set(i, mu.get(i) * this.sd + this.mean);
            std.set(i, Math.sqrt(std.get(i)));
        }
        return new JointPrediction(this, samples, mu.toArray(new double[0]), std.toArray(new double[0]), cov);
    }

    public String toString() {
        StringBuilder sb = new StringBuilder("GaussianProcessRegression {\n");
        sb.append("  kernel: ").append(this.kernel).append(",\n");
        sb.append("  regressors: ").append(this.regressors.length).append(",\n");
        sb.append("  mean: ").append(String.format("%.4f,\n", this.mean));
        sb.append("  std.dev: ").append(String.format("%.4f,\n", this.sd));
        sb.append("  noise: ").append(String.format("%.4f", this.noise));
        if (!Double.isNaN(this.L)) {
            sb.append(",\n  log marginal likelihood: ").append(String.format("%.4f", this.L));
        }
        sb.append("\n}");
        return sb.toString();
    }

    public static GaussianProcessRegression<double[]> fit(double[][] x, double[] y, Properties params) {
        MercerKernel kernel = MercerKernel.of((String)params.getProperty("smile.gaussian_process.kernel", "linear"));
        return GaussianProcessRegression.fit(x, y, kernel, Options.of(params));
    }

    public static <T> GaussianProcessRegression<T> fit(T[] x, double[] y, MercerKernel<T> kernel, Options options) {
        if (x.length != y.length) {
            throw new IllegalArgumentException(String.format("The sizes of X and Y don't match: %d != %d", x.length, y.length));
        }
        int n = x.length;
        double mean = 0.0;
        double std = 1.0;
        if (options.normalize) {
            mean = MathEx.mean((double[])y);
            std = MathEx.stdev((double[])y);
            double[] target = new double[n];
            for (int i = 0; i < n; ++i) {
                target[i] = (y[i] - mean) / std;
            }
            y = target;
        }
        double noise = options.noise;
        if (options.maxIter > 0) {
            LogMarginalLikelihood<T> objective = new LogMarginalLikelihood<T>(x, y, kernel);
            double[] hp = kernel.hyperparameters();
            double[] lo = kernel.lo();
            double[] hi = kernel.hi();
            int m = lo.length;
            double[] params = Arrays.copyOf(hp, m + 1);
            double[] l = Arrays.copyOf(lo, m + 1);
            double[] u = Arrays.copyOf(hi, m + 1);
            params[m] = noise;
            l[m] = 1.0E-10;
            u[m] = 100000.0;
            BFGS.minimize(objective, (int)5, (double[])params, (double[])l, (double[])u, (double)options.tol, (int)options.maxIter);
            kernel = kernel.of(params);
            noise = params[params.length - 1];
        }
        DenseMatrix K = kernel.K((Object[])x);
        for (int i = 0; i < n; ++i) {
            K.add(i, i, noise);
        }
        Cholesky cholesky = K.cholesky();
        Vector w = cholesky.solve(y);
        double L = -0.5 * (w.dot(Vector.column((double[])y)) + cholesky.logdet() + (double)n * Math.log(Math.PI * 2));
        return new GaussianProcessRegression<T>(kernel, x, w, noise, mean, std, cholesky, L);
    }

    public static <T> GaussianProcessRegression<T> fit(T[] x, double[] y, T[] t, MercerKernel<T> kernel, Options options) {
        if (x.length != y.length) {
            throw new IllegalArgumentException(String.format("The sizes of X and Y don't match: %d != %d", x.length, y.length));
        }
        double mean = 0.0;
        double std = 1.0;
        if (options.normalize) {
            mean = MathEx.mean((double[])y);
            std = MathEx.stdev((double[])y);
            int n = x.length;
            double[] target = new double[n];
            for (int i = 0; i < n; ++i) {
                target[i] = (y[i] - mean) / std;
            }
            y = target;
        }
        double noise = options.noise;
        DenseMatrix G = kernel.K((Object[])x, (Object[])t);
        DenseMatrix K = G.ata();
        DenseMatrix Kt = kernel.K((Object[])t);
        K.axpy(noise, Kt);
        LU lu = K.lu();
        Vector w = G.tv(y);
        lu.solve((DenseMatrix)w);
        return new GaussianProcessRegression<T>(kernel, t, w, noise, mean, std);
    }

    public static <T> GaussianProcessRegression<T> nystrom(T[] x, double[] y, T[] t, MercerKernel<T> kernel, Options options) {
        if (x.length != y.length) {
            throw new IllegalArgumentException(String.format("The sizes of X and Y don't match: %d != %d", x.length, y.length));
        }
        int n = x.length;
        int m = t.length;
        double mean = 0.0;
        double std = 1.0;
        if (options.normalize) {
            mean = MathEx.mean((double[])y);
            std = MathEx.stdev((double[])y);
            double[] target = new double[n];
            for (int i = 0; i < n; ++i) {
                target[i] = (y[i] - mean) / std;
            }
            y = target;
        }
        DenseMatrix E = kernel.K((Object[])x, (Object[])t);
        DenseMatrix W = kernel.K((Object[])t);
        EVD eigen = W.eigen().sort();
        DenseMatrix U = eigen.Vr();
        DenseMatrix D = eigen.diag();
        for (int i = 0; i < m; ++i) {
            D.set(i, i, 1.0 / Math.sqrt(D.get(i, i)));
        }
        double noise = options.noise;
        DenseMatrix UD = U.mm(D);
        DenseMatrix UDUt = UD.mt(U);
        DenseMatrix L = E.mm(UDUt);
        DenseMatrix LtL = L.ata();
        for (int i = 0; i < LtL.nrow(); ++i) {
            LtL.add(i, i, noise);
        }
        Cholesky chol = LtL.cholesky();
        DenseMatrix invLtL = chol.inverse();
        DenseMatrix Kinv = L.mm(invLtL).mt(L);
        Vector w = Kinv.tv(y);
        for (int i = 0; i < n; ++i) {
            w.set(i, (y[i] - w.get(i)) / noise);
        }
        return new GaussianProcessRegression<T>(kernel, x, w, noise, mean, std);
    }

    public class JointPrediction {
        public final T[] x;
        public final double[] mu;
        public final double[] sd;
        public final DenseMatrix cov;
        private MultivariateGaussianDistribution dist;

        public JointPrediction(GaussianProcessRegression this$0, T[] x, double[] mu, double[] sd, DenseMatrix cov) {
            Objects.requireNonNull(this$0);
            this.x = x;
            this.mu = mu;
            this.sd = sd;
            this.cov = cov;
        }

        public double[][] sample(int n) {
            if (this.dist == null) {
                this.dist = new MultivariateGaussianDistribution(this.mu, this.cov);
            }
            return this.dist.rand(n);
        }

        public String toString() {
            return String.format("GaussianProcessRegression.Prediction {\n  mean    = %s\n  std.dev = %s\n  cov     = %s\n}", Arrays.toString(this.mu), Arrays.toString(this.sd), this.cov.toString(true));
        }
    }

    public record Options(double noise, boolean normalize, double tol, int maxIter) {
        public Options {
            if (noise < 0.0) {
                throw new IllegalArgumentException("Invalid noise variance = " + noise);
            }
            if (tol <= 0.0) {
                throw new IllegalArgumentException("Invalid tolerance: " + tol);
            }
        }

        public Options(double noise) {
            this(noise, true);
        }

        public Options(double noise, boolean normalize) {
            this(noise, normalize, 1.0E-5, 0);
        }

        public Properties toProperties() {
            Properties props = new Properties();
            props.setProperty("smile.gaussian_process.noise", Double.toString(this.noise));
            props.setProperty("smile.gaussian_process.normalize", Boolean.toString(this.normalize));
            props.setProperty("smile.gaussian_process.tolerance", Double.toString(this.tol));
            props.setProperty("smile.gaussian_process.iterations", Integer.toString(this.maxIter));
            return props;
        }

        public static Options of(Properties props) {
            double noise = Double.parseDouble(props.getProperty("smile.gaussian_process.noise", "1E-10"));
            boolean normalize = Boolean.parseBoolean(props.getProperty("smile.gaussian_process.normalize", "true"));
            double tol = Double.parseDouble(props.getProperty("smile.gaussian_process.tolerance", "1E-5"));
            int maxIter = Integer.parseInt(props.getProperty("smile.gaussian_process.iterations", "0"));
            return new Options(noise, normalize, tol, maxIter);
        }
    }

    private static class LogMarginalLikelihood<T>
    implements DifferentiableMultivariateFunction {
        final T[] x;
        final double[] y;
        MercerKernel<T> kernel;

        public LogMarginalLikelihood(T[] x, double[] y, MercerKernel<T> kernel) {
            this.x = x;
            this.y = y;
            this.kernel = kernel;
        }

        public double f(double[] params) {
            this.kernel = this.kernel.of(params);
            double noise = params[params.length - 1];
            DenseMatrix K = this.kernel.K((Object[])this.x);
            for (int i = 0; i < K.nrow(); ++i) {
                K.add(i, i, noise);
            }
            Cholesky cholesky = K.cholesky();
            Vector w = cholesky.solve(this.y);
            int n = this.x.length;
            double L = -0.5 * (w.dot(Vector.column((double[])this.y)) + cholesky.logdet() + (double)n * Math.log(Math.PI * 2));
            return -L;
        }

        public double g(double[] params, double[] g) {
            this.kernel = this.kernel.of(params);
            double noise = params[params.length - 1];
            DenseMatrix[] K = this.kernel.KG((Object[])this.x);
            DenseMatrix Ky = K[0];
            for (int i = 0; i < Ky.nrow(); ++i) {
                Ky.add(i, i, noise);
            }
            Cholesky cholesky = Ky.cholesky();
            DenseMatrix Kinv = cholesky.inverse();
            Vector w = Kinv.mv(this.y);
            g[g.length - 1] = -(w.dot(w) - Kinv.trace()) / 2.0;
            for (int i = 1; i < g.length; ++i) {
                DenseMatrix Kg = K[i];
                double gi = Kg.xAx(w) - Kinv.mm(Kg).trace();
                g[i - 1] = -gi / 2.0;
            }
            int n = this.x.length;
            double L = -0.5 * (w.dot(Vector.column((double[])this.y)) + cholesky.logdet() + (double)n * Math.log(Math.PI * 2));
            return -L;
        }
    }
}

