/*
 * Decompiled with CFR 0.152.
 */
package smile.regression;

import smile.base.mlp.Cost;
import smile.base.mlp.Layer;
import smile.base.mlp.LayerBuilder;
import smile.base.mlp.MultilayerPerceptron;
import smile.base.mlp.OutputFunction;
import smile.base.mlp.OutputLayer;
import smile.regression.OnlineRegression;

public class MLP
extends MultilayerPerceptron
implements OnlineRegression<double[]> {
    private static final long serialVersionUID = 2L;

    public MLP(int p, LayerBuilder ... builders) {
        super(MLP.net(p, builders));
    }

    private static Layer[] net(int p, LayerBuilder ... builders) {
        int l = builders.length;
        Layer[] net = new Layer[l + 1];
        for (int i = 0; i < l; ++i) {
            net[i] = builders[i].build(p);
            p = builders[i].neurons();
        }
        net[l] = new OutputLayer(1, p, OutputFunction.LINEAR, Cost.MEAN_SQUARED_ERROR);
        return net;
    }

    @Override
    public double predict(double[] x) {
        this.propagate(x);
        return this.output.output()[0];
    }

    @Override
    public void update(double[] x, double y) {
        this.propagate(x);
        this.target[0] = y;
        this.backpropagate();
        this.update();
    }

    public void update(double[][] x, double[] y) {
        double a = this.alpha;
        this.alpha = 1.0;
        for (int i = 0; i < x.length; ++i) {
            this.propagate(x[i]);
            this.target[0] = y[i];
            this.backpropagate();
        }
        this.update();
        this.alpha = a;
    }
}

