/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.nn.layers.normalization;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.List;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.conf.CNN2DFormat;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.gradient.DefaultGradient;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.layers.BaseLayer;
import org.deeplearning4j.nn.layers.LayerHelper;
import org.deeplearning4j.nn.layers.mkldnn.MKLDNNBatchNormHelper;
import org.deeplearning4j.nn.layers.normalization.BatchNormalizationHelper;
import org.deeplearning4j.nn.workspace.ArrayType;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.deeplearning4j.optimize.api.TrainingListener;
import org.nd4j.common.primitives.Pair;
import org.nd4j.common.util.OneTimeLogger;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.BroadcastOp;
import org.nd4j.linalg.api.ops.CustomOp;
import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastAddOp;
import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastDivOp;
import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastMulOp;
import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastSubOp;
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.DivOp;
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.SubOp;
import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.exception.ND4JOpProfilerException;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.ops.transforms.Transforms;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class BatchNormalization
extends BaseLayer<org.deeplearning4j.nn.conf.layers.BatchNormalization> {
    private static final Logger log = LoggerFactory.getLogger(BatchNormalization.class);
    protected static final double ONE_ON_2LOGE_10 = 1.0 / (2.0 * Math.log(10.0));
    BatchNormalizationHelper helper = null;
    protected int helperCountFail = 0;
    protected int index = 0;
    protected List<TrainingListener> listeners = new ArrayList<TrainingListener>();
    protected INDArray std;
    protected INDArray xMu;
    protected INDArray xHat;

    public BatchNormalization(NeuralNetConfiguration conf, DataType dataType) {
        super(conf, dataType);
        this.initializeHelper();
    }

    void initializeHelper() {
        block7: {
            String backend = Nd4j.getExecutioner().getEnvironmentInformation().getProperty("backend");
            if ("CUDA".equalsIgnoreCase(backend)) {
                try {
                    this.helper = Class.forName("org.deeplearning4j.cuda.normalization.CudnnBatchNormalizationHelper").asSubclass(BatchNormalizationHelper.class).getConstructor(DataType.class).newInstance(this.dataType);
                    log.debug("CudnnBatchNormalizationHelper successfully initialized");
                }
                catch (Throwable t) {
                    if (!(t instanceof ClassNotFoundException)) {
                        log.warn("Could not initialize CudnnBatchNormalizationHelper", t);
                        break block7;
                    }
                    OneTimeLogger.info((Logger)log, (String)"cuDNN not found: use cuDNN for better GPU performance by including the deeplearning4j-cuda module. For more information, please refer to: https://deeplearning4j.konduit.ai/config/backends/config-cudnn", (Object[])new Object[]{t});
                }
            } else if ("CPU".equalsIgnoreCase(backend)) {
                this.helper = new MKLDNNBatchNormHelper(this.dataType);
                log.trace("Created MKLDNNBatchNormHelper, layer {}", (Object)((org.deeplearning4j.nn.conf.layers.BatchNormalization)this.layerConf()).getLayerName());
            }
        }
        if (this.helper != null && !this.helper.checkSupported(((org.deeplearning4j.nn.conf.layers.BatchNormalization)this.layerConf()).getEps(), ((org.deeplearning4j.nn.conf.layers.BatchNormalization)this.layerConf()).isLockGammaBeta())) {
            log.debug("Removed helper {} as not supported with epsilon {}, lockGammaBeta={}", new Object[]{this.helper.getClass(), ((org.deeplearning4j.nn.conf.layers.BatchNormalization)this.layerConf()).getEps(), ((org.deeplearning4j.nn.conf.layers.BatchNormalization)this.layerConf()).isLockGammaBeta()});
            this.helper = null;
        }
    }

    @Override
    public Layer.Type type() {
        return Layer.Type.NORMALIZATION;
    }

    @Override
    public Pair<Gradient, INDArray> backpropGradient(INDArray epsilon, LayerWorkspaceMgr workspaceMgr) {
        INDArray batchVar;
        INDArray batchMean;
        INDArray nextEpsilon;
        INDArray dBetaView;
        INDArray dGammaView;
        this.assertInputSet(true);
        long[] shape = this.getShape(epsilon);
        long batchSize = epsilon.size(0);
        org.deeplearning4j.nn.conf.layers.BatchNormalization layerConf = (org.deeplearning4j.nn.conf.layers.BatchNormalization)this.layerConf();
        CNN2DFormat format = ((org.deeplearning4j.nn.conf.layers.BatchNormalization)this.layerConf()).getCnn2DFormat();
        boolean nchw = format == CNN2DFormat.NCHW;
        int chIdx = epsilon.rank() == 2 || nchw ? 1 : 3;
        INDArray input = this.input.castTo(this.dataType);
        INDArray globalMean = (INDArray)this.params.get("mean");
        INDArray globalVar = (INDArray)this.params.get("var");
        INDArray globalLog10Std = (INDArray)this.params.get("log10stdev");
        INDArray gamma = null;
        INDArray beta = null;
        INDArray dGlobalMeanView = (INDArray)this.gradientViews.get("mean");
        INDArray dGlobalVarView = (INDArray)this.gradientViews.get("var");
        INDArray dGlobalLog10StdView = (INDArray)this.gradientViews.get("log10stdev");
        if (layerConf.isLockGammaBeta()) {
            long[] tempShape = new long[]{1L, shape[chIdx]};
            dGammaView = Nd4j.createUninitialized((DataType)this.dataType, (long[])tempShape, (char)'c');
            dBetaView = Nd4j.createUninitialized((DataType)this.dataType, (long[])tempShape, (char)'c');
        } else {
            gamma = this.getParam("gamma");
            beta = this.getParam("beta");
            dGammaView = (INDArray)this.gradientViews.get("gamma");
            dBetaView = (INDArray)this.gradientViews.get("beta");
        }
        DefaultGradient retGradient = new DefaultGradient();
        if (!(this.helper == null || this.helperCountFail != 0 && ((org.deeplearning4j.nn.conf.layers.BatchNormalization)this.layerConf()).isCudnnAllowFallback())) {
            INDArray eps;
            INDArray in;
            if (layerConf.isLockGammaBeta()) {
                gamma = Nd4j.createUninitialized((DataType)this.dataType, (long[])new long[]{1L, shape[chIdx]}).assign((Number)layerConf.getGamma());
            }
            if (input.rank() == 2) {
                long[] lArray;
                if (nchw) {
                    long[] lArray2 = new long[4];
                    lArray2[0] = input.size(0);
                    lArray2[1] = input.size(1);
                    lArray2[2] = 1L;
                    lArray = lArray2;
                    lArray2[3] = 1L;
                } else {
                    long[] lArray3 = new long[4];
                    lArray3[0] = input.size(0);
                    lArray3[1] = 1L;
                    lArray3[2] = 1L;
                    lArray = lArray3;
                    lArray3[3] = input.size(1);
                }
                long[] shapeTemp = lArray;
                in = input.reshape(input.ordering(), shapeTemp);
                eps = epsilon.reshape(epsilon.ordering(), shapeTemp);
            } else {
                in = input;
                eps = epsilon;
            }
            Pair<Gradient, INDArray> ret = null;
            try {
                ret = this.helper.backpropGradient(in, eps, shape, gamma, beta, dGammaView, dBetaView, layerConf.getEps(), format, workspaceMgr);
            }
            catch (ND4JOpProfilerException e) {
                throw e;
            }
            catch (Throwable t) {
                if (t.getMessage() != null && t.getMessage().contains("Failed to allocate")) {
                    throw t;
                }
                if (((org.deeplearning4j.nn.conf.layers.BatchNormalization)this.layerConf()).isCudnnAllowFallback()) {
                    ++this.helperCountFail;
                    log.warn("CuDNN BatchNormalization backprop execution failed - falling back on built-in implementation", t);
                }
                throw new RuntimeException("Error during BatchNormalization CuDNN helper backprop - isCudnnAllowFallback() is set to false", t);
            }
            if (ret != null) {
                ((Gradient)ret.getFirst()).setGradientFor("mean", dGlobalMeanView);
                if (((org.deeplearning4j.nn.conf.layers.BatchNormalization)this.layerConf()).isUseLogStd()) {
                    ((Gradient)ret.getFirst()).setGradientFor("log10stdev", dGlobalLog10StdView);
                } else {
                    ((Gradient)ret.getFirst()).setGradientFor("var", dGlobalVarView);
                }
                if (input.rank() == 2) {
                    INDArray e = (INDArray)ret.getSecond();
                    ret.setSecond((Object)e.reshape(e.ordering(), new long[]{e.size(0), e.size(1)}));
                }
                INDArray batchMean2 = this.helper.getMeanCache(this.dataType);
                INDArray batchVar2 = this.helper.getVarCache(this.dataType);
                Nd4j.getExecutioner().exec((CustomOp)new SubOp(globalMean, batchMean2, dGlobalMeanView));
                dGlobalMeanView.muli((Number)(1.0 - ((org.deeplearning4j.nn.conf.layers.BatchNormalization)this.layerConf()).getDecay()));
                if (((org.deeplearning4j.nn.conf.layers.BatchNormalization)this.layerConf()).isUseLogStd()) {
                    INDArray vari = Nd4j.createUninitialized((DataType)this.dataType, (long[])globalLog10Std.shape()).assign((Number)10.0);
                    Transforms.pow((INDArray)vari, (INDArray)globalLog10Std, (boolean)false);
                    vari.muli(vari);
                    double decay = ((org.deeplearning4j.nn.conf.layers.BatchNormalization)this.layerConf()).getDecay();
                    INDArray varip1 = vari.mul((Number)decay).addi(batchVar2.mul((Number)(1.0 - decay)));
                    Nd4j.getExecutioner().exec((CustomOp)new DivOp(vari, varip1, dGlobalLog10StdView));
                    Transforms.log((INDArray)dGlobalLog10StdView, (boolean)false);
                    dGlobalLog10StdView.muli((Number)ONE_ON_2LOGE_10);
                } else {
                    Nd4j.getExecutioner().exec((CustomOp)new SubOp(globalVar, batchVar2, dGlobalVarView));
                    dGlobalVarView.muli((Number)(1.0 - ((org.deeplearning4j.nn.conf.layers.BatchNormalization)this.layerConf()).getDecay()));
                }
                return ret;
            }
        }
        if (epsilon.rank() == 2) {
            if (this.xHat == null && this.helper != null) {
                INDArray mean = this.helper.getMeanCache(this.dataType);
                this.std = Transforms.sqrt((INDArray)this.helper.getVarCache(this.dataType).addi((Number)((org.deeplearning4j.nn.conf.layers.BatchNormalization)this.layerConf()).getEps()));
                this.xMu = Nd4j.createUninitialized((DataType)this.dataType, (long[])input.shape(), (char)input.ordering());
                this.xMu = Nd4j.getExecutioner().exec((BroadcastOp)new BroadcastSubOp(input, mean, this.xMu, new int[]{1}));
                this.xHat = Nd4j.createUninitialized((DataType)this.dataType, (long[])input.shape(), (char)input.ordering());
                this.xHat = Nd4j.getExecutioner().exec((BroadcastOp)new BroadcastDivOp(this.xMu, this.std, this.xHat, new int[]{1}));
            }
            INDArray dBeta = epsilon.sum(true, new int[]{0});
            INDArray dGamma = epsilon.mul(this.xHat).sum(true, new int[]{0});
            INDArray dxhat = layerConf.isLockGammaBeta() ? epsilon.mul((Number)layerConf.getGamma()) : epsilon.mulRowVector(gamma);
            INDArray dLdVar = dxhat.mul(this.xMu).sum(true, new int[]{0}).muli((Number)-0.5).muli(Transforms.pow((INDArray)this.std, (Number)-3.0, (boolean)true));
            INDArray dxmu1 = dxhat.sum(true, new int[]{0}).divi(this.std).negi();
            INDArray dxmu2 = this.xMu.sum(true, new int[]{0}).muli((Number)(-2.0 / (double)batchSize)).muli(dLdVar);
            INDArray dLdmu = dxmu1.addi(dxmu2);
            INDArray dLdx = dxhat.diviRowVector(this.std).addi(this.xMu.muliRowVector(dLdVar.muli((Number)(2.0 / (double)batchSize)))).addiRowVector(dLdmu.muli((Number)(1.0 / (double)batchSize)));
            dGammaView.assign(dGamma);
            dBetaView.assign(dBeta);
            retGradient.setGradientFor("gamma", dGammaView);
            retGradient.setGradientFor("beta", dBetaView);
            nextEpsilon = dLdx;
            batchMean = input.mean(new int[]{0});
            batchVar = input.var(false, new int[]{0});
        } else if (epsilon.rank() == 4) {
            int wIdx;
            int[] nArray;
            if (nchw) {
                int[] nArray2 = new int[3];
                nArray2[0] = 0;
                nArray2[1] = 2;
                nArray = nArray2;
                nArray2[2] = 3;
            } else {
                int[] nArray3 = new int[3];
                nArray3[0] = 0;
                nArray3[1] = 1;
                nArray = nArray3;
                nArray3[2] = 2;
            }
            int[] nonChDims = nArray;
            int hIdx = nchw ? 2 : 1;
            int n = wIdx = nchw ? 3 : 2;
            if (this.xHat == null && this.helper != null) {
                INDArray mean = this.helper.getMeanCache(this.dataType);
                this.std = Transforms.sqrt((INDArray)this.helper.getVarCache(this.dataType).addi((Number)((org.deeplearning4j.nn.conf.layers.BatchNormalization)this.layerConf()).getEps()));
                this.xMu = Nd4j.createUninitialized((DataType)this.dataType, (long[])input.shape(), (char)input.ordering());
                this.xMu = Nd4j.getExecutioner().exec((BroadcastOp)new BroadcastSubOp(input, mean, this.xMu, new int[]{chIdx}));
                this.xHat = Nd4j.createUninitialized((DataType)this.dataType, (long[])input.shape(), (char)input.ordering());
                this.xHat = Nd4j.getExecutioner().exec((BroadcastOp)new BroadcastDivOp(this.xMu, this.std, this.xHat, new int[]{chIdx}));
            }
            INDArray dBeta = epsilon.sum(nonChDims);
            INDArray dGamma = epsilon.mul(this.xHat).sum(nonChDims);
            INDArray dxhat = layerConf.isLockGammaBeta() ? epsilon.mul((Number)layerConf.getGamma()) : Nd4j.getExecutioner().exec((BroadcastOp)new BroadcastMulOp(epsilon, gamma, Nd4j.createUninitialized((DataType)epsilon.dataType(), (long[])epsilon.shape(), (char)epsilon.ordering()), new int[]{chIdx}));
            INDArray dLdVar = dxhat.mul(this.xMu).sum(nonChDims).muli((Number)-0.5).muli(Transforms.pow((INDArray)this.std, (Number)-3.0, (boolean)true));
            long effectiveBatchSize = input.size(0) * input.size(hIdx) * input.size(wIdx);
            INDArray dxmu1 = dxhat.sum(nonChDims).divi(this.std).negi();
            INDArray dxmu2 = this.xMu.sum(nonChDims).muli((Number)(-2.0 / (double)effectiveBatchSize)).muli(dLdVar);
            INDArray dLdmu = dxmu1.addi(dxmu2);
            INDArray dLdx = Nd4j.getExecutioner().exec((BroadcastOp)new BroadcastDivOp(dxhat, this.std, dxhat, new int[]{chIdx})).addi(Nd4j.getExecutioner().exec((BroadcastOp)new BroadcastMulOp(this.xMu, dLdVar.muli((Number)(2.0 / (double)effectiveBatchSize)), this.xMu, new int[]{chIdx})));
            Nd4j.getExecutioner().execAndReturn((BroadcastOp)new BroadcastAddOp(dLdx, dLdmu.muli((Number)(1.0 / (double)effectiveBatchSize)), dLdx, new int[]{chIdx}));
            dGammaView.assign(dGamma);
            dBetaView.assign(dBeta);
            retGradient.setGradientFor("gamma", dGammaView);
            retGradient.setGradientFor("beta", dBetaView);
            nextEpsilon = dLdx;
            batchMean = input.mean(nonChDims);
            batchVar = input.var(false, nonChDims);
        } else {
            throw new IllegalStateException("The layer prior to BatchNorm in the configuration is not currently supported. " + this.layerId());
        }
        Nd4j.getExecutioner().exec((CustomOp)new SubOp(globalMean, batchMean, dGlobalMeanView));
        dGlobalMeanView.muli((Number)(1.0 - ((org.deeplearning4j.nn.conf.layers.BatchNormalization)this.layerConf()).getDecay()));
        if (((org.deeplearning4j.nn.conf.layers.BatchNormalization)this.layerConf()).isUseLogStd()) {
            INDArray vari = Nd4j.valueArrayOf((long[])globalLog10Std.shape(), (double)10.0, (DataType)globalMean.dataType());
            Transforms.pow((INDArray)vari, (INDArray)globalLog10Std, (boolean)false);
            vari.muli(vari);
            double decay = ((org.deeplearning4j.nn.conf.layers.BatchNormalization)this.layerConf()).getDecay();
            INDArray varip1 = vari.mul((Number)decay).addi(batchVar.mul((Number)(1.0 - decay)));
            Nd4j.getExecutioner().exec((CustomOp)new DivOp(vari, varip1, dGlobalLog10StdView));
            Transforms.log((INDArray)dGlobalLog10StdView, (boolean)false);
            dGlobalLog10StdView.muli((Number)ONE_ON_2LOGE_10);
        } else {
            Nd4j.getExecutioner().exec((CustomOp)new SubOp(globalVar, batchVar, dGlobalVarView));
            dGlobalVarView.muli((Number)(1.0 - ((org.deeplearning4j.nn.conf.layers.BatchNormalization)this.layerConf()).getDecay()));
        }
        retGradient.setGradientFor("mean", dGlobalMeanView);
        if (((org.deeplearning4j.nn.conf.layers.BatchNormalization)this.layerConf()).isUseLogStd()) {
            retGradient.setGradientFor("log10stdev", dGlobalLog10StdView);
        } else {
            retGradient.setGradientFor("var", dGlobalVarView);
        }
        nextEpsilon = workspaceMgr.leverageTo(ArrayType.ACTIVATION_GRAD, nextEpsilon);
        this.xHat = null;
        this.xMu = null;
        return new Pair((Object)retGradient, (Object)nextEpsilon);
    }

    @Override
    public void fit(INDArray input, LayerWorkspaceMgr workspaceMgr) {
        throw new UnsupportedOperationException("Not supported");
    }

    @Override
    public INDArray activate(boolean training, LayerWorkspaceMgr workspaceMgr) {
        this.assertInputSet(false);
        return this.preOutput(this.input, training ? Layer.TrainingMode.TRAIN : Layer.TrainingMode.TEST, workspaceMgr);
    }

    @Override
    public Gradient gradient() {
        return this.gradient;
    }

    public INDArray preOutput(INDArray x, Layer.TrainingMode training, LayerWorkspaceMgr workspaceMgr) {
        INDArray activations;
        INDArray mean;
        int wIdx;
        int[] nArray;
        int chIdx;
        CNN2DFormat format;
        int dim = 1;
        if (x.rank() == 4 && ((org.deeplearning4j.nn.conf.layers.BatchNormalization)this.layerConf()).getCnn2DFormat() == CNN2DFormat.NHWC) {
            dim = 3;
        }
        if (x.size(dim) != ((org.deeplearning4j.nn.conf.layers.BatchNormalization)this.layerConf()).getNOut()) {
            throw new IllegalArgumentException("input.size(" + dim + ") does not match expected input size of " + ((org.deeplearning4j.nn.conf.layers.BatchNormalization)this.layerConf()).getNIn() + " - got input array with shape " + Arrays.toString(x.shape()));
        }
        x = x.castTo(this.dataType);
        org.deeplearning4j.nn.conf.layers.BatchNormalization layerConf = (org.deeplearning4j.nn.conf.layers.BatchNormalization)this.layerConf();
        long[] shape = this.getShape(x);
        INDArray gamma = null;
        INDArray beta = null;
        INDArray globalMeanView = this.getParam("mean");
        INDArray globalVarView = this.getParam("var");
        if (layerConf.isLockGammaBeta()) {
            if (this.helper != null && this.input.rank() == 4) {
                long[] gammaBetaShape = new long[]{1L, ((org.deeplearning4j.nn.conf.layers.BatchNormalization)this.layerConf()).getNOut()};
                gamma = Nd4j.valueArrayOf((long[])gammaBetaShape, (double)((org.deeplearning4j.nn.conf.layers.BatchNormalization)this.layerConf()).getGamma(), (DataType)this.dataType);
                beta = Nd4j.valueArrayOf((long[])gammaBetaShape, (double)((org.deeplearning4j.nn.conf.layers.BatchNormalization)this.layerConf()).getBeta(), (DataType)this.dataType);
            }
        } else {
            gamma = this.getParam("gamma");
            beta = this.getParam("beta");
        }
        if (!(this.helper == null || this.helperCountFail != 0 && ((org.deeplearning4j.nn.conf.layers.BatchNormalization)this.layerConf()).isCudnnAllowFallback())) {
            INDArray in = x;
            if (x.rank() == 2) {
                in = x.reshape(x.ordering(), new long[]{in.size(0), in.size(1), 1L, 1L});
            }
            double decay = layerConf.getDecay();
            INDArray ret = null;
            try {
                if (globalVarView == null) {
                    INDArray log10s = this.getParam("log10stdev");
                    globalVarView = Transforms.pow((INDArray)Nd4j.valueArrayOf((long[])log10s.shape(), (double)10.0, (DataType)this.dataType), (INDArray)log10s, (boolean)false);
                    globalVarView.muli(globalVarView);
                }
                ret = this.helper.preOutput(in, training == Layer.TrainingMode.TRAIN, shape, gamma, beta, globalMeanView, globalVarView, decay, layerConf.getEps(), ((org.deeplearning4j.nn.conf.layers.BatchNormalization)this.layerConf()).getCnn2DFormat(), workspaceMgr);
            }
            catch (ND4JOpProfilerException e) {
                throw e;
            }
            catch (Throwable t) {
                if (t.getMessage() != null && t.getMessage().contains("Failed to allocate")) {
                    throw t;
                }
                if (((org.deeplearning4j.nn.conf.layers.BatchNormalization)this.layerConf()).isCudnnAllowFallback()) {
                    ++this.helperCountFail;
                    log.warn("CuDNN BatchNormalization forward pass execution failed - falling back on built-in implementation", t);
                }
                throw new RuntimeException("Error during BatchNormalization CuDNN helper backprop - isCudnnAllowFallback() is set to false", t);
            }
            if (ret != null) {
                if (this.input.rank() == 2) {
                    return ret.reshape(ret.ordering(), new long[]{ret.size(0), ret.size(1)});
                }
                return ret;
            }
        }
        boolean nchw = (format = ((org.deeplearning4j.nn.conf.layers.BatchNormalization)this.layerConf()).getCnn2DFormat()) == CNN2DFormat.NCHW;
        int n = chIdx = nchw ? 1 : 3;
        if (nchw) {
            int[] nArray2 = new int[3];
            nArray2[0] = 0;
            nArray2[1] = 2;
            nArray = nArray2;
            nArray2[2] = 3;
        } else {
            int[] nArray3 = new int[3];
            nArray3[0] = 0;
            nArray3[1] = 1;
            nArray = nArray3;
            nArray3[2] = 2;
        }
        int[] nonChDims = nArray;
        int hIdx = nchw ? 2 : 1;
        int n2 = wIdx = nchw ? 3 : 2;
        if (training == Layer.TrainingMode.TRAIN) {
            INDArray var;
            switch (x.rank()) {
                case 2: {
                    mean = x.mean(new int[]{0});
                    var = x.var(false, new int[]{0});
                    break;
                }
                case 4: {
                    mean = x.mean(nonChDims);
                    var = x.var(false, nonChDims);
                    break;
                }
                default: {
                    throw new IllegalStateException("Batch normalization on activations of rank " + x.rank() + " not supported " + this.layerId());
                }
            }
            this.std = Transforms.sqrt((INDArray)workspaceMgr.dup(ArrayType.INPUT, var).addi((Number)((org.deeplearning4j.nn.conf.layers.BatchNormalization)this.layerConf()).getEps()), (boolean)false);
        } else {
            INDArray var;
            mean = this.getParam("mean");
            if (((org.deeplearning4j.nn.conf.layers.BatchNormalization)this.layerConf()).isUseLogStd()) {
                INDArray log10s = this.getParam("log10stdev");
                var = Transforms.pow((INDArray)Nd4j.valueArrayOf((long[])log10s.shape(), (double)10.0, (DataType)this.dataType), (INDArray)log10s);
                var.muli(var);
            } else {
                var = this.getParam("var");
            }
            this.std = Transforms.sqrt((INDArray)workspaceMgr.dup(ArrayType.INPUT, var).addi((Number)((org.deeplearning4j.nn.conf.layers.BatchNormalization)this.layerConf()).getEps()), (boolean)false);
        }
        if (x.rank() == 2) {
            this.xMu = workspaceMgr.leverageTo(ArrayType.INPUT, x.subRowVector(mean));
            this.xHat = workspaceMgr.leverageTo(ArrayType.INPUT, this.xMu.divRowVector(this.std));
            if (layerConf.isLockGammaBeta()) {
                double g = layerConf.getGamma();
                double b = layerConf.getBeta();
                activations = g != 1.0 && b != 0.0 ? this.xHat.mul((Number)g).addi((Number)b) : this.xHat;
            } else {
                activations = this.xHat.mulRowVector(gamma).addiRowVector(beta);
            }
        } else if (x.rank() == 4) {
            if (!Shape.strideDescendingCAscendingF((INDArray)x)) {
                x = x.dup();
            }
            this.xMu = workspaceMgr.createUninitialized(ArrayType.INPUT, x.dataType(), x.shape(), x.ordering());
            this.xMu = Nd4j.getExecutioner().exec((BroadcastOp)new BroadcastSubOp(x, mean, this.xMu, new int[]{chIdx}));
            this.xHat = workspaceMgr.createUninitialized(ArrayType.INPUT, x.dataType(), x.shape(), x.ordering());
            this.xHat = Nd4j.getExecutioner().exec((BroadcastOp)new BroadcastDivOp(this.xMu, this.std, this.xHat, new int[]{chIdx}));
            if (layerConf.isLockGammaBeta()) {
                double g = layerConf.getGamma();
                double b = layerConf.getBeta();
                activations = g != 1.0 && b != 0.0 ? this.xHat.mul((Number)g).addi((Number)b) : this.xHat;
            } else {
                activations = workspaceMgr.createUninitialized(ArrayType.ACTIVATIONS, x.dataType(), x.shape(), x.ordering());
                activations = Nd4j.getExecutioner().exec((BroadcastOp)new BroadcastMulOp(this.xHat, gamma, activations, new int[]{chIdx}));
                activations = Nd4j.getExecutioner().exec((BroadcastOp)new BroadcastAddOp(activations, beta, activations, new int[]{chIdx}));
            }
        } else {
            throw new IllegalStateException("The layer prior to BatchNorm in the configuration is not currently supported. " + this.layerId());
        }
        activations = workspaceMgr.leverageTo(ArrayType.ACTIVATIONS, activations);
        return activations;
    }

    @Override
    public Collection<TrainingListener> getListeners() {
        return this.listeners;
    }

    @Override
    public void setListeners(TrainingListener ... listeners) {
        this.listeners = new ArrayList<TrainingListener>(Arrays.asList(listeners));
    }

    @Override
    public void setIndex(int index) {
        this.index = index;
    }

    @Override
    public int getIndex() {
        return this.index;
    }

    @Override
    public boolean isPretrainLayer() {
        return false;
    }

    @Override
    public LayerHelper getHelper() {
        return this.helper;
    }

    public long[] getShape(INDArray x) {
        if (x.rank() == 2) {
            return new long[]{1L, x.size(1)};
        }
        if (x.rank() == 4) {
            int chIdx = ((org.deeplearning4j.nn.conf.layers.BatchNormalization)this.layerConf()).getCnn2DFormat() == CNN2DFormat.NCHW ? 1 : 3;
            return new long[]{1L, x.size(chIdx)};
        }
        if (x.rank() == 3) {
            long wDim = x.size(1);
            long hdim = x.size(2);
            if (x.size(0) > 1L && wDim * hdim == x.length()) {
                throw new IllegalArgumentException("Illegal input for batch size " + this.layerId());
            }
            return new long[]{1L, wDim * hdim};
        }
        throw new IllegalStateException("Unable to process input of rank " + x.rank() + " " + this.layerId());
    }

    @Override
    public boolean updaterDivideByMinibatch(String paramName) {
        return !"mean".equals(paramName) && !"var".equals(paramName) && !"log10stdev".equals(paramName);
    }
}

