diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/GradientCheckTestsComputationGraph.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/GradientCheckTestsComputationGraph.java
index 1eb893e3f..623158c68 100644
--- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/GradientCheckTestsComputationGraph.java
+++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/GradientCheckTestsComputationGraph.java
@@ -16,6 +16,7 @@
package org.deeplearning4j.gradientcheck;
+import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.TestUtils;
import org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator;
@@ -53,9 +54,9 @@ import java.util.Arrays;
import java.util.Map;
import java.util.Random;
-import static org.junit.Assert.assertEquals;
-import static org.junit.Assert.assertTrue;
+import static org.junit.Assert.*;
+@Slf4j
public class GradientCheckTestsComputationGraph extends BaseDL4JTest {
public static final boolean PRINT_RESULTS = true;
@@ -287,6 +288,56 @@ public class GradientCheckTestsComputationGraph extends BaseDL4JTest {
}
}
+ @Test
+ public void testElementWiseVertexBroadcast(){
+
+ ElementWiseVertex.Op[] ops =
+ new ElementWiseVertex.Op[] {ElementWiseVertex.Op.Add, ElementWiseVertex.Op.Average,
+ ElementWiseVertex.Op.Subtract, ElementWiseVertex.Op.Max, ElementWiseVertex.Op.Product};
+
+ for(boolean firstSmaller : new boolean[]{false, true}) {
+ for (ElementWiseVertex.Op op : ops) {
+ ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder()
+ .updater(new NoOp())
+ .dataType(DataType.DOUBLE)
+ .activation(Activation.TANH)
+ .seed(12345)
+ .graphBuilder()
+ .addInputs("in")
+ .setOutputs("out")
+ .layer("l1", new DenseLayer.Builder().nIn(3).nOut(firstSmaller ? 1 : 3).build(), "in") //[mb,3]
+ .layer("l2", new DenseLayer.Builder().nIn(3).nOut(firstSmaller ? 3 : 1).build(), "in") //[mb,1]
+ .addVertex("ew", new ElementWiseVertex(op), "l1", "l2")
+ .layer("out", new OutputLayer.Builder().nIn(3).nOut(2).lossFunction(LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).build(), "ew")
+ .build();
+
+ ComputationGraph graph = new ComputationGraph(conf);
+ graph.init();
+
+ for (int mb : new int[]{1, 5}) {
+ String msg = (firstSmaller ? "first smaller, " : "second smaller, ") + "mb=" + mb + ", op=" + op;
+
+ log.info("Test: {}", msg);
+
+ INDArray in = Nd4j.rand(DataType.FLOAT, mb, 3);
+
+ INDArray out = graph.outputSingle(in);
+ assertArrayEquals(new long[]{mb, 2}, out.shape());
+
+ INDArray labels = TestUtils.randomOneHot(mb, 2);
+
+ graph.fit(new DataSet(in, labels));
+
+ boolean gradOK = GradientCheckUtil.checkGradients(graph, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR,
+ DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, new INDArray[]{in},
+ new INDArray[]{labels});
+ assertTrue(msg, gradOK);
+ TestUtils.testModelSerialization(graph);
+ }
+ }
+ }
+ }
+
@Test
public void testCnnDepthMerge() {
diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/conf/graph/ElementWiseVertexTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/conf/graph/ElementWiseVertexTest.java
index afae4b1dc..5bbb8846d 100644
--- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/conf/graph/ElementWiseVertexTest.java
+++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/conf/graph/ElementWiseVertexTest.java
@@ -17,6 +17,7 @@
package org.deeplearning4j.nn.conf.graph;
import org.deeplearning4j.BaseDL4JTest;
+import org.deeplearning4j.TestUtils;
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
@@ -34,6 +35,7 @@ import org.nd4j.linalg.activations.impl.ActivationTanH;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.rng.distribution.impl.UniformDistribution;
+import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.learning.config.Sgd;
import org.nd4j.linalg.lossfunctions.LossFunctions.LossFunction;
@@ -42,6 +44,8 @@ import org.nd4j.linalg.primitives.Pair;
import java.util.Map;
+import static org.junit.Assert.assertArrayEquals;
+
/**
* Created by binesh on 6/14/2017.
*/
@@ -690,6 +694,7 @@ public class ElementWiseVertexTest extends BaseDL4JTest {
Assert.assertEquals(0, mse(nullsafe(gradients.get("dense2_b")), dEdb2), this.epsilon);
}
+
private static double mse(INDArray output, INDArray target) {
double mse_expect = Transforms.pow(output.sub(target), 2.0).sumNumber().doubleValue()
/ (output.columns() * output.rows());
diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/ComputationGraphConfiguration.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/ComputationGraphConfiguration.java
index a259c7b2b..6cd8f06b3 100644
--- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/ComputationGraphConfiguration.java
+++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/ComputationGraphConfiguration.java
@@ -350,7 +350,7 @@ public class ComputationGraphConfiguration implements Serializable, Cloneable {
"Use .addInputs(String...) to label (and give an ordering to) the network inputs");
}
if ((networkOutputs == null || networkOutputs.isEmpty()) && !allowNoOutput) {
- throw new IllegalStateException("Invalid configuration: network has no outputs." +
+ throw new IllegalStateException("Invalid configuration: network has no outputs. " +
"Use .setOutput(String...) to specify (and give an ordering to) the output vertices, " +
"or use allowNoOutputs(true) to disable this check");
}
diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/ElementWiseVertex.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/ElementWiseVertex.java
index d9fe6acee..b22aae451 100644
--- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/ElementWiseVertex.java
+++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/ElementWiseVertex.java
@@ -27,15 +27,20 @@ import org.nd4j.linalg.api.memory.MemoryWorkspace;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.CustomOp;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
+import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastTo;
import org.nd4j.linalg.api.ops.impl.transforms.bool.MatchConditionTransform;
+import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.SubOp;
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.bool.Or;
-import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.OldSubOp;
+import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.conditions.Conditions;
+import org.nd4j.linalg.ops.transforms.Transforms;
import org.nd4j.linalg.primitives.Pair;
import org.deeplearning4j.nn.workspace.ArrayType;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
+import java.util.Arrays;
+
/** An ElementWiseVertex is used to combine the activations of two or more layer in an element-wise manner
* For example, the activations may be combined by addition, subtraction or multiplication or by selecting the maximum.
* Addition, Average, Product and Max may use an arbitrary number of input arrays. Note that in the case of subtraction, only two inputs may be used.
@@ -80,17 +85,44 @@ public class ElementWiseVertex extends BaseGraphVertex {
if (inputs.length == 1)
return workspaceMgr.dup(ArrayType.ACTIVATIONS, inputs[0]);
+ boolean isBc = false;
+ for(int i=1; i(null, new INDArray[] {workspaceMgr.dup(ArrayType.ACTIVATION_GRAD, epsilon)});
+ boolean broadcastCase = false;
+ for( int i=1; i input 0 backprops epsilon, input 1 backprops epsilon.sum(1,keepDim=true)
+ if(inputs[i].equalShapes(epsilon)){
+ out[i] = workspaceMgr.dup(ArrayType.ACTIVATION_GRAD, epsilon);
+ } else {
+ int[] bcDim = Shape.getBroadcastDimensions(inputs[i].shape(), epsilon.shape());
+ try(MemoryWorkspace ws = workspaceMgr.notifyScopeBorrowed(ArrayType.ACTIVATION_GRAD)){
+ out[i] = epsilon.sum(true, bcDim);
+ }
+ }
+ }
+ }
return new Pair<>(null, out);
case Average:
INDArray[] outAverage = new INDArray[nInForwardPass];
try(MemoryWorkspace ws = workspaceMgr.notifyScopeBorrowed(ArrayType.ACTIVATION_GRAD)){
- for (int i = 0; i < nInForwardPass; i++)
- outAverage[i] = epsilon.div(nInForwardPass);
+ for (int i = 0; i < nInForwardPass; i++) {
+ if(inputs[i].equalShapes(epsilon)){
+ outAverage[i] = epsilon.div(nInForwardPass);
+ } else {
+ int[] bcDim = Shape.getBroadcastDimensions(inputs[i].shape(), epsilon.shape());
+ outAverage[i] = epsilon.div(nInForwardPass).sum(true, bcDim);
+ }
+ }
}
return new Pair<>(null, outAverage);
case Subtract:
INDArray[] out2 = new INDArray[2];
- out2[0] = workspaceMgr.dup(ArrayType.ACTIVATION_GRAD, epsilon);
- out2[1] = workspaceMgr.dup(ArrayType.ACTIVATION_GRAD, epsilon).negi();
+ if(!broadcastCase){
+ out2[0] = workspaceMgr.dup(ArrayType.ACTIVATION_GRAD, epsilon);
+ out2[1] = workspaceMgr.dup(ArrayType.ACTIVATION_GRAD, epsilon).negi();
+ } else {
+ if(inputs[0].equalShapes(epsilon)){
+ //Second input is smaller/broadcast
+ out2[0] = workspaceMgr.dup(ArrayType.ACTIVATION_GRAD, epsilon);
+ int[] bcDim = Shape.getBroadcastDimensions(inputs[1].shape(), epsilon.shape());
+ try(MemoryWorkspace ws = workspaceMgr.notifyScopeBorrowed(ArrayType.ACTIVATION_GRAD)) {
+ out2[1] = epsilon.sum(true, bcDim).negi();
+ }
+ } else {
+ //First input is smaller/broadcast
+ int[] bcDim = Shape.getBroadcastDimensions(inputs[0].shape(), epsilon.shape());
+ try(MemoryWorkspace ws = workspaceMgr.notifyScopeBorrowed(ArrayType.ACTIVATION_GRAD)) {
+ out2[0] = epsilon.sum(true, bcDim);
+ }
+ out2[1] = workspaceMgr.dup(ArrayType.ACTIVATION_GRAD, epsilon).negi();
+ }
+ }
return new Pair<>(null, out2);
case Product:
INDArray[] out_product = new INDArray[nInForwardPass];
+ INDArray[] inBc = inputs;
+ if(broadcastCase){
+ inBc = new INDArray[inputs.length];
+ for( int i=0; i(null, out_product);
case Max:
INDArray[] outMax = new INDArray[nInForwardPass];
INDArray maxIndices = workspaceMgr.createUninitialized(ArrayType.BP_WORKING_MEM, DataType.INT, epsilon.shape(), epsilon.ordering());
+
+ INDArray[] bcIn = inputs;
+ if(broadcastCase){
+ //Broadcast to right shape...
+ bcIn = new INDArray[inputs.length];
+ for( int i=0; i(null, outMax);
default:
diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseNDArray.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseNDArray.java
index e77dd7fd0..3e85772cb 100644
--- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseNDArray.java
+++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseNDArray.java
@@ -2476,6 +2476,7 @@ public abstract class BaseNDArray implements INDArray, Iterable {
// length/data.length can be different in case of Threshold conversion
if(isEmpty() || isS())
return false;
+
return Shape.offset(jvmShapeInfo.javaShapeInformation) > 0
|| (length() < data().length() && data.dataType() != DataType.INT)
|| data().originalDataBuffer() != null;
@@ -4577,7 +4578,7 @@ public abstract class BaseNDArray implements INDArray, Iterable {
return ret;
} else {
INDArray ret = this.dup(order);
- return ret.reshape(order, shape);
+ return Nd4j.create(ret.data(), shape);
}
}
diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Max.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Max.java
index d0de5fb21..6c877f96d 100644
--- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Max.java
+++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Max.java
@@ -44,6 +44,10 @@ public class Max extends BaseDynamicTransformOp {
super(sameDiff, args, inPlace);
}
+ public Max( INDArray first, INDArray second, INDArray out){
+ super(new INDArray[]{first, second}, out == null ? null : new INDArray[]{out});
+ }
+
public Max( INDArray[] inputs, INDArray[] outputs) {
super(inputs, outputs);
}
diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Min.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Min.java
index 97beae406..73bfbacc7 100644
--- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Min.java
+++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Min.java
@@ -44,6 +44,10 @@ public class Min extends BaseDynamicTransformOp {
super(sameDiff, args, inPlace);
}
+ public Min( INDArray first, INDArray second, INDArray out){
+ super(new INDArray[]{first, second}, out == null ? null : new INDArray[]{out});
+ }
+
public Min( INDArray[] inputs, INDArray[] outputs) {
super(inputs, outputs);
}
diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Max.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Max.java
index 2816d4e60..db682174c 100644
--- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Max.java
+++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Max.java
@@ -26,7 +26,7 @@ import java.util.Collections;
import java.util.List;
/**
- * Calculate the absolute minimum over a vector
+ * Calculate the maximum value between two arrays in an elementwise fashion, broadcasting if required
*
* @author raver119@gmail.com
*/
diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Min.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Min.java
index eb24acdef..6585ace19 100644
--- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Min.java
+++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Min.java
@@ -26,7 +26,7 @@ import java.util.Collections;
import java.util.List;
/**
- * Calculate the absolute minimum over a vector
+ * Calculate the minimum value between two arrays in an elementwise fashion, broadcasting if required
*
* @author raver119@gmail.com
*/
diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/Nd4j.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/Nd4j.java
index b82a9e19e..a63d6c43a 100644
--- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/Nd4j.java
+++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/Nd4j.java
@@ -538,7 +538,7 @@ public class Nd4j {
public static INDArray create(int[] sliceShape, float[]... arrays) {
//TODO: Remove duplicate code.
int slices = arrays.length;
- INDArray ret = Nd4j.create(ArrayUtil.combine(new int[] {slices}, sliceShape));
+ INDArray ret = Nd4j.createUninitialized(DataType.FLOAT, ArrayUtil.toLongArray(ArrayUtil.combine(new int[] {slices}, sliceShape)));
for (int i = 0; i < ret.slices(); i++)
ret.putSlice(i, Nd4j.create(arrays[i]).reshape(ArrayUtil.toLongArray(sliceShape)));
return ret;
@@ -572,7 +572,7 @@ public class Nd4j {
*/
public static INDArray create(int[] sliceShape, double[]... arrays) {
int slices = arrays.length;
- INDArray ret = Nd4j.create(ArrayUtil.combine(new int[] {slices}, sliceShape));
+ INDArray ret = Nd4j.createUninitialized(DataType.DOUBLE, ArrayUtil.toLongArray(ArrayUtil.combine(new int[] {slices}, sliceShape)));
for (int i = 0; i < ret.slices(); i++)
ret.putSlice(i, Nd4j.create(arrays[i]).reshape(ArrayUtil.toLongArray(sliceShape)));
return ret;
@@ -3984,6 +3984,7 @@ public class Nd4j {
* @return the created ndarray.
*/
public static INDArray create(int[] data, long[] shape, DataType type) {
+ checkShapeValues(data.length, shape);
return INSTANCE.create(data, shape, Nd4j.getStrides(shape), type, Nd4j.getMemoryManager().getCurrentWorkspace());
}
@@ -3991,6 +3992,7 @@ public class Nd4j {
* See {@link #create(int[], long[], DataType)}
*/
public static INDArray create(long[] data, long[] shape, DataType type) {
+ checkShapeValues(data.length, shape);
return INSTANCE.create(data, shape, Nd4j.getStrides(shape), type, Nd4j.getMemoryManager().getCurrentWorkspace());
}
@@ -3998,6 +4000,7 @@ public class Nd4j {
* See {@link #create(int[], long[], DataType)}
*/
public static INDArray create(double[] data, long[] shape, DataType type) {
+ checkShapeValues(data.length, shape);
return INSTANCE.create(data, shape, Nd4j.getStrides(shape), type, Nd4j.getMemoryManager().getCurrentWorkspace());
}
@@ -4005,6 +4008,7 @@ public class Nd4j {
* See {@link #create(int[], long[], DataType)}
*/
public static INDArray create(float[] data, long[] shape, DataType type) {
+ checkShapeValues(data.length, shape);
return INSTANCE.create(data, shape, Nd4j.getStrides(shape), type, Nd4j.getMemoryManager().getCurrentWorkspace());
}
@@ -4012,6 +4016,7 @@ public class Nd4j {
* See {@link #create(int[], long[], DataType)}
*/
public static INDArray create(short[] data, long[] shape, DataType type) {
+ checkShapeValues(data.length, shape);
return INSTANCE.create(data, shape, Nd4j.getStrides(shape), type, Nd4j.getMemoryManager().getCurrentWorkspace());
}
@@ -4019,6 +4024,7 @@ public class Nd4j {
* See {@link #create(int[], long[], DataType)}
*/
public static INDArray create(byte[] data, long[] shape, DataType type) {
+ checkShapeValues(data.length, shape);
return INSTANCE.create(data, shape, Nd4j.getStrides(shape), type, Nd4j.getMemoryManager().getCurrentWorkspace());
}
@@ -4026,6 +4032,7 @@ public class Nd4j {
* See {@link #create(int[], long[], DataType)}
*/
public static INDArray create(boolean[] data, long[] shape, DataType type) {
+ checkShapeValues(data.length, shape);
return INSTANCE.create(data, shape, Nd4j.getStrides(shape), type, Nd4j.getMemoryManager().getCurrentWorkspace());
}
@@ -5165,17 +5172,17 @@ public class Nd4j {
protected static void checkShapeValues(int length, int... shape) {
checkShapeValues(shape);
- if (ArrayUtil.prodLong(shape) > length)
+ if (ArrayUtil.prodLong(shape) != length && !(length == 1 && shape.length == 0))
throw new ND4JIllegalStateException("Shape of the new array " + Arrays.toString(shape)
- + " doesn't match data length: " + length);
+ + " doesn't match data length: " + length + " - prod(shape) must equal the number of values provided");
}
protected static void checkShapeValues(int length, long... shape) {
checkShapeValues(shape);
- if (ArrayUtil.prodLong(shape) > length)
+ if (ArrayUtil.prodLong(shape) != length && !(length == 1 && shape.length == 0))
throw new ND4JIllegalStateException("Shape of the new array " + Arrays.toString(shape)
- + " doesn't match data length: " + length);
+ + " doesn't match data length: " + length + " - prod(shape) must equal the number of values provided");
}
diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/ops/transforms/Transforms.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/ops/transforms/Transforms.java
index e676197ee..57660b8d7 100644
--- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/ops/transforms/Transforms.java
+++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/ops/transforms/Transforms.java
@@ -45,9 +45,11 @@ import org.nd4j.linalg.api.ops.impl.transforms.pairwise.bool.Xor;
import org.nd4j.linalg.api.ops.impl.transforms.same.*;
import org.nd4j.linalg.api.ops.impl.transforms.strict.*;
import org.nd4j.linalg.api.shape.LongShapeDescriptor;
+import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.inverse.InvertMatrix;
+import java.util.Arrays;
import java.util.List;
/**
@@ -858,11 +860,11 @@ public class Transforms {
* @return
*/
public static INDArray max(INDArray first, INDArray second, boolean dup) {
- INDArray result = first;
- if (dup) {
- result = first.ulike();
- }
- return exec(new OldMax(first, second, result));
+ long[] outShape = broadcastResultShape(first, second); //Also validates
+ Preconditions.checkState(dup || Arrays.equals(outShape, first.shape()), "Cannot do inplace max operation when first input is not equal to result shape (%ndShape vs. result %s)",
+ first, outShape);
+ INDArray out = dup ? Nd4j.create(first.dataType(), outShape) : first;
+ return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.Max(first, second, out))[0];
}
/**
@@ -908,10 +910,11 @@ public class Transforms {
* @return
*/
public static INDArray min(INDArray first, INDArray second, boolean dup) {
- if (dup) {
- first = first.dup();
- }
- return exec(new OldMin(second, first, first));
+ long[] outShape = broadcastResultShape(first, second); //Also validates
+ Preconditions.checkState(dup || Arrays.equals(outShape, first.shape()), "Cannot do inplace min operation when first input is not equal to result shape (%ndShape vs. result %s)",
+ first, outShape);
+ INDArray out = dup ? Nd4j.create(first.dataType(), outShape) : first;
+ return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.Min(first, second, out))[0];
}
/**
@@ -1179,4 +1182,15 @@ public class Transforms {
}
}
+
+ protected static long[] broadcastResultShape(INDArray first, INDArray second){
+ if(first.equalShapes(second)){
+ return first.shape();
+ } else if(Shape.areShapesBroadcastable(first.shape(), second.shape())){
+ return Shape.broadcastOutputShape(first.shape(), second.shape());
+ } else {
+ throw new IllegalStateException("Array shapes are not broadcastable: " + Arrays.toString(first.shape()) +
+ " vs. " + Arrays.toString(second.shape()));
+ }
+ }
}
diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestsC.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestsC.java
index ab2685be2..246b0a8ef 100644
--- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestsC.java
+++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestsC.java
@@ -2699,6 +2699,21 @@ public class Nd4jTestsC extends BaseNd4jTest {
assertEquals(expected, actual);
}
+ @Test
+ public void testBroadcastDiv2(){
+ INDArray arr = Nd4j.ones(DataType.DOUBLE, 1, 64, 125, 125).muli(2);
+ INDArray vec = Nd4j.ones(DataType.DOUBLE, 64).muli(2);
+
+ INDArray exp = Nd4j.ones(DataType.DOUBLE, 1, 64, 125, 125);
+ INDArray out = arr.like();
+
+ for( int i=0; i<10; i++ ) {
+ out.assign(0.0);
+ Nd4j.getExecutioner().exec(new BroadcastDivOp(arr, vec, out, 1));
+ assertEquals(exp, out);
+ }
+ }
+
@Test
public void testBroadcastMult() {
@@ -7417,7 +7432,8 @@ public class Nd4jTestsC extends BaseNd4jTest {
INDArray arr1a = Nd4j.create(new long[]{2,3}, 'c').get(NDArrayIndex.all(), NDArrayIndex.interval(0,2));
INDArray arr3 = arr1a.reshape('c', false, 4,1);
- assertFalse(arr3.isView()); //Should be copy
+ boolean isView = arr3.isView();
+ assertFalse(isView); //Should be copy
try{
INDArray arr4 = arr1a.reshape('c', true, 4,1);
@@ -7861,6 +7877,54 @@ public class Nd4jTestsC extends BaseNd4jTest {
final INDArray arr2 = arr1.reshape(3,1);
assertEquals("Incorrect type!", DataType.FLOAT, arr1.mmul(arr2).dataType());
}
+
+
+ @Test
+ public void testCreateDtypes() {
+ int[] sliceShape = new int[] {9};
+ float[] arrays = new float[] {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f};
+ double [] arrays_double = new double[] {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0};
+
+ INDArray x = Nd4j.create( sliceShape, arrays, arrays );
+ assertEquals(DataType.FLOAT, x.dataType());
+
+ INDArray xd = Nd4j.create( sliceShape, arrays_double, arrays_double );
+ assertEquals(DataType.DOUBLE, xd.dataType());
+ }
+
+
+ @Test
+ public void testCreateShapeValidation(){
+ try {
+ Nd4j.create(new double[]{1, 2, 3}, new int[]{1, 1});
+ fail();
+ } catch (Exception t){
+ assertTrue(t.getMessage().contains("length"));
+ }
+
+ try {
+ Nd4j.create(new float[]{1, 2, 3}, new int[]{1, 1});
+ fail();
+ } catch (Exception t){
+ assertTrue(t.getMessage().contains("length"));
+ }
+
+ try {
+ Nd4j.create(new byte[]{1, 2, 3}, new long[]{1, 1}, DataType.BYTE);
+ fail();
+ } catch (Exception t){
+ assertTrue(t.getMessage().contains("length"));
+ }
+
+ try {
+ Nd4j.create(new double[]{1, 2, 3}, new int[]{1, 1}, 'c');
+ fail();
+ } catch (Exception t){
+ assertTrue(t.getMessage().contains("length"));
+ }
+ }
+
+
///////////////////////////////////////////////////////
protected static void fillJvmArray3D(float[][][] arr) {
int cnt = 1;
diff --git a/nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/buffer/BaseDataBuffer.java b/nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/buffer/BaseDataBuffer.java
index ee83e82e5..82fff8437 100644
--- a/nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/buffer/BaseDataBuffer.java
+++ b/nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/buffer/BaseDataBuffer.java
@@ -2601,7 +2601,8 @@ public abstract class BaseDataBuffer implements DataBuffer {
}
Pointer.memcpy(pointer, oldPointer, this.length() * getElementSize());
- //this.underlyingLength = length;
+ this.underlyingLength = length;
+ this.length = length;
return this;
}