diff --git a/datavec/datavec-data/datavec-data-image/src/main/java/org/datavec/image/loader/BaseImageLoader.java b/datavec/datavec-data/datavec-data-image/src/main/java/org/datavec/image/loader/BaseImageLoader.java
index d2518eeb6..0bfddeb32 100644
--- a/datavec/datavec-data/datavec-data-image/src/main/java/org/datavec/image/loader/BaseImageLoader.java
+++ b/datavec/datavec-data/datavec-data-image/src/main/java/org/datavec/image/loader/BaseImageLoader.java
@@ -16,6 +16,7 @@
package org.datavec.image.loader;
+import lombok.extern.slf4j.Slf4j;
import org.apache.commons.io.FileUtils;
import org.datavec.image.data.Image;
import org.datavec.image.transform.ImageTransform;
@@ -35,10 +36,9 @@ import java.util.Random;
/**
* Created by nyghtowl on 12/17/15.
*/
+@Slf4j
public abstract class BaseImageLoader implements Serializable {
- protected static final Logger log = LoggerFactory.getLogger(BaseImageLoader.class);
-
public enum MultiPageMode {
MINIBATCH, FIRST //, CHANNELS,
}
@@ -62,13 +62,37 @@ public abstract class BaseImageLoader implements Serializable {
public abstract INDArray asRowVector(InputStream inputStream) throws IOException;
+ /** As per {@link #asMatrix(File, boolean)} but NCHW/channels_first format */
public abstract INDArray asMatrix(File f) throws IOException;
+ /**
+ * Load an image from a file to an INDArray
+ * @param f File to load the image from
+ * @param nchw If true: return image in NCHW/channels_first [1, channels, height width] format; if false, return
+ * in NHWC/channels_last [1, height, width, channels] format
+ * @return Image file as as INDArray
+ */
+ public abstract INDArray asMatrix(File f, boolean nchw) throws IOException;
+
public abstract INDArray asMatrix(InputStream inputStream) throws IOException;
+ /**
+ * Load an image file from an input stream to an INDArray
+ * @param inputStream Input stream to load the image from
+ * @param nchw If true: return image in NCHW/channels_first [1, channels, height width] format; if false, return
+ * in NHWC/channels_last [1, height, width, channels] format
+ * @return Image file stream as as INDArray
+ */
+ public abstract INDArray asMatrix(InputStream inputStream, boolean nchw) throws IOException;
+ /** As per {@link #asMatrix(File)} but as an {@link Image}*/
public abstract Image asImageMatrix(File f) throws IOException;
+ /** As per {@link #asMatrix(File, boolean)} but as an {@link Image}*/
+ public abstract Image asImageMatrix(File f, boolean nchw) throws IOException;
+ /** As per {@link #asMatrix(InputStream)} but as an {@link Image}*/
public abstract Image asImageMatrix(InputStream inputStream) throws IOException;
+ /** As per {@link #asMatrix(InputStream, boolean)} but as an {@link Image}*/
+ public abstract Image asImageMatrix(InputStream inputStream, boolean nchw) throws IOException;
public static void downloadAndUntar(Map urlMap, File fullDir) {
diff --git a/datavec/datavec-data/datavec-data-image/src/main/java/org/datavec/image/loader/CifarLoader.java b/datavec/datavec-data/datavec-data-image/src/main/java/org/datavec/image/loader/CifarLoader.java
index 3d390c698..e513ebed3 100644
--- a/datavec/datavec-data/datavec-data-image/src/main/java/org/datavec/image/loader/CifarLoader.java
+++ b/datavec/datavec-data/datavec-data-image/src/main/java/org/datavec/image/loader/CifarLoader.java
@@ -16,6 +16,7 @@
package org.datavec.image.loader;
+import lombok.extern.slf4j.Slf4j;
import org.apache.commons.io.FileUtils;
import org.apache.commons.io.FilenameUtils;
import org.bytedeco.javacv.OpenCVFrameConverter;
@@ -47,6 +48,7 @@ import static org.bytedeco.opencv.global.opencv_imgproc.*;
* There is a special preProcessor used to normalize the dataset based on Sergey Zagoruyko example
* https://github.com/szagoruyko/cifar.torch
*/
+@Slf4j
public class CifarLoader extends NativeImageLoader implements Serializable {
public static final int NUM_TRAIN_IMAGES = 50000;
public static final int NUM_TEST_IMAGES = 10000;
diff --git a/datavec/datavec-data/datavec-data-image/src/main/java/org/datavec/image/loader/ImageLoader.java b/datavec/datavec-data/datavec-data-image/src/main/java/org/datavec/image/loader/ImageLoader.java
index d246c65ad..9c2c61d57 100644
--- a/datavec/datavec-data/datavec-data-image/src/main/java/org/datavec/image/loader/ImageLoader.java
+++ b/datavec/datavec-data/datavec-data-image/src/main/java/org/datavec/image/loader/ImageLoader.java
@@ -249,7 +249,14 @@ public class ImageLoader extends BaseImageLoader {
* @throws IOException
*/
public INDArray asMatrix(File f) throws IOException {
- return NDArrayUtil.toNDArray(fromFile(f));
+ return asMatrix(f, true);
+ }
+
+ @Override
+ public INDArray asMatrix(File f, boolean nchw) throws IOException {
+ try(InputStream is = new BufferedInputStream(new FileInputStream(f))){
+ return asMatrix(is, nchw);
+ }
}
/**
@@ -259,34 +266,68 @@ public class ImageLoader extends BaseImageLoader {
* @return the input stream to convert
*/
public INDArray asMatrix(InputStream inputStream) throws IOException {
- if (channels == 3)
- return toBgr(inputStream);
- try {
- BufferedImage image = ImageIO.read(inputStream);
- return asMatrix(image);
- } catch (IOException e) {
- throw new IOException("Unable to load image", e);
+ return asMatrix(inputStream, true);
+ }
+
+ @Override
+ public INDArray asMatrix(InputStream inputStream, boolean nchw) throws IOException {
+ INDArray ret;
+ if (channels == 3) {
+ ret = toBgr(inputStream);
+ } else {
+ try {
+ BufferedImage image = ImageIO.read(inputStream);
+ ret = asMatrix(image);
+ } catch (IOException e) {
+ throw new IOException("Unable to load image", e);
+ }
}
+ if(ret.rank() == 3){
+ ret = ret.reshape(1, ret.size(0), ret.size(1), ret.size(2));
+ }
+ if(!nchw)
+ ret = ret.permute(0,2,3,1); //NCHW to NHWC
+ return ret;
}
@Override
public org.datavec.image.data.Image asImageMatrix(File f) throws IOException {
+ return asImageMatrix(f, true);
+ }
+
+ @Override
+ public org.datavec.image.data.Image asImageMatrix(File f, boolean nchw) throws IOException {
try (BufferedInputStream bis = new BufferedInputStream(new FileInputStream(f))) {
- return asImageMatrix(bis);
+ return asImageMatrix(bis, nchw);
}
}
@Override
public org.datavec.image.data.Image asImageMatrix(InputStream inputStream) throws IOException {
- if (channels == 3)
- return toBgrImage(inputStream);
- try {
- BufferedImage image = ImageIO.read(inputStream);
- INDArray asMatrix = asMatrix(image);
- return new org.datavec.image.data.Image(asMatrix, image.getData().getNumBands(), image.getHeight(), image.getWidth());
- } catch (IOException e) {
- throw new IOException("Unable to load image", e);
+ return asImageMatrix(inputStream, true);
+ }
+
+ @Override
+ public org.datavec.image.data.Image asImageMatrix(InputStream inputStream, boolean nchw) throws IOException {
+ org.datavec.image.data.Image ret;
+ if (channels == 3) {
+ ret = toBgrImage(inputStream);
+ } else {
+ try {
+ BufferedImage image = ImageIO.read(inputStream);
+ INDArray asMatrix = asMatrix(image);
+ ret = new org.datavec.image.data.Image(asMatrix, image.getData().getNumBands(), image.getHeight(), image.getWidth());
+ } catch (IOException e) {
+ throw new IOException("Unable to load image", e);
+ }
}
+ if(ret.getImage().rank() == 3){
+ INDArray a = ret.getImage();
+ ret.setImage(a.reshape(1, a.size(0), a.size(1), a.size(2)));
+ }
+ if(!nchw)
+ ret.setImage(ret.getImage().permute(0,2,3,1)); //NCHW to NHWC
+ return ret;
}
/**
diff --git a/datavec/datavec-data/datavec-data-image/src/main/java/org/datavec/image/loader/LFWLoader.java b/datavec/datavec-data/datavec-data-image/src/main/java/org/datavec/image/loader/LFWLoader.java
index d28c73318..b71c53e42 100644
--- a/datavec/datavec-data/datavec-data-image/src/main/java/org/datavec/image/loader/LFWLoader.java
+++ b/datavec/datavec-data/datavec-data-image/src/main/java/org/datavec/image/loader/LFWLoader.java
@@ -17,6 +17,7 @@
package org.datavec.image.loader;
+import lombok.extern.slf4j.Slf4j;
import org.datavec.api.io.filters.BalancedPathFilter;
import org.datavec.api.io.labels.PathLabelGenerator;
import org.datavec.api.io.labels.PatternPathLabelGenerator;
@@ -48,6 +49,7 @@ import java.util.Random;
* most images are in color, although a few are grayscale
*
*/
+@Slf4j
public class LFWLoader extends BaseImageLoader implements Serializable {
public final static int NUM_IMAGES = 13233;
@@ -270,19 +272,39 @@ public class LFWLoader extends BaseImageLoader implements Serializable {
throw new UnsupportedOperationException();
}
+ @Override
+ public INDArray asMatrix(File f, boolean nchw) throws IOException {
+ throw new UnsupportedOperationException();
+ }
+
@Override
public INDArray asMatrix(InputStream inputStream) throws IOException {
throw new UnsupportedOperationException();
}
+ @Override
+ public INDArray asMatrix(InputStream inputStream, boolean nchw) throws IOException {
+ throw new UnsupportedOperationException();
+ }
+
@Override
public Image asImageMatrix(File f) throws IOException {
throw new UnsupportedOperationException();
}
+ @Override
+ public Image asImageMatrix(File f, boolean nchw) throws IOException {
+ throw new UnsupportedOperationException();
+ }
+
@Override
public Image asImageMatrix(InputStream inputStream) throws IOException {
throw new UnsupportedOperationException();
}
+ @Override
+ public Image asImageMatrix(InputStream inputStream, boolean nchw) throws IOException {
+ throw new UnsupportedOperationException();
+ }
+
}
diff --git a/datavec/datavec-data/datavec-data-image/src/main/java/org/datavec/image/loader/NativeImageLoader.java b/datavec/datavec-data/datavec-data-image/src/main/java/org/datavec/image/loader/NativeImageLoader.java
index 88bc161f2..ae9e2a322 100644
--- a/datavec/datavec-data/datavec-data-image/src/main/java/org/datavec/image/loader/NativeImageLoader.java
+++ b/datavec/datavec-data/datavec-data-image/src/main/java/org/datavec/image/loader/NativeImageLoader.java
@@ -248,17 +248,27 @@ public class NativeImageLoader extends BaseImageLoader {
@Override
public INDArray asMatrix(File f) throws IOException {
+ return asMatrix(f, true);
+ }
+
+ @Override
+ public INDArray asMatrix(File f, boolean nchw) throws IOException {
try (BufferedInputStream bis = new BufferedInputStream(new FileInputStream(f))) {
- return asMatrix(bis);
+ return asMatrix(bis, nchw);
}
}
@Override
public INDArray asMatrix(InputStream is) throws IOException {
- Mat mat = streamToMat(is);
+ return asMatrix(is, true);
+ }
+
+ @Override
+ public INDArray asMatrix(InputStream inputStream, boolean nchw) throws IOException {
+ Mat mat = streamToMat(inputStream);
INDArray a;
if (this.multiPageMode != null) {
- a = asMatrix(mat.data(), mat.cols());
+ a = asMatrix(mat.data(), mat.cols());
}else{
Mat image = imdecode(mat, IMREAD_ANYDEPTH | IMREAD_ANYCOLOR);
if (image == null || image.empty()) {
@@ -272,7 +282,11 @@ public class NativeImageLoader extends BaseImageLoader {
a = asMatrix(image);
image.deallocate();
}
- return a;
+ if(nchw) {
+ return a;
+ } else {
+ return a.permute(0, 2, 3, 1); //NCHW to NHWC
+ }
}
/**
@@ -331,19 +345,29 @@ public class NativeImageLoader extends BaseImageLoader {
}
public Image asImageMatrix(String filename) throws IOException {
- return asImageMatrix(filename);
+ return asImageMatrix(new File(filename));
}
@Override
public Image asImageMatrix(File f) throws IOException {
+ return asImageMatrix(f, true);
+ }
+
+ @Override
+ public Image asImageMatrix(File f, boolean nchw) throws IOException {
try (BufferedInputStream bis = new BufferedInputStream(new FileInputStream(f))) {
- return asImageMatrix(bis);
+ return asImageMatrix(bis, nchw);
}
}
@Override
public Image asImageMatrix(InputStream is) throws IOException {
- Mat mat = streamToMat(is);
+ return asImageMatrix(is, true);
+ }
+
+ @Override
+ public Image asImageMatrix(InputStream inputStream, boolean nchw) throws IOException {
+ Mat mat = streamToMat(inputStream);
Mat image = imdecode(mat, IMREAD_ANYDEPTH | IMREAD_ANYCOLOR);
if (image == null || image.empty()) {
PIX pix = pixReadMem(mat.data(), mat.cols());
@@ -354,6 +378,8 @@ public class NativeImageLoader extends BaseImageLoader {
pixDestroy(pix);
}
INDArray a = asMatrix(image);
+ if(!nchw)
+ a = a.permute(0,2,3,1); //NCHW to NHWC
Image i = new Image(a, image.channels(), image.rows(), image.cols());
image.deallocate();
diff --git a/datavec/datavec-data/datavec-data-image/src/test/java/org/datavec/image/loader/TestImageLoader.java b/datavec/datavec-data/datavec-data-image/src/test/java/org/datavec/image/loader/TestImageLoader.java
index 1683980f0..a82f12409 100644
--- a/datavec/datavec-data/datavec-data-image/src/test/java/org/datavec/image/loader/TestImageLoader.java
+++ b/datavec/datavec-data/datavec-data-image/src/test/java/org/datavec/image/loader/TestImageLoader.java
@@ -16,10 +16,16 @@
package org.datavec.image.loader;
+import org.datavec.image.data.Image;
import org.junit.Test;
+import org.nd4j.common.resources.Resources;
import org.nd4j.linalg.api.ndarray.INDArray;
import java.awt.image.BufferedImage;
+import java.io.BufferedInputStream;
+import java.io.File;
+import java.io.FileInputStream;
+import java.io.InputStream;
import java.util.Random;
import static org.junit.Assert.assertEquals;
@@ -208,4 +214,57 @@ public class TestImageLoader {
private BufferedImage makeRandomBufferedImage(boolean alpha) {
return makeRandomBufferedImage(alpha, rng.nextInt() % 100 + 100, rng.nextInt() % 100 + 100);
}
+
+
+ @Test
+ public void testNCHW_NHWC() throws Exception {
+ File f = Resources.asFile("datavec-data-image/voc/2007/JPEGImages/000005.jpg");
+
+ ImageLoader il = new ImageLoader(32, 32, 3);
+
+ //asMatrix(File, boolean)
+ INDArray a_nchw = il.asMatrix(f);
+ INDArray a_nchw2 = il.asMatrix(f, true);
+ INDArray a_nhwc = il.asMatrix(f, false);
+
+ assertEquals(a_nchw, a_nchw2);
+ assertEquals(a_nchw, a_nhwc.permute(0,3,1,2));
+
+
+ //asMatrix(InputStream, boolean)
+ try(InputStream is = new BufferedInputStream(new FileInputStream(f))){
+ a_nchw = il.asMatrix(is);
+ }
+ try(InputStream is = new BufferedInputStream(new FileInputStream(f))){
+ a_nchw2 = il.asMatrix(is, true);
+ }
+ try(InputStream is = new BufferedInputStream(new FileInputStream(f))){
+ a_nhwc = il.asMatrix(is, false);
+ }
+ assertEquals(a_nchw, a_nchw2);
+ assertEquals(a_nchw, a_nhwc.permute(0,3,1,2));
+
+
+ //asImageMatrix(File, boolean)
+ Image i_nchw = il.asImageMatrix(f);
+ Image i_nchw2 = il.asImageMatrix(f, true);
+ Image i_nhwc = il.asImageMatrix(f, false);
+
+ assertEquals(i_nchw.getImage(), i_nchw2.getImage());
+ assertEquals(i_nchw.getImage(), i_nhwc.getImage().permute(0,3,1,2)); //NHWC to NCHW
+
+
+ //asImageMatrix(InputStream, boolean)
+ try(InputStream is = new BufferedInputStream(new FileInputStream(f))){
+ i_nchw = il.asImageMatrix(is);
+ }
+ try(InputStream is = new BufferedInputStream(new FileInputStream(f))){
+ i_nchw2 = il.asImageMatrix(is, true);
+ }
+ try(InputStream is = new BufferedInputStream(new FileInputStream(f))){
+ i_nhwc = il.asImageMatrix(is, false);
+ }
+ assertEquals(i_nchw.getImage(), i_nchw2.getImage());
+ assertEquals(i_nchw.getImage(), i_nhwc.getImage().permute(0,3,1,2)); //NHWC to NCHW
+ }
}
diff --git a/datavec/datavec-data/datavec-data-image/src/test/java/org/datavec/image/loader/TestNativeImageLoader.java b/datavec/datavec-data/datavec-data-image/src/test/java/org/datavec/image/loader/TestNativeImageLoader.java
index 6e7705569..68e93107c 100644
--- a/datavec/datavec-data/datavec-data-image/src/test/java/org/datavec/image/loader/TestNativeImageLoader.java
+++ b/datavec/datavec-data/datavec-data-image/src/test/java/org/datavec/image/loader/TestNativeImageLoader.java
@@ -24,20 +24,19 @@ import org.bytedeco.javacpp.indexer.UByteIndexer;
import org.bytedeco.javacv.Frame;
import org.bytedeco.javacv.Java2DFrameConverter;
import org.bytedeco.javacv.OpenCVFrameConverter;
+import org.datavec.image.data.Image;
import org.datavec.image.data.ImageWritable;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.TemporaryFolder;
+import org.nd4j.common.resources.Resources;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.common.io.ClassPathResource;
import java.awt.image.BufferedImage;
-import java.io.File;
-import java.io.FileInputStream;
-import java.io.InputStream;
-import java.io.IOException;
+import java.io.*;
import java.lang.reflect.Field;
import java.util.Random;
@@ -604,4 +603,56 @@ public class TestNativeImageLoader {
}
}
+ @Test
+ public void testNCHW_NHWC() throws Exception {
+ File f = Resources.asFile("datavec-data-image/voc/2007/JPEGImages/000005.jpg");
+
+ NativeImageLoader il = new NativeImageLoader(32, 32, 3);
+
+ //asMatrix(File, boolean)
+ INDArray a_nchw = il.asMatrix(f);
+ INDArray a_nchw2 = il.asMatrix(f, true);
+ INDArray a_nhwc = il.asMatrix(f, false);
+
+ assertEquals(a_nchw, a_nchw2);
+ assertEquals(a_nchw, a_nhwc.permute(0,3,1,2));
+
+
+ //asMatrix(InputStream, boolean)
+ try(InputStream is = new BufferedInputStream(new FileInputStream(f))){
+ a_nchw = il.asMatrix(is);
+ }
+ try(InputStream is = new BufferedInputStream(new FileInputStream(f))){
+ a_nchw2 = il.asMatrix(is, true);
+ }
+ try(InputStream is = new BufferedInputStream(new FileInputStream(f))){
+ a_nhwc = il.asMatrix(is, false);
+ }
+ assertEquals(a_nchw, a_nchw2);
+ assertEquals(a_nchw, a_nhwc.permute(0,3,1,2));
+
+
+ //asImageMatrix(File, boolean)
+ Image i_nchw = il.asImageMatrix(f);
+ Image i_nchw2 = il.asImageMatrix(f, true);
+ Image i_nhwc = il.asImageMatrix(f, false);
+
+ assertEquals(i_nchw.getImage(), i_nchw2.getImage());
+ assertEquals(i_nchw.getImage(), i_nhwc.getImage().permute(0,3,1,2)); //NHWC to NCHW
+
+
+ //asImageMatrix(InputStream, boolean)
+ try(InputStream is = new BufferedInputStream(new FileInputStream(f))){
+ i_nchw = il.asImageMatrix(is);
+ }
+ try(InputStream is = new BufferedInputStream(new FileInputStream(f))){
+ i_nchw2 = il.asImageMatrix(is, true);
+ }
+ try(InputStream is = new BufferedInputStream(new FileInputStream(f))){
+ i_nhwc = il.asImageMatrix(is, false);
+ }
+ assertEquals(i_nchw.getImage(), i_nchw2.getImage());
+ assertEquals(i_nchw.getImage(), i_nhwc.getImage().permute(0,3,1,2)); //NHWC to NCHW
+ }
+
}
diff --git a/datavec/datavec-data/datavec-data-image/src/test/java/org/datavec/image/recordreader/TestImageRecordReader.java b/datavec/datavec-data/datavec-data-image/src/test/java/org/datavec/image/recordreader/TestImageRecordReader.java
index fdcbe959a..26cd83f06 100644
--- a/datavec/datavec-data/datavec-data-image/src/test/java/org/datavec/image/recordreader/TestImageRecordReader.java
+++ b/datavec/datavec-data/datavec-data-image/src/test/java/org/datavec/image/recordreader/TestImageRecordReader.java
@@ -474,16 +474,18 @@ public class TestImageRecordReader {
public void testNCHW_NCHW() throws Exception {
//Idea: labels order should be consistent regardless of input file order
File f0 = testDir.newFolder();
- Resources.copyDirectory("datavec-data-image/testimages/", f0);
+ new ClassPathResource("datavec-data-image/testimages/").copyDirectory(f0);
- FileSplit fs = new FileSplit(f0, new Random(12345));
- assertEquals(6, fs.locations().length);
+ FileSplit fs0 = new FileSplit(f0, new Random(12345));
+ FileSplit fs1 = new FileSplit(f0, new Random(12345));
+ assertEquals(6, fs0.locations().length);
+ assertEquals(6, fs1.locations().length);
ImageRecordReader nchw = new ImageRecordReader(32, 32, 3, true);
- nchw.initialize(fs);
+ nchw.initialize(fs0);
ImageRecordReader nhwc = new ImageRecordReader(32, 32, 3, false);
- nhwc.initialize(fs);
+ nhwc.initialize(fs1);
while(nchw.hasNext()){
assertTrue(nhwc.hasNext());
@@ -533,7 +535,7 @@ public class TestImageRecordReader {
//Test record(URI, DataInputStream)
- URI u = fs.locations()[0];
+ URI u = fs0.locations()[0];
try(DataInputStream dis = new DataInputStream(new BufferedInputStream(new FileInputStream(new File(u))))) {
List l = nchw.record(u, dis);