/*
 * Decompiled with CFR 0.152.
 */
package smile.base.mlp;

import smile.base.mlp.ActivationFunction;
import smile.base.mlp.Layer;
import smile.tensor.Vector;

public class HiddenLayer
extends Layer {
    private static final long serialVersionUID = 2L;
    private final ActivationFunction activation;

    public HiddenLayer(int n, int p, double dropout, ActivationFunction activation) {
        super(n, p, dropout);
        this.activation = activation;
    }

    public String toString() {
        if (this.dropout > 0.0) {
            return String.format("%s(%d, %.2f)", this.activation.name(), this.n, this.dropout);
        }
        return String.format("%s(%d)", this.activation.name(), this.n);
    }

    @Override
    public void transform(Vector x) {
        this.activation.f(x);
    }

    @Override
    public void backpropagate(Vector lowerLayerGradient) {
        Vector output = (Vector)this.output.get();
        Vector outputGradient = (Vector)this.outputGradient.get();
        this.activation.g(outputGradient, output);
        if (lowerLayerGradient != null) {
            this.weight.tv(outputGradient, lowerLayerGradient);
        }
    }
}

