/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.cuda.normalization;

import java.util.HashMap;
import java.util.Map;
import org.bytedeco.cuda.cudart.CUstream_st;
import org.bytedeco.cuda.cudnn.cudnnContext;
import org.bytedeco.cuda.cudnn.cudnnTensorStruct;
import org.bytedeco.cuda.global.cudnn;
import org.bytedeco.javacpp.Pointer;
import org.deeplearning4j.cuda.BaseCudnnHelper;
import org.deeplearning4j.nn.conf.CNN2DFormat;
import org.deeplearning4j.nn.gradient.DefaultGradient;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.layers.normalization.BatchNormalizationHelper;
import org.deeplearning4j.nn.workspace.ArrayType;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.nd4j.common.primitives.Pair;
import org.nd4j.common.util.ArrayUtil;
import org.nd4j.jita.allocator.impl.AtomicAllocator;
import org.nd4j.jita.conf.CudaEnvironment;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.memory.MemoryWorkspace;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.executioner.GridExecutioner;
import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.jcublas.context.CudaContext;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class CudnnBatchNormalizationHelper
extends BaseCudnnHelper
implements BatchNormalizationHelper {
    private static final Logger log = LoggerFactory.getLogger(CudnnBatchNormalizationHelper.class);
    protected final int batchNormMode = 1;
    private CudnnBatchNormalizationContext cudnnContext = new CudnnBatchNormalizationContext();
    private INDArray meanCache;
    private INDArray varCache;
    private double eps;

    public CudnnBatchNormalizationHelper(DataType dataType) {
        super(dataType);
    }

    public boolean checkSupported(double eps, boolean isFixedGammaBeta) {
        boolean supported = this.checkSupported();
        if (eps < 0.0) {
            supported = false;
            log.warn("Not supported: eps < CUDNN_BN_MIN_EPSILON (" + eps + " < " + 0.0 + ")");
        }
        return supported;
    }

    public Pair<Gradient, INDArray> backpropGradient(INDArray input, INDArray epsilon, long[] shape, INDArray gamma, INDArray beta, INDArray dGammaView, INDArray dBetaView, double eps, CNN2DFormat format, LayerWorkspaceMgr layerWorkspaceMgr) {
        long[] lArray;
        boolean nchw = format == CNN2DFormat.NCHW;
        this.eps = eps;
        int cudnnTensorFormat = nchw ? 0 : 1;
        int chIdx = nchw ? 1 : 3;
        int hIdx = nchw ? 2 : 1;
        int wIdx = nchw ? 3 : 2;
        int miniBatch = (int)input.size(0);
        int depth = (int)input.size(chIdx);
        int inH = (int)input.size(hIdx);
        int inW = (int)input.size(wIdx);
        boolean isHalf = input.dataType() == DataType.HALF;
        INDArray gammaOrig = null;
        INDArray dGammaViewOrig = null;
        INDArray dBetaViewOrig = null;
        if (isHalf) {
            gammaOrig = gamma;
            dGammaViewOrig = dGammaView;
            dBetaViewOrig = dBetaView;
            gamma = gamma.castTo(DataType.FLOAT);
            dGammaView = dGammaView.castTo(DataType.FLOAT);
            dBetaView = dBetaView.castTo(DataType.FLOAT);
        }
        DefaultGradient retGradient = new DefaultGradient();
        if (!Shape.hasDefaultStridesForShape((INDArray)epsilon)) {
            epsilon = epsilon.dup('c');
        }
        int[] srcStride = ArrayUtil.toInts((long[])input.stride());
        int[] deltaStride = ArrayUtil.toInts((long[])epsilon.stride());
        if (Nd4j.getExecutioner() instanceof GridExecutioner) {
            ((GridExecutioner)Nd4j.getExecutioner()).flushQueue();
        }
        CudnnBatchNormalizationHelper.checkCudnn(cudnn.cudnnSetTensor4dDescriptorEx((cudnnTensorStruct)this.cudnnContext.srcTensorDesc, (int)this.dataType, (int)miniBatch, (int)depth, (int)inH, (int)inW, (int)srcStride[0], (int)srcStride[chIdx], (int)srcStride[hIdx], (int)srcStride[wIdx]));
        CudnnBatchNormalizationHelper.checkCudnn(cudnn.cudnnSetTensor4dDescriptorEx((cudnnTensorStruct)this.cudnnContext.deltaTensorDesc, (int)this.dataType, (int)miniBatch, (int)depth, (int)inH, (int)inW, (int)deltaStride[0], (int)deltaStride[chIdx], (int)deltaStride[hIdx], (int)deltaStride[wIdx]));
        if (nchw) {
            long[] lArray2 = new long[4];
            lArray2[0] = miniBatch;
            lArray2[1] = depth;
            lArray2[2] = inH;
            lArray = lArray2;
            lArray2[3] = inW;
        } else {
            long[] lArray3 = new long[4];
            lArray3[0] = miniBatch;
            lArray3[1] = inH;
            lArray3[2] = inW;
            lArray = lArray3;
            lArray3[3] = depth;
        }
        long[] nextEpsShape = lArray;
        INDArray nextEpsilon = layerWorkspaceMgr.createUninitialized((Enum)ArrayType.ACTIVATION_GRAD, input.dataType(), nextEpsShape, 'c');
        int[] dstStride = ArrayUtil.toInts((long[])nextEpsilon.stride());
        CudnnBatchNormalizationHelper.checkCudnn(cudnn.cudnnSetTensor4dDescriptorEx((cudnnTensorStruct)this.cudnnContext.dstTensorDesc, (int)this.dataType, (int)miniBatch, (int)depth, (int)inH, (int)inW, (int)dstStride[0], (int)dstStride[chIdx], (int)dstStride[hIdx], (int)dstStride[wIdx]));
        CudnnBatchNormalizationHelper.checkCudnn(cudnn.cudnnSetTensor4dDescriptor((cudnnTensorStruct)this.cudnnContext.gammaBetaTensorDesc, (int)cudnnTensorFormat, (int)CudnnBatchNormalizationHelper.toCudnnDataType(gamma.data().dataType()), (int)((int)shape[0]), (int)((int)shape[1]), (int)(shape.length > 2 ? (int)shape[2] : 1), (int)(shape.length > 3 ? (int)shape[3] : 1)));
        AtomicAllocator allocator = AtomicAllocator.getInstance();
        CudaContext context = allocator.getFlowController().prepareActionAllWrite(new INDArray[]{input, epsilon, nextEpsilon, gamma, dGammaView, dBetaView});
        Pointer srcData = allocator.getPointer(input, context);
        Pointer epsData = allocator.getPointer(epsilon, context);
        Pointer dstData = allocator.getPointer(nextEpsilon, context);
        Pointer gammaData = allocator.getPointer(gamma, context);
        Pointer dGammaData = allocator.getPointer(dGammaView, context);
        Pointer dBetaData = allocator.getPointer(dBetaView, context);
        Pointer meanCacheData = allocator.getPointer(this.meanCache, context);
        Pointer varCacheData = allocator.getPointer(this.varCache, context);
        CudnnBatchNormalizationHelper.checkCudnn(cudnn.cudnnSetStream((cudnnContext)this.cudnnContext, (CUstream_st)new CUstream_st(context.getCublasStream())));
        CudnnBatchNormalizationHelper.checkCudnn(cudnn.cudnnBatchNormalizationBackward((cudnnContext)this.cudnnContext, (int)1, (Pointer)this.alpha, (Pointer)this.beta, (Pointer)this.alpha, (Pointer)this.alpha, (cudnnTensorStruct)this.cudnnContext.srcTensorDesc, (Pointer)srcData, (cudnnTensorStruct)this.cudnnContext.deltaTensorDesc, (Pointer)epsData, (cudnnTensorStruct)this.cudnnContext.dstTensorDesc, (Pointer)dstData, (cudnnTensorStruct)this.cudnnContext.gammaBetaTensorDesc, (Pointer)gammaData, (Pointer)dGammaData, (Pointer)dBetaData, (double)eps, (Pointer)meanCacheData, (Pointer)varCacheData));
        allocator.getFlowController().registerActionAllWrite(context, new INDArray[]{input, epsilon, nextEpsilon, gamma, dGammaView, dBetaView});
        retGradient.setGradientFor("gamma", dGammaView);
        retGradient.setGradientFor("beta", dBetaView);
        context.syncOldStream();
        if (isHalf) {
            gammaOrig.assign(gamma.castTo(DataType.HALF));
            dGammaViewOrig.assign(dGammaView.castTo(DataType.HALF));
            dBetaViewOrig.assign(dBetaView.castTo(DataType.HALF));
        }
        return new Pair((Object)retGradient, (Object)nextEpsilon);
    }

    public INDArray preOutput(INDArray x, boolean training, long[] shape, INDArray gamma, INDArray beta, INDArray mean, INDArray var, double decay, double eps, CNN2DFormat format, LayerWorkspaceMgr workspaceMgr) {
        long[] lArray;
        boolean nchw = format == CNN2DFormat.NCHW;
        int cudnnTensorFormat = nchw ? 0 : 1;
        int chIdx = nchw ? 1 : 3;
        int hIdx = nchw ? 2 : 1;
        int wIdx = nchw ? 3 : 2;
        this.eps = eps;
        boolean isHalf = x.dataType() == DataType.FLOAT16;
        INDArray origGamma = gamma;
        INDArray origBeta = beta;
        INDArray origMean = mean;
        INDArray origVar = var;
        if (isHalf) {
            gamma = gamma.castTo(DataType.FLOAT);
            beta = beta.castTo(DataType.FLOAT);
            mean = mean.castTo(DataType.FLOAT);
            var = var.castTo(DataType.FLOAT);
        }
        decay = 0.0;
        int miniBatch = (int)x.size(0);
        int inDepth = (int)x.size(chIdx);
        int inH = (int)x.size(hIdx);
        int inW = (int)x.size(wIdx);
        int[] srcStride = ArrayUtil.toInts((long[])x.stride());
        CudnnBatchNormalizationHelper.checkCudnn(cudnn.cudnnSetTensor4dDescriptorEx((cudnnTensorStruct)this.cudnnContext.srcTensorDesc, (int)this.dataType, (int)miniBatch, (int)inDepth, (int)inH, (int)inW, (int)srcStride[0], (int)srcStride[chIdx], (int)srcStride[hIdx], (int)srcStride[wIdx]));
        if (nchw) {
            long[] lArray2 = new long[4];
            lArray2[0] = miniBatch;
            lArray2[1] = inDepth;
            lArray2[2] = inH;
            lArray = lArray2;
            lArray2[3] = inW;
        } else {
            long[] lArray3 = new long[4];
            lArray3[0] = miniBatch;
            lArray3[1] = inH;
            lArray3[2] = inW;
            lArray = lArray3;
            lArray3[3] = inDepth;
        }
        long[] actShape = lArray;
        INDArray activations = workspaceMgr.createUninitialized((Enum)ArrayType.ACTIVATIONS, x.dataType(), actShape, 'c');
        int[] dstStride = ArrayUtil.toInts((long[])activations.stride());
        CudnnBatchNormalizationHelper.checkCudnn(cudnn.cudnnSetTensor4dDescriptorEx((cudnnTensorStruct)this.cudnnContext.dstTensorDesc, (int)this.dataType, (int)miniBatch, (int)inDepth, (int)inH, (int)inW, (int)dstStride[0], (int)dstStride[chIdx], (int)dstStride[hIdx], (int)dstStride[wIdx]));
        CudnnBatchNormalizationHelper.checkCudnn(cudnn.cudnnSetTensor4dDescriptor((cudnnTensorStruct)this.cudnnContext.gammaBetaTensorDesc, (int)cudnnTensorFormat, (int)CudnnBatchNormalizationHelper.toCudnnDataType(mean.data().dataType()), (int)((int)shape[0]), (int)((int)shape[1]), (int)(shape.length > 2 ? (int)shape[2] : 1), (int)(shape.length > 3 ? (int)shape[3] : 1)));
        AtomicAllocator allocator = AtomicAllocator.getInstance();
        CudaContext context = allocator.getFlowController().prepareActionAllWrite(new INDArray[]{x, activations, gamma, beta, mean, var});
        Pointer srcData = allocator.getPointer(x, context);
        Pointer dstData = allocator.getPointer(activations, context);
        Pointer gammaData = allocator.getPointer(gamma, context);
        Pointer betaData = allocator.getPointer(beta, context);
        Pointer meanData = allocator.getPointer(mean, context);
        Pointer varData = allocator.getPointer(var, context);
        if (Nd4j.getExecutioner() instanceof GridExecutioner) {
            ((GridExecutioner)Nd4j.getExecutioner()).flushQueue();
        }
        CudnnBatchNormalizationHelper.checkCudnn(cudnn.cudnnSetStream((cudnnContext)this.cudnnContext, (CUstream_st)new CUstream_st(context.getCublasStream())));
        if (training) {
            Throwable throwable;
            MemoryWorkspace ws;
            if (this.meanCache == null || this.meanCache.length() < mean.length()) {
                ws = Nd4j.getMemoryManager().scopeOutOfWorkspaces();
                throwable = null;
                try {
                    this.meanCache = Nd4j.createUninitialized((DataType)x.dataType(), (long[])new long[]{mean.length()});
                }
                catch (Throwable throwable2) {
                    throwable = throwable2;
                    throw throwable2;
                }
                finally {
                    if (ws != null) {
                        if (throwable != null) {
                            try {
                                ws.close();
                            }
                            catch (Throwable throwable3) {
                                throwable.addSuppressed(throwable3);
                            }
                        } else {
                            ws.close();
                        }
                    }
                }
                if (x.dataType() == DataType.HALF) {
                    ws = Nd4j.getMemoryManager().scopeOutOfWorkspaces();
                    throwable = null;
                    try {
                        this.meanCache = this.meanCache.castTo(DataType.FLOAT);
                    }
                    catch (Throwable throwable4) {
                        throwable = throwable4;
                        throw throwable4;
                    }
                    finally {
                        if (ws != null) {
                            if (throwable != null) {
                                try {
                                    ws.close();
                                }
                                catch (Throwable throwable5) {
                                    throwable.addSuppressed(throwable5);
                                }
                            } else {
                                ws.close();
                            }
                        }
                    }
                }
            }
            if (this.varCache == null || this.varCache.length() < mean.length()) {
                ws = Nd4j.getMemoryManager().scopeOutOfWorkspaces();
                throwable = null;
                try {
                    this.varCache = Nd4j.createUninitialized((DataType)x.dataType(), (long[])new long[]{mean.length()});
                }
                catch (Throwable throwable6) {
                    throwable = throwable6;
                    throw throwable6;
                }
                finally {
                    if (ws != null) {
                        if (throwable != null) {
                            try {
                                ws.close();
                            }
                            catch (Throwable throwable7) {
                                throwable.addSuppressed(throwable7);
                            }
                        } else {
                            ws.close();
                        }
                    }
                }
                if (this.nd4jDataType == DataType.HALF) {
                    ws = Nd4j.getMemoryManager().scopeOutOfWorkspaces();
                    throwable = null;
                    try {
                        this.varCache = this.varCache.castTo(DataType.FLOAT);
                    }
                    catch (Throwable throwable8) {
                        throwable = throwable8;
                        throw throwable8;
                    }
                    finally {
                        if (ws != null) {
                            if (throwable != null) {
                                try {
                                    ws.close();
                                }
                                catch (Throwable throwable9) {
                                    throwable.addSuppressed(throwable9);
                                }
                            } else {
                                ws.close();
                            }
                        }
                    }
                }
            }
            Pointer meanCacheData = allocator.getPointer(this.meanCache, context);
            Pointer varCacheData = allocator.getPointer(this.varCache, context);
            CudnnBatchNormalizationHelper.checkCudnn(cudnn.cudnnBatchNormalizationForwardTraining((cudnnContext)this.cudnnContext, (int)1, (Pointer)this.alpha, (Pointer)this.beta, (cudnnTensorStruct)this.cudnnContext.srcTensorDesc, (Pointer)srcData, (cudnnTensorStruct)this.cudnnContext.dstTensorDesc, (Pointer)dstData, (cudnnTensorStruct)this.cudnnContext.gammaBetaTensorDesc, (Pointer)gammaData, (Pointer)betaData, (double)decay, (Pointer)meanData, (Pointer)varData, (double)eps, (Pointer)meanCacheData, (Pointer)varCacheData));
        } else {
            CudnnBatchNormalizationHelper.checkCudnn(cudnn.cudnnBatchNormalizationForwardInference((cudnnContext)this.cudnnContext, (int)1, (Pointer)this.alpha, (Pointer)this.beta, (cudnnTensorStruct)this.cudnnContext.srcTensorDesc, (Pointer)srcData, (cudnnTensorStruct)this.cudnnContext.dstTensorDesc, (Pointer)dstData, (cudnnTensorStruct)this.cudnnContext.gammaBetaTensorDesc, (Pointer)gammaData, (Pointer)betaData, (Pointer)meanData, (Pointer)varData, (double)eps));
        }
        allocator.getFlowController().registerActionAllWrite(context, new INDArray[]{x, activations, gamma, beta, mean, var});
        if (CudaEnvironment.getInstance().getConfiguration().isDebug()) {
            context.syncOldStream();
        }
        context.syncOldStream();
        if (training) {
            AtomicAllocator.getInstance().getAllocationPoint(this.meanCache).tickDeviceWrite();
            AtomicAllocator.getInstance().getAllocationPoint(this.varCache).tickDeviceWrite();
        }
        if (training && isHalf) {
            origMean.assign(mean.castTo(DataType.HALF));
            origVar.assign(var.castTo(DataType.HALF));
            origGamma.assign(gamma.castTo(DataType.HALF));
            origBeta.assign(beta.castTo(DataType.HALF));
        }
        return activations;
    }

    public INDArray getMeanCache(DataType dataType) {
        if (dataType == DataType.HALF) {
            return this.meanCache.castTo(DataType.HALF);
        }
        return this.meanCache;
    }

    public INDArray getVarCache(DataType dataType) {
        INDArray ret;
        if (dataType == DataType.HALF) {
            INDArray vc = this.varCache.castTo(DataType.HALF);
            ret = vc.mul(vc).rdivi((Number)1.0).subi((Number)this.eps);
        } else {
            ret = this.varCache.mul(this.varCache).rdivi((Number)1.0).subi((Number)this.eps);
        }
        if (dataType == DataType.HALF) {
            return ret.castTo(DataType.HALF);
        }
        return ret;
    }

    public Map<String, Long> helperMemoryUse() {
        HashMap<String, Long> memUse = new HashMap<String, Long>();
        memUse.put("meanCache", this.meanCache == null ? 0L : this.meanCache.length() * (long)this.meanCache.data().getElementSize());
        memUse.put("varCache", this.varCache == null ? 0L : this.varCache.length() * (long)this.varCache.data().getElementSize());
        return memUse;
    }

    private static class CudnnBatchNormalizationContext
    extends BaseCudnnHelper.CudnnContext {
        private cudnnTensorStruct srcTensorDesc = new cudnnTensorStruct();
        private cudnnTensorStruct dstTensorDesc = new cudnnTensorStruct();
        private cudnnTensorStruct deltaTensorDesc = new cudnnTensorStruct();
        private cudnnTensorStruct gammaBetaTensorDesc = new cudnnTensorStruct();

        public CudnnBatchNormalizationContext() {
            this.createHandles();
            this.deallocator(new Deallocator(this));
        }

        public CudnnBatchNormalizationContext(CudnnBatchNormalizationContext c) {
            super(c);
            this.srcTensorDesc = new cudnnTensorStruct((Pointer)c.srcTensorDesc);
            this.dstTensorDesc = new cudnnTensorStruct((Pointer)c.dstTensorDesc);
            this.deltaTensorDesc = new cudnnTensorStruct((Pointer)c.deltaTensorDesc);
            this.gammaBetaTensorDesc = new cudnnTensorStruct((Pointer)c.gammaBetaTensorDesc);
        }

        @Override
        protected void createHandles() {
            super.createHandles();
            CudnnBatchNormalizationHelper.checkCudnn(cudnn.cudnnCreateTensorDescriptor((cudnnTensorStruct)this.srcTensorDesc));
            CudnnBatchNormalizationHelper.checkCudnn(cudnn.cudnnCreateTensorDescriptor((cudnnTensorStruct)this.dstTensorDesc));
            CudnnBatchNormalizationHelper.checkCudnn(cudnn.cudnnCreateTensorDescriptor((cudnnTensorStruct)this.deltaTensorDesc));
            CudnnBatchNormalizationHelper.checkCudnn(cudnn.cudnnCreateTensorDescriptor((cudnnTensorStruct)this.gammaBetaTensorDesc));
        }

        @Override
        protected void destroyHandles() {
            CudnnBatchNormalizationHelper.checkCudnn(cudnn.cudnnDestroyTensorDescriptor((cudnnTensorStruct)this.srcTensorDesc));
            CudnnBatchNormalizationHelper.checkCudnn(cudnn.cudnnDestroyTensorDescriptor((cudnnTensorStruct)this.dstTensorDesc));
            CudnnBatchNormalizationHelper.checkCudnn(cudnn.cudnnDestroyTensorDescriptor((cudnnTensorStruct)this.deltaTensorDesc));
            CudnnBatchNormalizationHelper.checkCudnn(cudnn.cudnnDestroyTensorDescriptor((cudnnTensorStruct)this.gammaBetaTensorDesc));
            super.destroyHandles();
        }

        private static class Deallocator
        extends CudnnBatchNormalizationContext
        implements Pointer.Deallocator {
            Deallocator(CudnnBatchNormalizationContext c) {
                super(c);
            }

            public void deallocate() {
                this.destroyHandles();
            }
        }
    }
}

