/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.jita.concurrency;

import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import lombok.NonNull;
import org.nd4j.jita.allocator.impl.AllocationPoint;
import org.nd4j.jita.allocator.impl.AtomicAllocator;
import org.nd4j.jita.conf.CudaEnvironment;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.concurrency.AffinityManager;
import org.nd4j.linalg.api.concurrency.BasicAffinityManager;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.jcublas.buffer.BaseCudaDataBuffer;
import org.nd4j.nativeblas.NativeOpsHolder;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class CudaAffinityManager
extends BasicAffinityManager {
    private static Logger logger = LoggerFactory.getLogger(CudaAffinityManager.class);
    private Map<Long, Integer> affinityMap = new ConcurrentHashMap<Long, Integer>();
    private AtomicInteger devPtr = new AtomicInteger(0);
    private ThreadLocal<AtomicBoolean> affiliated = new ThreadLocal();
    private AtomicInteger numberOfDevices = new AtomicInteger(-1);

    public Integer getDeviceForCurrentThread() {
        return NativeOpsHolder.getInstance().getDeviceNativeOps().getDevice();
    }

    public Integer getDeviceForThread(long threadId) {
        Integer id = this.affinityMap.get(threadId);
        if (id == null) {
            if (threadId == Thread.currentThread().getId()) {
                id = NativeOpsHolder.getInstance().getDeviceNativeOps().getDevice();
                this.affinityMap.put(threadId, id);
            } else {
                throw new RuntimeException("Affinity for thread [" + threadId + "] wasn't defined yet");
            }
        }
        return id;
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    protected Integer getNextDevice(long threadId) {
        Integer device = null;
        if (!CudaEnvironment.getInstance().getConfiguration().isForcedSingleGPU() && this.getNumberOfDevices() > 0) {
            CudaAffinityManager cudaAffinityManager = this;
            synchronized (cudaAffinityManager) {
                Thread t;
                device = CudaEnvironment.getInstance().getConfiguration().getAvailableDevices().get(this.devPtr.getAndIncrement());
                if (this.devPtr.get() >= CudaEnvironment.getInstance().getConfiguration().getAvailableDevices().size()) {
                    this.devPtr.set(0);
                }
                String n = (t = Thread.currentThread()).getId() == threadId ? t.getName() : "N/A";
                logger.debug("Mapping thread [{} - {}] to device [{}], out of [{}] devices...", new Object[]{threadId, n, device, CudaEnvironment.getInstance().getConfiguration().getAvailableDevices().size()});
            }
        } else {
            device = CudaEnvironment.getInstance().getConfiguration().getAvailableDevices().get(0);
            logger.debug("Single device is forced, mapping to device [{}]", (Object)device);
        }
        return device;
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public int getNumberOfDevices() {
        if (this.numberOfDevices.get() < 0) {
            CudaAffinityManager cudaAffinityManager = this;
            synchronized (cudaAffinityManager) {
                if (this.numberOfDevices.get() < 1) {
                    this.numberOfDevices.set(NativeOpsHolder.getInstance().getDeviceNativeOps().getAvailableDevices());
                }
            }
        }
        return this.numberOfDevices.get();
    }

    public void touch(INDArray array) {
        if (array == null) {
            return;
        }
        this.touch(array.data());
        this.touch(array.shapeInfoDataBuffer());
    }

    public void touch(DataBuffer buffer) {
        if (buffer == null) {
            return;
        }
        AllocationPoint point = AtomicAllocator.getInstance().getAllocationPoint(buffer);
        if (point.isConstant()) {
            Nd4j.getConstantHandler().relocateConstantSpace(buffer);
        } else {
            AtomicAllocator.getInstance().getMemoryHandler().relocateObject(buffer);
        }
    }

    public synchronized INDArray replicateToDevice(Integer deviceId, INDArray array) {
        if (array == null) {
            return null;
        }
        if (array.isS()) {
            return array.dup(array.ordering());
        }
        if (array.isView()) {
            throw new UnsupportedOperationException("It's impossible to replicate View");
        }
        long[] shape = array.shape();
        long[] stride = array.stride();
        int elementWiseStride = array.elementWiseStride();
        char ordering = array.ordering();
        long length = array.length();
        DataType dtype = array.dataType();
        boolean empty = array.isEmpty();
        AtomicAllocator.getInstance().getPointer(array, AtomicAllocator.getInstance().getDeviceContext());
        int currentDeviceId = this.getDeviceForCurrentThread();
        if (currentDeviceId != deviceId) {
            this.unsafeSetDevice(deviceId);
        }
        DataBuffer newDataBuffer = this.replicateToDevice(deviceId, array.data());
        DataBuffer newShapeBuffer = (DataBuffer)Nd4j.getShapeInfoProvider().createShapeInformation(shape, stride, (long)elementWiseStride, ordering, dtype, empty).getFirst();
        INDArray result = Nd4j.createArrayFromShapeBuffer((DataBuffer)newDataBuffer, (DataBuffer)newShapeBuffer);
        if (currentDeviceId != deviceId) {
            this.unsafeSetDevice(currentDeviceId);
        }
        return result;
    }

    public DataBuffer replicateToDevice(Integer deviceId, DataBuffer buffer) {
        if (buffer == null) {
            return null;
        }
        int currentDeviceId = Nd4j.getAffinityManager().getDeviceForCurrentThread();
        if (currentDeviceId != deviceId) {
            Nd4j.getAffinityManager().unsafeSetDevice(deviceId);
        }
        DataBuffer dstBuffer = Nd4j.createBuffer((DataType)buffer.dataType(), (long)buffer.length(), (boolean)false);
        AtomicAllocator.getInstance().memcpy(dstBuffer, buffer);
        if (currentDeviceId != deviceId) {
            Nd4j.getAffinityManager().unsafeSetDevice(Integer.valueOf(currentDeviceId));
        }
        return dstBuffer;
    }

    public void tagLocation(INDArray array, AffinityManager.Location location) {
        if (array.isEmpty()) {
            return;
        }
        if (location == AffinityManager.Location.HOST) {
            AtomicAllocator.getInstance().getAllocationPoint(array).tickHostWrite();
        } else if (location == AffinityManager.Location.DEVICE) {
            AtomicAllocator.getInstance().getAllocationPoint(array).tickDeviceWrite();
        } else if (location == AffinityManager.Location.EVERYWHERE) {
            AtomicAllocator.getInstance().getAllocationPoint(array).tickDeviceWrite();
            AtomicAllocator.getInstance().getAllocationPoint(array).tickHostRead();
        }
    }

    public void tagLocation(DataBuffer buffer, AffinityManager.Location location) {
        if (location == AffinityManager.Location.HOST) {
            AtomicAllocator.getInstance().getAllocationPoint(buffer).tickHostWrite();
        } else if (location == AffinityManager.Location.DEVICE) {
            AtomicAllocator.getInstance().getAllocationPoint(buffer).tickDeviceWrite();
        } else if (location == AffinityManager.Location.EVERYWHERE) {
            AtomicAllocator.getInstance().getAllocationPoint(buffer).tickDeviceWrite();
            AtomicAllocator.getInstance().getAllocationPoint(buffer).tickHostRead();
        }
    }

    public Integer getDeviceForArray(@NonNull INDArray array) {
        if (array == null) {
            throw new NullPointerException("array is marked non-null but is null");
        }
        return AtomicAllocator.getInstance().getDeviceId(array);
    }

    public void unsafeSetDevice(Integer deviceId) {
        NativeOpsHolder.getInstance().getDeviceNativeOps().setDevice(deviceId.intValue());
        AtomicAllocator.getInstance().getMemoryHandler().resetCachedContext();
    }

    public void ensureLocation(INDArray array, AffinityManager.Location location) {
        if (array == null || array.isEmpty() || array.isS()) {
            return;
        }
        ((BaseCudaDataBuffer)array.data()).lazyAllocateHostPointer();
        AllocationPoint point = AtomicAllocator.getInstance().getAllocationPoint(array);
        switch (location) {
            case HOST: {
                AtomicAllocator.getInstance().synchronizeHostData(array);
                break;
            }
            case DEVICE: {
                AtomicAllocator.getInstance().getFlowController().synchronizeToDevice(point);
                break;
            }
            default: {
                AtomicAllocator.getInstance().synchronizeHostData(array);
                AtomicAllocator.getInstance().getFlowController().synchronizeToDevice(point);
            }
        }
    }

    public AffinityManager.Location getActiveLocation(INDArray array) {
        if (array.isEmpty()) {
            return AffinityManager.Location.EVERYWHERE;
        }
        AllocationPoint point = AtomicAllocator.getInstance().getAllocationPoint(array);
        if (point.isActualOnDeviceSide() && point.isActualOnHostSide()) {
            return AffinityManager.Location.EVERYWHERE;
        }
        if (point.isActualOnDeviceSide()) {
            return AffinityManager.Location.DEVICE;
        }
        return AffinityManager.Location.HOST;
    }

    public boolean isCrossDeviceAccessSupported() {
        return NativeOpsHolder.getInstance().getDeviceNativeOps().isP2PAvailable() && CudaEnvironment.getInstance().getConfiguration().isCrossDeviceAccessAllowed();
    }

    public void allowCrossDeviceAccess(boolean reallyAllow) {
        CudaEnvironment.getInstance().getConfiguration().allowCrossDeviceAccess(reallyAllow);
    }
}

