diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/ImportClassMapping.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/ImportClassMapping.java
index 4c6ce710d..1c3bd8c89 100644
--- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/ImportClassMapping.java
+++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/ImportClassMapping.java
@@ -578,7 +578,12 @@ public class ImportClassMapping {
org.nd4j.linalg.api.ops.random.impl.ProbablisticMerge.class,
org.nd4j.linalg.api.ops.random.impl.Range.class,
org.nd4j.linalg.api.ops.random.impl.TruncatedNormalDistribution.class,
- org.nd4j.linalg.api.ops.random.impl.UniformDistribution.class
+ org.nd4j.linalg.api.ops.random.impl.UniformDistribution.class,
+ org.nd4j.linalg.api.ops.impl.transforms.custom.ShiftBits.class,
+ org.nd4j.linalg.api.ops.impl.transforms.custom.RShiftBits.class,
+ org.nd4j.linalg.api.ops.impl.transforms.custom.CyclicShiftBits.class,
+ org.nd4j.linalg.api.ops.impl.transforms.custom.CyclicRShiftBits.class
+
);
static {
diff --git a/nd4s/pom.xml b/nd4s/pom.xml
index 63e5495a7..d30ae4c9b 100644
--- a/nd4s/pom.xml
+++ b/nd4s/pom.xml
@@ -30,7 +30,7 @@
org.nd4j
nd4s
- pom
+ jar
nd4s
@@ -280,6 +280,19 @@
+
+ org.apache.maven.plugins
+ maven-jar-plugin
+
+
+ make-a-jar
+ compile
+
+ jar
+
+
+
+
diff --git a/nd4s/src/main/scala/org/nd4s/samediff/SameDiff.scala b/nd4s/src/main/scala/org/nd4s/samediff/SameDiff.scala
new file mode 100644
index 000000000..8ca21b72e
--- /dev/null
+++ b/nd4s/src/main/scala/org/nd4s/samediff/SameDiff.scala
@@ -0,0 +1,157 @@
+/*******************************************************************************
+ * Copyright (c) 2015-2019 Skymind, Inc.
+ *
+ * This program and the accompanying materials are made available under the
+ * terms of the Apache License, Version 2.0 which is available at
+ * https://www.apache.org/licenses/LICENSE-2.0.
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+ * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+ * License for the specific language governing permissions and limitations
+ * under the License.
+ *
+ * SPDX-License-Identifier: Apache-2.0
+ ******************************************************************************/
+package org.nd4s.samediff
+
+import org.nd4j.linalg.api.ndarray.INDArray
+import org.nd4j.autodiff.samediff.SDVariable
+import org.nd4j.autodiff.samediff.SameDiff
+import org.nd4j.linalg.api.buffer.DataType
+import org.nd4j.linalg.factory.Nd4j
+
+/**
+ * Provides wrappers for nd4j SameDiff and related classes.
+ *
+ * Wrappers are designed to be used implicitly, client code
+ * should be similar to nd4j with additional syntactic sugar
+ * and Scala specific stuff.
+ *
+ * @author Alexander Stoyakin
+ */
+class SameDiffWrapper {
+
+ var sd: SameDiff = SameDiff.create()
+
+ def this(sd: SameDiff) {
+ this
+ this.sd = sd
+ }
+
+ def bind(name: String, data: INDArray): SDVariable =
+ sd.`var`(name, data)
+
+ def bind(name: String, dataType: DataType, shape: Array[Long]): SDVariable =
+ sd.`var`(name, dataType, shape: _*)
+
+ def bind(name: String, dataType: DataType, shape: Array[Int]): SDVariable =
+ sd.`var`(name, dataType, shape: _*)
+
+ def placeHolder(name: String, dataType: DataType, shape: Long*): SDVariable =
+ sd.placeHolder("ph1", DataType.FLOAT, 3, 4)
+}
+
+class SDVariableWrapper {
+
+ var thisVariable: SDVariable = null
+ var isScalar: Boolean = false
+
+ def this(variable: SDVariable) {
+ this
+ thisVariable = variable
+ }
+
+ def *(other: SDVariable): SDVariable =
+ thisVariable.mul(other)
+
+ def +(other: SDVariable): SDVariable =
+ thisVariable.add(other)
+
+ def /(other: SDVariable): SDVariable =
+ if (isScalar)
+ thisVariable.rdiv(other)
+ else
+ thisVariable.rdiv(other)
+
+ def -(other: SDVariable): SDVariable =
+ if (isScalar)
+ thisVariable.rsub(other)
+ else
+ thisVariable.sub(other)
+
+ def %(other: SDVariable): SDVariable = thisVariable.mod(null, other)
+
+ def `//`(other: SDVariable): SDVariable = thisVariable.fdiv(null, other)
+
+ def unary_-(): SDVariable = thisVariable.neg
+
+ def ^(other: SDVariable)(implicit sameDiff: SameDiff): SDVariable = sameDiff.math.xor(thisVariable, other)
+ def |(other: SDVariable)(implicit sameDiff: SameDiff): SDVariable = sameDiff.math.or(thisVariable, other)
+ def &(other: SDVariable)(implicit sameDiff: SameDiff): SDVariable = sameDiff.math.and(thisVariable, other)
+
+ def <<(x: Int)(implicit sameDiff: SameDiff): SDVariable = sameDiff.math.bitShift(null, thisVariable, x)
+ def >>(x: Int)(implicit sameDiff: SameDiff): SDVariable = sameDiff.math.bitShiftRight(null, thisVariable, x)
+
+ // Overloads for numeric arguments
+ // Float
+ def *(other: Float)(implicit sameDiff: SameDiff): SDVariable =
+ thisVariable.mul(sameDiff.constant(other))
+
+ def +(other: Float)(implicit sameDiff: SameDiff): SDVariable =
+ thisVariable.add(sameDiff.constant(other))
+
+ def -(other: Float)(implicit sameDiff: SameDiff): SDVariable =
+ if (isScalar)
+ thisVariable.rsub(sameDiff.constant(other))
+ else
+ thisVariable.sub(sameDiff.constant(other))
+
+ def /(other: Float)(implicit sameDiff: SameDiff): SDVariable =
+ if (isScalar)
+ thisVariable.rdiv(sameDiff.constant(other))
+ else
+ thisVariable.div(sameDiff.constant(other))
+
+ def %(other: Float)(implicit sameDiff: SameDiff): SDVariable =
+ thisVariable.mod(null, sameDiff.constant(other))
+
+ def `//`(other: Float)(implicit sameDiff: SameDiff): SDVariable =
+ thisVariable.fdiv(null, sameDiff.constant(other))
+
+ //Double
+ def *(other: Double)(implicit sameDiff: SameDiff): SDVariable =
+ thisVariable.mul(sameDiff.constant(other))
+
+ def +(other: Double)(implicit sameDiff: SameDiff): SDVariable =
+ thisVariable.add(sameDiff.constant(other))
+
+ def -(other: Double)(implicit sameDiff: SameDiff): SDVariable =
+ if (isScalar)
+ thisVariable.rsub(sameDiff.constant(other))
+ else
+ thisVariable.sub(sameDiff.constant(other))
+
+ def /(other: Double)(implicit sameDiff: SameDiff): SDVariable =
+ if (isScalar)
+ thisVariable.rdiv(sameDiff.constant(other))
+ else
+ thisVariable.div(sameDiff.constant(other))
+
+ def %(other: Double)(implicit sameDiff: SameDiff): SDVariable =
+ thisVariable.mod(null, sameDiff.constant(other))
+
+ def `//`(other: Double)(implicit sameDiff: SameDiff): SDVariable =
+ thisVariable.fdiv(null, sameDiff.constant(other))
+
+ // Int
+ def **(x: Int): SDVariable =
+ thisVariable.pow(x)
+
+ def ^(other: Boolean)(implicit sameDiff: SameDiff): SDVariable =
+ sameDiff.math.xor(thisVariable, sameDiff.constant(Nd4j.scalar(other)))
+ def |(other: Boolean)(implicit sameDiff: SameDiff): SDVariable =
+ sameDiff.math.or(thisVariable, sameDiff.constant(Nd4j.scalar(other)))
+ def &(other: Boolean)(implicit sameDiff: SameDiff): SDVariable =
+ sameDiff.math.and(thisVariable, sameDiff.constant(Nd4j.scalar(other)))
+}
diff --git a/nd4s/src/main/scala/org/nd4s/samediff/implicits/Implicits.scala b/nd4s/src/main/scala/org/nd4s/samediff/implicits/Implicits.scala
new file mode 100644
index 000000000..c10ff367c
--- /dev/null
+++ b/nd4s/src/main/scala/org/nd4s/samediff/implicits/Implicits.scala
@@ -0,0 +1,46 @@
+/*******************************************************************************
+ * Copyright (c) 2015-2019 Skymind, Inc.
+ *
+ * This program and the accompanying materials are made available under the
+ * terms of the Apache License, Version 2.0 which is available at
+ * https://www.apache.org/licenses/LICENSE-2.0.
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+ * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+ * License for the specific language governing permissions and limitations
+ * under the License.
+ *
+ * SPDX-License-Identifier: Apache-2.0
+ ******************************************************************************/
+package org.nd4s.samediff.implicits
+
+import org.nd4j.autodiff.samediff.{ SDVariable, SameDiff }
+import org.nd4j.linalg.factory.Nd4j
+import org.nd4s.samediff.{ SDVariableWrapper, SameDiffWrapper }
+
+object Implicits {
+ implicit def SameDiffToWrapper(sd: SameDiff): SameDiffWrapper =
+ new SameDiffWrapper(sd)
+
+ implicit def SDVariableToWrapper(variable: SDVariable): SDVariableWrapper =
+ new SDVariableWrapper(variable)
+
+ implicit def FloatToSDVariable(x: Float)(implicit sd: SameDiff): SDVariableWrapper = {
+ val result = new SDVariableWrapper(sd.constant(x))
+ result.isScalar = true
+ result
+ }
+
+ implicit def DoubleToSDVariable(x: Double)(implicit sd: SameDiff): SDVariableWrapper = {
+ val result = new SDVariableWrapper(sd.constant(x))
+ result.isScalar = true
+ result
+ }
+
+ implicit def BooleanToSDVariable(x: Boolean)(implicit sd: SameDiff): SDVariableWrapper = {
+ val result = new SDVariableWrapper(sd.constant(Nd4j.scalar(x)))
+ result.isScalar = true
+ result
+ }
+}
diff --git a/nd4s/src/test/scala/org/nd4s/NDArrayExtractionTest.scala b/nd4s/src/test/scala/org/nd4s/NDArrayExtractionTest.scala
index c0a1a95d5..5894e31d7 100644
--- a/nd4s/src/test/scala/org/nd4s/NDArrayExtractionTest.scala
+++ b/nd4s/src/test/scala/org/nd4s/NDArrayExtractionTest.scala
@@ -48,8 +48,8 @@ trait NDArrayExtractionTestBase extends FlatSpec { self: OrderingForTest =>
assert(extracted == expected)
}
- it should "be able to extract a part of 2d matrix with double data and offset" in {
- val ndArray = (1 to 9).map(_.toDouble).mkNDArray(Array(2, 2), NDOrdering.C, offset = 4)
+ it should "be able to extract a part of 2d matrix with double data" in {
+ val ndArray = (5 to 8).map(_.toDouble).mkNDArray(Array(2, 2), NDOrdering.C)
val expectedArray = Array(
Array(5d, 6d),
diff --git a/nd4s/src/test/scala/org/nd4s/NDArrayProjectionAPITest.scala b/nd4s/src/test/scala/org/nd4s/NDArrayProjectionAPITest.scala
index f9d4a5e68..388f440ce 100644
--- a/nd4s/src/test/scala/org/nd4s/NDArrayProjectionAPITest.scala
+++ b/nd4s/src/test/scala/org/nd4s/NDArrayProjectionAPITest.scala
@@ -303,7 +303,7 @@ class NDArrayProjectionAPITest extends FlatSpec {
}
"SliceProjectedNDArray" should "filter slice correctly" in {
- val ndArray = (1d until 10d by 1).asNDArray(2, 2, 2)
+ val ndArray = (1d until 9d by 1).asNDArray(2, 2, 2)
val result = ndArray.sliceP withFilter (input => false)
assert(result.filtered.isEmpty)
}
diff --git a/nd4s/src/test/scala/org/nd4s/samediff/ConstructionTest.scala b/nd4s/src/test/scala/org/nd4s/samediff/ConstructionTest.scala
new file mode 100644
index 000000000..700c626e4
--- /dev/null
+++ b/nd4s/src/test/scala/org/nd4s/samediff/ConstructionTest.scala
@@ -0,0 +1,117 @@
+/*******************************************************************************
+ * Copyright (c) 2015-2019 Skymind, Inc.
+ *
+ * This program and the accompanying materials are made available under the
+ * terms of the Apache License, Version 2.0 which is available at
+ * https://www.apache.org/licenses/LICENSE-2.0.
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+ * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+ * License for the specific language governing permissions and limitations
+ * under the License.
+ *
+ * SPDX-License-Identifier: Apache-2.0
+ ******************************************************************************/
+package org.nd4s.samediff
+
+import org.nd4j.autodiff.samediff.{ SDVariable, SameDiff }
+import org.nd4j.linalg.api.buffer.DataType
+import org.nd4j.linalg.api.ndarray.INDArray
+import org.nd4j.linalg.factory.Nd4j
+import org.nd4s.Implicits._
+import org.nd4s.samediff.implicits.Implicits._
+import org.scalatest.{ FlatSpec, Matchers }
+
+class ConstructionTest extends FlatSpec with Matchers {
+
+ "SameDiff" should "allow composition of arithmetic operations" in {
+
+ val sd = SameDiff.create()
+ val ph1 = sd.placeHolder("ph1", DataType.FLOAT, 3, 4)
+ val w1 = sd.bind("w1", Nd4j.rand(DataType.FLOAT, 4, 5))
+ val b1 = sd.bind("b1", Nd4j.rand(DataType.FLOAT, 5))
+
+ val mmul1 = ph1 * w1
+ val badd1 = mmul1 + b1
+
+ val loss1 = badd1.std("loss1", true)
+
+ sd.setLossVariables("loss1")
+ sd.createGradFunction
+ for (v <- Array[SDVariable](ph1, w1, b1, mmul1, badd1, loss1)) {
+ assert(v.getVarName != null && v.gradient != null)
+ }
+ }
+
+ "SameDiff" should "provide arithmetic operations for float arguments in arbitrary order" in {
+
+ implicit val sd = SameDiff.create()
+ val w1 = sd.bind("w1", 4.0f.toScalar)
+ var evaluated = w1.eval.castTo(DataType.FLOAT)
+ evaluated.toFloatVector.head shouldBe 4.0f
+
+ val w2 = w1 * 2.0f
+ w2.eval.toFloatVector.head shouldBe 8.0f
+ val w3 = w2 + 2.0f
+ w3.eval.toFloatVector.head shouldBe 10.0f
+
+ val w4 = 2.0f * w1
+ w4.eval.toFloatVector.head shouldBe 8.0f
+ val w5 = 2.0f + w2
+ w5.eval.toFloatVector.head shouldBe 10.0f
+
+ val w6 = w1 / 2.0f
+ w6.eval.toFloatVector.head shouldBe 2.0f
+ val w7 = w2 - 2.0f
+ w7.eval.toFloatVector.head shouldBe 6.0f
+
+ val w8 = 2.0f / w1
+ w8.eval.toFloatVector.head shouldBe 2.0f
+
+ val w9 = 2.0f - w2
+ w9.eval.toFloatVector.head shouldBe 6.0f
+ }
+
+ "SameDiff" should "provide arithmetic operations for double arguments in arbitrary order" in {
+ implicit val sd = SameDiff.create()
+ val w1 = sd.bind("w1", 4.0.toScalar)
+ var evaluated = w1.eval.castTo(DataType.DOUBLE)
+ evaluated.toFloatVector.head shouldBe 4.0
+
+ val w2 = w1 * 2.0
+ w2.eval.toFloatVector.head shouldBe 8.0
+ val w3 = w2 + 2.0
+ w3.eval.toFloatVector.head shouldBe 10.0
+
+ val w4 = 2.0 * w1
+ w4.eval.toFloatVector.head shouldBe 8.0
+ val w5 = 2.0 + w2
+ w5.eval.toFloatVector.head shouldBe 10.0
+
+ val w6 = w1 / 2.0
+ w6.eval.toFloatVector.head shouldBe 2.0
+ val w7 = w2 - 2.0
+ w7.eval.toFloatVector.head shouldBe 6.0
+
+ val w8 = 2.0 / w1
+ w8.eval.toFloatVector.head shouldBe 2.0
+ val w9 = 2.0 - w2
+ w9.eval.toFloatVector.head shouldBe 6.0f
+ }
+
+ "SameDiff" should "provide unary math operators" in {
+ implicit val sd = SameDiff.create()
+ val w1 = sd.bind("w1", 4.0.toScalar)
+ var evaluated = w1.eval.castTo(DataType.DOUBLE)
+ evaluated.toFloatVector.head shouldBe 4.0
+
+ val w2 = -w1
+ var evaluated2 = w2.eval.castTo(DataType.DOUBLE)
+ evaluated2.toFloatVector.head shouldBe -4.0
+
+ val w3 = w1 ** 2
+ var evaluated3 = w3.eval.castTo(DataType.DOUBLE)
+ evaluated3.toFloatVector.head shouldBe 16.0
+ }
+}
diff --git a/nd4s/src/test/scala/org/nd4s/samediff/MathTest.scala b/nd4s/src/test/scala/org/nd4s/samediff/MathTest.scala
new file mode 100644
index 000000000..a2c113b50
--- /dev/null
+++ b/nd4s/src/test/scala/org/nd4s/samediff/MathTest.scala
@@ -0,0 +1,191 @@
+/*******************************************************************************
+ * Copyright (c) 2015-2019 Skymind, Inc.
+ *
+ * This program and the accompanying materials are made available under the
+ * terms of the Apache License, Version 2.0 which is available at
+ * https://www.apache.org/licenses/LICENSE-2.0.
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+ * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+ * License for the specific language governing permissions and limitations
+ * under the License.
+ *
+ * SPDX-License-Identifier: Apache-2.0
+ ******************************************************************************/
+package org.nd4s.samediff
+
+import org.nd4j.autodiff.samediff.{ SDVariable, SameDiff }
+import org.nd4j.linalg.api.buffer.DataType
+import org.nd4j.linalg.api.ndarray.INDArray
+import org.nd4j.linalg.factory.Nd4j
+import org.nd4s.Implicits._
+import org.nd4s.samediff.implicits.Implicits._
+import org.scalatest.{ FlatSpec, Matchers }
+
+class MathTest extends FlatSpec with Matchers {
+
+ "SameDiff" should "allow composition of arithmetic operations" in {
+
+ val sd = SameDiff.create()
+ val ph1 = sd.placeHolder("ph1", DataType.FLOAT, 3, 4)
+ val w1 = sd.bind("w1", Nd4j.rand(DataType.FLOAT, 4, 5))
+ val b1 = sd.bind("b1", Nd4j.rand(DataType.FLOAT, 5))
+
+ val mmul1 = ph1 * w1
+ val badd1 = mmul1 + b1
+
+ val loss1 = badd1.std("loss1", true)
+
+ sd.setLossVariables("loss1")
+ sd.createGradFunction
+ for (v <- Array[SDVariable](ph1, w1, b1, mmul1, badd1, loss1)) {
+ assert(v.getVarName != null && v.gradient != null)
+ }
+ }
+
+ "SameDiff" should "provide arithmetic operations for float arguments in arbitrary order" in {
+
+ implicit val sd = SameDiff.create()
+ val w1 = sd.bind("w1", 4.0f.toScalar)
+ var evaluated = w1.eval.castTo(DataType.FLOAT)
+ evaluated.toFloatVector.head shouldBe 4.0f
+
+ val w2 = w1 * 2.0f
+ w2.eval.toFloatVector.head shouldBe 8.0f
+ val w3 = w2 + 2.0f
+ w3.eval.toFloatVector.head shouldBe 10.0f
+
+ val w4 = 2.0f * w1
+ w4.eval.toFloatVector.head shouldBe 8.0f
+ val w5 = 2.0f + w2
+ w5.eval.toFloatVector.head shouldBe 10.0f
+
+ val w6 = w1 / 2.0f
+ w6.eval.toFloatVector.head shouldBe 2.0f
+ val w7 = w2 - 2.0f
+ w7.eval.toFloatVector.head shouldBe 6.0f
+
+ val w8 = 2.0f / w1
+ w8.eval.toFloatVector.head shouldBe 2.0f
+
+ val w9 = 2.0f - w2
+ w9.eval.toFloatVector.head shouldBe 6.0f
+ }
+
+ "SameDiff" should "provide arithmetic operations for double arguments in arbitrary order" in {
+ implicit val sd = SameDiff.create()
+ val w1 = sd.bind("w1", 4.0.toScalar)
+ var evaluated = w1.eval.castTo(DataType.DOUBLE)
+ evaluated.toFloatVector.head shouldBe 4.0
+
+ val w2 = w1 * 2.0
+ w2.eval.toFloatVector.head shouldBe 8.0
+ val w3 = w2 + 2.0
+ w3.eval.toFloatVector.head shouldBe 10.0
+
+ val w4 = 2.0 * w1
+ w4.eval.toFloatVector.head shouldBe 8.0
+ val w5 = 2.0 + w2
+ w5.eval.toFloatVector.head shouldBe 10.0
+
+ val w6 = w1 / 2.0
+ w6.eval.toFloatVector.head shouldBe 2.0
+ val w7 = w2 - 2.0
+ w7.eval.toFloatVector.head shouldBe 6.0
+
+ val w8 = 2.0 / w1
+ w8.eval.toFloatVector.head shouldBe 2.0
+ val w9 = 2.0 - w2
+ w9.eval.toFloatVector.head shouldBe 6.0f
+ }
+
+ "SameDiff" should "provide floor division" in {
+ implicit val sd = SameDiff.create()
+ val w1 = sd.bind("w1", 4.0.toScalar)
+ val w2 = sd.bind("w2", 1.2.toScalar)
+ val w3 = w1 `//` w2
+ w3.eval.toFloatVector.head shouldBe 3.0
+
+ val w4 = w1 `//` 1.5
+ w4.eval.toFloatVector.head shouldBe 2.0
+
+ val w5 = 9.5 `//` w1
+ w5.eval.toFloatVector.head shouldBe 2.0
+ }
+
+ "SameDiff" should "provide remainder division" in {
+ implicit val sd = SameDiff.create()
+ val w1 = sd.bind("w1", 40.0.toScalar)
+ val w2 = sd.bind("w2", 12.0.toScalar)
+ val w3 = w2 % w1
+ w3.eval.toFloatVector.head shouldBe 12.0
+ val w4 = w1 % w2
+ w4.eval.toFloatVector.head shouldBe 4.0
+
+ val w5 = w1 % 15.0
+ w5.eval.toFloatVector.head shouldBe 10.0
+
+ val w6 = 10.0 % w1
+ w6.eval.toFloatVector.head shouldBe 10.0
+ }
+
+ "SameDiff" should "provide unary math operators" in {
+ implicit val sd = SameDiff.create()
+ val w1 = sd.bind("w1", 4.0.toScalar)
+ var evaluated = w1.eval.castTo(DataType.DOUBLE)
+ evaluated.toFloatVector.head shouldBe 4.0
+
+ val w2 = -w1
+ var evaluated2 = w2.eval.castTo(DataType.DOUBLE)
+ evaluated2.toFloatVector.head shouldBe -4.0
+
+ val w3 = w1 ** 2
+ var evaluated3 = w3.eval.castTo(DataType.DOUBLE)
+ evaluated3.toFloatVector.head shouldBe 16.0
+ }
+
+ "SameDiff" should "provide boolean logic operators" in {
+ implicit val sd = SameDiff.create()
+ val w1 = sd.constant(Nd4j.scalar(true))
+ val w2 = sd.constant(Nd4j.scalar(true))
+
+ val w3 = w1 | w2
+ w3.eval.toIntVector.head shouldBe 1
+
+ val w4 = w1 & w2
+ w4.eval.toIntVector.head shouldBe 1
+
+ val w5 = w1 ^ w2
+ w5.eval.toIntVector.head shouldBe 0
+
+ val w6 = w1 | false
+ w6.eval.toIntVector.head shouldBe 1
+
+ val w7 = w1 & false
+ w7.eval.toIntVector.head shouldBe 0
+
+ val w8 = w1 ^ false
+ w8.eval.toIntVector.head shouldBe 1
+
+ val w9 = false | w1
+ w9.eval.toIntVector.head shouldBe 1
+
+ val w10 = false & w1
+ w10.eval.toIntVector.head shouldBe 0
+
+ val w11 = false ^ w1
+ w11.eval.toIntVector.head shouldBe 1
+ }
+
+ "SameDiff" should "provide shifting operations" in {
+ implicit val sd = SameDiff.create()
+ val w1 = sd.constant(16)
+
+ val w2 = w1 << 2
+ w2.eval.toIntVector.head shouldBe 64
+
+ val w3 = w1 >> 2
+ w3.eval.toIntVector.head shouldBe 4
+ }
+}
diff --git a/nd4s/src/test/scala/org/nd4s/samediff/SameDiffTest.scala b/nd4s/src/test/scala/org/nd4s/samediff/SameDiffTest.scala
new file mode 100644
index 000000000..a99b78214
--- /dev/null
+++ b/nd4s/src/test/scala/org/nd4s/samediff/SameDiffTest.scala
@@ -0,0 +1,123 @@
+package org.nd4s.samediff
+
+import java.lang.reflect.Field
+import java.util
+import java.util.{ Arrays, Collections, HashMap, List, Map }
+
+import com.google.common.collect.{ Lists, Maps }
+import org.junit.Assert._
+import org.junit.Assume.assumeNotNull
+import org.nd4j.autodiff.samediff._
+import org.nd4j.autodiff.samediff.impl.DefaultSameDiffConditional
+import org.nd4j.autodiff.validation.{ OpValidation, TestCase }
+import org.nd4j.linalg.activations.Activation
+import org.nd4j.linalg.api.blas.params.MMulTranspose
+import org.nd4j.linalg.api.buffer.DataType
+import org.nd4j.linalg.api.ndarray.INDArray
+import org.nd4j.linalg.api.ops.DynamicCustomOp
+import org.nd4j.linalg.api.ops.impl.layers.{ ExternalErrorsFunction, Linear }
+import org.nd4j.linalg.api.ops.impl.layers.convolution.config.{ Conv2DConfig, LocalResponseNormalizationConfig }
+import org.nd4j.linalg.api.ops.impl.reduce3.ManhattanDistance
+import org.nd4j.linalg.api.ops.impl.shape.tensorops.TensorArray
+import org.nd4j.linalg.api.ops.impl.transforms.any.IsMax
+import org.nd4j.linalg.api.ops.impl.transforms.comparison.{ OldMax, OldMin }
+import org.nd4j.linalg.api.ops.impl.transforms.custom._
+import org.nd4j.linalg.api.ops.random.impl.BernoulliDistribution
+import org.nd4j.linalg.api.shape.LongShapeDescriptor
+import org.nd4j.linalg.checkutil.NDArrayCreationUtil
+import org.nd4j.linalg.dataset.{ DataSet, MultiDataSet }
+import org.nd4j.linalg.dataset.adapter.SingletonMultiDataSetIterator
+import org.nd4j.linalg.factory.Nd4j
+import org.nd4j.linalg.indexing.NDArrayIndex
+import org.nd4j.linalg.indexing.NDArrayIndex.all
+import org.nd4j.linalg.learning.config.Adam
+import org.nd4j.linalg.ops.transforms.Transforms
+import org.nd4j.weightinit.impl.{ OneInitScheme, UniformInitScheme, ZeroInitScheme }
+import org.nd4s.samediff.implicits.Implicits._
+import org.scalatest.{ FlatSpec, Matchers }
+import scala.collection.JavaConversions._
+
+class SameDiffTest extends FlatSpec with Matchers {
+
+ "SameDiff" should "allow Mse backwards execution" in {
+
+ implicit val sd: SameDiff = SameDiff.create
+
+ val nOut: Int = 4
+ val minibatch: Int = 3
+ val input: SDVariable = sd.bind("in", DataType.FLOAT, Array[Long](minibatch, nOut))
+ val label: SDVariable = sd.bind("label", DataType.FLOAT, Array[Long](minibatch, nOut))
+
+ val diff: SDVariable = input - label
+ val sqDiff: SDVariable = diff * diff
+ //val sqDiff: SDVariable = diff ** 2
+ val msePerEx: SDVariable = sd.mean("msePerEx", sqDiff, 1)
+ val avgMSE: SDVariable = sd.mean("loss", msePerEx, 0)
+
+ val inputArr: INDArray = Nd4j.rand(DataType.FLOAT, minibatch, nOut)
+ val labelArr: INDArray = Nd4j.rand(DataType.FLOAT, minibatch, nOut)
+
+ sd.associateArrayWithVariable(inputArr, input)
+ sd.associateArrayWithVariable(labelArr, label)
+
+ val result: INDArray = sd.execAndEndResult
+ assertEquals(1, result.length)
+
+ val emptyMap = new HashMap[String, INDArray]()
+ sd.execBackwards(emptyMap)
+ }
+
+ "SameDiff" should "run test dense layer forward pass" in {
+ Nd4j.getRandom.setSeed(12345)
+ implicit val sd = SameDiff.create
+ val iInput = Nd4j.rand(3, 4)
+ val iWeights = Nd4j.rand(4, 5)
+ val iBias = Nd4j.rand(1, 5)
+ val input = sd.bind("input", iInput)
+ val weights = sd.bind("weights", iWeights)
+ val bias = sd.bind("bias", iBias)
+ val mmul = sd.mmul("mmul", input, weights)
+
+ val z = mmul + bias
+
+ val out = sd.nn.sigmoid("out", z)
+ val expMmul = iInput.mmul(iWeights)
+ val expZ = expMmul.addRowVector(iBias)
+ val expOut = Transforms.sigmoid(expZ, true)
+ sd.exec(new HashMap[String, INDArray](), sd.outputs)
+ assertEquals(expMmul, mmul.getArr)
+ assertEquals(expZ, z.getArr)
+ assertEquals(expOut, out.getArr)
+ }
+
+ "SameDiff" should "convert placeholder to constant" in {
+ Nd4j.getRandom.setSeed(12345)
+ val sd = SameDiff.create
+ val in = sd.placeHolder("in", DataType.FLOAT, 1, 3)
+ val in2 = sd.placeHolder("in2", DataType.FLOAT, 3, 4)
+ val b = sd.bind("b", Nd4j.rand(DataType.FLOAT, 1, 4))
+ val mmul = in.mmul(in2)
+ val add = mmul + b
+ val tanh = sd.math.tanh(add)
+ val loss = sd.variance(tanh, true)
+ val inArr = Nd4j.rand(DataType.FLOAT, 1, 3)
+ in.setArray(inArr)
+ val inArr2 = Nd4j.rand(DataType.FLOAT, 3, 4)
+ val c = TrainingConfig.builder
+ .updater(new Adam(0.1))
+ .weightDecay(0.01, true)
+ .dataSetFeatureMapping("in", "in2")
+ .skipBuilderValidation(true)
+ .build
+ sd.setTrainingConfig(c)
+ sd.fit(new SingletonMultiDataSetIterator(new MultiDataSet(Array[INDArray](inArr, inArr2), null)), 1)
+ val out = tanh.eval
+ in.convertToConstant
+ val out2 = tanh.eval
+ assertEquals(out, out2)
+ assertEquals(VariableType.CONSTANT, in.getVariableType)
+ assertEquals(inArr, in.getArr)
+ //Sanity check on fitting:
+ sd.fit(new SingletonMultiDataSetIterator(new MultiDataSet(Array[INDArray](inArr2), null)), 1)
+ }
+}
diff --git a/nd4s/src/test/scala/org/nd4s/samediff/TrainingTest.scala b/nd4s/src/test/scala/org/nd4s/samediff/TrainingTest.scala
new file mode 100644
index 000000000..d51707ee1
--- /dev/null
+++ b/nd4s/src/test/scala/org/nd4s/samediff/TrainingTest.scala
@@ -0,0 +1,125 @@
+package org.nd4s.samediff
+
+import org.nd4j.autodiff.samediff.{ SDVariable, SameDiff, TrainingConfig }
+import org.nd4j.linalg.api.buffer.DataType
+import org.nd4j.linalg.api.ndarray.INDArray
+import org.nd4j.linalg.dataset.{ DataSet, MultiDataSet }
+import org.nd4j.linalg.dataset.adapter.SingletonMultiDataSetIterator
+import org.nd4j.linalg.factory.Nd4j
+import org.nd4j.linalg.learning.config.Adam
+import org.nd4s.Implicits._
+import org.nd4s.samediff.implicits.Implicits._
+import org.scalatest.{ FlatSpec, Matchers }
+
+class TrainingTest extends FlatSpec with Matchers {
+
+ "SameDiff" should "allow loss calculation" in {
+ for (i <- 0 until 2) {
+ implicit val sd = SameDiff.create
+ val ph = sd.placeHolder("ph", DataType.FLOAT, 3, 4)
+ val w = sd.bind("w", Nd4j.rand(DataType.FLOAT, 4, 5))
+ val b = sd.bind("b", Nd4j.rand(DataType.FLOAT, 5))
+ val mmul = ph.mmul(w)
+ val badd = mmul + b
+ val add = badd + 1
+ val shape = add.shape
+ val unused1 = ph.mul(2)
+ val unused2 = ph.sub(4)
+ val unused3 = unused1.div(unused2)
+ val loss1 = add.std("l1", true)
+ val loss2 = mmul.mean("l2")
+ Console.println(sd.summary)
+ if (i == 0) {
+ sd.setLossVariables("l1", "l2")
+ sd.createGradFunction()
+ } else {
+ val tc = TrainingConfig.builder
+ .updater(new Adam(0.01))
+ .minimize("l1", "l2")
+ .dataSetFeatureMapping("ph")
+ .markLabelsUnused
+ .build
+ sd.setTrainingConfig(tc)
+ val ds = new DataSet(Nd4j.create(3, 4), null)
+ sd.fit(ds)
+ sd.fit(ds)
+ }
+ for (s <- Array[String]("w", "b", badd.getVarName, add.getVarName, "l1", "l2")) {
+ val gradVar = sd.getVariable(s).gradient
+ assert(gradVar != null)
+ }
+ //Unused:
+ assert(!shape.hasGradient)
+ try assert(shape.gradient == null)
+ catch {
+ case e: IllegalStateException =>
+ assert(e.getMessage.contains("only floating point variables"))
+ }
+ for (s <- Array[String](unused1.getVarName, unused2.getVarName, unused3.getVarName)) {
+ assert(sd.getVariable(s).gradient == null)
+ }
+ }
+ }
+
+ "SameDiff" should "allow creating and running model with 2 losses: train on the first one, then change losses" in {
+ // TODO: try to get rid of implicit here
+ implicit val sd = SameDiff.create
+ val ph1 = sd.placeHolder("ph1", DataType.FLOAT, 3, 4)
+ val w1 = sd.bind("w1", Nd4j.rand(DataType.FLOAT, 4, 5))
+ val b1 = sd.bind("b1", Nd4j.rand(DataType.FLOAT, 5))
+ val mmul1 = ph1.mmul(w1)
+ val badd1 = mmul1 + b1
+
+ val ph2 = sd.placeHolder("ph2", DataType.FLOAT, 3, 2)
+ val w2 = sd.bind("w2", Nd4j.rand(DataType.FLOAT, 2, 6))
+ val b2 = sd.bind("b2", Nd4j.rand(DataType.FLOAT, 6))
+ val mmul2 = ph2.mmul(w2)
+ val badd2 = mmul2 + b2
+ val loss1 = badd1.std("loss1", true)
+ val loss2 = badd2.std("loss2", true)
+ //First: create grad function for optimizing loss 1 only
+ sd.setLossVariables("loss1")
+ sd.createGradFunction()
+ for (v <- Array[SDVariable](ph1, w1, b1, mmul1, badd1, loss1)) {
+ assert(v.gradient != null)
+ }
+ for (v <- Array[SDVariable](ph2, w2, b2, mmul2, badd2, loss2)) {
+ assert(v.gradient == null)
+ }
+ //Now, set to other loss function
+ sd.setLossVariables("loss2")
+ sd.createGradFunction()
+ for (v <- Array[SDVariable](ph1, w1, b1, mmul1, badd1, loss1)) {
+ assert(v.gradient == null)
+ }
+ for (v <- Array[SDVariable](ph2, w2, b2, mmul2, badd2, loss2)) {
+ assert(v.gradient != null)
+ }
+ //Train the first side of the graph. The other side should remain unmodified!
+ sd.setLossVariables("loss1")
+ var w1Before = w1.getArr.dup
+ var b1Before = b1.getArr.dup
+ var w2Before = w2.getArr.dup
+ var b2Before = b2.getArr.dup
+ val tc = TrainingConfig.builder.updater(new Adam(1e-2)).dataSetFeatureMapping("ph1", "ph2").markLabelsUnused.build
+ sd.setTrainingConfig(tc)
+ val mds = new MultiDataSet(Array[INDArray](Nd4j.rand(DataType.FLOAT, 3, 4), Nd4j.rand(DataType.FLOAT, 3, 2)),
+ new Array[INDArray](0))
+ sd.fit(new SingletonMultiDataSetIterator(mds), 3)
+ assert(w1Before != w1.getArr)
+ assert(b1Before != b1.getArr)
+ assert(w2Before == w2.getArr)
+ assert(b2Before == b2.getArr)
+ //Train second side of graph; first side should be unmodified
+ sd.setLossVariables("loss2")
+ w1Before = w1.getArr.dup
+ b1Before = b1.getArr.dup
+ w2Before = w2.getArr.dup
+ b2Before = b2.getArr.dup
+ sd.fit(new SingletonMultiDataSetIterator(mds), 3)
+ assert(w1Before == w1.getArr)
+ assert(b1Before == b1.getArr)
+ assert(w2Before != w2.getArr)
+ assert(b2Before != b2.getArr)
+ }
+}