/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.cuda.convolution.subsampling;

import java.util.Collections;
import java.util.Map;
import org.bytedeco.cuda.cudart.CUstream_st;
import org.bytedeco.cuda.cudnn.cudnnContext;
import org.bytedeco.cuda.cudnn.cudnnPoolingStruct;
import org.bytedeco.cuda.cudnn.cudnnTensorStruct;
import org.bytedeco.cuda.global.cudnn;
import org.bytedeco.javacpp.Pointer;
import org.deeplearning4j.cuda.BaseCudnnHelper;
import org.deeplearning4j.cuda.convolution.CudnnConvolutionHelper;
import org.deeplearning4j.nn.conf.CNN2DFormat;
import org.deeplearning4j.nn.conf.ConvolutionMode;
import org.deeplearning4j.nn.conf.layers.PoolingType;
import org.deeplearning4j.nn.gradient.DefaultGradient;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.layers.convolution.subsampling.SubsamplingHelper;
import org.deeplearning4j.nn.workspace.ArrayType;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.nd4j.common.primitives.Pair;
import org.nd4j.jita.allocator.impl.AtomicAllocator;
import org.nd4j.jita.conf.CudaEnvironment;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.executioner.GridExecutioner;
import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.linalg.jcublas.context.CudaContext;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class CudnnSubsamplingHelper
extends BaseCudnnHelper
implements SubsamplingHelper {
    private static final Logger log = LoggerFactory.getLogger(CudnnSubsamplingHelper.class);
    private CudnnSubsamplingContext cudnnContext = new CudnnSubsamplingContext();

    public CudnnSubsamplingHelper(DataType dataType) {
        super(dataType);
    }

    public Pair<Gradient, INDArray> backpropGradient(INDArray input, INDArray epsilon, int[] kernel, int[] strides, int[] pad, PoolingType poolingType, ConvolutionMode convolutionMode, int[] dilation, CNN2DFormat format, LayerWorkspaceMgr workspaceMgr) {
        long[] lArray;
        int poolingMode;
        if (dilation[0] != 1 || dilation[1] != 1) {
            return null;
        }
        boolean nchw = format == CNN2DFormat.NCHW;
        int chIdx = nchw ? 1 : 3;
        int hIdx = nchw ? 2 : 1;
        int wIdx = nchw ? 3 : 2;
        INDArray reduced = this.activate(input, true, kernel, strides, pad, poolingType, convolutionMode, dilation, format, workspaceMgr);
        long miniBatch = input.size(0);
        long depth = input.size(chIdx);
        CudnnConvolutionHelper.CudnnForwardArgs args = CudnnConvolutionHelper.getCudnnForwardArgs(input, kernel, strides, pad, dilation, convolutionMode, poolingType, format);
        input = args.getInput();
        long inH = input.size(hIdx);
        long inW = input.size(wIdx);
        long[] srcStride = input.stride();
        int[] outSize = args.getOutSize();
        int outH = outSize[0];
        int outW = outSize[1];
        DefaultGradient retGradient = new DefaultGradient();
        switch (poolingType) {
            case AVG: {
                poolingMode = 1;
                break;
            }
            case MAX: {
                poolingMode = 0;
                break;
            }
            default: {
                return null;
            }
        }
        if (!Shape.hasDefaultStridesForShape((INDArray)epsilon) || epsilon.isView()) {
            epsilon = epsilon.dup('c');
        }
        input = input.dup();
        long[] deltaStride = epsilon.stride();
        if (Nd4j.getExecutioner() instanceof GridExecutioner) {
            ((GridExecutioner)Nd4j.getExecutioner()).flushQueue();
        }
        CudnnSubsamplingHelper.checkCudnn(cudnn.cudnnSetTensor4dDescriptorEx((cudnnTensorStruct)this.cudnnContext.srcTensorDesc, (int)this.dataType, (int)((int)miniBatch), (int)((int)depth), (int)((int)inH), (int)((int)inW), (int)((int)srcStride[0]), (int)((int)srcStride[chIdx]), (int)((int)srcStride[hIdx]), (int)((int)srcStride[wIdx])));
        CudnnSubsamplingHelper.checkCudnn(cudnn.cudnnSetTensor4dDescriptorEx((cudnnTensorStruct)this.cudnnContext.deltaTensorDesc, (int)this.dataType, (int)((int)miniBatch), (int)((int)depth), (int)outH, (int)outW, (int)((int)deltaStride[0]), (int)((int)deltaStride[chIdx]), (int)((int)deltaStride[hIdx]), (int)((int)deltaStride[wIdx])));
        CudnnSubsamplingHelper.checkCudnn(cudnn.cudnnSetPooling2dDescriptor((cudnnPoolingStruct)this.cudnnContext.poolingDesc, (int)poolingMode, (int)1, (int)kernel[0], (int)kernel[1], (int)pad[0], (int)pad[1], (int)strides[0], (int)strides[1]));
        if (nchw) {
            long[] lArray2 = new long[4];
            lArray2[0] = miniBatch;
            lArray2[1] = depth;
            lArray2[2] = inH;
            lArray = lArray2;
            lArray2[3] = inW;
        } else {
            long[] lArray3 = new long[4];
            lArray3[0] = miniBatch;
            lArray3[1] = inH;
            lArray3[2] = inW;
            lArray = lArray3;
            lArray3[3] = depth;
        }
        long[] outEpsShape = lArray;
        INDArray outEpsilon = workspaceMgr.createUninitialized((Enum)ArrayType.ACTIVATION_GRAD, input.dataType(), outEpsShape, 'c');
        long[] dstStride = outEpsilon.stride();
        CudnnSubsamplingHelper.checkCudnn(cudnn.cudnnSetTensor4dDescriptorEx((cudnnTensorStruct)this.cudnnContext.dstTensorDesc, (int)this.dataType, (int)((int)miniBatch), (int)((int)depth), (int)((int)inH), (int)((int)inW), (int)((int)dstStride[0]), (int)((int)dstStride[chIdx]), (int)((int)dstStride[hIdx]), (int)((int)dstStride[wIdx])));
        AtomicAllocator allocator = AtomicAllocator.getInstance();
        CudaContext context = allocator.getFlowController().prepareAction(input, new INDArray[]{epsilon, reduced, outEpsilon});
        Pointer srcData = allocator.getPointer(input, context);
        Pointer epsData = allocator.getPointer(epsilon, context);
        Pointer zData = allocator.getPointer(reduced, context);
        Pointer dstData = allocator.getPointer(outEpsilon, context);
        CudnnSubsamplingHelper.checkCudnn(cudnn.cudnnSetStream((cudnnContext)this.cudnnContext, (CUstream_st)new CUstream_st(context.getCublasStream())));
        CudnnSubsamplingHelper.checkCudnn(cudnn.cudnnPoolingBackward((cudnnContext)this.cudnnContext, (cudnnPoolingStruct)this.cudnnContext.poolingDesc, (Pointer)this.alpha, (cudnnTensorStruct)this.cudnnContext.deltaTensorDesc, (Pointer)zData, (cudnnTensorStruct)this.cudnnContext.deltaTensorDesc, (Pointer)epsData, (cudnnTensorStruct)this.cudnnContext.srcTensorDesc, (Pointer)srcData, (Pointer)this.beta, (cudnnTensorStruct)this.cudnnContext.dstTensorDesc, (Pointer)dstData));
        allocator.registerAction(context, outEpsilon, new INDArray[]{input, epsilon, reduced});
        if (CudaEnvironment.getInstance().getConfiguration().isDebug()) {
            context.syncOldStream();
        }
        if (args.isManualPadBottom() || args.isManualPadRight()) {
            outEpsilon = nchw ? outEpsilon.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.interval((long)0L, (long)(outEpsilon.size(2) - (long)(args.isManualPadBottom() ? 1 : 0))), NDArrayIndex.interval((long)0L, (long)(outEpsilon.size(3) - (long)(args.isManualPadRight() ? 1 : 0)))}) : outEpsilon.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval((long)0L, (long)(outEpsilon.size(1) - (long)(args.isManualPadBottom() ? 1 : 0))), NDArrayIndex.interval((long)0L, (long)(outEpsilon.size(2) - (long)(args.isManualPadRight() ? 1 : 0))), NDArrayIndex.all()});
        }
        return new Pair((Object)retGradient, (Object)outEpsilon);
    }

    public INDArray activate(INDArray input, boolean training, int[] kernel, int[] strides, int[] pad, PoolingType poolingType, ConvolutionMode convolutionMode, int[] dilation, CNN2DFormat format, LayerWorkspaceMgr workspaceMgr) {
        long[] lArray;
        int poolingMode;
        if (dilation[0] != 1 || dilation[1] != 1) {
            return null;
        }
        boolean nchw = format == CNN2DFormat.NCHW;
        int chIdx = nchw ? 1 : 3;
        int hIdx = nchw ? 2 : 1;
        int wIdx = nchw ? 3 : 2;
        long miniBatch = input.size(0);
        long inDepth = input.size(nchw ? 1 : 3);
        CudnnConvolutionHelper.CudnnForwardArgs args = CudnnConvolutionHelper.getCudnnForwardArgs(input, kernel, strides, pad, dilation, convolutionMode, poolingType, format);
        input = args.getInput();
        long inH = input.size(nchw ? 2 : 1);
        long inW = input.size(nchw ? 3 : 2);
        long[] srcStride = input.stride();
        int[] outSize = args.getOutSize();
        int outH = outSize[0];
        int outW = outSize[1];
        switch (poolingType) {
            case AVG: {
                poolingMode = 1;
                break;
            }
            case MAX: {
                poolingMode = 0;
                break;
            }
            default: {
                return null;
            }
        }
        if (Nd4j.getExecutioner() instanceof GridExecutioner) {
            ((GridExecutioner)Nd4j.getExecutioner()).flushQueue();
        }
        CudnnSubsamplingHelper.checkCudnn(cudnn.cudnnSetPooling2dDescriptor((cudnnPoolingStruct)this.cudnnContext.poolingDesc, (int)poolingMode, (int)1, (int)kernel[0], (int)kernel[1], (int)pad[0], (int)pad[1], (int)strides[0], (int)strides[1]));
        CudnnSubsamplingHelper.checkCudnn(cudnn.cudnnSetTensor4dDescriptorEx((cudnnTensorStruct)this.cudnnContext.srcTensorDesc, (int)this.dataType, (int)((int)miniBatch), (int)((int)inDepth), (int)((int)inH), (int)((int)inW), (int)((int)srcStride[0]), (int)((int)srcStride[chIdx]), (int)((int)srcStride[hIdx]), (int)((int)srcStride[wIdx])));
        if (nchw) {
            long[] lArray2 = new long[4];
            lArray2[0] = miniBatch;
            lArray2[1] = inDepth;
            lArray2[2] = outH;
            lArray = lArray2;
            lArray2[3] = outW;
        } else {
            long[] lArray3 = new long[4];
            lArray3[0] = miniBatch;
            lArray3[1] = outH;
            lArray3[2] = outW;
            lArray = lArray3;
            lArray3[3] = inDepth;
        }
        long[] outShape = lArray;
        INDArray reduced = workspaceMgr.createUninitialized((Enum)ArrayType.ACTIVATIONS, input.dataType(), outShape, 'c');
        long[] dstStride = reduced.stride();
        CudnnSubsamplingHelper.checkCudnn(cudnn.cudnnSetTensor4dDescriptorEx((cudnnTensorStruct)this.cudnnContext.dstTensorDesc, (int)this.dataType, (int)((int)miniBatch), (int)((int)inDepth), (int)outH, (int)outW, (int)((int)dstStride[0]), (int)((int)dstStride[chIdx]), (int)((int)dstStride[hIdx]), (int)((int)dstStride[wIdx])));
        AtomicAllocator allocator = AtomicAllocator.getInstance();
        CudaContext context = allocator.getFlowController().prepareAction(input, new INDArray[]{reduced});
        Pointer srcData = allocator.getPointer(input, context);
        Pointer dstData = allocator.getPointer(reduced, context);
        CudnnSubsamplingHelper.checkCudnn(cudnn.cudnnSetStream((cudnnContext)this.cudnnContext, (CUstream_st)new CUstream_st(context.getCublasStream())));
        CudnnSubsamplingHelper.checkCudnn(cudnn.cudnnPoolingForward((cudnnContext)this.cudnnContext, (cudnnPoolingStruct)this.cudnnContext.poolingDesc, (Pointer)this.alpha, (cudnnTensorStruct)this.cudnnContext.srcTensorDesc, (Pointer)srcData, (Pointer)this.beta, (cudnnTensorStruct)this.cudnnContext.dstTensorDesc, (Pointer)dstData));
        allocator.registerAction(context, reduced, new INDArray[]{input});
        if (CudaEnvironment.getInstance().getConfiguration().isDebug()) {
            context.syncOldStream();
        }
        return reduced;
    }

    public Map<String, Long> helperMemoryUse() {
        return Collections.emptyMap();
    }

    private static class CudnnSubsamplingContext
    extends BaseCudnnHelper.CudnnContext {
        private cudnnTensorStruct srcTensorDesc = new cudnnTensorStruct();
        private cudnnTensorStruct dstTensorDesc = new cudnnTensorStruct();
        private cudnnTensorStruct deltaTensorDesc = new cudnnTensorStruct();
        private cudnnPoolingStruct poolingDesc = new cudnnPoolingStruct();

        public CudnnSubsamplingContext() {
            this.createHandles();
            this.deallocator(new Deallocator(this));
        }

        public CudnnSubsamplingContext(CudnnSubsamplingContext c) {
            super(c);
            this.srcTensorDesc = new cudnnTensorStruct((Pointer)c.srcTensorDesc);
            this.dstTensorDesc = new cudnnTensorStruct((Pointer)c.dstTensorDesc);
            this.deltaTensorDesc = new cudnnTensorStruct((Pointer)c.deltaTensorDesc);
            this.poolingDesc = new cudnnPoolingStruct((Pointer)c.poolingDesc);
        }

        @Override
        protected void createHandles() {
            super.createHandles();
            CudnnSubsamplingHelper.checkCudnn(cudnn.cudnnCreateTensorDescriptor((cudnnTensorStruct)this.srcTensorDesc));
            CudnnSubsamplingHelper.checkCudnn(cudnn.cudnnCreateTensorDescriptor((cudnnTensorStruct)this.dstTensorDesc));
            CudnnSubsamplingHelper.checkCudnn(cudnn.cudnnCreateTensorDescriptor((cudnnTensorStruct)this.deltaTensorDesc));
            CudnnSubsamplingHelper.checkCudnn(cudnn.cudnnCreatePoolingDescriptor((cudnnPoolingStruct)this.poolingDesc));
        }

        @Override
        protected void destroyHandles() {
            CudnnSubsamplingHelper.checkCudnn(cudnn.cudnnDestroyPoolingDescriptor((cudnnPoolingStruct)this.poolingDesc));
            CudnnSubsamplingHelper.checkCudnn(cudnn.cudnnDestroyTensorDescriptor((cudnnTensorStruct)this.srcTensorDesc));
            CudnnSubsamplingHelper.checkCudnn(cudnn.cudnnDestroyTensorDescriptor((cudnnTensorStruct)this.dstTensorDesc));
            CudnnSubsamplingHelper.checkCudnn(cudnn.cudnnDestroyTensorDescriptor((cudnnTensorStruct)this.deltaTensorDesc));
            super.destroyHandles();
        }

        private static class Deallocator
        extends CudnnSubsamplingContext
        implements Pointer.Deallocator {
            Deallocator(CudnnSubsamplingContext c) {
                super(c);
            }

            public void deallocate() {
                this.destroyHandles();
            }
        }
    }
}

