package org.deeplearning4j.nn.layers.feedforward;

import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.PReLULayer;
import org.deeplearning4j.nn.gradient.DefaultGradient;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.layers.BaseLayer;
import org.deeplearning4j.nn.workspace.ArrayType;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.nd4j.common.primitives.Pair;
import org.nd4j.linalg.activations.impl.ActivationPReLU;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;

/* loaded from: input_file:org/deeplearning4j/nn/layers/feedforward/PReLU.class */
public class PReLU extends BaseLayer<PReLULayer> {
    long[] axes;

    public PReLU(NeuralNetConfiguration neuralNetConfiguration, DataType dataType) {
        super(neuralNetConfiguration, dataType);
        this.axes = layerConf().getSharedAxes();
    }

    @Override // org.deeplearning4j.nn.layers.AbstractLayer, org.deeplearning4j.nn.api.Layer
    public Layer.Type type() {
        return Layer.Type.FEED_FORWARD;
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.api.Layer
    public INDArray activate(boolean z, LayerWorkspaceMgr layerWorkspaceMgr) {
        assertInputSet(false);
        applyDropOutIfNecessary(z, layerWorkspaceMgr);
        return new ActivationPReLU(getParam("W"), this.axes).getActivation(z ? layerWorkspaceMgr.dup(ArrayType.ACTIVATIONS, this.input, this.input.ordering()) : layerWorkspaceMgr.leverageTo(ArrayType.ACTIVATIONS, this.input), z);
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.api.Layer
    public Pair<Gradient, INDArray> backpropGradient(INDArray iNDArray, LayerWorkspaceMgr layerWorkspaceMgr) {
        assertInputSet(true);
        Pair backprop = new ActivationPReLU(getParam("W"), this.axes).backprop(layerWorkspaceMgr.dup(ArrayType.ACTIVATION_GRAD, this.input, this.input.ordering()), iNDArray);
        INDArray iNDArray2 = (INDArray) backprop.getFirst();
        INDArray iNDArray3 = (INDArray) backprop.getSecond();
        INDArray iNDArray4 = this.gradientViews.get("W");
        iNDArray4.assign(iNDArray3);
        INDArray backpropDropOutIfPresent = backpropDropOutIfPresent(layerWorkspaceMgr.leverageTo(ArrayType.ACTIVATION_GRAD, iNDArray2));
        DefaultGradient defaultGradient = new DefaultGradient();
        defaultGradient.setGradientFor("W", iNDArray4, 'c');
        return new Pair<>(defaultGradient, backpropDropOutIfPresent);
    }

    @Override // org.deeplearning4j.nn.api.Layer
    public boolean isPretrainLayer() {
        return false;
    }
}
