diff --git a/libnd4j/blas/cuda/NativeOps.cu b/libnd4j/blas/cuda/NativeOps.cu index 16c888c0a..1f8149865 100755 --- a/libnd4j/blas/cuda/NativeOps.cu +++ b/libnd4j/blas/cuda/NativeOps.cu @@ -3602,7 +3602,13 @@ void deleteGraphContext(nd4j::graph::Context* ptr) { nd4j::graph::RandomGenerator* createRandomGenerator(Nd4jLong rootSeed, Nd4jLong nodeSeed) { - return new nd4j::graph::RandomGenerator(rootSeed, nodeSeed); + try { + return new nd4j::graph::RandomGenerator(rootSeed, nodeSeed); + } catch (std::exception &e) { + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); + return nullptr; + } } Nd4jLong getRandomGeneratorRootState(nd4j::graph::RandomGenerator* ptr) { diff --git a/libnd4j/include/graph/RandomGenerator.h b/libnd4j/include/graph/RandomGenerator.h index e58f415c5..de475b8f8 100644 --- a/libnd4j/include/graph/RandomGenerator.h +++ b/libnd4j/include/graph/RandomGenerator.h @@ -28,6 +28,7 @@ #include #include #include +#include #ifdef __CUDACC__ #include @@ -46,7 +47,10 @@ namespace nd4j { public: void *operator new(size_t len) { void *ptr; - cudaHostAlloc(&ptr, len, cudaHostAllocDefault); + auto res = cudaHostAlloc(&ptr, len, cudaHostAllocDefault); + if (res != 0) + throw std::runtime_error("CudaManagedRandomGenerator: failed to allocate memory"); + return ptr; } diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/rng/CudaNativeRandom.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/rng/CudaNativeRandom.java index 6e2d8ebf0..68fff737a 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/rng/CudaNativeRandom.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/rng/CudaNativeRandom.java @@ -54,6 +54,10 @@ public class CudaNativeRandom extends NativeRandom { public void init() { nativeOps = NativeOpsHolder.getInstance().getDeviceNativeOps(); statePointer = nativeOps.createRandomGenerator(this.seed, this.seed ^ 0xdeadbeef); + + if (nativeOps.lastErrorCode() != 0) + throw new RuntimeException(nativeOps.lastErrorMessage()); + setSeed(seed); }