/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.jita.workspace;

import java.util.List;
import java.util.Queue;
import lombok.NonNull;
import org.bytedeco.javacpp.Pointer;
import org.nd4j.allocator.impl.MemoryTracker;
import org.nd4j.jita.allocator.impl.AllocationShape;
import org.nd4j.jita.allocator.impl.AtomicAllocator;
import org.nd4j.jita.workspace.CudaWorkspaceDeallocator;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.memory.AllocationsTracker;
import org.nd4j.linalg.api.memory.Deallocator;
import org.nd4j.linalg.api.memory.abstracts.Nd4jWorkspace;
import org.nd4j.linalg.api.memory.conf.WorkspaceConfiguration;
import org.nd4j.linalg.api.memory.enums.AllocationKind;
import org.nd4j.linalg.api.memory.enums.DebugMode;
import org.nd4j.linalg.api.memory.enums.LocationPolicy;
import org.nd4j.linalg.api.memory.enums.MemoryKind;
import org.nd4j.linalg.api.memory.enums.MirroringPolicy;
import org.nd4j.linalg.api.memory.enums.ResetPolicy;
import org.nd4j.linalg.api.memory.enums.SpillPolicy;
import org.nd4j.linalg.api.memory.pointers.PagedPointer;
import org.nd4j.linalg.api.memory.pointers.PointersPair;
import org.nd4j.linalg.exception.ND4JIllegalStateException;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.jcublas.context.CudaContext;
import org.nd4j.nativeblas.NativeOpsHolder;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class CudaWorkspace
extends Nd4jWorkspace {
    private static final Logger log = LoggerFactory.getLogger(CudaWorkspace.class);

    public CudaWorkspace(@NonNull WorkspaceConfiguration configuration) {
        super(configuration);
        if (configuration == null) {
            throw new NullPointerException("configuration is marked non-null but is null");
        }
    }

    public CudaWorkspace(@NonNull WorkspaceConfiguration configuration, @NonNull String workspaceId) {
        super(configuration, workspaceId);
        if (configuration == null) {
            throw new NullPointerException("configuration is marked non-null but is null");
        }
        if (workspaceId == null) {
            throw new NullPointerException("workspaceId is marked non-null but is null");
        }
    }

    public CudaWorkspace(@NonNull WorkspaceConfiguration configuration, @NonNull String workspaceId, Integer deviceId) {
        super(configuration, workspaceId);
        if (configuration == null) {
            throw new NullPointerException("configuration is marked non-null but is null");
        }
        if (workspaceId == null) {
            throw new NullPointerException("workspaceId is marked non-null but is null");
        }
        this.deviceId = deviceId;
    }

    protected void init() {
        if (this.workspaceConfiguration.getPolicyLocation() == LocationPolicy.MMAP) {
            throw new ND4JIllegalStateException("CUDA do not support MMAP workspaces yet");
        }
        super.init();
        if (this.currentSize.get() > 0L) {
            Pointer ptr;
            this.isInit.set(true);
            long bytes = this.currentSize.get();
            if (this.isDebug.get()) {
                log.info("Allocating [{}] workspace on device_{}, {} bytes...", new Object[]{this.id, Nd4j.getAffinityManager().getDeviceForCurrentThread(), bytes});
            }
            if (this.isDebug.get()) {
                Nd4j.getWorkspaceManager().printAllocationStatisticsForCurrentThread();
            }
            if ((ptr = this.memoryManager.allocate(bytes + 1024L, MemoryKind.HOST, false)) == null) {
                throw new ND4JIllegalStateException("Can't allocate memory for workspace");
            }
            this.workspace.setHostPointer(new PagedPointer(ptr));
            if (this.workspaceConfiguration.getPolicyMirroring() != MirroringPolicy.HOST_ONLY) {
                this.workspace.setDevicePointer(new PagedPointer(this.memoryManager.allocate(bytes + 1024L, MemoryKind.DEVICE, false)));
                AllocationsTracker.getInstance().markAllocated(AllocationKind.GENERAL, Nd4j.getAffinityManager().getDeviceForCurrentThread(), bytes + 1024L);
                MemoryTracker.getInstance().incrementWorkspaceAllocatedAmount(Nd4j.getAffinityManager().getDeviceForCurrentThread().intValue(), bytes + 1024L);
                long addr = this.workspace.getDevicePointer().address();
                long div = addr % 32L;
                if (div != 0L) {
                    this.deviceOffset.set(32L - div);
                    this.hostOffset.set(32L - div);
                }
            }
        }
    }

    public PagedPointer alloc(long requiredMemory, DataType type, boolean initialize) {
        return this.alloc(requiredMemory, MemoryKind.DEVICE, type, initialize);
    }

    public synchronized void destroyWorkspace(boolean extended) {
        long size = this.currentSize.getAndSet(0L);
        this.reset();
        if (extended) {
            this.clearExternalAllocations();
        }
        this.clearPinnedAllocations(extended);
        if (this.workspace.getHostPointer() != null) {
            NativeOpsHolder.getInstance().getDeviceNativeOps().freeHost((Pointer)this.workspace.getHostPointer());
        }
        if (this.workspace.getDevicePointer() != null) {
            NativeOpsHolder.getInstance().getDeviceNativeOps().freeDevice((Pointer)this.workspace.getDevicePointer(), 0);
            AllocationsTracker.getInstance().markReleased(AllocationKind.GENERAL, Nd4j.getAffinityManager().getDeviceForCurrentThread(), size + 1024L);
            MemoryTracker.getInstance().decrementWorkspaceAmount(Nd4j.getAffinityManager().getDeviceForCurrentThread().intValue(), size + 1024L);
        }
        this.workspace.setDevicePointer(null);
        this.workspace.setHostPointer(null);
    }

    public PagedPointer alloc(long requiredMemory, MemoryKind kind, DataType type, boolean initialize) {
        boolean trimmer;
        long numElements = requiredMemory / (long)Nd4j.sizeOfDataType((DataType)type);
        if (requiredMemory % 32L != 0L) {
            requiredMemory += 32L - requiredMemory % 32L;
        }
        if (!this.isUsed.get()) {
            if (this.disabledCounter.incrementAndGet() % 10L == 0L) {
                log.warn("Worskpace was turned off, and wasn't enabled after {} allocations", (Object)this.disabledCounter.get());
            }
            if (kind == MemoryKind.DEVICE) {
                PagedPointer pointer = new PagedPointer(this.memoryManager.allocate(requiredMemory, MemoryKind.DEVICE, initialize), numElements);
                this.externalAllocations.add(new PointersPair(null, pointer));
                MemoryTracker.getInstance().incrementWorkspaceAllocatedAmount(Nd4j.getAffinityManager().getDeviceForCurrentThread().intValue(), requiredMemory);
                return pointer;
            }
            PagedPointer pointer = new PagedPointer(this.memoryManager.allocate(requiredMemory, MemoryKind.HOST, initialize), numElements);
            this.externalAllocations.add(new PointersPair(pointer, null));
            return pointer;
        }
        boolean bl = trimmer = this.workspaceConfiguration.getPolicyReset() == ResetPolicy.ENDOFBUFFER_REACHED && requiredMemory + this.cycleAllocations.get() > this.initialBlockSize.get() && this.initialBlockSize.get() > 0L && kind == MemoryKind.DEVICE || this.trimmedMode.get();
        if (trimmer && this.workspaceConfiguration.getPolicySpill() == SpillPolicy.REALLOCATE && !this.trimmedMode.get()) {
            this.trimmedMode.set(true);
            this.trimmedStep.set(this.stepsCount.get());
        }
        if (kind == MemoryKind.DEVICE) {
            if (this.deviceOffset.get() + requiredMemory <= this.currentSize.get() && !trimmer && Nd4j.getWorkspaceManager().getDebugMode() != DebugMode.SPILL_EVERYTHING) {
                this.cycleAllocations.addAndGet(requiredMemory);
                long prevOffset = this.deviceOffset.getAndAdd(requiredMemory);
                if (this.workspaceConfiguration.getPolicyMirroring() == MirroringPolicy.HOST_ONLY) {
                    return null;
                }
                PagedPointer ptr = this.workspace.getDevicePointer().withOffset(prevOffset, numElements);
                if (this.isDebug.get()) {
                    log.info("Workspace [{}] device_{}: alloc array of {} bytes, capacity of {} elements; prevOffset: {}; newOffset: {}; size: {}; address: {}", new Object[]{this.id, Nd4j.getAffinityManager().getDeviceForCurrentThread(), requiredMemory, numElements, prevOffset, this.deviceOffset.get(), this.currentSize.get(), ptr.address()});
                }
                if (initialize) {
                    CudaContext context = AtomicAllocator.getInstance().getDeviceContext();
                    int ret = NativeOpsHolder.getInstance().getDeviceNativeOps().memsetAsync((Pointer)ptr, 0, requiredMemory, 0, (Pointer)context.getSpecialStream());
                    if (ret == 0) {
                        throw new ND4JIllegalStateException("memset failed device_" + Nd4j.getAffinityManager().getDeviceForCurrentThread());
                    }
                    context.syncSpecialStream();
                }
                return ptr;
            }
            if (this.workspaceConfiguration.getPolicyReset() == ResetPolicy.ENDOFBUFFER_REACHED && this.currentSize.get() > 0L && !trimmer && Nd4j.getWorkspaceManager().getDebugMode() != DebugMode.SPILL_EVERYTHING) {
                this.deviceOffset.set(0L);
                this.resetPlanned.set(true);
                return this.alloc(requiredMemory, kind, type, initialize);
            }
            if (!trimmer) {
                this.spilledAllocationsSize.addAndGet(requiredMemory);
            } else {
                this.pinnedAllocationsSize.addAndGet(requiredMemory);
            }
            if (this.isDebug.get()) {
                log.info("Workspace [{}] device_{}: spilled DEVICE array of {} bytes, capacity of {} elements", new Object[]{this.id, Nd4j.getAffinityManager().getDeviceForCurrentThread(), requiredMemory, numElements});
            }
            AllocationShape shape = new AllocationShape(requiredMemory / (long)Nd4j.sizeOfDataType((DataType)type), Nd4j.sizeOfDataType((DataType)type), type);
            this.cycleAllocations.addAndGet(requiredMemory);
            if (this.workspaceConfiguration.getPolicyMirroring() == MirroringPolicy.HOST_ONLY) {
                return null;
            }
            switch (this.workspaceConfiguration.getPolicySpill()) {
                case REALLOCATE: 
                case EXTERNAL: {
                    if (!trimmer) {
                        this.externalCount.incrementAndGet();
                        PagedPointer pointer = new PagedPointer(this.memoryManager.allocate(requiredMemory, MemoryKind.DEVICE, initialize), numElements);
                        pointer.isLeaked();
                        PointersPair pp = new PointersPair(null, pointer);
                        pp.setRequiredMemory(Long.valueOf(requiredMemory));
                        this.externalAllocations.add(pp);
                        MemoryTracker.getInstance().incrementWorkspaceAllocatedAmount(Nd4j.getAffinityManager().getDeviceForCurrentThread().intValue(), requiredMemory);
                        return pointer;
                    }
                    this.pinnedCount.incrementAndGet();
                    PagedPointer pointer = new PagedPointer(this.memoryManager.allocate(requiredMemory, MemoryKind.DEVICE, initialize), numElements);
                    pointer.isLeaked();
                    this.pinnedAllocations.add(new PointersPair(Long.valueOf(this.stepsCount.get()), Long.valueOf(requiredMemory), null, pointer));
                    MemoryTracker.getInstance().incrementWorkspaceAllocatedAmount(Nd4j.getAffinityManager().getDeviceForCurrentThread().intValue(), requiredMemory);
                    return pointer;
                }
            }
            throw new ND4JIllegalStateException("Can't allocate memory: Workspace is full");
        }
        if (kind == MemoryKind.HOST) {
            if (this.hostOffset.get() + requiredMemory <= this.currentSize.get() && !trimmer && Nd4j.getWorkspaceManager().getDebugMode() != DebugMode.SPILL_EVERYTHING) {
                long prevOffset = this.hostOffset.getAndAdd(requiredMemory);
                PagedPointer ptr = this.workspace.getHostPointer().withOffset(prevOffset, numElements);
                if (initialize) {
                    Pointer.memset((Pointer)ptr, (int)0, (long)requiredMemory);
                }
                return ptr;
            }
            if (this.workspaceConfiguration.getPolicyReset() == ResetPolicy.ENDOFBUFFER_REACHED && this.currentSize.get() > 0L && !trimmer && Nd4j.getWorkspaceManager().getDebugMode() != DebugMode.SPILL_EVERYTHING) {
                this.hostOffset.set(0L);
                return this.alloc(requiredMemory, kind, type, initialize);
            }
            AllocationShape shape = new AllocationShape(requiredMemory / (long)Nd4j.sizeOfDataType((DataType)type), Nd4j.sizeOfDataType((DataType)type), type);
            switch (this.workspaceConfiguration.getPolicySpill()) {
                case REALLOCATE: 
                case EXTERNAL: {
                    if (!trimmer) {
                        PagedPointer pointer = new PagedPointer(this.memoryManager.allocate(requiredMemory, MemoryKind.HOST, initialize), numElements);
                        this.externalAllocations.add(new PointersPair(pointer, null));
                        return pointer;
                    }
                    PagedPointer pointer = new PagedPointer(this.memoryManager.allocate(requiredMemory, MemoryKind.HOST, initialize), numElements);
                    pointer.isLeaked();
                    this.pinnedAllocations.add(new PointersPair(Long.valueOf(this.stepsCount.get()), Long.valueOf(0L), pointer, null));
                    return pointer;
                }
            }
            throw new ND4JIllegalStateException("Can't allocate memory: Workspace is full");
        }
        throw new ND4JIllegalStateException("Unknown MemoryKind was passed in: " + kind);
    }

    protected void clearPinnedAllocations(boolean extended) {
        if (this.isDebug.get()) {
            log.info("Workspace [{}] device_{} threadId {} cycle {}: clearing pinned allocations...", new Object[]{this.id, Nd4j.getAffinityManager().getDeviceForCurrentThread(), Thread.currentThread().getId(), this.cyclesCount.get()});
        }
        while (!this.pinnedAllocations.isEmpty()) {
            PointersPair pair = (PointersPair)this.pinnedAllocations.peek();
            if (pair == null) {
                throw new RuntimeException();
            }
            long stepNumber = pair.getAllocationCycle();
            long stepCurrent = this.stepsCount.get();
            if (this.isDebug.get()) {
                log.info("Allocation step: {}; Current step: {}", (Object)stepNumber, (Object)stepCurrent);
            }
            if (stepNumber + 2L >= stepCurrent && !extended) break;
            this.pinnedAllocations.remove();
            if (pair.getDevicePointer() != null) {
                NativeOpsHolder.getInstance().getDeviceNativeOps().freeDevice((Pointer)pair.getDevicePointer(), 0);
                MemoryTracker.getInstance().decrementWorkspaceAmount(Nd4j.getAffinityManager().getDeviceForCurrentThread().intValue(), pair.getRequiredMemory().longValue());
                this.pinnedCount.decrementAndGet();
                if (this.isDebug.get()) {
                    log.info("deleting external device allocation ");
                }
            }
            if (pair.getHostPointer() != null) {
                NativeOpsHolder.getInstance().getDeviceNativeOps().freeHost((Pointer)pair.getHostPointer());
                if (this.isDebug.get()) {
                    log.info("deleting external host allocation ");
                }
            }
            long sizez = pair.getRequiredMemory() * -1L;
            this.pinnedAllocationsSize.addAndGet(sizez);
        }
    }

    protected void clearExternalAllocations() {
        if (this.isDebug.get()) {
            log.info("Workspace [{}] device_{} threadId {} guid [{}]: clearing external allocations...", new Object[]{this.id, Nd4j.getAffinityManager().getDeviceForCurrentThread(), Thread.currentThread().getId(), this.guid});
        }
        Nd4j.getExecutioner().commit();
        try {
            for (PointersPair pair : this.externalAllocations) {
                Long sizez;
                if (pair.getHostPointer() != null) {
                    NativeOpsHolder.getInstance().getDeviceNativeOps().freeHost((Pointer)pair.getHostPointer());
                    if (this.isDebug.get()) {
                        log.info("deleting external host allocation... ");
                    }
                }
                if (pair.getDevicePointer() == null) continue;
                NativeOpsHolder.getInstance().getDeviceNativeOps().freeDevice((Pointer)pair.getDevicePointer(), 0);
                if (this.isDebug.get()) {
                    log.info("deleting external device allocation... ");
                }
                if ((sizez = pair.getRequiredMemory()) == null) continue;
                AllocationsTracker.getInstance().markReleased(AllocationKind.GENERAL, Nd4j.getAffinityManager().getDeviceForCurrentThread(), sizez.longValue());
                MemoryTracker.getInstance().decrementWorkspaceAmount(Nd4j.getAffinityManager().getDeviceForCurrentThread().intValue(), sizez.longValue());
            }
        }
        catch (Exception e) {
            log.error("RC: Workspace [{}] device_{} threadId {} guid [{}]: clearing external allocations...", new Object[]{this.id, Nd4j.getAffinityManager().getDeviceForCurrentThread(), Thread.currentThread().getId(), this.guid});
            throw new RuntimeException(e);
        }
        this.spilledAllocationsSize.set(0L);
        this.externalCount.set(0);
        this.externalAllocations.clear();
    }

    protected void resetWorkspace() {
        if (this.currentSize.get() < 1L) {
            return;
        }
    }

    protected PointersPair workspace() {
        return this.workspace;
    }

    protected Queue<PointersPair> pinnedPointers() {
        return this.pinnedAllocations;
    }

    protected List<PointersPair> externalPointers() {
        return this.externalAllocations;
    }

    public Deallocator deallocator() {
        return new CudaWorkspaceDeallocator(this);
    }

    public String getUniqueId() {
        return "Workspace_" + this.getId() + "_" + Nd4j.getDeallocatorService().nextValue();
    }

    public int targetDevice() {
        return this.deviceId;
    }

    public long getPrimaryOffset() {
        return this.getDeviceOffset();
    }
}

