package org.nd4j.linalg.jcublas.context;

import org.bytedeco.javacpp.Pointer;
import org.bytedeco.javacpp.PointerPointer;
import org.nd4j.jita.allocator.pointers.cuda.cublasHandle_t;
import org.nd4j.jita.allocator.pointers.cuda.cudaStream_t;
import org.nd4j.jita.allocator.pointers.cuda.cusolverDnHandle_t;
import org.nd4j.linalg.exception.ND4JIllegalStateException;
import org.nd4j.nativeblas.NativeOps;
import org.nd4j.nativeblas.NativeOpsHolder;

/* loaded from: input_file:org/nd4j/linalg/jcublas/context/CudaContext.class */
public class CudaContext {
    private cudaStream_t oldStream;
    private cudaStream_t specialStream;
    private cublasHandle_t cublasHandle;
    private cusolverDnHandle_t solverHandle;
    private Pointer bufferReduction;
    private Pointer bufferAllocation;
    private Pointer bufferScalar;
    private Pointer bufferSpecial;
    private int deviceId;
    private static final transient NativeOps nativeOps = NativeOpsHolder.getInstance().getDeviceNativeOps();

    /* loaded from: input_file:org/nd4j/linalg/jcublas/context/CudaContext$CudaContextBuilder.class */
    public static class CudaContextBuilder {
        private cudaStream_t oldStream;
        private cudaStream_t specialStream;
        private cublasHandle_t cublasHandle;
        private cusolverDnHandle_t solverHandle;
        private Pointer bufferReduction;
        private Pointer bufferAllocation;
        private Pointer bufferScalar;
        private Pointer bufferSpecial;
        private int deviceId;

        CudaContextBuilder() {
        }

        public CudaContextBuilder oldStream(cudaStream_t cudastream_t) {
            this.oldStream = cudastream_t;
            return this;
        }

        public CudaContextBuilder specialStream(cudaStream_t cudastream_t) {
            this.specialStream = cudastream_t;
            return this;
        }

        public CudaContextBuilder cublasHandle(cublasHandle_t cublashandle_t) {
            this.cublasHandle = cublashandle_t;
            return this;
        }

        public CudaContextBuilder solverHandle(cusolverDnHandle_t cusolverdnhandle_t) {
            this.solverHandle = cusolverdnhandle_t;
            return this;
        }

        public CudaContextBuilder bufferReduction(Pointer pointer) {
            this.bufferReduction = pointer;
            return this;
        }

        public CudaContextBuilder bufferAllocation(Pointer pointer) {
            this.bufferAllocation = pointer;
            return this;
        }

        public CudaContextBuilder bufferScalar(Pointer pointer) {
            this.bufferScalar = pointer;
            return this;
        }

        public CudaContextBuilder bufferSpecial(Pointer pointer) {
            this.bufferSpecial = pointer;
            return this;
        }

        public CudaContextBuilder deviceId(int i) {
            this.deviceId = i;
            return this;
        }

        public CudaContext build() {
            return new CudaContext(this.oldStream, this.specialStream, this.cublasHandle, this.solverHandle, this.bufferReduction, this.bufferAllocation, this.bufferScalar, this.bufferSpecial, this.deviceId);
        }

        public String toString() {
            return "CudaContext.CudaContextBuilder(oldStream=" + this.oldStream + ", specialStream=" + this.specialStream + ", cublasHandle=" + this.cublasHandle + ", solverHandle=" + this.solverHandle + ", bufferReduction=" + this.bufferReduction + ", bufferAllocation=" + this.bufferAllocation + ", bufferScalar=" + this.bufferScalar + ", bufferSpecial=" + this.bufferSpecial + ", deviceId=" + this.deviceId + ")";
        }
    }

    public String toString() {
        return "CudaContext{bufferReduction=" + this.bufferReduction + ", bufferScalar=" + this.bufferScalar + ", deviceId=" + this.deviceId + '}';
    }

    public void syncOldStream() {
        if (nativeOps.streamSynchronize(this.oldStream) == 0) {
            throw new ND4JIllegalStateException("CUDA stream synchronization failed");
        }
    }

    public void syncSpecialStream() {
        if (nativeOps.streamSynchronize(this.specialStream) == 0) {
            throw new ND4JIllegalStateException("CUDA special stream synchronization failed");
        }
    }

    public Pointer getCublasStream() {
        return new PointerPointer(getOldStream()).get(0L);
    }

    public cublasHandle_t getCublasHandle() {
        return new cublasHandle_t(new PointerPointer(this.cublasHandle).get(0L));
    }

    public cusolverDnHandle_t getSolverHandle() {
        return new cusolverDnHandle_t(new PointerPointer(this.solverHandle).get(0L));
    }

    public static CudaContextBuilder builder() {
        return new CudaContextBuilder();
    }

    public cudaStream_t getOldStream() {
        return this.oldStream;
    }

    public cudaStream_t getSpecialStream() {
        return this.specialStream;
    }

    public Pointer getBufferReduction() {
        return this.bufferReduction;
    }

    public Pointer getBufferAllocation() {
        return this.bufferAllocation;
    }

    public Pointer getBufferScalar() {
        return this.bufferScalar;
    }

    public Pointer getBufferSpecial() {
        return this.bufferSpecial;
    }

    public int getDeviceId() {
        return this.deviceId;
    }

    public void setOldStream(cudaStream_t cudastream_t) {
        this.oldStream = cudastream_t;
    }

    public void setSpecialStream(cudaStream_t cudastream_t) {
        this.specialStream = cudastream_t;
    }

    public void setCublasHandle(cublasHandle_t cublashandle_t) {
        this.cublasHandle = cublashandle_t;
    }

    public void setSolverHandle(cusolverDnHandle_t cusolverdnhandle_t) {
        this.solverHandle = cusolverdnhandle_t;
    }

    public void setBufferReduction(Pointer pointer) {
        this.bufferReduction = pointer;
    }

    public void setBufferAllocation(Pointer pointer) {
        this.bufferAllocation = pointer;
    }

    public void setBufferScalar(Pointer pointer) {
        this.bufferScalar = pointer;
    }

    public void setBufferSpecial(Pointer pointer) {
        this.bufferSpecial = pointer;
    }

    public void setDeviceId(int i) {
        this.deviceId = i;
    }

    public boolean equals(Object obj) {
        if (obj == this) {
            return true;
        }
        if (!(obj instanceof CudaContext)) {
            return false;
        }
        CudaContext cudaContext = (CudaContext) obj;
        if (!cudaContext.canEqual(this) || getDeviceId() != cudaContext.getDeviceId()) {
            return false;
        }
        cudaStream_t oldStream = getOldStream();
        cudaStream_t oldStream2 = cudaContext.getOldStream();
        if (oldStream == null) {
            if (oldStream2 != null) {
                return false;
            }
        } else if (!oldStream.equals(oldStream2)) {
            return false;
        }
        cudaStream_t specialStream = getSpecialStream();
        cudaStream_t specialStream2 = cudaContext.getSpecialStream();
        if (specialStream == null) {
            if (specialStream2 != null) {
                return false;
            }
        } else if (!specialStream.equals(specialStream2)) {
            return false;
        }
        cublasHandle_t cublasHandle = getCublasHandle();
        cublasHandle_t cublasHandle2 = cudaContext.getCublasHandle();
        if (cublasHandle == null) {
            if (cublasHandle2 != null) {
                return false;
            }
        } else if (!cublasHandle.equals(cublasHandle2)) {
            return false;
        }
        cusolverDnHandle_t solverHandle = getSolverHandle();
        cusolverDnHandle_t solverHandle2 = cudaContext.getSolverHandle();
        if (solverHandle == null) {
            if (solverHandle2 != null) {
                return false;
            }
        } else if (!solverHandle.equals(solverHandle2)) {
            return false;
        }
        Pointer bufferReduction = getBufferReduction();
        Pointer bufferReduction2 = cudaContext.getBufferReduction();
        if (bufferReduction == null) {
            if (bufferReduction2 != null) {
                return false;
            }
        } else if (!bufferReduction.equals(bufferReduction2)) {
            return false;
        }
        Pointer bufferAllocation = getBufferAllocation();
        Pointer bufferAllocation2 = cudaContext.getBufferAllocation();
        if (bufferAllocation == null) {
            if (bufferAllocation2 != null) {
                return false;
            }
        } else if (!bufferAllocation.equals(bufferAllocation2)) {
            return false;
        }
        Pointer bufferScalar = getBufferScalar();
        Pointer bufferScalar2 = cudaContext.getBufferScalar();
        if (bufferScalar == null) {
            if (bufferScalar2 != null) {
                return false;
            }
        } else if (!bufferScalar.equals(bufferScalar2)) {
            return false;
        }
        Pointer bufferSpecial = getBufferSpecial();
        Pointer bufferSpecial2 = cudaContext.getBufferSpecial();
        return bufferSpecial == null ? bufferSpecial2 == null : bufferSpecial.equals(bufferSpecial2);
    }

    protected boolean canEqual(Object obj) {
        return obj instanceof CudaContext;
    }

    public int hashCode() {
        int deviceId = (1 * 59) + getDeviceId();
        cudaStream_t oldStream = getOldStream();
        int hashCode = (deviceId * 59) + (oldStream == null ? 43 : oldStream.hashCode());
        cudaStream_t specialStream = getSpecialStream();
        int hashCode2 = (hashCode * 59) + (specialStream == null ? 43 : specialStream.hashCode());
        cublasHandle_t cublasHandle = getCublasHandle();
        int hashCode3 = (hashCode2 * 59) + (cublasHandle == null ? 43 : cublasHandle.hashCode());
        cusolverDnHandle_t solverHandle = getSolverHandle();
        int hashCode4 = (hashCode3 * 59) + (solverHandle == null ? 43 : solverHandle.hashCode());
        Pointer bufferReduction = getBufferReduction();
        int hashCode5 = (hashCode4 * 59) + (bufferReduction == null ? 43 : bufferReduction.hashCode());
        Pointer bufferAllocation = getBufferAllocation();
        int hashCode6 = (hashCode5 * 59) + (bufferAllocation == null ? 43 : bufferAllocation.hashCode());
        Pointer bufferScalar = getBufferScalar();
        int hashCode7 = (hashCode6 * 59) + (bufferScalar == null ? 43 : bufferScalar.hashCode());
        Pointer bufferSpecial = getBufferSpecial();
        return (hashCode7 * 59) + (bufferSpecial == null ? 43 : bufferSpecial.hashCode());
    }

    public CudaContext(cudaStream_t cudastream_t, cudaStream_t cudastream_t2, cublasHandle_t cublashandle_t, cusolverDnHandle_t cusolverdnhandle_t, Pointer pointer, Pointer pointer2, Pointer pointer3, Pointer pointer4, int i) {
        this.deviceId = -1;
        this.oldStream = cudastream_t;
        this.specialStream = cudastream_t2;
        this.cublasHandle = cublashandle_t;
        this.solverHandle = cusolverdnhandle_t;
        this.bufferReduction = pointer;
        this.bufferAllocation = pointer2;
        this.bufferScalar = pointer3;
        this.bufferSpecial = pointer4;
        this.deviceId = i;
    }

    public CudaContext() {
        this.deviceId = -1;
    }
}
