package smile.base.mlp;

import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.Serializable;
import java.util.Arrays;
import smile.math.MathEx;
import smile.math.matrix.Matrix;

/* loaded from: input_file:smile/base/mlp/Layer.class */
public abstract class Layer implements Serializable {
    private static final long serialVersionUID = 2;
    protected int n;
    protected int p;
    protected Matrix weight;
    protected double[] bias;
    protected transient ThreadLocal<double[]> output;
    protected transient ThreadLocal<double[]> outputGradient;
    protected transient ThreadLocal<Matrix> weightGradient;
    protected transient ThreadLocal<double[]> biasGradient;
    protected transient ThreadLocal<Matrix> rmsWeightGradient;
    protected transient ThreadLocal<double[]> rmsBiasGradient;
    protected transient ThreadLocal<Matrix> weightUpdate;
    protected transient ThreadLocal<double[]> biasUpdate;

    public Layer(int i, int i2) {
        this(Matrix.rand(i, i2, -Math.sqrt(6.0d / (i + i2)), Math.sqrt(6.0d / (i + i2))), new double[i]);
    }

    public Layer(Matrix matrix, double[] dArr) {
        this.n = matrix.nrows();
        this.p = matrix.ncols();
        this.weight = matrix;
        this.bias = dArr;
        init();
    }

    private void readObject(ObjectInputStream objectInputStream) throws IOException, ClassNotFoundException {
        objectInputStream.defaultReadObject();
        init();
    }

    private void init() {
        this.output = new ThreadLocal<double[]>() { // from class: smile.base.mlp.Layer.1
            /* JADX INFO: Access modifiers changed from: protected */
            /* JADX WARN: Can't rename method to resolve collision */
            @Override // java.lang.ThreadLocal
            public synchronized double[] initialValue() {
                return new double[Layer.this.n];
            }
        };
        this.outputGradient = new ThreadLocal<double[]>() { // from class: smile.base.mlp.Layer.2
            /* JADX INFO: Access modifiers changed from: protected */
            /* JADX WARN: Can't rename method to resolve collision */
            @Override // java.lang.ThreadLocal
            public synchronized double[] initialValue() {
                return new double[Layer.this.n];
            }
        };
        this.weightGradient = new ThreadLocal<Matrix>() { // from class: smile.base.mlp.Layer.3
            /* JADX INFO: Access modifiers changed from: protected */
            /* JADX WARN: Can't rename method to resolve collision */
            @Override // java.lang.ThreadLocal
            public synchronized Matrix initialValue() {
                return new Matrix(Layer.this.n, Layer.this.p);
            }
        };
        this.biasGradient = new ThreadLocal<double[]>() { // from class: smile.base.mlp.Layer.4
            /* JADX INFO: Access modifiers changed from: protected */
            /* JADX WARN: Can't rename method to resolve collision */
            @Override // java.lang.ThreadLocal
            public synchronized double[] initialValue() {
                return new double[Layer.this.n];
            }
        };
        this.rmsWeightGradient = new ThreadLocal<Matrix>() { // from class: smile.base.mlp.Layer.5
            /* JADX INFO: Access modifiers changed from: protected */
            /* JADX WARN: Can't rename method to resolve collision */
            @Override // java.lang.ThreadLocal
            public synchronized Matrix initialValue() {
                return new Matrix(Layer.this.n, Layer.this.p);
            }
        };
        this.rmsBiasGradient = new ThreadLocal<double[]>() { // from class: smile.base.mlp.Layer.6
            /* JADX INFO: Access modifiers changed from: protected */
            /* JADX WARN: Can't rename method to resolve collision */
            @Override // java.lang.ThreadLocal
            public synchronized double[] initialValue() {
                return new double[Layer.this.n];
            }
        };
        this.weightUpdate = new ThreadLocal<Matrix>() { // from class: smile.base.mlp.Layer.7
            /* JADX INFO: Access modifiers changed from: protected */
            /* JADX WARN: Can't rename method to resolve collision */
            @Override // java.lang.ThreadLocal
            public synchronized Matrix initialValue() {
                return new Matrix(Layer.this.n, Layer.this.p);
            }
        };
        this.biasUpdate = new ThreadLocal<double[]>() { // from class: smile.base.mlp.Layer.8
            /* JADX INFO: Access modifiers changed from: protected */
            /* JADX WARN: Can't rename method to resolve collision */
            @Override // java.lang.ThreadLocal
            public synchronized double[] initialValue() {
                return new double[Layer.this.n];
            }
        };
    }

    public int getOutputSize() {
        return this.n;
    }

    public int getInputSize() {
        return this.p;
    }

    public double[] output() {
        return this.output.get();
    }

    public double[] gradient() {
        return this.outputGradient.get();
    }

    public void propagate(double[] dArr) {
        double[] dArr2 = this.output.get();
        System.arraycopy(this.bias, 0, dArr2, 0, this.n);
        this.weight.mv(1.0d, dArr, 1.0d, dArr2);
        f(dArr2);
    }

    public abstract void f(double[] dArr);

    public abstract void backpropagate(double[] dArr);

    public void computeGradientUpdate(double[] dArr, double d, double d2, double d3) {
        double[] dArr2 = this.outputGradient.get();
        if (d2 <= 0.0d || d2 >= 1.0d) {
            this.weight.add(d, dArr2, dArr);
            for (int i = 0; i < this.n; i++) {
                double[] dArr3 = this.bias;
                int i2 = i;
                dArr3[i2] = dArr3[i2] + (d * dArr2[i]);
            }
        } else {
            Matrix matrix = this.weightUpdate.get();
            double[] dArr4 = this.biasUpdate.get();
            matrix.mul(d2);
            matrix.add(d, dArr2, dArr);
            this.weight.add(1.0d, matrix);
            for (int i3 = 0; i3 < this.n; i3++) {
                double d4 = (d2 * dArr4[i3]) + (d * dArr2[i3]);
                dArr4[i3] = d4;
                double[] dArr5 = this.bias;
                int i4 = i3;
                dArr5[i4] = dArr5[i4] + d4;
            }
        }
        if (d3 <= 0.9d || d3 >= 1.0d) {
            return;
        }
        this.weight.mul(d3);
    }

    public void computeGradient(double[] dArr) {
        double[] dArr2 = this.outputGradient.get();
        Matrix matrix = this.weightGradient.get();
        double[] dArr3 = this.biasGradient.get();
        matrix.add(1.0d, dArr2, dArr);
        for (int i = 0; i < this.n; i++) {
            int i2 = i;
            dArr3[i2] = dArr3[i2] + dArr2[i];
        }
    }

    public void update(int i, double d, double d2, double d3, double d4, double d5) {
        Matrix matrix = this.weightGradient.get();
        double[] dArr = this.biasGradient.get();
        double d6 = d / i;
        if (d4 > 0.0d && d4 < 1.0d) {
            d6 = d;
            matrix.div(i);
            for (int i2 = 0; i2 < this.n; i2++) {
                int i3 = i2;
                dArr[i3] = dArr[i3] / i;
            }
            Matrix matrix2 = this.rmsWeightGradient.get();
            double[] dArr2 = this.rmsBiasGradient.get();
            double d7 = 1.0d - d4;
            for (int i4 = 0; i4 < this.p; i4++) {
                for (int i5 = 0; i5 < this.n; i5++) {
                    matrix2.set(i5, i4, (d4 * matrix2.get(i5, i4)) + (d7 * MathEx.sqr(matrix.get(i5, i4))));
                }
            }
            for (int i6 = 0; i6 < this.n; i6++) {
                dArr2[i6] = (d4 * dArr2[i6]) + (d7 * MathEx.sqr(dArr[i6]));
            }
            for (int i7 = 0; i7 < this.p; i7++) {
                for (int i8 = 0; i8 < this.n; i8++) {
                    matrix.div(i8, i7, Math.sqrt(d5 + matrix2.get(i8, i7)));
                }
            }
            for (int i9 = 0; i9 < this.n; i9++) {
                int i10 = i9;
                dArr[i10] = dArr[i10] / Math.sqrt(d5 + dArr2[i9]);
            }
        }
        if (d2 <= 0.0d || d2 >= 1.0d) {
            this.weight.add(d6, matrix);
            for (int i11 = 0; i11 < this.n; i11++) {
                double[] dArr3 = this.bias;
                int i12 = i11;
                dArr3[i12] = dArr3[i12] + (d6 * dArr[i11]);
            }
        } else {
            Matrix matrix3 = this.weightUpdate.get();
            double[] dArr4 = this.biasUpdate.get();
            matrix3.add(d2, d6, matrix);
            for (int i13 = 0; i13 < this.n; i13++) {
                dArr4[i13] = (d2 * dArr4[i13]) + (d6 * dArr[i13]);
            }
            this.weight.add(1.0d, matrix3);
            MathEx.add(this.bias, dArr4);
        }
        if (d3 > 0.9d && d3 < 1.0d) {
            this.weight.mul(d3);
        }
        matrix.fill(0.0d);
        Arrays.fill(dArr, 0.0d);
    }

    public static HiddenLayerBuilder linear(int i) {
        return new HiddenLayerBuilder(i, ActivationFunction.linear());
    }

    public static HiddenLayerBuilder rectifier(int i) {
        return new HiddenLayerBuilder(i, ActivationFunction.rectifier());
    }

    public static HiddenLayerBuilder sigmoid(int i) {
        return new HiddenLayerBuilder(i, ActivationFunction.sigmoid());
    }

    public static HiddenLayerBuilder tanh(int i) {
        return new HiddenLayerBuilder(i, ActivationFunction.tanh());
    }

    public static OutputLayerBuilder mse(int i, OutputFunction outputFunction) {
        return new OutputLayerBuilder(i, outputFunction, Cost.MEAN_SQUARED_ERROR);
    }

    public static OutputLayerBuilder mle(int i, OutputFunction outputFunction) {
        return new OutputLayerBuilder(i, outputFunction, Cost.LIKELIHOOD);
    }
}
