/*
 * Decompiled with CFR 0.152.
 */
package smile.base.mlp;

import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Locale;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import smile.base.mlp.ActivationFunction;
import smile.base.mlp.Cost;
import smile.base.mlp.HiddenLayerBuilder;
import smile.base.mlp.InputLayer;
import smile.base.mlp.LayerBuilder;
import smile.base.mlp.OutputFunction;
import smile.base.mlp.OutputLayerBuilder;
import smile.linalg.Transpose;
import smile.math.MathEx;
import smile.tensor.DenseMatrix;
import smile.tensor.ScalarType;
import smile.tensor.Vector;

public abstract class Layer
implements AutoCloseable,
Serializable {
    private static final long serialVersionUID = 2L;
    protected final int n;
    protected final int p;
    protected final double dropout;
    protected DenseMatrix weight;
    protected Vector bias;
    protected transient ThreadLocal<Vector> output;
    protected transient ThreadLocal<Vector> outputGradient;
    protected transient ThreadLocal<DenseMatrix> weightGradient;
    protected transient ThreadLocal<Vector> biasGradient;
    protected transient ThreadLocal<DenseMatrix> weightGradientMoment1;
    protected transient ThreadLocal<DenseMatrix> weightGradientMoment2;
    protected transient ThreadLocal<Vector> biasGradientMoment1;
    protected transient ThreadLocal<Vector> biasGradientMoment2;
    protected transient ThreadLocal<DenseMatrix> weightUpdate;
    protected transient ThreadLocal<Vector> biasUpdate;
    protected transient ThreadLocal<byte[]> mask;

    Layer(int n, double dropout) {
        if (dropout < 0.0 || dropout >= 1.0) {
            throw new IllegalArgumentException("Invalid dropout rate: " + dropout);
        }
        this.n = n;
        this.p = n;
        this.dropout = dropout;
        this.output = ThreadLocal.withInitial(() -> Vector.zeros((ScalarType)ScalarType.Float32, (int)n));
        if (dropout > 0.0) {
            this.mask = ThreadLocal.withInitial(() -> new byte[n]);
        }
    }

    public Layer(int n, int p) {
        this(n, p, 0.0);
    }

    public Layer(int n, int p, double dropout) {
        this(DenseMatrix.rand((ScalarType)ScalarType.Float32, (int)n, (int)p, (double)(-Math.sqrt(6.0 / (double)(n + p))), (double)Math.sqrt(6.0 / (double)(n + p))), Vector.zeros((ScalarType)ScalarType.Float32, (int)n), dropout);
    }

    public Layer(DenseMatrix weight, Vector bias) {
        this(weight, bias, 0.0);
    }

    public Layer(DenseMatrix weight, Vector bias, double dropout) {
        if (dropout < 0.0 || dropout >= 1.0) {
            throw new IllegalArgumentException("Invalid dropout rate: " + dropout);
        }
        this.n = weight.nrow();
        this.p = weight.ncol();
        this.weight = weight;
        this.bias = bias;
        this.dropout = dropout;
        this.init();
    }

    @Override
    public void close() {
        if (this.output != null) {
            this.output.remove();
        }
        if (this.outputGradient != null) {
            this.outputGradient.remove();
        }
        if (this.weightGradient != null) {
            this.weightGradient.remove();
        }
        if (this.biasGradient != null) {
            this.biasGradient.remove();
        }
        if (this.weightGradientMoment1 != null) {
            this.weightGradientMoment1.remove();
        }
        if (this.weightGradientMoment2 != null) {
            this.weightGradientMoment2.remove();
        }
        if (this.biasGradientMoment1 != null) {
            this.biasGradientMoment1.remove();
        }
        if (this.biasGradientMoment2 != null) {
            this.biasGradientMoment2.remove();
        }
        if (this.weightUpdate != null) {
            this.weightUpdate.remove();
        }
        if (this.biasUpdate != null) {
            this.biasUpdate.remove();
        }
        if (this.mask != null) {
            this.mask.remove();
        }
    }

    private void readObject(ObjectInputStream in) throws IOException, ClassNotFoundException {
        in.defaultReadObject();
        this.init();
    }

    private void init() {
        this.output = ThreadLocal.withInitial(() -> this.weight.vector(this.n));
        this.outputGradient = ThreadLocal.withInitial(() -> this.weight.vector(this.n));
        this.weightGradient = ThreadLocal.withInitial(() -> this.weight.zeros(this.n, this.p));
        this.biasGradient = ThreadLocal.withInitial(() -> this.weight.vector(this.n));
        this.weightGradientMoment1 = ThreadLocal.withInitial(() -> this.weight.zeros(this.n, this.p));
        this.weightGradientMoment2 = ThreadLocal.withInitial(() -> this.weight.zeros(this.n, this.p));
        this.biasGradientMoment1 = ThreadLocal.withInitial(() -> this.weight.vector(this.n));
        this.biasGradientMoment2 = ThreadLocal.withInitial(() -> this.weight.vector(this.n));
        this.weightUpdate = ThreadLocal.withInitial(() -> this.weight.zeros(this.n, this.p));
        this.biasUpdate = ThreadLocal.withInitial(() -> this.weight.vector(this.n));
        if (this.dropout > 0.0) {
            this.mask = ThreadLocal.withInitial(() -> new byte[this.n]);
        }
    }

    public int getOutputSize() {
        return this.n;
    }

    public int getInputSize() {
        return this.p;
    }

    public DenseMatrix weight() {
        return this.weight;
    }

    public Vector bias() {
        return this.bias;
    }

    public Vector output() {
        return this.output.get();
    }

    public Vector gradient() {
        return this.outputGradient.get();
    }

    public void propagate(Vector x) {
        Vector output = this.output.get();
        Vector.copy((Vector)this.bias, (int)0, (Vector)output, (int)0, (int)this.n);
        this.weight.mv(Transpose.NO_TRANSPOSE, 1.0, x, 1.0, output);
        this.transform(output);
    }

    public void propagateDropout() {
        if (this.dropout > 0.0) {
            Vector output = this.output.get();
            byte[] mask = this.mask.get();
            double scale = 1.0 / (1.0 - this.dropout);
            for (int i = 0; i < this.n; ++i) {
                byte retain;
                mask[i] = retain = (byte)(!(MathEx.random() < this.dropout) ? 1 : 0);
                output.mul(i, (double)retain * scale);
            }
        }
    }

    public abstract void transform(Vector var1);

    public abstract void backpropagate(Vector var1);

    public void backpopagateDropout() {
        if (this.dropout > 0.0) {
            Vector gradient = this.outputGradient.get();
            byte[] mask = this.mask.get();
            double scale = 1.0 / (1.0 - this.dropout);
            for (int i = 0; i < this.n; ++i) {
                gradient.mul(i, (double)mask[i] * scale);
            }
        }
    }

    public void computeGradientUpdate(Vector x, double learningRate, double momentum, double decay) {
        Vector outputGradient = this.outputGradient.get();
        if (momentum > 0.0 && momentum < 1.0) {
            DenseMatrix weightUpdate = this.weightUpdate.get();
            Vector biasUpdate = this.biasUpdate.get();
            weightUpdate.scale(momentum);
            weightUpdate.ger(learningRate, outputGradient, x);
            this.weight.add(weightUpdate);
            biasUpdate.add(momentum, (DenseMatrix)biasUpdate, learningRate, (DenseMatrix)outputGradient);
            this.bias.add((DenseMatrix)biasUpdate);
        } else {
            this.weight.ger(learningRate, outputGradient, x);
            this.bias.axpy(learningRate, outputGradient);
        }
        if (decay > 0.9 && decay < 1.0) {
            this.weight.scale(decay);
        }
    }

    public void computeGradient(Vector x) {
        Vector outputGradient = this.outputGradient.get();
        DenseMatrix weightGradient = this.weightGradient.get();
        Vector biasGradient = this.biasGradient.get();
        weightGradient.ger(1.0, outputGradient, x);
        biasGradient.add((DenseMatrix)outputGradient);
    }

    public void update(int m, double learningRate, double momentum, double decay, double rho, double epsilon) {
        DenseMatrix weightGradient = this.weightGradient.get();
        Vector biasGradient = this.biasGradient.get();
        double eta = learningRate / (double)m;
        if (rho > 0.0 && rho < 1.0) {
            int i;
            int i2;
            int j;
            eta = learningRate;
            weightGradient.scale(1.0 / (double)m);
            biasGradient.scale(1.0 / (double)m);
            DenseMatrix rmsWeightGradient = this.weightGradientMoment2.get();
            Vector rmsBiasGradient = this.biasGradientMoment2.get();
            double rho1 = 1.0 - rho;
            for (j = 0; j < this.p; ++j) {
                for (i2 = 0; i2 < this.n; ++i2) {
                    rmsWeightGradient.set(i2, j, rho * rmsWeightGradient.get(i2, j) + rho1 * MathEx.pow2((double)weightGradient.get(i2, j)));
                }
            }
            for (i = 0; i < this.n; ++i) {
                rmsBiasGradient.set(i, rho * rmsBiasGradient.get(i) + rho1 * MathEx.pow2((double)biasGradient.get(i)));
            }
            for (j = 0; j < this.p; ++j) {
                for (i2 = 0; i2 < this.n; ++i2) {
                    weightGradient.div(i2, j, Math.sqrt(epsilon + rmsWeightGradient.get(i2, j)));
                }
            }
            for (i = 0; i < this.n; ++i) {
                biasGradient.div(i, Math.sqrt(epsilon + rmsBiasGradient.get(i)));
            }
        }
        if (momentum > 0.0 && momentum < 1.0) {
            DenseMatrix weightUpdate = this.weightUpdate.get();
            Vector biasUpdate = this.biasUpdate.get();
            weightUpdate.add(momentum, weightUpdate, eta, weightGradient);
            biasUpdate.add(momentum, (DenseMatrix)biasUpdate, eta, (DenseMatrix)biasGradient);
            this.weight.add(weightUpdate);
            this.bias.add((DenseMatrix)biasUpdate);
        } else {
            this.weight.axpy(eta, weightGradient);
            this.bias.axpy(eta, biasGradient);
        }
        if (decay > 0.9 && decay < 1.0) {
            this.weight.scale(decay);
        }
        weightGradient.fill(0.0);
        biasGradient.fill(0.0);
    }

    public static HiddenLayerBuilder builder(String activation, int neurons, double dropout, double param) {
        switch (activation.toLowerCase(Locale.ROOT)) {
            case "relu": {
                return Layer.rectifier(neurons, dropout);
            }
            case "sigmoid": {
                return Layer.sigmoid(neurons, dropout);
            }
            case "tanh": {
                return Layer.tanh(neurons, dropout);
            }
            case "linear": {
                return Layer.linear(neurons, dropout);
            }
            case "leaky": {
                if (Double.isNaN(param)) {
                    return Layer.leaky(neurons, dropout);
                }
                return Layer.leaky(neurons, dropout, param);
            }
        }
        throw new IllegalArgumentException("Unsupported activation function: " + activation);
    }

    public static LayerBuilder input(int neurons) {
        return Layer.input(neurons, 0.0);
    }

    public static LayerBuilder input(int neurons, double dropout) {
        return new LayerBuilder(neurons, dropout){

            @Override
            public InputLayer build(int p) {
                return new InputLayer(this.neurons, this.dropout);
            }
        };
    }

    public static HiddenLayerBuilder linear(int neurons) {
        return Layer.linear(neurons, 0.0);
    }

    public static HiddenLayerBuilder linear(int neurons, double dropout) {
        return new HiddenLayerBuilder(neurons, dropout, ActivationFunction.linear());
    }

    public static HiddenLayerBuilder rectifier(int neurons) {
        return Layer.rectifier(neurons, 0.0);
    }

    public static HiddenLayerBuilder rectifier(int neurons, double dropout) {
        return new HiddenLayerBuilder(neurons, dropout, ActivationFunction.rectifier());
    }

    public static HiddenLayerBuilder leaky(int neurons) {
        return Layer.rectifier(neurons, 0.0);
    }

    public static HiddenLayerBuilder leaky(int neurons, double dropout) {
        return new HiddenLayerBuilder(neurons, dropout, ActivationFunction.leaky());
    }

    public static HiddenLayerBuilder leaky(int neurons, double dropout, double a) {
        return new HiddenLayerBuilder(neurons, dropout, ActivationFunction.leaky(a));
    }

    public static HiddenLayerBuilder sigmoid(int neurons) {
        return Layer.sigmoid(neurons, 0.0);
    }

    public static HiddenLayerBuilder sigmoid(int neurons, double dropout) {
        return new HiddenLayerBuilder(neurons, dropout, ActivationFunction.sigmoid());
    }

    public static HiddenLayerBuilder tanh(int neurons) {
        return Layer.tanh(neurons, 0.0);
    }

    public static HiddenLayerBuilder tanh(int neurons, double dropout) {
        return new HiddenLayerBuilder(neurons, dropout, ActivationFunction.tanh());
    }

    public static OutputLayerBuilder mse(int neurons, OutputFunction output) {
        return new OutputLayerBuilder(neurons, output, Cost.MEAN_SQUARED_ERROR);
    }

    public static OutputLayerBuilder mle(int neurons, OutputFunction output) {
        return new OutputLayerBuilder(neurons, output, Cost.LIKELIHOOD);
    }

    public static LayerBuilder[] of(int k, int p, String spec) {
        Pattern regex = Pattern.compile(String.format("(\\w+)\\((%s)(,\\s*(%s))?(,\\s*(%s))?\\)", "[-+]?\\d{1,9}", "[-+]?[0-9]*\\.?[0-9]+(?:[eE][-+]?[0-9]+)?", "[-+]?[0-9]*\\.?[0-9]+(?:[eE][-+]?[0-9]+)?"));
        String[] layers = spec.split("\\|");
        ArrayList<LayerBuilder> builders = new ArrayList<LayerBuilder>();
        for (int i = 0; i < layers.length; ++i) {
            Matcher m = regex.matcher(layers[i]);
            if (m.matches()) {
                String activation = m.group(1);
                int neurons = Integer.parseInt(m.group(2));
                double dropout = 0.0;
                if (m.group(3) != null) {
                    dropout = Double.parseDouble(m.group(4));
                }
                double param = Double.NaN;
                if (m.group(5) != null) {
                    param = Double.parseDouble(m.group(6));
                }
                if (i == 0) {
                    if (activation.equalsIgnoreCase("input")) {
                        builders.add(Layer.input(neurons, dropout));
                        continue;
                    }
                    builders.add(Layer.input(p));
                    builders.add(Layer.builder(activation, neurons, dropout, param));
                    continue;
                }
                builders.add(Layer.builder(activation, neurons, dropout, param));
                continue;
            }
            throw new IllegalArgumentException("Invalid layer: " + layers[i]);
        }
        if (k < 2) {
            builders.add(Layer.mse(1, OutputFunction.LINEAR));
        } else if (k == 2) {
            builders.add(Layer.mle(1, OutputFunction.SIGMOID));
        } else {
            builders.add(Layer.mle(k, OutputFunction.SOFTMAX));
        }
        return builders.toArray(new LayerBuilder[0]);
    }
}

