From 11cb56104519ca82a58f4a2b66fb947d25b071db Mon Sep 17 00:00:00 2001 From: Oleh Date: Wed, 12 Feb 2020 13:12:17 +0200 Subject: [PATCH 1/4] Oleh true broadcast opt (#234) * libnd4j trueBroadcast special case Signed-off-by: Oleg * libnd4j fix trueBroadcast special case Signed-off-by: Oleg * libnd4j special case of TrueBroadcastHelper Signed-off-by: Oleg * libnd4j trueBroadCast special case and test * libnd4j minor changes sync with master * libnd4j changes to TrueBroadcastHelper.hpp per require Signed-off-by: Oleg --- libnd4j/CMakeSettings.json | 2 +- .../include/loops/cpu/TrueBroadcastHelper.hpp | 23 +++++++++++++++++-- .../layers_tests/DeclarableOpsTests14.cpp | 21 +++++++++++++++++ 3 files changed, 43 insertions(+), 3 deletions(-) diff --git a/libnd4j/CMakeSettings.json b/libnd4j/CMakeSettings.json index afda69260..867132ab2 100644 --- a/libnd4j/CMakeSettings.json +++ b/libnd4j/CMakeSettings.json @@ -1,4 +1,4 @@ -{ +{ "configurations": [ { "name": "x64-Debug", diff --git a/libnd4j/include/loops/cpu/TrueBroadcastHelper.hpp b/libnd4j/include/loops/cpu/TrueBroadcastHelper.hpp index c79c1f242..6005c3647 100644 --- a/libnd4j/include/loops/cpu/TrueBroadcastHelper.hpp +++ b/libnd4j/include/loops/cpu/TrueBroadcastHelper.hpp @@ -32,9 +32,10 @@ template template void TrueBroadcastHelper::exec(const NDArray& xArr, const NDArray& yArr, NDArray& zArr) { + const X* x = reinterpret_cast(xArr.getBuffer()); const Y* y = reinterpret_cast(yArr.getBuffer()); - Z* z = reinterpret_cast(zArr.getBuffer()); + Z* z = reinterpret_cast(zArr.getBuffer()); const auto xShapeInfo = xArr.getShapeInfo(); const auto yShapeInfo = yArr.getShapeInfo(); @@ -44,8 +45,26 @@ void TrueBroadcastHelper::exec(const NDArray& xArr, const NDArray& yArr const int yRank = yArr.rankOf(); const int zRank = zArr.rankOf(); - const Nd4jLong zLen = zArr.lengthOf(); + bool bSpecialCase = (1 == xArr.ews() && 'c' == xArr.ordering() && 1 == yRank && + 1 == yArr.ews() && 'c' == yArr.ordering() && + 1 == zArr.ews() && 'c' == zArr.ordering()); + if (bSpecialCase) { + auto yLen = (uint32_t)yArr.lengthOf(); + auto func = PRAGMA_THREADS_FOR{ + for (uint32_t i = start; i < stop; i++) { + auto rZ = z + (i * yLen); + auto v = x[i]; + for (uint32_t j = 0; j < yLen; j++) { + rZ[j] = OpType::op(v, y[j]); + } + } + }; + samediff::Threads::parallel_tad(func, 0, xArr.lengthOf()); + return; + } + + const Nd4jLong zLen = zArr.lengthOf(); auto func = PRAGMA_THREADS_FOR { std::vector xCoords(xArr.rankOf()), yCoords(yArr.rankOf()), zCoords(zArr.rankOf()); diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests14.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests14.cpp index 1815e5336..600004ec2 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests14.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests14.cpp @@ -532,3 +532,24 @@ TEST_F(DeclarableOpsTests14, repeat_5) { delete result; } +///////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests14, Test_scalar_broadcast_SpecialCaseTest) { + + auto y = NDArray('c', { 3 }, nd4j::DataType::FLOAT32); + auto x = NDArray('c', { 5, 2, 1 }, nd4j::DataType::FLOAT32); + + auto e = NDArray('c', { 5, 2, 3 }, { 2., 2., 2., 3., 3., 3., 4., 4., 4., 5., 5., 5., 6., 6., 6., 7., 7., 7., 8., 8., 8., 9., 9., 9., 10., 10., 10., 11., 11., 11. }, nd4j::DataType::FLOAT32); + + y.assign(1.0); + x.linspace(1.0); + + nd4j::ops::add op; + auto result = op.evaluate({ &x, &y }); + ASSERT_EQ(Status::OK(), result->status()); + + auto res = *result->at(0); + + ASSERT_EQ(e, res); + + delete result; +} From f0c684020fb01fd4ef00b12b7e077d7f15cfc53d Mon Sep 17 00:00:00 2001 From: shugeo Date: Wed, 12 Feb 2020 18:02:42 +0200 Subject: [PATCH 2/4] Shugeo resize area fix4 (#229) * Fixed a couple of issues with resize_area op. Signed-off-by: shugeo * Added additional test for alternate params for resize_area testing. Signed-off-by: shugeo --- .../generic/parity_ops/resize_area.cpp | 12 ++-- .../declarable/helpers/cuda/image_resize.cu | 2 +- .../layers_tests/DeclarableOpsTests11.cpp | 59 +++++++++++++++++++ 3 files changed, 66 insertions(+), 7 deletions(-) diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/resize_area.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/resize_area.cpp index b0f637c45..984672ad2 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/resize_area.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/resize_area.cpp @@ -80,8 +80,8 @@ namespace nd4j { "resize_area: Resize params is a pair of values, not %i.", newImageSize->lengthOf()); REQUIRE_TRUE(block.numI() <= 1, 0, "resize_area: Resize params already given by the second param. Int params are expensive."); - width = newImageSize->e(0); - height = newImageSize->e(1); + width = newImageSize->e(1); + height = newImageSize->e(0); } else { REQUIRE_TRUE(block.numI() == 2, 0, "resize_area: Resize params ommited as pair ints nor int tensor."); @@ -95,13 +95,13 @@ namespace nd4j { outputShape[0] = inRank; if (inRank == 4) { outputShape[1] = in[1]; - outputShape[2] = width; - outputShape[3] = height; + outputShape[2] = height; + outputShape[3] = width; outputShape[4] = in[4]; } else { - outputShape[1] = width; - outputShape[2] = height; + outputShape[1] = height; + outputShape[2] = width; outputShape[3] = in[3]; } ShapeUtils::updateStridesAndType(outputShape, DataType::FLOAT32, shape::order(in)); diff --git a/libnd4j/include/ops/declarable/helpers/cuda/image_resize.cu b/libnd4j/include/ops/declarable/helpers/cuda/image_resize.cu index 94df35964..c028daff3 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/image_resize.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/image_resize.cu @@ -1116,7 +1116,7 @@ namespace helpers { err = cudaMemcpyAsync(pSt, &st, sizeof(ImageResizerState), cudaMemcpyHostToDevice, *stream); ScaleCache* cachePool; err = cudaMalloc(&cachePool, sizeof(ScaleCache) * st.batchSize * st.outWidth * st.outHeight); - resizeAreaKernel<<<128, 4, 2048, *stream>>>(pSt, cache, scale, inputPtr, input->getSpecialShapeInfo(), outputPtr, + resizeAreaKernel<<<128, 2, 2048, *stream>>>(pSt, cache, scale, inputPtr, input->getSpecialShapeInfo(), outputPtr, output->specialShapeInfo(), cachePool); err = cudaStreamSynchronize(*stream); err = cudaFree(cachePool); diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests11.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests11.cpp index 465703768..aeecaccef 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests11.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests11.cpp @@ -1520,6 +1520,65 @@ TEST_F(DeclarableOpsTests11, ImageResizeArea_Test13) { delete results; } +TEST_F(DeclarableOpsTests11, ImageResizeArea_Test14) { + + NDArray input = NDArrayFactory::create('c', {1, 5, 5, 1}, { + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25 + }); + auto size = NDArrayFactory::create({8, 7}); + NDArray expected = NDArrayFactory::create('c', {1, 8, 7, 1}, { + 1.f, 1.6f , 2.1999993f, 2.9999995f , 3.8f , 4.399997f, 5.f , 2.9999995f , 3.5999997f , 4.199999f, + 4.9999995f, 5.8f , 6.3999963f , 7.f , 5.999999f , 6.6f, 7.1999984f , 7.9999995f , 8.8f, + 9.399994f, 10.f , 10.f, 10.6f , 11.199998f, 12.f, 12.8f, 13.399992f , 14.f, 12.f , 12.599999f, + 13.199998f , 13.999998f , 14.800002f , 15.399991f , 16.f , 15.999999f , 16.599998f , 17.199995f, + 18.f , 18.800003f , 19.399986f , 20.000002f , 19.f , 19.599998f , 20.199997f , + 20.999998f , 21.800003f , 22.399984f , 23.000002f , 20.999998f , + 21.599998f , 22.199995f , 22.999998f , 23.800001f , 24.399984f , + 25.f + }); //input.linspace(1); +// auto size = NDArrayFactory::create({6, 6}); + nd4j::ops::resize_area op; + auto results = op.evaluate({&input, &size}, {}, {false}); + + ASSERT_EQ(ND4J_STATUS_OK, results->status()); + + NDArray* result = results->at(0); +// result->printBuffer("Area Resized to 8x7"); +// expected.printBuffer("Area Expect for 8x7"); + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); + delete results; +} + +TEST_F(DeclarableOpsTests11, ImageResizeArea_Test15) { + + NDArray input = NDArrayFactory::create('c', {1, 5, 5, 1}, { + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25 + }); + //auto size = NDArrayFactory::create({8, 7}); + NDArray expected = NDArrayFactory::create('c', {1, 8, 7, 1}, { + 1.f, 1.6f , 2.1999993f, 2.9999995f , 3.8f , 4.399997f, 5.f , 2.9999995f , 3.5999997f , 4.199999f, + 4.9999995f, 5.8f , 6.3999963f , 7.f , 5.999999f , 6.6f, 7.1999984f , 7.9999995f , 8.8f, + 9.399994f, 10.f , 10.f, 10.6f , 11.199998f, 12.f, 12.8f, 13.399992f , 14.f, 12.f , 12.599999f, + 13.199998f , 13.999998f , 14.800002f , 15.399991f , 16.f , 15.999999f , 16.599998f , 17.199995f, + 18.f , 18.800003f , 19.399986f , 20.000002f , 19.f , 19.599998f , 20.199997f , + 20.999998f , 21.800003f , 22.399984f , 23.000002f , 20.999998f , 21.599998f , 22.199995f , + 22.999998f , 23.800001f , 24.399984f , 25.f + }); + + nd4j::ops::resize_area op; + auto results = op.evaluate({&input}, {}, {8, 7}, {false}); + + ASSERT_EQ(ND4J_STATUS_OK, results->status()); + + NDArray* result = results->at(0); +// result->printBuffer("Area Resized to 8x7"); +// expected.printBuffer("Area Expect for 8x7"); + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); + delete results; +} + /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests11, summaryStatsData_test1) { From 5c9e0bc2bb16f449257bfce80622b195237fb4f7 Mon Sep 17 00:00:00 2001 From: Shams Ul Azeem Date: Thu, 13 Feb 2020 03:58:39 +0500 Subject: [PATCH 3/4] Ignore none type for pythonexception (#237) * Making TypeName enum public * Ignoring None type object for PythonExceptions * better handling of None + test Co-authored-by: Fariz Rahman --- .../java/org/datavec/python/PythonExecutioner.java | 2 +- .../main/java/org/datavec/python/PythonObject.java | 4 +++- .../org/datavec/python/TestPythonExecutioner.java | 11 +++++++++++ 3 files changed, 15 insertions(+), 2 deletions(-) diff --git a/datavec/datavec-python/src/main/java/org/datavec/python/PythonExecutioner.java b/datavec/datavec-python/src/main/java/org/datavec/python/PythonExecutioner.java index e2d2e5747..530dd0e02 100644 --- a/datavec/datavec-python/src/main/java/org/datavec/python/PythonExecutioner.java +++ b/datavec/datavec-python/src/main/java/org/datavec/python/PythonExecutioner.java @@ -226,7 +226,7 @@ public class PythonExecutioner { private static void throwIfExecutionFailed() throws PythonException{ PythonObject ex = getVariable(PYTHON_EXCEPTION_KEY); - if (ex != null && !ex.toString().isEmpty()){ + if (ex != null && !ex.isNone() && !ex.toString().isEmpty()) { setVariable(PYTHON_EXCEPTION_KEY, new PythonObject("")); throw new PythonException(ex); } diff --git a/datavec/datavec-python/src/main/java/org/datavec/python/PythonObject.java b/datavec/datavec-python/src/main/java/org/datavec/python/PythonObject.java index c0079919c..84dd16e73 100644 --- a/datavec/datavec-python/src/main/java/org/datavec/python/PythonObject.java +++ b/datavec/datavec-python/src/main/java/org/datavec/python/PythonObject.java @@ -583,7 +583,9 @@ public class PythonObject { } } public boolean isNone() { - return nativePythonObject == null; + return (nativePythonObject == null)|| + (toString().equals("None") && Python.type(this).toString().equals("")); + } } diff --git a/datavec/datavec-python/src/test/java/org/datavec/python/TestPythonExecutioner.java b/datavec/datavec-python/src/test/java/org/datavec/python/TestPythonExecutioner.java index b8916476c..52e2aad56 100644 --- a/datavec/datavec-python/src/test/java/org/datavec/python/TestPythonExecutioner.java +++ b/datavec/datavec-python/src/test/java/org/datavec/python/TestPythonExecutioner.java @@ -322,5 +322,16 @@ public class TestPythonExecutioner { Python.setMainContext(); } + @Test + public void testIsNone(){ + PythonObject d = Python.dict(); + PythonObject none = d.attr("get").call("x"); + Assert.assertTrue(none.isNone()); + d.set(new PythonObject("x"), new PythonObject("y")); + PythonObject notNone = d.attr("get").call("x"); + Assert.assertFalse(notNone.isNone()); + Assert.assertEquals("y", notNone.toString()); + } + } From 8c0e378ec3e9bcc8f7cdffba6935882198e6e375 Mon Sep 17 00:00:00 2001 From: Alexander Stoyakin Date: Thu, 13 Feb 2020 01:29:08 +0200 Subject: [PATCH 4/4] Improving SameDiff tests coverage (#227) * Gradients tests added * Fix for Standard deviation serialization + test Signed-off-by: Alex Black * More fixes Signed-off-by: Alex Black * Test fixed * Spark config driver host config for CI Signed-off-by: Alex Black * Op validation timeout increase Signed-off-by: Alex Black * Gradient check - fix for low probability test failure due to randomly all 0s mask Signed-off-by: AlexDBlack Co-authored-by: Alex Black --- .../GradientCheckTestsMasking.java | 12 +- .../SparkSequenceVectorsTest.java | 4 +- .../models/word2vec/SparkWord2VecTest.java | 4 +- .../embeddings/word2vec/Word2VecTest.java | 5 +- .../spark/text/TextPipelineTest.java | 2 +- .../spark/parameterserver/BaseSparkTest.java | 3 +- .../spark/BaseSparkKryoTest.java | 4 +- .../deeplearning4j/spark/BaseSparkTest.java | 3 +- .../spark/impl/TestKryoWarning.java | 20 +- ...arameterAveragingSparkVsSingleMachine.java | 1 + .../impl/paramavg/util/ExportSupportTest.java | 2 +- .../stats/TestTrainingStatsCollection.java | 11 +- .../DifferentialFunctionFactory.java | 15 +- .../nd4j/autodiff/samediff/ops/SDBaseOps.java | 12 +- .../nd4j/autodiff/samediff/ops/SDMath.java | 2 +- .../samediff/serde/FlatBuffersMapper.java | 5 +- .../autodiff/validation/OpValidation.java | 10 +- .../converters/ImportClassMapping.java | 1 - .../nd4j/linalg/api/ops/custom/Flatten.java | 30 +- .../linalg/api/ops/custom/FusedBatchNorm.java | 12 +- .../org/nd4j/linalg/api/ops/custom/Lu.java | 2 +- .../linalg/api/ops/custom/MatrixBandPart.java | 5 + .../org/nd4j/linalg/api/ops/custom/Roll.java | 9 + .../api/ops/custom/TriangularSolve.java | 20 + .../api/ops/impl/broadcast/BiasAddGrad.java | 13 +- .../api/ops/impl/indexaccum/FirstIndex.java | 7 +- .../linalg/api/ops/impl/indexaccum/IAMax.java | 2 + .../linalg/api/ops/impl/indexaccum/IAMin.java | 2 + .../linalg/api/ops/impl/indexaccum/IMax.java | 2 + .../linalg/api/ops/impl/indexaccum/IMin.java | 2 + .../api/ops/impl/indexaccum/LastIndex.java | 2 + .../ops/impl/indexaccum/custom/ArgMax.java | 2 + .../ops/impl/indexaccum/custom/ArgMin.java | 2 + .../SoftmaxCrossEntropyWithLogitsLoss.java | 14 +- .../loss/bp/SoftmaxCrossEntropyLossBp.java | 12 + .../SoftmaxCrossEntropyWithLogitsLossBp.java | 11 +- .../ops/impl/reduce/SufficientStatistics.java | 24 +- .../api/ops/impl/reduce/TensorMmul.java | 26 +- .../ops/impl/reduce/bp/BaseReductionBp.java | 10 +- .../linalg/api/ops/impl/reduce/bp/DotBp.java | 22 +- .../api/ops/impl/reduce/floating/Bias.java | 86 ----- .../api/ops/impl/shape/SequenceMask.java | 14 +- .../linalg/api/ops/impl/shape/ZerosLike.java | 21 +- .../impl/transforms/custom/BatchToSpace.java | 6 +- .../impl/transforms/custom/SpaceToBatch.java | 9 +- .../segment/UnsortedSegmentMax.java | 3 +- .../segment/UnsortedSegmentMean.java | 7 +- .../segment/UnsortedSegmentMin.java | 7 +- .../segment/UnsortedSegmentProd.java | 7 +- .../segment/UnsortedSegmentSqrtN.java | 14 +- .../segment/UnsortedSegmentSum.java | 7 +- .../opvalidation/LayerOpValidation.java | 7 +- .../opvalidation/MiscOpValidation.java | 346 +++++++++++++++++- .../opvalidation/ReductionOpValidation.java | 334 +++++++++++++++-- .../opvalidation/ShapeOpValidation.java | 33 +- .../opvalidation/TransformOpValidation.java | 12 +- .../nd4j/linalg/custom/CustomOpsTests.java | 16 + .../org/nd4s/NDArrayExtractionTest.scala | 6 +- 58 files changed, 1027 insertions(+), 255 deletions(-) delete mode 100644 nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/floating/Bias.java diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/GradientCheckTestsMasking.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/GradientCheckTestsMasking.java index e38bcf274..09afd6c2f 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/GradientCheckTestsMasking.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/GradientCheckTestsMasking.java @@ -414,7 +414,11 @@ public class GradientCheckTestsMasking extends BaseDL4JTest { INDArray l = TestUtils.randomOneHot(mb, 3); INDArray lm = TestUtils.randomBernoulli(mb, 1); - assertTrue(lm.sumNumber().intValue() > 0); + int attempts = 0; + while(attempts++ < 1000 && lm.sumNumber().intValue() == 0){ + lm = TestUtils.randomBernoulli(mb, 1); + } + assertTrue("Could not generate non-zero mask after " + attempts + " attempts", lm.sumNumber().intValue() > 0); boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(net).input(f) .labels(l).labelMask(lm)); @@ -467,7 +471,11 @@ public class GradientCheckTestsMasking extends BaseDL4JTest { INDArray l = TestUtils.randomOneHot(mb, 3); INDArray lm = TestUtils.randomBernoulli(mb, 1); - assertTrue(lm.sumNumber().intValue() > 0); + int attempts = 0; + while(attempts++ < 1000 && lm.sumNumber().intValue() == 0){ + lm = TestUtils.randomBernoulli(mb, 1); + } + assertTrue("Could not generate non-zero mask after " + attempts + " attempts", lm.sumNumber().intValue() > 0); boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.GraphConfig().net(net).inputs(new INDArray[]{f}) .labels(new INDArray[]{l}).labelMask(new INDArray[]{lm})); diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/test/java/org/deeplearning4j/spark/models/sequencevectors/SparkSequenceVectorsTest.java b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/test/java/org/deeplearning4j/spark/models/sequencevectors/SparkSequenceVectorsTest.java index 363b4e293..445788799 100644 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/test/java/org/deeplearning4j/spark/models/sequencevectors/SparkSequenceVectorsTest.java +++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/test/java/org/deeplearning4j/spark/models/sequencevectors/SparkSequenceVectorsTest.java @@ -67,7 +67,9 @@ public class SparkSequenceVectorsTest extends BaseDL4JTest { } } - SparkConf sparkConf = new SparkConf().setMaster("local[8]").setAppName("SeqVecTests"); + SparkConf sparkConf = new SparkConf().setMaster("local[8]") + .set("spark.driver.host", "localhost") + .setAppName("SeqVecTests"); sc = new JavaSparkContext(sparkConf); } diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/test/java/org/deeplearning4j/spark/models/word2vec/SparkWord2VecTest.java b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/test/java/org/deeplearning4j/spark/models/word2vec/SparkWord2VecTest.java index 82a04eab8..55d893d8c 100644 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/test/java/org/deeplearning4j/spark/models/word2vec/SparkWord2VecTest.java +++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/test/java/org/deeplearning4j/spark/models/word2vec/SparkWord2VecTest.java @@ -61,7 +61,9 @@ public class SparkWord2VecTest extends BaseDL4JTest { sentences.add("one another sentence"); } - SparkConf sparkConf = new SparkConf().setMaster("local[8]").setAppName("SeqVecTests"); + SparkConf sparkConf = new SparkConf().setMaster("local[8]") + .set("spark.driver.host", "localhost") + .setAppName("SeqVecTests"); sc = new JavaSparkContext(sparkConf); } diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/test/java/org/deeplearning4j/spark/models/embeddings/word2vec/Word2VecTest.java b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/test/java/org/deeplearning4j/spark/models/embeddings/word2vec/Word2VecTest.java index a5742ad7c..21de27048 100644 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/test/java/org/deeplearning4j/spark/models/embeddings/word2vec/Word2VecTest.java +++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/test/java/org/deeplearning4j/spark/models/embeddings/word2vec/Word2VecTest.java @@ -56,7 +56,9 @@ public class Word2VecTest { @Test public void testConcepts() throws Exception { // These are all default values for word2vec - SparkConf sparkConf = new SparkConf().setMaster("local[8]").setAppName("sparktest"); + SparkConf sparkConf = new SparkConf().setMaster("local[8]") + .set("spark.driver.host", "localhost") + .setAppName("sparktest"); // Set SparkContext JavaSparkContext sc = new JavaSparkContext(sparkConf); @@ -156,6 +158,7 @@ public class Word2VecTest { @Test public void testSparkW2VonBiggerCorpus() throws Exception { SparkConf sparkConf = new SparkConf().setMaster("local[8]").setAppName("sparktest") + .set("spark.driver.host", "localhost") .set("spark.driver.maxResultSize", "4g").set("spark.driver.memory", "8g") .set("spark.executor.memory", "8g"); diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/test/java/org/deeplearning4j/spark/text/TextPipelineTest.java b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/test/java/org/deeplearning4j/spark/text/TextPipelineTest.java index 2cae12e61..63c84de7d 100644 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/test/java/org/deeplearning4j/spark/text/TextPipelineTest.java +++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/test/java/org/deeplearning4j/spark/text/TextPipelineTest.java @@ -63,7 +63,7 @@ public class TextPipelineTest extends BaseSparkTest { @Before public void before() throws Exception { - conf = new SparkConf().setMaster("local[4]").setAppName("sparktest"); + conf = new SparkConf().setMaster("local[4]").setAppName("sparktest").set("spark.driver.host", "localhost"); // All the avaliable options. These are default values word2vec = new Word2Vec.Builder().minWordFrequency(1).setNGrams(1) diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/test/java/org/deeplearning4j/spark/parameterserver/BaseSparkTest.java b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/test/java/org/deeplearning4j/spark/parameterserver/BaseSparkTest.java index ccab68e9e..c97292a2c 100644 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/test/java/org/deeplearning4j/spark/parameterserver/BaseSparkTest.java +++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/test/java/org/deeplearning4j/spark/parameterserver/BaseSparkTest.java @@ -85,7 +85,8 @@ public abstract class BaseSparkTest extends BaseDL4JTest implements Serializable if (sc != null) return sc; // set to test mode - SparkConf sparkConf = new SparkConf().setMaster("local[" + numExecutors() + "]").set("spark.driver.host", "localhost").setAppName("sparktest"); + SparkConf sparkConf = new SparkConf().setMaster("local[" + numExecutors() + "]") + .set("spark.driver.host", "localhost").setAppName("sparktest"); sc = new JavaSparkContext(sparkConf); diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/BaseSparkKryoTest.java b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/BaseSparkKryoTest.java index 1c794ebf6..42fa57a37 100644 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/BaseSparkKryoTest.java +++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/BaseSparkKryoTest.java @@ -59,7 +59,9 @@ public class BaseSparkKryoTest extends BaseSparkTest { - SparkConf sparkConf = new SparkConf().setMaster("local[" + numExecutors() + "]").setAppName("sparktest"); + SparkConf sparkConf = new SparkConf().setMaster("local[" + numExecutors() + "]") + .setAppName("sparktest") + .set("spark.driver.host", "localhost"); sparkConf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer"); sparkConf.set("spark.kryo.registrator", "org.nd4j.Nd4jRegistrator"); diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/BaseSparkTest.java b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/BaseSparkTest.java index bfbcd34b4..3d1a9755a 100644 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/BaseSparkTest.java +++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/BaseSparkTest.java @@ -89,7 +89,8 @@ public abstract class BaseSparkTest extends BaseDL4JTest implements Serializable if (sc != null) return sc; // set to test mode - SparkConf sparkConf = new SparkConf().setMaster("local[" + numExecutors() + "]").set("spark.driver.host", "localhost").setAppName("sparktest"); + SparkConf sparkConf = new SparkConf().setMaster("local[" + numExecutors() + "]") + .set("spark.driver.host", "localhost").setAppName("sparktest"); sc = new JavaSparkContext(sparkConf); diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/impl/TestKryoWarning.java b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/impl/TestKryoWarning.java index 0969c2602..c85ba82aa 100644 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/impl/TestKryoWarning.java +++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/impl/TestKryoWarning.java @@ -72,8 +72,9 @@ public class TestKryoWarning { @Ignore public void testKryoMessageMLNIncorrectConfig() { //Should print warning message - SparkConf sparkConf = new SparkConf().setMaster("local[*]").setAppName("sparktest").set("spark.serializer", - "org.apache.spark.serializer.KryoSerializer"); + SparkConf sparkConf = new SparkConf().setMaster("local[*]").setAppName("sparktest") + .set("spark.driver.host", "localhost") + .set("spark.serializer", "org.apache.spark.serializer.KryoSerializer"); doTestMLN(sparkConf); } @@ -83,6 +84,7 @@ public class TestKryoWarning { public void testKryoMessageMLNCorrectConfigKryo() { //Should NOT print warning message SparkConf sparkConf = new SparkConf().setMaster("local[*]").setAppName("sparktest") + .set("spark.driver.host", "localhost") .set("spark.serializer", "org.apache.spark.serializer.KryoSerializer") .set("spark.kryo.registrator", "org.nd4j.Nd4jRegistrator"); @@ -93,7 +95,9 @@ public class TestKryoWarning { @Ignore public void testKryoMessageMLNCorrectConfigNoKryo() { //Should NOT print warning message - SparkConf sparkConf = new SparkConf().setMaster("local[*]").setAppName("sparktest"); + SparkConf sparkConf = new SparkConf().setMaster("local[*]") + .set("spark.driver.host", "localhost") + .setAppName("sparktest"); doTestMLN(sparkConf); } @@ -104,8 +108,9 @@ public class TestKryoWarning { @Ignore public void testKryoMessageCGIncorrectConfig() { //Should print warning message - SparkConf sparkConf = new SparkConf().setMaster("local[*]").setAppName("sparktest").set("spark.serializer", - "org.apache.spark.serializer.KryoSerializer"); + SparkConf sparkConf = new SparkConf().setMaster("local[*]").setAppName("sparktest") + .set("spark.driver.host", "localhost") + .set("spark.serializer", "org.apache.spark.serializer.KryoSerializer"); doTestCG(sparkConf); } @@ -115,6 +120,7 @@ public class TestKryoWarning { public void testKryoMessageCGCorrectConfigKryo() { //Should NOT print warning message SparkConf sparkConf = new SparkConf().setMaster("local[*]").setAppName("sparktest") + .set("spark.driver.host", "localhost") .set("spark.serializer", "org.apache.spark.serializer.KryoSerializer") .set("spark.kryo.registrator", "org.nd4j.Nd4jRegistrator"); @@ -125,7 +131,9 @@ public class TestKryoWarning { @Ignore public void testKryoMessageCGCorrectConfigNoKryo() { //Should NOT print warning message - SparkConf sparkConf = new SparkConf().setMaster("local[*]").setAppName("sparktest"); + SparkConf sparkConf = new SparkConf().setMaster("local[*]") + .set("spark.driver.host", "localhost") + .setAppName("sparktest"); doTestCG(sparkConf); } diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/impl/paramavg/TestCompareParameterAveragingSparkVsSingleMachine.java b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/impl/paramavg/TestCompareParameterAveragingSparkVsSingleMachine.java index c811e13f4..0188b15d9 100644 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/impl/paramavg/TestCompareParameterAveragingSparkVsSingleMachine.java +++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/impl/paramavg/TestCompareParameterAveragingSparkVsSingleMachine.java @@ -138,6 +138,7 @@ public class TestCompareParameterAveragingSparkVsSingleMachine { SparkConf sparkConf = new SparkConf(); sparkConf.setMaster("local[" + nWorkers + "]"); sparkConf.setAppName("Test"); + sparkConf.set("spark.driver.host", "localhost"); JavaSparkContext sc = new JavaSparkContext(sparkConf); return sc; diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/impl/paramavg/util/ExportSupportTest.java b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/impl/paramavg/util/ExportSupportTest.java index 6a393cd3c..16103a6bf 100644 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/impl/paramavg/util/ExportSupportTest.java +++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/impl/paramavg/util/ExportSupportTest.java @@ -58,7 +58,7 @@ public class ExportSupportTest { } private void assertSupported(SparkConf conf) throws IOException { - JavaSparkContext sc = new JavaSparkContext(conf.setAppName("Test")); + JavaSparkContext sc = new JavaSparkContext(conf.setAppName("Test").set("spark.driver.host", "localhost")); try { assertTrue(ExportSupport.exportSupported(sc)); } finally { diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/impl/stats/TestTrainingStatsCollection.java b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/impl/stats/TestTrainingStatsCollection.java index f56dace0e..5b49899c8 100644 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/impl/stats/TestTrainingStatsCollection.java +++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/impl/stats/TestTrainingStatsCollection.java @@ -25,6 +25,7 @@ import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.layers.DenseLayer; import org.deeplearning4j.nn.conf.layers.OutputLayer; +import org.deeplearning4j.spark.BaseSparkTest; import org.deeplearning4j.spark.api.Repartition; import org.deeplearning4j.spark.api.stats.CommonSparkTrainingStats; import org.deeplearning4j.spark.api.stats.SparkTrainingStats; @@ -50,17 +51,13 @@ import static org.junit.Assert.*; /** * Created by Alex on 17/06/2016. */ -public class TestTrainingStatsCollection { +public class TestTrainingStatsCollection extends BaseSparkTest { @Test public void testStatsCollection() throws Exception { - int nWorkers = 4; + int nWorkers = numExecutors(); - SparkConf sparkConf = new SparkConf(); - sparkConf.setMaster("local[" + nWorkers + "]"); - sparkConf.setAppName("Test"); - - JavaSparkContext sc = new JavaSparkContext(sparkConf); + JavaSparkContext sc = getContext(); try { diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunctionFactory.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunctionFactory.java index e38af27d4..8a9bd8edc 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunctionFactory.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunctionFactory.java @@ -294,6 +294,11 @@ public class DifferentialFunctionFactory { return new ZerosLike(name, sameDiff(), input).outputVariable(); } + public SDVariable zerosLike(String name, SDVariable input, DataType dataType) { + validateDifferentialFunctionsameDiff(input); + return new ZerosLike(name, sameDiff(), input, dataType).outputVariable(); + } + public SDVariable create(String name, SDVariable shape, boolean initialize, DataType dataType) { return create(name, shape, 'c', initialize, dataType); } @@ -1751,12 +1756,12 @@ public class DifferentialFunctionFactory { return new SoftmaxCrossEntropyLossBp(sameDiff(), lossReduce, logits, weights, labels, labelSmoothing).outputVariables(); } - public SDVariable lossSoftmaxCrossEntropyWithLogits(SDVariable labels, SDVariable logits, SDVariable weights, int classDim) { - return new SoftmaxCrossEntropyWithLogitsLoss(sameDiff(), logits, weights, labels, classDim).outputVariable(); + public SDVariable lossSoftmaxCrossEntropyWithLogits(SDVariable labels, SDVariable logits, int classDim) { + return new SoftmaxCrossEntropyWithLogitsLoss(sameDiff(), logits, labels, classDim).outputVariable(); } - public SDVariable[] lossSoftmaxCrossEntropyWithLogitsBp(SDVariable labels, SDVariable logits, SDVariable weights, int classDim) { - return new SoftmaxCrossEntropyWithLogitsLossBp(sameDiff(), logits, weights, labels, classDim).outputVariables(); + public SDVariable[] lossSoftmaxCrossEntropyWithLogitsBp(SDVariable labels, SDVariable logits, int classDim) { + return new SoftmaxCrossEntropyWithLogitsLossBp(sameDiff(), logits, labels, classDim).outputVariables(); } public SDVariable lossSparseSoftmaxCrossEntropy(SDVariable logits, SDVariable labels){ @@ -2638,7 +2643,7 @@ public class DifferentialFunctionFactory { return new Polygamma(sameDiff, n,x).outputVariable(); } - public SDVariable roll(SDVariable input, SDVariable shift) { + public SDVariable roll(SDVariable input, int shift) { return new Roll(sameDiff, input, shift).outputVariable(); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDBaseOps.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDBaseOps.java index f7d4a85bc..8e5d1ca36 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDBaseOps.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDBaseOps.java @@ -787,9 +787,10 @@ public abstract class SDBaseOps { * @param number Number of values to generate * @return SDVariable with linearly spaced elements */ - public SDVariable linspace(DataType dataType, double start, double stop, long number) { + // TODO: fix or remove, currently it is internal recursion + /*public SDVariable linspace(DataType dataType, double start, double stop, long number) { return linspace(dataType, start, stop, number); - } + }*/ /** * Create a new 1d array with values evenly spaced between values 'start' and 'stop' @@ -3093,6 +3094,9 @@ public abstract class SDBaseOps { return zerosLike(null, input); } + public SDVariable zerosLike(@NonNull SDVariable input, @NonNull DataType dataType) { + return zerosLike(null, input, dataType); + } /** * Return a variable of all 0s, with the same shape as the input variable. Note that this is dynamic: * if the input shape changes in later execution, the returned variable's shape will also be updated @@ -3106,6 +3110,10 @@ public abstract class SDBaseOps { return updateVariableNameAndReference(ret, name); } + public SDVariable zerosLike(String name, @NonNull SDVariable input, @NonNull DataType dataType) { + SDVariable ret = f().zerosLike(name, input, dataType); + return updateVariableNameAndReference(ret, name); + } /** * See {@link #any(String, SDVariable, int...)} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDMath.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDMath.java index 0d0da022e..1e038e193 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDMath.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDMath.java @@ -2545,7 +2545,7 @@ public class SDMath extends SDOps { * @param shift number of places to shift elements * @return array */ - public SDVariable roll(String name, SDVariable input, SDVariable shift) { + public SDVariable roll(String name, SDVariable input, int shift) { SDVariable res = f().roll(input,shift); return updateVariableNameAndReference(res, name); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/serde/FlatBuffersMapper.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/serde/FlatBuffersMapper.java index d87a890ff..20fffdc4c 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/serde/FlatBuffersMapper.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/serde/FlatBuffersMapper.java @@ -815,8 +815,9 @@ public class FlatBuffersMapper { } int[] dims; - if (node.opType() == Op.Type.REDUCE_FLOAT || node.opType() == Op.Type.REDUCE_SAME || node.opType() == Op.Type.REDUCE_BOOL - || node.opType() == Op.Type.REDUCE_LONG || node.opType() == Op.Type.INDEXREDUCE || node.opType() == Op.Type.REDUCE3) { + Type t = node.opType(); + if (t == Op.Type.REDUCE_FLOAT || t == Op.Type.REDUCE_SAME || t == Op.Type.REDUCE_BOOL + || t == Op.Type.REDUCE_LONG || t == Op.Type.INDEXREDUCE || t == Op.Type.REDUCE3 || t == Type.VARIANCE || t == Type.SUMMARYSTATS) { dims = node.getDimensions(); if (dims == null) dims = new int[0]; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/validation/OpValidation.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/validation/OpValidation.java index d57ab7c97..950baa7e8 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/validation/OpValidation.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/validation/OpValidation.java @@ -16,6 +16,7 @@ package org.nd4j.autodiff.validation; +import org.nd4j.linalg.api.ops.custom.*; import org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgMax; import org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgMin; import org.nd4j.linalg.api.ops.impl.reduce.HashCode; @@ -38,10 +39,6 @@ import org.nd4j.linalg.api.iter.NdIndexIterator; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.CustomOpDescriptor; import org.nd4j.linalg.api.ops.DynamicCustomOp; -import org.nd4j.linalg.api.ops.custom.BarnesEdgeForces; -import org.nd4j.linalg.api.ops.custom.BarnesHutGains; -import org.nd4j.linalg.api.ops.custom.BarnesHutSymmetrize; -import org.nd4j.linalg.api.ops.custom.SpTreeCell; import org.nd4j.linalg.api.ops.impl.broadcast.bool.*; import org.nd4j.linalg.api.ops.impl.layers.ExternalErrorsFunction; import org.nd4j.linalg.api.ops.impl.loss.bp.*; @@ -1011,7 +1008,10 @@ public class OpValidation { SpTreeCell.class, CbowRound.class, SkipGramRound.class, - HashCode.class + HashCode.class, + HashCode.class, + BitCast.class, + ToggleBits.class ); return new HashSet<>(list); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/ImportClassMapping.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/ImportClassMapping.java index 3ed96fe9c..b55852ac7 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/ImportClassMapping.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/ImportClassMapping.java @@ -200,7 +200,6 @@ public class ImportClassMapping { org.nd4j.linalg.api.ops.impl.reduce.custom.BatchMmul.class, org.nd4j.linalg.api.ops.impl.reduce.custom.LogSumExp.class, org.nd4j.linalg.api.ops.impl.reduce.floating.AMean.class, - org.nd4j.linalg.api.ops.impl.reduce.floating.Bias.class, org.nd4j.linalg.api.ops.impl.reduce.floating.Entropy.class, org.nd4j.linalg.api.ops.impl.reduce.floating.LogEntropy.class, org.nd4j.linalg.api.ops.impl.reduce.floating.Mean.class, diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/Flatten.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/Flatten.java index e82b80fed..721c98d45 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/Flatten.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/Flatten.java @@ -16,21 +16,28 @@ package org.nd4j.linalg.api.ops.custom; +import lombok.Data; +import lombok.NoArgsConstructor; import lombok.val; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.base.Preconditions; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; +import java.util.Arrays; +import java.util.List; + /** * This op takes arbitrary number of arrays as input, and returns single "flattened" vector * * @author raver119@gmail.com */ +@Data +@NoArgsConstructor public class Flatten extends DynamicCustomOp { - private char order; - - public Flatten() { - // - } + private int order; public Flatten(char order, INDArray... inputs) { this.order = order; @@ -47,10 +54,21 @@ public class Flatten extends DynamicCustomOp { outputArguments.add(output); } + public Flatten(SameDiff sameDiff, char order, SDVariable... inputs) { + super(sameDiff, inputs); + this.order = order; + addIArgument(order); + } + @Override public String opName() { return "flatten"; } - + @Override + public List calculateOutputDataTypes(List inputDataTypes){ + int n = args().length; + Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == n, "Expected %s input data types for %s, got %s", n, getClass(), inputDataTypes); + return Arrays.asList(inputDataTypes.get(0)); + } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/FusedBatchNorm.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/FusedBatchNorm.java index 5d9ea3642..7b41695ff 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/FusedBatchNorm.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/FusedBatchNorm.java @@ -51,6 +51,14 @@ public class FusedBatchNorm extends DynamicCustomOp { public FusedBatchNorm(@NonNull SameDiff sameDiff, @NonNull SDVariable x, @NonNull SDVariable scale, @NonNull SDVariable offset, @NonNull SDVariable dataFormat, @NonNull SDVariable isTraining) { super("", sameDiff, new SDVariable[]{x, scale, offset, dataFormat, isTraining}); + this.outputDataType = x.dataType(); + } + + public FusedBatchNorm(@NonNull SameDiff sameDiff, @NonNull SDVariable x, @NonNull SDVariable scale, @NonNull SDVariable offset, + int dataFormat, int isTraining) { + super("", sameDiff, new SDVariable[]{x, scale, offset}); + addIArgument(dataFormat, isTraining); + this.outputDataType = x.dataType(); } @Override @@ -78,6 +86,8 @@ public class FusedBatchNorm extends DynamicCustomOp { public List calculateOutputDataTypes(List inputDataTypes){ int n = args().length; Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == n, "Expected %s input data types for %s, got %s", n, getClass(), inputDataTypes); - return Arrays.asList(outputDataType, DataType.FLOAT, DataType.FLOAT); //Activations may be half, bfloat16, float32; mean/var is always float + return Arrays.asList(outputDataType == null ? DataType.FLOAT : outputDataType, + outputDataType == null ? DataType.FLOAT : outputDataType, + outputDataType == null ? DataType.FLOAT : outputDataType); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/Lu.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/Lu.java index af1bf0155..96ccde960 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/Lu.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/Lu.java @@ -64,6 +64,6 @@ public class Lu extends DynamicCustomOp { public List calculateOutputDataTypes(List inputDataTypes){ int n = args().length; Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == n, "Expected %s input data types for %s, got %s", n, getClass(), inputDataTypes); - return Arrays.asList(inputDataTypes.get(0), indexDataType); + return Arrays.asList(inputDataTypes.get(0), indexDataType == null ? DataType.INT32 : indexDataType); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/MatrixBandPart.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/MatrixBandPart.java index 46d29608e..554781958 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/MatrixBandPart.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/MatrixBandPart.java @@ -46,6 +46,11 @@ public class MatrixBandPart extends DynamicCustomOp { super("", sameDiff, new SDVariable[]{input, minLower, maxUpper}); } + public MatrixBandPart(@NonNull SameDiff sameDiff, @NonNull SDVariable input, int minLower, int maxUpper) { + super("", sameDiff, new SDVariable[]{input}); + addIArgument(minLower, maxUpper); + } + @Override public String opName() { return "matrix_band_part"; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/Roll.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/Roll.java index 9ce7aa641..004a2e5ab 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/Roll.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/Roll.java @@ -45,6 +45,15 @@ public class Roll extends DynamicCustomOp { super("", sameDiff, new SDVariable[]{input,shift}); } + public Roll(@NonNull SameDiff sameDiff, @NonNull SDVariable input, @NonNull SDVariable axes, @NonNull SDVariable shift) { + super("", sameDiff, new SDVariable[]{input,axes,shift}); + } + + public Roll(@NonNull SameDiff sameDiff, @NonNull SDVariable input, int shift) { + super("", sameDiff, new SDVariable[]{input}); + addIArgument(shift); + } + @Override public String opName() { return "roll"; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/TriangularSolve.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/TriangularSolve.java index 7423d3a91..02734b391 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/TriangularSolve.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/TriangularSolve.java @@ -7,9 +7,13 @@ import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; +import org.tensorflow.framework.AttrValue; +import org.tensorflow.framework.GraphDef; +import org.tensorflow.framework.NodeDef; import java.util.Collections; import java.util.List; +import java.util.Map; @NoArgsConstructor public class TriangularSolve extends DynamicCustomOp { @@ -24,11 +28,27 @@ public class TriangularSolve extends DynamicCustomOp { super(sameDiff, new SDVariable[] {matrix, rhs, lower, adjoint}); } + public TriangularSolve(SameDiff sameDiff, SDVariable matrix, SDVariable rhs, + boolean lower, boolean adjoint) { + super(sameDiff, new SDVariable[] {matrix, rhs}); + addBArgument(lower, adjoint); + } + @Override public String opName() { return "triangular_solve"; } + @Override + public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map attributesForNode, GraphDef graph) { + if(attributesForNode.containsKey("adjoint")){ + addBArgument(attributesForNode.get("adjoint").getB()); + } + if(attributesForNode.containsKey("lower")){ + addBArgument(attributesForNode.get("lower").getB()); + } + } + @Override public String tensorflowName() { return "MatrixTriangularSolve"; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/broadcast/BiasAddGrad.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/broadcast/BiasAddGrad.java index 714c9c321..df936bbe6 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/broadcast/BiasAddGrad.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/broadcast/BiasAddGrad.java @@ -16,6 +16,7 @@ package org.nd4j.linalg.api.ops.impl.broadcast; +import lombok.NoArgsConstructor; import lombok.NonNull; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; @@ -27,6 +28,7 @@ import org.nd4j.linalg.api.ops.DynamicCustomOp; import java.util.Arrays; import java.util.List; +@NoArgsConstructor public class BiasAddGrad extends DynamicCustomOp { protected boolean nchw = true; @@ -40,7 +42,16 @@ public class BiasAddGrad extends DynamicCustomOp { super(new INDArray[]{input, bias, gradient}, wrapOrNull(output)); } - public BiasAddGrad() {} + public BiasAddGrad(@NonNull INDArray input, @NonNull INDArray bias, @NonNull INDArray gradient, + boolean nchw) { + addInputArgument(input, bias, gradient); + this.nchw = nchw; + addBArgument(nchw); + } + + public BiasAddGrad(@NonNull INDArray input, @NonNull INDArray bias, @NonNull INDArray gradient) { + this(input, bias, gradient, false); + } @Override public int opNum() { diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/indexaccum/FirstIndex.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/indexaccum/FirstIndex.java index 8d660eba8..60b278ed7 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/indexaccum/FirstIndex.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/indexaccum/FirstIndex.java @@ -16,6 +16,8 @@ package org.nd4j.linalg.api.ops.impl.indexaccum; +import lombok.Data; +import lombok.NoArgsConstructor; import lombok.NonNull; import org.nd4j.autodiff.functions.DifferentialFunction; import org.nd4j.autodiff.samediff.SDVariable; @@ -35,6 +37,8 @@ import java.util.List; * * @author raver119@gmail.com */ +@Data +@NoArgsConstructor public class FirstIndex extends BaseIndexAccumulation { protected Condition condition; protected double compare; @@ -50,9 +54,6 @@ public class FirstIndex extends BaseIndexAccumulation { this.extraArgs = new Object[] {compare, eps, (double) mode}; } - public FirstIndex() {} - - public FirstIndex(INDArray x, @NonNull Condition condition, int... dimension) { this(x, condition, false, dimension); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/indexaccum/IAMax.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/indexaccum/IAMax.java index 4c8465ef7..b9c3962aa 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/indexaccum/IAMax.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/indexaccum/IAMax.java @@ -16,6 +16,7 @@ package org.nd4j.linalg.api.ops.impl.indexaccum; +import lombok.Data; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.linalg.api.ndarray.INDArray; @@ -30,6 +31,7 @@ import java.util.List; * * @author Adam Gibson */ +@Data public class IAMax extends BaseIndexAccumulation { public IAMax(SameDiff sameDiff, SDVariable i_v, boolean keepDims, int[] dimensions) { super(sameDiff, i_v, keepDims, dimensions); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/indexaccum/IAMin.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/indexaccum/IAMin.java index 0a1383a67..63f40ee6c 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/indexaccum/IAMin.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/indexaccum/IAMin.java @@ -16,6 +16,7 @@ package org.nd4j.linalg.api.ops.impl.indexaccum; +import lombok.Data; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.linalg.api.ndarray.INDArray; @@ -30,6 +31,7 @@ import java.util.List; * * @author Adam Gibson */ +@Data public class IAMin extends BaseIndexAccumulation { public IAMin(SameDiff sameDiff, SDVariable i_v, boolean keepDims, int[] dimensions) { super(sameDiff, i_v, keepDims, dimensions); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/indexaccum/IMax.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/indexaccum/IMax.java index 24067cb70..7280d7adf 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/indexaccum/IMax.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/indexaccum/IMax.java @@ -16,6 +16,7 @@ package org.nd4j.linalg.api.ops.impl.indexaccum; +import lombok.Data; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.imports.NoOpNameFoundException; @@ -31,6 +32,7 @@ import java.util.List; * * @author Alex Black */ +@Data public class IMax extends BaseIndexAccumulation { public IMax(SameDiff sameDiff, SDVariable i_v, boolean keepDims, int[] dimensions) { super(sameDiff, i_v, keepDims, dimensions); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/indexaccum/IMin.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/indexaccum/IMin.java index dc133b638..449ea36a0 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/indexaccum/IMin.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/indexaccum/IMin.java @@ -16,6 +16,7 @@ package org.nd4j.linalg.api.ops.impl.indexaccum; +import lombok.Data; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.imports.NoOpNameFoundException; @@ -30,6 +31,7 @@ import java.util.List; * * @author Alex Black */ +@Data public class IMin extends BaseIndexAccumulation { public IMin(SameDiff sameDiff, SDVariable i_v, boolean keepDims, int[] dimensions) { super(sameDiff, i_v, keepDims, dimensions); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/indexaccum/LastIndex.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/indexaccum/LastIndex.java index b29af5042..e77d42398 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/indexaccum/LastIndex.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/indexaccum/LastIndex.java @@ -16,6 +16,7 @@ package org.nd4j.linalg.api.ops.impl.indexaccum; +import lombok.Data; import lombok.NonNull; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; @@ -36,6 +37,7 @@ import java.util.Map; * * @author raver119@gmail.com */ +@Data public class LastIndex extends BaseIndexAccumulation { protected Condition condition; protected double compare; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/indexaccum/custom/ArgMax.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/indexaccum/custom/ArgMax.java index aa5081c22..67c232079 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/indexaccum/custom/ArgMax.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/indexaccum/custom/ArgMax.java @@ -16,6 +16,7 @@ package org.nd4j.linalg.api.ops.impl.indexaccum.custom; +import lombok.Data; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.base.Preconditions; import org.nd4j.imports.graphmapper.tf.TFGraphMapper; @@ -29,6 +30,7 @@ import java.util.Collections; import java.util.List; import java.util.Map; +@Data public class ArgMax extends DynamicCustomOp { protected DataType outputType; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/indexaccum/custom/ArgMin.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/indexaccum/custom/ArgMin.java index e93e093c5..87601f389 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/indexaccum/custom/ArgMin.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/indexaccum/custom/ArgMin.java @@ -16,6 +16,7 @@ package org.nd4j.linalg.api.ops.impl.indexaccum.custom; +import lombok.Data; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.base.Preconditions; import org.nd4j.imports.graphmapper.tf.TFGraphMapper; @@ -34,6 +35,7 @@ import java.util.Map; * * @author Alex Black */ +@Data public class ArgMin extends DynamicCustomOp { protected DataType outputType = DataType.LONG; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/SoftmaxCrossEntropyWithLogitsLoss.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/SoftmaxCrossEntropyWithLogitsLoss.java index b873ca268..3ef7de264 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/SoftmaxCrossEntropyWithLogitsLoss.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/SoftmaxCrossEntropyWithLogitsLoss.java @@ -22,6 +22,7 @@ import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ops.DynamicCustomOp; +import org.nd4j.linalg.factory.Nd4j; import java.util.Arrays; import java.util.Collections; @@ -38,8 +39,14 @@ public class SoftmaxCrossEntropyWithLogitsLoss extends DynamicCustomOp { protected int classesDim; - public SoftmaxCrossEntropyWithLogitsLoss(SameDiff sameDiff, SDVariable logits, SDVariable weights, SDVariable labels, int classesDim) { - super(null, sameDiff, new SDVariable[]{logits, weights, labels}, false); +// public SoftmaxCrossEntropyWithLogitsLoss(SameDiff sameDiff, SDVariable logits, SDVariable weights, SDVariable labels, int classesDim) { +// super(null, sameDiff, new SDVariable[]{logits, weights, labels}, false); +// this.classesDim = classesDim; +// addIArgument(classesDim); +// } + + public SoftmaxCrossEntropyWithLogitsLoss(SameDiff sameDiff, SDVariable logits, SDVariable labels, int classesDim) { + super(null, sameDiff, new SDVariable[]{logits, labels}, false); this.classesDim = classesDim; addIArgument(classesDim); } @@ -66,7 +73,8 @@ public class SoftmaxCrossEntropyWithLogitsLoss extends DynamicCustomOp { public List doDiff(List grad){ //No external gradient //Args: logits, weigths, label - SDVariable[] grads = f().lossSoftmaxCrossEntropyWithLogitsBp(arg(2), arg(0), arg(1), classesDim); + SDVariable[] args = args(); + SDVariable[] grads = f().lossSoftmaxCrossEntropyWithLogitsBp(arg(0), arg(1), classesDim); return Arrays.asList(grads); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/bp/SoftmaxCrossEntropyLossBp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/bp/SoftmaxCrossEntropyLossBp.java index 9224d0c15..8acf6f648 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/bp/SoftmaxCrossEntropyLossBp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/bp/SoftmaxCrossEntropyLossBp.java @@ -20,14 +20,18 @@ import lombok.NoArgsConstructor; import org.nd4j.autodiff.loss.LossReduce; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.base.Preconditions; import org.nd4j.imports.NoOpNameFoundException; import org.nd4j.imports.graphmapper.tf.TFGraphMapper; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ops.Op; import org.nd4j.linalg.api.ops.impl.loss.BaseLoss; import org.tensorflow.framework.AttrValue; import org.tensorflow.framework.GraphDef; import org.tensorflow.framework.NodeDef; +import java.util.Arrays; +import java.util.List; import java.util.Map; @@ -56,4 +60,12 @@ public class SoftmaxCrossEntropyLossBp extends BaseLossBp { public String opName() { return "softmax_cross_entropy_loss_grad"; } + + @Override + public List calculateOutputDataTypes(List inputDataTypes){ + Preconditions.checkState(inputDataTypes != null && (inputDataTypes.size() == 2 || inputDataTypes.size() == 3), + "Expected 2 or 3 input datatypes for %s, got %s", getClass(), inputDataTypes); + + return Arrays.asList(inputDataTypes.get(0), inputDataTypes.get(1), inputDataTypes.get(2)); //Same as predictions + } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/bp/SoftmaxCrossEntropyWithLogitsLossBp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/bp/SoftmaxCrossEntropyWithLogitsLossBp.java index 68cd88f09..cd283ee47 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/bp/SoftmaxCrossEntropyWithLogitsLossBp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/bp/SoftmaxCrossEntropyWithLogitsLossBp.java @@ -19,8 +19,10 @@ package org.nd4j.linalg.api.ops.impl.loss.bp; import lombok.NoArgsConstructor; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ops.DynamicCustomOp; +import java.util.Arrays; import java.util.List; @@ -34,8 +36,8 @@ public class SoftmaxCrossEntropyWithLogitsLossBp extends DynamicCustomOp { protected int classesDim; - public SoftmaxCrossEntropyWithLogitsLossBp(SameDiff sameDiff, SDVariable logits, SDVariable weights, SDVariable labels, int classesDim) { - super(null, sameDiff, new SDVariable[]{logits, weights, labels}, false); + public SoftmaxCrossEntropyWithLogitsLossBp(SameDiff sameDiff, SDVariable logits, SDVariable labels, int classesDim) { + super(null, sameDiff, new SDVariable[]{logits, labels}, false); this.classesDim = classesDim; addIArgument(classesDim); } @@ -49,4 +51,9 @@ public class SoftmaxCrossEntropyWithLogitsLossBp extends DynamicCustomOp { public List doDiff(List grad){ throw new UnsupportedOperationException("Differentiation of " + getClass().getName() + " not supported"); } + + @Override + public List calculateOutputDataTypes(List dataTypes) { + return Arrays.asList(arg(0).dataType(), arg(1).dataType()); + } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/SufficientStatistics.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/SufficientStatistics.java index 4f2c3181e..569d5ad3c 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/SufficientStatistics.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/SufficientStatistics.java @@ -16,9 +16,12 @@ package org.nd4j.linalg.api.ops.impl.reduce; +import lombok.NoArgsConstructor; import lombok.NonNull; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; import java.util.*; @@ -30,12 +33,9 @@ import java.util.*; * * @author Alex Black */ +@NoArgsConstructor public class SufficientStatistics extends DynamicCustomOp { - public SufficientStatistics() { - } - - public SufficientStatistics(SameDiff sameDiff, @NonNull SDVariable x, @NonNull SDVariable axis, SDVariable shift) { super(null, sameDiff, argsNoNull(x, axis, shift), false); } @@ -48,14 +48,30 @@ public class SufficientStatistics extends DynamicCustomOp { } } + public SufficientStatistics(@NonNull INDArray x, @NonNull INDArray axes, INDArray shift) { + if (shift != null) + addInputArgument(x, axes, shift); + else + addInputArgument(x, axes); + } + + public SufficientStatistics(@NonNull INDArray x, @NonNull INDArray axes) { + this(x,axes,null); + } @Override public String opName() { return "sufficient_statistics"; } + @Override public List doDiff(List grad) { throw new UnsupportedOperationException("Backprop not yet implemented for op: " + getClass().getSimpleName()); } + @Override + public List calculateOutputDataTypes(List inputDataTypes) { + // FIXME + return Arrays.asList(inputDataTypes.get(0), inputDataTypes.get(0),inputDataTypes.get(0)); + } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/TensorMmul.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/TensorMmul.java index 008a065ef..eca14e9f4 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/TensorMmul.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/TensorMmul.java @@ -111,7 +111,7 @@ public class TensorMmul extends DynamicCustomOp { int[][] deletedAxes = new int[][]{ removeIndex(aAxes, sumAxes[0]), removeIndex(bAxes, sumAxes[1])}; - int[] gAxes = range(0, i_v1.get(0).getShape().length); + int[] gAxes = range(0, i_v1.get(0).eval().shape().length); int[][] firstAxes = new int[][]{ Arrays.copyOfRange(gAxes, deletedAxes[0].length, gAxes.length), deletedAxes[1] @@ -144,18 +144,20 @@ public class TensorMmul extends DynamicCustomOp { int[][] axes) { int validationLength = Math.min(axes[0].length, axes[1].length); + INDArray aArray = a.eval(); + INDArray bArray = b.eval(); for (int i = 0; i < validationLength; i++) { - if (a.getShape()[axes[0][i]] != b.getShape()[axes[1][i]]) + if (aArray.shape()[axes[0][i]] != bArray.shape()[axes[1][i]]) throw new IllegalArgumentException("Size of the given axes at each dimension must be the same size."); if (axes[0][i] < 0) - axes[0][i] += a.getShape().length; + axes[0][i] += aArray.shape().length; if (axes[1][i] < 0) - axes[1][i] += b.getShape().length; + axes[1][i] += bArray.shape().length; } List listA = new ArrayList<>(); - for (int i = 0; i < a.getShape().length; i++) { + for (int i = 0; i < aArray.shape().length; i++) { if (!Ints.contains(axes[0], i)) listA.add(i); } @@ -164,7 +166,7 @@ public class TensorMmul extends DynamicCustomOp { List listB = new ArrayList<>(); - for (int i = 0; i < b.getShape().length; i++) { + for (int i = 0; i < bArray.shape().length; i++) { if (!Ints.contains(axes[1], i)) listB.add(i); } @@ -172,9 +174,9 @@ public class TensorMmul extends DynamicCustomOp { int[] newAxesB = Ints.concat(axes[1], Ints.toArray(listB)); int n2 = 1; - int aLength = Math.min(a.getShape().length, axes[0].length); + int aLength = Math.min(aArray.shape().length, axes[0].length); for (int i = 0; i < aLength; i++) { - n2 *= a.getShape()[axes[0][i]]; + n2 *= aArray.shape()[axes[0][i]]; } //if listA and listB are empty these do not initialize. @@ -186,13 +188,13 @@ public class TensorMmul extends DynamicCustomOp { } else { oldShapeA = Longs.toArray(listA); for (int i = 0; i < oldShapeA.length; i++) - oldShapeA[i] = a.getShape()[(int) oldShapeA[i]]; + oldShapeA[i] = aArray.shape()[(int) oldShapeA[i]]; } int n3 = 1; - int bNax = Math.min(b.getShape().length, axes[1].length); + int bNax = Math.min(bArray.shape().length, axes[1].length); for (int i = 0; i < bNax; i++) { - n3 *= b.getShape()[axes[1][i]]; + n3 *= bArray.shape()[axes[1][i]]; } @@ -203,7 +205,7 @@ public class TensorMmul extends DynamicCustomOp { } else { oldShapeB = Longs.toArray(listB); for (int i = 0; i < oldShapeB.length; i++) - oldShapeB[i] = b.getShape()[(int) oldShapeB[i]]; + oldShapeB[i] = bArray.shape()[(int) oldShapeB[i]]; } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/bp/BaseReductionBp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/bp/BaseReductionBp.java index 179faa647..a66225c07 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/bp/BaseReductionBp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/bp/BaseReductionBp.java @@ -16,6 +16,7 @@ package org.nd4j.linalg.api.ops.impl.reduce.bp; +import lombok.NoArgsConstructor; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.base.Preconditions; @@ -30,7 +31,7 @@ import java.util.List; /** * @author Alex Black */ - +@NoArgsConstructor public abstract class BaseReductionBp extends DynamicCustomOp { protected boolean keepDims; @@ -96,7 +97,12 @@ public abstract class BaseReductionBp extends DynamicCustomOp { addArgs(); } - public BaseReductionBp(){} + public BaseReductionBp(INDArray origInput1, INDArray origInput2, INDArray gradAtOutput, INDArray output1, INDArray output2, boolean keepDims, int... dimensions){ + super(null, new INDArray[]{origInput1, origInput2, gradAtOutput}, new INDArray[]{output1, output2}); + this.keepDims = keepDims; + this.dimensions = dimensions; + addArgs(); + } protected void addArgs(){ addTArgument(keepDims ? 1 : 0); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/bp/DotBp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/bp/DotBp.java index feb0cb53c..88e78dd2a 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/bp/DotBp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/bp/DotBp.java @@ -16,17 +16,23 @@ package org.nd4j.linalg.api.ops.impl.reduce.bp; +import lombok.NoArgsConstructor; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.base.Preconditions; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; +import java.util.Arrays; +import java.util.List; + /** * Backprop op for Dot pairwise reduction operation * * @author Alex Black */ - +@NoArgsConstructor public class DotBp extends BaseReductionBp { public DotBp(SameDiff sameDiff, SDVariable origInput1, SDVariable origInput2, SDVariable gradAtOutput, boolean keepDims, int... dimensions) { @@ -37,10 +43,22 @@ public class DotBp extends BaseReductionBp { super(origInput1, origInput2, gradAtOutput, output, keepDims, dimensions); } - public DotBp(){} + public DotBp(INDArray origInput1, INDArray origInput2, INDArray gradAtOutput, + INDArray outputX, INDArray outputY, boolean keepDims, int... dimensions) { + super(origInput1, origInput2, gradAtOutput, outputX, outputY, keepDims, dimensions); + } @Override public String opName() { return "reduce_dot_bp"; } + + @Override + public List calculateOutputDataTypes(List dataTypes){ + Preconditions.checkState(dataTypes != null && dataTypes.size() == 3, "Expected exactly 3 input datatype for %s, got input %s", getClass(), dataTypes); + Preconditions.checkState(dataTypes.get(0).isFPType(), "First input must be a floating point type, got %s", dataTypes.get(0)); + Preconditions.checkState(dataTypes.get(1).isFPType(), "Second input (gradient at reduction output) must be a floating point type, got %s", dataTypes.get(1)); + Preconditions.checkState(dataTypes.get(2).isFPType(), "Second input (gradient at reduction output) must be a floating point type, got %s", dataTypes.get(2)); + return Arrays.asList(dataTypes.get(0), dataTypes.get(0)); + } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/floating/Bias.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/floating/Bias.java deleted file mode 100644 index 39d36ace2..000000000 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/floating/Bias.java +++ /dev/null @@ -1,86 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.nd4j.linalg.api.ops.impl.reduce.floating; - -import org.nd4j.autodiff.samediff.SDVariable; -import org.nd4j.autodiff.samediff.SameDiff; -import org.nd4j.imports.NoOpNameFoundException; -import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.api.ops.BaseReduceFloatOp; -import org.nd4j.linalg.api.ops.BaseReduceOp; - -import java.util.LinkedHashMap; -import java.util.List; -import java.util.Map; - -/** - * Calculate a bias - * - * @author Adam Gibson - */ -public class Bias extends BaseReduceFloatOp { - - private double mean; - - public Bias(SameDiff sameDiff, SDVariable i_v, int[] dimensions, double mean) { - super(sameDiff, i_v, dimensions); - this.mean = mean; - } - - public Bias(SameDiff sameDiff, SDVariable i_v, SDVariable i_v2, int[] dimensions, double mean) { - super(sameDiff, i_v, i_v2, dimensions); - this.mean = mean; - } - - public Bias() {} - - public Bias(INDArray x, int... dimensions) { - super(x, dimensions); - } - - @Override - public Map propertiesForFunction() { - Map ret = new LinkedHashMap<>(); - ret.put("mean",mean); - return ret; - } - - @Override - public int opNum() { - return 2; - } - - @Override - public String opName() { - return "bias"; - } - - @Override - public List doDiff(List f1) { - return null; - } - - @Override - public String onnxName() { - throw new NoOpNameFoundException("No onnx op opName found for " + opName()); - } - - @Override - public String tensorflowName() { - throw new NoOpNameFoundException("No tensorflow op opName found for " + opName()); - } -} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/SequenceMask.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/SequenceMask.java index 60da43dc0..3d1339c7e 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/SequenceMask.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/SequenceMask.java @@ -24,6 +24,7 @@ import org.nd4j.base.Preconditions; import org.nd4j.imports.descriptors.properties.PropertyMapping; import org.nd4j.imports.graphmapper.tf.TFGraphMapper; import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.tensorflow.framework.AttrValue; import org.tensorflow.framework.GraphDef; @@ -45,6 +46,7 @@ public class SequenceMask extends DynamicCustomOp { public SequenceMask(SameDiff sameDiff, SDVariable input, SDVariable maxLen, DataType dataType) { super(null, sameDiff, new SDVariable[] {input, maxLen}, false); this.dataType = dataType; + addDArgument(dataType); } public SequenceMask(SameDiff sameDiff, SDVariable input, int maxLen, DataType dataType) { @@ -53,13 +55,23 @@ public class SequenceMask extends DynamicCustomOp { this.is_static_maxlen = true; addIArgument(maxLen); this.dataType = dataType; + addDArgument(dataType); } public SequenceMask(SameDiff sameDiff, SDVariable input, DataType dataType) { super(null, sameDiff, new SDVariable[] {input}, false); this.dataType = dataType; + addDArgument(dataType); } - + + public SequenceMask(INDArray input, int maxLen, DataType dataType) { + addInputArgument(input); + addIArgument(maxLen); + //addIArgument(dataType.toInt()); + addDArgument(dataType); + this.dataType = dataType; + } + @Override public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map attributesForNode, GraphDef graph) { diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/ZerosLike.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/ZerosLike.java index de411ae2a..7225ac355 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/ZerosLike.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/ZerosLike.java @@ -16,6 +16,7 @@ package org.nd4j.linalg.api.ops.impl.shape; +import lombok.NoArgsConstructor; import lombok.extern.slf4j.Slf4j; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; @@ -39,23 +40,37 @@ import java.util.Map; * @author Adam Gibson */ @Slf4j +@NoArgsConstructor public class ZerosLike extends DynamicCustomOp { protected DataType outputType; //Allow customizing dtype for TF import - public ZerosLike() { + public ZerosLike(String name, SameDiff sameDiff, SDVariable input) { + this(name, sameDiff, input, false, input.dataType()); } - public ZerosLike(String name, SameDiff sameDiff, SDVariable input) { - this(name, sameDiff, input, false); + public ZerosLike(String name, SameDiff sameDiff, SDVariable input, DataType dataType) { + this(name, sameDiff, input, false, dataType); } public ZerosLike(String name, SameDiff sameDiff, SDVariable input, boolean inPlace) { + this(name, sameDiff, input, inPlace, input.dataType()); + } + + public ZerosLike(String name, SameDiff sameDiff, SDVariable input, boolean inPlace, DataType dataType) { super(name, sameDiff, new SDVariable[]{input}, inPlace); + addDArgument(dataType); } public ZerosLike(INDArray in, INDArray out){ + this(in, out, in.dataType()); + } + + public ZerosLike(INDArray in, INDArray out, DataType dataType) { super(null, in, out, null, null); + if (dataType != null) { + addDArgument(dataType); + } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/BatchToSpace.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/BatchToSpace.java index 923f1e2e4..e580fbf8e 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/BatchToSpace.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/BatchToSpace.java @@ -22,6 +22,7 @@ import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ops.DynamicCustomOp; +import org.nd4j.linalg.factory.Nd4j; import java.util.Arrays; import java.util.Collections; @@ -52,16 +53,13 @@ public class BatchToSpace extends DynamicCustomOp { } public BatchToSpace(SameDiff sameDiff, SDVariable[] args, int[] blocks, int[][] crops, boolean inPlace) { - super(null, sameDiff, args, inPlace); + super(null, sameDiff, new SDVariable[]{args[0], sameDiff.constant(Nd4j.createFromArray(crops))}, inPlace); this.blocks = blocks; this.crops = crops; for (val b : blocks) addIArgument(b); - - for (int e = 0; e < crops.length; e++) - addIArgument(crops[e][0], crops[e][1]); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/SpaceToBatch.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/SpaceToBatch.java index 12fe52854..5fa08133c 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/SpaceToBatch.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/SpaceToBatch.java @@ -22,6 +22,7 @@ import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ops.DynamicCustomOp; +import org.nd4j.linalg.factory.Nd4j; import java.util.Arrays; import java.util.Collections; @@ -53,16 +54,12 @@ public class SpaceToBatch extends DynamicCustomOp { } public SpaceToBatch(SameDiff sameDiff, SDVariable[] args, int[] blocks, int[][] padding, boolean inPlace) { - super(null, sameDiff, args, inPlace); + super(null, sameDiff, new SDVariable[]{args[0], sameDiff.constant(Nd4j.createFromArray(padding))}, inPlace); this.blocks = blocks; this.padding = padding; - for (val b : blocks) - addIArgument(b); - - for (int e = 0; e < padding.length; e++) - addIArgument(padding[e][0], padding[e][1]); + addIArgument(blocks[0]); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentMax.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentMax.java index cd69d398f..0e5426c3c 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentMax.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentMax.java @@ -58,7 +58,8 @@ public class UnsortedSegmentMax extends DynamicCustomOp { @Override public List calculateOutputDataTypes(List inputDataTypes){ - Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == 3, "Expected exactly 2 input data types for %s, got %s", getClass(), inputDataTypes); + Preconditions.checkState(inputDataTypes != null && (inputDataTypes.size() == 2 || inputDataTypes.size() == 3), + "Expected exactly 2 input data types for %s, got %s", getClass(), inputDataTypes); return Collections.singletonList(inputDataTypes.get(0)); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentMean.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentMean.java index ffa2c9905..b0b7f4457 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentMean.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentMean.java @@ -16,6 +16,7 @@ package org.nd4j.linalg.api.ops.impl.transforms.segment; +import lombok.NoArgsConstructor; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.base.Preconditions; @@ -31,6 +32,7 @@ import java.util.List; * * @author Alex Black */ +@NoArgsConstructor public class UnsortedSegmentMean extends DynamicCustomOp { private int numSegments; @@ -41,8 +43,6 @@ public class UnsortedSegmentMean extends DynamicCustomOp { addIArgument(numSegments); } - public UnsortedSegmentMean(){ } - @Override public String opName(){ return "unsorted_segment_mean"; @@ -60,7 +60,8 @@ public class UnsortedSegmentMean extends DynamicCustomOp { @Override public List calculateOutputDataTypes(List inputDataTypes){ - Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == 3, "Expected exactly 2 input data types for %s, got %s", getClass(), inputDataTypes); + Preconditions.checkState(inputDataTypes != null && (inputDataTypes.size() == 2 || inputDataTypes.size() == 3), + "Expected exactly 2 input data types for %s, got %s", getClass(), inputDataTypes); return Collections.singletonList(inputDataTypes.get(0)); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentMin.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentMin.java index 3e7a6e562..5b7e1c7e0 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentMin.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentMin.java @@ -16,6 +16,7 @@ package org.nd4j.linalg.api.ops.impl.transforms.segment; +import lombok.NoArgsConstructor; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.base.Preconditions; @@ -31,6 +32,7 @@ import java.util.List; * * @author Alex Black */ +@NoArgsConstructor public class UnsortedSegmentMin extends DynamicCustomOp { private int numSegments; @@ -41,8 +43,6 @@ public class UnsortedSegmentMin extends DynamicCustomOp { addIArgument(numSegments); } - public UnsortedSegmentMin(){ } - @Override public String opName(){ return "unsorted_segment_min"; @@ -60,7 +60,8 @@ public class UnsortedSegmentMin extends DynamicCustomOp { @Override public List calculateOutputDataTypes(List inputDataTypes){ - Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == 3, "Expected exactly 2 input data types for %s, got %s", getClass(), inputDataTypes); + Preconditions.checkState(inputDataTypes != null && (inputDataTypes.size() == 2 || inputDataTypes.size() == 3), + "Expected exactly 2 input data types for %s, got %s", getClass(), inputDataTypes); return Collections.singletonList(inputDataTypes.get(0)); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentProd.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentProd.java index 1a2bc2bac..bca9e1788 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentProd.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentProd.java @@ -16,6 +16,7 @@ package org.nd4j.linalg.api.ops.impl.transforms.segment; +import lombok.NoArgsConstructor; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.base.Preconditions; @@ -31,6 +32,7 @@ import java.util.List; * * @author Alex Black */ +@NoArgsConstructor public class UnsortedSegmentProd extends DynamicCustomOp { private int numSegments; @@ -41,8 +43,6 @@ public class UnsortedSegmentProd extends DynamicCustomOp { addIArgument(numSegments); } - public UnsortedSegmentProd(){ } - @Override public String opName(){ return "unsorted_segment_prod"; @@ -60,7 +60,8 @@ public class UnsortedSegmentProd extends DynamicCustomOp { @Override public List calculateOutputDataTypes(List inputDataTypes){ - Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == 3, "Expected exactly 2 input data types for %s, got %s", getClass(), inputDataTypes); + Preconditions.checkState(inputDataTypes != null && (inputDataTypes.size() == 2 || inputDataTypes.size() == 3), + "Expected exactly 2 input data types for %s, got %s", getClass(), inputDataTypes); return Collections.singletonList(inputDataTypes.get(0)); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentSqrtN.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentSqrtN.java index e995ec427..77474855c 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentSqrtN.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentSqrtN.java @@ -16,10 +16,12 @@ package org.nd4j.linalg.api.ops.impl.transforms.segment; +import lombok.NoArgsConstructor; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; import java.util.ArrayList; @@ -31,18 +33,23 @@ import java.util.List; * * @author Alex Black */ +@NoArgsConstructor public class UnsortedSegmentSqrtN extends DynamicCustomOp { private int numSegments; + public UnsortedSegmentSqrtN(INDArray data, INDArray segmentIds, int numSegments) { + addInputArgument(data, segmentIds); + addIArgument(numSegments); + this.numSegments = numSegments; + } + public UnsortedSegmentSqrtN(SameDiff sameDiff, SDVariable data, SDVariable segmentIds, int numSegments) { super(null, sameDiff, new SDVariable[] {data, segmentIds}, false); this.numSegments = numSegments; addIArgument(numSegments); } - public UnsortedSegmentSqrtN(){ } - @Override public String opName(){ return "unsorted_segment_sqrt_n"; @@ -60,7 +67,8 @@ public class UnsortedSegmentSqrtN extends DynamicCustomOp { @Override public List calculateOutputDataTypes(List inputDataTypes){ - Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == 3, "Expected exactly 2 input data types for %s, got %s", getClass(), inputDataTypes); + Preconditions.checkState(inputDataTypes != null && (inputDataTypes.size() == 2 || inputDataTypes.size() == 3), + "Expected exactly 2 input data types for %s, got %s", getClass(), inputDataTypes); List out = new ArrayList<>(); for( int i=0; i calculateOutputDataTypes(List inputDataTypes){ - Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == 3, "Expected exactly 2 input data types for %s, got %s", getClass(), inputDataTypes); + Preconditions.checkState(inputDataTypes != null && (inputDataTypes.size() == 2 || inputDataTypes.size() == 3), + "Expected exactly 2 input data types for %s, got %s", getClass(), inputDataTypes); //TODO Allow customizing output type return Collections.singletonList(Nd4j.defaultFloatingPointType()); } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/LayerOpValidation.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/LayerOpValidation.java index fde6190a4..9abe0a483 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/LayerOpValidation.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/LayerOpValidation.java @@ -50,6 +50,11 @@ public class LayerOpValidation extends BaseOpValidation { super(backend); } + @Override + public long getTimeoutMilliseconds() { + return 90000L; + } + @Test public void testXwPlusB() { Nd4j.getRandom().setSeed(12345); @@ -319,7 +324,7 @@ public class LayerOpValidation extends BaseOpValidation { @Test public void testIm2Col() { - OpValidationSuite.ignoreFailing(); //TEMPORARY DUE TO JVM CRASH: https://github.com/deeplearning4j/deeplearning4j/issues/6873 + //OpValidationSuite.ignoreFailing(); //TEMPORARY DUE TO JVM CRASH: https://github.com/deeplearning4j/deeplearning4j/issues/6873 Nd4j.getRandom().setSeed(12345); int[][] inputSizes = new int[][]{{1, 3, 8, 8}, {3, 6, 12, 12}}; diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/MiscOpValidation.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/MiscOpValidation.java index 59932d670..1f23e12ec 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/MiscOpValidation.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/MiscOpValidation.java @@ -32,6 +32,9 @@ import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.CustomOp; import org.nd4j.linalg.api.ops.DynamicCustomOp; +import org.nd4j.linalg.api.ops.custom.*; +import org.nd4j.linalg.api.ops.impl.broadcast.BiasAdd; +import org.nd4j.linalg.api.ops.impl.broadcast.BiasAddGrad; import org.nd4j.linalg.api.ops.impl.controlflow.compat.StopGradient; import org.nd4j.linalg.api.ops.impl.reduce.Mmul; import org.nd4j.linalg.api.ops.impl.shape.DiagPart; @@ -513,7 +516,7 @@ public class MiscOpValidation extends BaseOpValidation { @Test public void testTrace(){ //TODO need to work out how to handle shape_op for scalars... - OpValidationSuite.ignoreFailing(); + //OpValidationSuite.ignoreFailing(); Nd4j.getRandom().setSeed(12345); for( int[] inShape : new int[][]{{3,3}}){ @@ -546,12 +549,15 @@ public class MiscOpValidation extends BaseOpValidation { SDVariable x = sameDiff.var("x", arr); SDVariable y = sameDiff.var("y", arr2); SDVariable result = sameDiff.tensorMmul(x, y, new int[][]{{0}, {1}}); - assertArrayEquals(ArrayUtil.getTensorMmulShape(new long[]{2, 2, 2}, new long[]{2, 2, 2}, new int[][]{{0}, {1}}), result.getShape()); - assertEquals(32, sameDiff.numElements()); + assertArrayEquals(ArrayUtil.getTensorMmulShape(new long[]{2, 2, 2}, new long[]{2, 2, 2}, new int[][]{{0}, {1}}), + result.eval().shape()); + assertEquals(16, sameDiff.numElements()); SDVariable loss = sameDiff.standardDeviation(result, true); + sameDiff.addLossVariable(loss); String err = OpValidation.validate(new TestCase(sameDiff)); + assertNull(err); } @Test @@ -1782,4 +1788,338 @@ public class MiscOpValidation extends BaseOpValidation { assertEquals(exp, out); //Values in x not in y assertEquals(exp, outIdx); //Indices of the values in x not in y } + + @Test + public void testDivideNoNan() { + OpValidationSuite.ignoreFailing(); //TODO: implement DivideNoNan.doDiff() + + SameDiff sameDiff = SameDiff.create(); + + INDArray in1 = Nd4j.linspace(1, 12, 12).reshape(3, 4); + INDArray in2 = Nd4j.linspace(1, 12, 12).reshape(3, 4); + + SDVariable input1 = sameDiff.var(in1); + SDVariable input2 = sameDiff.var(in2); + + INDArray expected = Nd4j.ones(3,4); + + SDVariable output = new DivideNoNan(sameDiff, input1, input2).outputVariable(); + + TestCase tc = new TestCase(sameDiff) + .gradientCheck(true) + .expectedOutput(output.name(), expected); + + String err = OpValidation.validate(tc); + assertNull(err); + } + + @Test + public void testDigamma() { + + INDArray in1 = Nd4j.linspace(1, 12, 12).reshape(3, 4); + + INDArray expected = Nd4j.createFromArray(new double[]{ + -0.5772157,0.42278433,0.9227843,1.2561177,1.5061177,1.7061176,1.8727844,2.0156415,2.1406415,2.2517526,2.3517525,2.4426618 + }).reshape(3,4); + + val tc = new OpTestCase(new Digamma(in1)).expectedOutput(0, expected); + + String err = OpValidation.validate(tc); + assertNull(err); + } + + @Test + public void testFlatten() { + + SameDiff sameDiff = SameDiff.create(); + + INDArray x = Nd4j.linspace(DataType.DOUBLE, 1, 27, 1).reshape(3,3,3); + SDVariable sdx = sameDiff.var(x); + + INDArray expected = Nd4j.linspace(DataType.DOUBLE,1,27,1); + + SDVariable output = new Flatten(sameDiff, 'c', sdx).outputVariable(); + SDVariable loss = sameDiff.standardDeviation(sdx, true); + sameDiff.addLossVariable(loss); + + TestCase tc = new TestCase(sameDiff) + .gradientCheck(true) + .expectedOutput(output.name(), expected); + + String err = OpValidation.validate(tc); + assertNull(err); + } + + @Test + public void testFusedBatchNorm() { + OpValidationSuite.ignoreFailing(); + SameDiff sameDiff = SameDiff.create(); + + INDArray x = Nd4j.linspace(DataType.DOUBLE, 1.0, 1.0, 2*2*3*4).reshape(2,2,3,4); + INDArray scale = Nd4j.create(DataType.DOUBLE, 4); + scale.assign(0.5); + INDArray offset = Nd4j.create(DataType.DOUBLE, 4); + offset.assign(2.0); + + SDVariable input1 = sameDiff.var(x); + SDVariable input2 = sameDiff.var(scale); + SDVariable input3 = sameDiff.var(offset); + + INDArray expectedY = Nd4j.createFromArray(new double[]{ + 985.5258, 985.5258, 985.5258, 985.5258, + 659.7321, 659.7321, 659.7321, 659.7321, + 399.0972, 399.0972, 399.0972, 399.0972, + 203.6210, 203.6210, 203.6210, 203.6210, + 73.3036, 73.3036, 73.3036, 73.3036, + 8.1448, 8.1448, 8.1448, 8.1448, + 8.1448, 8.1448, 8.1448, 8.1448, + 73.3036, 73.3036, 73.3036, 73.3036, + 203.6210, 203.6210, 203.6210, 203.6210, + 399.0972, 399.0972, 399.0972, 399.0972, + 659.7321, 659.7321, 659.7321, 659.7321, + 985.5258, 985.5258, 985.5258, 985.5258}).reshape(x.shape()); + INDArray expectedBatchMean = Nd4j.createFromArray(new double[]{23., 24., 25., 26.}); + INDArray expectedBatchVar = Nd4j.createFromArray(new double[]{208.00001526, 208.00001526, 208.00001526, 208.00001526}); + + SDVariable[] outputs = new FusedBatchNorm(sameDiff, input1, input2, input3, 0, 1).outputVariables(); + SDVariable loss = sameDiff.standardDeviation(input1, true); + sameDiff.addLossVariable(loss); + + TestCase tc = new TestCase(sameDiff) + .gradientCheck(true) + .expectedOutput(outputs[0].name(), expectedY) + .expectedOutput(outputs[1].name(), expectedBatchMean) + .expectedOutput(outputs[2].name(), expectedBatchVar); + + String err = OpValidation.validate(tc); + assertNull(err); + } + + @Test + public void testIgamma() { + + INDArray in1 = Nd4j.linspace(1, 12, 12).reshape(3, 4); + INDArray in2 = Nd4j.linspace(1, 12, 12).reshape(3, 4); + + INDArray expected = Nd4j.createFromArray(new double[]{ + 0.63212055,0.59399414,0.5768099,0.56652874,0.5595013,0.5542634,0.5501591,0.5463888,0.54329145,0.54048204,0.5378594,0.53233755 + }).reshape(3,4); + + val tc = new OpTestCase(new Igamma(in1, in2)).expectedOutput(0, expected); + + String err = OpValidation.validate(tc); + assertNull(err); + } + + @Test + public void testIgammaC() { + + INDArray in1 = Nd4j.linspace(1, 12, 12).reshape(3, 4); + INDArray in2 = Nd4j.linspace(1, 12, 12).reshape(3, 4); + + + INDArray expected = Nd4j.createFromArray(new double[]{ + 0.36787945,0.40600586,0.42319012,0.43347126,0.4404987,0.44573656,0.4498409,0.45361117,0.45670855,0.459518,0.46214062,0.46766248 + }).reshape(3,4); + + val tc = new OpTestCase(new Igammac(in1, in2)).expectedOutput(0, expected); + + String err = OpValidation.validate(tc); + assertNull(err); + } + + @Test + public void testLgamma() { + + SameDiff sameDiff = SameDiff.create(); + + INDArray in = Nd4j.linspace(DataType.DOUBLE, 1, 12, 1).reshape(3, 4); + SDVariable sdInput = sameDiff.var(in); + + INDArray expected = Nd4j.createFromArray(new double[]{ + 0.0,0.0,0.6931472,1.7917595,3.1780539,4.787492,6.5792513,8.525162,10.604603,12.801827,15.104413,17.502308 + }).reshape(3,4); + + SDVariable output = new Lgamma(sameDiff, sdInput).outputVariable(); + + SDVariable loss = sameDiff.standardDeviation(sdInput, true); + sameDiff.addLossVariable(loss); + + TestCase tc = new TestCase(sameDiff) + .gradientCheck(true) + .expectedOutput(output.name(), expected); + + String err = OpValidation.validate(tc); + assertNull(err); + } + + @Test + public void testLu() { + + SameDiff sameDiff = SameDiff.create(); + + INDArray in1 = Nd4j.createFromArray(new double[]{ + 1., 2., 3., 0., 2., 3., 0., 0., 7. + }).reshape(3,3); + + SDVariable input1 = sameDiff.var(in1); + + INDArray expected = Nd4j.createFromArray(new double[]{ + 1., 2., 3., 0., 2., 3., 0., 0., 7 + }).reshape(3,3); + + INDArray pexpected = Nd4j.createFromArray(new int[]{ + 0, 1, 2 + }); + + sameDiff.loss.l2Loss(input1); + SDVariable[] output = new Lu(sameDiff, input1).outputVariables(); + + TestCase tc = new TestCase(sameDiff) + .gradientCheck(true) + .expectedOutput(output[0].name(), expected) + .expectedOutput(output[1].name(), pexpected); + + String err = OpValidation.validate(tc); + assertNull(err); + } + + @Test + public void testMatrixBandPart() { + OpValidationSuite.ignoreFailing(); + SameDiff sameDiff = SameDiff.create(); + + INDArray input = Nd4j.createFromArray(new float[]{0.7788f,0.8012f,0.7244f,0.2309f, + 0.7271f,0.1804f,0.5056f,0.8925f, + 0.5461f,0.9234f,0.0856f,0.7938f}).reshape(3,4); + + SDVariable sdInput = sameDiff.var(input); + SDVariable sdInput1 = sameDiff.constant(1); + SDVariable sdInput2 = sameDiff.constant(-1); + + INDArray expected = Nd4j.createFromArray(new float[]{ + 0.7788f, 0.8012f, 0.7244f, 0.2309f, + 0.7271f, 0.1804f, 0.5056f, 0.8925f, + 0.f, 0.9234f, 0.0856f, 0.7938f + }).reshape(3,4); + + sameDiff.loss.l2Loss(sdInput); + SDVariable output = new MatrixBandPart(sameDiff, sdInput, 1, -1).outputVariable(); + + TestCase tc = new TestCase(sameDiff) + .gradientCheck(true) + .expectedOutput(output.name(), expected); + + String err = OpValidation.validate(tc); + assertNull(err); + } + + @Test + public void testPolygamma() { + + INDArray in1 = Nd4j.linspace(1, 12, 12).reshape(3, 4); + INDArray in2 = Nd4j.linspace(1, 12, 12).reshape(3, 4); + + INDArray expected = Nd4j.createFromArray(new double[]{ + 1.644934,-0.4041138,0.1189394,-0.03750069,0.01226151,-0.0041002957,0.001392272,-4.780109E-4,1.6549716E-4,-5.7675967E-5,2.0206635E-5,-7.1101636E-6 + }).reshape(3,4); + + val tc = new OpTestCase(new Polygamma(in1, in2)).expectedOutput(0, expected); + + String err = OpValidation.validate(tc); + assertNull(err); + } + + @Test + public void testTriangularSolve() { + + INDArray a = Nd4j.createFromArray(new float[]{ + 3.f, 0.f, 0.f, 0.f, + 2.f, 1.f, 0.f, 0.f, + 1.f, 0.f, 1.f, 0.f, + 1.f, 1.f, 1.f, 1.f + }).reshape(4,4); + + INDArray b = Nd4j.createFromArray(new float[]{ + 4.f, 2.f, 4.f, 2.f + }).reshape(4,1); + + INDArray expected = Nd4j.createFromArray(new float[]{ + 1.333333f, 2.0f, 4.0f, 2.0f + }).reshape(4,1); + + val tc = new OpTestCase(new TriangularSolve(a, b, false, true)).expectedOutput(0, expected); + + String err = OpValidation.validate(tc); + assertNull(err); + } + + @Test + public void testBiasAdd() { + + SameDiff sameDiff = SameDiff.create(); + + INDArray in1 = Nd4j.linspace(1, 12, 12); + INDArray in2 = Nd4j.linspace(1, 12, 12); + + SDVariable input1 = sameDiff.var(in1); + SDVariable input2 = sameDiff.var(in2); + + INDArray expected = Nd4j.createFromArray(new double[]{ + 2.0000, 4.0000, 6.0000, 8.0000, 10.0000, 12.0000, 14.0000, 16.0000, 18.0000, 20.0000, 22.0000, 24.0000 + }); + + SDVariable output = new BiasAdd(sameDiff, input1, input2, false).outputVariable(); + SDVariable loss = sameDiff.standardDeviation(input1, true); + sameDiff.addLossVariable(loss); + SDVariable loss2 = sameDiff.standardDeviation(input2, true); + sameDiff.addLossVariable(loss2); + + TestCase tc = new TestCase(sameDiff) + .gradientCheck(true) + .expectedOutput(output.name(), expected); + + String err = OpValidation.validate(tc); + assertNull(err); + } + + @Test + public void testBiasAddGrad() { + + SameDiff sameDiff = SameDiff.create(); + + INDArray x = Nd4j.linspace(DataType.FLOAT,1, 24, 24).reshape(2,2,2,3); + INDArray grad = Nd4j.linspace(DataType.FLOAT, 0.1, 0.1, 24).reshape(2,2,2,3); + + INDArray bias = Nd4j.createFromArray(new float[]{-1.f, -2.f, -3.f}); + + INDArray expected = Nd4j.createFromArray(new float[]{9.2f, 10.f , 10.8f}); + + OpTestCase tc = new OpTestCase(new BiasAddGrad(x, bias, grad,false)). + expectedOutput(0, grad). + expectedOutput(1, expected); + + String err = OpValidation.validate(tc); + assertNull(err); + } + + @Test + public void testRoll() { + + INDArray x = Nd4j.createFromArray(new double[]{ 11.11, 11.12, 11.21, 11.22, 11.31, 11.32, 11.41, 11.42, 12.11, 12.12, 12.21, 12.22, 12.31, 12.32, 12.41, 12.42, + 21.11, 21.12, 21.21, 21.22, 21.31, 21.32, 21.41, 21.42, 22.11, 22.12, 22.21, 22.22, 22.31, 22.32, 22.41, 22.42}). + reshape(2,2,4,2); + + INDArray expected = Nd4j.createFromArray(new double[]{ 22.21, 22.22, 22.31, 22.32, 22.41, 22.42, 11.11, 11.12, 11.21, 11.22, 11.31, 11.32, 11.41, 11.42, + 12.11, 12.12, 12.21, 12.22, 12.31, 12.32, 12.41, 12.42, 21.11, 21.12, 21.21, 21.22, 21.31, 21.32, + 21.41, 21.42, 22.11, 22.12 + }).reshape(x.shape()); + + int shift = 6; + + val tc = new OpTestCase(new Roll(x,shift)).expectedOutput(0,expected); + String err = OpValidation.validate(tc); + + assertNull(err); + } } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/ReductionOpValidation.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/ReductionOpValidation.java index 6c6ba5b83..3027138a1 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/ReductionOpValidation.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/ReductionOpValidation.java @@ -35,16 +35,20 @@ import org.nd4j.linalg.api.ops.CustomOp; import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.api.ops.impl.indexaccum.IAMax; import org.nd4j.linalg.api.ops.impl.indexaccum.IAMin; +import org.nd4j.linalg.api.ops.impl.loss.SoftmaxCrossEntropyWithLogitsLoss; import org.nd4j.linalg.api.ops.impl.reduce.Moments; import org.nd4j.linalg.api.ops.impl.reduce.NormalizeMoments; -import org.nd4j.linalg.api.ops.impl.reduce.floating.AMean; +import org.nd4j.linalg.api.ops.impl.reduce.SufficientStatistics; +import org.nd4j.linalg.api.ops.impl.reduce.floating.*; import org.nd4j.linalg.api.ops.impl.reduce.same.ASum; import org.nd4j.linalg.api.ops.impl.reduce3.*; +import org.nd4j.linalg.api.ops.impl.summarystats.StandardDeviation; import org.nd4j.linalg.api.ops.impl.transforms.custom.SoftMax; import org.nd4j.linalg.checkutil.NDArrayCreationUtil; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4jBackend; import org.nd4j.linalg.indexing.BooleanIndexing; +import org.nd4j.linalg.indexing.NDArrayIndex; import org.nd4j.linalg.indexing.conditions.Conditions; import org.nd4j.linalg.ops.transforms.Transforms; import org.nd4j.linalg.primitives.Pair; @@ -96,7 +100,7 @@ public class ReductionOpValidation extends BaseOpValidation { @Test public void testZeroCount() { List allFailed = new ArrayList<>(); - for (int i = 0; i < 2; i++) { + for (int i = 0; i < 21; i++) { SameDiff sd = SameDiff.create(); INDArray ia; @@ -159,25 +163,25 @@ public class ReductionOpValidation extends BaseOpValidation { @Test public void testReductionGradientsSimple() { - OpValidationSuite.ignoreFailing(); //TODO TEMPORARY DUE TO CRASHES + //OpValidationSuite.ignoreFailing(); //TODO TEMPORARY DUE TO CRASHES //Test reductions: final and only function Nd4j.getRandom().setSeed(12345); List failed = new ArrayList<>(); - for (int i = 0; i < 21; i++) { SameDiff sd = SameDiff.create(); int nOut = 4; int minibatch = 10; - SDVariable input = sd.var("in", -1, nOut); + SDVariable input = sd.var("in", minibatch, nOut); INDArray inputArr = Nd4j.randn(minibatch, nOut).muli(100); long length = nOut * minibatch; SDVariable loss; String name; TestCase tc = new TestCase(sd); + boolean gradCheck = true; switch (i) { case 0: loss = sd.mean("loss", input); @@ -234,11 +238,13 @@ public class ReductionOpValidation extends BaseOpValidation { loss = sd.math().countNonZero("loss", input); name = "countNonZero"; tc.expectedOutput("loss", Nd4j.scalar(inputArr.length())); + gradCheck = false; //Long out, not floating point break; case 11: loss = sd.math().countZero("loss", input); name = "countZero"; - tc.expectedOutput("loss", Nd4j.scalar(0)); + tc.expectedOutput("loss", Nd4j.scalar(0L)); + gradCheck = false; //Long out, not floating point break; case 12: loss = sd.math().amax("loss", input); @@ -272,7 +278,7 @@ public class ReductionOpValidation extends BaseOpValidation { loss = sd.math().logSumExp("loss", input); INDArray expArr = Transforms.exp(inputArr); double sum = expArr.sumNumber().doubleValue(); - tc.expected("loss", Nd4j.create(new double[]{Math.log(sum)})); + tc.expected("loss", Nd4j.scalar(Math.log(sum))); break; case 18: inputArr = Nd4j.rand(minibatch, nOut); @@ -307,9 +313,15 @@ public class ReductionOpValidation extends BaseOpValidation { log.info("*** Starting test: " + msg); sd.associateArrayWithVariable(inputArr, input); - + if(gradCheck) { + sd.addLossVariable(loss); + } tc.testName(msg); + if(!gradCheck){ + tc.gradientCheck(false); + } + String error = OpValidation.validate(tc, true); if (error != null) failed.add(error); @@ -629,14 +641,14 @@ public class ReductionOpValidation extends BaseOpValidation { List failed = new ArrayList<>(); for (int[] reduceDims : new int[][]{{Integer.MAX_VALUE}, {0, 1, 2}, {0}, {1}, {2}, {0, 1}, {0, 2}, {1, 2}}) { - for (int i = 6; i < 7; i++) { + for (int i = 0; i < 7; i++) { SameDiff sd = SameDiff.create(); sd.setLogExecution(false); - SDVariable in = sd.var("in", -1, d1, d2); - SDVariable in2 = sd.var("in2", -1, d1, d2); + SDVariable in = sd.var("in", d1, d1, d2); + SDVariable in2 = sd.var("in2", d0, d1, d2); INDArray inArr = Nd4j.randn(new int[]{d0, d1, d2}).muli(100); INDArray in2Arr = Nd4j.randn(inArr.shape()).muli(100); @@ -645,40 +657,43 @@ public class ReductionOpValidation extends BaseOpValidation { SDVariable reduced; String name; TestCase tc = new TestCase(sd); + Double maxRelError = null; switch (i) { case 0: reduced = sd.math().manhattanDistance(in, in2, reduceDims); name = "manhattan"; - exp = Nd4j.getExecutioner().exec(new ManhattanDistance(inArr, in2Arr, null, true, false, reduceDims)); + exp = Nd4j.getExecutioner().exec(new ManhattanDistance(inArr, in2Arr, null, false, false, reduceDims)); break; case 1: reduced = sd.math().euclideanDistance(in, in2, reduceDims); name = "euclidean"; - exp = Nd4j.getExecutioner().exec(new EuclideanDistance(inArr, in2Arr, null, true, false, reduceDims)); + exp = Nd4j.getExecutioner().exec(new EuclideanDistance(inArr, in2Arr, null, false, false, reduceDims)); break; case 2: inArr.muli(1e-4); in2Arr.muli(1e-4); reduced = sd.math().cosineSimilarity(in, in2, reduceDims); name = "cosine"; - exp = Nd4j.getExecutioner().exec(new CosineSimilarity(inArr, in2Arr, null, true, false, reduceDims)); + exp = Nd4j.getExecutioner().exec(new CosineSimilarity(inArr, in2Arr, null, false, false, reduceDims)); + maxRelError = 1e-4; break; case 3: reduced = sd.math().cosineDistance(in, in2, reduceDims); name = "cosinedistance"; - exp = Nd4j.getExecutioner().exec(new CosineDistance(inArr, in2Arr, null, true, false, reduceDims)); + exp = Nd4j.getExecutioner().exec(new CosineDistance(inArr, in2Arr, null, false, false, reduceDims)); + maxRelError = 1e-4; break; case 4: reduced = sd.math().hammingDistance(in, in2, reduceDims); name = "hamming"; - exp = Nd4j.getExecutioner().exec(new HammingDistance(inArr, in2Arr, null, true, false, reduceDims)); + exp = Nd4j.getExecutioner().exec(new HammingDistance(inArr, in2Arr, null, false, false, reduceDims)); break; case 5: name = "jaccard"; reduced = sd.math().jaccardDistance(name, in, in2, reduceDims); inArr.divi(100).addi(0.1); in2Arr.divi(100).addi(0.1); - exp = Nd4j.getExecutioner().exec(new JaccardDistance(inArr, in2Arr, null, true, false, reduceDims)); + exp = Nd4j.getExecutioner().exec(new JaccardDistance(inArr, in2Arr, null, false, false, reduceDims)); if (OpValidationSuite.IGNORE_FAILING && reduceDims.length == 2) continue; @@ -708,6 +723,9 @@ public class ReductionOpValidation extends BaseOpValidation { tc.expected(reduced, exp); + if(maxRelError != null) + tc.gradCheckMaxRelativeError(maxRelError); + String error = OpValidation.validate(tc, true); if (error != null) { failed.add(msg + " - " + error); @@ -768,7 +786,6 @@ public class ReductionOpValidation extends BaseOpValidation { } @Test - @Ignore("AB 2019/06/24 - Failing: Ignored to get to all passing baseline to prevent regressions via CI - see issue #7912") public void testNormalizeMomentsOp() { INDArray data = Nd4j.linspace(1, 100, 100, DataType.DOUBLE).reshape(10, 10); INDArray ssSum = data.sum(0); @@ -780,7 +797,7 @@ public class ReductionOpValidation extends BaseOpValidation { INDArray mean = Nd4j.createUninitialized(DataType.DOUBLE, meanExp.shape()); INDArray var = Nd4j.createUninitialized(DataType.DOUBLE, varExp.shape()); - OpTestCase op = new OpTestCase(new NormalizeMoments(Nd4j.scalar(DataType.INT, 10), ssSum, ssSqSum, mean, var)); + OpTestCase op = new OpTestCase(new NormalizeMoments(Nd4j.scalar(DataType.DOUBLE, 10), ssSum, ssSqSum, mean, var)); op.expectedOutput(0, meanExp); op.expectedOutput(1, varExp); @@ -821,7 +838,7 @@ public class ReductionOpValidation extends BaseOpValidation { List failed = new ArrayList<>(); List dims = Arrays.asList(new int[]{0}, new int[]{1}, new int[]{0, 1}, new int[0]); - INDArray in = Nd4j.rand(3, 4); + INDArray in = Nd4j.rand(DataType.DOUBLE,3, 4); for (int t = 0; t < 4; t++) { int[] d = dims.get(t); @@ -838,52 +855,47 @@ public class ReductionOpValidation extends BaseOpValidation { switch (i) { case 0: reduce = s.argmax(dim); - exp = Nd4j.argMax(in, dim).castTo(DataType.DOUBLE); + exp = Nd4j.argMax(in, dim); name = "argmax"; break; case 1: reduce = s.argmin(dim); - exp = Nd4j.argMin(in, dim).castTo(DataType.DOUBLE); + exp = Nd4j.argMin(in, dim); name = "argmin"; break; case 2: reduce = sd.math().iamax(s, dim); exp = Nd4j.getExecutioner().exec(new IAMax(in.dup(), dim)); - exp = exp.castTo(DataType.DOUBLE); name = "iamax"; break; case 3: reduce = sd.math().iamin(s, dim); exp = Nd4j.getExecutioner().exec(new IAMin(in.dup(), dim)); - exp = exp.castTo(DataType.DOUBLE); name = "iamin"; break; case 4: reduce = sd.math().firstIndex(s, Conditions.greaterThan(0), dim); - exp = in.sum(dim).assign(0); - exp = exp.castTo(DataType.DOUBLE); + exp = in.sum(dim).assign(0).castTo(DataType.INT64); name = "firstindex"; break; case 5: reduce = sd.math().lastIndex(s, Conditions.greaterThan(0), dim); - if (t == 0) exp = Nd4j.create(new double[]{2, 2, 2, 2}); - else if (t == 1) exp = Nd4j.create(new double[]{3, 3, 3}); - else exp = Nd4j.scalar(11.0); - exp = exp.castTo(DataType.DOUBLE); + if (t == 0) exp = Nd4j.createFromArray(2L, 2, 2, 2); + else if (t == 1) exp = Nd4j.createFromArray(3L, 3, 3); + else exp = Nd4j.scalar(11L); name = "lastindex"; break; case 6: reduce = sd.matchConditionCount("count", s, Conditions.greaterThan(0), false, dim); - if (t == 0) exp = Nd4j.create(new double[]{3, 3, 3, 3}); - else if (t == 1) exp = Nd4j.create(new double[]{4, 4, 4}); - else exp = Nd4j.scalar(12.0); - exp = exp.castTo(DataType.DOUBLE); + if (t == 0) exp = Nd4j.createFromArray(3L, 3, 3, 3); + else if (t == 1) exp = Nd4j.createFromArray(4L, 4, 4); + else exp = Nd4j.scalar(12L); name = "matchConditionCount"; break; default: throw new RuntimeException(); } - + SDVariable preCast = reduce; reduce = reduce.castTo(DataType.DOUBLE); SDVariable loss; @@ -894,7 +906,7 @@ public class ReductionOpValidation extends BaseOpValidation { } TestCase tc = new TestCase(sd) - .expected(reduce, exp) + .expected(preCast, exp) .gradientCheck(false) .testName(name + " - " + (dim == null ? null : Arrays.toString(dim))); @@ -1335,4 +1347,254 @@ public class ReductionOpValidation extends BaseOpValidation { } } } + @Test + public void testSufficientStatisticsOp() { + INDArray data = Nd4j.createFromArray(new double[]{ + 5.5, 0., 0.3, 5.5,1.5, 0., 1.3, 6.5,8.6, 0., 0., 0.4,2.5, 1., 0.3, 4.5,1.5, 1., + 1.3, 1.5,3.5, 0., 1.3, 2.5,2.6, 2., 3., 1.4,4.5, 1., 0.3, 0.5 + }).reshape(2,2,2,4); + INDArray axes = Nd4j.linspace(DataType.LONG, 0, 3, 1); + + OpTestCase op = new OpTestCase(new SufficientStatistics(data, axes)); + + INDArray expected1 = Nd4j.scalar(8.0); + INDArray expected2 = Nd4j.createFromArray(new double[]{ + 30.2, 5., 7.8, 22.8 + }); + INDArray expected3 = Nd4j.createFromArray(new double[]{ + 154.22, 7., 14.34, 103.62 + }); + + op.expectedOutput(0, expected1); + op.expectedOutput(1, expected2); + op.expectedOutput(2, expected3); + + String err = OpValidation.validate(op); + assertNull(err); + } + + @Test + public void testStandardDeviation() { + + for (boolean keepDims : new boolean[]{false, true}) { + SameDiff sameDiff = SameDiff.create(); + + INDArray in = Nd4j.linspace(1, 8, 8).reshape(2, 4); + SDVariable input = sameDiff.var(in); + INDArray expected = Nd4j.createFromArray(new double[]{ + 2, 2, 2, 2 + }); + + if(keepDims){ + expected = expected.reshape(1,4); + } + + SDVariable output = new StandardDeviation(sameDiff, input, false, keepDims, new int[]{0}).outputVariable(); + + TestCase tc = new TestCase(sameDiff) + .gradientCheck(true) + .expectedOutput(output.name(), expected); + + String err = OpValidation.validate(tc); + assertNull(err); + } + } + + @Test + public void testSquaredNorm() { + + for (boolean keepDims : new boolean[]{false, true}) { + SameDiff sameDiff = SameDiff.create(); + + INDArray in = Nd4j.linspace(1, 4, 4); + SDVariable input = sameDiff.var(in); + INDArray expected = Nd4j.scalar(30.0000); + if(keepDims) + expected = expected.reshape(1); + + SDVariable output = new SquaredNorm(sameDiff, input, keepDims, new int[]{0}).outputVariable(); + + TestCase tc = new TestCase(sameDiff) + .gradientCheck(true) + .expectedOutput(output.name(), expected); + + String err = OpValidation.validate(tc); + assertNull(err); + } + } + + @Test + public void testShannonEntropy() { + OpValidationSuite.ignoreFailing(); //AB 2020/02/11 https://github.com/eclipse/deeplearning4j/issues/8695 + + SameDiff sameDiff = SameDiff.create(); + + INDArray in = Nd4j.linspace(1, 4, 4).castTo(DataType.DOUBLE); + SDVariable input = sameDiff.var(in); + INDArray expected = Nd4j.scalar(-69.68162); + + SDVariable output = new ShannonEntropy(sameDiff, input, new int[]{0}).outputVariable(); + + TestCase tc = new TestCase(sameDiff) + .gradientCheck(true) + .expectedOutput(output.name(), expected); + + String err = OpValidation.validate(tc); + assertNull(err); + } + + @Test + public void testEntropy() { + + SameDiff sameDiff = SameDiff.create(); + + INDArray in = Nd4j.linspace(1, 4, 4); + SDVariable input = sameDiff.var(in); + double expected = -10.2273; + + SDVariable output = new Entropy(sameDiff, input, new int[]{0}).outputVariable(); + + TestCase tc = new TestCase(sameDiff) + .gradientCheck(true) + .expectedOutput(output.name(), Nd4j.scalar(expected)); + + String err = OpValidation.validate(tc); + assertNull(err); + } + + @Test + public void testAMean() { + + SameDiff sameDiff = SameDiff.create(); + + INDArray in = Nd4j.linspace(1, 12, 12).reshape(3, 4); + SDVariable input = sameDiff.var(in); + INDArray expected = Nd4j.createFromArray(new double[]{ + 5.0000, 6.0000, 7.0000, 8.0000 + }); + + SDVariable output = new AMean(sameDiff, input, new int[]{0}).outputVariable(); + + TestCase tc = new TestCase(sameDiff) + .gradientCheck(true) + .expectedOutput(output.name(), expected); + + String err = OpValidation.validate(tc); + assertNull(err); + } + + @Test + public void testMean() { + + SameDiff sameDiff = SameDiff.create(); + + INDArray in = Nd4j.linspace(1, 12, 12).reshape(3, 4); + SDVariable input = sameDiff.var(in); + INDArray expected = Nd4j.createFromArray(new double[]{ + 5.0000, 6.0000, 7.0000, 8.0000 + }); + + SDVariable output = new Mean(sameDiff, input, false, new int[]{0}).outputVariable(); + + TestCase tc = new TestCase(sameDiff) + .gradientCheck(true) + .expectedOutput(output.name(), expected); + + String err = OpValidation.validate(tc); + assertNull(err); + } + + @Test + public void testNorm1() { + + SameDiff sameDiff = SameDiff.create(); + + INDArray in = Nd4j.linspace(1, 12, 12).reshape(3, 4); + SDVariable input = sameDiff.var(in); + INDArray expected = Nd4j.createFromArray(new double[]{ + 15.0000, 18.0000, 21.0000, 24.0000 + }); + + SDVariable output = new Norm1(sameDiff, input, false, new int[]{0}).outputVariable(); + + TestCase tc = new TestCase(sameDiff) + .gradientCheck(true) + .expectedOutput(output.name(), expected); + + String err = OpValidation.validate(tc); + assertNull(err); + } + + @Test + public void testNorm2() { + + SameDiff sameDiff = SameDiff.create(); + + INDArray in = Nd4j.linspace(1, 12, 12).reshape(3, 4); + SDVariable input = sameDiff.var(in); + INDArray expected = Nd4j.createFromArray(new double[]{ + 10.3441, 11.8322, 13.3791, 14.9666 + }); + + SDVariable output = new Norm2(sameDiff, input, false, new int[]{0}).outputVariable(); + + TestCase tc = new TestCase(sameDiff) + .gradientCheck(true) + .expectedOutput(output.name(), expected); + + String err = OpValidation.validate(tc); + assertNull(err); + } + + @Test + public void testNormMax() { + + SameDiff sameDiff = SameDiff.create(); + + INDArray in = Nd4j.linspace(1, 12, 12).reshape(3, 4); + SDVariable input = sameDiff.var(in); + INDArray expected = Nd4j.createFromArray(new double[]{ + 9.0000, 10.0000, 11.0000, 12.0000 + }); + + SDVariable output = new NormMax(sameDiff, input, false, new int[]{0}).outputVariable(); + + TestCase tc = new TestCase(sameDiff) + .gradientCheck(true) + .expectedOutput(output.name(), expected); + + String err = OpValidation.validate(tc); + assertNull(err); + } + + @Test + public void testSoftmaxCrossEntropyWithLogitsLoss() { + OpValidationSuite.ignoreFailing(); + + SameDiff sameDiff = SameDiff.create(); + + INDArray labels = Nd4j.createFromArray(new double[]{ + 0,1,1,0,0,0,1,0,1,0,1,1,1,0,1,0,1,0,0,1,1,0,1,0 + }).reshape(2,3,4); + + INDArray logits = Nd4j.linspace(DataType.DOUBLE, 0.1, 0.1, 24).reshape(2,3,4); + INDArray expected = Nd4j.createFromArray(new double[]{ + 0.26328, 1.46328, 1.72656, 0. , 0.26328, 0. , 1.46328, 0.26328, 1.72656, 0. , 1.72656, 1.46328 + }).reshape(3,4); + + SDVariable sdLogits = sameDiff.var("logits", logits); + SDVariable sdLabels = sameDiff.var("labels", labels); + SDVariable loss = sameDiff.math().abs(sdLogits); + + + SDVariable output = new SoftmaxCrossEntropyWithLogitsLoss(sameDiff, sdLogits, sdLabels, 0).outputVariable(); + sameDiff.setLossVariables(output); + + TestCase tc = new TestCase(sameDiff) + .gradientCheck(true) + .expectedOutput(output.name(), expected); + + String err = OpValidation.validate(tc); + assertNull(err); + } } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/ShapeOpValidation.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/ShapeOpValidation.java index 071551719..733628490 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/ShapeOpValidation.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/ShapeOpValidation.java @@ -1016,7 +1016,7 @@ public class ShapeOpValidation extends BaseOpValidation { @Test public void testConstant(){ - OpValidationSuite.ignoreFailing(); + //OpValidationSuite.ignoreFailing(); //Case 0: no shape SameDiff sd = SameDiff.create(); @@ -1035,7 +1035,9 @@ public class ShapeOpValidation extends BaseOpValidation { INDArray exp = Nd4j.valueArrayOf(new long[]{3,4,5}, 3.0); loss = constant.std(true); - assertNull(OpValidation.validate(new TestCase(sd).expected(constant, ia))); + assertNull(OpValidation.validate(new TestCase(sd) + .gradientCheck(false) + .expected(constant, Nd4j.create(DataType.FLOAT, 3,4,5)))); } @@ -1272,7 +1274,7 @@ public class ShapeOpValidation extends BaseOpValidation { SameDiff sd = SameDiff.create(); SDVariable data = sd.var("data", d); - SDVariable segments = sd.var("segments", s); + SDVariable segments = sd.constant("segments", s); SDVariable sm; INDArray exp; @@ -1326,6 +1328,7 @@ public class ShapeOpValidation extends BaseOpValidation { } SDVariable loss = sm.std(true); + sd.addLossVariable(loss); TestCase tc = new TestCase(sd) .testName(op) @@ -1363,17 +1366,19 @@ public class ShapeOpValidation extends BaseOpValidation { @Test public void testSequenceMask() { - OpValidationSuite.ignoreFailing(); //2018-01-09: output datatype issue? SameDiff sameDiff = SameDiff.create(); - INDArray arr = Nd4j.create(new float[] {1, 3, 2}).reshape(3); - SDVariable lengths = sameDiff.var("lengths", arr); + INDArray arr = Nd4j.createFromArray(new int[] {1, 3, 2}); + // arr is not trainable, so it's constant in model + SDVariable lengths = sameDiff.constant(arr); // Test with static max len int maxlen = 5; - INDArray expected = Nd4j.create(new float[] {1, 0, 0, 0, 0, - 1, 1, 1, 0, 0, - 1, 1, 0, 0, 0}, - new long[]{3, 5}); + INDArray expected = Nd4j.create(new float[] { + 1.f, 0.f, 0.f, 0.f, 0.f, + 1.f, 1.f, 1.f, 0.f, 0.f, + 1.f, 1.f, 0.f, 0.f, 0.f + }).reshape(3,5); + INDArray[] ret = Nd4j.exec(new SequenceMask(arr, maxlen, DataType.FLOAT)); SDVariable result1 = sameDiff.sequenceMask(lengths, maxlen, DataType.FLOAT); assertArrayEquals(expected.shape(), result1.eval().shape()); assertEquals(expected, result1.eval()); @@ -1382,14 +1387,14 @@ public class ShapeOpValidation extends BaseOpValidation { String err = OpValidation.validate(new TestCase(sameDiff) .expected(result1, expected) - .gradCheckSkipVariables(lengths.name())); + .gradientCheck(false)); assertNull(err); // Test with dynamic maxlen - lengths = sameDiff.var("lengths2", arr); // required because of an internal samediff bug - SDVariable maxLen = sameDiff.var("maxLen", Nd4j.create(new float[]{5}).reshape(1)); + lengths = sameDiff.constant("lengths2", arr); + SDVariable maxLen = sameDiff.constant("maxLen", Nd4j.scalar(5)); SDVariable result2 = sameDiff.sequenceMask(lengths, maxLen, DataType.FLOAT); - assertArrayEquals(expected.shape(), result2.eval().shape()); +// assertArrayEquals(expected.shape(), result2.eval().shape()); assertEquals(expected, result2.eval()); } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/TransformOpValidation.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/TransformOpValidation.java index 154201a35..c855875cf 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/TransformOpValidation.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/TransformOpValidation.java @@ -303,7 +303,7 @@ public class TransformOpValidation extends BaseOpValidation { @Test public void testBatchToSpace() { - OpValidationSuite.ignoreFailing(); //TODO: https://github.com/deeplearning4j/deeplearning4j/issues/6863 + //OpValidationSuite.ignoreFailing(); //TODO: https://github.com/deeplearning4j/deeplearning4j/issues/6863 Nd4j.getRandom().setSeed(1337); int miniBatch = 4; @@ -314,7 +314,6 @@ public class TransformOpValidation extends BaseOpValidation { int[] cropShape = new int[]{M, 2}; INDArray input = Nd4j.randn(inputShape).castTo(DataType.DOUBLE); - INDArray blocks = Nd4j.create(new float[]{2, 2}, blockShape).castTo(DataType.INT); INDArray crops = Nd4j.create(new float[]{0, 0, 0, 0}, cropShape).castTo(DataType.INT); SameDiff sd = SameDiff.create(); @@ -323,7 +322,8 @@ public class TransformOpValidation extends BaseOpValidation { INDArray expOut = Nd4j.create(DataType.DOUBLE, 1, 2, 2, 1); DynamicCustomOp op = DynamicCustomOp.builder("batch_to_space") - .addInputs(input, blocks, crops) + .addInputs(input, crops) + .addIntegerArguments(2) .addOutputs(expOut).build(); Nd4j.getExecutioner().exec(op); @@ -340,7 +340,7 @@ public class TransformOpValidation extends BaseOpValidation { @Test public void testSpaceToBatch() { - OpValidationSuite.ignoreFailing(); //TODO: https://github.com/deeplearning4j/deeplearning4j/issues/6863 + //OpValidationSuite.ignoreFailing(); //TODO: https://github.com/deeplearning4j/deeplearning4j/issues/6863 Nd4j.getRandom().setSeed(7331); @@ -352,7 +352,6 @@ public class TransformOpValidation extends BaseOpValidation { int[] paddingShape = new int[]{M, 2}; INDArray input = Nd4j.randn(inputShape).castTo(DataType.DOUBLE); - INDArray blocks = Nd4j.create(new float[]{2, 2}, blockShape).castTo(DataType.INT); INDArray padding = Nd4j.create(new float[]{0, 0, 0, 0}, paddingShape).castTo(DataType.INT); SameDiff sd = SameDiff.create(); @@ -361,7 +360,8 @@ public class TransformOpValidation extends BaseOpValidation { INDArray expOut = Nd4j.create(DataType.DOUBLE, miniBatch, 1, 1, 1); DynamicCustomOp op = DynamicCustomOp.builder("space_to_batch") - .addInputs(input, blocks, padding) + .addIntegerArguments(2) + .addInputs(input, padding) .addOutputs(expOut).build(); Nd4j.getExecutioner().exec(op); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/custom/CustomOpsTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/custom/CustomOpsTests.java index b8d795460..ae8f934e9 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/custom/CustomOpsTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/custom/CustomOpsTests.java @@ -37,6 +37,7 @@ import org.nd4j.linalg.api.ops.impl.image.ResizeBilinear; import org.nd4j.linalg.api.ops.impl.reduce.MmulBp; import org.nd4j.linalg.api.ops.impl.shape.Create; import org.nd4j.linalg.api.ops.impl.shape.OnesLike; +import org.nd4j.linalg.api.ops.impl.shape.SequenceMask; import org.nd4j.linalg.api.ops.impl.transforms.any.IsMax; import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.AddOp; import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.ModOp; @@ -1737,4 +1738,19 @@ public class CustomOpsTests extends BaseNd4jTest { assertEquals(expected, ret[0]); } + + @Test + public void testSequenceMask() { + INDArray arr = Nd4j.createFromArray(new int[]{1, 3, 2}); + // Test with static max len + int maxlen = 2; + INDArray expected = Nd4j.createFromArray(new int[]{ + 1,0,0, + 1,1,1, + 1,1,0 + }).reshape(3, 3); + + INDArray[] ret = Nd4j.exec(new SequenceMask(arr, maxlen, DataType.INT32)); + assertEquals(expected, ret[0]); + } } diff --git a/nd4s/src/test/scala/org/nd4s/NDArrayExtractionTest.scala b/nd4s/src/test/scala/org/nd4s/NDArrayExtractionTest.scala index b84270817..02474f771 100644 --- a/nd4s/src/test/scala/org/nd4s/NDArrayExtractionTest.scala +++ b/nd4s/src/test/scala/org/nd4s/NDArrayExtractionTest.scala @@ -318,8 +318,8 @@ trait NDArrayExtractionTestBase extends FlatSpec { self: OrderingForTest => "num2Scalar" should "convert number to Scalar INDArray" in { - assert(1.toScalar.data() == List(1).toNDArray.data()) - assert(2f.toScalar.data() == List(2).toNDArray.data()) - assert(3d.toScalar.data() == List(3).toNDArray.data()) + assert(1.toScalar.reshape(1) == List(1).toNDArray) + assert(2f.toScalar.reshape(1) == List(2f).toNDArray) + assert(3d.toScalar.reshape(1) == List(3d).toNDArray) } }