From fbf7c9d38b2c21b0c2b2d350c84484dbfd106e9d Mon Sep 17 00:00:00 2001 From: shugeo Date: Thu, 2 Jan 2020 22:25:41 +0200 Subject: [PATCH] Fixed lu for cuda platform and tests. (#158) Signed-off-by: shugeo --- .../ops/declarable/generic/parity_ops/lup.cpp | 17 ++++-- .../ops/declarable/helpers/cuda/lup.cu | 7 ++- .../layers_tests/DeclarableOpsTests12.cpp | 54 +++++++++++++++++++ 3 files changed, 70 insertions(+), 8 deletions(-) diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/lup.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/lup.cpp index 83b4a42d9..e0e960159 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/lup.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/lup.cpp @@ -28,10 +28,14 @@ namespace nd4j { CUSTOM_OP_IMPL(lu, 1, 2, false, 0, 0) { auto input = INPUT_VARIABLE(0); auto z = OUTPUT_VARIABLE(0); - auto p = OUTPUT_VARIABLE(1); - REQUIRE_TRUE(input->rankOf() >=2, 0, "matrix_inverse: The rank of input array should not less than 2, but %i is given", input->rankOf()); - REQUIRE_TRUE(input->sizeAt(-1) == input->sizeAt(-2), 0, "matrix_inverse: The last two dimmensions should be equal, but %i and %i are given", input->sizeAt(-1), input->sizeAt(-2)); + auto p = OUTPUT_VARIABLE(1); + if (block.getIArguments()->size()) { + DataType dtype = (DataType)INT_ARG(0); + REQUIRE_TRUE(dtype == nd4j::DataType::INT32 || dtype == nd4j::DataType::INT64, 0, "lu: Permutation data type should be 32bit or 64bit int only, but '%s' given.", DataTypeUtils::asString(dtype).c_str()); } + + REQUIRE_TRUE(input->rankOf() >=2, 0, "lu: The rank of input array should not less than 2, but %i is given", input->rankOf()); + REQUIRE_TRUE(input->sizeAt(-1) == input->sizeAt(-2), 0, "lu: The last two dimmensions should be equal, but %i and %i are given", input->sizeAt(-1), input->sizeAt(-2)); helpers::lu(block.launchContext(), input, z, p); return Status::OK(); @@ -41,7 +45,12 @@ namespace nd4j { auto in = inputShape->at(0); auto shapeVector = ShapeUtils::shapeAsVector(in); auto luShape = ShapeBuilders::copyShapeInfoAndType(in, in, true, block.workspace()); - auto luP = ShapeBuilders::createShapeInfo(nd4j::DataType::INT32, shape::order(in), shapeVector.size() - 1, + auto dtype = nd4j::DataType::INT32; + if (block.getIArguments()->size()) { + dtype = (DataType)INT_ARG(0); + REQUIRE_TRUE(dtype == nd4j::DataType::INT32 || dtype == nd4j::DataType::INT64, 0, "lu: Permutation data type should be 32bit or 64bit int only, but '%s' given.", DataTypeUtils::asString(dtype).c_str()); + } + auto luP = ShapeBuilders::createShapeInfo(dtype, shape::order(in), shapeVector.size() - 1, shapeVector.data(), block.workspace()); return SHAPELIST(CONSTANT(luShape), CONSTANT(luP)); } diff --git a/libnd4j/include/ops/declarable/helpers/cuda/lup.cu b/libnd4j/include/ops/declarable/helpers/cuda/lup.cu index ce1dc2e95..3e8def28a 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/lup.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/lup.cu @@ -598,14 +598,13 @@ namespace helpers { static void lu_(LaunchContext * context, NDArray* input, NDArray* output, NDArray* permutationVectors) { auto n = input->sizeAt(-1); auto stream = context->getCudaStream(); - auto iota = NDArrayFactory::create('c', {n}); + NDArray iota('c', {n}, permutationVectors->dataType());// = NDArrayFactory::create(); // ('c', {n}); iota.linspace(0); iota.syncToDevice(); output->assign(input); // fill up output tensor with zeros - output->tickWriteDevice(); +// output->tickWriteDevice(); permutationVectors->applyTrueBroadcast(nd4j::BroadcastOpsTuple::Assign(), iota, *permutationVectors, true, nullptr); - permutationVectors->tickWriteDevice(); - +// permutationVectors->tickWriteDevice(); auto tads = ConstantTadHelper::getInstance()->tadForDimensions(output->shapeInfo(), {-2, -1}); auto permutaionTads = ConstantTadHelper::getInstance()->tadForDimensions(output->shapeInfo(), {-1}); auto batchNum = tads.numberOfTads(); diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests12.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests12.cpp index c1cb872b4..4cf10ed00 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests12.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests12.cpp @@ -2677,3 +2677,57 @@ TEST_F(DeclarableOpsTests12, LU_Test_3_3) { ASSERT_TRUE(expP.equalsTo(p)); delete res; } + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests12, LU_Test_4_1) { + + auto in = NDArrayFactory::create('c', {2, 2,2}, {0.7788f, 0.8012f, + 0.7244f, 0.2309f, + 0.7271f, 0.1804f, + 0.5056f, 0.8925f}); + auto expLU = NDArrayFactory::create('c', {2, 2,2}, { + 0.7788f, 0.8012f, 0.930149f, -0.514335f, + 0.7271f, 0.1804f, 0.695365f, 0.767056f + }); + + auto expP = NDArrayFactory::create('c', {2,2}, {0, 1, 0, 1}); + nd4j::ops::lu op; + + auto res = op.execute({&in}, {}, {}); + ASSERT_EQ(res->status(), ND4J_STATUS_OK); + auto z = res->at(0); + auto p = res->at(1); +// z->printIndexedBuffer("Triangulars4_1"); +// p->printIndexedBuffer("Permutaions4_1"); + + ASSERT_TRUE(expLU.equalsTo(z)); + ASSERT_TRUE(expP.equalsTo(p)); + delete res; +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests12, LU_Test_4_2) { + + auto in = NDArrayFactory::create('c', {2, 2,2}, {0.7788f, 0.8012f, + 0.7244f, 0.2309f, + 0.7271f, 0.1804f, + 0.5056f, 0.8925f}); + auto expLU = NDArrayFactory::create('c', {2, 2,2}, { + 0.7788f, 0.8012f, 0.930149f, -0.514335f, + 0.7271f, 0.1804f, 0.695365f, 0.767056f + }); + + auto expP = NDArrayFactory::create('c', {2,2}, {0, 1, 0, 1}); + nd4j::ops::lu op; + + auto res = op.execute({&in}, {}, {nd4j::DataType::INT64}); + ASSERT_EQ(res->status(), ND4J_STATUS_OK); + auto z = res->at(0); + auto p = res->at(1); +// z->printIndexedBuffer("Triangulars4_2"); +// p->printIndexedBuffer("Permutaions4_2"); + + ASSERT_TRUE(expLU.equalsTo(z)); + ASSERT_TRUE(expP.equalsTo(p)); + delete res; +}