/*
 * Decompiled with CFR 0.152.
 */
package smile.base.mlp;

import java.io.Serializable;
import java.util.Arrays;
import java.util.stream.Collectors;
import smile.base.mlp.Layer;
import smile.base.mlp.OutputLayer;

public abstract class MultilayerPerceptron
implements Serializable {
    private static final long serialVersionUID = 2L;
    protected int p;
    protected double[] x1;
    protected OutputLayer output;
    protected Layer[] net;
    protected double[] target;
    protected double eta = 0.1;
    protected double alpha = 0.0;
    protected double lambda = 0.0;

    public MultilayerPerceptron(Layer ... net) {
        if (net.length < 2) {
            throw new IllegalArgumentException("Too few layers: " + net.length);
        }
        Layer lower = net[0];
        for (int i = 1; i < net.length; ++i) {
            Layer layer = net[i];
            if (layer.getInputSize() != lower.getOutputSize()) {
                throw new IllegalArgumentException(String.format("Invalid network architecture. Layer %d has %d neurons while layer %d takes %d inputs", i - 1, lower.getOutputSize(), i, layer.getInputSize()));
            }
            lower = layer;
        }
        this.output = (OutputLayer)net[net.length - 1];
        this.net = Arrays.copyOf(net, net.length - 1);
        this.p = net[0].getInputSize();
        this.x1 = new double[this.p + 1];
        this.x1[this.p] = 1.0;
        this.target = new double[this.output.getOutputSize()];
    }

    public String toString() {
        return String.format("x(%d) -> %s -> %s(eta = %.2f, alpha = %.2f, lambda = %.2f)", this.p, Arrays.stream(this.net).map(Object::toString).collect(Collectors.joining(" -> ")), this.output, this.eta, this.alpha, this.lambda);
    }

    public void setLearningRate(double eta) {
        if (eta <= 0.0) {
            throw new IllegalArgumentException("Invalid learning rate: " + eta);
        }
        this.eta = eta;
    }

    public void setMomentum(double alpha) {
        if (alpha < 0.0 || alpha >= 1.0) {
            throw new IllegalArgumentException("Invalid momentum factor: " + alpha);
        }
        this.alpha = alpha;
    }

    public void setWeightDecay(double lambda) {
        if (lambda < 0.0 || lambda > 0.1) {
            throw new IllegalArgumentException("Invalid weight decay factor: " + lambda);
        }
        this.lambda = lambda;
    }

    public double getLearningRate() {
        return this.eta;
    }

    public double getMomentum() {
        return this.alpha;
    }

    public double getWeightDecay() {
        return this.lambda;
    }

    protected void propagate(double[] x) {
        if (x.length != this.x1.length - 1) {
            throw new IllegalArgumentException(String.format("Invalid input vector size: %d, expected: %d", x.length, this.x1.length - 1));
        }
        System.arraycopy(x, 0, this.x1, 0, x.length);
        double[] input = this.x1;
        for (int i = 0; i < this.net.length; ++i) {
            this.net[i].propagate(input);
            input = this.net[i].output();
        }
        this.output.propagate(input);
    }

    protected void backpropagate() {
        this.output.computeError(this.target, 1.0);
        Layer upper = this.output;
        for (int i = this.net.length - 1; i >= 0; --i) {
            double[] error = this.net[i].gradient();
            ((Layer)upper).backpropagate(error);
            upper = this.net[i];
        }
        ((Layer)upper).backpropagate(null);
        double[] x = this.x1;
        for (Layer layer : this.net) {
            layer.computeUpdate(this.eta, this.alpha, x);
            x = layer.output();
        }
        this.output.computeUpdate(this.eta, this.alpha, x);
    }

    protected void update() {
        double decay = 1.0 - 2.0 * this.eta * this.lambda;
        if (decay < 0.9) {
            throw new IllegalStateException(String.format("Invalid learning rate (eta = %.2f) and/or decay (lambda = %.2f)", this.eta, this.lambda));
        }
        for (Layer layer : this.net) {
            layer.update(this.alpha, decay);
        }
        this.output.update(this.alpha, decay);
    }
}

