/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.jita.handler.impl;

import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicLong;
import java.util.concurrent.locks.ReentrantReadWriteLock;
import lombok.NonNull;
import org.apache.commons.lang3.RandomUtils;
import org.bytedeco.javacpp.Pointer;
import org.nd4j.common.base.Preconditions;
import org.nd4j.jita.allocator.Allocator;
import org.nd4j.jita.allocator.concurrency.DeviceAllocationsTracker;
import org.nd4j.jita.allocator.enums.AllocationStatus;
import org.nd4j.jita.allocator.enums.CudaConstants;
import org.nd4j.jita.allocator.impl.AllocationPoint;
import org.nd4j.jita.allocator.impl.AllocationShape;
import org.nd4j.jita.allocator.impl.AtomicAllocator;
import org.nd4j.jita.allocator.pointers.CudaPointer;
import org.nd4j.jita.allocator.pointers.PointersPair;
import org.nd4j.jita.allocator.pointers.cuda.cublasHandle_t;
import org.nd4j.jita.allocator.pointers.cuda.cudaStream_t;
import org.nd4j.jita.allocator.pointers.cuda.cusolverDnHandle_t;
import org.nd4j.jita.conf.Configuration;
import org.nd4j.jita.conf.CudaEnvironment;
import org.nd4j.jita.flow.FlowController;
import org.nd4j.jita.flow.impl.GridFlowController;
import org.nd4j.jita.handler.MemoryHandler;
import org.nd4j.jita.memory.MemoryProvider;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.concurrency.AffinityManager;
import org.nd4j.linalg.api.memory.MemcpyDirection;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.performance.PerformanceTracker;
import org.nd4j.linalg.exception.ND4JIllegalStateException;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.jcublas.buffer.BaseCudaDataBuffer;
import org.nd4j.linalg.jcublas.context.CudaContext;
import org.nd4j.linalg.profiler.OpProfiler;
import org.nd4j.nativeblas.NativeOps;
import org.nd4j.nativeblas.NativeOpsHolder;
import org.nd4j.nativeblas.OpaqueLaunchContext;
import org.nd4j.shade.guava.collect.HashBasedTable;
import org.nd4j.shade.guava.collect.Table;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class CudaZeroHandler
implements MemoryHandler {
    private static Configuration configuration = CudaEnvironment.getInstance().getConfiguration();
    private static Logger log = LoggerFactory.getLogger(CudaZeroHandler.class);
    protected final AtomicLong zeroUseCounter = new AtomicLong(0L);
    protected volatile DeviceAllocationsTracker deviceMemoryTracker;
    protected Map<Long, Integer> devicesAffinity = new ConcurrentHashMap<Long, Integer>();
    private ReentrantReadWriteLock deviceLock = new ReentrantReadWriteLock();
    private AtomicInteger devPtr = new AtomicInteger(0);
    private final AtomicBoolean wasInitialised = new AtomicBoolean(false);
    private final FlowController flowController;
    private final AllocationStatus INITIAL_LOCATION;
    private final List<cublasHandle_t> cublasHandles = new ArrayList<cublasHandle_t>();
    private final AffinityManager affinityManager = Nd4j.getAffinityManager();
    private final transient ThreadLocal<CudaContext> tlContext = new ThreadLocal();
    private final List<ConcurrentHashMap<Long, Long>> deviceAllocations = new ArrayList<ConcurrentHashMap<Long, Long>>();
    private final Map<Long, ConcurrentHashMap<Long, Long>> zeroAllocations = new ConcurrentHashMap<Long, ConcurrentHashMap<Long, Long>>();
    private AtomicLong zeroCounter = new AtomicLong(0L);
    protected NativeOps nativeOps = NativeOpsHolder.getInstance().getDeviceNativeOps();
    private final ReentrantReadWriteLock lock = new ReentrantReadWriteLock();

    public CudaZeroHandler() {
        configuration.setInitialized();
        this.INITIAL_LOCATION = configuration.getFirstMemory();
        switch (configuration.getExecutionModel()) {
            case SEQUENTIAL: {
                this.flowController = new GridFlowController();
                break;
            }
            default: {
                throw new RuntimeException("Unknown ExecutionModel: [" + (Object)((Object)configuration.getExecutionModel()) + "]");
            }
        }
        int numDevices = NativeOpsHolder.getInstance().getDeviceNativeOps().getAvailableDevices();
        for (int i = 0; i < numDevices; ++i) {
            this.deviceAllocations.add(new ConcurrentHashMap());
            this.cublasHandles.add(null);
        }
        if (NativeOpsHolder.getInstance().getDeviceNativeOps().getDeviceMajor(0) < 3) {
            throw new ND4JIllegalStateException("CUDA backend requires compute capatibility of 3.0 and above to run.");
        }
    }

    @Override
    public void init(@NonNull Configuration configuration, @NonNull Allocator allocator) {
        if (configuration == null) {
            throw new NullPointerException("configuration is marked non-null but is null");
        }
        if (allocator == null) {
            throw new NullPointerException("allocator is marked non-null but is null");
        }
        CudaZeroHandler.configuration = configuration;
        this.deviceMemoryTracker = new DeviceAllocationsTracker(CudaZeroHandler.configuration);
        this.flowController.init(allocator);
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    private void pickupHostAllocation(AllocationPoint point) {
        int numBuckets = configuration.getNumberOfGcThreads();
        long bucketId = RandomUtils.nextInt((int)0, (int)numBuckets);
        long reqMemory = point.getNumberOfBytes();
        this.zeroUseCounter.addAndGet(reqMemory);
        point.setBucketId(bucketId);
        if (!this.zeroAllocations.containsKey(bucketId)) {
            log.debug("Creating bucketID: " + bucketId);
            CudaZeroHandler cudaZeroHandler = this;
            synchronized (cudaZeroHandler) {
                if (!this.zeroAllocations.containsKey(bucketId)) {
                    this.zeroAllocations.put(bucketId, new ConcurrentHashMap());
                }
            }
        }
        this.zeroAllocations.get(bucketId).put(point.getObjectId(), point.getObjectId());
    }

    @Override
    public PointersPair alloc(AllocationStatus targetMode, AllocationPoint point, AllocationShape shape, boolean initialize) {
        throw new UnsupportedOperationException();
    }

    @Override
    public boolean pingDeviceForFreeMemory(Integer deviceId, long requiredMemory) {
        return true;
    }

    @Override
    public void relocate(AllocationStatus currentStatus, AllocationStatus targetStatus, AllocationPoint point, AllocationShape shape, CudaContext context) {
    }

    @Override
    @Deprecated
    public void copyback(AllocationPoint point, AllocationShape shape) {
        throw new UnsupportedOperationException("Deprecated call");
    }

    @Override
    @Deprecated
    public void copyforward(AllocationPoint point, AllocationShape shape) {
        throw new UnsupportedOperationException("Deprecated call");
    }

    @Override
    @Deprecated
    public void fallback(AllocationPoint point, AllocationShape shape) {
        throw new IllegalStateException("Can't fallback from [" + (Object)((Object)point.getAllocationStatus()) + "]");
    }

    @Override
    public void free(AllocationPoint point, AllocationStatus target) {
    }

    @Override
    public AllocationStatus getInitialLocation() {
        return this.INITIAL_LOCATION;
    }

    @Override
    public void initializeDevice(Long threadId, Integer deviceId) {
    }

    @Override
    public void memcpyAsync(DataBuffer dstBuffer, Pointer srcPointer, long length, long dstOffset) {
        if (length < 1L) {
            return;
        }
        Preconditions.checkArgument((length <= dstBuffer.length() * (long)Nd4j.sizeOfDataType((DataType)dstBuffer.dataType()) ? 1 : 0) != 0, (String)"Length requested is bigger than target DataBuffer length");
        AllocationPoint point = ((BaseCudaDataBuffer)dstBuffer).getAllocationPoint();
        CudaContext tContext = null;
        if (dstBuffer.isConstant()) {
            CudaPointer dstPointer = new CudaPointer(point.getHostPointer().address() + dstOffset, 0L);
            CudaPointer srcPointerJ = new CudaPointer(srcPointer, length);
            long profD = PerformanceTracker.getInstance().helperStartTransaction();
            Pointer.memcpy((Pointer)dstPointer, (Pointer)srcPointerJ, (long)length);
            PerformanceTracker.getInstance().helperRegisterTransaction(point.getDeviceId(), profD, point.getNumberOfBytes(), MemcpyDirection.HOST_TO_HOST);
            point.tickHostRead();
        } else {
            CudaPointer rDP = new CudaPointer(point.getDevicePointer().address() + dstOffset);
            if (tContext == null) {
                tContext = this.flowController.prepareAction(point, new AllocationPoint[0]);
            }
            long prof = PerformanceTracker.getInstance().helperStartTransaction();
            this.flowController.commitTransfer(tContext.getSpecialStream());
            if (this.nativeOps.memcpyAsync((Pointer)rDP, srcPointer, length, CudaConstants.cudaMemcpyHostToDevice, (Pointer)tContext.getSpecialStream()) == 0) {
                throw new IllegalStateException("MemcpyAsync H2D failed: [" + srcPointer.address() + "] -> [" + rDP.address() + "]");
            }
            this.flowController.commitTransfer(tContext.getSpecialStream());
            PerformanceTracker.getInstance().helperRegisterTransaction(point.getDeviceId(), prof, point.getNumberOfBytes(), MemcpyDirection.HOST_TO_DEVICE);
            this.flowController.registerAction(tContext, point, new AllocationPoint[0]);
            point.tickDeviceWrite();
            if (point.getHostPointer() != null) {
                CudaContext context;
                CudaPointer dP = new CudaPointer(point.getHostPointer().address() + dstOffset);
                tContext = context = this.flowController.prepareAction(point, new AllocationPoint[0]);
                prof = PerformanceTracker.getInstance().helperStartTransaction();
                if (this.nativeOps.memcpyAsync((Pointer)dP, srcPointer, length, CudaConstants.cudaMemcpyHostToHost, (Pointer)context.getSpecialStream()) == 0) {
                    throw new IllegalStateException("MemcpyAsync H2H failed: [" + srcPointer.address() + "] -> [" + dP.address() + "]");
                }
                this.flowController.commitTransfer(tContext.getSpecialStream());
                PerformanceTracker.getInstance().helperRegisterTransaction(point.getDeviceId(), prof, point.getNumberOfBytes(), MemcpyDirection.HOST_TO_HOST);
                if (point.getAllocationStatus() == AllocationStatus.HOST) {
                    this.flowController.registerAction(context, point, new AllocationPoint[0]);
                }
                point.tickHostRead();
            }
        }
    }

    @Override
    public void memcpyDevice(DataBuffer dstBuffer, Pointer srcPointer, long length, long dstOffset, CudaContext context) {
        AllocationPoint point = ((BaseCudaDataBuffer)dstBuffer).getAllocationPoint();
        CudaPointer dP = new CudaPointer(point.getDevicePointer().address() + dstOffset);
        if (this.nativeOps.memcpyAsync((Pointer)dP, srcPointer, length, CudaConstants.cudaMemcpyDeviceToDevice, (Pointer)context.getOldStream()) == 0) {
            throw new ND4JIllegalStateException("memcpyAsync failed");
        }
        point.tickDeviceWrite();
    }

    @Override
    public void memcpySpecial(DataBuffer dstBuffer, Pointer srcPointer, long length, long dstOffset) {
        CudaContext context = this.getCudaContext();
        AllocationPoint point = ((BaseCudaDataBuffer)dstBuffer).getAllocationPoint();
        CudaPointer dP = new CudaPointer(point.getHostPointer().address() + dstOffset);
        long profH = PerformanceTracker.getInstance().helperStartTransaction();
        if (this.nativeOps.memcpyAsync((Pointer)dP, srcPointer, length, CudaConstants.cudaMemcpyHostToHost, (Pointer)context.getOldStream()) == 0) {
            throw new ND4JIllegalStateException("memcpyAsync failed");
        }
        PerformanceTracker.getInstance().helperRegisterTransaction(point.getDeviceId(), profH, point.getNumberOfBytes(), MemcpyDirection.HOST_TO_HOST);
        if (point.getAllocationStatus() == AllocationStatus.DEVICE) {
            CudaPointer rDP = new CudaPointer(point.getDevicePointer().address() + dstOffset);
            long profD = PerformanceTracker.getInstance().helperStartTransaction();
            if (this.nativeOps.memcpyAsync((Pointer)rDP, (Pointer)dP, length, CudaConstants.cudaMemcpyHostToDevice, (Pointer)context.getOldStream()) == 0) {
                throw new ND4JIllegalStateException("memcpyAsync failed");
            }
            context.syncOldStream();
            PerformanceTracker.getInstance().helperRegisterTransaction(point.getDeviceId(), profD, point.getNumberOfBytes(), MemcpyDirection.HOST_TO_DEVICE);
        }
        context.syncOldStream();
        point.tickDeviceWrite();
    }

    @Override
    public void memcpyBlocking(DataBuffer dstBuffer, Pointer srcPointer, long length, long dstOffset) {
        CudaContext context = this.getCudaContext();
        this.memcpyAsync(dstBuffer, srcPointer, length, dstOffset);
        context.syncOldStream();
    }

    @Override
    public void memcpy(DataBuffer dstBuffer, DataBuffer srcBuffer) {
        CudaContext context = this.getCudaContext();
        AllocationPoint dstPoint = ((BaseCudaDataBuffer)dstBuffer).getAllocationPoint();
        AllocationPoint srcPoint = ((BaseCudaDataBuffer)srcBuffer).getAllocationPoint();
        Pointer dP = null;
        Pointer sP = null;
        MemcpyDirection direction = null;
        long profDH = PerformanceTracker.getInstance().helperStartTransaction();
        Nd4j.getExecutioner().push();
        if (srcPoint.isActualOnDeviceSide()) {
            sP = AtomicAllocator.getInstance().getPointer(srcBuffer, context);
            dP = AtomicAllocator.getInstance().getPointer(dstBuffer, context);
            if (this.nativeOps.memcpyAsync(dP, sP, srcBuffer.length() * (long)srcBuffer.getElementSize(), CudaConstants.cudaMemcpyDeviceToDevice, (Pointer)context.getOldStream()) == 0) {
                throw new ND4JIllegalStateException("memcpyAsync failed");
            }
            dstPoint.tickDeviceWrite();
            direction = MemcpyDirection.DEVICE_TO_DEVICE;
        } else {
            sP = AtomicAllocator.getInstance().getHostPointer(srcBuffer);
            dP = AtomicAllocator.getInstance().getPointer(dstBuffer, context);
            if (this.nativeOps.memcpyAsync(dP, sP, srcBuffer.length() * (long)srcBuffer.getElementSize(), CudaConstants.cudaMemcpyHostToDevice, (Pointer)context.getOldStream()) == 0) {
                throw new ND4JIllegalStateException("memcpyAsync failed");
            }
            direction = MemcpyDirection.HOST_TO_DEVICE;
        }
        dstPoint.tickDeviceWrite();
        context.syncOldStream();
        PerformanceTracker.getInstance().helperRegisterTransaction(srcPoint.getDeviceId(), profDH / 2L, dstPoint.getNumberOfBytes(), direction);
    }

    @Override
    public Pointer getDevicePointer(DataBuffer buffer, CudaContext context) {
        AllocationPoint dstPoint = ((BaseCudaDataBuffer)buffer).getAllocationPoint();
        if (dstPoint.getAllocationStatus() == AllocationStatus.DEVICE && !dstPoint.isActualOnDeviceSide()) {
            throw new UnsupportedOperationException("Pew-pew");
        }
        if (dstPoint.getDevicePointer() == null) {
            return null;
        }
        CudaPointer p = new CudaPointer(dstPoint.getDevicePointer(), buffer.length(), 0L);
        if (OpProfiler.getInstance().getConfig().isCheckLocality()) {
            NativeOpsHolder.getInstance().getDeviceNativeOps().tryPointer((Pointer)context.getOldStream(), (Pointer)p, 1);
        }
        switch (buffer.dataType()) {
            case DOUBLE: {
                return p.asDoublePointer();
            }
            case FLOAT: {
                return p.asFloatPointer();
            }
            case UINT32: 
            case INT: {
                return p.asIntPointer();
            }
            case SHORT: 
            case UINT16: 
            case HALF: 
            case BFLOAT16: {
                return p.asShortPointer();
            }
            case UINT64: 
            case LONG: {
                return p.asLongPointer();
            }
            case UTF8: 
            case UBYTE: 
            case BYTE: {
                return p.asBytePointer();
            }
            case BOOL: {
                return p.asBooleanPointer();
            }
        }
        return p;
    }

    @Override
    public Pointer getHostPointer(DataBuffer buffer) {
        AllocationPoint dstPoint = ((BaseCudaDataBuffer)buffer).getAllocationPoint();
        if (dstPoint.getHostPointer() == null) {
            return null;
        }
        this.synchronizeThreadDevice(Thread.currentThread().getId(), dstPoint.getDeviceId(), dstPoint);
        CudaPointer p = new CudaPointer(dstPoint.getHostPointer(), buffer.length(), 0L);
        switch (buffer.dataType()) {
            case DOUBLE: {
                return p.asDoublePointer();
            }
            case FLOAT: {
                return p.asFloatPointer();
            }
            case UINT32: 
            case INT: {
                return p.asIntPointer();
            }
            case SHORT: 
            case UINT16: 
            case HALF: 
            case BFLOAT16: {
                return p.asShortPointer();
            }
            case UINT64: 
            case LONG: {
                return p.asLongPointer();
            }
        }
        return p;
    }

    @Override
    public synchronized void relocateObject(DataBuffer buffer) {
        AllocationPoint dstPoint = AtomicAllocator.getInstance().getAllocationPoint(buffer);
        throw new UnsupportedOperationException("Pew-pew");
    }

    @Override
    public boolean promoteObject(DataBuffer buffer) {
        AllocationPoint dstPoint = AtomicAllocator.getInstance().getAllocationPoint(buffer);
        throw new UnsupportedOperationException("Pew-pew");
    }

    @Override
    public Table<AllocationStatus, Integer, Long> getAllocationStatistics() {
        HashBasedTable table = HashBasedTable.create();
        table.put((Object)AllocationStatus.HOST, (Object)0, (Object)this.zeroUseCounter.get());
        for (Integer deviceId : configuration.getAvailableDevices()) {
            table.put((Object)AllocationStatus.DEVICE, (Object)deviceId, (Object)this.getAllocatedDeviceMemory(deviceId));
        }
        return table;
    }

    @Override
    public long getAllocatedDeviceMemory(Integer device) {
        return this.deviceMemoryTracker.getAllocatedSize(device);
    }

    @Override
    public long getAllocatedHostMemory() {
        return this.zeroUseCounter.get();
    }

    @Override
    public long getAllocatedDeviceObjects(Integer deviceId) {
        return this.deviceAllocations.get(deviceId).size();
    }

    @Override
    public long getAllocatedHostObjects(Long bucketId) {
        if (this.zeroAllocations.containsKey(bucketId)) {
            return this.zeroAllocations.get(bucketId).size();
        }
        return 0L;
    }

    @Override
    public long getAllocatedHostObjects() {
        AtomicLong counter = new AtomicLong(0L);
        for (Long threadId : this.zeroAllocations.keySet()) {
            counter.addAndGet(this.zeroAllocations.get(threadId).size());
        }
        return counter.get();
    }

    @Override
    public Set<Long> getDeviceTrackingPoints(Integer deviceId) {
        return this.deviceAllocations.get(deviceId).keySet();
    }

    @Override
    public Set<Long> getHostTrackingPoints(Long bucketId) {
        if (!this.zeroAllocations.containsKey(bucketId)) {
            return new HashSet<Long>();
        }
        return this.zeroAllocations.get(bucketId).keySet();
    }

    @Override
    public void purgeDeviceObject(Long threadId, Integer deviceId, Long objectId, AllocationPoint point, boolean copyback) {
        if (point.getAllocationStatus() != AllocationStatus.DEVICE) {
            return;
        }
        this.flowController.waitTillReleased(point);
        this.free(point, AllocationStatus.DEVICE);
        if (!this.deviceAllocations.get(deviceId).containsKey(objectId)) {
            throw new IllegalStateException("Can't happen ever");
        }
        this.forget(point, AllocationStatus.DEVICE);
        if (this.deviceAllocations.get(deviceId).containsKey(objectId)) {
            throw new IllegalStateException("Can't happen ever");
        }
        point.setAllocationStatus(AllocationStatus.HOST);
    }

    @Override
    public void purgeZeroObject(Long bucketId, Long objectId, AllocationPoint point, boolean copyback) {
        throw new UnsupportedOperationException("Pew-pew");
    }

    @Override
    public void forget(AllocationPoint point, AllocationStatus location) {
        if (location == AllocationStatus.DEVICE) {
            this.deviceAllocations.get(point.getDeviceId()).remove(point.getObjectId());
        } else if (location == AllocationStatus.HOST && point.getHostPointer() != null) {
            this.zeroAllocations.get(point.getBucketId()).remove(point.getObjectId());
        }
    }

    @Override
    public Integer getDeviceId() {
        int deviceId = Nd4j.getAffinityManager().getDeviceForCurrentThread();
        return deviceId;
    }

    @Override
    public Pointer getDeviceIdPointer() {
        return new CudaPointer(this.getDeviceId().intValue());
    }

    @Override
    public Set<Integer> getAvailableDevices() {
        return new HashSet<Integer>(configuration.getAvailableDevices());
    }

    @Override
    public CudaContext getDeviceContext() {
        return this.getCudaContext();
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    protected cublasHandle_t getCudaCublasHandle(OpaqueLaunchContext lc) {
        Integer deviceId = Nd4j.getAffinityManager().getDeviceForCurrentThread();
        try {
            this.lock.writeLock().lock();
            if (this.cublasHandles.get(deviceId) == null) {
                this.cublasHandles.remove(deviceId);
                this.cublasHandles.add(deviceId, new cublasHandle_t(this.nativeOps.lcBlasHandle(lc)));
            }
            cublasHandle_t cublasHandle_t2 = this.cublasHandles.get(deviceId);
            return cublasHandle_t2;
        }
        finally {
            this.lock.writeLock().unlock();
        }
    }

    @Override
    public CudaContext getCudaContext() {
        CudaContext ctx = this.tlContext.get();
        if (ctx == null) {
            OpaqueLaunchContext lc = this.nativeOps.defaultLaunchContext();
            ctx = CudaContext.builder().bufferScalar(this.nativeOps.lcScalarPointer(lc)).bufferReduction(this.nativeOps.lcReductionPointer(lc)).bufferAllocation(this.nativeOps.lcAllocationPointer(lc)).bufferSpecial(this.nativeOps.lcScalarPointer(lc)).oldStream(new cudaStream_t(this.nativeOps.lcExecutionStream(lc))).specialStream(new cudaStream_t(this.nativeOps.lcCopyStream(lc))).cublasHandle(this.getCudaCublasHandle(lc)).solverHandle(new cusolverDnHandle_t(this.nativeOps.lcSolverHandle(lc))).build();
            this.tlContext.set(ctx);
            return ctx;
        }
        return ctx;
    }

    @Override
    public void resetCachedContext() {
        this.tlContext.remove();
    }

    @Override
    public boolean isDeviceDependant() {
        return true;
    }

    @Override
    public void synchronizeThreadDevice(Long threadId, Integer deviceId, AllocationPoint point) {
        this.flowController.synchronizeToHost(point);
    }

    @Override
    public void registerAction(CudaContext context, INDArray result, INDArray ... operands) {
        this.flowController.registerAction(context, result, operands);
    }

    @Override
    public FlowController getFlowController() {
        return this.flowController;
    }

    @Override
    public MemoryProvider getMemoryProvider() {
        return null;
    }
}

