/*
 *  ******************************************************************************
 *  *
 *  *
 *  * This program and the accompanying materials are made available under the
 *  * terms of the Apache License, Version 2.0 which is available at
 *  * https://www.apache.org/licenses/LICENSE-2.0.
 *  *
 *  * See the NOTICE file distributed with this work for additional
 *  * information regarding copyright ownership.
 *  * Unless required by applicable law or agreed to in writing, software
 *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 *  * License for the specific language governing permissions and limitations
 *  * under the License.
 *  *
 *  * SPDX-License-Identifier: Apache-2.0
 *  *****************************************************************************
 */

package org.deeplearning4j.cuda.normalization;

import lombok.extern.slf4j.Slf4j;
import lombok.val;
import org.bytedeco.javacpp.Pointer;
import org.deeplearning4j.nn.gradient.DefaultGradient;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.cuda.BaseCudnnHelper;
import org.deeplearning4j.nn.layers.normalization.LocalResponseNormalizationHelper;
import org.nd4j.jita.allocator.Allocator;
import org.nd4j.jita.allocator.impl.AtomicAllocator;
import org.nd4j.jita.conf.CudaEnvironment;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.executioner.GridExecutioner;
import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.jcublas.context.CudaContext;
import org.nd4j.common.primitives.Pair;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.deeplearning4j.nn.workspace.ArrayType;
import org.nd4j.common.util.ArrayUtil;

import java.util.Collections;
import java.util.Map;

import org.bytedeco.cuda.cudart.*;
import org.bytedeco.cuda.cudnn.*;

import static org.bytedeco.cuda.global.cudnn.*;

/**
 * cuDNN-based helper for the local response normalization layer.
 *
 * @author saudet
 */
@Slf4j
public class CudnnLocalResponseNormalizationHelper extends BaseCudnnHelper implements LocalResponseNormalizationHelper {

    public CudnnLocalResponseNormalizationHelper(DataType dataType) {
        super(dataType);
    }

    private static class CudnnLocalResponseNormalizationContext extends CudnnContext {

        private static class Deallocator extends CudnnLocalResponseNormalizationContext implements Pointer.Deallocator {
            Deallocator(CudnnLocalResponseNormalizationContext c) {
                super(c);
            }

            @Override
            public void deallocate() {
                destroyHandles();
            }
        }

        private cudnnTensorStruct srcTensorDesc = new cudnnTensorStruct(), dstTensorDesc = new cudnnTensorStruct(),
                        deltaTensorDesc = new cudnnTensorStruct();
        private cudnnLRNStruct lrnDesc = new cudnnLRNStruct();

        public CudnnLocalResponseNormalizationContext() {
            createHandles();
            deallocator(new Deallocator(this));
        }

        public CudnnLocalResponseNormalizationContext(CudnnLocalResponseNormalizationContext c) {
            super(c);
            srcTensorDesc = new cudnnTensorStruct(c.srcTensorDesc);
            dstTensorDesc = new cudnnTensorStruct(c.dstTensorDesc);
            deltaTensorDesc = new cudnnTensorStruct(c.deltaTensorDesc);
            lrnDesc = new cudnnLRNStruct(c.lrnDesc);
        }

        @Override
        protected void createHandles() {
            super.createHandles();
            checkCudnn(cudnnCreateTensorDescriptor(srcTensorDesc));
            checkCudnn(cudnnCreateTensorDescriptor(dstTensorDesc));
            checkCudnn(cudnnCreateTensorDescriptor(deltaTensorDesc));
            checkCudnn(cudnnCreateLRNDescriptor(lrnDesc));
        }

        @Override
        protected void destroyHandles() {
            checkCudnn(cudnnDestroyLRNDescriptor(lrnDesc));
            checkCudnn(cudnnDestroyTensorDescriptor(srcTensorDesc));
            checkCudnn(cudnnDestroyTensorDescriptor(dstTensorDesc));
            checkCudnn(cudnnDestroyTensorDescriptor(deltaTensorDesc));
            super.destroyHandles();
        }
    }

    private CudnnLocalResponseNormalizationContext cudnnContext = new CudnnLocalResponseNormalizationContext();
    private INDArray activations = null;

    public boolean checkSupported(double k, double n, double alpha, double beta) {
        boolean supported = checkSupported();
        if (n < CUDNN_LRN_MIN_N) {
            supported = false;
            log.warn("Not supported: n < CUDNN_LRN_MIN_N (" + n + " < " + CUDNN_LRN_MIN_N + ")");
        }
        if (n > CUDNN_LRN_MAX_N) {
            supported = false;
            log.warn("Not supported: n > CUDNN_LRN_MAX_N (" + n + " > " + CUDNN_LRN_MAX_N + ")");
        }
        if (k < CUDNN_LRN_MIN_K) {
            supported = false;
            log.warn("Not supported: k < CUDNN_LRN_MIN_K (" + k + " < " + CUDNN_LRN_MIN_K + ")");
        }
        if (beta < CUDNN_LRN_MIN_BETA) {
            supported = false;
            log.warn("Not supported: beta < CUDNN_LRN_MIN_BETA (" + beta + " < " + CUDNN_LRN_MIN_BETA + ")");
        }
        return supported;
    }

    @Override
    public Pair<Gradient, INDArray> backpropGradient(INDArray input, INDArray epsilon, double k, double n, double alpha,
                                                     double beta, LayerWorkspaceMgr workspaceMgr) {
        val miniBatch = (int) input.size(0);
        val depth = (int) input.size(1);
        val inH = (int) input.size(2);
        val inW = (int) input.size(3);

        Gradient retGradient = new DefaultGradient();

        if (!Shape.hasDefaultStridesForShape(epsilon)) {
            // apparently not supported by cuDNN
            epsilon = epsilon.dup('c');
        }

        val srcStride = ArrayUtil.toInts(input.stride());
        val deltaStride = ArrayUtil.toInts(epsilon.stride());

        if (Nd4j.getExecutioner() instanceof GridExecutioner)
            ((GridExecutioner) Nd4j.getExecutioner()).flushQueue();

        checkCudnn(cudnnSetTensor4dDescriptorEx(cudnnContext.srcTensorDesc, dataType, miniBatch, depth, inH, inW,
                        srcStride[0], srcStride[1], srcStride[2], srcStride[3]));
        checkCudnn(cudnnSetTensor4dDescriptorEx(cudnnContext.deltaTensorDesc, dataType, miniBatch, depth, inH, inW,
                        deltaStride[0], deltaStride[1], deltaStride[2], deltaStride[3]));
        checkCudnn(cudnnSetLRNDescriptor(cudnnContext.lrnDesc, (int) n, alpha, beta, k));

        INDArray nextEpsilon = workspaceMgr.createUninitialized(ArrayType.ACTIVATION_GRAD, input.dataType(), new long[] {miniBatch, depth, inH, inW}, 'c');

        val dstStride = ArrayUtil.toInts(nextEpsilon.stride());
        checkCudnn(cudnnSetTensor4dDescriptorEx(cudnnContext.dstTensorDesc, dataType, miniBatch, depth, inH, inW,
                        dstStride[0], dstStride[1], dstStride[2], dstStride[3]));

        Allocator allocator = AtomicAllocator.getInstance();
        CudaContext context =
                        allocator.getFlowController().prepareActionAllWrite(input, epsilon, activations, nextEpsilon);
        Pointer srcData = allocator.getPointer(input, context);
        Pointer epsData = allocator.getPointer(epsilon, context);
        Pointer zData = allocator.getPointer(activations, context);
        Pointer dstData = allocator.getPointer(nextEpsilon, context);

        checkCudnn(cudnnSetStream(cudnnContext, new CUstream_st(context.getCublasStream())));
        checkCudnn(cudnnLRNCrossChannelBackward(cudnnContext, cudnnContext.lrnDesc, CUDNN_LRN_CROSS_CHANNEL_DIM1,
                        this.alpha, cudnnContext.deltaTensorDesc, zData, cudnnContext.deltaTensorDesc, epsData,
                        cudnnContext.srcTensorDesc, srcData, this.beta, cudnnContext.dstTensorDesc, dstData));

        allocator.getFlowController().registerActionAllWrite(context, input, epsilon, activations, nextEpsilon);

        if (CudaEnvironment.getInstance().getConfiguration().isDebug())
            context.syncOldStream();

        return new Pair<>(retGradient, nextEpsilon);
    }


    @Override
    public INDArray activate(INDArray input, boolean training, double k, double n, double alpha, double beta, LayerWorkspaceMgr workspaceMgr) {
        val miniBatch = (int) input.size(0);
        val inDepth = (int) input.size(1);
        val inH = (int) input.size(2);
        val inW = (int) input.size(3);

        if(!Shape.hasDefaultStridesForShape(input)){
            input = input.dup('c');
        }

        val srcStride = ArrayUtil.toInts(input.stride());
        checkCudnn(cudnnSetTensor4dDescriptorEx(cudnnContext.srcTensorDesc, dataType, miniBatch, inDepth, inH, inW,
                        srcStride[0], srcStride[1], srcStride[2], srcStride[3]));

        activations = workspaceMgr.createUninitialized(ArrayType.ACTIVATIONS, input.dataType(), new long[] {miniBatch, inDepth, inH, inW}, 'c');

        val dstStride = ArrayUtil.toInts(activations.stride());
        checkCudnn(cudnnSetTensor4dDescriptorEx(cudnnContext.dstTensorDesc, dataType, miniBatch, inDepth, inH, inW,
                        dstStride[0], dstStride[1], dstStride[2], dstStride[3]));
        checkCudnn(cudnnSetLRNDescriptor(cudnnContext.lrnDesc, (int) n, alpha, beta, k));

        Allocator allocator = AtomicAllocator.getInstance();
        CudaContext context = allocator.getFlowController().prepareActionAllWrite(input, activations);
        Pointer srcData = allocator.getPointer(input, context);
        Pointer dstData = allocator.getPointer(activations, context);

        if (Nd4j.getExecutioner() instanceof GridExecutioner)
            ((GridExecutioner) Nd4j.getExecutioner()).flushQueue();

        checkCudnn(cudnnSetStream(cudnnContext, new CUstream_st(context.getCublasStream())));
        checkCudnn(cudnnLRNCrossChannelForward(cudnnContext, cudnnContext.lrnDesc, CUDNN_LRN_CROSS_CHANNEL_DIM1,
                        this.alpha, cudnnContext.srcTensorDesc, srcData, this.beta, cudnnContext.dstTensorDesc,
                        dstData));

        allocator.getFlowController().registerActionAllWrite(context, input, activations);

        if (CudaEnvironment.getInstance().getConfiguration().isDebug())
            context.syncOldStream();

        return activations;
    }

    @Override
    public Map<String, Long> helperMemoryUse() {
        //No persistent memory use other than the structs (which are small)
        return Collections.emptyMap();
    }
}
