/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.linalg.jcublas.ops.executioner;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Properties;
import lombok.NonNull;
import org.bytedeco.javacpp.BooleanPointer;
import org.bytedeco.javacpp.DoublePointer;
import org.bytedeco.javacpp.IntPointer;
import org.bytedeco.javacpp.LongPointer;
import org.bytedeco.javacpp.Pointer;
import org.bytedeco.javacpp.PointerPointer;
import org.bytedeco.javacpp.indexer.LongIndexer;
import org.nd4j.common.base.Preconditions;
import org.nd4j.common.primitives.AtomicBoolean;
import org.nd4j.common.primitives.Pair;
import org.nd4j.common.util.ArrayUtil;
import org.nd4j.jita.allocator.impl.AtomicAllocator;
import org.nd4j.jita.allocator.pointers.CudaPointer;
import org.nd4j.jita.allocator.tad.DeviceTADManager;
import org.nd4j.jita.conf.CudaEnvironment;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.concurrency.AffinityManager;
import org.nd4j.linalg.api.memory.pointers.PagedPointer;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ndarray.INDArrayStatistics;
import org.nd4j.linalg.api.ops.BaseReduceBoolOp;
import org.nd4j.linalg.api.ops.BaseReduceOp;
import org.nd4j.linalg.api.ops.BroadcastOp;
import org.nd4j.linalg.api.ops.CustomOp;
import org.nd4j.linalg.api.ops.CustomOpDescriptor;
import org.nd4j.linalg.api.ops.IndexAccumulation;
import org.nd4j.linalg.api.ops.Op;
import org.nd4j.linalg.api.ops.OpContext;
import org.nd4j.linalg.api.ops.RandomOp;
import org.nd4j.linalg.api.ops.ReduceOp;
import org.nd4j.linalg.api.ops.ScalarOp;
import org.nd4j.linalg.api.ops.TransformOp;
import org.nd4j.linalg.api.ops.aggregates.Aggregate;
import org.nd4j.linalg.api.ops.aggregates.Batch;
import org.nd4j.linalg.api.ops.executioner.DefaultOpExecutioner;
import org.nd4j.linalg.api.ops.executioner.OpExecutioner;
import org.nd4j.linalg.api.ops.executioner.OpStatus;
import org.nd4j.linalg.api.ops.impl.scatter.ScatterUpdate;
import org.nd4j.linalg.api.ops.impl.summarystats.Variance;
import org.nd4j.linalg.api.ops.performance.PerformanceTracker;
import org.nd4j.linalg.api.ops.random.BaseRandomOp;
import org.nd4j.linalg.api.rng.Random;
import org.nd4j.linalg.api.shape.LongShapeDescriptor;
import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.api.shape.TadPack;
import org.nd4j.linalg.api.shape.options.ArrayOptionsHelper;
import org.nd4j.linalg.api.shape.options.ArrayType;
import org.nd4j.linalg.cache.TADManager;
import org.nd4j.linalg.exception.ND4JIllegalStateException;
import org.nd4j.linalg.exception.ND4JOpProfilerException;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.jcublas.bindings.Nd4jCuda;
import org.nd4j.linalg.jcublas.buffer.AddressRetriever;
import org.nd4j.linalg.jcublas.buffer.BaseCudaDataBuffer;
import org.nd4j.linalg.jcublas.buffer.CudaLongDataBuffer;
import org.nd4j.linalg.jcublas.buffer.CudaUtf8Buffer;
import org.nd4j.linalg.jcublas.context.CudaContext;
import org.nd4j.linalg.jcublas.ops.executioner.CudaOpContext;
import org.nd4j.nativeblas.LongPointerWrapper;
import org.nd4j.nativeblas.NativeOps;
import org.nd4j.nativeblas.NativeOpsHolder;
import org.nd4j.nativeblas.OpaqueConstantDataBuffer;
import org.nd4j.nativeblas.OpaqueConstantShapeBuffer;
import org.nd4j.nativeblas.OpaqueDataBuffer;
import org.nd4j.nativeblas.OpaqueShapeList;
import org.nd4j.nativeblas.OpaqueTadPack;
import org.nd4j.nativeblas.OpaqueVariable;
import org.nd4j.nativeblas.OpaqueVariablesSet;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class CudaExecutioner
extends DefaultOpExecutioner {
    private static final Logger log = LoggerFactory.getLogger(CudaExecutioner.class);
    protected static NativeOps nativeOps = NativeOpsHolder.getInstance().getDeviceNativeOps();
    protected static TADManager tadManager = new DeviceTADManager();
    protected ThreadLocal<PointerPointer> extraz = new ThreadLocal();
    protected volatile transient Properties properties;
    protected ThreadLocal<String> lastOp = new ThreadLocal();
    protected Map<String, CustomOpDescriptor> customOps = null;
    protected AtomicBoolean experimentalMode = new AtomicBoolean(false);

    public CudaExecutioner() {
        this.experimentalMode.set(nativeOps.isExperimentalEnabled());
    }

    public NativeOps getNativeOps() {
        return nativeOps;
    }

    public String getLastOp() {
        return this.lastOp.get();
    }

    public INDArray exec(BroadcastOp op) {
        long st = this.profilingConfigurableHookIn((Op)op, new DataBuffer[0]);
        this.checkForCompression((Op)op);
        int[] dimension = op.dimensions().toIntVector();
        if (this.extraz.get() == null) {
            this.extraz.set(new PointerPointer(32L));
        }
        CudaContext context = AtomicAllocator.getInstance().getDeviceContext();
        if (CudaEnvironment.getInstance().getConfiguration().isDebug()) {
            this.lastOp.set(op.opName());
        }
        Pointer hostYShapeInfo = op.y() == null ? null : AddressRetriever.retrieveHostPointer(op.y().shapeInfoDataBuffer());
        Pointer hostZShapeInfo = op.z() == null ? null : AddressRetriever.retrieveHostPointer(op.z().shapeInfoDataBuffer());
        OpaqueDataBuffer x = op.x() == null ? null : ((BaseCudaDataBuffer)op.x().data()).getOpaqueDataBuffer();
        OpaqueDataBuffer y = op.y() == null ? null : ((BaseCudaDataBuffer)op.y().data()).getOpaqueDataBuffer();
        OpaqueDataBuffer z = op.z() == null ? null : ((BaseCudaDataBuffer)op.z().data()).getOpaqueDataBuffer();
        Pointer xShapeInfo = AtomicAllocator.getInstance().getPointer(op.x().shapeInfoDataBuffer(), context);
        Pair tadBuffers = tadManager.getTADOnlyShapeInfo(op.x(), dimension);
        Pointer hostTadShapeInfo = AddressRetriever.retrieveHostPointer((DataBuffer)tadBuffers.getFirst());
        Pointer devTadShapeInfo = AtomicAllocator.getInstance().getPointer((DataBuffer)tadBuffers.getFirst(), context);
        DataBuffer offsets = (DataBuffer)tadBuffers.getSecond();
        Pointer devTadOffsets = AtomicAllocator.getInstance().getPointer(offsets, context);
        Pointer devTadShapeInfoZ = null;
        Pointer devTadOffsetsZ = null;
        Pair tadBuffersZ = tadManager.getTADOnlyShapeInfo(op.z(), dimension);
        devTadShapeInfoZ = AtomicAllocator.getInstance().getPointer((DataBuffer)tadBuffersZ.getFirst(), context);
        devTadOffsetsZ = AtomicAllocator.getInstance().getPointer((DataBuffer)tadBuffersZ.getSecond(), context);
        PointerPointer xShapeInfoHostPointer = this.extraz.get().put(new Pointer[]{AddressRetriever.retrieveHostPointer(op.x().shapeInfoDataBuffer()), context.getOldStream(), AtomicAllocator.getInstance().getDeviceIdPointer(), context.getBufferAllocation(), context.getBufferReduction(), context.getBufferScalar(), context.getBufferSpecial(), hostYShapeInfo, hostZShapeInfo, hostTadShapeInfo, devTadShapeInfo, devTadOffsets, devTadShapeInfoZ, devTadOffsetsZ});
        Pointer dimensionPointer = AtomicAllocator.getInstance().getPointer(AtomicAllocator.getInstance().getConstantBuffer(dimension), context);
        switch (op.getOpType()) {
            case BROADCAST: {
                nativeOps.execBroadcast(xShapeInfoHostPointer, op.opNum(), x, (LongPointer)AtomicAllocator.getInstance().getHostPointer(op.x().shapeInfoDataBuffer()), (LongPointer)xShapeInfo, y, (LongPointer)AtomicAllocator.getInstance().getHostPointer(op.y().shapeInfoDataBuffer()), (LongPointer)AtomicAllocator.getInstance().getPointer(op.y().shapeInfoDataBuffer(), context), z, (LongPointer)AtomicAllocator.getInstance().getHostPointer(op.z().shapeInfoDataBuffer()), (LongPointer)AtomicAllocator.getInstance().getPointer(op.z().shapeInfoDataBuffer(), context), ((BaseCudaDataBuffer)op.dimensions().data()).getOpaqueDataBuffer(), (LongPointer)AtomicAllocator.getInstance().getHostPointer(op.dimensions().shapeInfoDataBuffer()), (LongPointer)AtomicAllocator.getInstance().getPointer(op.dimensions().shapeInfoDataBuffer(), context));
                break;
            }
            case BROADCAST_BOOL: {
                nativeOps.execBroadcastBool(xShapeInfoHostPointer, op.opNum(), x, (LongPointer)AtomicAllocator.getInstance().getHostPointer(op.x().shapeInfoDataBuffer()), (LongPointer)xShapeInfo, y, (LongPointer)AtomicAllocator.getInstance().getHostPointer(op.y().shapeInfoDataBuffer()), (LongPointer)AtomicAllocator.getInstance().getPointer(op.y().shapeInfoDataBuffer(), context), z, (LongPointer)AtomicAllocator.getInstance().getHostPointer(op.z().shapeInfoDataBuffer()), (LongPointer)AtomicAllocator.getInstance().getPointer(op.z().shapeInfoDataBuffer(), context), null, ((BaseCudaDataBuffer)op.dimensions().data()).getOpaqueDataBuffer(), (LongPointer)AtomicAllocator.getInstance().getHostPointer(op.dimensions().shapeInfoDataBuffer()), (LongPointer)AtomicAllocator.getInstance().getPointer(op.dimensions().shapeInfoDataBuffer(), context));
                break;
            }
            default: {
                throw new UnsupportedOperationException("Unknown op type: " + op.getOpType());
            }
        }
        if (nativeOps.lastErrorCode() != 0) {
            throw new RuntimeException(nativeOps.lastErrorMessage());
        }
        this.profilingConfigurableHookOut((Op)op, null, st);
        return op.z();
    }

    protected INDArray naiveExec(ReduceOp op, int ... dimension) {
        long st;
        block33: {
            OpaqueDataBuffer z;
            OpaqueDataBuffer x;
            Pointer extraArgs;
            PointerPointer xShapeInfoHostPointer;
            Pointer xShapeInfo;
            Pointer hostZShapeInfo;
            Pointer hostXShapeInfo;
            CudaContext context;
            block35: {
                INDArray ret;
                block34: {
                    OpaqueDataBuffer y;
                    Pointer yDevTadShapeInfo;
                    Pointer yDevTadOffsets;
                    Pointer devTadOffsets;
                    Pointer devTadShapeInfo;
                    Pointer hostYShapeInfo;
                    block32: {
                        DataType argsType;
                        st = this.profilingConfigurableHookIn((Op)op, new DataBuffer[0]);
                        if (op instanceof BaseReduceOp && ((BaseReduceOp)op).isEmptyReduce()) {
                            if (op.z() != null) {
                                Preconditions.checkState((boolean)op.x().equalShapes(op.z()), (String)"For empty reductions, result (z) array must have same shape as x shape. Got: x=%ndShape, z=%ndShape", (Object)op.x(), (Object)op.z());
                                op.z().assign(op.x());
                                return op.z();
                            }
                            op.setZ(op.x().dup());
                            return op.z();
                        }
                        ret = op.z();
                        this.checkForCompression((Op)op);
                        op.validateDataTypes(null);
                        for (int i = 0; i < dimension.length; ++i) {
                            if (dimension[i] < op.x().rank() || dimension[i] == Integer.MAX_VALUE) continue;
                            throw new ND4JIllegalStateException("Op target dimension " + Arrays.toString(dimension) + " contains element that higher then rank of op.X: [" + op.x().rank() + "]");
                        }
                        context = AtomicAllocator.getInstance().getDeviceContext();
                        if (CudaEnvironment.getInstance().getConfiguration().isDebug()) {
                            this.lastOp.set(op.opName());
                        }
                        hostXShapeInfo = op.x() == null ? null : AddressRetriever.retrieveHostPointer(op.x().shapeInfoDataBuffer());
                        hostYShapeInfo = op.y() == null ? null : AddressRetriever.retrieveHostPointer(op.y().shapeInfoDataBuffer());
                        hostZShapeInfo = op.z() == null ? null : AddressRetriever.retrieveHostPointer(op.z().shapeInfoDataBuffer());
                        Pair tadBuffers = tadManager.getTADOnlyShapeInfo(op.x(), dimension);
                        Pointer hostTadShapeInfo = AddressRetriever.retrieveHostPointer((DataBuffer)tadBuffers.getFirst());
                        devTadShapeInfo = AtomicAllocator.getInstance().getPointer((DataBuffer)tadBuffers.getFirst(), context);
                        DataBuffer offsets = (DataBuffer)tadBuffers.getSecond();
                        devTadOffsets = offsets == null ? null : AtomicAllocator.getInstance().getPointer(offsets, context);
                        xShapeInfo = AtomicAllocator.getInstance().getPointer(op.x().shapeInfoDataBuffer(), context);
                        if (this.extraz.get() == null) {
                            this.extraz.set(new PointerPointer(32L));
                        }
                        xShapeInfoHostPointer = this.extraz.get().put(new Pointer[]{AddressRetriever.retrieveHostPointer(op.x().shapeInfoDataBuffer()), context.getOldStream(), AtomicAllocator.getInstance().getDeviceIdPointer(), context.getBufferAllocation(), context.getBufferReduction(), context.getBufferScalar(), context.getBufferSpecial(), hostYShapeInfo, hostZShapeInfo, hostTadShapeInfo, devTadShapeInfo, devTadOffsets});
                        yDevTadOffsets = null;
                        yDevTadShapeInfo = null;
                        if (op.y() != null) {
                            if (dimension.length == 0 || dimension.length == 1 && dimension[0] == Integer.MAX_VALUE || op.x().tensorAlongDimension(0L, dimension).length() != op.y().length()) {
                                if (!op.isComplexAccumulation() && op.x().length() != op.y().length()) {
                                    throw new ND4JIllegalStateException("Op.X [" + op.x().length() + "] and Op.Y [" + op.y().length() + "] lengths should match");
                                }
                                if (!op.z().isScalar()) {
                                    Pair yTadBuffers = tadManager.getTADOnlyShapeInfo(op.y(), dimension);
                                    yDevTadShapeInfo = AtomicAllocator.getInstance().getPointer((DataBuffer)yTadBuffers.getFirst(), context);
                                    DataBuffer yOffsets = (DataBuffer)yTadBuffers.getSecond();
                                    yDevTadOffsets = yOffsets == null ? null : AtomicAllocator.getInstance().getPointer(yOffsets, context);
                                    xShapeInfoHostPointer.put(12L, yDevTadShapeInfo);
                                    xShapeInfoHostPointer.put(13L, yDevTadOffsets);
                                }
                            } else {
                                DataBuffer fakeOffsets = Nd4j.getConstantHandler().getConstantBuffer(new int[]{0, 0}, DataType.LONG);
                                yDevTadOffsets = fakeOffsets == null ? null : AtomicAllocator.getInstance().getPointer(fakeOffsets, context);
                                yDevTadShapeInfo = AtomicAllocator.getInstance().getPointer(op.y().shapeInfoDataBuffer(), context);
                                xShapeInfoHostPointer.put(12L, AtomicAllocator.getInstance().getPointer(op.y().shapeInfoDataBuffer(), context));
                                xShapeInfoHostPointer.put(13L, null);
                            }
                        }
                        switch (op.getOpType()) {
                            case REDUCE_LONG: 
                            case REDUCE_BOOL: {
                                argsType = op.x().dataType();
                                break;
                            }
                            default: {
                                argsType = op.z().dataType();
                            }
                        }
                        extraArgs = op.extraArgs() != null ? AtomicAllocator.getInstance().getPointer(op.extraArgsDataBuff(argsType), context) : null;
                        Pointer dimensionPointer = AtomicAllocator.getInstance().getPointer(AtomicAllocator.getInstance().getConstantBuffer(dimension), context);
                        x = op.x() == null ? null : ((BaseCudaDataBuffer)op.x().data()).getOpaqueDataBuffer();
                        y = op.y() == null ? null : ((BaseCudaDataBuffer)op.y().data()).getOpaqueDataBuffer();
                        OpaqueDataBuffer opaqueDataBuffer = z = op.z() == null ? null : ((BaseCudaDataBuffer)op.z().data()).getOpaqueDataBuffer();
                        if (!(op instanceof Variance)) break block32;
                        if (ret.isScalar()) {
                            nativeOps.execSummaryStatsScalar(xShapeInfoHostPointer, op.opNum(), x, (LongPointer)hostXShapeInfo, (LongPointer)xShapeInfo, extraArgs, z, (LongPointer)hostZShapeInfo, (LongPointer)AtomicAllocator.getInstance().getPointer(op.z().shapeInfoDataBuffer()), ((Variance)op).isBiasCorrected());
                        } else {
                            nativeOps.execSummaryStatsTad(xShapeInfoHostPointer, op.opNum(), x, (LongPointer)hostXShapeInfo, (LongPointer)xShapeInfo, extraArgs, z, (LongPointer)hostZShapeInfo, (LongPointer)AtomicAllocator.getInstance().getPointer(op.z().shapeInfoDataBuffer(), context), ((BaseCudaDataBuffer)op.dimensions().data()).getOpaqueDataBuffer(), (LongPointer)op.dimensions().shapeInfoDataBuffer().addressPointer(), null, ((Variance)op).isBiasCorrected(), (LongPointer)devTadShapeInfo, (LongPointer)devTadOffsets);
                        }
                        break block33;
                    }
                    if (op.y() == null) break block34;
                    if (op.isComplexAccumulation()) {
                        LongPointerWrapper dT = new LongPointerWrapper(devTadOffsets);
                        LongPointerWrapper yT = new LongPointerWrapper(yDevTadOffsets);
                        nativeOps.execReduce3All(xShapeInfoHostPointer, op.opNum(), x, (LongPointer)hostXShapeInfo, (LongPointer)xShapeInfo, extraArgs, y, (LongPointer)hostYShapeInfo, (LongPointer)AtomicAllocator.getInstance().getPointer(op.y().shapeInfoDataBuffer(), context), z, (LongPointer)hostZShapeInfo, (LongPointer)AtomicAllocator.getInstance().getPointer(op.z().shapeInfoDataBuffer(), context), ((BaseCudaDataBuffer)op.dimensions().data()).getOpaqueDataBuffer(), (LongPointer)op.dimensions().shapeInfoDataBuffer().addressPointer(), null, (LongPointer)devTadShapeInfo, (LongPointer)dT, (LongPointer)yDevTadShapeInfo, (LongPointer)yT);
                    } else if (ret.isScalar()) {
                        nativeOps.execReduce3Scalar(xShapeInfoHostPointer, op.opNum(), x, (LongPointer)hostXShapeInfo, (LongPointer)xShapeInfo, extraArgs, y, (LongPointer)hostYShapeInfo, (LongPointer)AtomicAllocator.getInstance().getPointer(op.y().shapeInfoDataBuffer(), context), z, (LongPointer)hostZShapeInfo, (LongPointer)AtomicAllocator.getInstance().getPointer(op.z().shapeInfoDataBuffer(), context));
                    } else {
                        nativeOps.execReduce3Tad(xShapeInfoHostPointer, op.opNum(), x, (LongPointer)hostXShapeInfo, (LongPointer)xShapeInfo, extraArgs, y, (LongPointer)hostYShapeInfo, (LongPointer)AtomicAllocator.getInstance().getPointer(op.y().shapeInfoDataBuffer(), context), z, (LongPointer)hostZShapeInfo, (LongPointer)AtomicAllocator.getInstance().getPointer(op.z().shapeInfoDataBuffer(), context), ((BaseCudaDataBuffer)op.dimensions().data()).getOpaqueDataBuffer(), (LongPointer)op.dimensions().shapeInfoDataBuffer().addressPointer(), null, (LongPointer)devTadShapeInfo, (LongPointer)devTadOffsets, (LongPointer)yDevTadShapeInfo, (LongPointer)yDevTadOffsets);
                    }
                    break block33;
                }
                if (!ret.isScalar()) break block35;
                switch (op.getOpType()) {
                    case REDUCE_FLOAT: {
                        nativeOps.execReduceFloat(xShapeInfoHostPointer, op.opNum(), x, (LongPointer)hostXShapeInfo, (LongPointer)xShapeInfo, extraArgs, z, (LongPointer)hostZShapeInfo, (LongPointer)AtomicAllocator.getInstance().getPointer(op.z().shapeInfoDataBuffer()));
                        break block33;
                    }
                    case REDUCE_BOOL: {
                        nativeOps.execReduceBool(xShapeInfoHostPointer, op.opNum(), x, (LongPointer)hostXShapeInfo, (LongPointer)xShapeInfo, extraArgs, z, (LongPointer)hostZShapeInfo, (LongPointer)AtomicAllocator.getInstance().getPointer(op.z().shapeInfoDataBuffer()));
                        break block33;
                    }
                    case REDUCE_LONG: {
                        nativeOps.execReduceLong(xShapeInfoHostPointer, op.opNum(), x, (LongPointer)hostXShapeInfo, (LongPointer)xShapeInfo, extraArgs, z, (LongPointer)hostZShapeInfo, (LongPointer)AtomicAllocator.getInstance().getPointer(op.z().shapeInfoDataBuffer()));
                        break block33;
                    }
                    case REDUCE_SAME: {
                        nativeOps.execReduceSame(xShapeInfoHostPointer, op.opNum(), x, (LongPointer)hostXShapeInfo, (LongPointer)xShapeInfo, extraArgs, z, (LongPointer)hostZShapeInfo, (LongPointer)AtomicAllocator.getInstance().getPointer(op.z().shapeInfoDataBuffer()));
                        break block33;
                    }
                    default: {
                        throw new UnsupportedOperationException();
                    }
                }
            }
            switch (op.getOpType()) {
                case REDUCE_FLOAT: {
                    nativeOps.execReduceFloat2(xShapeInfoHostPointer, op.opNum(), x, (LongPointer)hostXShapeInfo, (LongPointer)xShapeInfo, extraArgs, z, (LongPointer)hostZShapeInfo, (LongPointer)AtomicAllocator.getInstance().getPointer(op.z().shapeInfoDataBuffer(), context), ((BaseCudaDataBuffer)op.dimensions().data()).getOpaqueDataBuffer(), (LongPointer)op.dimensions().shapeInfoDataBuffer().addressPointer(), null);
                    break;
                }
                case REDUCE_BOOL: {
                    nativeOps.execReduceBool2(xShapeInfoHostPointer, op.opNum(), x, (LongPointer)hostXShapeInfo, (LongPointer)xShapeInfo, extraArgs, z, (LongPointer)hostZShapeInfo, (LongPointer)AtomicAllocator.getInstance().getPointer(op.z().shapeInfoDataBuffer(), context), ((BaseCudaDataBuffer)op.dimensions().data()).getOpaqueDataBuffer(), (LongPointer)op.dimensions().shapeInfoDataBuffer().addressPointer(), null);
                    break;
                }
                case REDUCE_SAME: {
                    nativeOps.execReduceSame2(xShapeInfoHostPointer, op.opNum(), x, (LongPointer)hostXShapeInfo, (LongPointer)xShapeInfo, extraArgs, z, (LongPointer)hostZShapeInfo, (LongPointer)AtomicAllocator.getInstance().getPointer(op.z().shapeInfoDataBuffer(), context), ((BaseCudaDataBuffer)op.dimensions().data()).getOpaqueDataBuffer(), (LongPointer)op.dimensions().shapeInfoDataBuffer().addressPointer(), null);
                    break;
                }
                case REDUCE_LONG: {
                    nativeOps.execReduceLong2(xShapeInfoHostPointer, op.opNum(), x, (LongPointer)hostXShapeInfo, (LongPointer)xShapeInfo, extraArgs, z, (LongPointer)hostZShapeInfo, (LongPointer)AtomicAllocator.getInstance().getPointer(op.z().shapeInfoDataBuffer(), context), ((BaseCudaDataBuffer)op.dimensions().data()).getOpaqueDataBuffer(), (LongPointer)op.dimensions().shapeInfoDataBuffer().addressPointer(), null);
                    break;
                }
                default: {
                    throw new UnsupportedOperationException();
                }
            }
        }
        if (nativeOps.lastErrorCode() != 0) {
            throw new RuntimeException(nativeOps.lastErrorMessage());
        }
        this.profilingConfigurableHookOut((Op)op, null, st);
        return op.z();
    }

    public INDArray exec(Variance op) {
        return this.exec((ReduceOp)op);
    }

    public INDArray exec(ReduceOp op) {
        boolean wholeDims;
        this.checkForCompression((Op)op);
        if (op instanceof BaseReduceOp && ((BaseReduceOp)op).isEmptyReduce()) {
            if (op.z() != null) {
                Preconditions.checkState((boolean)op.x().equalShapes(op.z()), (String)"For empty reductions, result (z) array must have same shape as x shape. Got: x=%ndShape, z=%ndShape", (Object)op.x(), (Object)op.z());
                op.z().assign(op.x());
                return op.z();
            }
            op.setZ(op.x().dup());
            return op.z();
        }
        int[] dimension = op.dimensions().toIntVector();
        if (this.extraz.get() == null) {
            this.extraz.set(new PointerPointer(32L));
        }
        long[] maxShape = Shape.getMaxShape((INDArray[])new INDArray[]{op.x(), op.y()});
        boolean bl = wholeDims = Shape.wholeArrayDimension((int[])dimension) || op.x().rank() == dimension.length || dimension.length == 0;
        long[] retShape = Shape.reductionShape((INDArray)(op.y() == null ? op.x() : (op.x().length() > op.y().length() ? op.x() : op.y())), (int[])dimension, (boolean)true, (boolean)op.isKeepDims());
        if (op.x().isVector() && op.x().length() == (long)ArrayUtil.prod((long[])retShape) && ArrayUtil.prodLong((long[])retShape) > 1L && op.y() == null) {
            return op.noOp();
        }
        DataType dtype = op.resultType();
        INDArray ret = null;
        if (op.z() == null || op.z() == op.x()) {
            if (op.isComplexAccumulation()) {
                long xT = op.x().tensorsAlongDimension(dimension);
                long yT = op.y().tensorsAlongDimension(dimension);
                ret = Nd4j.createUninitialized((DataType)dtype, (long[])new long[]{xT, yT});
            } else {
                if (op.y() != null) {
                    if (op.x().length() == op.y().length()) {
                        if (!wholeDims && op.x().tensorsAlongDimension(dimension) != op.y().tensorsAlongDimension(dimension)) {
                            throw new ND4JIllegalStateException("Number of TADs along dimension don't match: (x shape = " + Arrays.toString(op.x().shape()) + ", y shape = " + Arrays.toString(op.y().shape()) + ", dimension = " + Arrays.toString(dimension) + ")");
                        }
                    } else {
                        if (dimension.length == 0) {
                            throw new ND4JIllegalStateException("TAD vs TAD comparison requires dimension (or other comparison mode was supposed to be used?)");
                        }
                        long xTADSize = op.x().length() / op.x().tensorsAlongDimension(dimension);
                        if (xTADSize != op.y().length()) {
                            throw new ND4JIllegalStateException("Size of TADs along dimension don't match for pairwise execution: (x TAD size = " + xTADSize + ", y size = " + op.y().length());
                        }
                    }
                }
                ret = Nd4j.create((DataType)dtype, (long[])retShape);
            }
            op.setZ(ret);
        } else if (op.z().length() != (retShape.length == 0 ? 1L : ArrayUtil.prodLong((long[])retShape))) {
            throw new ND4JIllegalStateException("Shape of target array for reduction [" + Arrays.toString(op.z().shape()) + "] doesn't match expected [" + Arrays.toString(retShape) + "]");
        }
        long st = this.profilingConfigurableHookIn((Op)op, new DataBuffer[0]);
        this.naiveExec(op, dimension);
        this.profilingConfigurableHookOut((Op)op, null, st);
        return op.z();
    }

    public INDArray exec(IndexAccumulation op) {
        int[] dimension = Shape.normalizeAxis((int)op.x().rank(), (int[])op.dimensions().toIntVector());
        if (op.x().isEmpty()) {
            for (int d : dimension) {
                Preconditions.checkArgument((op.x().shape()[d] != 0L ? 1 : 0) != 0, (String)"IndexReduce can't be issued along axis with 0 in shape");
            }
        }
        if (op.z() == null) {
            long[] retShape = Shape.reductionShape((INDArray)op.x(), (int[])dimension, (boolean)true, (boolean)op.isKeepDims());
            op.setZ(Nd4j.createUninitialized((DataType)DataType.LONG, (long[])retShape));
        }
        long st = this.profilingConfigurableHookIn((Op)op, new DataBuffer[0]);
        this.checkForCompression((Op)op);
        if (this.extraz.get() == null) {
            this.extraz.set(new PointerPointer(32L));
        }
        if (op.x().isVector() && op.x().length() == op.z().length()) {
            return op.x();
        }
        if (op.z().isEmpty()) {
            return op.z();
        }
        if (CudaEnvironment.getInstance().getConfiguration().isDebug()) {
            this.lastOp.set(op.opName());
        }
        CudaContext context = AtomicAllocator.getInstance().getDeviceContext();
        Pointer hostXShapeInfo = op.x() == null ? null : AddressRetriever.retrieveHostPointer(op.x().shapeInfoDataBuffer());
        Pointer hostYShapeInfo = op.y() == null ? null : AddressRetriever.retrieveHostPointer(op.y().shapeInfoDataBuffer());
        Pointer hostZShapeInfo = op.z() == null ? null : AddressRetriever.retrieveHostPointer(op.z().shapeInfoDataBuffer());
        Pointer xShapeInfo = AtomicAllocator.getInstance().getPointer(op.x().shapeInfoDataBuffer(), context);
        Pointer zShapeInfo = AtomicAllocator.getInstance().getPointer(op.z().shapeInfoDataBuffer(), context);
        Pair tadBuffers = tadManager.getTADOnlyShapeInfo(op.x(), dimension);
        Pointer hostTadShapeInfo = AddressRetriever.retrieveHostPointer((DataBuffer)tadBuffers.getFirst());
        Pointer devTadShapeInfo = AtomicAllocator.getInstance().getPointer((DataBuffer)tadBuffers.getFirst(), context);
        DataBuffer offsets = (DataBuffer)tadBuffers.getSecond();
        Pointer devTadOffsets = offsets == null ? null : AtomicAllocator.getInstance().getPointer(offsets, context);
        PointerPointer xShapeInfoHostPointer = this.extraz.get().put(new Pointer[]{AddressRetriever.retrieveHostPointer(op.x().shapeInfoDataBuffer()), context.getOldStream(), AtomicAllocator.getInstance().getDeviceIdPointer(), context.getBufferAllocation(), context.getBufferReduction(), context.getBufferScalar(), context.getBufferSpecial(), hostYShapeInfo, hostZShapeInfo, hostTadShapeInfo, devTadShapeInfo, devTadOffsets});
        Pointer extraArgs = op.extraArgs() != null ? AtomicAllocator.getInstance().getPointer(op.extraArgsDataBuff(op.x().dataType()), context) : null;
        Pointer dimensionPointer = AtomicAllocator.getInstance().getPointer(AtomicAllocator.getInstance().getConstantBuffer(dimension), context);
        OpaqueDataBuffer x = op.x() == null ? null : ((BaseCudaDataBuffer)op.x().data()).getOpaqueDataBuffer();
        OpaqueDataBuffer y = op.y() == null ? null : ((BaseCudaDataBuffer)op.y().data()).getOpaqueDataBuffer();
        OpaqueDataBuffer z = op.z() == null ? null : ((BaseCudaDataBuffer)op.z().data()).getOpaqueDataBuffer();
        nativeOps.execIndexReduce(xShapeInfoHostPointer, op.opNum(), x, (LongPointer)hostXShapeInfo, (LongPointer)xShapeInfo, extraArgs, z, (LongPointer)hostZShapeInfo, (LongPointer)zShapeInfo, ((BaseCudaDataBuffer)op.dimensions().data()).getOpaqueDataBuffer(), (LongPointer)op.dimensions().shapeInfoDataBuffer().addressPointer(), null);
        if (nativeOps.lastErrorCode() != 0) {
            throw new RuntimeException(nativeOps.lastErrorMessage());
        }
        this.profilingConfigurableHookOut((Op)op, null, st);
        return op.z();
    }

    public INDArray exec(Op op) {
        return this.exec(op, null);
    }

    public INDArray exec(Op op, OpContext oc) {
        this.checkForCompression(op);
        if (op instanceof TransformOp) {
            TransformOp t = (TransformOp)op;
            this.invoke(t, oc);
        } else if (op instanceof ReduceOp) {
            ReduceOp acc = (ReduceOp)op;
            this.invoke(acc, oc, acc.dimensions().toIntVector());
        } else if (op instanceof ScalarOp) {
            ScalarOp sc = (ScalarOp)op;
            this.invoke(sc, oc);
        } else if (op instanceof BroadcastOp) {
            BroadcastOp broadcastOp = (BroadcastOp)op;
            this.invoke(broadcastOp, oc);
        } else if (op instanceof IndexAccumulation) {
            IndexAccumulation indexAccumulation = (IndexAccumulation)op;
            this.invoke(indexAccumulation, oc, indexAccumulation.dimensions().toIntVector());
        } else if (op instanceof RandomOp) {
            this.exec((RandomOp)op, oc, Nd4j.getRandom());
        } else if (op instanceof CustomOp) {
            this.exec((CustomOp)op, oc);
        }
        return op.z();
    }

    public TransformOp execAndReturn(TransformOp op) {
        this.checkForCompression((Op)op);
        this.invoke(op, null);
        return op;
    }

    protected CudaContext invoke(BroadcastOp op, OpContext oc) {
        long st = this.profilingConfigurableHookIn((Op)op, new DataBuffer[0]);
        INDArray x = this.getX((Op)op, oc);
        INDArray y = this.getY((Op)op, oc);
        INDArray z = this.getZ((Op)op, oc);
        this.checkForCompression((Op)op);
        if (this.extraz.get() == null) {
            this.extraz.set(new PointerPointer(32L));
        }
        CudaContext context = AtomicAllocator.getInstance().getDeviceContext();
        if (CudaEnvironment.getInstance().getConfiguration().isDebug()) {
            this.lastOp.set(op.opName());
        }
        Pointer xShapeInfo = AtomicAllocator.getInstance().getPointer(x.shapeInfoDataBuffer(), context);
        Pointer hostXShapeInfo = x == null ? null : AddressRetriever.retrieveHostPointer(x.shapeInfoDataBuffer());
        Pointer hostYShapeInfo = y == null ? null : AddressRetriever.retrieveHostPointer(y.shapeInfoDataBuffer());
        Pointer hostZShapeInfo = z == null ? null : AddressRetriever.retrieveHostPointer(z.shapeInfoDataBuffer());
        Pair tadBuffers = tadManager.getTADOnlyShapeInfo(x, op.getDimension());
        Pointer hostTadShapeInfo = AddressRetriever.retrieveHostPointer((DataBuffer)tadBuffers.getFirst());
        Pointer devTadShapeInfo = AtomicAllocator.getInstance().getPointer((DataBuffer)tadBuffers.getFirst(), context);
        DataBuffer offsets = (DataBuffer)tadBuffers.getSecond();
        Pointer devTadOffsets = AtomicAllocator.getInstance().getPointer(offsets, context);
        Pointer devTadShapeInfoZ = null;
        Pointer devTadOffsetsZ = null;
        Pair tadBuffersZ = tadManager.getTADOnlyShapeInfo(z, op.getDimension());
        devTadShapeInfoZ = AtomicAllocator.getInstance().getPointer((DataBuffer)tadBuffersZ.getFirst(), context);
        devTadOffsetsZ = AtomicAllocator.getInstance().getPointer((DataBuffer)tadBuffersZ.getSecond(), context);
        PointerPointer xShapeInfoHostPointer = this.extraz.get().put(new Pointer[]{AddressRetriever.retrieveHostPointer(x.shapeInfoDataBuffer()), context.getOldStream(), AtomicAllocator.getInstance().getDeviceIdPointer(), context.getBufferAllocation(), context.getBufferReduction(), context.getBufferScalar(), context.getBufferSpecial(), hostYShapeInfo, hostZShapeInfo, hostTadShapeInfo, devTadShapeInfo, devTadOffsets, devTadShapeInfoZ, devTadOffsetsZ});
        Pointer yShapeInfo = AtomicAllocator.getInstance().getPointer(y.shapeInfoDataBuffer(), context);
        Pointer zShapeInfo = AtomicAllocator.getInstance().getPointer(z.shapeInfoDataBuffer(), context);
        Pointer dimensionPointer = AtomicAllocator.getInstance().getPointer(AtomicAllocator.getInstance().getConstantBuffer(op.getDimension()), context);
        OpaqueDataBuffer xb = x == null ? null : ((BaseCudaDataBuffer)x.data()).getOpaqueDataBuffer();
        OpaqueDataBuffer yb = y == null ? null : ((BaseCudaDataBuffer)y.data()).getOpaqueDataBuffer();
        OpaqueDataBuffer zb = z == null ? null : ((BaseCudaDataBuffer)z.data()).getOpaqueDataBuffer();
        switch (op.getOpType()) {
            case BROADCAST: {
                nativeOps.execBroadcast(xShapeInfoHostPointer, op.opNum(), xb, (LongPointer)hostXShapeInfo, (LongPointer)xShapeInfo, yb, (LongPointer)hostYShapeInfo, (LongPointer)yShapeInfo, zb, (LongPointer)hostZShapeInfo, (LongPointer)zShapeInfo, ((BaseCudaDataBuffer)op.dimensions().data()).getOpaqueDataBuffer(), (LongPointer)op.dimensions().shapeInfoDataBuffer().addressPointer(), null);
                break;
            }
            case BROADCAST_BOOL: {
                nativeOps.execBroadcastBool(xShapeInfoHostPointer, op.opNum(), xb, (LongPointer)hostXShapeInfo, (LongPointer)xShapeInfo, yb, (LongPointer)hostYShapeInfo, (LongPointer)yShapeInfo, zb, (LongPointer)hostZShapeInfo, (LongPointer)zShapeInfo, null, ((BaseCudaDataBuffer)op.dimensions().data()).getOpaqueDataBuffer(), (LongPointer)op.dimensions().shapeInfoDataBuffer().addressPointer(), null);
                break;
            }
            default: {
                throw new UnsupportedOperationException("Unknown opType: " + op.getOpType());
            }
        }
        if (nativeOps.lastErrorCode() != 0) {
            throw new RuntimeException(nativeOps.lastErrorMessage());
        }
        this.profilingConfigurableHookOut((Op)op, oc, st);
        return null;
    }

    protected CudaContext invoke(IndexAccumulation op, OpContext oc, int[] dimension) {
        INDArray x = this.getX((Op)op, oc);
        INDArray y = this.getY((Op)op, oc);
        INDArray z = this.getZ((Op)op, oc);
        dimension = Shape.normalizeAxis((int)x.rank(), (int[])dimension);
        if ((dimension == null || dimension.length == 1 && dimension[0] == Integer.MAX_VALUE) && (z == x || z == null)) {
            z = Nd4j.createUninitialized((DataType)DataType.LONG, (long[])new long[0], (char)'c');
            this.setZ(z, (Op)op, oc);
        }
        boolean keepDims = op.isKeepDims();
        long[] retShape = Shape.reductionShape((INDArray)x, (int[])dimension, (boolean)true, (boolean)keepDims);
        if (z == null || x == z) {
            INDArray ret = Nd4j.createUninitialized((DataType)DataType.LONG, (long[])retShape);
            this.setZ(ret, (Op)op, oc);
            z = ret;
        } else if (!Arrays.equals(retShape, z.shape())) {
            throw new IllegalStateException("Z array shape does not match expected return type for op " + op + ": expected shape " + Arrays.toString(retShape) + ", z.shape()=" + Arrays.toString(z.shape()));
        }
        long st = this.profilingConfigurableHookIn((Op)op, new DataBuffer[0]);
        this.checkForCompression((Op)op);
        if (this.extraz.get() == null) {
            this.extraz.set(new PointerPointer(32L));
        }
        if (CudaEnvironment.getInstance().getConfiguration().isDebug()) {
            this.lastOp.set(op.opName());
        }
        CudaEnvironment.getInstance().getConfiguration().enableDebug(true);
        if (dimension != null) {
            for (int i = 0; i < dimension.length; ++i) {
                if (dimension[i] < x.rank() || dimension[i] == Integer.MAX_VALUE) continue;
                throw new ND4JIllegalStateException("Op target dimension " + Arrays.toString(dimension) + " contains element that higher then rank of op.X: [" + x.rank() + "]");
            }
        }
        CudaContext context = AtomicAllocator.getInstance().getDeviceContext();
        Pointer xShapeInfo = AtomicAllocator.getInstance().getPointer(x.shapeInfoDataBuffer(), context);
        Pointer extraArgs = op.extraArgs() != null ? AtomicAllocator.getInstance().getPointer(op.extraArgsDataBuff(x.dataType()), context) : null;
        Pointer hostXShapeInfo = x == null ? null : AddressRetriever.retrieveHostPointer(x.shapeInfoDataBuffer());
        Pointer hostYShapeInfo = y == null ? null : AddressRetriever.retrieveHostPointer(y.shapeInfoDataBuffer());
        Pointer hostZShapeInfo = z == null ? null : AddressRetriever.retrieveHostPointer(z.shapeInfoDataBuffer());
        int[] fdimension = dimension;
        if (fdimension == null) {
            fdimension = new int[]{0};
        }
        Pair tadBuffers = tadManager.getTADOnlyShapeInfo(x, fdimension);
        Pointer hostTadShapeInfo = AddressRetriever.retrieveHostPointer((DataBuffer)tadBuffers.getFirst());
        Pointer devTadShapeInfo = AtomicAllocator.getInstance().getPointer((DataBuffer)tadBuffers.getFirst(), context);
        DataBuffer offsets = (DataBuffer)tadBuffers.getSecond();
        Pointer devTadOffsets = offsets == null ? null : AtomicAllocator.getInstance().getPointer(offsets, context);
        Pointer zShapeInfo = AtomicAllocator.getInstance().getPointer(z.shapeInfoDataBuffer(), context);
        OpaqueDataBuffer xb = x == null ? null : ((BaseCudaDataBuffer)x.data()).getOpaqueDataBuffer();
        OpaqueDataBuffer zb = z == null ? null : ((BaseCudaDataBuffer)z.data()).getOpaqueDataBuffer();
        PointerPointer xShapeInfoHostPointer = this.extraz.get().put(new Pointer[]{AddressRetriever.retrieveHostPointer(x.shapeInfoDataBuffer()), context.getOldStream(), AtomicAllocator.getInstance().getDeviceIdPointer(), context.getBufferAllocation(), context.getBufferReduction(), context.getBufferScalar(), context.getBufferSpecial(), hostYShapeInfo, hostZShapeInfo, hostTadShapeInfo, devTadShapeInfo, devTadOffsets});
        if (z.isScalar() || dimension == null || dimension[0] == Integer.MAX_VALUE) {
            nativeOps.execIndexReduceScalar(xShapeInfoHostPointer, op.opNum(), xb, (LongPointer)hostXShapeInfo, (LongPointer)xShapeInfo, extraArgs, zb, (LongPointer)hostZShapeInfo, (LongPointer)zShapeInfo);
        } else {
            if (dimension != null && dimension.length > 1) {
                Arrays.sort(dimension);
            }
            Pointer dimensionPointer = AtomicAllocator.getInstance().getHostPointer(AtomicAllocator.getInstance().getConstantBuffer(dimension));
            nativeOps.execIndexReduce(xShapeInfoHostPointer, op.opNum(), xb, (LongPointer)hostXShapeInfo, (LongPointer)xShapeInfo, extraArgs, zb, (LongPointer)hostZShapeInfo, (LongPointer)zShapeInfo, ((BaseCudaDataBuffer)op.dimensions().data()).getOpaqueDataBuffer(), (LongPointer)op.dimensions().shapeInfoDataBuffer().addressPointer(), null);
        }
        if (nativeOps.lastErrorCode() != 0) {
            throw new RuntimeException(nativeOps.lastErrorMessage());
        }
        this.profilingConfigurableHookOut((Op)op, oc, st);
        return null;
    }

    protected CudaContext invoke(ReduceOp op, OpContext oc, int[] dimension) {
        long st;
        CudaContext context;
        block38: {
            OpaqueDataBuffer zb;
            OpaqueDataBuffer yb;
            OpaqueDataBuffer xb;
            Pointer zShapeInfo;
            Pointer yDevTadOffsets;
            Pointer yDevTadShapeInfo;
            PointerPointer xShapeInfoHostPointer;
            Pointer hostZShapeInfo;
            Pointer hostYShapeInfo;
            Pointer hostXShapeInfo;
            Pointer extraArgs;
            Pointer xShapeInfo;
            Pointer devTadOffsets;
            Pointer devTadShapeInfo;
            INDArray y;
            block36: {
                block39: {
                    block37: {
                        DataType dataType;
                        context = AtomicAllocator.getInstance().getDeviceContext();
                        INDArray x = this.getX((Op)op, oc);
                        y = this.getY((Op)op, oc);
                        INDArray z = this.getZ((Op)op, oc);
                        if (op instanceof BaseReduceOp && ((BaseReduceOp)op).isEmptyReduce()) {
                            if (z != null) {
                                Preconditions.checkState((boolean)x.equalShapes(z), (String)"For empty reductions, result (z) array must have same shape as x shape. Got: x=%ndShape, z=%ndShape", (Object)x, (Object)z);
                                z.assign(x);
                                return context;
                            }
                            op.setZ(x.dup());
                            return context;
                        }
                        if (op instanceof BaseReduceBoolOp && x.isEmpty() && (dimension == null || dimension.length == 1 && dimension[0] == Integer.MAX_VALUE)) {
                            if (z == null) {
                                op.setZ(Nd4j.scalar((boolean)((BaseReduceBoolOp)op).emptyValue()));
                            } else {
                                z.assign(((BaseReduceBoolOp)op).emptyValue());
                            }
                            return context;
                        }
                        st = this.profilingConfigurableHookIn((Op)op, new DataBuffer[0]);
                        this.checkForCompression((Op)op);
                        dimension = Shape.normalizeAxis((int)x.rank(), (int[])dimension);
                        if (this.extraz.get() == null) {
                            this.extraz.set(new PointerPointer(32L));
                        }
                        if (dimension == null) {
                            dimension = new int[]{Integer.MAX_VALUE};
                        }
                        if (dimension != null && dimension.length > 1) {
                            Arrays.sort(dimension);
                        }
                        for (int i = 0; i < dimension.length; ++i) {
                            if (dimension[i] < x.rank() || dimension[i] == Integer.MAX_VALUE) continue;
                            throw new ND4JIllegalStateException("Op target dimension " + Arrays.toString(dimension) + " contains element that higher then rank of op.X: [" + x.rank() + "]");
                        }
                        if (CudaEnvironment.getInstance().getConfiguration().isDebug()) {
                            this.lastOp.set(op.opName());
                        }
                        Pair tadBuffers = x.isEmpty() ? Pair.makePair((Object)x.data(), null) : tadManager.getTADOnlyShapeInfo(x, dimension);
                        Pointer hostTadShapeInfo = AddressRetriever.retrieveHostPointer((DataBuffer)tadBuffers.getFirst());
                        devTadShapeInfo = AtomicAllocator.getInstance().getPointer((DataBuffer)tadBuffers.getFirst(), context);
                        DataBuffer offsets = x.isEmpty() ? null : (DataBuffer)tadBuffers.getSecond();
                        devTadOffsets = offsets == null ? null : AtomicAllocator.getInstance().getPointer(offsets, context);
                        xShapeInfo = AtomicAllocator.getInstance().getPointer(x.shapeInfoDataBuffer(), context);
                        long[] retShape = Shape.reductionShape((INDArray)x, (int[])dimension, (boolean)true, (boolean)op.isKeepDims());
                        if (y != null) {
                            if (x.length() == y.length()) {
                                if (x.tensorsAlongDimension(dimension) != y.tensorsAlongDimension(dimension)) {
                                    throw new ND4JIllegalStateException("Number of TADs along dimension don't match: (x shape = " + Arrays.toString(x.shape()) + ", y shape = " + Arrays.toString(y.shape()) + ", dimension = " + Arrays.toString(dimension) + ")");
                                }
                            } else {
                                long xTADSize = x.length() / x.tensorsAlongDimension(dimension);
                                if (xTADSize != y.length()) {
                                    throw new ND4JIllegalStateException("Size of TADs along dimension don't match for pairwise execution: (x TAD size = " + xTADSize + ", y size = " + y.length());
                                }
                            }
                        }
                        DataType dataType2 = dataType = oc != null ? op.resultType(oc) : op.resultType();
                        if (z == null) {
                            INDArray ret = Nd4j.createUninitialized((DataType)dataType, (long[])retShape);
                            this.setZ(ret, (Op)op, oc);
                            z = ret;
                        } else if (z.dataType() != dataType || !Arrays.equals(retShape, z.shape())) {
                            throw new ND4JIllegalStateException("Output array for op " + op.getClass().getSimpleName() + " should have type " + dataType + " and shape " + Arrays.toString(retShape) + " but has datatype " + z.dataType() + " and shape " + Arrays.toString(z.shape()));
                        }
                        DataBuffer eb = op.extraArgsDataBuff(z.dataType() == DataType.BOOL || op.getOpType() == Op.Type.REDUCE_LONG ? x.dataType() : z.dataType());
                        extraArgs = op.extraArgs() != null ? AtomicAllocator.getInstance().getPointer(eb, context) : null;
                        hostXShapeInfo = x == null ? null : AddressRetriever.retrieveHostPointer(x.shapeInfoDataBuffer());
                        hostYShapeInfo = y == null ? null : AddressRetriever.retrieveHostPointer(y.shapeInfoDataBuffer());
                        hostZShapeInfo = z == null ? null : AddressRetriever.retrieveHostPointer(z.shapeInfoDataBuffer());
                        xShapeInfoHostPointer = this.extraz.get().put(new Pointer[]{AddressRetriever.retrieveHostPointer(x.shapeInfoDataBuffer()), context.getOldStream(), AtomicAllocator.getInstance().getDeviceIdPointer(), context.getBufferAllocation(), context.getBufferReduction(), context.getBufferScalar(), context.getBufferSpecial(), hostYShapeInfo, hostZShapeInfo, hostTadShapeInfo, devTadShapeInfo, devTadOffsets});
                        Pair yTadBuffers = y == null ? null : tadManager.getTADOnlyShapeInfo(y, dimension);
                        yDevTadShapeInfo = y == null ? null : AtomicAllocator.getInstance().getPointer((DataBuffer)yTadBuffers.getFirst(), context);
                        DataBuffer yOffsets = y == null ? null : (DataBuffer)yTadBuffers.getSecond();
                        Pointer pointer = yDevTadOffsets = yOffsets == null ? null : AtomicAllocator.getInstance().getPointer(yOffsets, context);
                        if (y != null) {
                            xShapeInfoHostPointer.put(12L, yDevTadShapeInfo);
                            xShapeInfoHostPointer.put(13L, yDevTadOffsets);
                        }
                        zShapeInfo = AtomicAllocator.getInstance().getPointer(z.shapeInfoDataBuffer(), context);
                        xb = x == null ? null : ((BaseCudaDataBuffer)x.data()).getOpaqueDataBuffer();
                        yb = y == null ? null : ((BaseCudaDataBuffer)y.data()).getOpaqueDataBuffer();
                        zb = z == null ? null : ((BaseCudaDataBuffer)z.data()).getOpaqueDataBuffer();
                        op.validateDataTypes(null);
                        if (!z.isScalar()) break block36;
                        if (!(op instanceof Variance)) break block37;
                        nativeOps.execSummaryStatsScalar(xShapeInfoHostPointer, op.opNum(), xb, (LongPointer)hostXShapeInfo, (LongPointer)xShapeInfo, extraArgs, zb, (LongPointer)hostZShapeInfo, (LongPointer)zShapeInfo, ((Variance)op).isBiasCorrected());
                        break block38;
                    }
                    if (y == null) break block39;
                    Pointer yShapeInfo = AtomicAllocator.getInstance().getPointer(y.shapeInfoDataBuffer(), context);
                    nativeOps.execReduce3Scalar(xShapeInfoHostPointer, op.opNum(), xb, (LongPointer)hostXShapeInfo, (LongPointer)xShapeInfo, extraArgs, yb, (LongPointer)hostYShapeInfo, (LongPointer)yShapeInfo, zb, (LongPointer)hostZShapeInfo, (LongPointer)zShapeInfo);
                    break block38;
                }
                switch (op.getOpType()) {
                    case REDUCE_FLOAT: {
                        nativeOps.execReduceFloat(xShapeInfoHostPointer, op.opNum(), xb, (LongPointer)hostXShapeInfo, (LongPointer)xShapeInfo, extraArgs, zb, (LongPointer)hostZShapeInfo, (LongPointer)zShapeInfo);
                        break block38;
                    }
                    case REDUCE_BOOL: {
                        nativeOps.execReduceBool(xShapeInfoHostPointer, op.opNum(), xb, (LongPointer)hostXShapeInfo, (LongPointer)xShapeInfo, extraArgs, zb, (LongPointer)hostZShapeInfo, (LongPointer)zShapeInfo);
                        break block38;
                    }
                    case REDUCE_SAME: {
                        nativeOps.execReduceSame(xShapeInfoHostPointer, op.opNum(), xb, (LongPointer)hostXShapeInfo, (LongPointer)xShapeInfo, extraArgs, zb, (LongPointer)hostZShapeInfo, (LongPointer)zShapeInfo);
                        break block38;
                    }
                    case REDUCE_LONG: {
                        nativeOps.execReduceLong(xShapeInfoHostPointer, op.opNum(), xb, (LongPointer)hostXShapeInfo, (LongPointer)xShapeInfo, extraArgs, zb, (LongPointer)hostZShapeInfo, (LongPointer)zShapeInfo);
                        break block38;
                    }
                    default: {
                        throw new UnsupportedOperationException();
                    }
                }
            }
            Pointer dimensionPointer = AtomicAllocator.getInstance().getPointer(AtomicAllocator.getInstance().getConstantBuffer(dimension), context);
            if (y != null) {
                Pointer yShapeInfo = AtomicAllocator.getInstance().getPointer(y.shapeInfoDataBuffer(), context);
                nativeOps.execReduce3Tad(xShapeInfoHostPointer, op.opNum(), xb, (LongPointer)hostXShapeInfo, (LongPointer)xShapeInfo, extraArgs, yb, (LongPointer)hostYShapeInfo, (LongPointer)yShapeInfo, zb, (LongPointer)hostZShapeInfo, (LongPointer)zShapeInfo, ((BaseCudaDataBuffer)op.dimensions().data()).getOpaqueDataBuffer(), (LongPointer)op.dimensions().shapeInfoDataBuffer().addressPointer(), null, (LongPointer)devTadShapeInfo, (LongPointer)devTadOffsets, (LongPointer)yDevTadShapeInfo, (LongPointer)yDevTadOffsets);
            } else if (op instanceof Variance) {
                nativeOps.execSummaryStatsTad(xShapeInfoHostPointer, op.opNum(), xb, (LongPointer)hostXShapeInfo, (LongPointer)xShapeInfo, extraArgs, zb, (LongPointer)hostZShapeInfo, (LongPointer)zShapeInfo, ((BaseCudaDataBuffer)op.dimensions().data()).getOpaqueDataBuffer(), (LongPointer)op.dimensions().shapeInfoDataBuffer().addressPointer(), null, ((Variance)op).isBiasCorrected(), (LongPointer)devTadShapeInfo, (LongPointer)devTadOffsets);
            } else {
                switch (op.getOpType()) {
                    case REDUCE_FLOAT: {
                        nativeOps.execReduceFloat2(xShapeInfoHostPointer, op.opNum(), xb, (LongPointer)hostXShapeInfo, (LongPointer)xShapeInfo, extraArgs, zb, (LongPointer)hostZShapeInfo, (LongPointer)zShapeInfo, ((BaseCudaDataBuffer)op.dimensions().data()).getOpaqueDataBuffer(), (LongPointer)op.dimensions().shapeInfoDataBuffer().addressPointer(), null);
                        break;
                    }
                    case REDUCE_SAME: {
                        nativeOps.execReduceSame2(xShapeInfoHostPointer, op.opNum(), xb, (LongPointer)hostXShapeInfo, (LongPointer)xShapeInfo, extraArgs, zb, (LongPointer)hostZShapeInfo, (LongPointer)zShapeInfo, ((BaseCudaDataBuffer)op.dimensions().data()).getOpaqueDataBuffer(), (LongPointer)op.dimensions().shapeInfoDataBuffer().addressPointer(), null);
                        break;
                    }
                    case REDUCE_BOOL: {
                        nativeOps.execReduceBool2(xShapeInfoHostPointer, op.opNum(), xb, (LongPointer)hostXShapeInfo, (LongPointer)xShapeInfo, extraArgs, zb, (LongPointer)hostZShapeInfo, (LongPointer)zShapeInfo, ((BaseCudaDataBuffer)op.dimensions().data()).getOpaqueDataBuffer(), (LongPointer)op.dimensions().shapeInfoDataBuffer().addressPointer(), null);
                        break;
                    }
                    case REDUCE_LONG: {
                        nativeOps.execReduceLong2(xShapeInfoHostPointer, op.opNum(), xb, (LongPointer)hostXShapeInfo, (LongPointer)xShapeInfo, extraArgs, zb, (LongPointer)hostZShapeInfo, (LongPointer)zShapeInfo, ((BaseCudaDataBuffer)op.dimensions().data()).getOpaqueDataBuffer(), (LongPointer)op.dimensions().shapeInfoDataBuffer().addressPointer(), null);
                        break;
                    }
                    default: {
                        throw new UnsupportedOperationException();
                    }
                }
            }
        }
        if (nativeOps.lastErrorCode() != 0) {
            throw new RuntimeException(nativeOps.lastErrorMessage());
        }
        this.profilingConfigurableHookOut((Op)op, oc, st);
        Nd4j.getExecutioner().commit();
        return context;
    }

    protected CudaContext intercept(ScalarOp op, int[] dimension) {
        long st = this.profilingConfigurableHookIn((Op)op, new DataBuffer[0]);
        if (dimension != null && dimension.length > 1) {
            Arrays.sort(dimension);
        }
        CudaContext context = AtomicAllocator.getInstance().getDeviceContext();
        if (CudaEnvironment.getInstance().getConfiguration().isDebug()) {
            this.lastOp.set(op.opName());
        }
        Pointer hostXShapeInfo = op.x() == null ? null : AddressRetriever.retrieveHostPointer(op.x().shapeInfoDataBuffer());
        Pointer hostYShapeInfo = op.y() == null ? null : AddressRetriever.retrieveHostPointer(op.y().shapeInfoDataBuffer());
        Pointer hostZShapeInfo = op.z() == null ? null : AddressRetriever.retrieveHostPointer(op.z().shapeInfoDataBuffer());
        Pointer xShapeInfo = AtomicAllocator.getInstance().getPointer(op.x().shapeInfoDataBuffer(), context);
        Pointer yShapeInfo = AtomicAllocator.getInstance().getPointer(op.y().shapeInfoDataBuffer(), context);
        Pointer zShapeInfo = AtomicAllocator.getInstance().getPointer(op.z().shapeInfoDataBuffer(), context);
        Pair tadBuffers = tadManager.getTADOnlyShapeInfo(op.x(), dimension);
        Pointer hostTadShapeInfo = AddressRetriever.retrieveHostPointer((DataBuffer)tadBuffers.getFirst());
        Pointer devTadShapeInfo = AtomicAllocator.getInstance().getPointer((DataBuffer)tadBuffers.getFirst(), context);
        DataBuffer offsets = (DataBuffer)tadBuffers.getSecond();
        Pointer devTadOffsets = AtomicAllocator.getInstance().getPointer(offsets, context);
        Pointer devTadShapeInfoZ = null;
        Pointer devTadOffsetsZ = null;
        Pair tadBuffersZ = tadManager.getTADOnlyShapeInfo(op.z(), dimension);
        devTadShapeInfoZ = AtomicAllocator.getInstance().getPointer((DataBuffer)tadBuffersZ.getFirst(), context);
        devTadOffsetsZ = AtomicAllocator.getInstance().getPointer((DataBuffer)tadBuffersZ.getSecond(), context);
        PointerPointer extraPointers = this.extraz.get().put(new Pointer[]{AddressRetriever.retrieveHostPointer(op.x().shapeInfoDataBuffer()), context.getOldStream(), AtomicAllocator.getInstance().getDeviceIdPointer(), context.getBufferAllocation(), context.getBufferReduction(), context.getBufferScalar(), context.getBufferSpecial(), hostYShapeInfo, hostZShapeInfo, hostTadShapeInfo, devTadShapeInfo, devTadOffsets, devTadShapeInfoZ, devTadOffsetsZ});
        Pointer extraArgs = op.extraArgs() != null ? AtomicAllocator.getInstance().getPointer(op.extraArgsDataBuff(op.z().dataType()), context) : null;
        Pointer dimensionPointer = AtomicAllocator.getInstance().getPointer(AtomicAllocator.getInstance().getConstantBuffer(dimension), context);
        OpaqueDataBuffer x = op.x() == null ? null : ((BaseCudaDataBuffer)op.x().data()).getOpaqueDataBuffer();
        OpaqueDataBuffer y = op.y() == null ? null : ((BaseCudaDataBuffer)op.y().data()).getOpaqueDataBuffer();
        OpaqueDataBuffer z = op.z() == null ? null : ((BaseCudaDataBuffer)op.z().data()).getOpaqueDataBuffer();
        switch (op.getOpType()) {
            case SCALAR: {
                nativeOps.execScalarTad(extraPointers, op.opNum(), x, (LongPointer)hostXShapeInfo, (LongPointer)xShapeInfo, z, (LongPointer)hostZShapeInfo, (LongPointer)zShapeInfo, y, (LongPointer)hostYShapeInfo, (LongPointer)yShapeInfo, extraArgs, ((BaseCudaDataBuffer)op.dimensions().data()).getOpaqueDataBuffer(), (LongPointer)op.dimensions().shapeInfoDataBuffer().addressPointer(), null, (LongPointer)devTadShapeInfo, (LongPointer)devTadOffsets, (LongPointer)devTadShapeInfoZ, (LongPointer)devTadOffsetsZ);
                break;
            }
            case SCALAR_BOOL: {
                nativeOps.execScalarBoolTad(extraPointers, op.opNum(), x, (LongPointer)hostXShapeInfo, (LongPointer)xShapeInfo, z, (LongPointer)hostZShapeInfo, (LongPointer)zShapeInfo, y, (LongPointer)hostYShapeInfo, (LongPointer)yShapeInfo, extraArgs, ((BaseCudaDataBuffer)op.dimensions().data()).getOpaqueDataBuffer(), (LongPointer)op.dimensions().shapeInfoDataBuffer().addressPointer(), null, (LongPointer)devTadShapeInfo, (LongPointer)devTadOffsets, (LongPointer)devTadShapeInfoZ, (LongPointer)devTadOffsetsZ);
                break;
            }
            default: {
                throw new UnsupportedOperationException();
            }
        }
        if (nativeOps.lastErrorCode() != 0) {
            throw new RuntimeException(nativeOps.lastErrorMessage());
        }
        this.profilingConfigurableHookOut((Op)op, null, st);
        return null;
    }

    public INDArray exec(ScalarOp op) {
        this.invoke(op, null);
        return op.z();
    }

    protected CudaContext invoke(ScalarOp op, OpContext oc) {
        long st = this.profilingConfigurableHookIn((Op)op, new DataBuffer[0]);
        this.checkForCompression((Op)op);
        INDArray x = this.getX((Op)op, oc);
        INDArray y = this.getY((Op)op, oc);
        INDArray z = this.getZ((Op)op, oc);
        if (z == null) {
            switch (op.getOpType()) {
                case SCALAR: {
                    z = x.ulike();
                    this.setZ(x.ulike(), (Op)op, oc);
                    break;
                }
                case SCALAR_BOOL: {
                    z = Nd4j.createUninitialized((DataType)DataType.BOOL, (long[])x.shape());
                    this.setZ(z, (Op)op, oc);
                    break;
                }
                default: {
                    throw new ND4JIllegalStateException("Unknown op type: [" + op.getOpType() + "]");
                }
            }
        }
        if (x.length() != z.length()) {
            throw new ND4JIllegalStateException("op.X length should be equal to op.Y length: [" + Arrays.toString(x.shapeInfoDataBuffer().asInt()) + "] != [" + Arrays.toString(z.shapeInfoDataBuffer().asInt()) + "]");
        }
        if (this.extraz.get() == null) {
            this.extraz.set(new PointerPointer(32L));
        }
        if (CudaEnvironment.getInstance().getConfiguration().isDebug()) {
            this.lastOp.set(op.opName());
        }
        if (op.dimensions() != null) {
            this.intercept(op, op.dimensions().toIntVector());
            return null;
        }
        CudaContext context = AtomicAllocator.getInstance().getDeviceContext();
        Pointer hostXShapeInfo = x == null ? null : AddressRetriever.retrieveHostPointer(x.shapeInfoDataBuffer());
        Pointer hostYShapeInfo = op.scalar() == null ? null : AddressRetriever.retrieveHostPointer(op.scalar().shapeInfoDataBuffer());
        Pointer hostZShapeInfo = z == null ? null : AddressRetriever.retrieveHostPointer(z.shapeInfoDataBuffer());
        Pointer xShapeInfo = AtomicAllocator.getInstance().getPointer(x.shapeInfoDataBuffer(), context);
        Pointer extraArgs = op.extraArgs() != null ? AtomicAllocator.getInstance().getPointer(op.extraArgsDataBuff(op.getOpType() == Op.Type.SCALAR_BOOL ? x.dataType() : z.dataType()), context) : null;
        Pointer zShapeInfo = AtomicAllocator.getInstance().getPointer(z.shapeInfoDataBuffer(), context);
        PointerPointer xShapeInfoHostPointer = this.extraz.get().put(new Pointer[]{AddressRetriever.retrieveHostPointer(x.shapeInfoDataBuffer()), context.getOldStream(), AtomicAllocator.getInstance().getDeviceIdPointer(), context.getBufferAllocation(), context.getBufferReduction(), context.getBufferScalar(), context.getBufferSpecial(), hostYShapeInfo, hostZShapeInfo, null, null});
        OpaqueDataBuffer xb = x == null ? null : ((BaseCudaDataBuffer)x.data()).getOpaqueDataBuffer();
        OpaqueDataBuffer yb = op.scalar() == null ? null : ((BaseCudaDataBuffer)op.scalar().data()).getOpaqueDataBuffer();
        OpaqueDataBuffer zb = z == null ? null : ((BaseCudaDataBuffer)z.data()).getOpaqueDataBuffer();
        switch (op.getOpType()) {
            case SCALAR_BOOL: {
                nativeOps.execScalarBool(xShapeInfoHostPointer, op.opNum(), xb, (LongPointer)hostXShapeInfo, (LongPointer)xShapeInfo, zb, (LongPointer)hostZShapeInfo, (LongPointer)zShapeInfo, yb, (LongPointer)hostYShapeInfo, (LongPointer)AtomicAllocator.getInstance().getPointer(op.scalar().shapeInfoDataBuffer(), context), extraArgs);
                break;
            }
            case SCALAR: {
                nativeOps.execScalar(xShapeInfoHostPointer, op.opNum(), xb, (LongPointer)hostXShapeInfo, (LongPointer)xShapeInfo, zb, (LongPointer)hostZShapeInfo, (LongPointer)zShapeInfo, yb, (LongPointer)hostYShapeInfo, (LongPointer)AtomicAllocator.getInstance().getPointer(op.scalar().shapeInfoDataBuffer(), context), extraArgs);
                break;
            }
            default: {
                throw new UnsupportedOperationException("Unknown op type: " + op.getOpType());
            }
        }
        if (nativeOps.lastErrorCode() != 0) {
            throw new RuntimeException(nativeOps.lastErrorMessage());
        }
        this.profilingConfigurableHookOut((Op)op, oc, st);
        return null;
    }

    protected CudaContext invoke(TransformOp op, OpContext oc) {
        OpaqueDataBuffer zb;
        Pointer hostYShapeInfo;
        long st = this.profilingConfigurableHookIn((Op)op, new DataBuffer[0]);
        INDArray x = this.getX((Op)op, oc);
        INDArray y = this.getY((Op)op, oc);
        INDArray z = this.getZ((Op)op, oc);
        this.checkForCompression((Op)op);
        AtomicAllocator allocator = AtomicAllocator.getInstance();
        if (this.extraz.get() == null) {
            this.extraz.set(new PointerPointer(32L));
        }
        CudaContext context = allocator.getDeviceContext();
        if (CudaEnvironment.getInstance().getConfiguration().isDebug()) {
            this.lastOp.set(op.opName());
        }
        INDArray ret = null;
        Pointer xShapeInfo = allocator.getPointer(x.shapeInfoDataBuffer(), context);
        Object dimensionDevPointer = null;
        Object dimensionHostPointer = null;
        Object retPointer = null;
        Object retHostShape = null;
        Object dimension = null;
        Pointer hostXShapeInfo = x == null ? null : AddressRetriever.retrieveHostPointer(x.shapeInfoDataBuffer());
        Pointer pointer = hostYShapeInfo = y == null ? null : AddressRetriever.retrieveHostPointer(y.shapeInfoDataBuffer());
        if (z == null) {
            ret = Nd4j.createUninitialized((DataType)op.resultType(), (long[])x.shape(), (char)x.ordering());
            this.setZ(ret, (Op)op, oc);
            z = ret;
        }
        Pointer extraArgs = op.extraArgs() != null ? allocator.getPointer(op.extraArgsDataBuff(op.getOpType() == Op.Type.TRANSFORM_BOOL || op.getOpType() == Op.Type.PAIRWISE_BOOL ? x.dataType() : z.dataType()), context) : null;
        Pointer hostZShapeInfo = z == null ? null : AddressRetriever.retrieveHostPointer(z.shapeInfoDataBuffer());
        Object hostTadShapeInfo = null;
        Object devTadShapeInfo = null;
        Object hostMaxTadShapeInfo = null;
        Object devMaxTadShapeInfo = null;
        Object devTadOffsets = null;
        Object devMaxTadOffsets = null;
        op.validateDataTypes(oc, this.experimentalMode.get());
        Pointer zShapeInfo = allocator.getPointer(z.shapeInfoDataBuffer(), context);
        PointerPointer xShapeInfoHostPointer = this.extraz.get().put(new Pointer[]{AddressRetriever.retrieveHostPointer(x.shapeInfoDataBuffer()), context.getOldStream(), allocator.getDeviceIdPointer(), context.getBufferAllocation(), context.getBufferReduction(), context.getBufferScalar(), context.getBufferSpecial(), hostYShapeInfo, hostZShapeInfo, hostTadShapeInfo, devTadShapeInfo, devTadOffsets, hostMaxTadShapeInfo, devMaxTadShapeInfo, devMaxTadOffsets, dimensionDevPointer, dimensionHostPointer, retPointer, new CudaPointer(dimension == null ? 0L : (long)(dimension).length), retHostShape});
        OpaqueDataBuffer xb = x == null ? null : ((BaseCudaDataBuffer)x.data()).getOpaqueDataBuffer();
        OpaqueDataBuffer yb = y == null ? null : ((BaseCudaDataBuffer)y.data()).getOpaqueDataBuffer();
        OpaqueDataBuffer opaqueDataBuffer = zb = z == null ? null : ((BaseCudaDataBuffer)z.data()).getOpaqueDataBuffer();
        if (y != null) {
            Pointer yShapeInfo = allocator.getPointer(y.shapeInfoDataBuffer(), context);
            switch (op.getOpType()) {
                case TRANSFORM_BOOL: 
                case PAIRWISE_BOOL: {
                    nativeOps.execPairwiseTransformBool(xShapeInfoHostPointer, op.opNum(), xb, (LongPointer)hostXShapeInfo, (LongPointer)xShapeInfo, yb, (LongPointer)hostYShapeInfo, (LongPointer)yShapeInfo, zb, (LongPointer)hostZShapeInfo, (LongPointer)zShapeInfo, extraArgs);
                    break;
                }
                default: {
                    nativeOps.execPairwiseTransform(xShapeInfoHostPointer, op.opNum(), xb, (LongPointer)hostXShapeInfo, (LongPointer)xShapeInfo, yb, (LongPointer)hostYShapeInfo, (LongPointer)yShapeInfo, zb, (LongPointer)hostZShapeInfo, (LongPointer)zShapeInfo, extraArgs);
                    break;
                }
            }
        } else {
            switch (op.getOpType()) {
                case TRANSFORM_ANY: {
                    nativeOps.execTransformAny(xShapeInfoHostPointer, op.opNum(), xb, (LongPointer)hostXShapeInfo, (LongPointer)xShapeInfo, zb, (LongPointer)hostZShapeInfo, (LongPointer)zShapeInfo, extraArgs);
                    break;
                }
                case TRANSFORM_FLOAT: {
                    nativeOps.execTransformFloat(xShapeInfoHostPointer, op.opNum(), xb, (LongPointer)hostXShapeInfo, (LongPointer)xShapeInfo, zb, (LongPointer)hostZShapeInfo, (LongPointer)zShapeInfo, extraArgs);
                    break;
                }
                case TRANSFORM_BOOL: {
                    nativeOps.execTransformBool(xShapeInfoHostPointer, op.opNum(), xb, (LongPointer)hostXShapeInfo, (LongPointer)xShapeInfo, zb, (LongPointer)hostZShapeInfo, (LongPointer)zShapeInfo, extraArgs);
                    break;
                }
                case TRANSFORM_SAME: {
                    nativeOps.execTransformSame(xShapeInfoHostPointer, op.opNum(), xb, (LongPointer)hostXShapeInfo, (LongPointer)xShapeInfo, zb, (LongPointer)hostZShapeInfo, (LongPointer)zShapeInfo, extraArgs);
                    break;
                }
                case TRANSFORM_STRICT: {
                    nativeOps.execTransformStrict(xShapeInfoHostPointer, op.opNum(), xb, (LongPointer)hostXShapeInfo, (LongPointer)xShapeInfo, zb, (LongPointer)hostZShapeInfo, (LongPointer)zShapeInfo, extraArgs);
                    break;
                }
                default: {
                    throw new UnsupportedOperationException();
                }
            }
        }
        if (nativeOps.lastErrorCode() != 0) {
            throw new RuntimeException(nativeOps.lastErrorMessage());
        }
        if (extraArgs != null) {
            extraArgs.address();
        }
        if (ret != null) {
            ret.elementWiseStride();
        }
        this.profilingConfigurableHookOut((Op)op, oc, st);
        return null;
    }

    protected <T extends Aggregate> DataBuffer getBuffer(Batch<T> batch) {
        DataBuffer buffer = Nd4j.getDataBufferFactory().createInt(batch.getSample().getRequiredBatchMemorySize() * 4L, false);
        batch.setParamsSurface(buffer);
        return buffer;
    }

    public <T extends Aggregate> void exec(Batch<T> batch) {
        throw new UnsupportedOperationException("Pew-pew");
    }

    public void exec(List<Aggregate> batch) {
        if (batch.size() == 0) {
            return;
        }
        List batches = Batch.getBatches(batch, (int)8192);
        for (Batch single : batches) {
            this.exec(single);
        }
        CudaContext context = AtomicAllocator.getInstance().getDeviceContext();
        context.syncOldStream();
    }

    public void exec(Aggregate op) {
        throw new UnsupportedOperationException("Pew-pew");
    }

    public INDArray exec(RandomOp op) {
        return this.exec(op, Nd4j.getRandom());
    }

    public INDArray exec(RandomOp op, Random rng) {
        return this.exec(op, null, rng);
    }

    public INDArray exec(RandomOp op, OpContext oc, Random rng) {
        OpaqueDataBuffer zb;
        INDArray x = this.getX((Op)op, oc);
        INDArray y = this.getY((Op)op, oc);
        INDArray z = this.getZ((Op)op, oc);
        if (op instanceof BaseRandomOp && ((BaseRandomOp)op).isTripleArgRngOp() && z != null && x == null && y == null) {
            x = z;
            y = z;
        }
        long st = this.profilingConfigurableHookIn((Op)op, new DataBuffer[0]);
        this.checkForCompression((Op)op);
        if (rng.getStatePointer() == null) {
            throw new IllegalStateException("You should use one of NativeRandom classes for NativeOperations execution");
        }
        if (this.extraz.get() == null) {
            this.extraz.set(new PointerPointer(32L));
        }
        if (CudaEnvironment.getInstance().getConfiguration().isDebug()) {
            this.lastOp.set(op.opName());
        }
        CudaContext context = AtomicAllocator.getInstance().getDeviceContext();
        PointerPointer extraZZ = this.extraz.get().put(new Pointer[]{AddressRetriever.retrieveHostPointer(z.shapeInfoDataBuffer()), context.getOldStream(), AtomicAllocator.getInstance().getDeviceIdPointer()});
        Pointer hostXShapeInfo = x == null ? null : AddressRetriever.retrieveHostPointer(x.shapeInfoDataBuffer());
        Pointer hostYShapeInfo = y == null ? null : AddressRetriever.retrieveHostPointer(y.shapeInfoDataBuffer());
        Pointer hostZShapeInfo = z == null ? null : AddressRetriever.retrieveHostPointer(z.shapeInfoDataBuffer());
        OpaqueDataBuffer xb = x == null ? null : ((BaseCudaDataBuffer)x.data()).getOpaqueDataBuffer();
        OpaqueDataBuffer yb = y == null ? null : ((BaseCudaDataBuffer)y.data()).getOpaqueDataBuffer();
        OpaqueDataBuffer opaqueDataBuffer = zb = z == null ? null : ((BaseCudaDataBuffer)z.data()).getOpaqueDataBuffer();
        if (x != null && y != null && z != null) {
            nativeOps.execRandom3(extraZZ, op.opNum(), rng.getStatePointer(), xb, (LongPointer)hostXShapeInfo, (LongPointer)AtomicAllocator.getInstance().getPointer(x.shapeInfoDataBuffer(), context), yb, (LongPointer)hostYShapeInfo, (LongPointer)AtomicAllocator.getInstance().getPointer(y.shapeInfoDataBuffer(), context), zb, (LongPointer)hostZShapeInfo, (LongPointer)AtomicAllocator.getInstance().getPointer(z.shapeInfoDataBuffer(), context), AtomicAllocator.getInstance().getPointer(op.extraArgsDataBuff(z.dataType()), context));
        } else if (x != null && z != null) {
            nativeOps.execRandom2(extraZZ, op.opNum(), rng.getStatePointer(), xb, (LongPointer)hostXShapeInfo, (LongPointer)AtomicAllocator.getInstance().getPointer(x.shapeInfoDataBuffer(), context), zb, (LongPointer)hostZShapeInfo, (LongPointer)AtomicAllocator.getInstance().getPointer(z.shapeInfoDataBuffer(), context), AtomicAllocator.getInstance().getPointer(op.extraArgsDataBuff(z.dataType()), context));
        } else {
            nativeOps.execRandom(extraZZ, op.opNum(), rng.getStatePointer(), zb, (LongPointer)hostZShapeInfo, (LongPointer)AtomicAllocator.getInstance().getPointer(z.shapeInfoDataBuffer(), context), AtomicAllocator.getInstance().getPointer(op.extraArgsDataBuff(z.dataType()), context));
        }
        if (nativeOps.lastErrorCode() != 0) {
            throw new RuntimeException(nativeOps.lastErrorMessage());
        }
        this.profilingConfigurableHookOut((Op)op, oc, st);
        return z;
    }

    public synchronized Properties getEnvironmentInformation() {
        if (this.properties == null) {
            Properties props = super.getEnvironmentInformation();
            ArrayList devicesList = new ArrayList();
            for (int i = 0; i < nativeOps.getAvailableDevices(); ++i) {
                HashMap<String, Object> deviceProps = new HashMap<String, Object>();
                deviceProps.put("cuda.deviceName", nativeOps.getDeviceName(i));
                deviceProps.put("cuda.freeMemory", nativeOps.getDeviceFreeMemory(i));
                deviceProps.put("cuda.totalMemory", nativeOps.getDeviceTotalMemory(i));
                deviceProps.put("cuda.deviceMajor", Long.valueOf(nativeOps.getDeviceMajor(i)));
                deviceProps.put("cuda.deviceMinor", Long.valueOf(nativeOps.getDeviceMinor(i)));
                devicesList.add(i, deviceProps);
            }
            props.put("backend", "CUDA");
            props.put("cuda.availableDevices", (Object)nativeOps.getAvailableDevices());
            props.put("cuda.devicesInformation", devicesList);
            props.put("blas.vendor", Nd4j.factory().blas().getBlasVendor().toString());
            props.put("memory.free", (Object)(Pointer.maxBytes() - Pointer.totalBytes()));
            props.put("memoryBandwidth", PerformanceTracker.getInstance().getCurrentBandwidth());
            this.properties = props;
        } else {
            List devicesList = (List)this.properties.get("cuda.devicesInformation");
            for (int i = 0; i < nativeOps.getAvailableDevices(); ++i) {
                Map dev = (Map)devicesList.get(i);
                dev.put("cuda.freeMemory", nativeOps.getDeviceFreeMemory(i));
                dev.put("cuda.totalMemory", nativeOps.getDeviceTotalMemory(i));
            }
            this.properties.put("cuda.devicesInformation", devicesList);
            this.properties.put("memory.free", (Object)(Pointer.maxBytes() - Pointer.totalBytes()));
            this.properties.put("memoryBandwidth", PerformanceTracker.getInstance().getCurrentBandwidth());
        }
        return this.properties;
    }

    public TADManager getTADManager() {
        return tadManager;
    }

    public void printEnvironmentInformation() {
        super.printEnvironmentInformation();
    }

    public void commit() {
        CudaContext ctx = AtomicAllocator.getInstance().getDeviceContext();
        ctx.syncOldStream();
        ctx.syncSpecialStream();
    }

    public synchronized Map<String, CustomOpDescriptor> getCustomOperations() {
        if (this.customOps == null) {
            String[] split;
            String list = nativeOps.getAllCustomOps();
            if (list == null || list.isEmpty()) {
                log.warn("No customs ops available!");
                this.customOps = Collections.emptyMap();
                return this.customOps;
            }
            HashMap<String, CustomOpDescriptor> map = new HashMap<String, CustomOpDescriptor>();
            for (String op : split = list.split(";")) {
                if (op == null || op.isEmpty()) continue;
                String[] another = op.split(":");
                CustomOpDescriptor descriptor = CustomOpDescriptor.builder().hash(Long.valueOf(another[1]).longValue()).numInputs(Integer.valueOf(another[2]).intValue()).numOutputs(Integer.valueOf(another[3]).intValue()).allowsInplace(Integer.valueOf(another[4]) == 1).numTArgs(Integer.valueOf(another[5]).intValue()).numIArgs(Integer.valueOf(another[6]).intValue()).build();
                map.put(another[0], descriptor);
            }
            this.customOps = Collections.unmodifiableMap(map);
        }
        return this.customOps;
    }

    protected LongShapeDescriptor getShapeFromPointer(LongPointer ptr) {
        int rank = (int)ptr.get(0L);
        long[] shape = new long[rank * 2 + 4];
        for (int i = 0; i < shape.length; ++i) {
            shape[i] = ptr.get((long)i);
        }
        ArrayType t = ArrayOptionsHelper.arrayType((long[])shape);
        return LongShapeDescriptor.fromShape((long[])Shape.shape((long[])shape), (long[])Shape.stride((long[])shape), (long)Shape.elementWiseStride((long[])shape), (char)Shape.order((long[])shape), (DataType)ArrayOptionsHelper.dataType((long[])shape), (t == ArrayType.EMPTY ? 1 : 0) != 0);
    }

    public List<LongShapeDescriptor> calculateOutputShape(@NonNull CustomOp op) {
        if (op == null) {
            throw new NullPointerException("op is marked non-null but is null");
        }
        return this.calculateOutputShape(op, null);
    }

    public List<LongShapeDescriptor> calculateOutputShape(@NonNull CustomOp op, OpContext opContext) {
        Object object;
        int nIn;
        if (op == null) {
            throw new NullPointerException("op is marked non-null but is null");
        }
        Nd4j.getExecutioner().commit();
        String lc = op.opName().toLowerCase();
        long hash = op.opHash();
        ArrayList<LongShapeDescriptor> result = new ArrayList<LongShapeDescriptor>();
        int n = nIn = opContext != null ? opContext.numInputArguments() : op.numInputArguments();
        if (nIn == 0 && op.getDescriptor().getNumInputs() >= 1) {
            if (log.isTraceEnabled()) {
                log.trace("Could not calculate output shape for op {}: number of input args was 0", (Object)op.getClass().getName());
            }
            return Collections.emptyList();
        }
        PointerPointer inputBuffers = new PointerPointer((long)(nIn * 2));
        PointerPointer inputShapes = new PointerPointer((long)nIn);
        List inputArgs = opContext != null ? opContext.getInputArrays() : op.inputArguments();
        int cnt = 0;
        for (INDArray in : inputArgs) {
            AffinityManager.Location loc = Nd4j.getAffinityManager().getActiveLocation(in);
            if (loc != AffinityManager.Location.DEVICE && loc != AffinityManager.Location.EVERYWHERE) {
                Nd4j.getAffinityManager().ensureLocation(in, AffinityManager.Location.DEVICE);
                AtomicAllocator.getInstance().tickDeviceWrite(in);
            }
            if (!in.isEmpty()) {
                inputBuffers.put((long)cnt, in.data().addressPointer());
                inputBuffers.put((long)(cnt + nIn), AtomicAllocator.getInstance().getPointer(in.data()));
            }
            inputShapes.put((long)cnt++, in.shapeInfoDataBuffer().addressPointer());
        }
        int nIArgs = opContext != null ? opContext.numIArguments() : op.numIArguments();
        LongPointer iArgs = nIArgs > 0 ? new LongPointer((long)nIArgs) : null;
        cnt = 0;
        if (opContext != null) {
            for (Long i : opContext.getIArguments()) {
                iArgs.put((long)cnt++, i.longValue());
            }
        } else {
            for (Object i : (AffinityManager.Location)op.iArgs()) {
                iArgs.put((long)cnt++, (long)i);
            }
        }
        int nTArgs = opContext != null ? opContext.numTArguments() : op.numTArguments();
        DoublePointer tArgs = nTArgs > 0 ? new DoublePointer((long)nTArgs) : null;
        int nBArgs = opContext != null ? opContext.numBArguments() : op.numBArguments();
        BooleanPointer bArgs = nBArgs > 0 ? new BooleanPointer((long)nBArgs) : null;
        int nDArgs = opContext != null ? opContext.numDArguments() : op.numDArguments();
        IntPointer dArgs = nDArgs > 0 ? new IntPointer((long)nDArgs) : null;
        cnt = 0;
        if (opContext != null) {
            object = opContext.getBArguments().iterator();
            while (object.hasNext()) {
                Boolean b = (Boolean)object.next();
                bArgs.put((long)cnt++, b.booleanValue());
            }
        } else {
            for (boolean b : op.bArgs()) {
                bArgs.put((long)cnt++, b);
            }
        }
        cnt = 0;
        if (opContext != null) {
            object = opContext.getTArguments().iterator();
            while (object.hasNext()) {
                Double b = (Double)object.next();
                tArgs.put((long)cnt++, b.doubleValue());
            }
        } else {
            for (Object b : (Object)op.tArgs()) {
                tArgs.put((long)cnt++, (double)b);
            }
        }
        cnt = 0;
        if (opContext != null) {
            for (DataType b : opContext.getDArguments()) {
                dArgs.put((long)cnt++, b.toInt());
            }
        } else {
            for (DataType b : op.dArgs()) {
                dArgs.put((long)cnt++, b.toInt());
            }
        }
        OpaqueShapeList ptrptr = nativeOps.calculateOutputShapes2(null, hash, inputBuffers, inputShapes, nIn, tArgs, nTArgs, iArgs, nIArgs, bArgs, nBArgs, dArgs, nDArgs);
        if (nativeOps.lastErrorCode() != 0) {
            throw new RuntimeException(nativeOps.lastErrorMessage());
        }
        if (ptrptr == null) {
            throw new RuntimeException();
        }
        int e = 0;
        while ((long)e < nativeOps.getShapeListSize(ptrptr)) {
            result.add(this.getShapeFromPointer(new PagedPointer((Pointer)nativeOps.getShape(ptrptr, (long)e)).asLongPointer()));
            ++e;
        }
        nativeOps.deleteShapeList((Pointer)ptrptr);
        return result;
    }

    public INDArray[] exec(CustomOp op) {
        Nd4j.getExecutioner().commit();
        boolean shapeOverride = false;
        if (op.numOutputArguments() == 0 && !op.isInplaceCall()) {
            try {
                List<LongShapeDescriptor> list = this.calculateOutputShape(op);
                if (list.isEmpty()) {
                    throw new ND4JIllegalStateException("Op name " + op.opName() + " failed to execute. You can't execute non-inplace CustomOp without outputs being specified");
                }
                for (LongShapeDescriptor shape : list) {
                    op.addOutputArgument(new INDArray[]{Nd4j.create((LongShapeDescriptor)shape, (boolean)false)});
                }
                shapeOverride = true;
            }
            catch (Exception e) {
                throw new ND4JIllegalStateException("Op name " + op.opName() + " - no output arrays were provided and calculateOutputShape failed to execute", (Throwable)e);
            }
        }
        String name = op.opName();
        CudaOpContext context = (CudaOpContext)this.buildContext();
        try {
            if (shapeOverride) {
                context.shapeFunctionOverride(true);
            }
            context.markInplace(op.isInplaceCall());
            context.setRngStates(Nd4j.getRandom().rootState(), Nd4j.getRandom().nodeState());
            context.setInputArrays(op.inputArguments());
            context.setOutputArrays(op.outputArguments());
            context.setBArguments(op.bArgs());
            context.setIArguments(op.iArgs());
            context.setTArguments(op.tArgs());
            context.setDArguments(op.dArgs());
            INDArray[] result = this.exec(op, (OpContext)context);
            Pair<Long, Long> states = context.getRngStates();
            Nd4j.getRandom().setStates(((Long)states.getFirst()).longValue(), ((Long)states.getSecond()).longValue());
            INDArray[] iNDArrayArray = result;
            if (context != null) {
                context.close();
            }
            return iNDArrayArray;
        }
        catch (Throwable throwable) {
            try {
                if (context != null) {
                    try {
                        context.close();
                    }
                    catch (Throwable throwable2) {
                        throwable.addSuppressed(throwable2);
                    }
                }
                throw throwable;
            }
            catch (ND4JOpProfilerException e) {
                throw e;
            }
            catch (Exception e) {
                throw new RuntimeException("Op [" + name + "] execution failed", e);
            }
        }
    }

    public void enableDebugMode(boolean reallyEnable) {
        this.debug.set(reallyEnable);
        nativeOps.enableDebugMode(reallyEnable);
    }

    public void enableVerboseMode(boolean reallyEnable) {
        this.verbose.set(reallyEnable);
        nativeOps.enableVerboseMode(reallyEnable);
    }

    public void registerGraph(long id, Pointer graph) {
        nativeOps.registerGraph(null, id, graph);
        if (nativeOps.lastErrorCode() != 0) {
            throw new RuntimeException(nativeOps.lastErrorMessage());
        }
    }

    public Map<String, INDArray> executeGraph(long id, @NonNull Map<String, INDArray> map, @NonNull Map<String, Integer> reverseMap) {
        if (map == null) {
            throw new NullPointerException("map is marked non-null but is null");
        }
        if (reverseMap == null) {
            throw new NullPointerException("reverseMap is marked non-null but is null");
        }
        Nd4j.getExecutioner().commit();
        PointerPointer ptrBuffers = new PointerPointer((long)(map.size() * 2));
        PointerPointer ptrShapes = new PointerPointer((long)(map.size() * 2));
        IntPointer ptrIndices = new IntPointer((long)map.size());
        int cnt = 0;
        ArrayList<String> keySet = new ArrayList<String>(map.keySet());
        for (String key : keySet) {
            INDArray array = map.get(key);
            ptrBuffers.put((long)cnt, AtomicAllocator.getInstance().getHostPointer(array));
            ptrShapes.put((long)cnt, AtomicAllocator.getInstance().getHostPointer(array.shapeInfoDataBuffer()));
            ptrIndices.put((long)cnt, reverseMap.get(key).intValue());
            ++cnt;
        }
        LinkedHashMap<String, INDArray> newMap = new LinkedHashMap<String, INDArray>();
        OpaqueVariablesSet result = nativeOps.executeStoredGraph(null, id, ptrBuffers, ptrShapes, ptrIndices, map.size());
        if (nativeOps.lastErrorCode() != 0) {
            throw new RuntimeException(nativeOps.lastErrorMessage());
        }
        OpStatus status = OpStatus.byNumber((int)nativeOps.getVariablesSetStatus(result));
        if (status != OpStatus.ND4J_STATUS_OK) {
            throw new ND4JIllegalStateException("Op execution failed: " + status);
        }
        int e = 0;
        if ((long)e < nativeOps.getVariablesSetSize(result)) {
            OpaqueVariable var = nativeOps.getVariable(result, (long)e);
            int nodeId = nativeOps.getVariableId(var);
            int index = nativeOps.getVariableIndex(var);
            LongPointer shapeInfo = nativeOps.getVariableShape(var);
            Pointer buffer = nativeOps.getVariableBuffer(var);
            int rank = (int)shapeInfo.get(0L);
            long[] jshape = new long[rank * 2 + 4];
            for (int i = 0; i < jshape.length; ++i) {
                jshape[i] = shapeInfo.get((long)i);
            }
            long[] shapeOf = Shape.shapeOf((long[])jshape);
            long[] stridesOf = Shape.stridesOf((long[])jshape);
            char order = Shape.order((long[])jshape);
            INDArray array = Nd4j.create((long[])shapeOf, (long[])stridesOf, (long)0L, (char)order);
            Pointer.memcpy((Pointer)AtomicAllocator.getInstance().getHostPointer(array), (Pointer)buffer, (long)(ArrayUtil.prod((long[])shapeOf) * array.dataType().width()));
            throw new UnsupportedOperationException("Pew-pew");
        }
        if (nativeOps.lastErrorCode() != 0) {
            throw new RuntimeException(nativeOps.lastErrorMessage());
        }
        nativeOps.deleteVariablesSet(result);
        return newMap;
    }

    public void forgetGraph(long id) {
        nativeOps.unregisterGraph(null, id);
        if (nativeOps.lastErrorCode() != 0) {
            throw new RuntimeException(nativeOps.lastErrorMessage());
        }
    }

    public void setElementsThreshold(int threshold) {
        nativeOps.setElementThreshold(threshold);
    }

    public void setTadThreshold(int threshold) {
        nativeOps.setTADThreshold(threshold);
    }

    public OpExecutioner.ExecutionerType type() {
        return OpExecutioner.ExecutionerType.CUDA;
    }

    public String getString(DataBuffer buffer, long index) {
        Preconditions.checkArgument((boolean)(buffer instanceof CudaUtf8Buffer), (String)"Expected Utf8Buffer");
        long addr = ((LongIndexer)buffer.indexer()).get(index);
        PagedPointer ptr = new PagedPointer(addr);
        Nd4jCuda.utf8string str = new Nd4jCuda.utf8string((Pointer)ptr);
        return str._buffer().capacity((long)str._length()).getString();
    }

    public boolean isExperimentalMode() {
        return this.experimentalMode.get();
    }

    public void scatterUpdate(ScatterUpdate.UpdateOp op, @NonNull INDArray array, @NonNull INDArray indices, @NonNull INDArray updates, @NonNull int[] axis) {
        if (array == null) {
            throw new NullPointerException("array is marked non-null but is null");
        }
        if (indices == null) {
            throw new NullPointerException("indices is marked non-null but is null");
        }
        if (updates == null) {
            throw new NullPointerException("updates is marked non-null but is null");
        }
        if (axis == null) {
            throw new NullPointerException("axis is marked non-null but is null");
        }
        CudaContext context = AtomicAllocator.getInstance().getDeviceContext();
        Pair tadX = tadManager.getTADOnlyShapeInfo(array, axis);
        Pair tadY = tadManager.getTADOnlyShapeInfo(updates, axis);
        if (((DataBuffer)tadY.getSecond()).length() != indices.length()) {
            throw new IllegalStateException("Number of updates doesn't match number of indices. Bad dimensions used?");
        }
        if (this.extraz.get() == null) {
            this.extraz.set(new PointerPointer(32L));
        }
        PointerPointer stuff = this.extraz.get().put(new Pointer[]{null, context.getOldStream()});
        nativeOps.scatterUpdate(stuff, op.ordinal(), (int)indices.length(), null, (LongPointer)AtomicAllocator.getInstance().getHostPointer((DataBuffer)tadX.getFirst()), null, AtomicAllocator.getInstance().getPointer(array, context), (LongPointer)AtomicAllocator.getInstance().getPointer((DataBuffer)tadX.getFirst()), (LongPointer)AtomicAllocator.getInstance().getPointer((DataBuffer)tadX.getSecond()), null, (LongPointer)AtomicAllocator.getInstance().getHostPointer((DataBuffer)tadY.getFirst()), null, AtomicAllocator.getInstance().getPointer(updates, context), (LongPointer)AtomicAllocator.getInstance().getPointer((DataBuffer)tadY.getFirst()), (LongPointer)AtomicAllocator.getInstance().getPointer((DataBuffer)tadY.getSecond()), AtomicAllocator.getInstance().getHostPointer(indices), (LongPointer)AtomicAllocator.getInstance().getHostPointer(indices.shapeInfoDataBuffer()), AtomicAllocator.getInstance().getPointer(indices, context), (LongPointer)AtomicAllocator.getInstance().getPointer(indices.shapeInfoDataBuffer(), context));
        if (nativeOps.lastErrorCode() != 0) {
            throw new RuntimeException(nativeOps.lastErrorMessage());
        }
    }

    public OpContext buildContext() {
        return new CudaOpContext();
    }

    public INDArray[] exec(CustomOp op, OpContext context) {
        Nd4j.getExecutioner().commit();
        long st = this.profilingConfigurableHookIn(op, context);
        CudaContext ctx = AtomicAllocator.getInstance().getDeviceContext();
        int status = nativeOps.execCustomOp2(null, op.opHash(), context.contextPointer());
        if (nativeOps.lastErrorCode() != 0) {
            throw new RuntimeException(nativeOps.lastErrorMessage());
        }
        if (status != 0) {
            throw new RuntimeException("Op [" + op.opName() + "] execution failed");
        }
        for (INDArray in : op.inputArguments()) {
            if (in.isEmpty()) continue;
            ((BaseCudaDataBuffer)in.data()).actualizePointerAndIndexer();
        }
        for (INDArray out : op.outputArguments()) {
            if (!out.isEmpty()) {
                ((BaseCudaDataBuffer)out.data()).actualizePointerAndIndexer();
            }
            AtomicAllocator.getInstance().tickDeviceWrite(out);
        }
        this.profilingConfigurableHookOut(op, context, st);
        if (context.getOutputArrays().isEmpty()) {
            return new INDArray[0];
        }
        return context.getOutputArrays().toArray(new INDArray[context.getOutputArrays().size()]);
    }

    public INDArrayStatistics inspectArray(@NonNull INDArray array) {
        if (array == null) {
            throw new NullPointerException("array is marked non-null but is null");
        }
        Nd4jCuda.DebugInfo debugInfo = new Nd4jCuda.DebugInfo();
        CudaContext ctx = AtomicAllocator.getInstance().getDeviceContext();
        AtomicAllocator.getInstance().synchronizeHostData(array);
        if (this.extraz.get() == null) {
            this.extraz.set(new PointerPointer(32L));
        }
        PointerPointer extras = this.extraz.get().put(new Pointer[]{null, ctx.getOldStream(), AtomicAllocator.getInstance().getDeviceIdPointer(), ctx.getBufferAllocation(), ctx.getBufferReduction(), ctx.getBufferScalar(), ctx.getBufferSpecial()});
        nativeOps.inspectArray(extras, AtomicAllocator.getInstance().getHostPointer(array), (LongPointer)AtomicAllocator.getInstance().getHostPointer(array.shapeInfoDataBuffer()), AtomicAllocator.getInstance().getPointer(array, ctx), (LongPointer)AtomicAllocator.getInstance().getPointer(array.shapeInfoDataBuffer()), (Pointer)debugInfo);
        if (nativeOps.lastErrorCode() != 0) {
            throw new RuntimeException(nativeOps.lastErrorMessage());
        }
        return INDArrayStatistics.builder().minValue(debugInfo._minValue()).maxValue(debugInfo._maxValue()).meanValue(debugInfo._meanValue()).stdDevValue(debugInfo._stdDevValue()).countInf(debugInfo._infCount()).countNaN(debugInfo._nanCount()).countNegative(debugInfo._negativeCount()).countPositive(debugInfo._positiveCount()).countZero(debugInfo._zeroCount()).build();
    }

    public DataBuffer createShapeInfo(long[] shape, long[] stride, long elementWiseStride, char order, DataType dtype, boolean empty) {
        if (nativeOps.lastErrorCode() != 0) {
            throw new RuntimeException(nativeOps.lastErrorMessage());
        }
        OpaqueConstantShapeBuffer dbf = nativeOps.shapeBuffer(shape.length, new LongPointer(shape), new LongPointer(stride), dtype.toInt(), order, elementWiseStride, empty);
        if (nativeOps.lastErrorCode() != 0) {
            throw new RuntimeException(nativeOps.lastErrorMessage());
        }
        CudaLongDataBuffer result = new CudaLongDataBuffer(nativeOps.getConstantShapeBufferPrimary(dbf), nativeOps.getConstantShapeBufferSpecial(dbf), (long)Shape.shapeInfoLength((long)shape.length));
        nativeOps.deleteConstantShapeBuffer(dbf);
        return result;
    }

    public DataBuffer createShapeInfo(long[] shape, long[] stride, long elementWiseStride, char order, DataType dtype, long extras) {
        if (nativeOps.lastErrorCode() != 0) {
            throw new RuntimeException(nativeOps.lastErrorMessage());
        }
        OpaqueConstantShapeBuffer dbf = nativeOps.shapeBufferEx(shape.length, new LongPointer(shape), new LongPointer(stride), dtype.toInt(), order, elementWiseStride, extras);
        if (nativeOps.lastErrorCode() != 0) {
            throw new RuntimeException(nativeOps.lastErrorMessage());
        }
        CudaLongDataBuffer result = new CudaLongDataBuffer(nativeOps.getConstantShapeBufferPrimary(dbf), nativeOps.getConstantShapeBufferSpecial(dbf), (long)Shape.shapeInfoLength((long)shape.length));
        nativeOps.deleteConstantShapeBuffer(dbf);
        return result;
    }

    public TadPack tadShapeInfoAndOffsets(INDArray array, int[] dimension) {
        if (nativeOps.lastErrorCode() != 0) {
            throw new RuntimeException(nativeOps.lastErrorMessage());
        }
        OpaqueTadPack pack = nativeOps.tadOnlyShapeInfo((LongPointer)array.shapeInfoDataBuffer().addressPointer(), new IntPointer(dimension), dimension.length);
        if (nativeOps.lastErrorCode() != 0) {
            throw new RuntimeException(nativeOps.lastErrorMessage());
        }
        CudaLongDataBuffer tadShape = new CudaLongDataBuffer((Pointer)nativeOps.getPrimaryShapeInfo(pack), (Pointer)nativeOps.getSpecialShapeInfo(pack), (long)nativeOps.getShapeInfoLength(pack));
        CudaLongDataBuffer tadOffsets = new CudaLongDataBuffer((Pointer)nativeOps.getPrimaryOffsets(pack), (Pointer)nativeOps.getSpecialOffsets(pack), nativeOps.getNumberOfTads(pack));
        nativeOps.deleteTadPack(pack);
        return new TadPack((DataBuffer)tadShape, (DataBuffer)tadOffsets);
    }

    public DataBuffer createConstantBuffer(long[] values, DataType desiredType) {
        if (nativeOps.lastErrorCode() != 0) {
            throw new RuntimeException(nativeOps.lastErrorMessage());
        }
        OpaqueConstantDataBuffer dbf = nativeOps.constantBufferLong(desiredType.toInt(), new LongPointer(values), values.length);
        if (nativeOps.lastErrorCode() != 0) {
            throw new RuntimeException(nativeOps.lastErrorMessage());
        }
        DataBuffer buffer = Nd4j.createBuffer((Pointer)nativeOps.getConstantDataBufferPrimary(dbf), (Pointer)nativeOps.getConstantDataBufferSpecial(dbf), (long)values.length, (DataType)desiredType);
        buffer.setConstant(true);
        return buffer;
    }

    public DataBuffer createConstantBuffer(double[] values, DataType desiredType) {
        if (nativeOps.lastErrorCode() != 0) {
            throw new RuntimeException(nativeOps.lastErrorMessage());
        }
        OpaqueConstantDataBuffer dbf = nativeOps.constantBufferDouble(desiredType.toInt(), new DoublePointer(values), values.length);
        if (nativeOps.lastErrorCode() != 0) {
            throw new RuntimeException(nativeOps.lastErrorMessage());
        }
        DataBuffer buffer = Nd4j.createBuffer((Pointer)nativeOps.getConstantDataBufferPrimary(dbf), (Pointer)nativeOps.getConstantDataBufferSpecial(dbf), (long)values.length, (DataType)desiredType);
        buffer.setConstant(true);
        return buffer;
    }

    public int useCount(DataBuffer buffer) {
        return nativeOps.dbUseCount(((BaseCudaDataBuffer)buffer).getOpaqueDataBuffer());
    }
}

