/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.cuda;

import lombok.NonNull;
import org.bytedeco.cuda.cudnn.cudnnContext;
import org.bytedeco.cuda.cudnn.cudnnTensorStruct;
import org.bytedeco.cuda.global.cudart;
import org.bytedeco.cuda.global.cudnn;
import org.bytedeco.javacpp.DoublePointer;
import org.bytedeco.javacpp.FloatPointer;
import org.bytedeco.javacpp.Pointer;
import org.bytedeco.javacpp.PointerPointer;
import org.bytedeco.javacpp.SizeTPointer;
import org.nd4j.jita.allocator.impl.AtomicAllocator;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.factory.Nd4j;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public abstract class BaseCudnnHelper {
    private static final Logger log = LoggerFactory.getLogger(BaseCudnnHelper.class);
    protected final DataType nd4jDataType;
    protected final int dataType;
    protected final int dataTypeSize;
    protected final Pointer alpha;
    protected final Pointer beta;
    protected SizeTPointer sizeInBytes = new SizeTPointer(1L);

    protected static void checkCuda(int error) {
        if (error != 0) {
            throw new RuntimeException("CUDA error = " + error + ": " + cudart.cudaGetErrorString((int)error).getString());
        }
    }

    protected static void checkCudnn(int status) {
        if (status != 0) {
            throw new RuntimeException("cuDNN status = " + status + ": " + cudnn.cudnnGetErrorString((int)status).getString());
        }
    }

    public BaseCudnnHelper(@NonNull DataType dataType) {
        if (dataType == null) {
            throw new NullPointerException("dataType is marked non-null but is null");
        }
        this.nd4jDataType = dataType;
        int n = dataType == DataType.DOUBLE ? 1 : (this.dataType = dataType == DataType.FLOAT ? 0 : 2);
        this.dataTypeSize = dataType == DataType.DOUBLE ? 8 : (dataType == DataType.FLOAT ? 4 : 2);
        this.alpha = this.dataType == 1 ? new DoublePointer(new double[]{1.0}) : new FloatPointer(new float[]{1.0f});
        this.beta = this.dataType == 1 ? new DoublePointer(new double[]{0.0}) : new FloatPointer(new float[]{0.0f});
    }

    public static int toCudnnDataType(DataType type) {
        switch (type) {
            case DOUBLE: {
                return 1;
            }
            case FLOAT: {
                return 0;
            }
            case INT: {
                return 4;
            }
            case HALF: {
                return 2;
            }
        }
        throw new RuntimeException("Cannot convert type: " + type);
    }

    public boolean checkSupported() {
        return true;
    }

    protected static int[] adaptForTensorDescr(int[] shapeOrStrides) {
        int i;
        if (shapeOrStrides.length >= 4) {
            return shapeOrStrides;
        }
        int[] out = new int[4];
        for (i = 0; i < shapeOrStrides.length; ++i) {
            out[i] = shapeOrStrides[i];
        }
        while (i < 4) {
            out[i] = 1;
            ++i;
        }
        return out;
    }

    protected static class TensorArray
    extends PointerPointer<cudnnTensorStruct> {
        public TensorArray() {
        }

        public TensorArray(long size) {
            PointerPointer p = new PointerPointer(size);
            p.deallocate(false);
            this.address = p.address();
            this.limit = p.limit();
            this.capacity = p.capacity();
            cudnnTensorStruct t = new cudnnTensorStruct();
            int i = 0;
            while ((long)i < this.capacity) {
                BaseCudnnHelper.checkCudnn(cudnn.cudnnCreateTensorDescriptor((cudnnTensorStruct)t));
                this.put(i, (Pointer)t);
                ++i;
            }
            this.deallocator(new Deallocator(this, (Pointer)p));
        }

        public TensorArray(TensorArray a) {
            super((Pointer)a);
        }

        static class Deallocator
        extends TensorArray
        implements Pointer.Deallocator {
            Pointer owner;

            Deallocator(TensorArray a, Pointer owner) {
                this.address = a.address;
                this.capacity = a.capacity;
                this.owner = owner;
            }

            public void deallocate() {
                int i = 0;
                while (!this.isNull() && (long)i < this.capacity) {
                    cudnnTensorStruct t = (cudnnTensorStruct)this.get(cudnnTensorStruct.class, i);
                    BaseCudnnHelper.checkCudnn(cudnn.cudnnDestroyTensorDescriptor((cudnnTensorStruct)t));
                    ++i;
                }
                if (this.owner != null) {
                    this.owner.deallocate();
                    this.owner = null;
                }
                this.setNull();
            }
        }
    }

    protected static class DataCache
    extends Pointer {
        public DataCache() {
        }

        public DataCache(long size) {
            this.position = 0L;
            this.limit = this.capacity = size;
            int error = cudart.cudaMalloc((Pointer)this, (long)size);
            if (error != 0) {
                log.warn("Cannot allocate " + size + " bytes of device memory (CUDA error = " + error + "), proceeding with host memory");
                BaseCudnnHelper.checkCuda(cudart.cudaMallocHost((Pointer)this, (long)size));
                this.deallocator(new HostDeallocator(this));
            } else {
                this.deallocator(new Deallocator(this));
            }
        }

        public DataCache(DataCache c) {
            super((Pointer)c);
        }

        static class HostDeallocator
        extends DataCache
        implements Pointer.Deallocator {
            HostDeallocator(DataCache c) {
                super(c);
            }

            public void deallocate() {
                BaseCudnnHelper.checkCuda(cudart.cudaFreeHost((Pointer)this));
                this.setNull();
            }
        }

        static class Deallocator
        extends DataCache
        implements Pointer.Deallocator {
            Deallocator(DataCache c) {
                super(c);
            }

            public void deallocate() {
                BaseCudnnHelper.checkCuda(cudart.cudaFree((Pointer)this));
                this.setNull();
            }
        }
    }

    protected static class CudnnContext
    extends cudnnContext {
        public CudnnContext() {
            Nd4j.create((int)1);
            AtomicAllocator.getInstance();
        }

        public CudnnContext(CudnnContext c) {
            super((Pointer)c);
        }

        protected void createHandles() {
            BaseCudnnHelper.checkCudnn(cudnn.cudnnCreate((cudnnContext)this));
        }

        protected void destroyHandles() {
            BaseCudnnHelper.checkCudnn(cudnn.cudnnDestroy((cudnnContext)this));
        }

        protected static class Deallocator
        extends CudnnContext
        implements Pointer.Deallocator {
            Deallocator(CudnnContext c) {
                super(c);
            }

            public void deallocate() {
                this.destroyHandles();
            }
        }
    }
}

