From fec620fafa6383287bf8f17be7c364832ae177db Mon Sep 17 00:00:00 2001 From: Fariz Rahman Date: Wed, 4 Mar 2020 04:46:32 +0400 Subject: [PATCH] TensorflowConversion Data Types (#284) * dtypes * bf16 and bool * tests --- .../conversion/TensorflowConversionTest.java | 46 +++++++++++++++---- .../conversion/TensorflowConversion.java | 28 ++++++++++- 2 files changed, 64 insertions(+), 10 deletions(-) diff --git a/nd4j/nd4j-backends/nd4j-tests-tensorflow/src/test/cpujava/org/nd4j/tensorflow/conversion/TensorflowConversionTest.java b/nd4j/nd4j-backends/nd4j-tests-tensorflow/src/test/cpujava/org/nd4j/tensorflow/conversion/TensorflowConversionTest.java index fbf4249bd..fea5a5aa8 100644 --- a/nd4j/nd4j-backends/nd4j-tests-tensorflow/src/test/cpujava/org/nd4j/tensorflow/conversion/TensorflowConversionTest.java +++ b/nd4j/nd4j-backends/nd4j-tests-tensorflow/src/test/cpujava/org/nd4j/tensorflow/conversion/TensorflowConversionTest.java @@ -16,9 +16,11 @@ package org.nd4j.tensorflow.conversion; +import lombok.extern.slf4j.Slf4j; import org.apache.commons.io.IOUtils; import org.junit.Test; import org.nd4j.BaseND4JTest; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.io.ClassPathResource; @@ -29,7 +31,9 @@ import static org.bytedeco.tensorflow.global.tensorflow.*; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNotNull; import static org.junit.Assert.fail; +import static org.nd4j.linalg.api.buffer.DataType.*; +@Slf4j public class TensorflowConversionTest extends BaseND4JTest { @Test @@ -53,15 +57,39 @@ public class TensorflowConversionTest extends BaseND4JTest { @Test public void testConversionFromNdArray() throws Exception { - INDArray arr = Nd4j.linspace(1,4,4); - TensorflowConversion tensorflowConversion =TensorflowConversion.getInstance(); - TF_Tensor tf_tensor = tensorflowConversion.tensorFromNDArray(arr); - INDArray fromTensor = tensorflowConversion.ndArrayFromTensor(tf_tensor); - assertEquals(arr,fromTensor); - arr.addi(1.0); - tf_tensor = tensorflowConversion.tensorFromNDArray(arr); - fromTensor = tensorflowConversion.ndArrayFromTensor(tf_tensor); - assertEquals(arr,fromTensor); + DataType[] dtypes = new DataType[]{ + DOUBLE, + FLOAT, + SHORT, + LONG, + BYTE, + UBYTE, + UINT16, + UINT32, + UINT64, + BFLOAT16, + BOOL, + INT, + HALF + }; + for(DataType dtype: dtypes){ + log.debug("Testing conversion for data type " + dtype); + INDArray arr = Nd4j.linspace(1, 4, 4).reshape(2, 2).castTo(dtype); + TensorflowConversion tensorflowConversion =TensorflowConversion.getInstance(); + TF_Tensor tf_tensor = tensorflowConversion.tensorFromNDArray(arr); + INDArray fromTensor = tensorflowConversion.ndArrayFromTensor(tf_tensor); + assertEquals(arr,fromTensor); + if (dtype == BOOL){ + arr.putScalar(3, 0); + } + else{ + arr.addi(1.0); + } + tf_tensor = tensorflowConversion.tensorFromNDArray(arr); + fromTensor = tensorflowConversion.ndArrayFromTensor(tf_tensor); + assertEquals(arr,fromTensor); + } + } diff --git a/nd4j/nd4j-tensorflow/src/main/java/org/nd4j/tensorflow/conversion/TensorflowConversion.java b/nd4j/nd4j-tensorflow/src/main/java/org/nd4j/tensorflow/conversion/TensorflowConversion.java index c79ec4fd3..5204a00c1 100644 --- a/nd4j/nd4j-tensorflow/src/main/java/org/nd4j/tensorflow/conversion/TensorflowConversion.java +++ b/nd4j/nd4j-tensorflow/src/main/java/org/nd4j/tensorflow/conversion/TensorflowConversion.java @@ -121,8 +121,16 @@ public class TensorflowConversion { default: throw new IllegalArgumentException("Unsupported compression algorithm: " + algo); } break; + case SHORT: type = DT_INT16; break; case LONG: type = DT_INT64; break; case UTF8: type = DT_STRING; break; + case BYTE: type = DT_INT8; break; + case UBYTE: type = DT_UINT8; break; + case UINT16: type = DT_UINT16; break; + case UINT32: type = DT_UINT32; break; + case UINT64: type = DT_UINT64; break; + case BFLOAT16: type = DT_BFLOAT16; break; + case BOOL: type = DT_BOOL; break; default: throw new IllegalArgumentException("Unsupported data type: " + dataType); } @@ -250,6 +258,15 @@ public class TensorflowConversion { case FLOAT: return FloatIndexer.create(new FloatPointer(pointer)); case INT: return IntIndexer.create(new IntPointer(pointer)); case LONG: return LongIndexer.create(new LongPointer(pointer)); + case SHORT: return ShortIndexer.create(new ShortPointer(pointer)); + case BYTE: return ByteIndexer.create(new BytePointer(pointer)); + case UBYTE: return UByteIndexer.create(new BytePointer(pointer)); + case UINT16: return UShortIndexer.create(new ShortPointer(pointer)); + case UINT32: return UIntIndexer.create(new IntPointer(pointer)); + case UINT64: return ULongIndexer.create(new LongPointer(pointer)); + case BFLOAT16: return Bfloat16Indexer.create(new ShortPointer(pointer)); + case HALF: return HalfIndexer.create(new ShortPointer(pointer)); + case BOOL: return BooleanIndexer.create(new BooleanPointer(pointer)); default: throw new IllegalArgumentException("Illegal type " + type); } } @@ -258,9 +275,18 @@ public class TensorflowConversion { switch(tensorflowType) { case DT_DOUBLE: return DataType.DOUBLE; case DT_FLOAT: return DataType.FLOAT; - case DT_INT32: return DataType.LONG; + case DT_HALF: return DataType.HALF; + case DT_INT16: return DataType.SHORT; + case DT_INT32: return DataType.INT; case DT_INT64: return DataType.LONG; case DT_STRING: return DataType.UTF8; + case DT_INT8: return DataType.BYTE; + case DT_UINT8: return DataType.UBYTE; + case DT_UINT16: return DataType.UINT16; + case DT_UINT32: return DataType.UINT32; + case DT_UINT64: return DataType.UINT64; + case DT_BFLOAT16: return DataType.BFLOAT16; + case DT_BOOL: return DataType.BOOL; default: throw new IllegalArgumentException("Illegal type " + tensorflowType); } }