/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.cuda.normalization;

import java.util.Collections;
import java.util.Map;
import org.bytedeco.cuda.cudart.CUstream_st;
import org.bytedeco.cuda.cudnn.cudnnContext;
import org.bytedeco.cuda.cudnn.cudnnLRNStruct;
import org.bytedeco.cuda.cudnn.cudnnTensorStruct;
import org.bytedeco.cuda.global.cudnn;
import org.bytedeco.javacpp.Pointer;
import org.deeplearning4j.cuda.BaseCudnnHelper;
import org.deeplearning4j.nn.gradient.DefaultGradient;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.layers.normalization.LocalResponseNormalizationHelper;
import org.deeplearning4j.nn.workspace.ArrayType;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.nd4j.common.primitives.Pair;
import org.nd4j.common.util.ArrayUtil;
import org.nd4j.jita.allocator.impl.AtomicAllocator;
import org.nd4j.jita.conf.CudaEnvironment;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.executioner.GridExecutioner;
import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.jcublas.context.CudaContext;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class CudnnLocalResponseNormalizationHelper
extends BaseCudnnHelper
implements LocalResponseNormalizationHelper {
    private static final Logger log = LoggerFactory.getLogger(CudnnLocalResponseNormalizationHelper.class);
    private CudnnLocalResponseNormalizationContext cudnnContext = new CudnnLocalResponseNormalizationContext();
    private INDArray activations = null;

    public CudnnLocalResponseNormalizationHelper(DataType dataType) {
        super(dataType);
    }

    public boolean checkSupported(double k, double n, double alpha, double beta) {
        boolean supported = this.checkSupported();
        if (n < 1.0) {
            supported = false;
            log.warn("Not supported: n < CUDNN_LRN_MIN_N (" + n + " < " + 1 + ")");
        }
        if (n > 16.0) {
            supported = false;
            log.warn("Not supported: n > CUDNN_LRN_MAX_N (" + n + " > " + 16 + ")");
        }
        if (k < 1.0E-5) {
            supported = false;
            log.warn("Not supported: k < CUDNN_LRN_MIN_K (" + k + " < " + 1.0E-5 + ")");
        }
        if (beta < 0.01) {
            supported = false;
            log.warn("Not supported: beta < CUDNN_LRN_MIN_BETA (" + beta + " < " + 0.01 + ")");
        }
        return supported;
    }

    public Pair<Gradient, INDArray> backpropGradient(INDArray input, INDArray epsilon, double k, double n, double alpha, double beta, LayerWorkspaceMgr workspaceMgr) {
        int miniBatch = (int)input.size(0);
        int depth = (int)input.size(1);
        int inH = (int)input.size(2);
        int inW = (int)input.size(3);
        DefaultGradient retGradient = new DefaultGradient();
        if (!Shape.hasDefaultStridesForShape((INDArray)epsilon)) {
            epsilon = epsilon.dup('c');
        }
        int[] srcStride = ArrayUtil.toInts((long[])input.stride());
        int[] deltaStride = ArrayUtil.toInts((long[])epsilon.stride());
        if (Nd4j.getExecutioner() instanceof GridExecutioner) {
            ((GridExecutioner)Nd4j.getExecutioner()).flushQueue();
        }
        CudnnLocalResponseNormalizationHelper.checkCudnn(cudnn.cudnnSetTensor4dDescriptorEx((cudnnTensorStruct)this.cudnnContext.srcTensorDesc, (int)this.dataType, (int)miniBatch, (int)depth, (int)inH, (int)inW, (int)srcStride[0], (int)srcStride[1], (int)srcStride[2], (int)srcStride[3]));
        CudnnLocalResponseNormalizationHelper.checkCudnn(cudnn.cudnnSetTensor4dDescriptorEx((cudnnTensorStruct)this.cudnnContext.deltaTensorDesc, (int)this.dataType, (int)miniBatch, (int)depth, (int)inH, (int)inW, (int)deltaStride[0], (int)deltaStride[1], (int)deltaStride[2], (int)deltaStride[3]));
        CudnnLocalResponseNormalizationHelper.checkCudnn(cudnn.cudnnSetLRNDescriptor((cudnnLRNStruct)this.cudnnContext.lrnDesc, (int)((int)n), (double)alpha, (double)beta, (double)k));
        INDArray nextEpsilon = workspaceMgr.createUninitialized((Enum)ArrayType.ACTIVATION_GRAD, input.dataType(), new long[]{miniBatch, depth, inH, inW}, 'c');
        int[] dstStride = ArrayUtil.toInts((long[])nextEpsilon.stride());
        CudnnLocalResponseNormalizationHelper.checkCudnn(cudnn.cudnnSetTensor4dDescriptorEx((cudnnTensorStruct)this.cudnnContext.dstTensorDesc, (int)this.dataType, (int)miniBatch, (int)depth, (int)inH, (int)inW, (int)dstStride[0], (int)dstStride[1], (int)dstStride[2], (int)dstStride[3]));
        AtomicAllocator allocator = AtomicAllocator.getInstance();
        CudaContext context = allocator.getFlowController().prepareActionAllWrite(new INDArray[]{input, epsilon, this.activations, nextEpsilon});
        Pointer srcData = allocator.getPointer(input, context);
        Pointer epsData = allocator.getPointer(epsilon, context);
        Pointer zData = allocator.getPointer(this.activations, context);
        Pointer dstData = allocator.getPointer(nextEpsilon, context);
        CudnnLocalResponseNormalizationHelper.checkCudnn(cudnn.cudnnSetStream((cudnnContext)this.cudnnContext, (CUstream_st)new CUstream_st(context.getCublasStream())));
        CudnnLocalResponseNormalizationHelper.checkCudnn(cudnn.cudnnLRNCrossChannelBackward((cudnnContext)this.cudnnContext, (cudnnLRNStruct)this.cudnnContext.lrnDesc, (int)0, (Pointer)this.alpha, (cudnnTensorStruct)this.cudnnContext.deltaTensorDesc, (Pointer)zData, (cudnnTensorStruct)this.cudnnContext.deltaTensorDesc, (Pointer)epsData, (cudnnTensorStruct)this.cudnnContext.srcTensorDesc, (Pointer)srcData, (Pointer)this.beta, (cudnnTensorStruct)this.cudnnContext.dstTensorDesc, (Pointer)dstData));
        allocator.getFlowController().registerActionAllWrite(context, new INDArray[]{input, epsilon, this.activations, nextEpsilon});
        if (CudaEnvironment.getInstance().getConfiguration().isDebug()) {
            context.syncOldStream();
        }
        return new Pair((Object)retGradient, (Object)nextEpsilon);
    }

    public INDArray activate(INDArray input, boolean training, double k, double n, double alpha, double beta, LayerWorkspaceMgr workspaceMgr) {
        int miniBatch = (int)input.size(0);
        int inDepth = (int)input.size(1);
        int inH = (int)input.size(2);
        int inW = (int)input.size(3);
        if (!Shape.hasDefaultStridesForShape((INDArray)input)) {
            input = input.dup('c');
        }
        int[] srcStride = ArrayUtil.toInts((long[])input.stride());
        CudnnLocalResponseNormalizationHelper.checkCudnn(cudnn.cudnnSetTensor4dDescriptorEx((cudnnTensorStruct)this.cudnnContext.srcTensorDesc, (int)this.dataType, (int)miniBatch, (int)inDepth, (int)inH, (int)inW, (int)srcStride[0], (int)srcStride[1], (int)srcStride[2], (int)srcStride[3]));
        this.activations = workspaceMgr.createUninitialized((Enum)ArrayType.ACTIVATIONS, input.dataType(), new long[]{miniBatch, inDepth, inH, inW}, 'c');
        int[] dstStride = ArrayUtil.toInts((long[])this.activations.stride());
        CudnnLocalResponseNormalizationHelper.checkCudnn(cudnn.cudnnSetTensor4dDescriptorEx((cudnnTensorStruct)this.cudnnContext.dstTensorDesc, (int)this.dataType, (int)miniBatch, (int)inDepth, (int)inH, (int)inW, (int)dstStride[0], (int)dstStride[1], (int)dstStride[2], (int)dstStride[3]));
        CudnnLocalResponseNormalizationHelper.checkCudnn(cudnn.cudnnSetLRNDescriptor((cudnnLRNStruct)this.cudnnContext.lrnDesc, (int)((int)n), (double)alpha, (double)beta, (double)k));
        AtomicAllocator allocator = AtomicAllocator.getInstance();
        CudaContext context = allocator.getFlowController().prepareActionAllWrite(new INDArray[]{input, this.activations});
        Pointer srcData = allocator.getPointer(input, context);
        Pointer dstData = allocator.getPointer(this.activations, context);
        if (Nd4j.getExecutioner() instanceof GridExecutioner) {
            ((GridExecutioner)Nd4j.getExecutioner()).flushQueue();
        }
        CudnnLocalResponseNormalizationHelper.checkCudnn(cudnn.cudnnSetStream((cudnnContext)this.cudnnContext, (CUstream_st)new CUstream_st(context.getCublasStream())));
        CudnnLocalResponseNormalizationHelper.checkCudnn(cudnn.cudnnLRNCrossChannelForward((cudnnContext)this.cudnnContext, (cudnnLRNStruct)this.cudnnContext.lrnDesc, (int)0, (Pointer)this.alpha, (cudnnTensorStruct)this.cudnnContext.srcTensorDesc, (Pointer)srcData, (Pointer)this.beta, (cudnnTensorStruct)this.cudnnContext.dstTensorDesc, (Pointer)dstData));
        allocator.getFlowController().registerActionAllWrite(context, new INDArray[]{input, this.activations});
        if (CudaEnvironment.getInstance().getConfiguration().isDebug()) {
            context.syncOldStream();
        }
        return this.activations;
    }

    public Map<String, Long> helperMemoryUse() {
        return Collections.emptyMap();
    }

    private static class CudnnLocalResponseNormalizationContext
    extends BaseCudnnHelper.CudnnContext {
        private cudnnTensorStruct srcTensorDesc = new cudnnTensorStruct();
        private cudnnTensorStruct dstTensorDesc = new cudnnTensorStruct();
        private cudnnTensorStruct deltaTensorDesc = new cudnnTensorStruct();
        private cudnnLRNStruct lrnDesc = new cudnnLRNStruct();

        public CudnnLocalResponseNormalizationContext() {
            this.createHandles();
            this.deallocator(new Deallocator(this));
        }

        public CudnnLocalResponseNormalizationContext(CudnnLocalResponseNormalizationContext c) {
            super(c);
            this.srcTensorDesc = new cudnnTensorStruct((Pointer)c.srcTensorDesc);
            this.dstTensorDesc = new cudnnTensorStruct((Pointer)c.dstTensorDesc);
            this.deltaTensorDesc = new cudnnTensorStruct((Pointer)c.deltaTensorDesc);
            this.lrnDesc = new cudnnLRNStruct((Pointer)c.lrnDesc);
        }

        @Override
        protected void createHandles() {
            super.createHandles();
            CudnnLocalResponseNormalizationHelper.checkCudnn(cudnn.cudnnCreateTensorDescriptor((cudnnTensorStruct)this.srcTensorDesc));
            CudnnLocalResponseNormalizationHelper.checkCudnn(cudnn.cudnnCreateTensorDescriptor((cudnnTensorStruct)this.dstTensorDesc));
            CudnnLocalResponseNormalizationHelper.checkCudnn(cudnn.cudnnCreateTensorDescriptor((cudnnTensorStruct)this.deltaTensorDesc));
            CudnnLocalResponseNormalizationHelper.checkCudnn(cudnn.cudnnCreateLRNDescriptor((cudnnLRNStruct)this.lrnDesc));
        }

        @Override
        protected void destroyHandles() {
            CudnnLocalResponseNormalizationHelper.checkCudnn(cudnn.cudnnDestroyLRNDescriptor((cudnnLRNStruct)this.lrnDesc));
            CudnnLocalResponseNormalizationHelper.checkCudnn(cudnn.cudnnDestroyTensorDescriptor((cudnnTensorStruct)this.srcTensorDesc));
            CudnnLocalResponseNormalizationHelper.checkCudnn(cudnn.cudnnDestroyTensorDescriptor((cudnnTensorStruct)this.dstTensorDesc));
            CudnnLocalResponseNormalizationHelper.checkCudnn(cudnn.cudnnDestroyTensorDescriptor((cudnnTensorStruct)this.deltaTensorDesc));
            super.destroyHandles();
        }

        private static class Deallocator
        extends CudnnLocalResponseNormalizationContext
        implements Pointer.Deallocator {
            Deallocator(CudnnLocalResponseNormalizationContext c) {
                super(c);
            }

            public void deallocate() {
                this.destroyHandles();
            }
        }
    }
}

