/*
 * Decompiled with CFR 0.152.
 */
package smile.base.mlp;

import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.Serializable;
import java.util.Arrays;
import java.util.Properties;
import java.util.stream.Collectors;
import smile.base.mlp.InputLayer;
import smile.base.mlp.Layer;
import smile.base.mlp.OutputLayer;
import smile.tensor.ScalarType;
import smile.tensor.Vector;
import smile.util.function.TimeFunction;

public abstract class MultilayerPerceptron
implements AutoCloseable,
Serializable {
    private static final long serialVersionUID = 2L;
    protected int p;
    protected OutputLayer output;
    protected Layer[] net;
    protected transient ThreadLocal<Vector> target;
    protected TimeFunction learningRate = TimeFunction.constant((double)0.01);
    protected TimeFunction momentum = null;
    protected double rho = 0.0;
    protected double epsilon = 1.0E-7;
    protected double lambda = 0.0;
    protected double clipValue = 0.0;
    protected double clipNorm = 0.0;
    protected int t = 0;

    public MultilayerPerceptron(Layer ... net) {
        if (net.length <= 2) {
            throw new IllegalArgumentException("Too few layers: " + net.length);
        }
        if (!(net[0] instanceof InputLayer)) {
            throw new IllegalArgumentException("The first layer is not an InputLayer: " + String.valueOf(net[0]));
        }
        if (!(net[net.length - 1] instanceof OutputLayer)) {
            throw new IllegalArgumentException("The last layer is not an OutputLayer: " + String.valueOf(net[net.length - 1]));
        }
        Layer lower = net[0];
        for (int i = 1; i < net.length; ++i) {
            Layer layer = net[i];
            if (layer.getInputSize() != lower.getOutputSize()) {
                throw new IllegalArgumentException(String.format("Invalid network architecture. Layer %d has %d neurons while layer %d takes %d inputs", i - 1, lower.getOutputSize(), i, layer.getInputSize()));
            }
            lower = layer;
        }
        this.output = (OutputLayer)net[net.length - 1];
        this.net = Arrays.copyOf(net, net.length - 1);
        this.p = net[0].getInputSize();
        this.init();
    }

    @Override
    public void close() {
        if (this.target != null) {
            this.target.remove();
        }
        for (Layer layer : this.net) {
            layer.close();
        }
        this.output.close();
    }

    private void readObject(ObjectInputStream in) throws IOException, ClassNotFoundException {
        in.defaultReadObject();
        this.init();
    }

    private void init() {
        this.target = ThreadLocal.withInitial(() -> Vector.zeros((ScalarType)ScalarType.Float32, (int)this.output.getOutputSize()));
    }

    public String toString() {
        String s = String.format("%s -> %s(learning rate = %s", Arrays.stream(this.net).map(Object::toString).collect(Collectors.joining(" -> ")), this.output, this.learningRate);
        if (this.momentum != null) {
            s = String.format("%s, momentum = %s", s, this.momentum);
        }
        if (this.lambda != 0.0) {
            s = String.format("%s, weight decay = %f", s, this.lambda);
        }
        if (this.rho != 0.0) {
            s = String.format("%s, RMSProp = %f", s, this.rho);
        }
        return s + ")";
    }

    public void setLearningRate(TimeFunction rate) {
        this.learningRate = rate;
    }

    public void setMomentum(TimeFunction momentum) {
        this.momentum = momentum;
    }

    public void setRMSProp(double rho, double epsilon) {
        if (rho < 0.0 || rho >= 1.0) {
            throw new IllegalArgumentException("Invalid rho = " + rho);
        }
        if (epsilon <= 0.0) {
            throw new IllegalArgumentException("Invalid epsilon = " + epsilon);
        }
        this.rho = rho;
        this.epsilon = epsilon;
    }

    public void setWeightDecay(double lambda) {
        if (lambda < 0.0) {
            throw new IllegalArgumentException("Invalid weight decay factor: " + lambda);
        }
        this.lambda = lambda;
    }

    public void setClipValue(double clipValue) {
        if (clipValue < 0.0) {
            throw new IllegalArgumentException("Invalid gradient clipping value: " + clipValue);
        }
        this.clipValue = clipValue;
    }

    public void setClipNorm(double clipNorm) {
        if (clipNorm < 0.0) {
            throw new IllegalArgumentException("Invalid gradient clipping norm: " + clipNorm);
        }
        this.clipNorm = clipNorm;
    }

    public double getLearningRate() {
        return this.learningRate.apply(this.t);
    }

    public double getMomentum() {
        return this.momentum == null ? 0.0 : this.momentum.apply(this.t);
    }

    public double getWeightDecay() {
        return this.lambda;
    }

    public double getClipValue() {
        return this.clipValue;
    }

    public double getClipNorm() {
        return this.clipNorm;
    }

    protected Vector vector(double[] x) {
        return this.net[0].output().vector(x);
    }

    protected void propagate(Vector x, boolean training) {
        Vector input = x;
        for (Layer layer : this.net) {
            layer.propagate(input);
            if (training) {
                layer.propagateDropout();
            }
            input = layer.output();
        }
        this.output.propagate(input);
    }

    private void clipGradient(Vector gradient) {
        if (this.clipNorm > 0.0) {
            double norm = gradient.norm2();
            if (norm > this.clipNorm) {
                double scale = this.clipNorm / norm;
                gradient.scale(scale);
            }
        } else if (this.clipValue > 0.0) {
            for (int j = 0; j < gradient.size(); ++j) {
                if (gradient.get(j) > this.clipValue) {
                    gradient.set(j, this.clipValue);
                    continue;
                }
                if (!(gradient.get(j) < -this.clipValue)) continue;
                gradient.set(j, -this.clipValue);
            }
        }
    }

    protected void backpropagate(boolean update) {
        this.output.computeOutputGradient(this.target.get(), 1.0);
        this.clipGradient(this.output.gradient());
        Layer upper = this.output;
        int i = this.net.length;
        while (--i > 0) {
            upper.backpropagate(this.net[i].gradient());
            upper = this.net[i];
            upper.backpopagateDropout();
            this.clipGradient(upper.gradient());
        }
        upper.backpropagate(null);
        if (update) {
            double eta = this.getLearningRate();
            if (eta <= 0.0) {
                throw new IllegalArgumentException("Invalid learning rate: " + eta);
            }
            double alpha = this.getMomentum();
            if (alpha < 0.0 || alpha >= 1.0) {
                throw new IllegalArgumentException("Invalid momentum factor: " + alpha);
            }
            double decay = 1.0 - 2.0 * eta * this.lambda;
            if (decay < 0.9) {
                throw new IllegalStateException(String.format("Invalid learning rate (eta = %.2f) and/or L2 regularization (lambda = %.2f) such that weight decay = %.2f", eta, this.lambda, decay));
            }
            Vector x = this.net[0].output();
            for (int i2 = 1; i2 < this.net.length; ++i2) {
                Layer layer = this.net[i2];
                layer.computeGradientUpdate(x, eta, alpha, decay);
                x = layer.output();
            }
            this.output.computeGradientUpdate(x, eta, alpha, decay);
        } else {
            Vector x = this.net[0].output();
            for (int i3 = 1; i3 < this.net.length; ++i3) {
                Layer layer = this.net[i3];
                layer.computeGradient(x);
                x = layer.output();
            }
            this.output.computeGradient(x);
        }
    }

    protected void update(int m) {
        double eta = this.getLearningRate();
        if (eta <= 0.0) {
            throw new IllegalArgumentException("Invalid learning rate: " + eta);
        }
        double alpha = this.getMomentum();
        if (alpha < 0.0 || alpha >= 1.0) {
            throw new IllegalArgumentException("Invalid momentum factor: " + alpha);
        }
        double decay = 1.0 - 2.0 * eta * this.lambda;
        if (decay < 0.9) {
            throw new IllegalStateException(String.format("Invalid learning rate (eta = %.2f) and/or decay (lambda = %.2f)", eta, this.lambda));
        }
        for (int i = 1; i < this.net.length; ++i) {
            this.net[i].update(m, eta, alpha, decay, this.rho, this.epsilon);
        }
        this.output.update(m, eta, alpha, decay, this.rho, this.epsilon);
    }

    public void setParameters(Properties params) {
        String rho;
        String clipNorm;
        String clipValue;
        String momentum;
        String weightDecay;
        String learningRate = params.getProperty("smile.mlp.learning_rate");
        if (learningRate != null) {
            this.setLearningRate(TimeFunction.of((String)learningRate));
        }
        if ((weightDecay = params.getProperty("smile.mlp.weight_decay")) != null) {
            this.setWeightDecay(Double.parseDouble(weightDecay));
        }
        if ((momentum = params.getProperty("smile.mlp.momentum")) != null) {
            this.setMomentum(TimeFunction.of((String)momentum));
        }
        if ((clipValue = params.getProperty("smile.mlp.clip_value")) != null) {
            this.setClipValue(Double.parseDouble(clipValue));
        }
        if ((clipNorm = params.getProperty("smile.mlp.clip_norm")) != null) {
            this.setClipNorm(Double.parseDouble(clipNorm));
        }
        if ((rho = params.getProperty("smile.mlp.RMSProp.rho")) != null) {
            double epsilon = Double.parseDouble(params.getProperty("smile.mlp.RMSProp.epsilon", "1E-7"));
            this.setRMSProp(Double.parseDouble(rho), epsilon);
        }
    }
}

