package org.deeplearning4j.nn.layers.recurrent;

import org.deeplearning4j.eval.EvaluationBinary;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.gradient.DefaultGradient;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.workspace.ArrayType;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.nd4j.common.base.Preconditions;
import org.nd4j.common.primitives.Pair;
import org.nd4j.common.primitives.Quad;
import org.nd4j.linalg.activations.IActivation;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.memory.MemoryWorkspace;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastCopyOp;
import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastMulOp;
import org.nd4j.linalg.api.ops.impl.transforms.custom.LayerNorm;
import org.nd4j.linalg.api.ops.impl.transforms.custom.LayerNormBp;
import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;

/* loaded from: input_file:org/deeplearning4j/nn/layers/recurrent/SimpleRnn.class */
public class SimpleRnn extends BaseRecurrentLayer<org.deeplearning4j.nn.conf.layers.recurrent.SimpleRnn> {
    public static final String STATE_KEY_PREV_ACTIVATION = "prevAct";

    public SimpleRnn(NeuralNetConfiguration neuralNetConfiguration, DataType dataType) {
        super(neuralNetConfiguration, dataType);
    }

    @Override // org.deeplearning4j.nn.api.layers.RecurrentLayer
    public INDArray rnnTimeStep(INDArray iNDArray, LayerWorkspaceMgr layerWorkspaceMgr) {
        setInput(iNDArray, layerWorkspaceMgr);
        INDArray iNDArray2 = (INDArray) activateHelper(this.stateMap.get("prevAct"), false, false, layerWorkspaceMgr).getFirst();
        MemoryWorkspace scopeOutOfWorkspaces = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces();
        try {
            this.stateMap.put("prevAct", iNDArray2.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.point(iNDArray2.size(2) - 1)}).dup());
            if (scopeOutOfWorkspaces != null) {
                scopeOutOfWorkspaces.close();
            }
            return iNDArray2;
        } catch (Throwable th) {
            if (scopeOutOfWorkspaces != null) {
                try {
                    scopeOutOfWorkspaces.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
            throw th;
        }
    }

    @Override // org.deeplearning4j.nn.api.layers.RecurrentLayer
    public INDArray rnnActivateUsingStoredState(INDArray iNDArray, boolean z, boolean z2, LayerWorkspaceMgr layerWorkspaceMgr) {
        setInput(iNDArray, layerWorkspaceMgr);
        INDArray iNDArray2 = (INDArray) activateHelper(this.tBpttStateMap.get("prevAct"), z, false, layerWorkspaceMgr).getFirst();
        if (z2) {
            MemoryWorkspace scopeOutOfWorkspaces = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces();
            try {
                this.tBpttStateMap.put("prevAct", iNDArray2.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.point(iNDArray2.size(2) - 1)}).dup());
                if (scopeOutOfWorkspaces != null) {
                    scopeOutOfWorkspaces.close();
                }
            } catch (Throwable th) {
                if (scopeOutOfWorkspaces != null) {
                    try {
                        scopeOutOfWorkspaces.close();
                    } catch (Throwable th2) {
                        th.addSuppressed(th2);
                    }
                }
                throw th;
            }
        }
        return iNDArray2;
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.api.Layer
    public Pair<Gradient, INDArray> backpropGradient(INDArray iNDArray, LayerWorkspaceMgr layerWorkspaceMgr) {
        return tbpttBackpropGradient(iNDArray, -1, layerWorkspaceMgr);
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Override // org.deeplearning4j.nn.api.layers.RecurrentLayer
    public Pair<Gradient, INDArray> tbpttBackpropGradient(INDArray iNDArray, int i, LayerWorkspaceMgr layerWorkspaceMgr) {
        INDArray iNDArray2;
        assertInputSet(true);
        if (iNDArray.ordering() != 'f' || !Shape.hasDefaultStridesForShape(iNDArray)) {
            iNDArray = iNDArray.dup('f');
        }
        long nOut = ((org.deeplearning4j.nn.conf.layers.recurrent.SimpleRnn) layerConf()).getNOut();
        INDArray permuteIfNWC = permuteIfNWC(this.input.castTo(this.dataType));
        Quad<INDArray, INDArray, INDArray, INDArray> activateHelper = activateHelper(null, true, true, layerWorkspaceMgr);
        INDArray paramWithNoise = getParamWithNoise("W", true, layerWorkspaceMgr);
        INDArray paramWithNoise2 = getParamWithNoise("RW", true, layerWorkspaceMgr);
        INDArray paramWithNoise3 = getParamWithNoise("b", true, layerWorkspaceMgr);
        INDArray paramWithNoise4 = hasLayerNorm() ? getParamWithNoise("g", true, layerWorkspaceMgr) : null;
        INDArray iNDArray3 = paramWithNoise4 != null ? paramWithNoise4.get(new INDArrayIndex[]{NDArrayIndex.interval(0L, 0L, true), NDArrayIndex.interval(0L, nOut)}) : null;
        INDArray iNDArray4 = paramWithNoise4 != null ? paramWithNoise4.get(new INDArrayIndex[]{NDArrayIndex.interval(0L, 0L, true), NDArrayIndex.interval(nOut, nOut * 2)}) : null;
        INDArray iNDArray5 = this.gradientViews.get("W");
        INDArray iNDArray6 = this.gradientViews.get("RW");
        INDArray iNDArray7 = this.gradientViews.get("b");
        INDArray iNDArray8 = hasLayerNorm() ? this.gradientViews.get("g") : null;
        INDArray iNDArray9 = iNDArray8 != null ? iNDArray8.get(new INDArrayIndex[]{NDArrayIndex.interval(0L, 0L, true), NDArrayIndex.interval(0L, nOut)}) : null;
        INDArray iNDArray10 = iNDArray8 != null ? iNDArray8.get(new INDArrayIndex[]{NDArrayIndex.interval(0L, 0L, true), NDArrayIndex.interval(nOut, nOut * 2)}) : null;
        this.gradientsFlattened.assign(0);
        IActivation activationFn = ((org.deeplearning4j.nn.conf.layers.recurrent.SimpleRnn) layerConf()).getActivationFn();
        long size = permuteIfNWC.size(2);
        INDArray createUninitialized = layerWorkspaceMgr.createUninitialized(ArrayType.ACTIVATION_GRAD, permuteIfNWC.dataType(), permuteIfNWC.shape(), 'f');
        INDArray iNDArray11 = null;
        long max = i > 0 ? Math.max(0L, size - i) : 0L;
        INDArray permuteIfNWC2 = permuteIfNWC(iNDArray);
        long j = size;
        while (true) {
            long j2 = j - 1;
            if (j2 < max) {
                break;
            }
            INDArray dup = permuteIfNWC2.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.point(j2)}).dup();
            INDArray iNDArray12 = ((INDArray) activateHelper.getFirst()).get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.point(j2)});
            INDArray iNDArray13 = ((INDArray) activateHelper.getSecond()).get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.point(j2)});
            INDArray iNDArray14 = hasLayerNorm() ? ((INDArray) activateHelper.getThird()).get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.point(j2)}) : null;
            INDArray iNDArray15 = hasLayerNorm() ? ((INDArray) activateHelper.getFourth()).get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.point(j2)}) : null;
            INDArray iNDArray16 = permuteIfNWC.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.point(j2)});
            INDArray iNDArray17 = createUninitialized.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.point(j2)});
            if (iNDArray11 != null) {
                Nd4j.gemm(iNDArray11, paramWithNoise2, dup, false, true, 1.0d, 1.0d);
                Nd4j.gemm(iNDArray12, iNDArray11, iNDArray6, true, false, 1.0d, 1.0d);
            }
            INDArray iNDArray18 = (INDArray) activationFn.backprop(iNDArray13.dup(), dup).getFirst();
            INDArray iNDArray19 = null;
            if (this.maskArray != null) {
                iNDArray19 = this.maskArray.getColumn(j2, true).castTo(this.dataType);
                iNDArray18.muliColumnVector(iNDArray19);
            }
            if (hasLayerNorm()) {
                iNDArray2 = layerWorkspaceMgr.createUninitialized(ArrayType.BP_WORKING_MEM, iNDArray18.dataType(), iNDArray18.shape());
                INDArray createUninitialized2 = layerWorkspaceMgr.createUninitialized(ArrayType.BP_WORKING_MEM, iNDArray8.dataType(), iNDArray9.shape());
                INDArray createUninitialized3 = layerWorkspaceMgr.createUninitialized(ArrayType.BP_WORKING_MEM, iNDArray7.dataType(), iNDArray7.shape());
                Nd4j.getExecutioner().exec(new LayerNormBp(iNDArray14, iNDArray3, paramWithNoise3, iNDArray18, iNDArray2, createUninitialized2, createUninitialized3, true, new int[]{1}));
                iNDArray9.addi(createUninitialized2);
                iNDArray7.addi(createUninitialized3);
            } else {
                iNDArray2 = iNDArray18;
                iNDArray7.addi(iNDArray18.sum(new int[]{0}));
            }
            Nd4j.gemm(iNDArray16, iNDArray2, iNDArray5, true, false, 1.0d, 1.0d);
            Nd4j.gemm(iNDArray2, paramWithNoise, iNDArray17, false, true, 1.0d, EvaluationBinary.DEFAULT_EDGE_VALUE);
            if (!hasLayerNorm() || j2 <= max) {
                iNDArray11 = iNDArray18;
            } else {
                iNDArray11 = layerWorkspaceMgr.createUninitialized(ArrayType.BP_WORKING_MEM, iNDArray18.dataType(), iNDArray18.shape());
                INDArray createUninitialized4 = layerWorkspaceMgr.createUninitialized(ArrayType.BP_WORKING_MEM, iNDArray8.dataType(), iNDArray10.shape());
                Nd4j.getExecutioner().exec(new LayerNormBp(iNDArray15, iNDArray4, iNDArray18, iNDArray11, createUninitialized4, true, new int[]{1}));
                iNDArray10.addi(createUninitialized4);
            }
            if (this.maskArray != null) {
                iNDArray17.muliColumnVector(iNDArray19);
            }
            j = j2;
        }
        this.weightNoiseParams.clear();
        DefaultGradient defaultGradient = new DefaultGradient(this.gradientsFlattened);
        defaultGradient.gradientForVariable().put("W", iNDArray5);
        defaultGradient.gradientForVariable().put("RW", iNDArray6);
        defaultGradient.gradientForVariable().put("b", iNDArray7);
        if (hasLayerNorm()) {
            defaultGradient.gradientForVariable().put("g", iNDArray8);
        }
        return new Pair<>(defaultGradient, permuteIfNWC(backpropDropOutIfPresent(createUninitialized)));
    }

    @Override // org.deeplearning4j.nn.api.Layer
    public boolean isPretrainLayer() {
        return false;
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.api.Layer
    public INDArray activate(boolean z, LayerWorkspaceMgr layerWorkspaceMgr) {
        return (INDArray) activateHelper(null, z, false, layerWorkspaceMgr).getFirst();
    }

    /* JADX WARN: Multi-variable type inference failed */
    private Quad<INDArray, INDArray, INDArray, INDArray> activateHelper(INDArray iNDArray, boolean z, boolean z2, LayerWorkspaceMgr layerWorkspaceMgr) {
        assertInputSet(false);
        Preconditions.checkState(this.input.rank() == 3, "3D input expected to RNN layer expected, got " + this.input.rank());
        Preconditions.checkState(iNDArray == null || iNDArray.size(0) == this.input.size(0), "Invalid RNN previous state (last time step activations/initialization): rnnTimeStep with different minibatch size, or forgot to call rnnClearPreviousState between batches? Previous step output = [batch, nIn] = %ndShape, current input = [batch, nIn, seqLength] = %ndShape", iNDArray, this.input);
        applyDropOutIfNecessary(z, layerWorkspaceMgr);
        INDArray permuteIfNWC = permuteIfNWC(this.input.castTo(this.dataType));
        long size = permuteIfNWC.size(0);
        long size2 = permuteIfNWC.size(2);
        long nOut = ((org.deeplearning4j.nn.conf.layers.recurrent.SimpleRnn) layerConf()).getNOut();
        INDArray paramWithNoise = getParamWithNoise("W", z, layerWorkspaceMgr);
        INDArray paramWithNoise2 = getParamWithNoise("RW", z, layerWorkspaceMgr);
        INDArray paramWithNoise3 = getParamWithNoise("b", z, layerWorkspaceMgr);
        INDArray paramWithNoise4 = hasLayerNorm() ? getParamWithNoise("g", z, layerWorkspaceMgr) : null;
        INDArray iNDArray2 = paramWithNoise4 != null ? paramWithNoise4.get(new INDArrayIndex[]{NDArrayIndex.interval(0L, 0L, true), NDArrayIndex.interval(0L, nOut)}) : null;
        INDArray iNDArray3 = paramWithNoise4 != null ? paramWithNoise4.get(new INDArrayIndex[]{NDArrayIndex.interval(0L, 0L, true), NDArrayIndex.interval(nOut, nOut * 2)}) : null;
        INDArray createUninitialized = layerWorkspaceMgr.createUninitialized(ArrayType.ACTIVATIONS, paramWithNoise.dataType(), new long[]{size, nOut, size2}, 'f');
        INDArray createUninitialized2 = z2 ? layerWorkspaceMgr.createUninitialized(ArrayType.BP_WORKING_MEM, paramWithNoise.dataType(), createUninitialized.shape()) : null;
        INDArray createUninitialized3 = (z2 && hasLayerNorm()) ? layerWorkspaceMgr.createUninitialized(ArrayType.BP_WORKING_MEM, paramWithNoise.dataType(), createUninitialized.shape(), 'f') : null;
        INDArray createUninitialized4 = (z2 && hasLayerNorm()) ? layerWorkspaceMgr.createUninitialized(ArrayType.BP_WORKING_MEM, paramWithNoise.dataType(), createUninitialized.shape(), 'f') : null;
        if (permuteIfNWC.ordering() != 'f' || Shape.strideDescendingCAscendingF(permuteIfNWC)) {
            permuteIfNWC = layerWorkspaceMgr.dup(ArrayType.ACTIVATIONS, permuteIfNWC, 'f');
        }
        if (!hasLayerNorm()) {
            Nd4j.getExecutioner().exec(new BroadcastCopyOp(createUninitialized, paramWithNoise3, createUninitialized, new int[]{1}));
        }
        IActivation activationFn = ((org.deeplearning4j.nn.conf.layers.recurrent.SimpleRnn) layerConf()).getActivationFn();
        for (int i = 0; i < size2; i++) {
            INDArray iNDArray4 = createUninitialized.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.point(i)});
            INDArray iNDArray5 = permuteIfNWC.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.point(i)});
            if (hasLayerNorm()) {
                INDArray iNDArray6 = (z2 ? createUninitialized3 : createUninitialized).get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.point(i)});
                Nd4j.gemm(iNDArray5, paramWithNoise, iNDArray6, false, false, 1.0d, EvaluationBinary.DEFAULT_EDGE_VALUE);
                Nd4j.getExecutioner().exec(new LayerNorm(iNDArray6, iNDArray2, paramWithNoise3, iNDArray4, true, new int[]{1}));
            } else {
                Nd4j.gemm(iNDArray5, paramWithNoise, iNDArray4, false, false, 1.0d, 1.0d);
            }
            if (i > 0 || iNDArray != null) {
                if (hasLayerNorm()) {
                    INDArray createUninitialized5 = z2 ? createUninitialized4.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.point(i)}) : layerWorkspaceMgr.createUninitialized(ArrayType.FF_WORKING_MEM, iNDArray4.dataType(), iNDArray4.shape(), 'f');
                    Nd4j.gemm(iNDArray, paramWithNoise2, createUninitialized5, false, false, 1.0d, EvaluationBinary.DEFAULT_EDGE_VALUE);
                    INDArray createUninitialized6 = layerWorkspaceMgr.createUninitialized(ArrayType.FF_WORKING_MEM, iNDArray4.dataType(), iNDArray4.shape(), 'f');
                    Nd4j.getExecutioner().exec(new LayerNorm(createUninitialized5, iNDArray3, createUninitialized6, true, new int[]{1}));
                    iNDArray4.addi(createUninitialized6);
                } else {
                    Nd4j.gemm(iNDArray, paramWithNoise2, iNDArray4, false, false, 1.0d, 1.0d);
                }
            }
            if (z2) {
                createUninitialized2.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.point(i)}).assign(iNDArray4);
            }
            activationFn.getActivation(iNDArray4, z);
            if (this.maskArray != null) {
                iNDArray4.muliColumnVector(this.maskArray.getColumn(i, true).castTo(this.dataType));
            }
            iNDArray = iNDArray4;
        }
        if (this.maskArray != null) {
            INDArray castTo = this.maskArray.castTo(this.dataType);
            Nd4j.getExecutioner().exec(new BroadcastMulOp(createUninitialized, castTo, createUninitialized, new int[]{0, 2}));
            if (z2) {
                Nd4j.getExecutioner().exec(new BroadcastMulOp(createUninitialized2, castTo, createUninitialized2, new int[]{0, 2}));
            }
        }
        if (!z2) {
            createUninitialized = permuteIfNWC(createUninitialized);
            createUninitialized2 = permuteIfNWC(createUninitialized2);
            createUninitialized3 = permuteIfNWC(createUninitialized3);
            createUninitialized4 = permuteIfNWC(createUninitialized4);
        }
        return new Quad<>(createUninitialized, createUninitialized2, createUninitialized3, createUninitialized4);
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Override // org.deeplearning4j.nn.layers.BaseLayer
    public boolean hasLayerNorm() {
        return ((org.deeplearning4j.nn.conf.layers.recurrent.SimpleRnn) layerConf()).hasLayerNorm();
    }
}
