From 256c9d20b0482df7c84843ca42794fd5289591e0 Mon Sep 17 00:00:00 2001 From: raver119 Date: Thu, 23 Jan 2020 09:51:02 +0300 Subject: [PATCH] alloc check for RNG (#179) * missing alloc validation in RandomGenerator for CUDA Signed-off-by: raver119 * set error message if rng alloc failed Signed-off-by: raver119 * check for error code during RNG creation in java Signed-off-by: raver119 --- libnd4j/blas/cuda/NativeOps.cu | 8 +++++++- libnd4j/include/graph/RandomGenerator.h | 6 +++++- .../org/nd4j/linalg/jcublas/rng/CudaNativeRandom.java | 4 ++++ 3 files changed, 16 insertions(+), 2 deletions(-) diff --git a/libnd4j/blas/cuda/NativeOps.cu b/libnd4j/blas/cuda/NativeOps.cu index 16c888c0a..1f8149865 100755 --- a/libnd4j/blas/cuda/NativeOps.cu +++ b/libnd4j/blas/cuda/NativeOps.cu @@ -3602,7 +3602,13 @@ void deleteGraphContext(nd4j::graph::Context* ptr) { nd4j::graph::RandomGenerator* createRandomGenerator(Nd4jLong rootSeed, Nd4jLong nodeSeed) { - return new nd4j::graph::RandomGenerator(rootSeed, nodeSeed); + try { + return new nd4j::graph::RandomGenerator(rootSeed, nodeSeed); + } catch (std::exception &e) { + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); + return nullptr; + } } Nd4jLong getRandomGeneratorRootState(nd4j::graph::RandomGenerator* ptr) { diff --git a/libnd4j/include/graph/RandomGenerator.h b/libnd4j/include/graph/RandomGenerator.h index e58f415c5..de475b8f8 100644 --- a/libnd4j/include/graph/RandomGenerator.h +++ b/libnd4j/include/graph/RandomGenerator.h @@ -28,6 +28,7 @@ #include #include #include +#include #ifdef __CUDACC__ #include @@ -46,7 +47,10 @@ namespace nd4j { public: void *operator new(size_t len) { void *ptr; - cudaHostAlloc(&ptr, len, cudaHostAllocDefault); + auto res = cudaHostAlloc(&ptr, len, cudaHostAllocDefault); + if (res != 0) + throw std::runtime_error("CudaManagedRandomGenerator: failed to allocate memory"); + return ptr; } diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/rng/CudaNativeRandom.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/rng/CudaNativeRandom.java index 6e2d8ebf0..68fff737a 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/rng/CudaNativeRandom.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/rng/CudaNativeRandom.java @@ -54,6 +54,10 @@ public class CudaNativeRandom extends NativeRandom { public void init() { nativeOps = NativeOpsHolder.getInstance().getDeviceNativeOps(); statePointer = nativeOps.createRandomGenerator(this.seed, this.seed ^ 0xdeadbeef); + + if (nativeOps.lastErrorCode() != 0) + throw new RuntimeException(nativeOps.lastErrorMessage()); + setSeed(seed); }