package org.deeplearning4j.nn.layers.normalization;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.List;
import org.deeplearning4j.eval.EvaluationBinary;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.conf.CNN2DFormat;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.gradient.DefaultGradient;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.layers.BaseLayer;
import org.deeplearning4j.nn.layers.HelperUtils;
import org.deeplearning4j.nn.layers.LayerHelper;
import org.deeplearning4j.nn.layers.mkldnn.MKLDNNBatchNormHelper;
import org.deeplearning4j.nn.params.BatchNormalizationParamInitializer;
import org.deeplearning4j.nn.workspace.ArrayType;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.deeplearning4j.optimize.api.TrainingListener;
import org.nd4j.common.primitives.Pair;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastAddOp;
import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastDivOp;
import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastMulOp;
import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastSubOp;
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.DivOp;
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.SubOp;
import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.exception.ND4JOpProfilerException;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.ops.transforms.Transforms;
import org.nd4j.shade.guava.primitives.Longs;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/deeplearning4j/nn/layers/normalization/BatchNormalization.class */
public class BatchNormalization extends BaseLayer<org.deeplearning4j.nn.conf.layers.BatchNormalization> {
    private static final Logger log = LoggerFactory.getLogger(BatchNormalization.class);
    protected static final double ONE_ON_2LOGE_10 = 1.0d / (2.0d * Math.log(10.0d));
    BatchNormalizationHelper helper;
    protected int helperCountFail;
    protected int index;
    protected List<TrainingListener> listeners;
    protected INDArray std;
    protected INDArray xMu;
    protected INDArray xHat;
    public static final String BATCH_NORM_CUDNN_HELPER_CLASS_NAME = "org.deeplearning4j.cuda.normalization.CudnnBatchNormalizationHelper";

    public BatchNormalization(NeuralNetConfiguration neuralNetConfiguration, DataType dataType) {
        super(neuralNetConfiguration, dataType);
        this.helper = null;
        this.helperCountFail = 0;
        this.index = 0;
        this.listeners = new ArrayList();
        initializeHelper();
    }

    void initializeHelper() {
        this.helper = (BatchNormalizationHelper) HelperUtils.createHelper(BATCH_NORM_CUDNN_HELPER_CLASS_NAME, MKLDNNBatchNormHelper.class.getName(), BatchNormalizationHelper.class, layerConf().getLayerName(), this.dataType);
        if (this.helper == null || this.helper.checkSupported(layerConf().getEps(), layerConf().isLockGammaBeta())) {
            return;
        }
        log.debug("Removed helper {} as not supported with epsilon {}, lockGammaBeta={}", new Object[]{this.helper.getClass(), Double.valueOf(layerConf().getEps()), Boolean.valueOf(layerConf().isLockGammaBeta())});
        this.helper = null;
    }

    @Override // org.deeplearning4j.nn.layers.AbstractLayer, org.deeplearning4j.nn.api.Layer
    public Layer.Type type() {
        return Layer.Type.NORMALIZATION;
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.api.Layer
    public Pair<Gradient, INDArray> backpropGradient(INDArray iNDArray, LayerWorkspaceMgr layerWorkspaceMgr) {
        INDArray iNDArray2;
        INDArray iNDArray3;
        INDArray iNDArray4;
        INDArray mean;
        INDArray var;
        INDArray iNDArray5;
        INDArray iNDArray6;
        assertInputSet(true);
        long[] shape = getShape(iNDArray);
        long size = iNDArray.size(0);
        org.deeplearning4j.nn.conf.layers.BatchNormalization layerConf = layerConf();
        CNN2DFormat cnn2DFormat = layerConf().getCnn2DFormat();
        boolean z = cnn2DFormat == CNN2DFormat.NCHW;
        int i = (iNDArray.rank() == 2 || z) ? 1 : 3;
        INDArray castTo = this.input.castTo(this.dataType);
        INDArray iNDArray7 = this.params.get(BatchNormalizationParamInitializer.GLOBAL_MEAN);
        INDArray iNDArray8 = this.params.get(BatchNormalizationParamInitializer.GLOBAL_VAR);
        INDArray iNDArray9 = this.params.get(BatchNormalizationParamInitializer.GLOBAL_LOG_STD);
        INDArray iNDArray10 = null;
        INDArray iNDArray11 = null;
        INDArray iNDArray12 = this.gradientViews.get(BatchNormalizationParamInitializer.GLOBAL_MEAN);
        INDArray iNDArray13 = this.gradientViews.get(BatchNormalizationParamInitializer.GLOBAL_VAR);
        INDArray iNDArray14 = this.gradientViews.get(BatchNormalizationParamInitializer.GLOBAL_LOG_STD);
        if (layerConf.isLockGammaBeta()) {
            long[] jArr = {1, shape[i]};
            iNDArray2 = Nd4j.createUninitialized(this.dataType, jArr, 'c');
            iNDArray3 = Nd4j.createUninitialized(this.dataType, jArr, 'c');
        } else {
            iNDArray10 = getParam(BatchNormalizationParamInitializer.GAMMA);
            iNDArray11 = getParam(BatchNormalizationParamInitializer.BETA);
            iNDArray2 = this.gradientViews.get(BatchNormalizationParamInitializer.GAMMA);
            iNDArray3 = this.gradientViews.get(BatchNormalizationParamInitializer.BETA);
        }
        DefaultGradient defaultGradient = new DefaultGradient();
        if (this.helper != null && (this.helperCountFail == 0 || !layerConf().isCudnnAllowFallback())) {
            if (layerConf.isLockGammaBeta()) {
                iNDArray10 = Nd4j.createUninitialized(this.dataType, new long[]{1, shape[i]}).assign(Double.valueOf(layerConf.getGamma()));
            }
            if (castTo.rank() == 2) {
                long[] jArr2 = z ? new long[]{castTo.size(0), castTo.size(1), 1, 1} : new long[]{castTo.size(0), 1, 1, castTo.size(1)};
                iNDArray5 = castTo.reshape(castTo.ordering(), jArr2);
                iNDArray6 = iNDArray.reshape(iNDArray.ordering(), jArr2);
            } else {
                iNDArray5 = castTo;
                iNDArray6 = iNDArray;
            }
            Pair<Gradient, INDArray> pair = null;
            try {
                pair = this.helper.backpropGradient(iNDArray5, iNDArray6, shape, iNDArray10, iNDArray11, iNDArray2, iNDArray3, layerConf.getEps(), cnn2DFormat, layerWorkspaceMgr);
            } catch (ND4JOpProfilerException e) {
                throw e;
            } catch (Throwable th) {
                if (th.getMessage() != null && th.getMessage().contains("Failed to allocate")) {
                    throw th;
                }
                if (!layerConf().isCudnnAllowFallback()) {
                    throw new RuntimeException("Error during BatchNormalization CuDNN helper backprop - isCudnnAllowFallback() is set to false", th);
                }
                this.helperCountFail++;
                log.warn("CuDNN BatchNormalization backprop execution failed - falling back on built-in implementation", th);
            }
            if (pair != null) {
                ((Gradient) pair.getFirst()).setGradientFor(BatchNormalizationParamInitializer.GLOBAL_MEAN, iNDArray12);
                if (layerConf().isUseLogStd()) {
                    ((Gradient) pair.getFirst()).setGradientFor(BatchNormalizationParamInitializer.GLOBAL_LOG_STD, iNDArray14);
                } else {
                    ((Gradient) pair.getFirst()).setGradientFor(BatchNormalizationParamInitializer.GLOBAL_VAR, iNDArray13);
                }
                if (castTo.rank() == 2) {
                    INDArray iNDArray15 = (INDArray) pair.getSecond();
                    pair.setSecond(iNDArray15.reshape(iNDArray15.ordering(), new long[]{iNDArray15.size(0), iNDArray15.size(1)}));
                }
                INDArray meanCache = this.helper.getMeanCache(this.dataType);
                INDArray varCache = this.helper.getVarCache(this.dataType);
                Nd4j.getExecutioner().exec(new SubOp(iNDArray7, meanCache, iNDArray12));
                iNDArray12.muli(Double.valueOf(1.0d - layerConf().getDecay()));
                if (layerConf().isUseLogStd()) {
                    INDArray assign = Nd4j.createUninitialized(this.dataType, iNDArray9.shape()).assign(Double.valueOf(10.0d));
                    Transforms.pow(assign, iNDArray9, false);
                    assign.muli(assign);
                    double decay = layerConf().getDecay();
                    Nd4j.getExecutioner().exec(new DivOp(assign, assign.mul(Double.valueOf(decay)).addi(varCache.mul(Double.valueOf(1.0d - decay))), iNDArray14));
                    Transforms.log(iNDArray14, false);
                    iNDArray14.muli(Double.valueOf(ONE_ON_2LOGE_10));
                } else {
                    Nd4j.getExecutioner().exec(new SubOp(iNDArray8, varCache, iNDArray13));
                    iNDArray13.muli(Double.valueOf(1.0d - layerConf().getDecay()));
                }
                return pair;
            }
        }
        if (iNDArray.rank() == 2) {
            if (this.xHat == null && this.helper != null) {
                INDArray meanCache2 = this.helper.getMeanCache(this.dataType);
                this.std = Transforms.sqrt(this.helper.getVarCache(this.dataType).addi(Double.valueOf(layerConf().getEps())));
                this.xMu = Nd4j.createUninitialized(this.dataType, castTo.shape(), castTo.ordering());
                this.xMu = Nd4j.getExecutioner().exec(new BroadcastSubOp(castTo, meanCache2, this.xMu, new int[]{1}));
                this.xHat = Nd4j.createUninitialized(this.dataType, castTo.shape(), castTo.ordering());
                this.xHat = Nd4j.getExecutioner().exec(new BroadcastDivOp(this.xMu, this.std, this.xHat, new int[]{1}));
            }
            INDArray sum = iNDArray.sum(true, new int[]{0});
            INDArray sum2 = iNDArray.mul(this.xHat).sum(true, new int[]{0});
            INDArray mul = layerConf.isLockGammaBeta() ? iNDArray.mul(Double.valueOf(layerConf.getGamma())) : iNDArray.mulRowVector(iNDArray10);
            INDArray muli = mul.mul(this.xMu).sum(true, new int[]{0}).muli(Double.valueOf(-0.5d)).muli(Transforms.pow(this.std, Double.valueOf(-3.0d), true));
            INDArray addiRowVector = mul.diviRowVector(this.std).addi(this.xMu.muliRowVector(muli.muli(Double.valueOf(2.0d / size)))).addiRowVector(mul.sum(true, new int[]{0}).divi(this.std).negi().addi(this.xMu.sum(true, new int[]{0}).muli(Double.valueOf((-2.0d) / size)).muli(muli)).muli(Double.valueOf(1.0d / size)));
            iNDArray2.assign(sum2);
            iNDArray3.assign(sum);
            defaultGradient.setGradientFor(BatchNormalizationParamInitializer.GAMMA, iNDArray2);
            defaultGradient.setGradientFor(BatchNormalizationParamInitializer.BETA, iNDArray3);
            iNDArray4 = addiRowVector;
            mean = castTo.mean(new int[]{0});
            var = castTo.var(false, new int[]{0});
        } else {
            if (iNDArray.rank() != 4) {
                throw new IllegalStateException("The layer prior to BatchNorm in the configuration is not currently supported. " + layerId());
            }
            int[] iArr = z ? new int[]{0, 2, 3} : new int[]{0, 1, 2};
            int i2 = z ? 2 : 1;
            int i3 = z ? 3 : 2;
            if (this.xHat == null && this.helper != null) {
                INDArray meanCache3 = this.helper.getMeanCache(this.dataType);
                this.std = Transforms.sqrt(this.helper.getVarCache(this.dataType).addi(Double.valueOf(layerConf().getEps()))).detach();
                this.xMu = Nd4j.createUninitialized(this.dataType, castTo.shape(), castTo.ordering()).detach();
                this.xMu = Nd4j.getExecutioner().exec(new BroadcastSubOp(castTo, meanCache3, this.xMu, new int[]{i})).detach();
                this.xHat = Nd4j.createUninitialized(this.dataType, castTo.shape(), castTo.ordering()).detach();
                this.xHat = Nd4j.getExecutioner().exec(new BroadcastDivOp(this.xMu, this.std, this.xHat, new int[]{i})).detach();
            }
            INDArray sum3 = iNDArray.sum(iArr);
            INDArray sum4 = iNDArray.mul(this.xHat).sum(iArr);
            INDArray mul2 = layerConf.isLockGammaBeta() ? iNDArray.mul(Double.valueOf(layerConf.getGamma())) : Nd4j.getExecutioner().exec(new BroadcastMulOp(iNDArray, iNDArray10, Nd4j.createUninitialized(iNDArray.dataType(), iNDArray.shape(), iNDArray.ordering()), new int[]{i}));
            INDArray muli2 = mul2.mul(this.xMu).sum(iArr).muli(Double.valueOf(-0.5d)).muli(Transforms.pow(this.std, Double.valueOf(-3.0d), true));
            long size2 = castTo.size(0) * castTo.size(i2) * castTo.size(i3);
            INDArray addi = mul2.sum(iArr).divi(this.std).negi().addi(this.xMu.sum(iArr).muli(Double.valueOf((-2.0d) / size2)).muli(muli2));
            INDArray addi2 = Nd4j.getExecutioner().exec(new BroadcastDivOp(mul2, this.std, mul2, new int[]{i})).addi(Nd4j.getExecutioner().exec(new BroadcastMulOp(this.xMu, muli2.muli(Double.valueOf(2.0d / size2)), this.xMu, new int[]{i})));
            Nd4j.getExecutioner().execAndReturn(new BroadcastAddOp(addi2, addi.muli(Double.valueOf(1.0d / size2)), addi2, new int[]{i}));
            iNDArray2.assign(sum4);
            iNDArray3.assign(sum3);
            defaultGradient.setGradientFor(BatchNormalizationParamInitializer.GAMMA, iNDArray2);
            defaultGradient.setGradientFor(BatchNormalizationParamInitializer.BETA, iNDArray3);
            iNDArray4 = addi2;
            mean = castTo.mean(iArr);
            var = castTo.var(false, iArr);
        }
        Nd4j.getExecutioner().exec(new SubOp(iNDArray7, mean, iNDArray12));
        iNDArray12.muli(Double.valueOf(1.0d - layerConf().getDecay()));
        if (layerConf().isUseLogStd()) {
            INDArray valueArrayOf = Nd4j.valueArrayOf(iNDArray9.shape(), 10.0d, iNDArray7.dataType());
            Transforms.pow(valueArrayOf, iNDArray9, false);
            valueArrayOf.muli(valueArrayOf);
            double decay2 = layerConf().getDecay();
            Nd4j.getExecutioner().exec(new DivOp(valueArrayOf, valueArrayOf.mul(Double.valueOf(decay2)).addi(var.mul(Double.valueOf(1.0d - decay2)).reshape(valueArrayOf.shape())), iNDArray14));
            Transforms.log(iNDArray14, false);
            iNDArray14.muli(Double.valueOf(ONE_ON_2LOGE_10));
        } else {
            Nd4j.getExecutioner().exec(new SubOp(iNDArray8, var, iNDArray13));
            iNDArray13.muli(Double.valueOf(1.0d - layerConf().getDecay()));
        }
        defaultGradient.setGradientFor(BatchNormalizationParamInitializer.GLOBAL_MEAN, iNDArray12);
        if (layerConf().isUseLogStd()) {
            defaultGradient.setGradientFor(BatchNormalizationParamInitializer.GLOBAL_LOG_STD, iNDArray14);
        } else {
            defaultGradient.setGradientFor(BatchNormalizationParamInitializer.GLOBAL_VAR, iNDArray13);
        }
        INDArray leverageTo = layerWorkspaceMgr.leverageTo(ArrayType.ACTIVATION_GRAD, iNDArray4);
        this.xHat = null;
        this.xMu = null;
        return new Pair<>(defaultGradient, leverageTo);
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.layers.AbstractLayer, org.deeplearning4j.nn.api.Model
    public void fit(INDArray iNDArray, LayerWorkspaceMgr layerWorkspaceMgr) {
        throw new UnsupportedOperationException("Not supported");
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.api.Layer
    public INDArray activate(boolean z, LayerWorkspaceMgr layerWorkspaceMgr) {
        assertInputSet(false);
        return preOutput(this.input, z ? Layer.TrainingMode.TRAIN : Layer.TrainingMode.TEST, layerWorkspaceMgr);
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.layers.AbstractLayer, org.deeplearning4j.nn.api.Model
    public Gradient gradient() {
        return this.gradient;
    }

    /* JADX WARN: Type inference failed for: r1v119, types: [long[], long[][]] */
    public INDArray preOutput(INDArray iNDArray, Layer.TrainingMode trainingMode, LayerWorkspaceMgr layerWorkspaceMgr) {
        INDArray param;
        INDArray param2;
        INDArray exec;
        INDArray var;
        int i = 1;
        boolean z = false;
        if (iNDArray.rank() == 3) {
            iNDArray = iNDArray.reshape(Longs.concat((long[][]) new long[]{new long[]{1}, iNDArray.shape()}));
            z = true;
        }
        if (iNDArray.rank() == 4 && layerConf().getCnn2DFormat() == CNN2DFormat.NHWC) {
            i = 3;
        }
        if (iNDArray.size(i) != layerConf().getNOut()) {
            long nIn = layerConf().getNIn();
            Arrays.toString(iNDArray.shape());
            IllegalArgumentException illegalArgumentException = new IllegalArgumentException("input.size(" + i + ") does not match expected input size of " + nIn + " - got input array with shape " + illegalArgumentException);
            throw illegalArgumentException;
        }
        INDArray castTo = iNDArray.castTo(this.dataType);
        org.deeplearning4j.nn.conf.layers.BatchNormalization layerConf = layerConf();
        long[] shape = getShape(castTo);
        INDArray iNDArray2 = null;
        INDArray iNDArray3 = null;
        INDArray param3 = getParam(BatchNormalizationParamInitializer.GLOBAL_MEAN);
        INDArray param4 = getParam(BatchNormalizationParamInitializer.GLOBAL_VAR);
        if (!layerConf.isLockGammaBeta()) {
            iNDArray2 = getParam(BatchNormalizationParamInitializer.GAMMA);
            iNDArray3 = getParam(BatchNormalizationParamInitializer.BETA);
        } else if (this.helper != null && this.input.rank() == 4) {
            long[] jArr = {1, layerConf().getNOut()};
            iNDArray2 = Nd4j.valueArrayOf(jArr, layerConf().getGamma(), this.dataType);
            iNDArray3 = Nd4j.valueArrayOf(jArr, layerConf().getBeta(), this.dataType);
        }
        if (this.helper != null && (this.helperCountFail == 0 || !layerConf().isCudnnAllowFallback())) {
            INDArray iNDArray4 = castTo;
            if (castTo.rank() == 2) {
                iNDArray4 = castTo.reshape(castTo.ordering(), new long[]{iNDArray4.size(0), iNDArray4.size(1), 1, 1});
            }
            double decay = layerConf.getDecay();
            INDArray iNDArray5 = null;
            if (param4 == null) {
                try {
                    INDArray param5 = getParam(BatchNormalizationParamInitializer.GLOBAL_LOG_STD);
                    param4 = Transforms.pow(Nd4j.valueArrayOf(param5.shape(), 10.0d, this.dataType), param5, false);
                    param4.muli(param4);
                } catch (ND4JOpProfilerException e) {
                    throw e;
                } catch (Throwable th) {
                    if (th.getMessage() != null && th.getMessage().contains("Failed to allocate")) {
                        throw th;
                    }
                    if (!layerConf().isCudnnAllowFallback()) {
                        throw new RuntimeException("Error during BatchNormalization CuDNN helper backprop - isCudnnAllowFallback() is set to false", th);
                    }
                    this.helperCountFail++;
                    log.warn("CuDNN BatchNormalization forward pass execution failed - falling back on built-in implementation", th);
                }
            }
            iNDArray5 = this.helper.preOutput(iNDArray4, trainingMode == Layer.TrainingMode.TRAIN, shape, iNDArray2, iNDArray3, param3, param4, decay, layerConf.getEps(), layerConf().getCnn2DFormat(), layerWorkspaceMgr);
            if (iNDArray5 != null) {
                return this.input.rank() == 2 ? iNDArray5.reshape(iNDArray5.ordering(), new long[]{iNDArray5.size(0), iNDArray5.size(1)}) : (iNDArray.rank() == 3 && iNDArray5.rank() == 4) ? iNDArray5.reshape(iNDArray5.ordering(), new long[]{iNDArray5.size(1), iNDArray5.size(2), iNDArray5.size(3)}) : iNDArray5;
            }
        }
        boolean z2 = layerConf().getCnn2DFormat() == CNN2DFormat.NCHW;
        int i2 = z2 ? 1 : 3;
        int[] iArr = z2 ? new int[]{0, 2, 3} : new int[]{0, 1, 2};
        char c = z2 ? (char) 2 : (char) 1;
        char c2 = z2 ? (char) 3 : (char) 2;
        if (trainingMode == Layer.TrainingMode.TRAIN) {
            switch (castTo.rank()) {
                case 2:
                    param = castTo.mean(new int[]{0});
                    var = castTo.var(false, new int[]{0});
                    break;
                case 4:
                    param = castTo.mean(iArr);
                    var = castTo.var(false, iArr);
                    break;
                default:
                    throw new IllegalStateException("Batch normalization on activations of rank " + castTo.rank() + " not supported " + layerId());
            }
            this.std = Transforms.sqrt(layerWorkspaceMgr.dup(ArrayType.INPUT, var).addi(Double.valueOf(layerConf().getEps())), false);
        } else {
            param = getParam(BatchNormalizationParamInitializer.GLOBAL_MEAN);
            if (layerConf().isUseLogStd()) {
                INDArray param6 = getParam(BatchNormalizationParamInitializer.GLOBAL_LOG_STD);
                param2 = Transforms.pow(Nd4j.valueArrayOf(param6.shape(), 10.0d, this.dataType), param6);
                param2.muli(param2);
            } else {
                param2 = getParam(BatchNormalizationParamInitializer.GLOBAL_VAR);
            }
            this.std = Transforms.sqrt(param2.add(Double.valueOf(layerConf().getEps())));
        }
        if (castTo.rank() == 2) {
            this.xMu = layerWorkspaceMgr.leverageTo(ArrayType.INPUT, castTo.subRowVector(param));
            this.xHat = layerWorkspaceMgr.leverageTo(ArrayType.INPUT, this.xMu.divRowVector(this.std));
            if (layerConf.isLockGammaBeta()) {
                double gamma = layerConf.getGamma();
                double beta = layerConf.getBeta();
                exec = (gamma == 1.0d || beta == EvaluationBinary.DEFAULT_EDGE_VALUE) ? this.xHat : this.xHat.mul(Double.valueOf(gamma)).addi(Double.valueOf(beta));
            } else {
                exec = this.xHat.mulRowVector(iNDArray2).addiRowVector(iNDArray3);
            }
        } else {
            if (castTo.rank() != 4) {
                throw new IllegalStateException("The layer prior to BatchNorm in the configuration is not currently supported. " + layerId());
            }
            if (!Shape.strideDescendingCAscendingF(castTo)) {
                castTo = castTo.dup();
            }
            this.xMu = Nd4j.createUninitialized(castTo.dataType(), castTo.shape(), castTo.ordering());
            this.xMu = Nd4j.getExecutioner().exec(new BroadcastSubOp(castTo, param, this.xMu, new int[]{i2}));
            this.xHat = Nd4j.createUninitialized(castTo.dataType(), castTo.shape(), castTo.ordering());
            this.xHat = Nd4j.getExecutioner().exec(new BroadcastDivOp(this.xMu, this.std, this.xHat, new int[]{i2}));
            if (layerConf.isLockGammaBeta()) {
                double gamma2 = layerConf.getGamma();
                double beta2 = layerConf.getBeta();
                exec = (gamma2 == 1.0d || beta2 == EvaluationBinary.DEFAULT_EDGE_VALUE) ? this.xHat : this.xHat.mul(Double.valueOf(gamma2)).addi(Double.valueOf(beta2));
            } else {
                INDArray exec2 = Nd4j.getExecutioner().exec(new BroadcastMulOp(this.xHat, iNDArray2, layerWorkspaceMgr.createUninitialized(ArrayType.ACTIVATIONS, castTo.dataType(), castTo.shape(), castTo.ordering()), new int[]{i2}));
                exec = Nd4j.getExecutioner().exec(new BroadcastAddOp(exec2, iNDArray3, exec2, new int[]{i2}));
            }
        }
        INDArray leverageTo = layerWorkspaceMgr.leverageTo(ArrayType.ACTIVATIONS, exec);
        this.xHat = this.xHat.detach();
        this.xMu = this.xMu.detach();
        if (z) {
            leverageTo = leverageTo.reshape(new long[]{leverageTo.size(1), leverageTo.size(2), leverageTo.size(3)});
        }
        return leverageTo;
    }

    @Override // org.deeplearning4j.nn.layers.AbstractLayer, org.deeplearning4j.nn.api.Layer
    public Collection<TrainingListener> getListeners() {
        return this.listeners;
    }

    @Override // org.deeplearning4j.nn.layers.AbstractLayer, org.deeplearning4j.nn.api.Layer, org.deeplearning4j.nn.api.Model
    public void setListeners(TrainingListener... trainingListenerArr) {
        this.listeners = new ArrayList(Arrays.asList(trainingListenerArr));
    }

    @Override // org.deeplearning4j.nn.layers.AbstractLayer, org.deeplearning4j.nn.api.Layer
    public void setIndex(int i) {
        this.index = i;
    }

    @Override // org.deeplearning4j.nn.layers.AbstractLayer, org.deeplearning4j.nn.api.Layer
    public int getIndex() {
        return this.index;
    }

    @Override // org.deeplearning4j.nn.api.Layer
    public boolean isPretrainLayer() {
        return false;
    }

    @Override // org.deeplearning4j.nn.layers.AbstractLayer, org.deeplearning4j.nn.api.Layer
    public LayerHelper getHelper() {
        return this.helper;
    }

    public long[] getShape(INDArray iNDArray) {
        if (iNDArray.rank() == 2) {
            return new long[]{1, iNDArray.size(1)};
        }
        if (iNDArray.rank() == 4) {
            return new long[]{1, iNDArray.size(layerConf().getCnn2DFormat() == CNN2DFormat.NCHW ? 1 : 3)};
        }
        if (iNDArray.rank() != 3) {
            throw new IllegalStateException("Unable to process input of rank " + iNDArray.rank() + " " + layerId());
        }
        long size = iNDArray.size(1);
        long size2 = iNDArray.size(2);
        if (iNDArray.size(0) <= 1 || size * size2 != iNDArray.length()) {
            return new long[]{1, size * size2};
        }
        throw new IllegalArgumentException("Illegal input for batch size " + layerId());
    }

    @Override // org.deeplearning4j.nn.layers.AbstractLayer, org.deeplearning4j.nn.api.Trainable
    public boolean updaterDivideByMinibatch(String str) {
        return (BatchNormalizationParamInitializer.GLOBAL_MEAN.equals(str) || BatchNormalizationParamInitializer.GLOBAL_VAR.equals(str) || BatchNormalizationParamInitializer.GLOBAL_LOG_STD.equals(str)) ? false : true;
    }
}
