/*
 * Decompiled with CFR 0.152.
 */
package smile.regression;

import java.util.Properties;
import smile.base.rbf.RBF;
import smile.regression.Regression;
import smile.tensor.DenseMatrix;
import smile.tensor.QR;
import smile.tensor.ScalarType;

public class RBFNetwork<T>
implements Regression<T> {
    private static final long serialVersionUID = 2L;
    private final double[] w;
    private final RBF<T>[] rbf;
    private final boolean normalized;

    public RBFNetwork(RBF<T>[] rbf, double[] w, boolean normalized) {
        this.rbf = rbf;
        this.w = w;
        this.normalized = normalized;
    }

    public static <T> RBFNetwork<T> fit(T[] x, double[] y, RBF<T>[] rbf) {
        return RBFNetwork.fit(x, y, rbf, false);
    }

    public static <T> RBFNetwork<T> fit(T[] x, double[] y, RBF<T>[] rbf, boolean normalize) {
        if (x.length != y.length) {
            throw new IllegalArgumentException(String.format("The sizes of X and Y don't match: %d != %d", x.length, y.length));
        }
        int n = x.length;
        int m = rbf.length;
        DenseMatrix G = DenseMatrix.zeros((ScalarType)ScalarType.Float64, (int)n, (int)m);
        double[] b = new double[n];
        for (int i = 0; i < n; ++i) {
            double sum = 0.0;
            for (int j = 0; j < m; ++j) {
                double r = rbf[j].f(x[i]);
                G.set(i, j, r);
                sum += r;
            }
            b[i] = normalize ? sum * y[i] : y[i];
        }
        QR qr = G.qr();
        double[] w = qr.solve(b).toArray(new double[0]);
        return new RBFNetwork<T>(rbf, w, normalize);
    }

    public static RBFNetwork<double[]> fit(double[][] x, double[] y, Properties params) {
        int neurons = Integer.parseInt(params.getProperty("smile.rbf.neurons", "30"));
        boolean normalize = Boolean.parseBoolean(params.getProperty("smile.rbf.normalize", "false"));
        return RBFNetwork.fit(x, y, RBF.fit(x, neurons), normalize);
    }

    public boolean isNormalized() {
        return this.normalized;
    }

    @Override
    public double predict(T x) {
        double sum = 0.0;
        double sumw = 0.0;
        for (int i = 0; i < this.rbf.length; ++i) {
            double f = this.rbf[i].f(x);
            sumw += this.w[i] * f;
            sum += f;
        }
        return this.normalized ? sumw / sum : sumw;
    }
}

