/*
 * Decompiled with CFR 0.152.
 */
package smile.base.mlp;

import java.io.Serializable;
import smile.base.mlp.ActivationFunction;
import smile.base.mlp.Cost;
import smile.base.mlp.HiddenLayerBuilder;
import smile.base.mlp.OutputFunction;
import smile.base.mlp.OutputLayerBuilder;
import smile.math.matrix.DenseMatrix;
import smile.math.matrix.Matrix;
import smile.stat.distribution.GaussianDistribution;

public abstract class Layer
implements Serializable {
    private static final long serialVersionUID = 2L;
    protected int n;
    protected int p;
    protected double[] output;
    protected double[] gradient;
    protected DenseMatrix weight;
    protected DenseMatrix delta;
    protected DenseMatrix update;

    public Layer(int n, int p) {
        this.n = n;
        this.p = p;
        this.weight = Matrix.zeros((int)n, (int)(p + 1));
        this.delta = Matrix.zeros((int)n, (int)(p + 1));
        this.update = Matrix.zeros((int)n, (int)(p + 1));
        GaussianDistribution gaussian = GaussianDistribution.getInstance();
        double r = Math.sqrt(2.0 / (double)p);
        for (int j = 0; j < p; ++j) {
            for (int i = 0; i < n; ++i) {
                this.weight.set(i, j, r * gaussian.rand());
            }
        }
    }

    public int getOutputSize() {
        return this.n;
    }

    public int getInputSize() {
        return this.p;
    }

    public double[] output() {
        return this.output;
    }

    public double[] gradient() {
        return this.gradient;
    }

    public void propagate(double[] x) {
        assert (x[this.p] == 1.0) : "bias/intercept is not 1";
        this.weight.ax(x, this.output);
        this.f(this.output);
    }

    public abstract void f(double[] var1);

    public abstract void backpropagate(double[] var1);

    public void computeUpdate(double eta, double alpha, double[] x) {
        for (int j = 0; j <= this.p; ++j) {
            double xj = x[j];
            for (int i = 0; i < this.n; ++i) {
                double dw = eta * this.gradient[i] * xj;
                this.delta.set(i, j, dw);
                if (alpha > 0.0) {
                    dw += alpha * this.update.get(i, j);
                }
                this.update.set(i, j, dw);
            }
        }
    }

    public void update(double alpha, double lambda) {
        this.weight.add(this.update);
        if (lambda < 1.0) {
            for (int j = 0; j < this.p; ++j) {
                for (int i = 0; i < this.n; ++i) {
                    this.weight.mul(i, j, lambda);
                }
            }
        }
        if (alpha == 1.0) {
            this.update.fill(0.0);
        }
    }

    public static HiddenLayerBuilder linear(int n) {
        return new HiddenLayerBuilder(n, ActivationFunction.linear());
    }

    public static HiddenLayerBuilder rectifier(int n) {
        return new HiddenLayerBuilder(n, ActivationFunction.rectifier());
    }

    public static HiddenLayerBuilder sigmoid(int n) {
        return new HiddenLayerBuilder(n, ActivationFunction.sigmoid());
    }

    public static HiddenLayerBuilder tanh(int n) {
        return new HiddenLayerBuilder(n, ActivationFunction.tanh());
    }

    public static OutputLayerBuilder mse(int n, OutputFunction f) {
        return new OutputLayerBuilder(n, f, Cost.MEAN_SQUARED_ERROR);
    }

    public static OutputLayerBuilder mle(int n, OutputFunction f) {
        return new OutputLayerBuilder(n, f, Cost.LIKELIHOOD);
    }
}

