From 9cc8803b8dbfc1b1c17dbf4a8ea452e4eece222e Mon Sep 17 00:00:00 2001 From: Alex Black Date: Wed, 4 Dec 2019 22:52:06 +1100 Subject: [PATCH] DL4J + Keras import: Causal Conv1D support (#107) * Keras causal conv1d support first steps Signed-off-by: AlexDBlack * Add tests Signed-off-by: AlexDBlack * Causal conv mode Signed-off-by: AlexDBlack * Gradient check and fixes for causal conv1d Signed-off-by: AlexDBlack * Fix Conv1D import and testing Signed-off-by: AlexDBlack * Cleanup Signed-off-by: AlexDBlack * Small keras test fix Signed-off-by: Alex Black * Don't allow setting causal convolution mode to conv2d/3d layers Signed-off-by: Alex Black * More robustly infer nIn for recurrent layers for ambiguous NCW and NWC cases Signed-off-by: Alex Black * Polish and cleanup Signed-off-by: Alex Black --- .../gradientcheck/CNN1DGradientCheckTest.java | 74 +++++++ .../convolution/ConvolutionLayerTest.java | 69 +++++++ .../nn/modelimport/keras/KerasLayer.java | 4 + .../keras/config/KerasLayerConfiguration.java | 1 + .../modelimport/keras/layers/KerasInput.java | 21 +- .../modelimport/keras/layers/KerasLoss.java | 7 +- .../convolutional/KerasConvolution.java | 1 - .../convolutional/KerasConvolution1D.java | 17 +- .../convolutional/KerasConvolutionUtils.java | 3 +- .../keras/layers/recurrent/KerasLSTM.java | 19 ++ .../layers/recurrent/KerasSimpleRnn.java | 18 ++ .../layers/wrappers/KerasBidirectional.java | 12 +- .../keras/utils/KerasLayerUtils.java | 11 ++ .../keras/e2e/KerasModelEndToEndTest.java | 185 +++++++++++++++--- .../nn/conf/ConvolutionMode.java | 14 +- .../nn/conf/layers/Convolution1DLayer.java | 5 + .../nn/conf/layers/Convolution3D.java | 6 + .../nn/conf/layers/ConvolutionLayer.java | 15 ++ .../nn/conf/layers/Deconvolution2D.java | 6 + .../conf/layers/DepthwiseConvolution2D.java | 6 + .../conf/layers/SeparableConvolution2D.java | 6 + .../nn/conf/layers/Subsampling1DLayer.java | 5 + .../nn/conf/layers/Subsampling3DLayer.java | 6 + .../nn/conf/layers/SubsamplingLayer.java | 15 ++ .../convolution/Convolution1DLayer.java | 94 +++++++++ .../util/Convolution1DUtils.java | 7 +- .../deeplearning4j/util/ConvolutionUtils.java | 17 +- 27 files changed, 588 insertions(+), 56 deletions(-) diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/CNN1DGradientCheckTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/CNN1DGradientCheckTest.java index 64748f932..a0a109cb1 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/CNN1DGradientCheckTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/CNN1DGradientCheckTest.java @@ -27,6 +27,8 @@ import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.layers.*; import org.deeplearning4j.nn.conf.layers.convolutional.Cropping1D; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; +import org.deeplearning4j.util.Convolution1DUtils; +import org.deeplearning4j.util.ConvolutionUtils; import org.junit.Test; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.api.buffer.DataType; @@ -442,4 +444,76 @@ public class CNN1DGradientCheckTest extends BaseDL4JTest { } } } + + @Test + public void testCnn1Causal() { + int convNIn = 2; + int convNOut1 = 3; + int convNOut2 = 4; + int finalNOut = 3; + + int[] lengths = {11, 12, 13, 9, 10, 11}; + int[] kernels = {2, 3, 2, 4, 2, 3}; + int[] dilations = {1, 1, 2, 1, 2, 1}; + int[] strides = {1, 2, 1, 2, 1, 1}; + boolean[] masks = {false, true, false, true, false, true}; + boolean[] hasB = {true, false, true, false, true, true}; + + for (int i = 0; i < lengths.length; i++) { + int length = lengths[i]; + int k = kernels[i]; + int d = dilations[i]; + int st = strides[i]; + boolean mask = masks[i]; + boolean hasBias = hasB[i]; + //TODO has bias + String s = "k=" + k + ", s=" + st + "d=" + d + ", seqLen=" + length; + log.info("Starting test: " + s); + Nd4j.getRandom().setSeed(12345); + + MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + .dataType(DataType.DOUBLE) + .updater(new NoOp()) + .activation(Activation.TANH) + .weightInit(new NormalDistribution(0, 1)) + .seed(12345) + .list() + .layer(new Convolution1DLayer.Builder().kernelSize(k) + .dilation(d) + .hasBias(hasBias) + .convolutionMode(ConvolutionMode.Causal) + .stride(st).nIn(convNIn).nOut(convNOut1) + .build()) + .layer(new Convolution1DLayer.Builder().kernelSize(k) + .dilation(d) + .convolutionMode(ConvolutionMode.Causal) + .stride(st).nIn(convNOut1).nOut(convNOut2) + .build()) + .layer(new RnnOutputLayer.Builder(LossFunctions.LossFunction.MCXENT) + .activation(Activation.SOFTMAX).nOut(finalNOut).build()) + .setInputType(InputType.recurrent(convNIn, length)).build(); + + MultiLayerNetwork net = new MultiLayerNetwork(conf); + net.init(); + + INDArray f = Nd4j.rand(DataType.DOUBLE, 2, convNIn, length); + INDArray fm = null; + if (mask) { + fm = Nd4j.create(2, length); + fm.get(NDArrayIndex.point(0), NDArrayIndex.all()).assign(1); + fm.get(NDArrayIndex.point(1), NDArrayIndex.interval(0, length-2)).assign(1); + } + + long outSize1 = Convolution1DUtils.getOutputSize(length, k, st, 0, ConvolutionMode.Causal, d); + long outSize2 = Convolution1DUtils.getOutputSize(outSize1, k, st, 0, ConvolutionMode.Causal, d); + + INDArray label = TestUtils.randomOneHotTimeSeries(2, finalNOut, (int)outSize2); + + boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, + DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, f, label, fm, null); + + assertTrue(s, gradOK); + TestUtils.testModelSerialization(net); + } + } } diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/convolution/ConvolutionLayerTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/convolution/ConvolutionLayerTest.java index 1c4b764bd..431831487 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/convolution/ConvolutionLayerTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/convolution/ConvolutionLayerTest.java @@ -712,4 +712,73 @@ public class ConvolutionLayerTest extends BaseDL4JTest { assertTrue(msg,msg.contains("Deconvolution2D") && msg.contains("input") && msg.contains("channels")); } } + + @Test + public void testConv1dCausalAllowed(){ + new Convolution1DLayer.Builder().convolutionMode(ConvolutionMode.Causal).kernelSize(2).build(); + new Subsampling1DLayer.Builder().convolutionMode(ConvolutionMode.Causal).kernelSize(2).build(); + } + + @Test + public void testConv2dNoCausalAllowed(){ + + try{ + new ConvolutionLayer.Builder().convolutionMode(ConvolutionMode.Causal).build(); + fail("Expected exception"); + } catch (Throwable t){ + String m = t.getMessage().toLowerCase(); + assertTrue(m, m.contains("causal") && m.contains("1d")); + } + + try{ + new Deconvolution2D.Builder().convolutionMode(ConvolutionMode.Causal).build(); + fail("Expected exception"); + } catch (Throwable t){ + String m = t.getMessage().toLowerCase(); + assertTrue(m, m.contains("causal") && m.contains("1d")); + } + + try{ + new DepthwiseConvolution2D.Builder().convolutionMode(ConvolutionMode.Causal).build(); + fail("Expected exception"); + } catch (Throwable t){ + String m = t.getMessage().toLowerCase(); + assertTrue(m, m.contains("causal") && m.contains("1d")); + } + + try{ + new SeparableConvolution2D.Builder().convolutionMode(ConvolutionMode.Causal).build(); + fail("Expected exception"); + } catch (Throwable t){ + String m = t.getMessage().toLowerCase(); + assertTrue(m, m.contains("causal") && m.contains("1d")); + } + + try{ + new SubsamplingLayer.Builder().convolutionMode(ConvolutionMode.Causal).build(); + fail("Expected exception"); + } catch (Throwable t){ + String m = t.getMessage().toLowerCase(); + assertTrue(m, m.contains("causal") && m.contains("1d")); + } + } + + @Test + public void testConv3dNoCausalAllowed(){ + try{ + new Convolution3D.Builder().convolutionMode(ConvolutionMode.Causal).build(); + fail("Expected exception"); + } catch (Throwable t){ + String m = t.getMessage().toLowerCase(); + assertTrue(m, m.contains("causal") && m.contains("1d")); + } + + try{ + new Subsampling3DLayer.Builder().convolutionMode(ConvolutionMode.Causal).build(); + fail("Expected exception"); + } catch (Throwable t){ + String m = t.getMessage().toLowerCase(); + assertTrue(m, m.contains("causal") && m.contains("1d")); + } + } } diff --git a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/KerasLayer.java b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/KerasLayer.java index a31ac6177..7d70077af 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/KerasLayer.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/KerasLayer.java @@ -356,6 +356,10 @@ public class KerasLayer { return this.layer; } + public void setLayer(Layer layer){ + this.layer = layer; + } + /** * Whether this Keras layer maps to a DL4J Vertex. * diff --git a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/config/KerasLayerConfiguration.java b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/config/KerasLayerConfiguration.java index 84a85a2d5..6d6fc42c9 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/config/KerasLayerConfiguration.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/config/KerasLayerConfiguration.java @@ -233,6 +233,7 @@ public class KerasLayerConfiguration { private final String LAYER_BORDER_MODE_SAME = "same"; private final String LAYER_BORDER_MODE_VALID = "valid"; private final String LAYER_BORDER_MODE_FULL = "full"; + private final String LAYER_BORDER_MODE_CAUSAL = "causal"; /* Noise layers */ private final String LAYER_FIELD_RATE = "rate"; diff --git a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/KerasInput.java b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/KerasInput.java index c1df4b592..785e480d1 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/KerasInput.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/KerasInput.java @@ -124,7 +124,26 @@ public class KerasInput extends KerasLayer { myInputType = new InputType.InputTypeFeedForward(this.inputShape[0]); break; case 2: - myInputType = new InputType.InputTypeRecurrent(this.inputShape[1], this.inputShape[0]); + if(this.dimOrder != null) { + switch (this.dimOrder) { + case TENSORFLOW: //NWC == channels_last + myInputType = new InputType.InputTypeRecurrent(this.inputShape[1], this.inputShape[0]); + break; + case THEANO: //NCW == channels_first + myInputType = new InputType.InputTypeRecurrent(this.inputShape[0], this.inputShape[1]); + break; + case NONE: + //Assume RNN in [mb, seqLen, size] format + myInputType = new InputType.InputTypeRecurrent(this.inputShape[0], this.inputShape[1]); + break; + default: + throw new IllegalStateException("Unknown/not supported dimension ordering: " + this.dimOrder); + } + } else { + //Assume RNN in [mb, seqLen, size] format + myInputType = new InputType.InputTypeRecurrent(this.inputShape[0], this.inputShape[1]); + } + break; case 3: switch (this.dimOrder) { diff --git a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/KerasLoss.java b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/KerasLoss.java index 6fd72bd3e..e3c603287 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/KerasLoss.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/KerasLoss.java @@ -27,6 +27,7 @@ import org.deeplearning4j.nn.conf.layers.RnnLossLayer; import org.deeplearning4j.nn.modelimport.keras.KerasLayer; import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException; import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException; +import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.lossfunctions.LossFunctions; import java.util.ArrayList; @@ -96,13 +97,13 @@ public class KerasLoss extends KerasLayer { */ public FeedForwardLayer getLossLayer(InputType type) throws UnsupportedKerasConfigurationException { if (type instanceof InputType.InputTypeFeedForward) { - this.layer = new LossLayer.Builder(loss).name(this.layerName).build(); + this.layer = new LossLayer.Builder(loss).name(this.layerName).activation(Activation.IDENTITY).build(); } else if (type instanceof InputType.InputTypeRecurrent) { - this.layer = new RnnLossLayer.Builder(loss).name(this.layerName).build(); + this.layer = new RnnLossLayer.Builder(loss).name(this.layerName).activation(Activation.IDENTITY).build(); } else if (type instanceof InputType.InputTypeConvolutional) { - this.layer = new CnnLossLayer.Builder(loss).name(this.layerName).build(); + this.layer = new CnnLossLayer.Builder(loss).name(this.layerName).activation(Activation.IDENTITY).build(); } else { throw new UnsupportedKerasConfigurationException("Unsupported output layer type" + "got : " + type.toString()); diff --git a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/convolutional/KerasConvolution.java b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/convolutional/KerasConvolution.java index f1d2f0210..a5f3e15ae 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/convolutional/KerasConvolution.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/convolutional/KerasConvolution.java @@ -79,7 +79,6 @@ abstract public class KerasConvolution extends KerasLayer { public KerasConvolution(Map layerConfig, boolean enforceTrainingConfig) throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException { super(layerConfig, enforceTrainingConfig); - } /** diff --git a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/convolutional/KerasConvolution1D.java b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/convolutional/KerasConvolution1D.java index 3da88d3b1..120870de9 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/convolutional/KerasConvolution1D.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/convolutional/KerasConvolution1D.java @@ -185,18 +185,11 @@ public class KerasConvolution1D extends KerasConvolution { break; case THEANO: - paramValue = kerasParamValue.permute(2, 1, 0); - paramValue = paramValue.reshape( - paramValue.size(0), paramValue.size(1), - paramValue.size(2), 1).dup(); - for (int i = 0; i < paramValue.tensorsAlongDimension(2, 3); i++) { - INDArray copyFilter = paramValue.tensorAlongDimension(i, 2, 3).dup(); - double[] flattenedFilter = copyFilter.ravel().data().asDouble(); - ArrayUtils.reverse(flattenedFilter); - INDArray newFilter = Nd4j.create(flattenedFilter, copyFilter.shape()); - INDArray inPlaceFilter = paramValue.tensorAlongDimension(i, 2, 3); - inPlaceFilter.muli(0).addi(newFilter.castTo(inPlaceFilter.dataType())); - } + //Convert from keras [k,nIn,nOut] to DL4J conv2d [nOut, nIn, k, 1] + long k = kerasParamValue.size(0); + long nIn = kerasParamValue.size(1); + long nOut = kerasParamValue.size(2); + paramValue = kerasParamValue.permute(2, 1, 0).dup('c').reshape(nOut, nIn, k, 1); break; default: throw new InvalidKerasConfigurationException("Unknown keras backend " + this.getDimOrder()); diff --git a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/convolutional/KerasConvolutionUtils.java b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/convolutional/KerasConvolutionUtils.java index b60b41459..0968260b7 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/convolutional/KerasConvolutionUtils.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/convolutional/KerasConvolutionUtils.java @@ -264,7 +264,8 @@ public class KerasConvolutionUtils { } else if (borderMode.equals(conf.getLAYER_BORDER_MODE_VALID()) || borderMode.equals(conf.getLAYER_BORDER_MODE_FULL())) { convolutionMode = ConvolutionMode.Truncate; - + } else if(borderMode.equals(conf.getLAYER_BORDER_MODE_CAUSAL())) { + convolutionMode = ConvolutionMode.Causal; } else { throw new UnsupportedKerasConfigurationException("Unsupported convolution border mode: " + borderMode); } diff --git a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/recurrent/KerasLSTM.java b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/recurrent/KerasLSTM.java index 7d5603261..1c205bbca 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/recurrent/KerasLSTM.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/recurrent/KerasLSTM.java @@ -23,11 +23,13 @@ import lombok.val; import org.deeplearning4j.nn.api.layers.LayerConstraint; import org.deeplearning4j.nn.conf.InputPreProcessor; import org.deeplearning4j.nn.conf.inputs.InputType; +import org.deeplearning4j.nn.conf.layers.FeedForwardLayer; import org.deeplearning4j.nn.conf.layers.InputTypeUtil; import org.deeplearning4j.nn.conf.layers.LSTM; import org.deeplearning4j.nn.conf.layers.Layer; import org.deeplearning4j.nn.conf.layers.recurrent.LastTimeStep; import org.deeplearning4j.nn.conf.layers.util.MaskZeroLayer; +import org.deeplearning4j.nn.conf.layers.wrapper.BaseWrapperLayer; import org.deeplearning4j.nn.modelimport.keras.KerasLayer; import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException; import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException; @@ -186,6 +188,9 @@ public class KerasLSTM extends KerasLayer { .biasInit(0.0) // TODO: this is incorrect .l1(this.weightL1Regularization) .l2(this.weightL2Regularization); + Integer nIn = KerasLayerUtils.getNInFromInputDim(layerConfig, conf); + if(nIn != null) + builder.setNIn(nIn); if (biasConstraint != null) builder.constrainBias(biasConstraint); if (weightConstraint != null) @@ -436,6 +441,20 @@ public class KerasLSTM extends KerasLayer { log.warn("Attemping to set weights for unknown parameters: " + unknownParamNames.substring(1, unknownParamNames.length() - 1)); } + + + FeedForwardLayer ffl; + if(this.layer instanceof BaseWrapperLayer){ + BaseWrapperLayer bwl = (BaseWrapperLayer)this.layer; + ffl = (FeedForwardLayer)bwl.getUnderlying(); + } else { + ffl = (FeedForwardLayer) this.layer; + } + if(ffl.getNIn() != wRows){ + //Workaround/hack for ambiguous input shapes (nIn inference) for some RNN models (using NCW format but not recorded in config) + //We can reliably infer nIn from the shape of the weights array however + ffl.setNIn(wRows); + } } /** diff --git a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/recurrent/KerasSimpleRnn.java b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/recurrent/KerasSimpleRnn.java index 6f5edf597..f6ecbb6a5 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/recurrent/KerasSimpleRnn.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/recurrent/KerasSimpleRnn.java @@ -22,11 +22,13 @@ import lombok.extern.slf4j.Slf4j; import org.deeplearning4j.nn.api.layers.LayerConstraint; import org.deeplearning4j.nn.conf.InputPreProcessor; import org.deeplearning4j.nn.conf.inputs.InputType; +import org.deeplearning4j.nn.conf.layers.FeedForwardLayer; import org.deeplearning4j.nn.conf.layers.InputTypeUtil; import org.deeplearning4j.nn.conf.layers.Layer; import org.deeplearning4j.nn.conf.layers.recurrent.LastTimeStep; import org.deeplearning4j.nn.conf.layers.recurrent.SimpleRnn; import org.deeplearning4j.nn.conf.layers.util.MaskZeroLayer; +import org.deeplearning4j.nn.conf.layers.wrapper.BaseWrapperLayer; import org.deeplearning4j.nn.modelimport.keras.KerasLayer; import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException; import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException; @@ -154,6 +156,9 @@ public class KerasSimpleRnn extends KerasLayer { .biasInit(0.0) .l1(this.weightL1Regularization) .l2(this.weightL2Regularization); + Integer nIn = KerasLayerUtils.getNInFromInputDim(layerConfig, conf); + if(nIn != null) + builder.setNIn(nIn); if (biasConstraint != null) builder.constrainBias(biasConstraint); if (weightConstraint != null) @@ -282,6 +287,19 @@ public class KerasSimpleRnn extends KerasLayer { log.warn("Attemping to set weights for unknown parameters: " + unknownParamNames.substring(1, unknownParamNames.length() - 1)); } + + FeedForwardLayer ffl; + if(this.layer instanceof BaseWrapperLayer){ + BaseWrapperLayer bwl = (BaseWrapperLayer)this.layer; + ffl = (FeedForwardLayer)bwl.getUnderlying(); + } else { + ffl = (FeedForwardLayer) this.layer; + } + if(ffl.getNIn() != W.rows()){ + //Workaround/hack for ambiguous input shapes (nIn inference) for some RNN models (using NCW format but not recorded in config) + //We can reliably infer nIn from the shape of the weights array however + ffl.setNIn(W.rows()); + } } } diff --git a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/wrappers/KerasBidirectional.java b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/wrappers/KerasBidirectional.java index 40f1f7074..d37ee399c 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/wrappers/KerasBidirectional.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/wrappers/KerasBidirectional.java @@ -229,8 +229,8 @@ public class KerasBidirectional extends KerasLayer { @Override public void setWeights(Map weights) throws InvalidKerasConfigurationException { - Map forwardWeights = getUnderlyingWeights(weights, "forward"); - Map backwardWeights = getUnderlyingWeights(weights, "backward"); + Map forwardWeights = getUnderlyingWeights(((Bidirectional)this.layer).getFwd(), weights, "forward"); + Map backwardWeights = getUnderlyingWeights(((Bidirectional)this.layer).getBwd(), weights, "backward"); this.weights = new HashMap<>(); @@ -241,7 +241,7 @@ public class KerasBidirectional extends KerasLayer { } - private Map getUnderlyingWeights(Map weights, String direction) + private Map getUnderlyingWeights(Layer l, Map weights, String direction) throws InvalidKerasConfigurationException { int keras1SubstringLength; if (kerasRnnlayer instanceof KerasLSTM) @@ -270,8 +270,12 @@ public class KerasBidirectional extends KerasLayer { weights = newWeights; } + Layer layerBefore = kerasRnnlayer.getLayer(); + kerasRnnlayer.setLayer(l); kerasRnnlayer.setWeights(weights); - return kerasRnnlayer.getWeights(); + Map ret = kerasRnnlayer.getWeights(); + kerasRnnlayer.setLayer(layerBefore); + return ret; } } diff --git a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/utils/KerasLayerUtils.java b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/utils/KerasLayerUtils.java index 8d80d3f38..3494ecf49 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/utils/KerasLayerUtils.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/utils/KerasLayerUtils.java @@ -505,6 +505,17 @@ public class KerasLayerUtils { return nOut; } + public static Integer getNInFromInputDim(Map layerConfig, KerasLayerConfiguration conf) throws InvalidKerasConfigurationException { + Map innerConfig = KerasLayerUtils.getInnerLayerConfigFromConfig(layerConfig, conf); + if(innerConfig.containsKey(conf.getLAYER_FIELD_INPUT_DIM())){ + Object id = innerConfig.get(conf.getLAYER_FIELD_INPUT_DIM()); + if(id instanceof Number){ + return ((Number)id).intValue(); + } + } + return null; + } + /** * Get dropout from Keras layer configuration. * diff --git a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/e2e/KerasModelEndToEndTest.java b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/e2e/KerasModelEndToEndTest.java index 0565cc091..d4f458a39 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/e2e/KerasModelEndToEndTest.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/e2e/KerasModelEndToEndTest.java @@ -24,6 +24,8 @@ import org.deeplearning4j.eval.ROCMultiClass; import org.deeplearning4j.gradientcheck.GradientCheckUtil; import org.deeplearning4j.nn.api.Layer; import org.deeplearning4j.nn.api.layers.IOutputLayer; +import org.deeplearning4j.nn.conf.ConvolutionMode; +import org.deeplearning4j.nn.conf.layers.Convolution1DLayer; import org.deeplearning4j.nn.conf.layers.FeedForwardLayer; import org.deeplearning4j.nn.conf.layers.LossLayer; import org.deeplearning4j.nn.conf.layers.RnnOutputLayer; @@ -47,6 +49,8 @@ import org.nd4j.linalg.activations.impl.*; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; +import org.nd4j.linalg.function.BiFunction; +import org.nd4j.linalg.function.Function; import org.nd4j.linalg.learning.config.NoOp; import org.nd4j.linalg.lossfunctions.LossFunctions; import org.nd4j.linalg.lossfunctions.impl.LossSparseMCXENT; @@ -58,10 +62,7 @@ import java.io.InputStream; import java.net.URL; import java.nio.file.Files; import java.nio.file.StandardCopyOption; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import java.util.Random; +import java.util.*; import static org.junit.Assert.*; @@ -86,7 +87,16 @@ public class KerasModelEndToEndTest extends BaseDL4JTest { @Rule public final TemporaryFolder testDir = new TemporaryFolder(); - @Test(expected = IllegalStateException.class) + public static final BiFunction nwc2ncwExpected = new BiFunction() { + @Override + public INDArray apply(String s, INDArray array) { + if(array.rank() == 3) + return array.permute(0, 2, 1); //NWC to NCW + return array; + } + }; + + @Test(expected = IllegalStateException.class) public void fileNotFoundEndToEnd() throws Exception { String modelPath = "modelimport/keras/examples/foo/bar.h5"; importEndModelTest(modelPath, null, true, true, false, false); @@ -154,28 +164,28 @@ public class KerasModelEndToEndTest extends BaseDL4JTest { public void importImdbLstmTfKeras1() throws Exception { String modelPath = "modelimport/keras/examples/imdb_lstm/imdb_lstm_tf_keras_1_model.h5"; String inputsOutputPath = "modelimport/keras/examples/imdb_lstm/imdb_lstm_tf_keras_1_inputs_and_outputs.h5"; - importEndModelTest(modelPath, inputsOutputPath, true, true, false, false); + importEndModelTest(modelPath, inputsOutputPath, true, true, false, false, true, null, nwc2ncwExpected); } @Test public void importImdbLstmThKeras1() throws Exception { String modelPath = "modelimport/keras/examples/imdb_lstm/imdb_lstm_th_keras_1_model.h5"; String inputsOutputPath = "modelimport/keras/examples/imdb_lstm/imdb_lstm_th_keras_1_inputs_and_outputs.h5"; - importEndModelTest(modelPath, inputsOutputPath, true, true, false, false); + importEndModelTest(modelPath, inputsOutputPath, true, true, false, false, true, null, nwc2ncwExpected); } @Test public void importImdbLstmTfKeras2() throws Exception { String modelPath = "modelimport/keras/examples/imdb_lstm/imdb_lstm_tf_keras_2_model.h5"; String inputsOutputPath = "modelimport/keras/examples/imdb_lstm/imdb_lstm_tf_keras_2_inputs_and_outputs.h5"; - importEndModelTest(modelPath, inputsOutputPath, true, true, false, false); + importEndModelTest(modelPath, inputsOutputPath, true, true, false, false, true, null, nwc2ncwExpected); } @Test public void importImdbLstmThKeras2() throws Exception { String modelPath = "modelimport/keras/examples/imdb_lstm/imdb_lstm_th_keras_2_model.h5"; String inputsOutputPath = "modelimport/keras/examples/imdb_lstm/imdb_lstm_th_keras_2_inputs_and_outputs.h5"; - importEndModelTest(modelPath, inputsOutputPath, false, true, false, false); + importEndModelTest(modelPath, inputsOutputPath, false, true, false, false, true, null, nwc2ncwExpected); } /** @@ -247,7 +257,7 @@ public class KerasModelEndToEndTest extends BaseDL4JTest { String modelPath = "modelimport/keras/examples/simple_flatten_rnn/simple_flatten_rnn_tf_keras_2_model.h5"; String inputsOutputPath = "modelimport/keras/examples/simple_flatten_rnn/" + "simple_flatten_rnn_tf_keras_2_inputs_and_outputs.h5"; - importEndModelTest(modelPath, inputsOutputPath, true, true, false, false); + importEndModelTest(modelPath, inputsOutputPath, true, true, false, false, true, null, nwc2ncwExpected); } /** @@ -598,6 +608,122 @@ public class KerasModelEndToEndTest extends BaseDL4JTest { model.summary(); } + @Test + public void testCausalCon1D() throws Exception { + String[] names = new String[]{ + "causal_conv1d_k2_s1_d1_cl_model.h5", + "causal_conv1d_k2_s1_d2_cl_model.h5", + "causal_conv1d_k2_s2_d1_cl_model.h5", + "causal_conv1d_k2_s3_d1_cl_model.h5", + "causal_conv1d_k3_s1_d1_cl_model.h5", + "causal_conv1d_k3_s1_d2_cl_model.h5", + "causal_conv1d_k3_s2_d1_cl_model.h5", + "causal_conv1d_k3_s3_d1_cl_model.h5", + "causal_conv1d_k4_s1_d1_cl_model.h5", + "causal_conv1d_k4_s1_d2_cl_model.h5", + "causal_conv1d_k4_s2_d1_cl_model.h5", + "causal_conv1d_k4_s3_d1_cl_model.h5" + }; + + for(String name : names ){ + System.out.println("Starting test: " + name); + String modelPath = "modelimport/keras/examples/causal_conv1d/" + name; + String inputsOutputPath = "modelimport/keras/examples/causal_conv1d/" + (name.substring(0,name.length()-"model.h5".length()) + "inputs_and_outputs.h5"); + Function f = new Function() { + @Override + public INDArray apply(INDArray i) { + //NWC to NCW + return i.permute(0, 2, 1); + } + }; + + MultiLayerNetwork net = importEndModelTest(modelPath, inputsOutputPath, true, true, + true, true, false, f, nwc2ncwExpected); + Layer l = net.getLayer(0); + Convolution1DLayer c1d = (Convolution1DLayer) l.getConfig(); + assertEquals(ConvolutionMode.Causal, c1d.getConvolutionMode()); + } + } + + @Test + public void testCon1D() throws Exception { + String[] names = new String[]{ + "conv1d_k2_s1_d1_cf_same_model.h5", + "conv1d_k2_s1_d1_cf_valid_model.h5", + "conv1d_k2_s1_d1_cl_same_model.h5", + "conv1d_k2_s1_d1_cl_valid_model.h5", + "conv1d_k2_s1_d2_cf_same_model.h5", + "conv1d_k2_s1_d2_cf_valid_model.h5", + "conv1d_k2_s1_d2_cl_same_model.h5", + "conv1d_k2_s1_d2_cl_valid_model.h5", + "conv1d_k2_s2_d1_cf_same_model.h5", + "conv1d_k2_s2_d1_cf_valid_model.h5", + "conv1d_k2_s2_d1_cl_same_model.h5", + "conv1d_k2_s2_d1_cl_valid_model.h5", + "conv1d_k2_s3_d1_cf_same_model.h5", + "conv1d_k2_s3_d1_cf_valid_model.h5", + "conv1d_k2_s3_d1_cl_same_model.h5", + "conv1d_k2_s3_d1_cl_valid_model.h5", + "conv1d_k3_s1_d1_cf_same_model.h5", + "conv1d_k3_s1_d1_cf_valid_model.h5", + "conv1d_k3_s1_d1_cl_same_model.h5", + "conv1d_k3_s1_d1_cl_valid_model.h5", + "conv1d_k3_s1_d2_cf_same_model.h5", + "conv1d_k3_s1_d2_cf_valid_model.h5", + "conv1d_k3_s1_d2_cl_same_model.h5", + "conv1d_k3_s1_d2_cl_valid_model.h5", + "conv1d_k3_s2_d1_cf_same_model.h5", + "conv1d_k3_s2_d1_cf_valid_model.h5", + "conv1d_k3_s2_d1_cl_same_model.h5", + "conv1d_k3_s2_d1_cl_valid_model.h5", + "conv1d_k3_s3_d1_cf_same_model.h5", + "conv1d_k3_s3_d1_cf_valid_model.h5", + "conv1d_k3_s3_d1_cl_same_model.h5", + "conv1d_k3_s3_d1_cl_valid_model.h5", + "conv1d_k4_s1_d1_cf_same_model.h5", + "conv1d_k4_s1_d1_cf_valid_model.h5", + "conv1d_k4_s1_d1_cl_same_model.h5", + "conv1d_k4_s1_d1_cl_valid_model.h5", + "conv1d_k4_s1_d2_cf_same_model.h5", + "conv1d_k4_s1_d2_cf_valid_model.h5", + "conv1d_k4_s1_d2_cl_same_model.h5", + "conv1d_k4_s1_d2_cl_valid_model.h5", + "conv1d_k4_s2_d1_cf_same_model.h5", + "conv1d_k4_s2_d1_cf_valid_model.h5", + "conv1d_k4_s2_d1_cl_same_model.h5", + "conv1d_k4_s2_d1_cl_valid_model.h5", + "conv1d_k4_s3_d1_cf_same_model.h5", + "conv1d_k4_s3_d1_cf_valid_model.h5", + "conv1d_k4_s3_d1_cl_same_model.h5", + "conv1d_k4_s3_d1_cl_valid_model.h5", + }; + + for(String name : names ){ + System.out.println("Starting test: " + name); + String modelPath = "modelimport/keras/examples/conv1d/" + name; + String inputsOutputPath = "modelimport/keras/examples/conv1d/" + (name.substring(0,name.length()-"model.h5".length()) + "inputs_and_outputs.h5"); + Function f = name.contains("_cf_") ? null : new Function() { + @Override + public INDArray apply(INDArray i) { + //NWC to NCW + return i.permute(0, 2, 1); + } + }; + + BiFunction f2 = name.contains("_cf_") ? null : new BiFunction() { + @Override + public INDArray apply(String s, INDArray array) { +// if("conv".equals(s)){ + return array.permute(0, 2, 1); +// } + } + }; + + importEndModelTest(modelPath, inputsOutputPath, true, true, + true, true, false, f, f2); + } + } + private ComputationGraph importFunctionalModelH5Test(String modelPath) throws Exception { return importFunctionalModelH5Test(modelPath, null, false); } @@ -640,6 +766,12 @@ public class KerasModelEndToEndTest extends BaseDL4JTest { public MultiLayerNetwork importEndModelTest(String modelPath, String inputsOutputsPath, boolean tfOrdering, boolean checkPredictions, boolean checkGradients, boolean enforceTrainingConfig) throws Exception { + return importEndModelTest(modelPath, inputsOutputsPath, tfOrdering, checkPredictions, checkGradients, true, enforceTrainingConfig, null, null); + } + + public MultiLayerNetwork importEndModelTest(String modelPath, String inputsOutputsPath, boolean tfOrdering, boolean checkPredictions, + boolean checkGradients, boolean enforceTrainingConfig, boolean checkAuc, Function inputPreProc, + BiFunction expectedPreProc) throws Exception { MultiLayerNetwork model; try(InputStream is = Resources.asStream(modelPath)) { File modelFile = createTempFile(TEMP_MODEL_FILENAME, H5_EXTENSION); @@ -658,20 +790,25 @@ public class KerasModelEndToEndTest extends BaseDL4JTest { if (checkPredictions) { INDArray input = getInputs(outputsArchive, tfOrdering)[0]; + if(inputPreProc != null) + input = inputPreProc.apply(input); + Map activationsKeras = getActivations(outputsArchive, tfOrdering); for (int i = 0; i < model.getLayers().length; i++) { String layerName = model.getLayerNames().get(i); if (activationsKeras.containsKey(layerName)) { INDArray activationsDl4j = model.feedForwardToLayer(i, input, false).get(i + 1); - if (activationsDl4j.shape().length == 3) - activationsDl4j = activationsDl4j.permute(0, 2, 1); - compareINDArrays(layerName, activationsKeras.get(layerName), activationsDl4j, EPS); - + INDArray exp = activationsKeras.get(layerName); + if(expectedPreProc != null) + exp = expectedPreProc.apply(layerName, exp); + compareINDArrays(layerName, exp, activationsDl4j, EPS); } } INDArray predictionsKeras = getPredictions(outputsArchive, tfOrdering)[0]; INDArray predictionsDl4j = model.output(input, false); + if(expectedPreProc != null) + predictionsKeras = expectedPreProc.apply("output", predictionsKeras); compareINDArrays("predictions", predictionsKeras, predictionsDl4j, EPS); INDArray outputs = getOutputs(outputsArchive, true)[0]; @@ -680,7 +817,8 @@ public class KerasModelEndToEndTest extends BaseDL4JTest { } val nOut = (int) outputs.size(-1); - compareMulticlassAUC("predictions", outputs, predictionsKeras, predictionsDl4j, nOut, EPS); + if(checkAuc) + compareMulticlassAUC("predictions", outputs, predictionsKeras, predictionsDl4j, nOut, EPS); } if (checkGradients && ! SKIP_GRAD_CHECKS) { @@ -760,20 +898,23 @@ public class KerasModelEndToEndTest extends BaseDL4JTest { return predictions; } - private static void compareINDArrays(String label, INDArray a, INDArray b, double eps) { - INDArray diff = a.sub(b.castTo(a.dataType())); + private static void compareINDArrays(String label, INDArray expected, INDArray actual, double eps) { + if(!expected.equalShapes(actual)){ + throw new IllegalStateException("Shapes do not match for \"" + label + "\": got " + Arrays.toString(expected.shape()) + " vs " + Arrays.toString(actual.shape())); + } + INDArray diff = expected.sub(actual.castTo(expected.dataType())); double min = diff.minNumber().doubleValue(); double max = diff.maxNumber().doubleValue(); - log.info(label + ": " + a.equalsWithEps(b, eps) + ", " + min + ", " + max); + log.info(label + ": " + expected.equalsWithEps(actual, eps) + ", " + min + ", " + max); double threshold = 1e-7; - double aAbsMax = Math.max(Math.abs(a.minNumber().doubleValue()), Math.abs(a.maxNumber().doubleValue())); - double bAbsMax = Math.max(Math.abs(b.minNumber().doubleValue()), Math.abs(b.maxNumber().doubleValue())); + double aAbsMax = Math.max(Math.abs(expected.minNumber().doubleValue()), Math.abs(expected.maxNumber().doubleValue())); + double bAbsMax = Math.max(Math.abs(actual.minNumber().doubleValue()), Math.abs(actual.maxNumber().doubleValue())); // skip too small absolute inputs if (Math.abs(aAbsMax) > threshold && Math.abs(bAbsMax) > threshold) { - assertTrue(a.equalsWithEps(b.castTo(a.dataType()), eps)); + boolean eq = expected.equalsWithEps(actual.castTo(expected.dataType()), eps); + assertTrue("Output differs: " + label, eq); } - } private static void compareMulticlassAUC(String label, INDArray target, INDArray a, INDArray b, int nbClasses, diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/ConvolutionMode.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/ConvolutionMode.java index d6b1e0b55..4bd1050f0 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/ConvolutionMode.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/ConvolutionMode.java @@ -69,6 +69,18 @@ package org.deeplearning4j.nn.conf; *
*
*
+ * Causal: Causal padding mode can only be used for 1D convolutional neural networks.
+ * The motivation behind causal padding mode is that the output time steps depend only on current and past time steps.
+ * That is, out[t] (for time t) depends on only on values in[T] for t < T
+ * The output size of 1D convolution/subsampling layers is the same as with SAME convolution mode - + * i.e., outSize = ceil( inputSize / stride )
+ * Padding is also the same as SAME mode, but all padding in on the left (start of sequence) instead of being on both + * left and right of the input
+ * For more details on causal convolutions, see WaveNet: A Generative Model For Audio, + * section 2.1. + *
+ *
+ *
* For further information on output sizes for convolutional neural networks, see the "Spatial arrangement" section at * http://cs231n.github.io/convolutional-networks/ * @@ -76,6 +88,6 @@ package org.deeplearning4j.nn.conf; */ public enum ConvolutionMode { - Strict, Truncate, Same + Strict, Truncate, Same, Causal } diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Convolution1DLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Convolution1DLayer.java index d4ccc4811..b220ba5a6 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Convolution1DLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Convolution1DLayer.java @@ -124,6 +124,11 @@ public class Convolution1DLayer extends ConvolutionLayer { this.setKernelSize((int[]) null); } + @Override + protected boolean allowCausal() { + return true; + } + /** * @param kernelSize Kernel size * @param stride Stride diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Convolution3D.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Convolution3D.java index 61475bf98..cc26169cf 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Convolution3D.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Convolution3D.java @@ -163,6 +163,12 @@ public class Convolution3D extends ConvolutionLayer { super(new int[] {2, 2, 2}, new int[] {1, 1, 1}, new int[] {0, 0, 0}, new int[] {1, 1, 1}, 3); } + @Override + protected boolean allowCausal() { + //Causal convolution - allowed for 1D only + return false; + } + public Builder(int[] kernelSize, int[] stride, int[] padding, int[] dilation) { super(kernelSize, stride, padding, dilation, 3); } diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/ConvolutionLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/ConvolutionLayer.java index 4fdf1e9cc..b0c5bb3d4 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/ConvolutionLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/ConvolutionLayer.java @@ -30,6 +30,7 @@ import org.deeplearning4j.nn.params.ConvolutionParamInitializer; import org.deeplearning4j.optimize.api.TrainingListener; import org.deeplearning4j.util.ConvolutionUtils; import org.deeplearning4j.util.ValidationUtils; +import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; @@ -283,6 +284,12 @@ public class ConvolutionLayer extends FeedForwardLayer { super(); } + @Override + protected boolean allowCausal() { + //Causal convolution - allowed for 1D only + return false; + } + /** * Size of the convolution rows/columns * @@ -456,6 +463,14 @@ public class ConvolutionLayer extends FeedForwardLayer { protected BaseConvBuilder() {} + protected abstract boolean allowCausal(); + + protected void setConvolutionMode(ConvolutionMode convolutionMode){ + Preconditions.checkState(allowCausal() || convolutionMode != ConvolutionMode.Causal, "Causal convolution mode can only be used with 1D" + + " convolutional neural network layers"); + this.convolutionMode = convolutionMode; + } + /** * If true (default): include bias parameters in the model. False: no bias. * diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Deconvolution2D.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Deconvolution2D.java index 03b6ec405..11c9fdb7b 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Deconvolution2D.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Deconvolution2D.java @@ -133,6 +133,12 @@ public class Deconvolution2D extends ConvolutionLayer { super(); } + @Override + protected boolean allowCausal() { + //Causal convolution - allowed for 1D only + return false; + } + /** * Set the convolution mode for the Convolution layer. See {@link ConvolutionMode} for more details * diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/DepthwiseConvolution2D.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/DepthwiseConvolution2D.java index 03fec1191..e103cb0a0 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/DepthwiseConvolution2D.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/DepthwiseConvolution2D.java @@ -133,6 +133,12 @@ public class DepthwiseConvolution2D extends ConvolutionLayer { super(); } + @Override + protected boolean allowCausal() { + //Causal convolution - allowed for 1D only + return false; + } + /** * Set channels multiplier for depth-wise convolution * diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/SeparableConvolution2D.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/SeparableConvolution2D.java index 181cc5311..133c14869 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/SeparableConvolution2D.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/SeparableConvolution2D.java @@ -184,6 +184,12 @@ public class SeparableConvolution2D extends ConvolutionLayer { super(); } + @Override + protected boolean allowCausal() { + //Causal convolution - allowed for 1D only + return false; + } + /** * Set channels multiplier of channels-wise step in separable convolution * diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Subsampling1DLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Subsampling1DLayer.java index 4da7ff011..9f3162374 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Subsampling1DLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Subsampling1DLayer.java @@ -167,6 +167,11 @@ public class Subsampling1DLayer extends SubsamplingLayer { this(poolingType, DEFAULT_KERNEL, DEFAULT_STRIDE, DEFAULT_PADDING); } + @Override + protected boolean allowCausal() { + return true; + } + public Builder() { this(DEFAULT_POOLING, DEFAULT_KERNEL, DEFAULT_STRIDE, DEFAULT_PADDING); } diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Subsampling3DLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Subsampling3DLayer.java index 2fcc345a1..550e29e4f 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Subsampling3DLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Subsampling3DLayer.java @@ -431,6 +431,12 @@ public class Subsampling3DLayer extends NoParamLayer { this.setPoolingType(poolingType); } + protected void setConvolutionMode(ConvolutionMode convolutionMode){ + Preconditions.checkState(convolutionMode != ConvolutionMode.Causal, "Causal convolution mode can only be used with 1D" + + " convolutional neural network layers"); + this.convolutionMode = convolutionMode; + } + /** * Set the convolution mode for the Convolution layer. See {@link ConvolutionMode} for more details * diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/SubsamplingLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/SubsamplingLayer.java index c20526cf1..be6764e9a 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/SubsamplingLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/SubsamplingLayer.java @@ -28,6 +28,7 @@ import org.deeplearning4j.nn.params.EmptyParamInitializer; import org.deeplearning4j.optimize.api.TrainingListener; import org.deeplearning4j.util.ConvolutionUtils; import org.deeplearning4j.util.ValidationUtils; +import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; @@ -270,6 +271,12 @@ public class SubsamplingLayer extends NoParamLayer { super(poolingType); } + @Override + protected boolean allowCausal() { + //Only conv1d/subsampling1d can use causal mode + return false; + } + /** * Kernel size * @@ -449,6 +456,14 @@ public class SubsamplingLayer extends NoParamLayer { this.eps = eps; } + protected abstract boolean allowCausal(); + + public void setConvolutionMode(ConvolutionMode convolutionMode){ + Preconditions.checkState(allowCausal() || convolutionMode != ConvolutionMode.Causal, "Causal convolution mode can only be used with 1D" + + " convolutional neural network layers"); + this.convolutionMode = convolutionMode; + } + /** * Set the convolution mode for the Convolution layer. See {@link ConvolutionMode} for more details * diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/Convolution1DLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/Convolution1DLayer.java index 1ffd19062..985c2f06b 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/Convolution1DLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/Convolution1DLayer.java @@ -18,18 +18,30 @@ package org.deeplearning4j.nn.layers.convolution; import org.deeplearning4j.exception.DL4JInvalidInputException; import org.deeplearning4j.nn.api.MaskState; +import org.deeplearning4j.nn.conf.ConvolutionMode; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; +import org.deeplearning4j.nn.gradient.DefaultGradient; import org.deeplearning4j.nn.gradient.Gradient; +import org.deeplearning4j.nn.params.ConvolutionParamInitializer; +import org.deeplearning4j.util.Convolution1DUtils; import org.deeplearning4j.util.ConvolutionUtils; import org.nd4j.base.Preconditions; +import org.nd4j.linalg.activations.IActivation; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.impl.layers.convolution.Conv1D; +import org.nd4j.linalg.api.ops.impl.layers.convolution.Conv1DDerivative; +import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv1DConfig; +import org.nd4j.linalg.api.ops.impl.layers.convolution.config.PaddingMode; +import org.nd4j.linalg.api.shape.LongShapeDescriptor; import org.nd4j.linalg.factory.Broadcast; +import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.primitives.Pair; import org.deeplearning4j.nn.workspace.ArrayType; import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; import java.util.Arrays; +import java.util.List; /** * 1D (temporal) convolutional layer. Currently, we just subclass off the @@ -70,6 +82,52 @@ public class Convolution1DLayer extends ConvolutionLayer { Broadcast.mul(epsilon, maskOut, epsilon, 0, 2); } + if(layerConf().getConvolutionMode() == ConvolutionMode.Causal){ + Pair fwd = causalConv1dForward(); + IActivation afn = layerConf().getActivationFn(); + INDArray delta = afn.backprop(fwd.getFirst(), epsilon).getFirst(); //TODO handle activation function params + + //TODO eventually we'll use this for all convolution modes - but only after libnd4j has cuDNN support + org.deeplearning4j.nn.conf.layers.Convolution1DLayer c = (org.deeplearning4j.nn.conf.layers.Convolution1DLayer) layerConf(); + Conv1DConfig conf = Conv1DConfig.builder() + .k(c.getKernelSize()[0]) + .s(c.getStride()[0]) + .d(c.getDilation()[0]) + .p(c.getPadding()[0]) + .dataFormat(Conv1DConfig.NCW) + .paddingMode(PaddingMode.CAUSAL) + .build(); + + INDArray w = getParam(ConvolutionParamInitializer.WEIGHT_KEY); + w = w.reshape(w.ordering(), w.size(0), w.size(1), w.size(2)).permute(2, 1, 0); //[oC, iC, k, 1] to [k, iC, oC] + + INDArray[] inputArrs; + INDArray[] outputArrs; + INDArray wg = gradientViews.get(ConvolutionParamInitializer.WEIGHT_KEY); + wg = wg.reshape(wg.ordering(), wg.size(0), wg.size(1), wg.size(2)).permute(2, 1, 0); //[oC, iC, k, 1] -> [kW, iC, oC] + INDArray epsOut = workspaceMgr.createUninitialized(ArrayType.ACTIVATION_GRAD, input.dataType(), input.shape()); + if(layerConf().hasBias()){ + INDArray b = getParam(ConvolutionParamInitializer.BIAS_KEY); + b = b.reshape(b.length()); + inputArrs = new INDArray[]{input.castTo(w.dataType()), w, b, delta}; + INDArray bg = gradientViews.get(ConvolutionParamInitializer.BIAS_KEY); + bg = bg.reshape(bg.length()); + outputArrs = new INDArray[]{epsOut, wg, bg}; + } else { + inputArrs = new INDArray[]{input.castTo(w.dataType()), w, delta}; + outputArrs = new INDArray[]{epsOut, wg}; + } + Conv1DDerivative op = new Conv1DDerivative(inputArrs, outputArrs, conf); + Nd4j.exec(op); + + Gradient retGradient = new DefaultGradient(); + if(layerConf().hasBias()){ + retGradient.setGradientFor(ConvolutionParamInitializer.BIAS_KEY, gradientViews.get(ConvolutionParamInitializer.BIAS_KEY)); + } + retGradient.setGradientFor(ConvolutionParamInitializer.WEIGHT_KEY, gradientViews.get(ConvolutionParamInitializer.WEIGHT_KEY), 'c'); + return new Pair<>(retGradient, epsOut); + } + // add singleton fourth dimension to input and next layer's epsilon epsilon = epsilon.reshape(epsilon.size(0), epsilon.size(1), epsilon.size(2), 1); INDArray origInput = input; @@ -98,6 +156,12 @@ public class Convolution1DLayer extends ConvolutionLayer { @Override protected Pair preOutput(boolean training, boolean forBackprop, LayerWorkspaceMgr workspaceMgr) { assertInputSet(false); + + if(layerConf().getConvolutionMode() == ConvolutionMode.Causal){ + return causalConv1dForward(); + } + + INDArray origInput = input; input = input.reshape(input.size(0), input.size(1), input.size(2), 1); @@ -113,6 +177,36 @@ public class Convolution1DLayer extends ConvolutionLayer { return preOutput; } + protected Pair causalConv1dForward(){ + //TODO eventually we'll use this for all convolution modes - but only after libnd4j has cuDNN support + org.deeplearning4j.nn.conf.layers.Convolution1DLayer c = (org.deeplearning4j.nn.conf.layers.Convolution1DLayer) layerConf(); + Conv1DConfig conf = Conv1DConfig.builder() + .k(c.getKernelSize()[0]) + .s(c.getStride()[0]) + .d(c.getDilation()[0]) + .p(c.getPadding()[0]) + .dataFormat(Conv1DConfig.NCW) + .paddingMode(PaddingMode.CAUSAL) + .build(); + INDArray w = getParam(ConvolutionParamInitializer.WEIGHT_KEY); + w = w.reshape(w.ordering(), w.size(0), w.size(1), w.size(2)).permute(2, 1, 0); //[oC, iC, k, 1] to [k, iC, oC] + + INDArray[] inputs; + if(layerConf().hasBias()){ + INDArray b = getParam(ConvolutionParamInitializer.BIAS_KEY); + b = b.reshape(b.length()); + inputs = new INDArray[]{input.castTo(w.dataType()), w, b}; + } else { + inputs = new INDArray[]{input.castTo(w.dataType()), w}; + } + + Conv1D op = new Conv1D(inputs, null, conf); + List outShape = op.calculateOutputShape(); + op.setOutputArgument(0, Nd4j.create(outShape.get(0), false)); + Nd4j.exec(op); + return new Pair<>(op.getOutputArgument(0), null); + } + @Override public INDArray activate(boolean training, LayerWorkspaceMgr workspaceMgr){ INDArray act4d = super.activate(training, workspaceMgr); diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/util/Convolution1DUtils.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/util/Convolution1DUtils.java index f0c8d76c9..165483e1c 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/util/Convolution1DUtils.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/util/Convolution1DUtils.java @@ -66,7 +66,7 @@ public class Convolution1DUtils { public static long getOutputSize(long inH, int kernel, int strides, int padding, ConvolutionMode convolutionMode, int dilation) { long eKernel = effectiveKernelSize(kernel, dilation); - if (convolutionMode == ConvolutionMode.Same) { + if (convolutionMode == ConvolutionMode.Same || convolutionMode == ConvolutionMode.Causal) { return (int) Math.ceil(inH / ((double) strides)); } return (inH - eKernel + 2 * padding) / strides + 1; @@ -92,7 +92,7 @@ public class Convolution1DUtils { boolean atrous = (eKernel == kernel); validateShapes(inputData, eKernel, strides, padding, convolutionMode, dilation, inH, atrous); - if (convolutionMode == ConvolutionMode.Same) { + if (convolutionMode == ConvolutionMode.Same || convolutionMode == ConvolutionMode.Causal) { int outH = (int) Math.ceil(inH / ((double) strides)); return outH; } @@ -106,8 +106,9 @@ public class Convolution1DUtils { boolean atrous) { int inH = inShape; + boolean t = convolutionMode == ConvolutionMode.Truncate; - if (convolutionMode != ConvolutionMode.Same && (eKernel <= 0 || eKernel > inH + 2 * padding)) { + if (t && (eKernel <= 0 || eKernel > inH + 2 * padding)) { StringBuilder sb = new StringBuilder(); sb.append("Invalid input data or configuration: "); if (atrous) sb.append("effective "); diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/util/ConvolutionUtils.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/util/ConvolutionUtils.java index 56421bc00..3a447c361 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/util/ConvolutionUtils.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/util/ConvolutionUtils.java @@ -121,7 +121,7 @@ public class ConvolutionUtils { int[] inShape = new int[]{inH, inW}; validateShapes(inputData, eKernel, strides, padding, convolutionMode, dilation, inShape, atrous); - if (convolutionMode == ConvolutionMode.Same) { + if (convolutionMode == ConvolutionMode.Same || convolutionMode == ConvolutionMode.Causal) { int outH = (int) Math.ceil(inH / ((double) strides[0])); int outW = (int) Math.ceil(inW / ((double) strides[1])); @@ -142,7 +142,9 @@ public class ConvolutionUtils { int inH = inShape[0]; int inW = inShape[1]; - if (convolutionMode != ConvolutionMode.Same && (eKernel[0] <= 0 || eKernel[0] > inH + 2 * padding[0])) { + boolean t = (convolutionMode == ConvolutionMode.Truncate); + + if (t && (eKernel[0] <= 0 || eKernel[0] > inH + 2 * padding[0])) { StringBuilder sb = new StringBuilder(); sb.append("Invalid input data or configuration: "); if (atrous) sb.append("effective "); @@ -158,7 +160,7 @@ public class ConvolutionUtils { throw new DL4JInvalidInputException(sb.toString()); } - if (convolutionMode != ConvolutionMode.Same && (eKernel[1] <= 0 || eKernel[1] > inW + 2 * padding[1])) { + if (t && (eKernel[1] <= 0 || eKernel[1] > inW + 2 * padding[1])) { StringBuilder sb = new StringBuilder(); sb.append("Invalid input data or configuration: "); if (atrous) sb.append("effective "); @@ -175,8 +177,7 @@ public class ConvolutionUtils { throw new DL4JInvalidInputException(sb.toString()); } - if (eKernel.length == 3 && convolutionMode != ConvolutionMode.Same - && (eKernel[2] <= 0 || eKernel[2] > inShape[2] + 2 * padding[2])) { + if (eKernel.length == 3 && t && (eKernel[2] <= 0 || eKernel[2] > inShape[2] + 2 * padding[2])) { int inD = inShape[2]; StringBuilder sb = new StringBuilder(); sb.append("Invalid input data or configuration: "); @@ -615,7 +616,7 @@ public class ConvolutionUtils { */ public static INDArray cnn1dMaskReduction(INDArray in, int kernel, int stride, int padding, int dilation, ConvolutionMode cm){ Preconditions.checkState(in.rank()==2, "Rank must be 2 for cnn1d mask array - shape ", in.shape()); - if(cm == ConvolutionMode.Same && stride == 1 ){ + if((cm == ConvolutionMode.Same || cm == ConvolutionMode.Causal) && stride == 1 ){ return in; } @@ -630,7 +631,7 @@ public class ConvolutionUtils { int[] k = new int[]{kernel,1}; int[] s = new int[]{stride, 1}; int[] d = new int[]{dilation, 1}; - if (cm == ConvolutionMode.Same) { + if (cm == ConvolutionMode.Same || cm == ConvolutionMode.Causal) { outSize = ConvolutionUtils.getOutputSize(reshaped4d, k, s, null, cm, d); //Also performs validation } else { pad = new int[]{padding, 0}; @@ -645,7 +646,7 @@ public class ConvolutionUtils { .sH(s[0]).sW(s[1]) .pH(pad == null ? 0 : pad[0]).pW(pad == null ? 0 : pad[1]) .dH(d[0]).dW(d[1]) - .isSameMode(cm== ConvolutionMode.Same) + .isSameMode(cm == ConvolutionMode.Same || cm == ConvolutionMode.Causal) .isNHWC(false) .build());