/*
 * Decompiled with CFR 0.152.
 */
package smile.base.mlp;

import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.Serializable;
import java.util.Arrays;
import java.util.stream.Collectors;
import smile.base.mlp.Layer;
import smile.base.mlp.OutputLayer;
import smile.math.TimeFunction;

public abstract class MultilayerPerceptron
implements Serializable {
    private static final long serialVersionUID = 2L;
    protected int p;
    protected OutputLayer output;
    protected Layer[] net;
    protected transient ThreadLocal<double[]> target;
    protected TimeFunction learningRate = TimeFunction.constant((double)0.01);
    protected TimeFunction momentum = TimeFunction.constant((double)0.0);
    protected double rho = 0.0;
    protected double epsilon = 1.0E-7;
    protected double lambda = 0.0;
    protected int t = 0;

    public MultilayerPerceptron(Layer ... net) {
        if (net.length < 2) {
            throw new IllegalArgumentException("Too few layers: " + net.length);
        }
        Layer lower = net[0];
        for (int i = 1; i < net.length; ++i) {
            Layer layer = net[i];
            if (layer.getInputSize() != lower.getOutputSize()) {
                throw new IllegalArgumentException(String.format("Invalid network architecture. Layer %d has %d neurons while layer %d takes %d inputs", i - 1, lower.getOutputSize(), i, layer.getInputSize()));
            }
            lower = layer;
        }
        this.output = (OutputLayer)net[net.length - 1];
        this.net = Arrays.copyOf(net, net.length - 1);
        this.p = net[0].getInputSize();
        this.init();
    }

    private void readObject(ObjectInputStream in) throws IOException, ClassNotFoundException {
        in.defaultReadObject();
        this.init();
    }

    private void init() {
        this.target = new ThreadLocal<double[]>(){

            @Override
            protected synchronized double[] initialValue() {
                return new double[MultilayerPerceptron.this.output.getOutputSize()];
            }
        };
    }

    public String toString() {
        return String.format("x(%d) -> %s -> %s(learning rate = %s, momentum = %s, weight decay = %.2f)", this.p, Arrays.stream(this.net).map(Object::toString).collect(Collectors.joining(" -> ")), this.output, this.learningRate, this.momentum, this.lambda);
    }

    public void setLearningRate(TimeFunction rate) {
        this.learningRate = rate;
    }

    public void setMomentum(TimeFunction momentum) {
        this.momentum = momentum;
    }

    public void setRMSProp(double rho, double epsilon) {
        if (rho < 0.0 || rho >= 1.0) {
            throw new IllegalArgumentException("Invalid rho = " + rho);
        }
        if (epsilon <= 0.0) {
            throw new IllegalArgumentException("Invalid epsilon = " + epsilon);
        }
        this.rho = rho;
        this.epsilon = epsilon;
    }

    public void setWeightDecay(double lambda) {
        if (lambda < 0.0) {
            throw new IllegalArgumentException("Invalid weight decay factor: " + lambda);
        }
        this.lambda = lambda;
    }

    public double getLearningRate() {
        return this.learningRate.apply(this.t);
    }

    public double getMomentum() {
        return this.momentum.apply(this.t);
    }

    public double getWeightDecay() {
        return this.lambda;
    }

    protected void propagate(double[] x) {
        double[] input = x;
        for (int i = 0; i < this.net.length; ++i) {
            this.net[i].propagate(input);
            input = this.net[i].output();
        }
        this.output.propagate(input);
    }

    protected void backpropagate(double[] x, boolean update) {
        this.output.computeOutputGradient(this.target.get(), 1.0);
        Layer upper = this.output;
        for (int i = this.net.length - 1; i >= 0; --i) {
            ((Layer)upper).backpropagate(this.net[i].gradient());
            upper = this.net[i];
        }
        ((Layer)upper).backpropagate(null);
        if (update) {
            double eta = this.learningRate.apply(this.t);
            if (eta <= 0.0) {
                throw new IllegalArgumentException("Invalid learning rate: " + eta);
            }
            double alpha = this.momentum.apply(this.t);
            if (alpha < 0.0 || alpha >= 1.0) {
                throw new IllegalArgumentException("Invalid momentum factor: " + alpha);
            }
            double decay = 1.0 - 2.0 * eta * this.lambda;
            if (decay < 0.9) {
                throw new IllegalStateException(String.format("Invalid learning rate (eta = %.2f) and/or L2 regularization (lambda = %.2f) such that weight decay = %.2f", eta, this.lambda, decay));
            }
            for (Layer layer : this.net) {
                layer.computeGradientUpdate(x, eta, alpha, decay);
                x = layer.output();
            }
            this.output.computeGradientUpdate(x, eta, alpha, decay);
        } else {
            for (Layer layer : this.net) {
                layer.computeGradient(x);
                x = layer.output();
            }
            this.output.computeGradient(x);
        }
    }

    protected void update(int m) {
        double eta = this.learningRate.apply(this.t);
        if (eta <= 0.0) {
            throw new IllegalArgumentException("Invalid learning rate: " + eta);
        }
        double alpha = this.momentum.apply(this.t);
        if (alpha < 0.0 || alpha >= 1.0) {
            throw new IllegalArgumentException("Invalid momentum factor: " + alpha);
        }
        double decay = 1.0 - 2.0 * eta * this.lambda;
        if (decay < 0.9) {
            throw new IllegalStateException(String.format("Invalid learning rate (eta = %.2f) and/or decay (lambda = %.2f)", eta, this.lambda));
        }
        for (Layer layer : this.net) {
            layer.update(m, eta, alpha, decay, this.rho, this.epsilon);
        }
        this.output.update(m, eta, alpha, decay, this.rho, this.epsilon);
    }
}

