diff --git a/.gitignore b/.gitignore index ad2e28e6f..fd33cb142 100644 --- a/.gitignore +++ b/.gitignore @@ -70,3 +70,6 @@ venv2/ # Ignore the nd4j files that are created by javacpp at build to stop merge conflicts nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/nativeblas/Nd4jCuda.java + +# Ignore meld temp files +*.orig diff --git a/arbiter/arbiter-core/pom.xml b/arbiter/arbiter-core/pom.xml index 6ce3c9c1f..76251f4cd 100644 --- a/arbiter/arbiter-core/pom.xml +++ b/arbiter/arbiter-core/pom.xml @@ -14,7 +14,8 @@ ~ SPDX-License-Identifier: Apache-2.0 ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~--> - + arbiter org.deeplearning4j @@ -33,10 +34,10 @@ nd4j-api ${nd4j.version} - - com.google.code.findbugs - * - + + com.google.code.findbugs + * + diff --git a/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/AssertTestsExtendBaseClass.java b/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/AssertTestsExtendBaseClass.java new file mode 100644 index 000000000..75a64d05f --- /dev/null +++ b/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/AssertTestsExtendBaseClass.java @@ -0,0 +1,49 @@ +/* ****************************************************************************** + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.deeplearning4j.arbiter.optimize; + +import lombok.extern.slf4j.Slf4j; +import org.deeplearning4j.BaseDL4JTest; +import org.nd4j.AbstractAssertTestsClass; +import java.util.*; + +/** + * This class checks that all test classes (i.e., anything with one or more methods annotated with @Test) + * extends BaseDl4jTest - either directly or indirectly. + * Other than a small set of exceptions, all tests must extend this + * + * @author Alex Black + */ + +@Slf4j +public class AssertTestsExtendBaseClass extends AbstractAssertTestsClass { + + @Override + protected Set> getExclusions() { + //Set of classes that are exclusions to the rule (either run manually or have their own logging + timeouts) + return new HashSet<>(); + } + + @Override + protected String getPackageName() { + return "org.deeplearning4j.arbiter.optimize"; + } + + @Override + protected Class getBaseClass() { + return BaseDL4JTest.class; + } +} diff --git a/arbiter/arbiter-deeplearning4j/src/test/java/org/deeplearning4j/arbiter/AssertTestsExtendBaseClass.java b/arbiter/arbiter-deeplearning4j/src/test/java/org/deeplearning4j/arbiter/AssertTestsExtendBaseClass.java new file mode 100644 index 000000000..b8d200350 --- /dev/null +++ b/arbiter/arbiter-deeplearning4j/src/test/java/org/deeplearning4j/arbiter/AssertTestsExtendBaseClass.java @@ -0,0 +1,50 @@ +/* ****************************************************************************** + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.deeplearning4j.arbiter; + +import lombok.extern.slf4j.Slf4j; +import org.deeplearning4j.BaseDL4JTest; +import org.nd4j.AbstractAssertTestsClass; + +import java.util.*; + +/** + * This class checks that all test classes (i.e., anything with one or more methods annotated with @Test) + * extends BaseDl4jTest - either directly or indirectly. + * Other than a small set of exceptions, all tests must extend this + * + * @author Alex Black + */ + +@Slf4j +public class AssertTestsExtendBaseClass extends AbstractAssertTestsClass { + + @Override + protected Set> getExclusions() { + //Set of classes that are exclusions to the rule (either run manually or have their own logging + timeouts) + return new HashSet<>(); + } + + @Override + protected String getPackageName() { + return "org.deeplearning4j.arbiter"; + } + + @Override + protected Class getBaseClass() { + return BaseDL4JTest.class; + } +} diff --git a/arbiter/arbiter-server/src/test/java/org/deeplearning4j/arbiter/server/AssertTestsExtendBaseClass.java b/arbiter/arbiter-server/src/test/java/org/deeplearning4j/arbiter/server/AssertTestsExtendBaseClass.java new file mode 100644 index 000000000..b305b123b --- /dev/null +++ b/arbiter/arbiter-server/src/test/java/org/deeplearning4j/arbiter/server/AssertTestsExtendBaseClass.java @@ -0,0 +1,50 @@ +/* ****************************************************************************** + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.deeplearning4j.arbiter.server; + +import lombok.extern.slf4j.Slf4j; +import org.deeplearning4j.BaseDL4JTest; +import org.nd4j.AbstractAssertTestsClass; + +import java.util.*; + +/** + * This class checks that all test classes (i.e., anything with one or more methods annotated with @Test) + * extends BaseDl4jTest - either directly or indirectly. + * Other than a small set of exceptions, all tests must extend this + * + * @author Alex Black + */ + +@Slf4j +public class AssertTestsExtendBaseClass extends AbstractAssertTestsClass { + + @Override + protected Set> getExclusions() { + //Set of classes that are exclusions to the rule (either run manually or have their own logging + timeouts) + return new HashSet<>(); + } + + @Override + protected String getPackageName() { + return "org.deeplearning4j.arbiter.server"; + } + + @Override + protected Class getBaseClass() { + return BaseDL4JTest.class; + } +} diff --git a/arbiter/arbiter-server/src/test/java/org/deeplearning4j/arbiter/server/MnistDataSetIteratorFactory.java b/arbiter/arbiter-server/src/test/java/org/deeplearning4j/arbiter/server/MnistDataSetIteratorFactory.java index dbf05d34f..57bef758d 100644 --- a/arbiter/arbiter-server/src/test/java/org/deeplearning4j/arbiter/server/MnistDataSetIteratorFactory.java +++ b/arbiter/arbiter-server/src/test/java/org/deeplearning4j/arbiter/server/MnistDataSetIteratorFactory.java @@ -17,6 +17,7 @@ package org.deeplearning4j.arbiter.server; import lombok.Data; +import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; import org.nd4j.linalg.dataset.api.iterator.DataSetIteratorFactory; @@ -27,7 +28,7 @@ import java.io.IOException; * Created by agibsonccc on 3/13/17. */ @Data -public class MnistDataSetIteratorFactory implements DataSetIteratorFactory { +public class MnistDataSetIteratorFactory extends BaseDL4JTest implements DataSetIteratorFactory { /** * @return */ diff --git a/arbiter/arbiter-server/src/test/java/org/deeplearning4j/arbiter/server/TestDataFactoryProviderMnist.java b/arbiter/arbiter-server/src/test/java/org/deeplearning4j/arbiter/server/TestDataFactoryProviderMnist.java index e1a7f820e..c4a75ffb4 100644 --- a/arbiter/arbiter-server/src/test/java/org/deeplearning4j/arbiter/server/TestDataFactoryProviderMnist.java +++ b/arbiter/arbiter-server/src/test/java/org/deeplearning4j/arbiter/server/TestDataFactoryProviderMnist.java @@ -17,13 +17,14 @@ package org.deeplearning4j.arbiter.server; import lombok.AllArgsConstructor; +import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.datasets.iterator.EarlyTerminationDataSetIterator; import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; import org.nd4j.linalg.dataset.api.iterator.DataSetIteratorFactory; @AllArgsConstructor -public class TestDataFactoryProviderMnist implements DataSetIteratorFactory { +public class TestDataFactoryProviderMnist extends BaseDL4JTest implements DataSetIteratorFactory { private int batchSize; private int terminationIter; diff --git a/arbiter/arbiter-ui/pom.xml b/arbiter/arbiter-ui/pom.xml index 2067a3fc7..7392392db 100644 --- a/arbiter/arbiter-ui/pom.xml +++ b/arbiter/arbiter-ui/pom.xml @@ -54,6 +54,13 @@ ${dl4j.version} + + org.deeplearning4j + deeplearning4j-common-tests + ${dl4j.version} + test + + ch.qos.logback logback-classic diff --git a/arbiter/arbiter-ui/src/test/java/org/deeplearning4j/arbiter/optimize/AssertTestsExtendBaseClass.java b/arbiter/arbiter-ui/src/test/java/org/deeplearning4j/arbiter/optimize/AssertTestsExtendBaseClass.java new file mode 100644 index 000000000..fee20847c --- /dev/null +++ b/arbiter/arbiter-ui/src/test/java/org/deeplearning4j/arbiter/optimize/AssertTestsExtendBaseClass.java @@ -0,0 +1,50 @@ +/* ****************************************************************************** + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.deeplearning4j.arbiter.optimize; + +import lombok.extern.slf4j.Slf4j; +import org.nd4j.AbstractAssertTestsClass; +import org.deeplearning4j.BaseDL4JTest; + +import java.util.*; + +/** + * This class checks that all test classes (i.e., anything with one or more methods annotated with @Test) + * extends BaseDl4jTest - either directly or indirectly. + * Other than a small set of exceptions, all tests must extend this + * + * @author Alex Black + */ + +@Slf4j +public class AssertTestsExtendBaseClass extends AbstractAssertTestsClass { + + @Override + protected Set> getExclusions() { + //Set of classes that are exclusions to the rule (either run manually or have their own logging + timeouts) + return new HashSet<>(); + } + + @Override + protected String getPackageName() { + return "org.deeplearning4j.arbiter.optimize"; + } + + @Override + protected Class getBaseClass() { + return BaseDL4JTest.class; + } +} diff --git a/arbiter/arbiter-ui/src/test/java/org/deeplearning4j/arbiter/optimize/TestBasic.java b/arbiter/arbiter-ui/src/test/java/org/deeplearning4j/arbiter/optimize/TestBasic.java index 804c6f974..ddf73e455 100644 --- a/arbiter/arbiter-ui/src/test/java/org/deeplearning4j/arbiter/optimize/TestBasic.java +++ b/arbiter/arbiter-ui/src/test/java/org/deeplearning4j/arbiter/optimize/TestBasic.java @@ -16,6 +16,7 @@ package org.deeplearning4j.arbiter.optimize; +import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.api.storage.StatsStorage; import org.deeplearning4j.arbiter.ComputationGraphSpace; import org.deeplearning4j.arbiter.MultiLayerSpace; @@ -70,7 +71,7 @@ import java.util.concurrent.TimeUnit; /** * Created by Alex on 19/07/2017. */ -public class TestBasic { +public class TestBasic extends BaseDL4JTest { @Test @Ignore diff --git a/change-scala-versions.sh b/change-scala-versions.sh index 8968abbf3..aace1b05e 100755 --- a/change-scala-versions.sh +++ b/change-scala-versions.sh @@ -88,5 +88,15 @@ find "$BASEDIR" -name 'pom.xml' -not -path '*target*' \ #Scala maven plugin, 2.11 find "$BASEDIR" -name 'pom.xml' -not -path '*target*' \ -exec bash -c "sed_i 's/\(scalaVersion>\)'$FROM_VERSION'<\/scalaVersion>/\1'$TO_VERSION'<\/scalaVersion>/g' {}" \; + +# Disable deeplearning4j-nlp-korean for scala 2.12 - see https://github.com/eclipse/deeplearning4j/issues/8840 +if [ $TO_VERSION = $SCALA_211_VERSION ]; then + #Enable + sed -i 's/ / deeplearning4j-nlp-korean<\/module>/g' deeplearning4j/deeplearning4j-nlp-parent/pom.xml +else + #Disable + sed -i 's/ deeplearning4j-nlp-korean<\/module>/ /g' deeplearning4j/deeplearning4j-nlp-parent/pom.xml +fi + echo "Done updating Scala versions."; diff --git a/datavec/datavec-api/pom.xml b/datavec/datavec-api/pom.xml index 10ed3517a..3c3eec86e 100644 --- a/datavec/datavec-api/pom.xml +++ b/datavec/datavec-api/pom.xml @@ -15,7 +15,8 @@ ~ SPDX-License-Identifier: Apache-2.0 ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~--> - + datavec-parent org.datavec @@ -79,6 +80,14 @@ ${nd4j.version} + + org.nd4j + nd4j-common-tests + ${nd4j.version} + test + + + ch.qos.logback logback-classic diff --git a/datavec/datavec-api/src/main/java/org/datavec/api/util/ndarray/RecordConverter.java b/datavec/datavec-api/src/main/java/org/datavec/api/util/ndarray/RecordConverter.java index c55d4d3bb..92a1f737b 100644 --- a/datavec/datavec-api/src/main/java/org/datavec/api/util/ndarray/RecordConverter.java +++ b/datavec/datavec-api/src/main/java/org/datavec/api/util/ndarray/RecordConverter.java @@ -47,9 +47,8 @@ public class RecordConverter { * * @return the array */ - @Deprecated - public static INDArray toArray(Collection record, int size) { - return toArray(record); + public static INDArray toArray(DataType dataType, Collection record, int size) { + return toArray(dataType, record); } /** @@ -78,13 +77,23 @@ public class RecordConverter { /** * Convert a set of records in to a matrix + * As per {@link #toMatrix(DataType, List)} but hardcoded to Float datatype * @param records the records ot convert * @return the matrix for the records */ public static INDArray toMatrix(List> records) { + return toMatrix(DataType.FLOAT, records); + } + + /** + * Convert a set of records in to a matrix + * @param records the records ot convert + * @return the matrix for the records + */ + public static INDArray toMatrix(DataType dataType, List> records) { List toStack = new ArrayList<>(); for(List l : records){ - toStack.add(toArray(l)); + toStack.add(toArray(dataType, l)); } return Nd4j.vstack(toStack); @@ -92,10 +101,20 @@ public class RecordConverter { /** * Convert a record to an INDArray. May contain a mix of Writables and row vector NDArrayWritables. + * As per {@link #toArray(DataType, Collection)} but hardcoded to Float datatype * @param record the record to convert * @return the array */ - public static INDArray toArray(Collection record) { + public static INDArray toArray(Collection record){ + return toArray(DataType.FLOAT, record); + } + + /** + * Convert a record to an INDArray. May contain a mix of Writables and row vector NDArrayWritables. + * @param record the record to convert + * @return the array + */ + public static INDArray toArray(DataType dataType, Collection record) { List l; if(record instanceof List){ l = (List)record; @@ -124,7 +143,7 @@ public class RecordConverter { } } - INDArray arr = Nd4j.create(1, length); + INDArray arr = Nd4j.create(dataType, 1, length); int k = 0; for (Writable w : record ) { diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/AssertTestsExtendBaseClass.java b/datavec/datavec-api/src/test/java/org/datavec/api/AssertTestsExtendBaseClass.java new file mode 100644 index 000000000..43c606123 --- /dev/null +++ b/datavec/datavec-api/src/test/java/org/datavec/api/AssertTestsExtendBaseClass.java @@ -0,0 +1,57 @@ +/* ****************************************************************************** + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.datavec.api; + +import lombok.extern.slf4j.Slf4j; +import org.datavec.api.transform.serde.testClasses.CustomCondition; +import org.datavec.api.transform.serde.testClasses.CustomFilter; +import org.datavec.api.transform.serde.testClasses.CustomTransform; +import org.nd4j.AbstractAssertTestsClass; +import org.nd4j.BaseND4JTest; + +import java.util.*; + +/** + * This class checks that all test classes (i.e., anything with one or more methods annotated with @Test) + * extends BaseND4jTest - either directly or indirectly. + * Other than a small set of exceptions, all tests must extend this + * + * @author Alex Black + */ + +@Slf4j +public class AssertTestsExtendBaseClass extends AbstractAssertTestsClass { + + @Override + protected Set> getExclusions() { + //Set of classes that are exclusions to the rule (either run manually or have their own logging + timeouts) + Set> res = new HashSet<>(); + res.add(CustomCondition.class); + res.add(CustomFilter.class); + res.add(CustomTransform.class); + return res; + } + + @Override + protected String getPackageName() { + return "org.datavec.api"; + } + + @Override + protected Class getBaseClass() { + return BaseND4JTest.class; + } +} diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/CSVLineSequenceRecordReaderTest.java b/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/CSVLineSequenceRecordReaderTest.java index 70a7ffa7b..84d9b259f 100644 --- a/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/CSVLineSequenceRecordReaderTest.java +++ b/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/CSVLineSequenceRecordReaderTest.java @@ -25,6 +25,7 @@ import org.datavec.api.writable.Writable; import org.junit.Rule; import org.junit.Test; import org.junit.rules.TemporaryFolder; +import org.nd4j.BaseND4JTest; import java.io.File; import java.nio.charset.StandardCharsets; @@ -34,7 +35,7 @@ import java.util.List; import static org.junit.Assert.assertEquals; -public class CSVLineSequenceRecordReaderTest { +public class CSVLineSequenceRecordReaderTest extends BaseND4JTest { @Rule public TemporaryFolder testDir = new TemporaryFolder(); diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/CSVMultiSequenceRecordReaderTest.java b/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/CSVMultiSequenceRecordReaderTest.java index 888d8b523..c293d4544 100644 --- a/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/CSVMultiSequenceRecordReaderTest.java +++ b/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/CSVMultiSequenceRecordReaderTest.java @@ -26,6 +26,8 @@ import org.datavec.api.writable.Writable; import org.junit.Rule; import org.junit.Test; import org.junit.rules.TemporaryFolder; +import org.nd4j.BaseND4JTest; +import org.nd4j.linalg.api.ops.impl.controlflow.compat.BaseCompatOp; import java.io.File; import java.nio.charset.StandardCharsets; @@ -37,7 +39,7 @@ import java.util.List; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; -public class CSVMultiSequenceRecordReaderTest { +public class CSVMultiSequenceRecordReaderTest extends BaseND4JTest { @Rule public TemporaryFolder testDir = new TemporaryFolder(); diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/CSVNLinesSequenceRecordReaderTest.java b/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/CSVNLinesSequenceRecordReaderTest.java index 9b84fddc3..9f297d83b 100644 --- a/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/CSVNLinesSequenceRecordReaderTest.java +++ b/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/CSVNLinesSequenceRecordReaderTest.java @@ -24,6 +24,7 @@ import org.datavec.api.records.reader.impl.csv.CSVRecordReader; import org.datavec.api.split.FileSplit; import org.datavec.api.writable.Writable; import org.junit.Test; +import org.nd4j.BaseND4JTest; import org.nd4j.linalg.io.ClassPathResource; import java.util.ArrayList; @@ -34,7 +35,7 @@ import static org.junit.Assert.assertEquals; /** * Created by Alex on 19/09/2016. */ -public class CSVNLinesSequenceRecordReaderTest { +public class CSVNLinesSequenceRecordReaderTest extends BaseND4JTest { @Test public void testCSVNLinesSequenceRecordReader() throws Exception { diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/CSVRecordReaderTest.java b/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/CSVRecordReaderTest.java index 534cc986e..471dc07c4 100644 --- a/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/CSVRecordReaderTest.java +++ b/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/CSVRecordReaderTest.java @@ -31,6 +31,7 @@ import org.datavec.api.writable.IntWritable; import org.datavec.api.writable.Text; import org.datavec.api.writable.Writable; import org.junit.Test; +import org.nd4j.BaseND4JTest; import org.nd4j.linalg.io.ClassPathResource; import java.io.File; @@ -44,7 +45,7 @@ import java.util.NoSuchElementException; import static org.junit.Assert.*; -public class CSVRecordReaderTest { +public class CSVRecordReaderTest extends BaseND4JTest { @Test public void testNext() throws Exception { CSVRecordReader reader = new CSVRecordReader(); diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/CSVSequenceRecordReaderTest.java b/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/CSVSequenceRecordReaderTest.java index fbbd992d1..e0763bbbc 100644 --- a/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/CSVSequenceRecordReaderTest.java +++ b/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/CSVSequenceRecordReaderTest.java @@ -26,6 +26,7 @@ import org.datavec.api.writable.Writable; import org.junit.Rule; import org.junit.Test; import org.junit.rules.TemporaryFolder; +import org.nd4j.BaseND4JTest; import org.nd4j.linalg.io.ClassPathResource; import java.io.File; @@ -39,7 +40,7 @@ import java.util.List; import static org.junit.Assert.assertEquals; -public class CSVSequenceRecordReaderTest { +public class CSVSequenceRecordReaderTest extends BaseND4JTest { @Rule public TemporaryFolder tempDir = new TemporaryFolder(); diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/CSVVariableSlidingWindowRecordReaderTest.java b/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/CSVVariableSlidingWindowRecordReaderTest.java index 8e60acad9..fe0c94c4c 100644 --- a/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/CSVVariableSlidingWindowRecordReaderTest.java +++ b/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/CSVVariableSlidingWindowRecordReaderTest.java @@ -22,6 +22,7 @@ import org.datavec.api.records.reader.impl.csv.CSVVariableSlidingWindowRecordRea import org.datavec.api.split.FileSplit; import org.datavec.api.writable.Writable; import org.junit.Test; +import org.nd4j.BaseND4JTest; import org.nd4j.linalg.io.ClassPathResource; import java.util.LinkedList; @@ -34,7 +35,7 @@ import static org.junit.Assert.assertEquals; * * @author Justin Long (crockpotveggies) */ -public class CSVVariableSlidingWindowRecordReaderTest { +public class CSVVariableSlidingWindowRecordReaderTest extends BaseND4JTest { @Test public void testCSVVariableSlidingWindowRecordReader() throws Exception { diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/FileBatchRecordReaderTest.java b/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/FileBatchRecordReaderTest.java index c67e32192..d6f03d815 100644 --- a/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/FileBatchRecordReaderTest.java +++ b/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/FileBatchRecordReaderTest.java @@ -27,6 +27,7 @@ import org.datavec.api.writable.Writable; import org.junit.Rule; import org.junit.Test; import org.junit.rules.TemporaryFolder; +import org.nd4j.BaseND4JTest; import org.nd4j.api.loader.FileBatch; import java.io.File; @@ -36,7 +37,7 @@ import java.util.List; import static org.junit.Assert.*; -public class FileBatchRecordReaderTest { +public class FileBatchRecordReaderTest extends BaseND4JTest { @Rule public TemporaryFolder testDir = new TemporaryFolder(); diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/FileRecordReaderTest.java b/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/FileRecordReaderTest.java index 533f5be66..6bf66880f 100644 --- a/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/FileRecordReaderTest.java +++ b/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/FileRecordReaderTest.java @@ -23,6 +23,7 @@ import org.datavec.api.split.FileSplit; import org.datavec.api.split.InputSplit; import org.datavec.api.writable.Writable; import org.junit.Test; +import org.nd4j.BaseND4JTest; import org.nd4j.linalg.io.ClassPathResource; import java.net.URI; @@ -36,7 +37,7 @@ import static org.junit.Assert.assertFalse; /** * Created by nyghtowl on 11/14/15. */ -public class FileRecordReaderTest { +public class FileRecordReaderTest extends BaseND4JTest { @Test public void testReset() throws Exception { diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/JacksonLineRecordReaderTest.java b/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/JacksonLineRecordReaderTest.java index bfeadef36..2f91579f0 100644 --- a/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/JacksonLineRecordReaderTest.java +++ b/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/JacksonLineRecordReaderTest.java @@ -27,6 +27,7 @@ import org.datavec.api.writable.Writable; import org.junit.Rule; import org.junit.Test; import org.junit.rules.TemporaryFolder; +import org.nd4j.BaseND4JTest; import org.nd4j.linalg.io.ClassPathResource; import org.nd4j.shade.jackson.core.JsonFactory; import org.nd4j.shade.jackson.databind.ObjectMapper; @@ -39,7 +40,7 @@ import java.util.List; import static org.junit.Assert.assertEquals; -public class JacksonLineRecordReaderTest { +public class JacksonLineRecordReaderTest extends BaseND4JTest { @Rule public TemporaryFolder testDir = new TemporaryFolder(); diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/JacksonRecordReaderTest.java b/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/JacksonRecordReaderTest.java index c95de48e7..f1fa8d6b2 100644 --- a/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/JacksonRecordReaderTest.java +++ b/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/JacksonRecordReaderTest.java @@ -30,6 +30,7 @@ import org.datavec.api.writable.Writable; import org.junit.Rule; import org.junit.Test; import org.junit.rules.TemporaryFolder; +import org.nd4j.BaseND4JTest; import org.nd4j.linalg.io.ClassPathResource; import org.nd4j.shade.jackson.core.JsonFactory; import org.nd4j.shade.jackson.databind.ObjectMapper; @@ -48,7 +49,7 @@ import static org.junit.Assert.assertFalse; /** * Created by Alex on 11/04/2016. */ -public class JacksonRecordReaderTest { +public class JacksonRecordReaderTest extends BaseND4JTest { @Rule public TemporaryFolder testDir = new TemporaryFolder(); diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/LibSvmRecordReaderTest.java b/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/LibSvmRecordReaderTest.java index 75871a6b7..5e8ca6546 100644 --- a/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/LibSvmRecordReaderTest.java +++ b/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/LibSvmRecordReaderTest.java @@ -24,6 +24,7 @@ import org.datavec.api.writable.DoubleWritable; import org.datavec.api.writable.IntWritable; import org.datavec.api.writable.Writable; import org.junit.Test; +import org.nd4j.BaseND4JTest; import org.nd4j.linalg.io.ClassPathResource; import java.io.IOException; @@ -44,7 +45,7 @@ import static org.junit.Assert.assertEquals; * * @author dave@skymind.io */ -public class LibSvmRecordReaderTest { +public class LibSvmRecordReaderTest extends BaseND4JTest { @Test public void testBasicRecord() throws IOException, InterruptedException { diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/LineReaderTest.java b/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/LineReaderTest.java index 5027357eb..17a41f4d4 100644 --- a/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/LineReaderTest.java +++ b/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/LineReaderTest.java @@ -29,6 +29,7 @@ import org.datavec.api.writable.Writable; import org.junit.Rule; import org.junit.Test; import org.junit.rules.TemporaryFolder; +import org.nd4j.BaseND4JTest; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -48,7 +49,7 @@ import static org.junit.Assert.assertEquals; /** * Created by agibsonccc on 11/17/14. */ -public class LineReaderTest { +public class LineReaderTest extends BaseND4JTest { @Rule public TemporaryFolder testDir = new TemporaryFolder(); diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/RegexRecordReaderTest.java b/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/RegexRecordReaderTest.java index 778d14424..539b0f351 100644 --- a/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/RegexRecordReaderTest.java +++ b/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/RegexRecordReaderTest.java @@ -32,6 +32,7 @@ import org.datavec.api.writable.Writable; import org.junit.Rule; import org.junit.Test; import org.junit.rules.TemporaryFolder; +import org.nd4j.BaseND4JTest; import org.nd4j.linalg.io.ClassPathResource; import java.io.File; @@ -45,7 +46,7 @@ import static org.junit.Assert.assertFalse; /** * Created by Alex on 12/04/2016. */ -public class RegexRecordReaderTest { +public class RegexRecordReaderTest extends BaseND4JTest { @Rule public TemporaryFolder testDir = new TemporaryFolder(); diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/SVMLightRecordReaderTest.java b/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/SVMLightRecordReaderTest.java index 92f8c57e4..25d2959ce 100644 --- a/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/SVMLightRecordReaderTest.java +++ b/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/SVMLightRecordReaderTest.java @@ -24,6 +24,7 @@ import org.datavec.api.writable.DoubleWritable; import org.datavec.api.writable.IntWritable; import org.datavec.api.writable.Writable; import org.junit.Test; +import org.nd4j.BaseND4JTest; import org.nd4j.linalg.io.ClassPathResource; import java.io.IOException; @@ -42,7 +43,7 @@ import static org.junit.Assert.assertEquals; * * @author dave@skymind.io */ -public class SVMLightRecordReaderTest { +public class SVMLightRecordReaderTest extends BaseND4JTest { @Test public void testBasicRecord() throws IOException, InterruptedException { diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/TestCollectionRecordReaders.java b/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/TestCollectionRecordReaders.java index a06d56400..fa68c4a1f 100644 --- a/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/TestCollectionRecordReaders.java +++ b/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/TestCollectionRecordReaders.java @@ -23,6 +23,7 @@ import org.datavec.api.records.reader.impl.collection.CollectionSequenceRecordRe import org.datavec.api.writable.IntWritable; import org.datavec.api.writable.Writable; import org.junit.Test; +import org.nd4j.BaseND4JTest; import java.util.ArrayList; import java.util.Arrays; @@ -33,7 +34,7 @@ import static org.junit.Assert.*; /** * Created by Alex on 21/05/2016. */ -public class TestCollectionRecordReaders { +public class TestCollectionRecordReaders extends BaseND4JTest { @Test public void testCollectionSequenceRecordReader() throws Exception { diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/TestConcatenatingRecordReader.java b/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/TestConcatenatingRecordReader.java index 172d884a3..266ad2edc 100644 --- a/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/TestConcatenatingRecordReader.java +++ b/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/TestConcatenatingRecordReader.java @@ -20,11 +20,12 @@ import org.datavec.api.records.reader.RecordReader; import org.datavec.api.records.reader.impl.csv.CSVRecordReader; import org.datavec.api.split.FileSplit; import org.junit.Test; +import org.nd4j.BaseND4JTest; import org.nd4j.linalg.io.ClassPathResource; import static org.junit.Assert.assertEquals; -public class TestConcatenatingRecordReader { +public class TestConcatenatingRecordReader extends BaseND4JTest { @Test public void test() throws Exception { diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/TestSerialization.java b/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/TestSerialization.java index c249737a3..91fc22886 100644 --- a/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/TestSerialization.java +++ b/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/TestSerialization.java @@ -34,6 +34,7 @@ import org.datavec.api.transform.schema.Schema; import org.datavec.api.writable.Text; import org.datavec.api.writable.Writable; import org.junit.Test; +import org.nd4j.BaseND4JTest; import org.nd4j.linalg.io.ClassPathResource; import org.nd4j.shade.jackson.core.JsonFactory; import org.nd4j.shade.jackson.databind.ObjectMapper; @@ -49,7 +50,7 @@ import static org.junit.Assert.assertEquals; * Note however that not all are used/usable with spark (such as Collection[Sequence]RecordReader * and the rest are generally used without being initialized on a particular dataset */ -public class TestSerialization { +public class TestSerialization extends BaseND4JTest { @Test public void testRR() throws Exception { diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/transform/TransformProcessRecordReaderTests.java b/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/transform/TransformProcessRecordReaderTests.java index 5daad01b3..ff3ceb9be 100644 --- a/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/transform/TransformProcessRecordReaderTests.java +++ b/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/transform/TransformProcessRecordReaderTests.java @@ -27,6 +27,7 @@ import org.datavec.api.writable.LongWritable; import org.datavec.api.writable.Writable; import org.joda.time.DateTimeZone; import org.junit.Test; +import org.nd4j.BaseND4JTest; import org.nd4j.linalg.io.ClassPathResource; import java.util.ArrayList; @@ -39,7 +40,7 @@ import static org.junit.Assert.assertTrue; /** * Created by agibsonccc on 3/21/17. */ -public class TransformProcessRecordReaderTests { +public class TransformProcessRecordReaderTests extends BaseND4JTest { @Test public void simpleTransformTest() throws Exception { diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/records/writer/impl/CSVRecordWriterTest.java b/datavec/datavec-api/src/test/java/org/datavec/api/records/writer/impl/CSVRecordWriterTest.java index 5a165b0ac..c3a8f4181 100644 --- a/datavec/datavec-api/src/test/java/org/datavec/api/records/writer/impl/CSVRecordWriterTest.java +++ b/datavec/datavec-api/src/test/java/org/datavec/api/records/writer/impl/CSVRecordWriterTest.java @@ -24,6 +24,7 @@ import org.datavec.api.writable.Text; import org.datavec.api.writable.Writable; import org.junit.Before; import org.junit.Test; +import org.nd4j.BaseND4JTest; import java.io.File; import java.util.ArrayList; @@ -34,7 +35,7 @@ import static org.junit.Assert.assertEquals; /** * @author raver119@gmail.com */ -public class CSVRecordWriterTest { +public class CSVRecordWriterTest extends BaseND4JTest { @Before public void setUp() throws Exception { diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/records/writer/impl/LibSvmRecordWriterTest.java b/datavec/datavec-api/src/test/java/org/datavec/api/records/writer/impl/LibSvmRecordWriterTest.java index 0c7d70b09..91996056d 100644 --- a/datavec/datavec-api/src/test/java/org/datavec/api/records/writer/impl/LibSvmRecordWriterTest.java +++ b/datavec/datavec-api/src/test/java/org/datavec/api/records/writer/impl/LibSvmRecordWriterTest.java @@ -27,6 +27,7 @@ import org.datavec.api.writable.IntWritable; import org.datavec.api.writable.NDArrayWritable; import org.datavec.api.writable.Writable; import org.junit.Test; +import org.nd4j.BaseND4JTest; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.io.ClassPathResource; @@ -49,7 +50,7 @@ import static org.junit.Assert.assertEquals; * * @author dave@skymind.io */ -public class LibSvmRecordWriterTest { +public class LibSvmRecordWriterTest extends BaseND4JTest { @Test public void testBasic() throws Exception { diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/records/writer/impl/SVMLightRecordWriterTest.java b/datavec/datavec-api/src/test/java/org/datavec/api/records/writer/impl/SVMLightRecordWriterTest.java index 7b9b8c203..f057c7d45 100644 --- a/datavec/datavec-api/src/test/java/org/datavec/api/records/writer/impl/SVMLightRecordWriterTest.java +++ b/datavec/datavec-api/src/test/java/org/datavec/api/records/writer/impl/SVMLightRecordWriterTest.java @@ -25,6 +25,7 @@ import org.datavec.api.split.partition.NumberOfRecordsPartitioner; import org.datavec.api.writable.*; import org.datavec.api.writable.NDArrayWritable; import org.junit.Test; +import org.nd4j.BaseND4JTest; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.io.ClassPathResource; @@ -47,7 +48,7 @@ import static org.junit.Assert.assertEquals; * * @author dave@skymind.io */ -public class SVMLightRecordWriterTest { +public class SVMLightRecordWriterTest extends BaseND4JTest { @Test public void testBasic() throws Exception { diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/split/InputSplitTests.java b/datavec/datavec-api/src/test/java/org/datavec/api/split/InputSplitTests.java index e8ce37bd3..59e1feee8 100644 --- a/datavec/datavec-api/src/test/java/org/datavec/api/split/InputSplitTests.java +++ b/datavec/datavec-api/src/test/java/org/datavec/api/split/InputSplitTests.java @@ -16,6 +16,7 @@ package org.datavec.api.split; +import org.nd4j.BaseND4JTest; import org.nd4j.shade.guava.io.Files; import org.datavec.api.io.filters.BalancedPathFilter; import org.datavec.api.io.filters.RandomPathFilter; @@ -36,7 +37,7 @@ import static org.junit.Assert.assertEquals; * * @author saudet */ -public class InputSplitTests { +public class InputSplitTests extends BaseND4JTest { @Test public void testSample() throws URISyntaxException { diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/split/NumberedFileInputSplitTests.java b/datavec/datavec-api/src/test/java/org/datavec/api/split/NumberedFileInputSplitTests.java index 797f546dd..f8be04d47 100644 --- a/datavec/datavec-api/src/test/java/org/datavec/api/split/NumberedFileInputSplitTests.java +++ b/datavec/datavec-api/src/test/java/org/datavec/api/split/NumberedFileInputSplitTests.java @@ -17,13 +17,14 @@ package org.datavec.api.split; import org.junit.Test; +import org.nd4j.BaseND4JTest; import java.net.URI; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; -public class NumberedFileInputSplitTests { +public class NumberedFileInputSplitTests extends BaseND4JTest { @Test public void testNumberedFileInputSplitBasic() { String baseString = "/path/to/files/prefix%d.suffix"; diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/split/TestStreamInputSplit.java b/datavec/datavec-api/src/test/java/org/datavec/api/split/TestStreamInputSplit.java index c618c625d..94119015c 100644 --- a/datavec/datavec-api/src/test/java/org/datavec/api/split/TestStreamInputSplit.java +++ b/datavec/datavec-api/src/test/java/org/datavec/api/split/TestStreamInputSplit.java @@ -24,6 +24,7 @@ import org.datavec.api.writable.Writable; import org.junit.Rule; import org.junit.Test; import org.junit.rules.TemporaryFolder; +import org.nd4j.BaseND4JTest; import org.nd4j.linalg.function.Function; import java.io.File; @@ -40,7 +41,7 @@ import java.util.Random; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNotEquals; -public class TestStreamInputSplit { +public class TestStreamInputSplit extends BaseND4JTest { @Rule public TemporaryFolder testDir = new TemporaryFolder(); diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/split/TransformSplitTest.java b/datavec/datavec-api/src/test/java/org/datavec/api/split/TransformSplitTest.java index 457f07097..ea6b9fea4 100644 --- a/datavec/datavec-api/src/test/java/org/datavec/api/split/TransformSplitTest.java +++ b/datavec/datavec-api/src/test/java/org/datavec/api/split/TransformSplitTest.java @@ -17,6 +17,7 @@ package org.datavec.api.split; import org.junit.Test; +import org.nd4j.BaseND4JTest; import java.net.URI; import java.net.URISyntaxException; @@ -28,7 +29,7 @@ import static org.junit.Assert.assertArrayEquals; /** * @author Ede Meijer */ -public class TransformSplitTest { +public class TransformSplitTest extends BaseND4JTest { @Test public void testTransform() throws URISyntaxException { Collection inputFiles = asList(new URI("file:///foo/bar/../0.csv"), new URI("file:///foo/1.csv")); diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/split/parittion/PartitionerTests.java b/datavec/datavec-api/src/test/java/org/datavec/api/split/parittion/PartitionerTests.java index c9fb57eb9..f27f7527f 100644 --- a/datavec/datavec-api/src/test/java/org/datavec/api/split/parittion/PartitionerTests.java +++ b/datavec/datavec-api/src/test/java/org/datavec/api/split/parittion/PartitionerTests.java @@ -16,6 +16,7 @@ package org.datavec.api.split.parittion; +import org.nd4j.BaseND4JTest; import org.nd4j.shade.guava.io.Files; import org.datavec.api.conf.Configuration; import org.datavec.api.split.FileSplit; @@ -31,7 +32,7 @@ import static junit.framework.TestCase.assertTrue; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNotNull; -public class PartitionerTests { +public class PartitionerTests extends BaseND4JTest { @Test public void testRecordsPerFilePartition() { Partitioner partitioner = new NumberOfRecordsPartitioner(); diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/transform/TestTransformProcess.java b/datavec/datavec-api/src/test/java/org/datavec/api/transform/TestTransformProcess.java index eeb4be27a..efb9f2b6e 100644 --- a/datavec/datavec-api/src/test/java/org/datavec/api/transform/TestTransformProcess.java +++ b/datavec/datavec-api/src/test/java/org/datavec/api/transform/TestTransformProcess.java @@ -26,12 +26,13 @@ import org.datavec.api.writable.IntWritable; import org.datavec.api.writable.Text; import org.datavec.api.writable.Writable; import org.junit.Test; +import org.nd4j.BaseND4JTest; import java.util.*; import static org.junit.Assert.assertEquals; -public class TestTransformProcess { +public class TestTransformProcess extends BaseND4JTest { @Test public void testExecution(){ diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/transform/condition/TestConditions.java b/datavec/datavec-api/src/test/java/org/datavec/api/transform/condition/TestConditions.java index da4e53398..0c69959d6 100644 --- a/datavec/datavec-api/src/test/java/org/datavec/api/transform/condition/TestConditions.java +++ b/datavec/datavec-api/src/test/java/org/datavec/api/transform/condition/TestConditions.java @@ -24,6 +24,7 @@ import org.datavec.api.transform.schema.Schema; import org.datavec.api.transform.transform.TestTransforms; import org.datavec.api.writable.*; import org.junit.Test; +import org.nd4j.BaseND4JTest; import java.util.*; @@ -33,7 +34,7 @@ import static org.junit.Assert.assertTrue; /** * Created by Alex on 24/03/2016. */ -public class TestConditions { +public class TestConditions extends BaseND4JTest { @Test public void testIntegerCondition() { diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/transform/filter/TestFilters.java b/datavec/datavec-api/src/test/java/org/datavec/api/transform/filter/TestFilters.java index 4d96b5b6e..314ee72ff 100644 --- a/datavec/datavec-api/src/test/java/org/datavec/api/transform/filter/TestFilters.java +++ b/datavec/datavec-api/src/test/java/org/datavec/api/transform/filter/TestFilters.java @@ -24,6 +24,7 @@ import org.datavec.api.writable.DoubleWritable; import org.datavec.api.writable.IntWritable; import org.datavec.api.writable.Writable; import org.junit.Test; +import org.nd4j.BaseND4JTest; import java.util.ArrayList; import java.util.Arrays; @@ -37,7 +38,7 @@ import static org.junit.Assert.assertTrue; /** * Created by Alex on 21/03/2016. */ -public class TestFilters { +public class TestFilters extends BaseND4JTest { @Test diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/transform/join/TestJoin.java b/datavec/datavec-api/src/test/java/org/datavec/api/transform/join/TestJoin.java index e6ae74185..1d113c6ff 100644 --- a/datavec/datavec-api/src/test/java/org/datavec/api/transform/join/TestJoin.java +++ b/datavec/datavec-api/src/test/java/org/datavec/api/transform/join/TestJoin.java @@ -23,6 +23,7 @@ import org.datavec.api.writable.NullWritable; import org.datavec.api.writable.Text; import org.datavec.api.writable.Writable; import org.junit.Test; +import org.nd4j.BaseND4JTest; import java.util.ArrayList; import java.util.Arrays; @@ -33,7 +34,7 @@ import static org.junit.Assert.assertEquals; /** * Created by Alex on 18/04/2016. */ -public class TestJoin { +public class TestJoin extends BaseND4JTest { @Test public void testJoin() { diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/transform/ops/AggregableMultiOpTest.java b/datavec/datavec-api/src/test/java/org/datavec/api/transform/ops/AggregableMultiOpTest.java index 059cb618c..57ec54e8a 100644 --- a/datavec/datavec-api/src/test/java/org/datavec/api/transform/ops/AggregableMultiOpTest.java +++ b/datavec/datavec-api/src/test/java/org/datavec/api/transform/ops/AggregableMultiOpTest.java @@ -18,6 +18,7 @@ package org.datavec.api.transform.ops; import org.datavec.api.writable.Writable; import org.junit.Test; +import org.nd4j.BaseND4JTest; import java.io.Serializable; import java.util.*; @@ -27,7 +28,7 @@ import static org.junit.Assert.assertTrue; /** * Created by huitseeker on 5/14/17. */ -public class AggregableMultiOpTest { +public class AggregableMultiOpTest extends BaseND4JTest { private List intList = new ArrayList<>(Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9)); diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/transform/ops/AggregatorImplsTest.java b/datavec/datavec-api/src/test/java/org/datavec/api/transform/ops/AggregatorImplsTest.java index 487926c7a..c722dada4 100644 --- a/datavec/datavec-api/src/test/java/org/datavec/api/transform/ops/AggregatorImplsTest.java +++ b/datavec/datavec-api/src/test/java/org/datavec/api/transform/ops/AggregatorImplsTest.java @@ -19,6 +19,7 @@ package org.datavec.api.transform.ops; import org.junit.Rule; import org.junit.Test; import org.junit.rules.ExpectedException; +import org.nd4j.BaseND4JTest; import java.util.ArrayList; import java.util.Arrays; @@ -30,7 +31,7 @@ import static org.junit.Assert.assertTrue; /** * Created by huitseeker on 5/14/17. */ -public class AggregatorImplsTest { +public class AggregatorImplsTest extends BaseND4JTest { private List intList = new ArrayList<>(Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9)); private List stringList = new ArrayList<>(Arrays.asList("arakoa", "abracadabra", "blast", "acceptance")); diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/transform/ops/DispatchOpTest.java b/datavec/datavec-api/src/test/java/org/datavec/api/transform/ops/DispatchOpTest.java index 076e4412d..a636e7239 100644 --- a/datavec/datavec-api/src/test/java/org/datavec/api/transform/ops/DispatchOpTest.java +++ b/datavec/datavec-api/src/test/java/org/datavec/api/transform/ops/DispatchOpTest.java @@ -18,6 +18,7 @@ package org.datavec.api.transform.ops; import org.datavec.api.writable.Writable; import org.junit.Test; +import org.nd4j.BaseND4JTest; import java.util.ArrayList; import java.util.Arrays; @@ -29,7 +30,7 @@ import static org.junit.Assert.assertTrue; /** * Created by huitseeker on 5/14/17. */ -public class DispatchOpTest { +public class DispatchOpTest extends BaseND4JTest { private List intList = new ArrayList<>(Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9)); private List stringList = new ArrayList<>(Arrays.asList("arakoa", "abracadabra", "blast", "acceptance")); diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/transform/reduce/TestMultiOpReduce.java b/datavec/datavec-api/src/test/java/org/datavec/api/transform/reduce/TestMultiOpReduce.java index 1b0a20430..9aef39aa4 100644 --- a/datavec/datavec-api/src/test/java/org/datavec/api/transform/reduce/TestMultiOpReduce.java +++ b/datavec/datavec-api/src/test/java/org/datavec/api/transform/reduce/TestMultiOpReduce.java @@ -29,6 +29,7 @@ import org.datavec.api.transform.ops.IAggregableReduceOp; import org.datavec.api.transform.schema.Schema; import org.datavec.api.writable.*; import org.junit.Test; +import org.nd4j.BaseND4JTest; import java.util.*; @@ -38,7 +39,7 @@ import static org.junit.Assert.fail; /** * Created by Alex on 21/03/2016. */ -public class TestMultiOpReduce { +public class TestMultiOpReduce extends BaseND4JTest { @Test public void testMultiOpReducerDouble() { diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/transform/reduce/TestReductions.java b/datavec/datavec-api/src/test/java/org/datavec/api/transform/reduce/TestReductions.java index f9debfe2c..dc6443630 100644 --- a/datavec/datavec-api/src/test/java/org/datavec/api/transform/reduce/TestReductions.java +++ b/datavec/datavec-api/src/test/java/org/datavec/api/transform/reduce/TestReductions.java @@ -21,13 +21,14 @@ import org.datavec.api.transform.reduce.impl.GeographicMidpointReduction; import org.datavec.api.writable.Text; import org.datavec.api.writable.Writable; import org.junit.Test; +import org.nd4j.BaseND4JTest; import java.util.Arrays; import java.util.List; import static org.junit.Assert.assertEquals; -public class TestReductions { +public class TestReductions extends BaseND4JTest { @Test public void testGeographicMidPointReduction(){ diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/transform/schema/TestJsonYaml.java b/datavec/datavec-api/src/test/java/org/datavec/api/transform/schema/TestJsonYaml.java index dff90f8b9..8e33b742c 100644 --- a/datavec/datavec-api/src/test/java/org/datavec/api/transform/schema/TestJsonYaml.java +++ b/datavec/datavec-api/src/test/java/org/datavec/api/transform/schema/TestJsonYaml.java @@ -19,13 +19,14 @@ package org.datavec.api.transform.schema; import org.datavec.api.transform.metadata.ColumnMetaData; import org.joda.time.DateTimeZone; import org.junit.Test; +import org.nd4j.BaseND4JTest; import static org.junit.Assert.assertEquals; /** * Created by Alex on 18/07/2016. */ -public class TestJsonYaml { +public class TestJsonYaml extends BaseND4JTest { @Test public void testToFromJsonYaml() { diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/transform/schema/TestSchemaMethods.java b/datavec/datavec-api/src/test/java/org/datavec/api/transform/schema/TestSchemaMethods.java index 870c10680..6cbcafff4 100644 --- a/datavec/datavec-api/src/test/java/org/datavec/api/transform/schema/TestSchemaMethods.java +++ b/datavec/datavec-api/src/test/java/org/datavec/api/transform/schema/TestSchemaMethods.java @@ -18,13 +18,14 @@ package org.datavec.api.transform.schema; import org.datavec.api.transform.ColumnType; import org.junit.Test; +import org.nd4j.BaseND4JTest; import static org.junit.Assert.assertEquals; /** * Created by Alex on 04/09/2016. */ -public class TestSchemaMethods { +public class TestSchemaMethods extends BaseND4JTest { @Test public void testNumberedColumnAdding() { diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/transform/sequence/TestReduceSequenceByWindowFunction.java b/datavec/datavec-api/src/test/java/org/datavec/api/transform/sequence/TestReduceSequenceByWindowFunction.java index 0f48eeff3..56c8d3f1e 100644 --- a/datavec/datavec-api/src/test/java/org/datavec/api/transform/sequence/TestReduceSequenceByWindowFunction.java +++ b/datavec/datavec-api/src/test/java/org/datavec/api/transform/sequence/TestReduceSequenceByWindowFunction.java @@ -30,6 +30,7 @@ import org.datavec.api.writable.NullWritable; import org.datavec.api.writable.Writable; import org.joda.time.DateTimeZone; import org.junit.Test; +import org.nd4j.BaseND4JTest; import java.util.ArrayList; import java.util.Arrays; @@ -41,7 +42,7 @@ import static org.junit.Assert.assertEquals; /** * Created by Alex on 16/04/2016. */ -public class TestReduceSequenceByWindowFunction { +public class TestReduceSequenceByWindowFunction extends BaseND4JTest { @Test public void testReduceSequenceByWindowFunction() { diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/transform/sequence/TestSequenceSplit.java b/datavec/datavec-api/src/test/java/org/datavec/api/transform/sequence/TestSequenceSplit.java index 6695599a5..98dd49587 100644 --- a/datavec/datavec-api/src/test/java/org/datavec/api/transform/sequence/TestSequenceSplit.java +++ b/datavec/datavec-api/src/test/java/org/datavec/api/transform/sequence/TestSequenceSplit.java @@ -24,6 +24,7 @@ import org.datavec.api.writable.Text; import org.datavec.api.writable.Writable; import org.joda.time.DateTimeZone; import org.junit.Test; +import org.nd4j.BaseND4JTest; import java.util.ArrayList; import java.util.Arrays; @@ -35,7 +36,7 @@ import static org.junit.Assert.assertEquals; /** * Created by Alex on 19/04/2016. */ -public class TestSequenceSplit { +public class TestSequenceSplit extends BaseND4JTest { @Test public void testSequenceSplitTimeSeparation() { diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/transform/sequence/TestWindowFunctions.java b/datavec/datavec-api/src/test/java/org/datavec/api/transform/sequence/TestWindowFunctions.java index 99fe9227d..cc12adc53 100644 --- a/datavec/datavec-api/src/test/java/org/datavec/api/transform/sequence/TestWindowFunctions.java +++ b/datavec/datavec-api/src/test/java/org/datavec/api/transform/sequence/TestWindowFunctions.java @@ -26,6 +26,7 @@ import org.datavec.api.writable.LongWritable; import org.datavec.api.writable.Writable; import org.joda.time.DateTimeZone; import org.junit.Test; +import org.nd4j.BaseND4JTest; import java.util.ArrayList; import java.util.Arrays; @@ -37,7 +38,7 @@ import static org.junit.Assert.assertEquals; /** * Created by Alex on 16/04/2016. */ -public class TestWindowFunctions { +public class TestWindowFunctions extends BaseND4JTest { @Test public void testTimeWindowFunction() { diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/transform/serde/TestCustomTransformJsonYaml.java b/datavec/datavec-api/src/test/java/org/datavec/api/transform/serde/TestCustomTransformJsonYaml.java index 03731f6d6..1da9f48e5 100644 --- a/datavec/datavec-api/src/test/java/org/datavec/api/transform/serde/TestCustomTransformJsonYaml.java +++ b/datavec/datavec-api/src/test/java/org/datavec/api/transform/serde/TestCustomTransformJsonYaml.java @@ -23,13 +23,14 @@ import org.datavec.api.transform.serde.testClasses.CustomCondition; import org.datavec.api.transform.serde.testClasses.CustomFilter; import org.datavec.api.transform.serde.testClasses.CustomTransform; import org.junit.Test; +import org.nd4j.BaseND4JTest; import static org.junit.Assert.assertEquals; /** * Created by Alex on 11/01/2017. */ -public class TestCustomTransformJsonYaml { +public class TestCustomTransformJsonYaml extends BaseND4JTest { @Test public void testCustomTransform() { diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/transform/serde/TestYamlJsonSerde.java b/datavec/datavec-api/src/test/java/org/datavec/api/transform/serde/TestYamlJsonSerde.java index d09995009..dd6e0941a 100644 --- a/datavec/datavec-api/src/test/java/org/datavec/api/transform/serde/TestYamlJsonSerde.java +++ b/datavec/datavec-api/src/test/java/org/datavec/api/transform/serde/TestYamlJsonSerde.java @@ -61,6 +61,7 @@ import org.datavec.api.writable.comparator.DoubleWritableComparator; import org.joda.time.DateTimeFieldType; import org.joda.time.DateTimeZone; import org.junit.Test; +import org.nd4j.BaseND4JTest; import java.util.*; import java.util.concurrent.TimeUnit; @@ -70,7 +71,7 @@ import static org.junit.Assert.assertEquals; /** * Created by Alex on 20/07/2016. */ -public class TestYamlJsonSerde { +public class TestYamlJsonSerde extends BaseND4JTest { public static YamlSerializer y = new YamlSerializer(); public static JsonSerializer j = new JsonSerializer(); diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/transform/stringreduce/TestReduce.java b/datavec/datavec-api/src/test/java/org/datavec/api/transform/stringreduce/TestReduce.java index e17650a87..ac69e3397 100644 --- a/datavec/datavec-api/src/test/java/org/datavec/api/transform/stringreduce/TestReduce.java +++ b/datavec/datavec-api/src/test/java/org/datavec/api/transform/stringreduce/TestReduce.java @@ -21,6 +21,7 @@ import org.datavec.api.transform.schema.Schema; import org.datavec.api.writable.Text; import org.datavec.api.writable.Writable; import org.junit.Test; +import org.nd4j.BaseND4JTest; import java.util.*; @@ -29,7 +30,7 @@ import static org.junit.Assert.assertEquals; /** * Created by Alex on 21/03/2016. */ -public class TestReduce { +public class TestReduce extends BaseND4JTest { @Test public void testReducerDouble() { diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/transform/transform/RegressionTestJson.java b/datavec/datavec-api/src/test/java/org/datavec/api/transform/transform/RegressionTestJson.java index 38ec1fda9..daa5c15c8 100644 --- a/datavec/datavec-api/src/test/java/org/datavec/api/transform/transform/RegressionTestJson.java +++ b/datavec/datavec-api/src/test/java/org/datavec/api/transform/transform/RegressionTestJson.java @@ -47,6 +47,7 @@ import org.datavec.api.writable.comparator.LongWritableComparator; import org.joda.time.DateTimeFieldType; import org.joda.time.DateTimeZone; import org.junit.Test; +import org.nd4j.BaseND4JTest; import org.nd4j.linalg.io.ClassPathResource; import java.io.File; @@ -58,7 +59,7 @@ import java.util.concurrent.TimeUnit; import static org.junit.Assert.assertEquals; -public class RegressionTestJson { +public class RegressionTestJson extends BaseND4JTest { @Test public void regressionTestJson100a() throws Exception { diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/transform/transform/TestJsonYaml.java b/datavec/datavec-api/src/test/java/org/datavec/api/transform/transform/TestJsonYaml.java index 9f647d365..00c4b745f 100644 --- a/datavec/datavec-api/src/test/java/org/datavec/api/transform/transform/TestJsonYaml.java +++ b/datavec/datavec-api/src/test/java/org/datavec/api/transform/transform/TestJsonYaml.java @@ -47,6 +47,7 @@ import org.datavec.api.writable.comparator.LongWritableComparator; import org.joda.time.DateTimeFieldType; import org.joda.time.DateTimeZone; import org.junit.Test; +import org.nd4j.BaseND4JTest; import java.util.*; import java.util.concurrent.TimeUnit; @@ -56,7 +57,7 @@ import static org.junit.Assert.assertEquals; /** * Created by Alex on 18/07/2016. */ -public class TestJsonYaml { +public class TestJsonYaml extends BaseND4JTest { @Test public void testToFromJsonYaml() { diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/transform/transform/TestTransforms.java b/datavec/datavec-api/src/test/java/org/datavec/api/transform/transform/TestTransforms.java index 600ee0b25..1d440913b 100644 --- a/datavec/datavec-api/src/test/java/org/datavec/api/transform/transform/TestTransforms.java +++ b/datavec/datavec-api/src/test/java/org/datavec/api/transform/transform/TestTransforms.java @@ -56,6 +56,7 @@ import org.joda.time.DateTimeFieldType; import org.joda.time.DateTimeZone; import org.junit.Assert; import org.junit.Test; +import org.nd4j.BaseND4JTest; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; @@ -72,7 +73,7 @@ import static org.junit.Assert.*; /** * Created by Alex on 21/03/2016. */ -public class TestTransforms { +public class TestTransforms extends BaseND4JTest { public static Schema getSchema(ColumnType type, String... colNames) { diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/transform/transform/ndarray/TestNDArrayWritableTransforms.java b/datavec/datavec-api/src/test/java/org/datavec/api/transform/transform/ndarray/TestNDArrayWritableTransforms.java index 78d929e65..c6dad8359 100644 --- a/datavec/datavec-api/src/test/java/org/datavec/api/transform/transform/ndarray/TestNDArrayWritableTransforms.java +++ b/datavec/datavec-api/src/test/java/org/datavec/api/transform/transform/ndarray/TestNDArrayWritableTransforms.java @@ -26,6 +26,7 @@ import org.datavec.api.writable.NDArrayWritable; import org.datavec.api.writable.Text; import org.datavec.api.writable.Writable; import org.junit.Test; +import org.nd4j.BaseND4JTest; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; @@ -39,7 +40,7 @@ import static org.junit.Assert.assertEquals; /** * Created by Alex on 02/06/2017. */ -public class TestNDArrayWritableTransforms { +public class TestNDArrayWritableTransforms extends BaseND4JTest { @Test public void testNDArrayWritableBasic() { diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/transform/transform/ndarray/TestYamlJsonSerde.java b/datavec/datavec-api/src/test/java/org/datavec/api/transform/transform/ndarray/TestYamlJsonSerde.java index 7eb3efdef..394457443 100644 --- a/datavec/datavec-api/src/test/java/org/datavec/api/transform/transform/ndarray/TestYamlJsonSerde.java +++ b/datavec/datavec-api/src/test/java/org/datavec/api/transform/transform/ndarray/TestYamlJsonSerde.java @@ -27,6 +27,7 @@ import org.datavec.api.transform.schema.Schema; import org.datavec.api.transform.serde.JsonSerializer; import org.datavec.api.transform.serde.YamlSerializer; import org.junit.Test; +import org.nd4j.BaseND4JTest; import java.util.Arrays; import java.util.List; @@ -36,7 +37,7 @@ import static org.junit.Assert.assertEquals; /** * Created by Alex on 20/07/2016. */ -public class TestYamlJsonSerde { +public class TestYamlJsonSerde extends BaseND4JTest { public static YamlSerializer y = new YamlSerializer(); public static JsonSerializer j = new JsonSerializer(); diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/transform/transform/parse/ParseDoubleTransformTest.java b/datavec/datavec-api/src/test/java/org/datavec/api/transform/transform/parse/ParseDoubleTransformTest.java index 8ec5233c7..4c2c718ae 100644 --- a/datavec/datavec-api/src/test/java/org/datavec/api/transform/transform/parse/ParseDoubleTransformTest.java +++ b/datavec/datavec-api/src/test/java/org/datavec/api/transform/transform/parse/ParseDoubleTransformTest.java @@ -20,6 +20,7 @@ import org.datavec.api.writable.DoubleWritable; import org.datavec.api.writable.Text; import org.datavec.api.writable.Writable; import org.junit.Test; +import org.nd4j.BaseND4JTest; import java.util.ArrayList; import java.util.Arrays; @@ -30,7 +31,7 @@ import static org.junit.Assert.assertEquals; /** * Created by agibsonccc on 10/22/16. */ -public class ParseDoubleTransformTest { +public class ParseDoubleTransformTest extends BaseND4JTest { @Test public void testDoubleTransform() { List record = new ArrayList<>(); diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/transform/ui/TestUI.java b/datavec/datavec-api/src/test/java/org/datavec/api/transform/ui/TestUI.java index 0dffb6dab..64f6a4422 100644 --- a/datavec/datavec-api/src/test/java/org/datavec/api/transform/ui/TestUI.java +++ b/datavec/datavec-api/src/test/java/org/datavec/api/transform/ui/TestUI.java @@ -35,6 +35,7 @@ import org.junit.Ignore; import org.junit.Rule; import org.junit.Test; import org.junit.rules.TemporaryFolder; +import org.nd4j.BaseND4JTest; import java.io.File; import java.util.ArrayList; @@ -46,7 +47,7 @@ import static org.junit.Assert.assertEquals; /** * Created by Alex on 25/03/2016. */ -public class TestUI { +public class TestUI extends BaseND4JTest { @Rule public TemporaryFolder testDir = new TemporaryFolder(); diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/util/ClassPathResourceTest.java b/datavec/datavec-api/src/test/java/org/datavec/api/util/ClassPathResourceTest.java index 9b95bbfb4..b68ae43ee 100644 --- a/datavec/datavec-api/src/test/java/org/datavec/api/util/ClassPathResourceTest.java +++ b/datavec/datavec-api/src/test/java/org/datavec/api/util/ClassPathResourceTest.java @@ -18,6 +18,7 @@ package org.datavec.api.util; import org.junit.Before; import org.junit.Test; +import org.nd4j.BaseND4JTest; import java.io.BufferedReader; import java.io.File; @@ -33,7 +34,7 @@ import static org.hamcrest.core.IsEqual.equalTo; /** * @author raver119@gmail.com */ -public class ClassPathResourceTest { +public class ClassPathResourceTest extends BaseND4JTest { private boolean isWindows = false; //File sizes are reported slightly different on Linux vs. Windows diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/util/TimeSeriesUtilsTest.java b/datavec/datavec-api/src/test/java/org/datavec/api/util/TimeSeriesUtilsTest.java index 1545938f6..d47ec60d7 100644 --- a/datavec/datavec-api/src/test/java/org/datavec/api/util/TimeSeriesUtilsTest.java +++ b/datavec/datavec-api/src/test/java/org/datavec/api/util/TimeSeriesUtilsTest.java @@ -20,6 +20,7 @@ import org.datavec.api.timeseries.util.TimeSeriesWritableUtils; import org.datavec.api.writable.DoubleWritable; import org.datavec.api.writable.Writable; import org.junit.Test; +import org.nd4j.BaseND4JTest; import org.nd4j.linalg.api.ndarray.INDArray; import java.util.ArrayList; @@ -27,7 +28,7 @@ import java.util.List; import static org.junit.Assert.assertArrayEquals; -public class TimeSeriesUtilsTest { +public class TimeSeriesUtilsTest extends BaseND4JTest { @Test public void testTimeSeriesCreation() { diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/writable/RecordConverterTest.java b/datavec/datavec-api/src/test/java/org/datavec/api/writable/RecordConverterTest.java index 6dfacdd93..dbc62ed93 100644 --- a/datavec/datavec-api/src/test/java/org/datavec/api/writable/RecordConverterTest.java +++ b/datavec/datavec-api/src/test/java/org/datavec/api/writable/RecordConverterTest.java @@ -16,6 +16,7 @@ package org.datavec.api.writable; +import org.nd4j.BaseND4JTest; import org.nd4j.shade.guava.collect.Lists; import org.datavec.api.transform.schema.Schema; import org.datavec.api.util.ndarray.RecordConverter; @@ -31,7 +32,7 @@ import java.util.TimeZone; import static org.junit.Assert.assertEquals; -public class RecordConverterTest { +public class RecordConverterTest extends BaseND4JTest { @Test public void toRecords_PassInClassificationDataSet_ExpectNDArrayAndIntWritables() { INDArray feature1 = Nd4j.create(new double[]{4, -5.7, 10, -0.1}, new long[]{1, 4}, DataType.FLOAT); @@ -86,7 +87,7 @@ public class RecordConverterTest { new IntWritable(1)); INDArray exp = Nd4j.create(new double[]{1, 2, 3, 4, 5, 6, 7, 8, 9, 1}, new long[]{1, 10}, DataType.FLOAT); - INDArray act = RecordConverter.toArray(l); + INDArray act = RecordConverter.toArray(DataType.FLOAT, l); assertEquals(exp, act); } @@ -101,7 +102,7 @@ public class RecordConverterTest { {1,2,3,4,5}, {6,7,8,9,10}}).castTo(DataType.FLOAT); - INDArray act = RecordConverter.toMatrix(Arrays.asList(l1,l2)); + INDArray act = RecordConverter.toMatrix(DataType.FLOAT, Arrays.asList(l1,l2)); assertEquals(exp, act); } diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/writable/TestNDArrayWritableAndSerialization.java b/datavec/datavec-api/src/test/java/org/datavec/api/writable/TestNDArrayWritableAndSerialization.java index 81c2f2d73..9242927e1 100644 --- a/datavec/datavec-api/src/test/java/org/datavec/api/writable/TestNDArrayWritableAndSerialization.java +++ b/datavec/datavec-api/src/test/java/org/datavec/api/writable/TestNDArrayWritableAndSerialization.java @@ -18,6 +18,7 @@ package org.datavec.api.writable; import org.datavec.api.transform.metadata.NDArrayMetaData; import org.junit.Test; +import org.nd4j.BaseND4JTest; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; @@ -28,7 +29,7 @@ import static org.junit.Assert.*; /** * Created by Alex on 02/06/2017. */ -public class TestNDArrayWritableAndSerialization { +public class TestNDArrayWritableAndSerialization extends BaseND4JTest { @Test public void testIsValid() { diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/writable/WritableTest.java b/datavec/datavec-api/src/test/java/org/datavec/api/writable/WritableTest.java index 93d7ed31b..bd636e62b 100644 --- a/datavec/datavec-api/src/test/java/org/datavec/api/writable/WritableTest.java +++ b/datavec/datavec-api/src/test/java/org/datavec/api/writable/WritableTest.java @@ -18,6 +18,7 @@ package org.datavec.api.writable; import org.datavec.api.writable.batch.NDArrayRecordBatch; import org.junit.Test; +import org.nd4j.BaseND4JTest; import org.nd4j.linalg.api.buffer.DataBuffer; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; @@ -31,9 +32,7 @@ import java.util.List; import static org.junit.Assert.*; -public class WritableTest { - - +public class WritableTest extends BaseND4JTest { @Test public void testWritableEqualityReflexive() { diff --git a/datavec/datavec-arrow/pom.xml b/datavec/datavec-arrow/pom.xml index 04420a5e9..60409bc53 100644 --- a/datavec/datavec-arrow/pom.xml +++ b/datavec/datavec-arrow/pom.xml @@ -49,6 +49,12 @@ arrow-format ${arrow.version} + + org.nd4j + nd4j-common-tests + ${nd4j.version} + test + diff --git a/datavec/datavec-arrow/src/test/java/org/datavec/arrow/ArrowConverterTest.java b/datavec/datavec-arrow/src/test/java/org/datavec/arrow/ArrowConverterTest.java index 1d8fddc0e..edd036f0a 100644 --- a/datavec/datavec-arrow/src/test/java/org/datavec/arrow/ArrowConverterTest.java +++ b/datavec/datavec-arrow/src/test/java/org/datavec/arrow/ArrowConverterTest.java @@ -40,6 +40,7 @@ import org.datavec.arrow.recordreader.ArrowWritableRecordBatch; import org.junit.Rule; import org.junit.Test; import org.junit.rules.TemporaryFolder; +import org.nd4j.BaseND4JTest; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.primitives.Pair; @@ -56,7 +57,7 @@ import static org.junit.Assert.assertArrayEquals; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; -public class ArrowConverterTest { +public class ArrowConverterTest extends BaseND4JTest { private static BufferAllocator bufferAllocator = new RootAllocator(Long.MAX_VALUE); diff --git a/datavec/datavec-arrow/src/test/java/org/datavec/arrow/AssertTestsExtendBaseClass.java b/datavec/datavec-arrow/src/test/java/org/datavec/arrow/AssertTestsExtendBaseClass.java new file mode 100644 index 000000000..f2cf7ce09 --- /dev/null +++ b/datavec/datavec-arrow/src/test/java/org/datavec/arrow/AssertTestsExtendBaseClass.java @@ -0,0 +1,50 @@ +/* ****************************************************************************** + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.datavec.arrow; + +import lombok.extern.slf4j.Slf4j; +import org.nd4j.BaseND4JTest; +import org.nd4j.AbstractAssertTestsClass; + +import java.util.*; + +/** + * This class checks that all test classes (i.e., anything with one or more methods annotated with @Test) + * extends BaseND4jTest - either directly or indirectly. + * Other than a small set of exceptions, all tests must extend this + * + * @author Alex Black + */ + +@Slf4j +public class AssertTestsExtendBaseClass extends AbstractAssertTestsClass { + + @Override + protected Set> getExclusions() { + //Set of classes that are exclusions to the rule (either run manually or have their own logging + timeouts) + return new HashSet<>(); + } + + @Override + protected String getPackageName() { + return "org.datavec.arrow"; + } + + @Override + protected Class getBaseClass() { + return BaseND4JTest.class; + } +} diff --git a/datavec/datavec-arrow/src/test/java/org/datavec/arrow/RecordMapperTest.java b/datavec/datavec-arrow/src/test/java/org/datavec/arrow/RecordMapperTest.java index 390bfdcd9..59ba5a546 100644 --- a/datavec/datavec-arrow/src/test/java/org/datavec/arrow/RecordMapperTest.java +++ b/datavec/datavec-arrow/src/test/java/org/datavec/arrow/RecordMapperTest.java @@ -31,6 +31,7 @@ import org.datavec.api.writable.Writable; import org.datavec.arrow.recordreader.ArrowRecordReader; import org.datavec.arrow.recordreader.ArrowRecordWriter; import org.junit.Test; +import org.nd4j.BaseND4JTest; import org.nd4j.linalg.primitives.Triple; import java.io.File; @@ -41,7 +42,7 @@ import java.util.List; import static org.junit.Assert.assertEquals; -public class RecordMapperTest { +public class RecordMapperTest extends BaseND4JTest { @Test public void testMultiWrite() throws Exception { diff --git a/datavec/datavec-arrow/src/test/java/org/datavec/arrow/recordreader/ArrowWritableRecordTimeSeriesBatchTests.java b/datavec/datavec-arrow/src/test/java/org/datavec/arrow/recordreader/ArrowWritableRecordTimeSeriesBatchTests.java index 6951561cd..e49a9fcc4 100644 --- a/datavec/datavec-arrow/src/test/java/org/datavec/arrow/recordreader/ArrowWritableRecordTimeSeriesBatchTests.java +++ b/datavec/datavec-arrow/src/test/java/org/datavec/arrow/recordreader/ArrowWritableRecordTimeSeriesBatchTests.java @@ -27,6 +27,7 @@ import org.datavec.api.writable.Writable; import org.datavec.arrow.ArrowConverter; import org.junit.Ignore; import org.junit.Test; +import org.nd4j.BaseND4JTest; import java.util.ArrayList; import java.util.Arrays; @@ -35,7 +36,7 @@ import java.util.List; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; -public class ArrowWritableRecordTimeSeriesBatchTests { +public class ArrowWritableRecordTimeSeriesBatchTests extends BaseND4JTest { private static BufferAllocator bufferAllocator = new RootAllocator(Long.MAX_VALUE); diff --git a/datavec/datavec-data/datavec-data-audio/pom.xml b/datavec/datavec-data/datavec-data-audio/pom.xml index 1f99eab7c..3b9674cd9 100644 --- a/datavec/datavec-data/datavec-data-audio/pom.xml +++ b/datavec/datavec-data/datavec-data-audio/pom.xml @@ -57,6 +57,13 @@ with-dependencies + + org.nd4j + nd4j-common-tests + ${nd4j.version} + test + + - + datavec-parent org.datavec @@ -31,6 +32,12 @@ datavec-api ${project.version} + + org.nd4j + nd4j-common-tests + ${nd4j.version} + test + com.maxmind.geoip2 geoip2 diff --git a/datavec/datavec-data/datavec-geo/src/test/java/org/datavec/api/transform/AssertTestsExtendBaseClass.java b/datavec/datavec-data/datavec-geo/src/test/java/org/datavec/api/transform/AssertTestsExtendBaseClass.java new file mode 100644 index 000000000..7d4a6836c --- /dev/null +++ b/datavec/datavec-data/datavec-geo/src/test/java/org/datavec/api/transform/AssertTestsExtendBaseClass.java @@ -0,0 +1,49 @@ +/* ****************************************************************************** + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.datavec.api.transform; + +import lombok.extern.slf4j.Slf4j; +import java.util.*; +import org.nd4j.AbstractAssertTestsClass; +import org.nd4j.BaseND4JTest; + +/** + * This class checks that all test classes (i.e., anything with one or more methods annotated with @Test) + * extends BaseND4jTest - either directly or indirectly. + * Other than a small set of exceptions, all tests must extend this + * + * @author Alex Black + */ + +@Slf4j +public class AssertTestsExtendBaseClass extends AbstractAssertTestsClass { + + @Override + protected Set> getExclusions() { + //Set of classes that are exclusions to the rule (either run manually or have their own logging + timeouts) + return new HashSet<>(); + } + + @Override + protected String getPackageName() { + return "org.datavec.api.transform"; + } + + @Override + protected Class getBaseClass() { + return BaseND4JTest.class; + } +} diff --git a/datavec/datavec-data/datavec-hadoop/pom.xml b/datavec/datavec-data/datavec-hadoop/pom.xml index 7b74ead38..a6c72b968 100644 --- a/datavec/datavec-data/datavec-hadoop/pom.xml +++ b/datavec/datavec-data/datavec-hadoop/pom.xml @@ -60,6 +60,13 @@ + + + org.nd4j + nd4j-common-tests + ${nd4j.version} + test + diff --git a/datavec/datavec-data/datavec-hadoop/src/test/java/org/datavec/hadoop/AssertTestsExtendBaseClass.java b/datavec/datavec-data/datavec-hadoop/src/test/java/org/datavec/hadoop/AssertTestsExtendBaseClass.java new file mode 100644 index 000000000..2aaf25041 --- /dev/null +++ b/datavec/datavec-data/datavec-hadoop/src/test/java/org/datavec/hadoop/AssertTestsExtendBaseClass.java @@ -0,0 +1,48 @@ +/* ****************************************************************************** + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.datavec.hadoop; + +import lombok.extern.slf4j.Slf4j; +import java.util.*; +import org.nd4j.AbstractAssertTestsClass; +import org.nd4j.BaseND4JTest; +/** + * This class checks that all test classes (i.e., anything with one or more methods annotated with @Test) + * extends BaseND4jTest - either directly or indirectly. + * Other than a small set of exceptions, all tests must extend this + * + * @author Alex Black + */ + +@Slf4j +public class AssertTestsExtendBaseClass extends AbstractAssertTestsClass { + + @Override + protected Set> getExclusions() { + //Set of classes that are exclusions to the rule (either run manually or have their own logging + timeouts) + return new HashSet<>(); + } + + @Override + protected String getPackageName() { + return "org.datavec.hadoop"; + } + + @Override + protected Class getBaseClass() { + return BaseND4JTest.class; + } +} diff --git a/datavec/datavec-excel/pom.xml b/datavec/datavec-excel/pom.xml index 00fc890d8..49dc26db8 100644 --- a/datavec/datavec-excel/pom.xml +++ b/datavec/datavec-excel/pom.xml @@ -51,6 +51,13 @@ poi-ooxml ${poi.version} + + + org.nd4j + nd4j-common-tests + ${nd4j.version} + test + diff --git a/datavec/datavec-excel/src/test/java/org/datavec/poi/excel/AssertTestsExtendBaseClass.java b/datavec/datavec-excel/src/test/java/org/datavec/poi/excel/AssertTestsExtendBaseClass.java new file mode 100644 index 000000000..1b61f7f6c --- /dev/null +++ b/datavec/datavec-excel/src/test/java/org/datavec/poi/excel/AssertTestsExtendBaseClass.java @@ -0,0 +1,50 @@ +/* ****************************************************************************** + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.datavec.poi.excel; + +import lombok.extern.slf4j.Slf4j; +import org.nd4j.AbstractAssertTestsClass; +import org.nd4j.BaseND4JTest; + +import java.util.*; + +/** + * This class checks that all test classes (i.e., anything with one or more methods annotated with @Test) + * extends BaseND4jTest - either directly or indirectly. + * Other than a small set of exceptions, all tests must extend this + * + * @author Alex Black + */ + +@Slf4j +public class AssertTestsExtendBaseClass extends AbstractAssertTestsClass { + + @Override + protected Set> getExclusions() { + //Set of classes that are exclusions to the rule (either run manually or have their own logging + timeouts) + return new HashSet<>(); + } + + @Override + protected String getPackageName() { + return "org.datavec.poi.excel"; + } + + @Override + protected Class getBaseClass() { + return BaseND4JTest.class; + } +} diff --git a/datavec/datavec-jdbc/pom.xml b/datavec/datavec-jdbc/pom.xml index bfafd25d0..6ef9b0441 100644 --- a/datavec/datavec-jdbc/pom.xml +++ b/datavec/datavec-jdbc/pom.xml @@ -58,6 +58,13 @@ ${derby.version} test + + + org.nd4j + nd4j-common-tests + ${nd4j.version} + test + diff --git a/datavec/datavec-jdbc/src/test/java/org/datavec/api/records/reader/AssertTestsExtendBaseClass.java b/datavec/datavec-jdbc/src/test/java/org/datavec/api/records/reader/AssertTestsExtendBaseClass.java new file mode 100644 index 000000000..1db810b7b --- /dev/null +++ b/datavec/datavec-jdbc/src/test/java/org/datavec/api/records/reader/AssertTestsExtendBaseClass.java @@ -0,0 +1,49 @@ +/* ****************************************************************************** + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.datavec.api.records.reader; + +import lombok.extern.slf4j.Slf4j; +import org.nd4j.AbstractAssertTestsClass; +import org.nd4j.BaseND4JTest; +import java.util.*; + +/** + * This class checks that all test classes (i.e., anything with one or more methods annotated with @Test) + * extends BaseND4jTest - either directly or indirectly. + * Other than a small set of exceptions, all tests must extend this + * + * @author Alex Black + */ + +@Slf4j +public class AssertTestsExtendBaseClass extends AbstractAssertTestsClass { + + @Override + protected Set> getExclusions() { + //Set of classes that are exclusions to the rule (either run manually or have their own logging + timeouts) + return new HashSet<>(); + } + + @Override + protected String getPackageName() { + return "org.datavec.api.records.reader"; + } + + @Override + protected Class getBaseClass() { + return BaseND4JTest.class; + } +} diff --git a/datavec/datavec-local/pom.xml b/datavec/datavec-local/pom.xml index 5c2c6f4ac..3adc0e011 100644 --- a/datavec/datavec-local/pom.xml +++ b/datavec/datavec-local/pom.xml @@ -81,6 +81,13 @@ test + + org.nd4j + nd4j-common-tests + ${nd4j.version} + test + + diff --git a/datavec/datavec-local/src/test/java/org/datavec/local/transforms/AssertTestsExtendBaseClass.java b/datavec/datavec-local/src/test/java/org/datavec/local/transforms/AssertTestsExtendBaseClass.java new file mode 100644 index 000000000..991b8466d --- /dev/null +++ b/datavec/datavec-local/src/test/java/org/datavec/local/transforms/AssertTestsExtendBaseClass.java @@ -0,0 +1,50 @@ +/* ****************************************************************************** + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.datavec.local.transforms; + +import lombok.extern.slf4j.Slf4j; +import org.nd4j.AbstractAssertTestsClass; +import org.nd4j.BaseND4JTest; + +import java.util.*; + +/** + * This class checks that all test classes (i.e., anything with one or more methods annotated with @Test) + * extends BaseDl4jTest - either directly or indirectly. + * Other than a small set of exceptions, all tests must extend this + * + * @author Alex Black + */ + +@Slf4j +public class AssertTestsExtendBaseClass extends AbstractAssertTestsClass { + + @Override + protected Set> getExclusions() { + //Set of classes that are exclusions to the rule (either run manually or have their own logging + timeouts) + return new HashSet<>(); + } + + @Override + protected String getPackageName() { + return "org.datavec.local.transforms"; + } + + @Override + protected Class getBaseClass() { + return BaseND4JTest.class; + } +} diff --git a/datavec/datavec-local/src/test/java/org/datavec/local/transforms/analysis/TestAnalyzeLocal.java b/datavec/datavec-local/src/test/java/org/datavec/local/transforms/analysis/TestAnalyzeLocal.java index ba7048547..1a46789ad 100644 --- a/datavec/datavec-local/src/test/java/org/datavec/local/transforms/analysis/TestAnalyzeLocal.java +++ b/datavec/datavec-local/src/test/java/org/datavec/local/transforms/analysis/TestAnalyzeLocal.java @@ -28,6 +28,7 @@ import org.datavec.local.transforms.AnalyzeLocal; import org.junit.Rule; import org.junit.Test; import org.junit.rules.TemporaryFolder; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.io.ClassPathResource; @@ -63,7 +64,7 @@ public class TestAnalyzeLocal { list.add(rr.next()); } - INDArray arr = RecordConverter.toMatrix(list); + INDArray arr = RecordConverter.toMatrix(DataType.DOUBLE, list); INDArray mean = arr.mean(0); INDArray std = arr.std(0); diff --git a/datavec/datavec-python/pom.xml b/datavec/datavec-python/pom.xml index 55cf6c5da..526b8238a 100644 --- a/datavec/datavec-python/pom.xml +++ b/datavec/datavec-python/pom.xml @@ -64,6 +64,13 @@ nd4j-native-api ${project.version} + + + org.nd4j + nd4j-common-tests + ${nd4j.version} + test + diff --git a/datavec/datavec-python/src/test/java/org/datavec/python/AssertTestsExtendBaseClass.java b/datavec/datavec-python/src/test/java/org/datavec/python/AssertTestsExtendBaseClass.java new file mode 100644 index 000000000..83aa2fe5a --- /dev/null +++ b/datavec/datavec-python/src/test/java/org/datavec/python/AssertTestsExtendBaseClass.java @@ -0,0 +1,50 @@ +/* ****************************************************************************** + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.datavec.python; + +import lombok.extern.slf4j.Slf4j; +import org.nd4j.AbstractAssertTestsClass; +import org.nd4j.BaseND4JTest; + +import java.util.*; + +/** + * This class checks that all test classes (i.e., anything with one or more methods annotated with @Test) + * extends BaseND4jTest - either directly or indirectly. + * Other than a small set of exceptions, all tests must extend this + * + * @author Alex Black + */ + +@Slf4j +public class AssertTestsExtendBaseClass extends AbstractAssertTestsClass { + + @Override + protected Set> getExclusions() { + //Set of classes that are exclusions to the rule (either run manually or have their own logging + timeouts) + return new HashSet<>(); + } + + @Override + protected String getPackageName() { + return "org.datavec.python"; + } + + @Override + protected Class getBaseClass() { + return BaseND4JTest.class; + } +} diff --git a/datavec/datavec-spark-inference-parent/datavec-spark-inference-client/pom.xml b/datavec/datavec-spark-inference-parent/datavec-spark-inference-client/pom.xml index 10cca8e5a..3b564b1b3 100644 --- a/datavec/datavec-spark-inference-parent/datavec-spark-inference-client/pom.xml +++ b/datavec/datavec-spark-inference-parent/datavec-spark-inference-client/pom.xml @@ -30,11 +30,6 @@ - - org.nd4j - nd4j-jackson - ${nd4j.version} - org.datavec datavec-spark-inference-server_2.11 @@ -51,6 +46,13 @@ datavec-spark-inference-model ${project.parent.version} + + + org.nd4j + nd4j-common-tests + ${nd4j.version} + test + diff --git a/datavec/datavec-spark-inference-parent/datavec-spark-inference-client/src/test/java/org/datavec/transform/client/AssertTestsExtendBaseClass.java b/datavec/datavec-spark-inference-parent/datavec-spark-inference-client/src/test/java/org/datavec/transform/client/AssertTestsExtendBaseClass.java new file mode 100644 index 000000000..3bff86e98 --- /dev/null +++ b/datavec/datavec-spark-inference-parent/datavec-spark-inference-client/src/test/java/org/datavec/transform/client/AssertTestsExtendBaseClass.java @@ -0,0 +1,49 @@ +/* ****************************************************************************** + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.datavec.transform.client; + +import lombok.extern.slf4j.Slf4j; +import org.nd4j.AbstractAssertTestsClass; +import org.nd4j.BaseND4JTest; +import java.util.*; + +/** + * This class checks that all test classes (i.e., anything with one or more methods annotated with @Test) + * extends BaseDl4jTest - either directly or indirectly. + * Other than a small set of exceptions, all tests must extend this + * + * @author Alex Black + */ + +@Slf4j +public class AssertTestsExtendBaseClass extends AbstractAssertTestsClass { + + @Override + protected Set> getExclusions() { + //Set of classes that are exclusions to the rule (either run manually or have their own logging + timeouts) + return new HashSet<>(); + } + + @Override + protected String getPackageName() { + return "org.datavec.transform.client"; + } + + @Override + protected Class getBaseClass() { + return BaseND4JTest.class; + } +} diff --git a/datavec/datavec-spark-inference-parent/datavec-spark-inference-model/pom.xml b/datavec/datavec-spark-inference-parent/datavec-spark-inference-model/pom.xml index 470340dc1..bac20d42e 100644 --- a/datavec/datavec-spark-inference-parent/datavec-spark-inference-model/pom.xml +++ b/datavec/datavec-spark-inference-parent/datavec-spark-inference-model/pom.xml @@ -45,6 +45,13 @@ datavec-local ${project.version} + + + org.nd4j + nd4j-common-tests + ${nd4j.version} + test + diff --git a/datavec/datavec-spark-inference-parent/datavec-spark-inference-model/src/main/java/org/datavec/spark/transform/CSVSparkTransform.java b/datavec/datavec-spark-inference-parent/datavec-spark-inference-model/src/main/java/org/datavec/spark/transform/CSVSparkTransform.java index 67d7fe44a..f76e9885f 100644 --- a/datavec/datavec-spark-inference-parent/datavec-spark-inference-model/src/main/java/org/datavec/spark/transform/CSVSparkTransform.java +++ b/datavec/datavec-spark-inference-parent/datavec-spark-inference-model/src/main/java/org/datavec/spark/transform/CSVSparkTransform.java @@ -33,6 +33,7 @@ import org.datavec.spark.transform.model.Base64NDArrayBody; import org.datavec.spark.transform.model.BatchCSVRecord; import org.datavec.spark.transform.model.SequenceBatchCSVRecord; import org.datavec.spark.transform.model.SingleCSVRecord; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.serde.base64.Nd4jBase64; @@ -91,7 +92,7 @@ public class CSVSparkTransform { transformProcess.getInitialSchema(),record.getValues()), transformProcess.getInitialSchema()); List finalRecord = execute(Arrays.asList(record2),transformProcess).get(0); - INDArray convert = RecordConverter.toArray(finalRecord); + INDArray convert = RecordConverter.toArray(DataType.DOUBLE, finalRecord); return new Base64NDArrayBody(Nd4jBase64.base64String(convert)); } diff --git a/datavec/datavec-spark-inference-parent/datavec-spark-inference-model/src/test/java/org/datavec/spark/transform/AssertTestsExtendBaseClass.java b/datavec/datavec-spark-inference-parent/datavec-spark-inference-model/src/test/java/org/datavec/spark/transform/AssertTestsExtendBaseClass.java new file mode 100644 index 000000000..4c6f529b9 --- /dev/null +++ b/datavec/datavec-spark-inference-parent/datavec-spark-inference-model/src/test/java/org/datavec/spark/transform/AssertTestsExtendBaseClass.java @@ -0,0 +1,50 @@ +/* ****************************************************************************** + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.datavec.spark.transform; + +import lombok.extern.slf4j.Slf4j; +import org.nd4j.AbstractAssertTestsClass; +import org.nd4j.BaseND4JTest; + +import java.util.*; + +/** + * This class checks that all test classes (i.e., anything with one or more methods annotated with @Test) + * extends BaseDl4jTest - either directly or indirectly. + * Other than a small set of exceptions, all tests must extend this + * + * @author Alex Black + */ + +@Slf4j +public class AssertTestsExtendBaseClass extends AbstractAssertTestsClass { + + @Override + protected Set> getExclusions() { + //Set of classes that are exclusions to the rule (either run manually or have their own logging + timeouts) + return new HashSet<>(); + } + + @Override + protected String getPackageName() { + return "org.datavec.spark.transform"; + } + + @Override + protected Class getBaseClass() { + return BaseND4JTest.class; + } +} diff --git a/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/pom.xml b/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/pom.xml index 47951b1aa..0c05f327b 100644 --- a/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/pom.xml +++ b/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/pom.xml @@ -164,6 +164,13 @@ spark-core_2.11 ${spark.version} + + + org.nd4j + nd4j-common-tests + ${nd4j.version} + test + diff --git a/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/src/test/java/org/datavec/spark/transform/AssertTestsExtendBaseClass.java b/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/src/test/java/org/datavec/spark/transform/AssertTestsExtendBaseClass.java new file mode 100644 index 000000000..4c6f529b9 --- /dev/null +++ b/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/src/test/java/org/datavec/spark/transform/AssertTestsExtendBaseClass.java @@ -0,0 +1,50 @@ +/* ****************************************************************************** + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.datavec.spark.transform; + +import lombok.extern.slf4j.Slf4j; +import org.nd4j.AbstractAssertTestsClass; +import org.nd4j.BaseND4JTest; + +import java.util.*; + +/** + * This class checks that all test classes (i.e., anything with one or more methods annotated with @Test) + * extends BaseDl4jTest - either directly or indirectly. + * Other than a small set of exceptions, all tests must extend this + * + * @author Alex Black + */ + +@Slf4j +public class AssertTestsExtendBaseClass extends AbstractAssertTestsClass { + + @Override + protected Set> getExclusions() { + //Set of classes that are exclusions to the rule (either run manually or have their own logging + timeouts) + return new HashSet<>(); + } + + @Override + protected String getPackageName() { + return "org.datavec.spark.transform"; + } + + @Override + protected Class getBaseClass() { + return BaseND4JTest.class; + } +} diff --git a/datavec/datavec-spark/pom.xml b/datavec/datavec-spark/pom.xml index 50194c91e..345b774c3 100644 --- a/datavec/datavec-spark/pom.xml +++ b/datavec/datavec-spark/pom.xml @@ -130,6 +130,12 @@ test + + org.nd4j + nd4j-common-tests + ${nd4j.version} + test + diff --git a/datavec/datavec-spark/src/test/java/org/datavec/spark/AssertTestsExtendBaseClass.java b/datavec/datavec-spark/src/test/java/org/datavec/spark/AssertTestsExtendBaseClass.java new file mode 100644 index 000000000..9539251e6 --- /dev/null +++ b/datavec/datavec-spark/src/test/java/org/datavec/spark/AssertTestsExtendBaseClass.java @@ -0,0 +1,50 @@ +/* ****************************************************************************** + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.datavec.spark; + +import lombok.extern.slf4j.Slf4j; +import org.nd4j.AbstractAssertTestsClass; +import org.nd4j.BaseND4JTest; + +import java.util.*; + +/** + * This class checks that all test classes (i.e., anything with one or more methods annotated with @Test) + * extends BaseDl4jTest - either directly or indirectly. + * Other than a small set of exceptions, all tests must extend this + * + * @author Alex Black + */ + +@Slf4j +public class AssertTestsExtendBaseClass extends AbstractAssertTestsClass { + + @Override + protected Set> getExclusions() { + //Set of classes that are exclusions to the rule (either run manually or have their own logging + timeouts) + return new HashSet<>(); + } + + @Override + protected String getPackageName() { + return "org.datavec.spark"; + } + + @Override + protected Class getBaseClass() { + return BaseND4JTest.class; + } +} diff --git a/datavec/datavec-spark/src/test/java/org/datavec/spark/transform/NormalizationTests.java b/datavec/datavec-spark/src/test/java/org/datavec/spark/transform/NormalizationTests.java index 5352ec10d..fcc20d661 100644 --- a/datavec/datavec-spark/src/test/java/org/datavec/spark/transform/NormalizationTests.java +++ b/datavec/datavec-spark/src/test/java/org/datavec/spark/transform/NormalizationTests.java @@ -104,7 +104,7 @@ public class NormalizationTests extends BaseSparkTest { } - INDArray arr = RecordConverter.toMatrix(data); + INDArray arr = RecordConverter.toMatrix(DataType.DOUBLE, data); Schema schema = builder.build(); JavaRDD> rdd = sc.parallelize(data); @@ -127,9 +127,9 @@ public class NormalizationTests extends BaseSparkTest { zeroToOne.transform(new DataSet(zeroToOnes, zeroToOnes)); INDArray zeroMeanUnitVarianceDataFrame = - RecordConverter.toMatrix(Normalization.zeromeanUnitVariance(schema, rdd).collect()); + RecordConverter.toMatrix(DataType.DOUBLE, Normalization.zeromeanUnitVariance(schema, rdd).collect()); INDArray zeroMeanUnitVarianceDataFrameZeroToOne = - RecordConverter.toMatrix(Normalization.normalize(schema, rdd).collect()); + RecordConverter.toMatrix(DataType.DOUBLE, Normalization.normalize(schema, rdd).collect()); assertEquals(standardScalered, zeroMeanUnitVarianceDataFrame); assertTrue(zeroToOnes.equalsWithEps(zeroMeanUnitVarianceDataFrameZeroToOne, 1e-1)); diff --git a/deeplearning4j/deeplearning4j-common-tests/pom.xml b/deeplearning4j/deeplearning4j-common-tests/pom.xml index 23e030df3..5a4ba921d 100644 --- a/deeplearning4j/deeplearning4j-common-tests/pom.xml +++ b/deeplearning4j/deeplearning4j-common-tests/pom.xml @@ -37,6 +37,11 @@ nd4j-api ${project.version} + + org.nd4j + nd4j-common-tests + ${nd4j.version} + ch.qos.logback logback-classic diff --git a/deeplearning4j/deeplearning4j-core/pom.xml b/deeplearning4j/deeplearning4j-core/pom.xml index 496bb6b1b..90c88d4c3 100644 --- a/deeplearning4j/deeplearning4j-core/pom.xml +++ b/deeplearning4j/deeplearning4j-core/pom.xml @@ -164,20 +164,6 @@ oshi-core ${oshi.version} - - - - org.reflections - reflections - ${reflections.version} - test - - - com.google.code.findbugs - * - - - diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/AssertTestsExtendBaseClass.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/AssertTestsExtendBaseClass.java index 5f0567094..34d4db39e 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/AssertTestsExtendBaseClass.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/AssertTestsExtendBaseClass.java @@ -17,15 +17,8 @@ package org.deeplearning4j; import lombok.extern.slf4j.Slf4j; import org.junit.Test; -import org.reflections.Reflections; -import org.reflections.scanners.MethodAnnotationsScanner; -import org.reflections.util.ClasspathHelper; -import org.reflections.util.ConfigurationBuilder; - -import java.lang.reflect.Method; import java.util.*; - -import static org.junit.Assert.assertEquals; +import org.nd4j.AbstractAssertTestsClass; /** * This class checks that all test classes (i.e., anything with one or more methods annotated with @Test) @@ -33,45 +26,24 @@ import static org.junit.Assert.assertEquals; * Other than a small set of exceptions, all tests must extend this * * @author Alex Black + * @author Alexander Stoyakin */ @Slf4j -public class AssertTestsExtendBaseClass extends BaseDL4JTest { +public class AssertTestsExtendBaseClass extends AbstractAssertTestsClass { @Override - public long getTimeoutMilliseconds() { - return 240000L; + protected Set> getExclusions() { + Set> exclusions = new HashSet<>(); + return exclusions; } - //Set of classes that are exclusions to the rule (either run manually or have their own logging + timeouts) - private static final Set> exclusions = new HashSet<>(); + @Override + protected String getPackageName() { + return "org.deeplearning4j"; + } - @Test - public void checkTestClasses(){ - - Reflections reflections = new Reflections(new ConfigurationBuilder() - .setUrls(ClasspathHelper.forPackage("org.deeplearning4j")) - .setScanners(new MethodAnnotationsScanner())); - Set methods = reflections.getMethodsAnnotatedWith(Test.class); - Set> s = new HashSet<>(); - for(Method m : methods){ - s.add(m.getDeclaringClass()); - } - - List> l = new ArrayList<>(s); - Collections.sort(l, new Comparator>() { - @Override - public int compare(Class aClass, Class t1) { - return aClass.getName().compareTo(t1.getName()); - } - }); - - int count = 0; - for(Class c : l){ - if(!BaseDL4JTest.class.isAssignableFrom(c) && !exclusions.contains(c)){ - log.error("Test {} does not extend BaseDL4JTest (directly or indirectly). All tests must extend this class for proper memory tracking and timeouts", c); - count++; - } - } - assertEquals("Number of tests not extending BaseDL4JTest", 0, count); + @Override + protected Class getBaseClass() { + return BaseDL4JTest.class; } } diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/TestUtils.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/TestUtils.java index d90ce628b..d54693f73 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/TestUtils.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/TestUtils.java @@ -31,6 +31,7 @@ import org.deeplearning4j.nn.layers.recurrent.LSTM; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.util.ModelSerializer; import org.nd4j.base.Preconditions; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.random.impl.BernoulliDistribution; import org.nd4j.linalg.factory.Nd4j; @@ -124,12 +125,20 @@ public class TestUtils { return randomOneHot(examples, nOut, new Random(12345)); } + public static INDArray randomOneHot(DataType dataType, long examples, long nOut){ + return randomOneHot(dataType, examples, nOut, new Random(12345)); + } + public static INDArray randomOneHot(long examples, long nOut, long rngSeed){ return randomOneHot(examples, nOut, new Random(rngSeed)); } - public static INDArray randomOneHot(long examples, long nOut, Random rng){ - INDArray arr = Nd4j.create(examples, nOut); + public static INDArray randomOneHot(long examples, long nOut, Random rng) { + return randomOneHot(Nd4j.defaultFloatingPointType(), examples,nOut, rng); + } + + public static INDArray randomOneHot(DataType dataType, long examples, long nOut, Random rng){ + INDArray arr = Nd4j.create(dataType, examples, nOut); for( int i=0; i> classesToTest = new ArrayList<>(); classesToTest.add(org.deeplearning4j.nn.layers.normalization.BatchNormalization.class); @@ -185,10 +191,11 @@ public class ValidateCuDNN extends BaseDL4JTest { //Test ONLY LRN - no other CuDNN functionality (i.e., DL4J impls for everything else) Nd4j.getRandom().setSeed(12345); + int minibatch = 8; int numClasses = 10; //imageHeight,imageWidth,channels - int imageHeight = 240; - int imageWidth = 240; + int imageHeight = 48; + int imageWidth = 48; int channels = 3; IActivation activation = new ActivationIdentity(); MultiLayerConfiguration multiLayerConfiguration = new NeuralNetConfiguration.Builder() @@ -229,8 +236,8 @@ public class ValidateCuDNN extends BaseDL4JTest { MultiLayerNetwork net = new MultiLayerNetwork(multiLayerConfiguration); net.init(); - int[] fShape = new int[]{32, channels, imageHeight, imageWidth}; - int[] lShape = new int[]{32, numClasses}; + int[] fShape = new int[]{minibatch, channels, imageHeight, imageWidth}; + int[] lShape = new int[]{minibatch, numClasses}; List> classesToTest = new ArrayList<>(); classesToTest.add(org.deeplearning4j.nn.layers.normalization.LocalResponseNormalization.class); diff --git a/deeplearning4j/deeplearning4j-cuda/src/test/java/org/deeplearning4j/convolution/TestConvolution.java b/deeplearning4j/deeplearning4j-cuda/src/test/java/org/deeplearning4j/convolution/TestConvolution.java index c89532ff5..67a2958b7 100644 --- a/deeplearning4j/deeplearning4j-cuda/src/test/java/org/deeplearning4j/convolution/TestConvolution.java +++ b/deeplearning4j/deeplearning4j-cuda/src/test/java/org/deeplearning4j/convolution/TestConvolution.java @@ -70,6 +70,11 @@ public class TestConvolution extends BaseDL4JTest { @Rule public TemporaryFolder testDir = new TemporaryFolder(); + @Override + public long getTimeoutMilliseconds() { + return 240000L; + } + @Test public void testSameModeActivationSizes() { int inH = 3; @@ -117,6 +122,8 @@ public class TestConvolution extends BaseDL4JTest { for (ConvolutionMode c : cm) { for (ConvolutionLayer.AlgoMode a : new ConvolutionLayer.AlgoMode[]{ConvolutionLayer.AlgoMode.NO_WORKSPACE, ConvolutionLayer.AlgoMode.PREFER_FASTEST}) { for (boolean conv : new boolean[]{true, false}) { + String msg = c + " - " + a + " - " + (conv ? "conv" : "subsampling"); + System.out.println(msg); org.deeplearning4j.nn.conf.layers.Layer l; if (conv) { @@ -125,7 +132,9 @@ public class TestConvolution extends BaseDL4JTest { l = new SubsamplingLayer.Builder().kernelSize(4, 4).stride(2, 2).build(); } - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345) + MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + .dataType(DataType.DOUBLE) + .seed(12345) .l2(0.0005).updater(new Sgd(0.01)).weightInit(WeightInit.XAVIER).convolutionMode(c).cudnnAlgoMode(a).list() .layer(0, l) .layer(1, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD) @@ -159,32 +168,32 @@ public class TestConvolution extends BaseDL4JTest { throw new RuntimeException(); - INDArray in = Nd4j.rand(new int[]{1, 1, 20, 20}); //(20-4+0)/2 +1 = 9 + INDArray in = Nd4j.rand(DataType.DOUBLE, new int[]{1, 1, 20, 20}); //(20-4+0)/2 +1 = 9 INDArray outCudnn = layerCudnn.activate(in, false, LayerWorkspaceMgr.noWorkspaces()); INDArray outStd = layerStandard.activate(in, false, LayerWorkspaceMgr.noWorkspaces()); - assertEquals(outStd, outCudnn); + assertEquals(msg, outStd, outCudnn); //Check backprop: - INDArray epsilon = Nd4j.rand(outStd.shape()); - Pair pCudnn = layerCudnn.backpropGradient(epsilon, LayerWorkspaceMgr.noWorkspaces()); - Pair pStd = layerStandard.backpropGradient(epsilon, LayerWorkspaceMgr.noWorkspaces()); + INDArray epsilon = Nd4j.rand(DataType.DOUBLE, outStd.shape()); + Pair pCudnn = layerCudnn.backpropGradient(epsilon.dup(), LayerWorkspaceMgr.noWorkspaces()); + Pair pStd = layerStandard.backpropGradient(epsilon.dup(), LayerWorkspaceMgr.noWorkspaces()); - System.out.println(Arrays.toString(pStd.getSecond().data().asFloat())); - System.out.println(Arrays.toString(pCudnn.getSecond().data().asFloat())); +// System.out.println(Arrays.toString(pStd.getSecond().data().asFloat())); +// System.out.println(Arrays.toString(pCudnn.getSecond().data().asFloat())); INDArray epsOutStd = pStd.getSecond(); INDArray epsOutCudnn = pCudnn.getSecond(); - assertTrue(epsOutStd.equalsWithEps(epsOutCudnn, 1e-4)); + assertTrue(msg, epsOutStd.equalsWithEps(epsOutCudnn, 1e-4)); if (conv) { INDArray gradStd = pStd.getFirst().gradient(); INDArray gradCudnn = pCudnn.getFirst().gradient(); - assertTrue(gradStd.equalsWithEps(gradCudnn, 1e-4)); + assertTrue(msg, gradStd.equalsWithEps(gradCudnn, 1e-4)); } } } @@ -192,7 +201,7 @@ public class TestConvolution extends BaseDL4JTest { } - @Test @Ignore //AB 2019/05/21 - Ignored to get master passing - issue logged here: https://github.com/deeplearning4j/deeplearning4j/issues/7766 + @Test public void validateXceptionImport() throws Exception { File dir = testDir.newFolder(); File fSource = Resources.asFile("modelimport/keras/examples/xception/xception_tf_keras_2.h5"); diff --git a/deeplearning4j/deeplearning4j-cuda/src/test/java/org/deeplearning4j/gradientcheck/CNNGradientCheckTest.java b/deeplearning4j/deeplearning4j-cuda/src/test/java/org/deeplearning4j/gradientcheck/CNNGradientCheckTest.java index 1b8c42e14..eb06a70ae 100644 --- a/deeplearning4j/deeplearning4j-cuda/src/test/java/org/deeplearning4j/gradientcheck/CNNGradientCheckTest.java +++ b/deeplearning4j/deeplearning4j-cuda/src/test/java/org/deeplearning4j/gradientcheck/CNNGradientCheckTest.java @@ -61,6 +61,11 @@ public class CNNGradientCheckTest extends BaseDL4JTest { Nd4j.setDataType(DataType.DOUBLE); } + @Override + public long getTimeoutMilliseconds() { + return 180000L; + } + @Test public void testGradientCNNMLN() { //Parameterized test, testing combinations of: diff --git a/deeplearning4j/deeplearning4j-cuda/src/test/java/org/deeplearning4j/gradientcheck/CuDNNGradientChecks.java b/deeplearning4j/deeplearning4j-cuda/src/test/java/org/deeplearning4j/gradientcheck/CuDNNGradientChecks.java index 9e43f042b..a2ab8236f 100644 --- a/deeplearning4j/deeplearning4j-cuda/src/test/java/org/deeplearning4j/gradientcheck/CuDNNGradientChecks.java +++ b/deeplearning4j/deeplearning4j-cuda/src/test/java/org/deeplearning4j/gradientcheck/CuDNNGradientChecks.java @@ -77,6 +77,10 @@ public class CuDNNGradientChecks extends BaseDL4JTest { DataTypeUtil.setDTypeForContext(DataType.DOUBLE); } + @Override + public long getTimeoutMilliseconds() { + return 180000L; + } @Test public void testConvolutional() throws Exception { diff --git a/deeplearning4j/deeplearning4j-cuda/src/test/java/org/deeplearning4j/lstm/ValidateCudnnDropout.java b/deeplearning4j/deeplearning4j-cuda/src/test/java/org/deeplearning4j/lstm/ValidateCudnnDropout.java index 3edf564b5..c46bb99d9 100644 --- a/deeplearning4j/deeplearning4j-cuda/src/test/java/org/deeplearning4j/lstm/ValidateCudnnDropout.java +++ b/deeplearning4j/deeplearning4j-cuda/src/test/java/org/deeplearning4j/lstm/ValidateCudnnDropout.java @@ -30,6 +30,11 @@ import static org.junit.Assert.assertTrue; public class ValidateCudnnDropout extends BaseDL4JTest { + @Override + public long getTimeoutMilliseconds() { + return 180000L; + } + @Test public void testCudnnDropoutSimple() { for (int[] shape : new int[][]{{10, 10}, {5, 2, 5, 2}}) { diff --git a/deeplearning4j/deeplearning4j-cuda/src/test/java/org/deeplearning4j/lstm/ValidateCudnnLSTM.java b/deeplearning4j/deeplearning4j-cuda/src/test/java/org/deeplearning4j/lstm/ValidateCudnnLSTM.java index 08b57aa65..6bbb934a5 100644 --- a/deeplearning4j/deeplearning4j-cuda/src/test/java/org/deeplearning4j/lstm/ValidateCudnnLSTM.java +++ b/deeplearning4j/deeplearning4j-cuda/src/test/java/org/deeplearning4j/lstm/ValidateCudnnLSTM.java @@ -46,6 +46,11 @@ import static org.junit.Assert.*; */ public class ValidateCudnnLSTM extends BaseDL4JTest { + @Override + public long getTimeoutMilliseconds() { + return 180000L; + } + @Test public void validateImplSimple() throws Exception { @@ -109,7 +114,7 @@ public class ValidateCudnnLSTM extends BaseDL4JTest { mln1.computeGradientAndScore(); mln2.computeGradientAndScore(); - assertEquals(mln1.score(), mln2.score(), 1e-8); + assertEquals(mln1.score(), mln2.score(), 1e-5); Gradient g1 = mln1.gradient(); Gradient g2 = mln2.gradient(); diff --git a/deeplearning4j/deeplearning4j-data/deeplearning4j-datavec-iterators/src/main/java/org/deeplearning4j/datasets/datavec/RecordReaderMultiDataSetIterator.java b/deeplearning4j/deeplearning4j-data/deeplearning4j-datavec-iterators/src/main/java/org/deeplearning4j/datasets/datavec/RecordReaderMultiDataSetIterator.java index 2f47c2c8b..92d5d579e 100644 --- a/deeplearning4j/deeplearning4j-data/deeplearning4j-datavec-iterators/src/main/java/org/deeplearning4j/datasets/datavec/RecordReaderMultiDataSetIterator.java +++ b/deeplearning4j/deeplearning4j-data/deeplearning4j-datavec-iterators/src/main/java/org/deeplearning4j/datasets/datavec/RecordReaderMultiDataSetIterator.java @@ -494,7 +494,7 @@ public class RecordReaderMultiDataSetIterator implements MultiDataSetIterator, S List c = list.get(i); if (details.entireReader) { //Convert entire reader contents, without modification - INDArray converted = RecordConverter.toArray(c); + INDArray converted = RecordConverter.toArray(Nd4j.defaultFloatingPointType(), c); putExample(arr, converted, i); } else if (details.oneHot) { //Convert a single column to a one-hot representation diff --git a/deeplearning4j/deeplearning4j-graph/pom.xml b/deeplearning4j/deeplearning4j-graph/pom.xml index ebc6740d9..645b4eca2 100644 --- a/deeplearning4j/deeplearning4j-graph/pom.xml +++ b/deeplearning4j/deeplearning4j-graph/pom.xml @@ -57,7 +57,6 @@ ${project.version} test - diff --git a/deeplearning4j/deeplearning4j-graph/src/test/java/org/deeplearning4j/graph/AssertTestsExtendedBaseClass.java b/deeplearning4j/deeplearning4j-graph/src/test/java/org/deeplearning4j/graph/AssertTestsExtendedBaseClass.java new file mode 100644 index 000000000..7341a3a2c --- /dev/null +++ b/deeplearning4j/deeplearning4j-graph/src/test/java/org/deeplearning4j/graph/AssertTestsExtendedBaseClass.java @@ -0,0 +1,49 @@ +/* ****************************************************************************** + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.deeplearning4j.graph; + +import lombok.extern.slf4j.Slf4j; +import java.util.*; + +import org.deeplearning4j.BaseDL4JTest; +import org.nd4j.AbstractAssertTestsClass; + +/** + * This class checks that all test classes (i.e., anything with one or more methods annotated with @Test) + * extends BaseDl4JTest - either directly or indirectly. + * Other than a small set of exceptions, all tests must extend this + * + * @author Alex Black + * @author Alexander Stoyakin + */ +@Slf4j +public class AssertTestsExtendedBaseClass extends AbstractAssertTestsClass { + + @Override + protected Set> getExclusions() { + Set> exclusions = new HashSet<>(); + return exclusions; + } + + @Override + protected String getPackageName() { + return "org.deeplearning4j.graph"; + } + + @Override + protected Class getBaseClass() {return BaseDL4JTest.class; } +} + diff --git a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/e2e/KerasModelEndToEndTest.java b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/e2e/KerasModelEndToEndTest.java index 7538d39bc..3e1efa365 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/e2e/KerasModelEndToEndTest.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/e2e/KerasModelEndToEndTest.java @@ -98,7 +98,7 @@ public class KerasModelEndToEndTest extends BaseDL4JTest { @Override public long getTimeoutMilliseconds() { - return 90000L; + return 180000L; //Most benchmarks should run very quickly; large timeout is to avoid issues with unusually slow download of test resources } @Test(expected = IllegalStateException.class) diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbors-client/pom.xml b/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbors-client/pom.xml index e3ca20366..0886e8d5b 100644 --- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbors-client/pom.xml +++ b/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbors-client/pom.xml @@ -41,11 +41,6 @@ deeplearning4j-nearestneighbors-model ${project.version} - - org.nd4j - nd4j-jackson - ${nd4j.version} - diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-chinese/src/test/java/org/deeplearning4j/text/tokenization/tokenizer/AssertTestsExtendBaseClass.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-chinese/src/test/java/org/deeplearning4j/text/tokenization/tokenizer/AssertTestsExtendBaseClass.java new file mode 100644 index 000000000..d7c03956f --- /dev/null +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-chinese/src/test/java/org/deeplearning4j/text/tokenization/tokenizer/AssertTestsExtendBaseClass.java @@ -0,0 +1,52 @@ +/* ****************************************************************************** + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.deeplearning4j.text.tokenization.tokenizer; + +import lombok.extern.slf4j.Slf4j; +import java.util.*; + +import org.deeplearning4j.BaseDL4JTest; +import org.nd4j.AbstractAssertTestsClass; + +/** + * This class checks that all test classes (i.e., anything with one or more methods annotated with @Test) + * extends BaseDl4JTest - either directly or indirectly. + * Other than a small set of exceptions, all tests must extend this + * + * @author Alex Black + * @author Alexander Stoyakin + */ +@Slf4j +public class AssertTestsExtendBaseClass extends AbstractAssertTestsClass { + + @Override + protected Set> getExclusions() { + Set> exclusions = new HashSet<>(); + return exclusions; + } + + @Override + protected String getPackageName() { + return "org.deeplearning4j"; + } + + @Override + protected Class getBaseClass() { + return BaseDL4JTest.class; + } +} + + diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-japanese/src/test/java/AssertTestsExtendBaseClass.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-japanese/src/test/java/AssertTestsExtendBaseClass.java new file mode 100644 index 000000000..c767c3e72 --- /dev/null +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-japanese/src/test/java/AssertTestsExtendBaseClass.java @@ -0,0 +1,53 @@ +/* ****************************************************************************** + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +import com.atilika.kuromoji.TestUtils; +import com.atilika.kuromoji.ipadic.RandomizedInputTest; +import lombok.extern.slf4j.Slf4j; +import java.util.*; + +import org.deeplearning4j.BaseDL4JTest; +import org.nd4j.AbstractAssertTestsClass; + +/** + * This class checks that all test classes (i.e., anything with one or more methods annotated with @Test) + * extends BaseDl4JTest - either directly or indirectly. + * Other than a small set of exceptions, all tests must extend this + * + * @author Alex Black + * @author Alexander Stoyakin + */ +@Slf4j +public class AssertTestsExtendBaseClass extends AbstractAssertTestsClass { + + @Override + protected Set> getExclusions() { + Set> exclusions = new HashSet<>(); + exclusions.add(TestUtils.class); + exclusions.add(RandomizedInputTest.class); + return exclusions; + } + + @Override + protected String getPackageName() { + return "org.deeplearning4j"; + } + + @Override + protected Class getBaseClass() { + return BaseDL4JTest.class; + } +} diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-korean/src/test/java/AssertTestsExtendBaseClass.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-korean/src/test/java/AssertTestsExtendBaseClass.java new file mode 100644 index 000000000..ccf95a8ea --- /dev/null +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-korean/src/test/java/AssertTestsExtendBaseClass.java @@ -0,0 +1,49 @@ +/* ****************************************************************************** + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +import lombok.extern.slf4j.Slf4j; +import java.util.*; + +import org.deeplearning4j.BaseDL4JTest; +import org.nd4j.AbstractAssertTestsClass; + +/** + * This class checks that all test classes (i.e., anything with one or more methods annotated with @Test) + * extends BaseDl4JTest - either directly or indirectly. + * Other than a small set of exceptions, all tests must extend this + * + * @author Alex Black + * @author Alexander Stoyakin + */ +@Slf4j +public class AssertTestsExtendBaseClass extends AbstractAssertTestsClass { + + @Override + protected Set> getExclusions() { + Set> exclusions = new HashSet<>(); + return exclusions; + } + + @Override + protected String getPackageName() { + return "org.deeplearning4j"; + } + + @Override + protected Class getBaseClass() { + return BaseDL4JTest.class; + } +} diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-uima/src/test/java/org/deeplearning4j/AssertTestsExtendBaseClass.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-uima/src/test/java/org/deeplearning4j/AssertTestsExtendBaseClass.java new file mode 100644 index 000000000..85b0c39a9 --- /dev/null +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-uima/src/test/java/org/deeplearning4j/AssertTestsExtendBaseClass.java @@ -0,0 +1,49 @@ + +/* ****************************************************************************** + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.deeplearning4j; + +import lombok.extern.slf4j.Slf4j; +import java.util.*; +import org.nd4j.AbstractAssertTestsClass; + +/** + * This class checks that all test classes (i.e., anything with one or more methods annotated with @Test) + * extends BaseDl4JTest - either directly or indirectly. + * Other than a small set of exceptions, all tests must extend this + * + * @author Alex Black + * @author Alexander Stoyakin + */ +@Slf4j +public class AssertTestsExtendBaseClass extends AbstractAssertTestsClass { + + @Override + protected Set> getExclusions() { + Set> exclusions = new HashSet<>(); + return exclusions; + } + + @Override + protected String getPackageName() { + return "org.deeplearning4j"; + } + + @Override + protected Class getBaseClass() { + return BaseDL4JTest.class; + } +} diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/pom.xml b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/pom.xml index f27fd7a94..668c728ae 100644 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/pom.xml +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/pom.xml @@ -14,76 +14,77 @@ ~ SPDX-License-Identifier: Apache-2.0 ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~--> - - 4.0.0 - - org.deeplearning4j - deeplearning4j-nlp-parent - 1.0.0-SNAPSHOT - + + 4.0.0 + + org.deeplearning4j + deeplearning4j-nlp-parent + 1.0.0-SNAPSHOT + - deeplearning4j-nlp + deeplearning4j-nlp - - - org.nd4j - nd4j-native-api - ${nd4j.version} - + + + org.nd4j + nd4j-native-api + ${nd4j.version} + - - commons-lang - commons-lang - 2.6 - - - org.deeplearning4j - deeplearning4j-core - ${project.version} - + + commons-lang + commons-lang + 2.6 + + + org.deeplearning4j + deeplearning4j-core + ${project.version} + - - org.threadly - threadly - ${threadly.version} - + + org.threadly + threadly + ${threadly.version} + - - junit - junit - test - + + junit + junit + test + - - org.mockito - mockito-core - ${mockito.version} - test - + + org.mockito + mockito-core + ${mockito.version} + test + - - ch.qos.logback - logback-classic - test - - - org.apache.commons - commons-lang3 - ${commonslang.version} - - - com.github.vinhkhuc - jfasttext - 0.4 - + + ch.qos.logback + logback-classic + test + + + org.apache.commons + commons-lang3 + ${commonslang.version} + + + com.github.vinhkhuc + jfasttext + 0.4 + - - org.deeplearning4j - deeplearning4j-common-tests - ${project.version} - test - - + + org.deeplearning4j + deeplearning4j-common-tests + ${project.version} + test + + diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/AssertTestsExtendBaseClass.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/AssertTestsExtendBaseClass.java new file mode 100644 index 000000000..6fb3b0316 --- /dev/null +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/AssertTestsExtendBaseClass.java @@ -0,0 +1,49 @@ +/* ****************************************************************************** + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.deeplearning4j; + +import lombok.extern.slf4j.Slf4j; +import java.util.*; +import org.nd4j.AbstractAssertTestsClass; + +/** + * This class checks that all test classes (i.e., anything with one or more methods annotated with @Test) + * extends BaseDl4JTest - either directly or indirectly. + * Other than a small set of exceptions, all tests must extend this + * + * @author Alex Black + * @author Alexander Stoyakin + */ +@Slf4j +public class AssertTestsExtendBaseClass extends AbstractAssertTestsClass { + + @Override + protected Set> getExclusions() { + Set> exclusions = new HashSet<>(); + return exclusions; + } + + @Override + protected String getPackageName() { + return "org.deeplearning4j"; + } + + @Override + protected Class getBaseClass() { + return BaseDL4JTest.class; + } +} + diff --git a/deeplearning4j/deeplearning4j-nn/pom.xml b/deeplearning4j/deeplearning4j-nn/pom.xml index 77acb2dc7..b817c0dc6 100644 --- a/deeplearning4j/deeplearning4j-nn/pom.xml +++ b/deeplearning4j/deeplearning4j-nn/pom.xml @@ -89,12 +89,6 @@ ${nd4j.version} - - org.nd4j - nd4j-jackson - ${nd4j.version} - - com.github.oshi diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LocallyConnected1D.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LocallyConnected1D.java index 60ecbf057..8fedee7b0 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LocallyConnected1D.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LocallyConnected1D.java @@ -31,6 +31,7 @@ import org.nd4j.autodiff.samediff.SDIndex; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.base.Preconditions; +import org.nd4j.enums.PadMode; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.api.memory.MemoryWorkspace; import org.nd4j.linalg.api.ndarray.INDArray; @@ -179,10 +180,10 @@ public class LocallyConnected1D extends SameDiffLayer { //NCW format. if(cm == ConvolutionMode.Same) { layerInput = sameDiff.nn().pad(layerInput, - sameDiff.constant(Nd4j.createFromArray(new int[][]{{0, 0}, {0, 0}, {padding, paddingR}})), 0); + sameDiff.constant(Nd4j.createFromArray(new int[][]{{0, 0}, {0, 0}, {padding, paddingR}})), PadMode.CONSTANT, 0); } else { layerInput = sameDiff.nn().pad(layerInput, - sameDiff.constant(Nd4j.createFromArray(new int[][]{{0, 0}, {0, 0}, {padding, padding}})), 0); + sameDiff.constant(Nd4j.createFromArray(new int[][]{{0, 0}, {0, 0}, {padding, padding}})), PadMode.CONSTANT, 0); } } diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LocallyConnected2D.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LocallyConnected2D.java index 5044017a0..6fad9ec69 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LocallyConnected2D.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LocallyConnected2D.java @@ -32,6 +32,7 @@ import org.nd4j.autodiff.samediff.SDIndex; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.base.Preconditions; +import org.nd4j.enums.PadMode; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.api.memory.MemoryWorkspace; import org.nd4j.linalg.api.ndarray.INDArray; @@ -185,10 +186,10 @@ public class LocallyConnected2D extends SameDiffLayer { //NCHW format if(cm == ConvolutionMode.Same){ layerInput = sameDiff.nn().pad(layerInput, - sameDiff.constant(Nd4j.createFromArray(new int[][]{{0,0},{0,0},{padding[0], paddingBr[0]}, {padding[1], paddingBr[1]}})), 0.0); + sameDiff.constant(Nd4j.createFromArray(new int[][]{{0,0},{0,0},{padding[0], paddingBr[0]}, {padding[1], paddingBr[1]}})), PadMode.CONSTANT, 0.0); } else { layerInput = sameDiff.nn().pad(layerInput, - sameDiff.constant(Nd4j.createFromArray(new int[][]{{0,0},{0,0},{padding[0], padding[0]}, {padding[1], padding[1]}})), 0.0); + sameDiff.constant(Nd4j.createFromArray(new int[][]{{0,0},{0,0},{padding[0], padding[0]}, {padding[1], padding[1]}})), PadMode.CONSTANT, 0.0); } } diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/objdetect/BoundingBoxesDeserializer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/objdetect/BoundingBoxesDeserializer.java new file mode 100644 index 000000000..dec2c6c33 --- /dev/null +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/objdetect/BoundingBoxesDeserializer.java @@ -0,0 +1,65 @@ +/* ****************************************************************************** + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.deeplearning4j.nn.conf.layers.objdetect; + +import org.nd4j.linalg.api.buffer.DataBuffer; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.factory.Nd4j; +import org.nd4j.serde.jackson.shaded.NDArrayTextDeSerializer; +import org.nd4j.shade.jackson.core.JsonParser; +import org.nd4j.shade.jackson.core.JsonProcessingException; +import org.nd4j.shade.jackson.databind.DeserializationContext; +import org.nd4j.shade.jackson.databind.JsonDeserializer; +import org.nd4j.shade.jackson.databind.JsonNode; + +import java.io.IOException; + +/** + * Custom deserializer to handle change in format between beta6 (and earlier) and later versions + * + * @author Alex Black + */ +public class BoundingBoxesDeserializer extends JsonDeserializer { + @Override + public INDArray deserialize(JsonParser jp, DeserializationContext deserializationContext) throws IOException, JsonProcessingException { + JsonNode node = jp.getCodec().readTree(jp); + if(node.has("dataBuffer")){ + //Must be legacy format serialization + JsonNode arr = node.get("dataBuffer"); + int rank = node.get("rankField").asInt(); + int numElements = node.get("numElements").asInt(); + int offset = node.get("offsetField").asInt(); + JsonNode shape = node.get("shapeField"); + JsonNode stride = node.get("strideField"); + int[] shapeArr = new int[rank]; + int[] strideArr = new int[rank]; + DataBuffer buff = Nd4j.createBuffer(numElements); + for (int i = 0; i < numElements; i++) { + buff.put(i, arr.get(i).asDouble()); + } + + String ordering = node.get("orderingField").asText(); + for (int i = 0; i < rank; i++) { + shapeArr[i] = shape.get(i).asInt(); + strideArr[i] = stride.get(i).asInt(); + } + + return Nd4j.create(buff, shapeArr, strideArr, offset, ordering.charAt(0)); + } + //Standard/new format + return new NDArrayTextDeSerializer().deserialize(node); + } +} diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/objdetect/Yolo2OutputLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/objdetect/Yolo2OutputLayer.java index 32f5627d5..24bda07f6 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/objdetect/Yolo2OutputLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/objdetect/Yolo2OutputLayer.java @@ -34,10 +34,9 @@ import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.learning.regularization.Regularization; import org.nd4j.linalg.lossfunctions.ILossFunction; import org.nd4j.linalg.lossfunctions.impl.LossL2; +import org.nd4j.serde.jackson.shaded.NDArrayTextSerializer; import org.nd4j.shade.jackson.databind.annotation.JsonDeserialize; import org.nd4j.shade.jackson.databind.annotation.JsonSerialize; -import org.nd4j.shade.serde.jackson.VectorDeSerializer; -import org.nd4j.shade.serde.jackson.VectorSerializer; import java.util.Arrays; import java.util.Collection; @@ -77,8 +76,8 @@ public class Yolo2OutputLayer extends org.deeplearning4j.nn.conf.layers.Layer { private double lambdaNoObj; private ILossFunction lossPositionScale; private ILossFunction lossClassPredictions; - @JsonSerialize(using = VectorSerializer.class) - @JsonDeserialize(using = VectorDeSerializer.class) + @JsonSerialize(using = NDArrayTextSerializer.class) + @JsonDeserialize(using = BoundingBoxesDeserializer.class) private INDArray boundingBoxes; private Yolo2OutputLayer() { diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/subsampling/SubsamplingLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/subsampling/SubsamplingLayer.java index b3a7a3c12..b38945e95 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/subsampling/SubsamplingLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/subsampling/SubsamplingLayer.java @@ -116,10 +116,10 @@ public class SubsamplingLayer extends AbstractLayer { } //Define the function for external errors: - fn = sameDiff.f().externalErrors(layerOutput); + fn = SameDiffUtils.externalErrors(sameDiff, null,layerOutput); fn.outputVariable(); this.outputKey = outputVar.name(); diff --git a/deeplearning4j/deeplearning4j-remote/deeplearning4j-json-server/src/test/java/org/deeplearning4j/remote/AssertTestsExtendBaseClass.java b/deeplearning4j/deeplearning4j-remote/deeplearning4j-json-server/src/test/java/org/deeplearning4j/remote/AssertTestsExtendBaseClass.java new file mode 100644 index 000000000..1d8c3d578 --- /dev/null +++ b/deeplearning4j/deeplearning4j-remote/deeplearning4j-json-server/src/test/java/org/deeplearning4j/remote/AssertTestsExtendBaseClass.java @@ -0,0 +1,48 @@ +/* ****************************************************************************** + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.deeplearning4j.remote; + +import lombok.extern.slf4j.Slf4j; +import java.util.*; + +import org.deeplearning4j.BaseDL4JTest; +import org.nd4j.AbstractAssertTestsClass; + +/** + * This class checks that all test classes (i.e., anything with one or more methods annotated with @Test) + * extends BaseDl4JTest - either directly or indirectly. + * Other than a small set of exceptions, all tests must extend this + * + * @author Alexander Stoyakin + */ +@Slf4j +public class AssertTestsExtendBaseClass extends AbstractAssertTestsClass { + + @Override + protected Set> getExclusions() { + Set> exclusions = new HashSet<>(); + return exclusions; + } + + @Override + protected String getPackageName() { + return "org.deeplearning4j.remote"; + } + + @Override + protected Class getBaseClass() { return BaseDL4JTest.class; } +} + diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/test/java/org/deeplearning4j/spark/models/sequencevectors/SparkSequenceVectorsTest.java b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/test/java/org/deeplearning4j/spark/models/sequencevectors/SparkSequenceVectorsTest.java index 445788799..15e449e94 100644 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/test/java/org/deeplearning4j/spark/models/sequencevectors/SparkSequenceVectorsTest.java +++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/test/java/org/deeplearning4j/spark/models/sequencevectors/SparkSequenceVectorsTest.java @@ -43,6 +43,12 @@ import static org.junit.Assert.assertNotEquals; * @author raver119@gmail.com */ public class SparkSequenceVectorsTest extends BaseDL4JTest { + + @Override + public long getTimeoutMilliseconds() { + return 120000L; + } + protected static List> sequencesCyclic; private JavaSparkContext sc; diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/test/java/org/deeplearning4j/spark/models/word2vec/SparkWord2VecTest.java b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/test/java/org/deeplearning4j/spark/models/word2vec/SparkWord2VecTest.java index f3b3f974a..a7bdfd45b 100644 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/test/java/org/deeplearning4j/spark/models/word2vec/SparkWord2VecTest.java +++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/test/java/org/deeplearning4j/spark/models/word2vec/SparkWord2VecTest.java @@ -48,6 +48,12 @@ import static org.junit.Assert.*; * @author raver119@gmail.com */ public class SparkWord2VecTest extends BaseDL4JTest { + + @Override + public long getTimeoutMilliseconds() { + return 120000L; + } + private static List sentences; private JavaSparkContext sc; diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/test/java/org/deeplearning4j/spark/text/BaseSparkTest.java b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/test/java/org/deeplearning4j/spark/text/BaseSparkTest.java index 475572edd..af39a474c 100644 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/test/java/org/deeplearning4j/spark/text/BaseSparkTest.java +++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/test/java/org/deeplearning4j/spark/text/BaseSparkTest.java @@ -34,6 +34,11 @@ import java.util.Map; public abstract class BaseSparkTest extends BaseDL4JTest implements Serializable { protected transient JavaSparkContext sc; + @Override + public long getTimeoutMilliseconds() { + return 120000L; + } + @Before public void before() throws Exception { sc = getContext(); diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/test/java/org/deeplearning4j/spark/text/TextPipelineTest.java b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/test/java/org/deeplearning4j/spark/text/TextPipelineTest.java index 63c84de7d..9e5ad1d67 100644 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/test/java/org/deeplearning4j/spark/text/TextPipelineTest.java +++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/test/java/org/deeplearning4j/spark/text/TextPipelineTest.java @@ -31,6 +31,7 @@ import org.deeplearning4j.spark.text.functions.CountCumSum; import org.deeplearning4j.spark.text.functions.TextPipeline; import org.deeplearning4j.text.stopwords.StopWords; import org.junit.Before; +import org.junit.Ignore; import org.junit.Test; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.primitives.Counter; @@ -350,7 +351,7 @@ public class TextPipelineTest extends BaseSparkTest { * * @throws Exception */ - @Test + @Test @Ignore //AB 2020/04/19 https://github.com/eclipse/deeplearning4j/issues/8849 public void testZipFunction1() throws Exception { JavaSparkContext sc = getContext(); JavaRDD corpusRDD = getCorpusRDD(sc); @@ -388,7 +389,7 @@ public class TextPipelineTest extends BaseSparkTest { sc.stop(); } - @Test + @Test @Ignore //AB 2020/04/19 https://github.com/eclipse/deeplearning4j/issues/8849 public void testZipFunction2() throws Exception { JavaSparkContext sc = getContext(); JavaRDD corpusRDD = getCorpusRDD(sc); diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/test/java/org/deeplearning4j/spark/parameterserver/BaseSparkTest.java b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/test/java/org/deeplearning4j/spark/parameterserver/BaseSparkTest.java index 50aa564c1..9a28fe351 100644 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/test/java/org/deeplearning4j/spark/parameterserver/BaseSparkTest.java +++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/test/java/org/deeplearning4j/spark/parameterserver/BaseSparkTest.java @@ -53,6 +53,12 @@ public abstract class BaseSparkTest extends BaseDL4JTest implements Serializable protected transient DataSet data; protected transient JavaRDD sparkData; + + @Override + public long getTimeoutMilliseconds() { + return 120000L; + } + @Before public void before() { diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/BaseSparkKryoTest.java b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/BaseSparkKryoTest.java index 42fa57a37..f7b4da172 100644 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/BaseSparkKryoTest.java +++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/BaseSparkKryoTest.java @@ -28,6 +28,11 @@ import java.util.Map; */ public class BaseSparkKryoTest extends BaseSparkTest { + @Override + public long getTimeoutMilliseconds() { + return 120000L; + } + @Override public JavaSparkContext getContext() { if (sc != null) { diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/BaseSparkTest.java b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/BaseSparkTest.java index 3d1a9755a..be78ec7cd 100644 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/BaseSparkTest.java +++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/BaseSparkTest.java @@ -55,6 +55,11 @@ public abstract class BaseSparkTest extends BaseDL4JTest implements Serializable protected transient DataSet data; protected transient JavaRDD sparkData; + @Override + public long getTimeoutMilliseconds() { + return 120000L; + } + @Before public void before() { diff --git a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-ui/pom.xml b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-ui/pom.xml index 282867e7e..44b868ae6 100644 --- a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-ui/pom.xml +++ b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-ui/pom.xml @@ -44,11 +44,6 @@ deeplearning4j-nlp ${project.version} - - org.nd4j - nd4j-jackson - ${nd4j.version} - junit diff --git a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-ui/src/main/java/org/deeplearning4j/ui/weights/ConvolutionalIterationListener.java b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-ui/src/main/java/org/deeplearning4j/ui/weights/ConvolutionalIterationListener.java index 037c44b40..4aae630ca 100644 --- a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-ui/src/main/java/org/deeplearning4j/ui/weights/ConvolutionalIterationListener.java +++ b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-ui/src/main/java/org/deeplearning4j/ui/weights/ConvolutionalIterationListener.java @@ -140,7 +140,7 @@ public class ConvolutionalIterationListener extends BaseTrainingListener { ComputationGraph l = (ComputationGraph) model; Layer[] layers = l.getLayers(); if(layers.length != activations.size()) - throw new RuntimeException(); + throw new RuntimeException("layers.length != activations.size(). Got layers.length="+layers.length+", activations.size()="+activations.size()); for( int i=0; i& whatArrsToCheck = std::vector(), const std::vector& IdxRange = {0., 1.}, const LossFunc loss = SUM, const std::vector& outArrsFFIdx = {}); + const std::vector& whatArrsToCheck = std::vector(), const std::vector& IdxRange = {0., 1.}, const LossFunc loss = SUM); }; diff --git a/libnd4j/include/helpers/impl/GradCheck.cpp b/libnd4j/include/helpers/impl/GradCheck.cpp index f3daa798c..12ecab75f 100644 --- a/libnd4j/include/helpers/impl/GradCheck.cpp +++ b/libnd4j/include/helpers/impl/GradCheck.cpp @@ -49,7 +49,7 @@ void GradCheck::fillGradArrays(const LossFunc loss, const std::vector& ////////////////////////////////////////////////////////////////////////// bool GradCheck::checkGrad(ops::DeclarableOp& opFF, ops::DeclarableOp& opBP, const OpArgsHolder& argsHolderFF, const OpArgsHolder& argsHolderBP, - const std::vector& whatArrsToCheck, const std::vector& idxRange, const LossFunc loss, const std::vector& outArrsFFIdx) { + const std::vector& whatArrsToCheck, const std::vector& idxRange, const LossFunc loss) { const int numInArrsFF = argsHolderFF.getNumInArrs(); // at the same time numInArrsFF = number of output arrays in opBP const int numInGradArrsBP = argsHolderBP.getNumInArrs() - numInArrsFF; // because argsHolderBP.getNumInArrs() = numInArrsFF + numInGradArrsBP @@ -82,23 +82,12 @@ bool GradCheck::checkGrad(ops::DeclarableOp& opFF, ops::DeclarableOp& opBP, cons int numOutArrs = outArrsFF.size(); double scorePlus = 0.; - if(!outArrsFFIdx.empty()) { - for(const auto& k : outArrsFFIdx) { // loop through independent output arrays - if(loss == SUM) - outArrsFF.at(k)->reduceNumber(reduce::Sum, tmpScalar); - else - outArrsFF.at(k)->reduceNumber(reduce::Mean, tmpScalar); - scorePlus += tmpScalar.e(0); - } - } - else { - for(int k = 0; k < numOutArrs; ++k) { // loop through output arrays - if(loss == SUM) - outArrsFF.at(k)->reduceNumber(reduce::Sum, tmpScalar); - else - outArrsFF.at(k)->reduceNumber(reduce::Mean, tmpScalar); - scorePlus += tmpScalar.e(0); - } + for(int k = 0; k < numOutArrs; ++k) { // loop through output arrays + if(loss == SUM) + outArrsFF.at(k)->reduceNumber(reduce::Sum, tmpScalar); + else + outArrsFF.at(k)->reduceNumber(reduce::Mean, tmpScalar); + scorePlus += tmpScalar.e(0); } // subtract epsilon, feed forward @@ -106,23 +95,12 @@ bool GradCheck::checkGrad(ops::DeclarableOp& opFF, ops::DeclarableOp& opBP, cons outArrsFF = opFF.execute(argsHolderFF); double scoreMinus = 0.; - if(!outArrsFFIdx.empty()) { - for(const auto& k : outArrsFFIdx) { // loop through independent output arrays - if(loss == SUM) - outArrsFF.at(k)->reduceNumber(reduce::Sum, tmpScalar); - else - outArrsFF.at(k)->reduceNumber(reduce::Mean, tmpScalar); - scoreMinus += tmpScalar.e(0); - } - } - else { - for(int k = 0; k < numOutArrs; ++k) { // loop through output arrays - if(loss == SUM) - outArrsFF.at(k)->reduceNumber(reduce::Sum, tmpScalar); - else - outArrsFF.at(k)->reduceNumber(reduce::Mean, tmpScalar); - scoreMinus += tmpScalar.e(0); - } + for(int k = 0; k < numOutArrs; ++k) { // loop through output arrays + if(loss == SUM) + outArrsFF.at(k)->reduceNumber(reduce::Sum, tmpScalar); + else + outArrsFF.at(k)->reduceNumber(reduce::Mean, tmpScalar); + scoreMinus += tmpScalar.e(0); } // restore initial element value diff --git a/libnd4j/include/helpers/impl/RandomLauncher.cpp b/libnd4j/include/helpers/impl/RandomLauncher.cpp index 8114c2ec4..f7cdd0f3a 100644 --- a/libnd4j/include/helpers/impl/RandomLauncher.cpp +++ b/libnd4j/include/helpers/impl/RandomLauncher.cpp @@ -26,8 +26,6 @@ #include namespace sd { - // FIXME: implement this - void RandomLauncher::applyDropOut(sd::LaunchContext *context, sd::graph::RandomGenerator& rng, NDArray *array, double retainProb, NDArray* z) { if (z == nullptr) z = array; @@ -35,8 +33,12 @@ namespace sd { ExtraArguments arguments({retainProb}); PointersManager pm(context, "applyDropOut"); + NDArray::prepareSpecialUse({z}, {array}); + NativeOpExecutioner::execRandom(context, random::DropOut, &rng, array->buffer(), array->shapeInfo(), array->specialBuffer(), array->specialShapeInfo(), z->buffer(), z->shapeInfo(), z->specialBuffer(), z->specialShapeInfo(), arguments.argumentsAsT(z->dataType())); pm.synchronize(); + + NDArray::registerSpecialUse({z}, {array}); } void RandomLauncher::applyInvertedDropOut(sd::LaunchContext *context, sd::graph::RandomGenerator& rng, NDArray *array, double retainProb, NDArray* z) { @@ -46,8 +48,12 @@ namespace sd { ExtraArguments arguments({retainProb}); PointersManager pm(context, "applyInvertedDropOut"); + NDArray::prepareSpecialUse({z}, {array}); + NativeOpExecutioner::execRandom(context, random::DropOutInverted, &rng, array->buffer(), array->shapeInfo(), array->specialBuffer(), array->specialShapeInfo(), z->buffer(), z->shapeInfo(), z->specialBuffer(), z->specialShapeInfo(), arguments.argumentsAsT(z->dataType())); pm.synchronize(); + + NDArray::registerSpecialUse({z}, {array}); } void RandomLauncher::applyAlphaDropOut(sd::LaunchContext *context, sd::graph::RandomGenerator& rng, NDArray *array, double retainProb, double alpha, double beta, double alphaPrime, NDArray* z) { @@ -57,63 +63,95 @@ namespace sd { ExtraArguments arguments({retainProb, alpha, beta, alphaPrime}); PointersManager pm(context, "applyAlphaDropOut"); + NDArray::prepareSpecialUse({z}, {array}); + NativeOpExecutioner::execRandom(context, random::AlphaDropOut, &rng, array->buffer(), array->shapeInfo(), array->specialBuffer(), array->specialShapeInfo(), z->buffer(), z->shapeInfo(), z->specialBuffer(), z->specialShapeInfo(), arguments.argumentsAsT(z->dataType())); pm.synchronize(); + + NDArray::registerSpecialUse({z}, {array}); } void RandomLauncher::fillBernoulli(sd::LaunchContext *context, sd::graph::RandomGenerator& rng, NDArray* array, double prob) { ExtraArguments arguments({prob}); PointersManager pm(context, "fillBernoulli"); + NDArray::prepareSpecialUse({array}, {}); + NativeOpExecutioner::execRandom(context, random::BernoulliDistribution, &rng, array->buffer(), array->shapeInfo(), array->specialBuffer(), array->specialShapeInfo(), arguments.argumentsAsT(array->dataType())); pm.synchronize(); + + NDArray::registerSpecialUse({array}, {}); } void RandomLauncher::fillUniform(sd::LaunchContext *context, sd::graph::RandomGenerator& rng, NDArray* array, double from, double to) { ExtraArguments arguments({from, to}); PointersManager pm(context, "fillUniform"); + NDArray::prepareSpecialUse({array}, {}); + NativeOpExecutioner::execRandom(context, random::UniformDistribution, &rng, array->buffer(), array->shapeInfo(), array->specialBuffer(), array->specialShapeInfo(), arguments.argumentsAsT(array->dataType())); pm.synchronize(); + + NDArray::registerSpecialUse({array}, {}); } void RandomLauncher::fillGaussian(sd::LaunchContext *context, sd::graph::RandomGenerator& rng, NDArray* array, double mean, double stdev) { ExtraArguments arguments({mean, stdev}); PointersManager pm(context, "fillGaussian"); + NDArray::prepareSpecialUse({array}, {}); + NativeOpExecutioner::execRandom(context, random::GaussianDistribution, &rng, array->buffer(), array->shapeInfo(), array->specialBuffer(), array->specialShapeInfo(), array->buffer(), array->shapeInfo(), array->specialBuffer(), array->specialShapeInfo(), array->buffer(), array->shapeInfo(), array->specialBuffer(), array->specialShapeInfo(), arguments.argumentsAsT(array->dataType())); pm.synchronize(); + + NDArray::registerSpecialUse({array}, {}); } void RandomLauncher::fillExponential(sd::LaunchContext *context, sd::graph::RandomGenerator& rng, NDArray* array, double lambda) { ExtraArguments arguments({lambda}); PointersManager pm(context, "fillExponential"); + NDArray::prepareSpecialUse({array}, {}); + NativeOpExecutioner::execRandom(context, random::ExponentialDistribution, &rng, array->buffer(), array->shapeInfo(), array->specialBuffer(), array->specialShapeInfo(), arguments.argumentsAsT(array->dataType())); pm.synchronize(); + + NDArray::registerSpecialUse({array}, {}); } void RandomLauncher::fillLogNormal(sd::LaunchContext *context, sd::graph::RandomGenerator& rng, NDArray* array, double mean, double stdev) { ExtraArguments arguments({mean, stdev}); PointersManager pm(context, "fillLogNormal"); + NDArray::prepareSpecialUse({array}, {}); + NativeOpExecutioner::execRandom(context, random::GaussianDistribution, &rng, array->buffer(), array->shapeInfo(), array->specialBuffer(), array->specialShapeInfo(), array->buffer(), array->shapeInfo(), array->specialBuffer(), array->specialShapeInfo(), array->buffer(), array->shapeInfo(), array->specialBuffer(), array->specialShapeInfo(), arguments.argumentsAsT(array->dataType())); pm.synchronize(); + + NDArray::registerSpecialUse({array}, {}); } void RandomLauncher::fillTruncatedNormal(sd::LaunchContext *context, sd::graph::RandomGenerator& rng, NDArray* array, double mean, double stdev) { ExtraArguments arguments({mean, stdev}); PointersManager pm(context, "fillTruncatedNormal"); + NDArray::prepareSpecialUse({array}, {}); + NativeOpExecutioner::execRandom(context, random::TruncatedNormalDistribution, &rng, array->buffer(), array->shapeInfo(), array->specialBuffer(), array->specialShapeInfo(), array->buffer(), array->shapeInfo(), array->specialBuffer(), array->specialShapeInfo(), array->buffer(), array->shapeInfo(), array->specialBuffer(), array->specialShapeInfo(), arguments.argumentsAsT(array->dataType())); pm.synchronize(); + + NDArray::registerSpecialUse({array}, {}); } void RandomLauncher::fillBinomial(sd::LaunchContext *context, sd::graph::RandomGenerator& rng, NDArray* array, int trials, double prob) { ExtraArguments arguments({(double) trials, prob}); PointersManager pm(context, "fillBinomial"); + NDArray::prepareSpecialUse({array}, {}); + NativeOpExecutioner::execRandom(context, random::BinomialDistributionEx, &rng, array->buffer(), array->shapeInfo(), array->specialBuffer(), array->specialShapeInfo(), array->buffer(), array->shapeInfo(), array->specialBuffer(), array->specialShapeInfo(), array->buffer(), array->shapeInfo(), array->specialBuffer(), array->specialShapeInfo(), arguments.argumentsAsT(array->dataType())); pm.synchronize(); + + NDArray::registerSpecialUse({array}, {}); } } diff --git a/libnd4j/include/ops/declarable/generic/nn/recurrent/gru.cpp b/libnd4j/include/ops/declarable/generic/nn/recurrent/gru.cpp index c6cd2e8f1..dee9a7c88 100644 --- a/libnd4j/include/ops/declarable/generic/nn/recurrent/gru.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/recurrent/gru.cpp @@ -1,5 +1,6 @@ /******************************************************************************* * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2020 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -15,7 +16,7 @@ ******************************************************************************/ // -// created by Yurii Shyrma on 15.02.2018 +// @author Yurii Shyrma (iuriish@yahoo.com) // #include @@ -30,83 +31,157 @@ namespace ops { ////////////////////////////////////////////////////////////////////////// CUSTOM_OP_IMPL(gru, 5, 1, false, 0, 0) { - auto x = INPUT_VARIABLE(0); // input [time x bS x iS] - auto h0 = INPUT_VARIABLE(1); // initial cell output (at time step = 0) [bS x nU] + auto x = INPUT_VARIABLE(0); // input [time, bS, nIn] + auto hI = INPUT_VARIABLE(1); // initial cell output (at time step = 0) [bS, nOut] - auto Wx = INPUT_VARIABLE(2); // input-to-hidden weights, [iS x 3*nU] - auto Wh = INPUT_VARIABLE(3); // hidden-to-hidden weights, [nU x 3*nU] - auto b = INPUT_VARIABLE(4); // biases, [3*nU] + auto Wx = INPUT_VARIABLE(2); // input-to-hidden weights, [nIn, 3*nOut] + auto Wh = INPUT_VARIABLE(3); // hidden-to-hidden weights, [nOut, 3*nOut] + auto b = INPUT_VARIABLE(4); // biases, [3*nOut] - auto h = OUTPUT_VARIABLE(0); // cell outputs [time x bS x nU], that is per each time step + auto h = OUTPUT_VARIABLE(0); // cell outputs [time, bS, nOut], that is per each time step - const int rank = x->rankOf(); // = 3 - const int time = x->sizeAt(0); - const int bS = x->sizeAt(1); - const int iS = x->sizeAt(2); - const int nU = h0->sizeAt(1); + const int bS = x->sizeAt(1); + const int nIn = x->sizeAt(2); + const int nOut = hI->sizeAt(1); - const std::vector h0CorrectShape = {bS, nU}; - const std::vector wxCorrectShape = {iS, 3*nU}; - const std::vector whCorrectShape = {nU, 3*nU}; - const std::vector bCorrectShape = {3*nU}; + const std::vector h0CorrectShape = {bS, nOut}; + const std::vector wxCorrectShape = {nIn, 3*nOut}; + const std::vector whCorrectShape = {nOut, 3*nOut}; + const std::vector bCorrectShape = {3*nOut}; - REQUIRE_TRUE(h0->isSameShape(h0CorrectShape), 0, "GRU operation: wrong shape of previous cell output array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(h0CorrectShape).c_str(), ShapeUtils::shapeAsString(h0).c_str()); + REQUIRE_TRUE(hI->isSameShape(h0CorrectShape), 0, "GRU operation: wrong shape of previous cell output array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(h0CorrectShape).c_str(), ShapeUtils::shapeAsString(hI).c_str()); REQUIRE_TRUE(Wx->isSameShape(wxCorrectShape), 0, "GRU operation: wrong shape of input-to-hidden weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(wxCorrectShape).c_str(), ShapeUtils::shapeAsString(Wx).c_str()); REQUIRE_TRUE(Wh->isSameShape(whCorrectShape), 0, "GRU operation: wrong shape of hidden-to-hidden weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(whCorrectShape).c_str(), ShapeUtils::shapeAsString(Wh).c_str()); REQUIRE_TRUE(b->isSameShape(bCorrectShape), 0, "GRU operation: wrong shape of biases array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(bCorrectShape).c_str(), ShapeUtils::shapeAsString(b).c_str()); - helpers::gruTimeLoop(block.launchContext(), x, h0, Wx, Wh, b, h); + helpers::gruTimeLoop(block.launchContext(), x, hI, Wx, Wh, b, h); return Status::OK(); } +////////////////////////////////////////////////////////////////////////// +DECLARE_TYPES(gru) { + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); +} - DECLARE_TYPES(gru) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes({ALL_FLOATS}); - } - - +////////////////////////////////////////////////////////////////////////// DECLARE_SHAPE_FN(gru) { - const auto xShapeInfo = inputShape->at(0); // input [time x bS x inSize] - const auto h0ShapeInfo = inputShape->at(1); // initial cell output [bS x numUnits], that is at time step t=0 - const auto WxShapeInfo = inputShape->at(2); // input-to-hidden weights, [inSize x 3*numUnits] - const auto WhShapeInfo = inputShape->at(3); // hidden-to-hidden weights, [numUnits x 3*numUnits] - const auto bShapeInfo = inputShape->at(4); // biases, [3*numUnits] - const int rank = shape::rank(xShapeInfo); // = 3 - const auto time = xShapeInfo[1]; - const auto bS = xShapeInfo[2]; - const auto inSize = xShapeInfo[3]; - const auto numUnits = h0ShapeInfo[2]; + auto x = INPUT_VARIABLE(0); // input [time, bS, nIn] + auto hI = INPUT_VARIABLE(1); // initial cell output (at time step = 0) [bS, nOut] - const std::vector h0CorrectShape = {bS, numUnits}; - const std::vector wxCorrectShape = {inSize, 3*numUnits}; - const std::vector whCorrectShape = {numUnits, 3*numUnits}; - const std::vector bCorrectShape = {3*numUnits}; + auto Wx = INPUT_VARIABLE(2); // input-to-hidden weights, [nIn, 3*nOut] + auto Wh = INPUT_VARIABLE(3); // hidden-to-hidden weights, [nOut, 3*nOut] + auto b = INPUT_VARIABLE(4); // biases, [3*nOut] - REQUIRE_TRUE(ShapeUtils::areShapesEqual(h0ShapeInfo, h0CorrectShape), 0, "GRU operation: wrong shape of previous cell output array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(h0CorrectShape).c_str(), ShapeUtils::shapeAsString(h0ShapeInfo).c_str()); - REQUIRE_TRUE(ShapeUtils::areShapesEqual(WxShapeInfo, wxCorrectShape), 0, "GRU operation: wrong shape of input-to-hidden weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(wxCorrectShape).c_str(), ShapeUtils::shapeAsString(WxShapeInfo).c_str()); - REQUIRE_TRUE(ShapeUtils::areShapesEqual(WhShapeInfo, whCorrectShape), 0, "GRU operation: wrong shape of hidden-to-hidden weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(whCorrectShape).c_str(), ShapeUtils::shapeAsString(WhShapeInfo).c_str()); - REQUIRE_TRUE(ShapeUtils::areShapesEqual(bShapeInfo, bCorrectShape), 0, "GRU operation: wrong shape of biases array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(bCorrectShape).c_str(), ShapeUtils::shapeAsString(bShapeInfo).c_str()); + const int time = x->sizeAt(0); + const int bS = x->sizeAt(1); + const int nIn = x->sizeAt(2); + const int nOut = hI->sizeAt(1); + const std::vector h0CorrectShape = {bS, nOut}; + const std::vector wxCorrectShape = {nIn, 3*nOut}; + const std::vector whCorrectShape = {nOut, 3*nOut}; + const std::vector bCorrectShape = {3*nOut}; - // evaluate output shapeInfo - Nd4jLong *hShapeInfo(nullptr); - ALLOCATE(hShapeInfo, block.getWorkspace(), shape::shapeInfoLength(rank), Nd4jLong); + REQUIRE_TRUE(hI->isSameShape(h0CorrectShape), 0, "GRU operation: wrong shape of previous cell output array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(h0CorrectShape).c_str(), ShapeUtils::shapeAsString(hI).c_str()); + REQUIRE_TRUE(Wx->isSameShape(wxCorrectShape), 0, "GRU operation: wrong shape of input-to-hidden weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(wxCorrectShape).c_str(), ShapeUtils::shapeAsString(Wx).c_str()); + REQUIRE_TRUE(Wh->isSameShape(whCorrectShape), 0, "GRU operation: wrong shape of hidden-to-hidden weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(whCorrectShape).c_str(), ShapeUtils::shapeAsString(Wh).c_str()); + REQUIRE_TRUE(b->isSameShape(bCorrectShape), 0, "GRU operation: wrong shape of biases array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(bCorrectShape).c_str(), ShapeUtils::shapeAsString(b).c_str()); - hShapeInfo[0] = rank; - hShapeInfo[1] = time; - hShapeInfo[2] = bS; - hShapeInfo[3] = numUnits; - - ShapeUtils::updateStridesAndType(hShapeInfo, xShapeInfo, shape::order(h0ShapeInfo)); + auto* hShapeInfo = ConstantShapeHelper::getInstance()->createShapeInfo(hI->dataType(), hI->ordering(), {time, bS, nOut}); return SHAPELIST(hShapeInfo); } +////////////////////////////////////////////////////////////////////////// +CUSTOM_OP_IMPL(gru_bp, 6, 5, false, 0, 0) { + + auto x = INPUT_VARIABLE(0); // input [time, bS, nIn] + auto hI = INPUT_VARIABLE(1); // initial cell output (at time step = 0) [bS, nOut] + + auto Wx = INPUT_VARIABLE(2); // input-to-hidden weights, [nIn, 3*nOut] + auto Wh = INPUT_VARIABLE(3); // hidden-to-hidden weights, [nOut, 3*nOut] + auto b = INPUT_VARIABLE(4); // biases, [3*nOut] + + auto dLdh = INPUT_VARIABLE(5); // gradient vs. ff output, [time, bS, nOut] + + auto dLdx = OUTPUT_VARIABLE(0); // gradient vs. ff input, [time, bS, nIn] + auto dLdhI = OUTPUT_NULLIFIED(1); // gradient vs. initial cell output, [bS, nOut] + auto dLdWx = OUTPUT_NULLIFIED(2); // gradient vs. input-to-hidden weights, [nIn, 3*nOut] + auto dLdWh = OUTPUT_NULLIFIED(3); // gradient vs. hidden-to-hidden weights, [nOut, 3*nOut] + auto dLdb = OUTPUT_NULLIFIED(4); // gradient vs. biases [3*nOut] + + const int time = x->sizeAt(0); + const int bS = x->sizeAt(1); + const int nIn = x->sizeAt(2); + const int nOut = hI->sizeAt(1); + + const std::vector h0CorrectShape = {bS, nOut}; + const std::vector wxCorrectShape = {nIn, 3*nOut}; + const std::vector whCorrectShape = {nOut, 3*nOut}; + const std::vector bCorrectShape = {3*nOut}; + const std::vector hCorrectShape = {time, bS, nOut}; + + REQUIRE_TRUE(hI->isSameShape(h0CorrectShape), 0, "GRU_BP operation: wrong shape of previous cell output array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(h0CorrectShape).c_str(), ShapeUtils::shapeAsString(hI).c_str()); + REQUIRE_TRUE(Wx->isSameShape(wxCorrectShape), 0, "GRU_BP operation: wrong shape of input-to-hidden weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(wxCorrectShape).c_str(), ShapeUtils::shapeAsString(Wx).c_str()); + REQUIRE_TRUE(Wh->isSameShape(whCorrectShape), 0, "GRU_BP operation: wrong shape of hidden-to-hidden weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(whCorrectShape).c_str(), ShapeUtils::shapeAsString(Wh).c_str()); + REQUIRE_TRUE(b->isSameShape(bCorrectShape), 0, "GRU_BP operation: wrong shape of biases array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(bCorrectShape).c_str(), ShapeUtils::shapeAsString(b).c_str()); + REQUIRE_TRUE(dLdh->isSameShape(hCorrectShape),0, "GRU_BP operation: wrong shape of gradient vs. ff output, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(hCorrectShape).c_str(), ShapeUtils::shapeAsString(dLdh).c_str()); + + helpers::gruTimeLoopBp(block.launchContext(), x, hI, Wx, Wh, b, dLdh, dLdx, dLdhI, dLdWx, dLdWh, dLdb); + + return Status::OK(); +} + +////////////////////////////////////////////////////////////////////////// +DECLARE_TYPES(gru_bp) { + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); +} + +////////////////////////////////////////////////////////////////////////// +DECLARE_SHAPE_FN(gru_bp) { + + auto x = INPUT_VARIABLE(0); // input [time, bS, nIn] + auto hI = INPUT_VARIABLE(1); // initial cell output (at time step = 0) [bS, nOut] + + auto Wx = INPUT_VARIABLE(2); // input-to-hidden weights, [nIn, 3*nOut] + auto Wh = INPUT_VARIABLE(3); // hidden-to-hidden weights, [nOut, 3*nOut] + auto b = INPUT_VARIABLE(4); // biases, [3*nOut] + + auto dLdh = INPUT_VARIABLE(5); // gradient vs. ff output, [time, bS, nOut] + + const int time = x->sizeAt(0); + const int bS = x->sizeAt(1); + const int nIn = x->sizeAt(2); + const int nOut = hI->sizeAt(1); + + const std::vector h0CorrectShape = {bS, nOut}; + const std::vector wxCorrectShape = {nIn, 3*nOut}; + const std::vector whCorrectShape = {nOut, 3*nOut}; + const std::vector bCorrectShape = {3*nOut}; + const std::vector hCorrectShape = {time, bS, nOut}; + + REQUIRE_TRUE(hI->isSameShape(h0CorrectShape), 0, "GRU_BP operation: wrong shape of previous cell output array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(h0CorrectShape).c_str(), ShapeUtils::shapeAsString(hI).c_str()); + REQUIRE_TRUE(Wx->isSameShape(wxCorrectShape), 0, "GRU_BP operation: wrong shape of input-to-hidden weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(wxCorrectShape).c_str(), ShapeUtils::shapeAsString(Wx).c_str()); + REQUIRE_TRUE(Wh->isSameShape(whCorrectShape), 0, "GRU_BP operation: wrong shape of hidden-to-hidden weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(whCorrectShape).c_str(), ShapeUtils::shapeAsString(Wh).c_str()); + REQUIRE_TRUE(b->isSameShape(bCorrectShape), 0, "GRU_BP operation: wrong shape of biases array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(bCorrectShape).c_str(), ShapeUtils::shapeAsString(b).c_str()); + REQUIRE_TRUE(dLdh->isSameShape(hCorrectShape),0, "GRU_BP operation: wrong shape of gradient vs. ff output, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(hCorrectShape).c_str(), ShapeUtils::shapeAsString(dLdh).c_str()); + + Nd4jLong* dLdxShapeInfo = ConstantShapeHelper::getInstance()->createShapeInfo(dLdh->dataType(), x->getShapeInfo()); + Nd4jLong* dLdhIShapeInfo = ConstantShapeHelper::getInstance()->createShapeInfo(dLdh->dataType(), hI->getShapeInfo()); + Nd4jLong* dLdWxShapeInfo = ConstantShapeHelper::getInstance()->createShapeInfo(dLdh->dataType(), Wx->getShapeInfo()); + Nd4jLong* dLdWhShapeInfo = ConstantShapeHelper::getInstance()->createShapeInfo(dLdh->dataType(), Wh->getShapeInfo()); + Nd4jLong* dLdbShapeInfo = ConstantShapeHelper::getInstance()->createShapeInfo(dLdh->dataType(), b->getShapeInfo()); + + return SHAPELIST(dLdxShapeInfo, dLdhIShapeInfo, dLdWxShapeInfo, dLdWhShapeInfo, dLdbShapeInfo); +} + } } diff --git a/libnd4j/include/ops/declarable/generic/nn/recurrent/gruCell.cpp b/libnd4j/include/ops/declarable/generic/nn/recurrent/gruCell.cpp index 204a1ca63..037f09736 100644 --- a/libnd4j/include/ops/declarable/generic/nn/recurrent/gruCell.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/recurrent/gruCell.cpp @@ -161,7 +161,7 @@ CUSTOM_OP_IMPL(gruCell_bp, 10, 6, false, 0, 0) { REQUIRE_TRUE(dLdc->isSameShape(hiCorrectShape), 0, "GRU_CELL_BP op: wrong shape of dLdc array (gradient wrt cell state), expected is %s, but got %s instead !", ShapeUtils::shapeAsString(hiCorrectShape).c_str(), ShapeUtils::shapeAsString(dLdc).c_str()); REQUIRE_TRUE(dLdh->isSameShape(hiCorrectShape), 0, "GRU_CELL_BP op: wrong shape of dLdh array (gradient wrt current cell output), expected is %s, but got %s instead !", ShapeUtils::shapeAsString(hiCorrectShape).c_str(), ShapeUtils::shapeAsString(dLdh).c_str()); - helpers::gruCellBP(block.launchContext(), x, hi, W, Wc, b, bc, dLdr, dLdu, dLdc, dLdh, dLdx, dLdhi, dLdW, dLdWc, dLdb, dLdbc); + helpers::gruCellBp(block.launchContext(), x, hi, W, Wc, b, bc, dLdr, dLdu, dLdc, dLdh, dLdx, dLdhi, dLdW, dLdWc, dLdb, dLdbc); return Status::OK(); } diff --git a/libnd4j/include/ops/declarable/generic/nn/recurrent/lstmLayer.cpp b/libnd4j/include/ops/declarable/generic/nn/recurrent/lstmLayer.cpp index 8637fe990..871291165 100644 --- a/libnd4j/include/ops/declarable/generic/nn/recurrent/lstmLayer.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/recurrent/lstmLayer.cpp @@ -727,12 +727,10 @@ CUSTOM_OP_IMPL(lstmLayer_bp, 4, 1, false, 1, 5) { dLdcLBwd = new NDArray((*dLdcL)({1,2, 0,0, 0,0})); } - // FIXME looks like sum (directionMode == 2) is impossible for backprop if(dLdh) { if(directionMode == 2) { // sum - REQUIRE_TRUE(false, 0, "LSTM_LAYER_BP operation: mode for bidirectional sum and dLdh being present has no sense for backpropagation !"); - // dLdhFwd = dLdh; - // dLdhBwd = new NDArray(dLdh->ordering(), dLdh->getShapeAsVector(), dLdh->dataType(), dLdh->getContext()); // automatically nullifies content + dLdhFwd = dLdh; + dLdhBwd = dLdh; } else if(directionMode == 3) { // concat dLdhFwd = new NDArray(dataFormat <= 1 ? (*dLdh)({0,0, 0,0, 0,nOut}) : (*dLdh)({0,0, 0,nOut, 0,0})); @@ -744,21 +742,20 @@ CUSTOM_OP_IMPL(lstmLayer_bp, 4, 1, false, 1, 5) { } } + NDArray dLdxBwd = dLdx->ulike(); - + // FIXME - following two calls are independent and may run in different streams helpers::lstmLayerTimeLoopBp(x, &WxFwd, &WrFwd, bFwd, seqLen, hIFwd, cIFwd, WpFwd, dLdhFwd, dLdhLFwd, dLdcLFwd, params, true, dLdx, &dLdWxFwd, &dLdWrFwd, dLdbFwd, dLdhIFwd, dLdcIFwd, dLdWpFwd); - NDArray dLdxBwd = dLdx->ulike(); helpers::lstmLayerTimeLoopBp(x, &WxBwd, &WrBwd, bBwd, seqLen, hIBwd, cIBwd, WpBwd, dLdhBwd, dLdhLBwd, dLdcLBwd, params, false, &dLdxBwd, &dLdWxBwd, &dLdWrBwd, dLdbBwd, dLdhIBwd, dLdcIBwd, dLdWpBwd); *dLdx += dLdxBwd; delete WpFwd; delete WpBwd; delete bFwd; delete bBwd; delete hIFwd; delete hIBwd; delete cIFwd; delete cIBwd; - delete dLdhBwd; delete dLdhLFwd; delete dLdhLBwd; delete dLdcLFwd; delete dLdcLBwd; + delete dLdhLFwd; delete dLdhLBwd; delete dLdcLFwd; delete dLdcLBwd; delete dLdWpFwd; delete dLdWpBwd; delete dLdbFwd; delete dLdbBwd; delete dLdhIFwd; delete dLdhIBwd; delete dLdcIFwd; delete dLdcIBwd; - if(dLdhFwd != dLdh) - delete dLdhFwd; + if(!(dLdh && directionMode == 2)) { delete dLdhFwd; delete dLdhBwd; } } return Status::OK(); diff --git a/libnd4j/include/ops/declarable/generic/nn/recurrent/lstmLayerCell.cpp b/libnd4j/include/ops/declarable/generic/nn/recurrent/lstmLayerCell.cpp index 46f32e399..4f24219bd 100644 --- a/libnd4j/include/ops/declarable/generic/nn/recurrent/lstmLayerCell.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/recurrent/lstmLayerCell.cpp @@ -293,7 +293,7 @@ CUSTOM_OP_IMPL(lstmLayerCellBp, 7, 5, false, 1, 3) { helpers::lstmLayerCell(x,Wx, Wr, b, hI, cI, Wp, params, &z, &a, &h, &c); - helpers::lstmLayerCellBp(x, Wx, Wr, b, hI, cI, Wp, dLdh, nullptr, &z, &a, &c, params, dLdx, dLdWx, dLdWr, dLdhI, dLdcI, dLdb, dLdWp); + helpers::lstmLayerCellBp(x, Wx, Wr, b, hI, cI, Wp, dLdh, nullptr, nullptr, &z, &a, &c, params, dLdx, dLdWx, dLdWr, dLdhI, dLdcI, dLdb, dLdWp); return Status::OK(); } diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/unsorted_segment_max.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/unsorted_segment_max.cpp index 77e851104..8ca01540c 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/unsorted_segment_max.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/unsorted_segment_max.cpp @@ -29,11 +29,11 @@ namespace sd { auto segmentedOutput = OUTPUT_NULLIFIED(0); Nd4jLong numOfClasses = block.width() == 3 ? INPUT_VARIABLE(2)->e(0) : INT_ARG(0); REQUIRE_TRUE(idxSegments->isVector(), 0, "unsorted_segment_max: segment indexes array should be a vector, but it rank is %i.", idxSegments->rankOf()); - REQUIRE_TRUE(idxSegments->lengthOf() == input->sizeAt(0), 0, "unsorted_segment_max: segment indexes array length should be equal to the input first dimension, but %i != %i.", idxSegments->lengthOf(), input->sizeAt(0)); + REQUIRE_TRUE(idxSegments->lengthOf() == input->sizeAt(0), 0, "unsorted_segment_max: segment indexes array length should be equal to the input first dimension, but %ld != %ild.", idxSegments->lengthOf(), input->sizeAt(0)); Nd4jLong wrong; - REQUIRE_TRUE(helpers::unsortedSegmentIndicesValidate(block.launchContext(), idxSegments, numOfClasses, wrong), 0, "unsorted_segment_max: segment indices should be in range [0, %i), but %i > %i", + REQUIRE_TRUE(helpers::unsortedSegmentIndicesValidate(block.launchContext(), idxSegments, numOfClasses, wrong), 0, "unsorted_segment_max: segment indices should be in range [0, %ld), but %ld != %ld", numOfClasses, wrong, numOfClasses); helpers::unsortedSegmentMaxFunctor(block.launchContext(), input, idxSegments, numOfClasses, segmentedOutput); diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/unsorted_segment_mean.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/unsorted_segment_mean.cpp index cad59b7e9..7aa46295c 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/unsorted_segment_mean.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/unsorted_segment_mean.cpp @@ -30,11 +30,11 @@ namespace sd { Nd4jLong numOfClasses = block.width() == 3 ? INPUT_VARIABLE(2)->e(0) : INT_ARG(0); REQUIRE_TRUE(idxSegments->isVector(), 0, "unsorted_segment_mean: segment indexes array should be a vector, but it rank is %i.", idxSegments->rankOf()); - REQUIRE_TRUE(idxSegments->lengthOf() == input->sizeAt(0), 0, "unsorted_segment_mean: segment indexes array length should be equal to the input first dimension, but %i != %i.", idxSegments->lengthOf(), input->sizeAt(0)); + REQUIRE_TRUE(idxSegments->lengthOf() == input->sizeAt(0), 0, "unsorted_segment_mean: segment indexes array length should be equal to the input first dimension, but %ld != %ld.", idxSegments->lengthOf(), input->sizeAt(0)); Nd4jLong wrong; - REQUIRE_TRUE(helpers::unsortedSegmentIndicesValidate(block.launchContext(), idxSegments, numOfClasses, wrong), 0, "unsorted_segment_mean: segment indices should be in range [0, %i), but %i > %i", + REQUIRE_TRUE(helpers::unsortedSegmentIndicesValidate(block.launchContext(), idxSegments, numOfClasses, wrong), 0, "unsorted_segment_mean: segment indices should be in range [0, %ld), but %ld != %ld", numOfClasses, wrong, numOfClasses); helpers::unsortedSegmentMeanFunctor(block.launchContext(), input, idxSegments, numOfClasses, segmentedOutput); diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/unsorted_segment_min.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/unsorted_segment_min.cpp index 87b96e844..76dd982f7 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/unsorted_segment_min.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/unsorted_segment_min.cpp @@ -29,11 +29,11 @@ namespace sd { auto segmentedOutput = OUTPUT_NULLIFIED(0); Nd4jLong numOfClasses = block.width() == 3 ? INPUT_VARIABLE(2)->e(0) : INT_ARG(0); REQUIRE_TRUE(idxSegments->isVector(), 0, "unsorted_segment_min: segment indexes array should be a vector, but it rank is %i.", idxSegments->rankOf()); - REQUIRE_TRUE(idxSegments->lengthOf() == input->sizeAt(0), 0, "unsorted_segment_min: segment indexes array length should be equal to the input first dimension, but %i != %i.", idxSegments->lengthOf(), input->sizeAt(0)); + REQUIRE_TRUE(idxSegments->lengthOf() == input->sizeAt(0), 0, "unsorted_segment_min: segment indexes array length should be equal to the input first dimension, but %ld != %ld.", idxSegments->lengthOf(), input->sizeAt(0)); Nd4jLong wrong; - REQUIRE_TRUE(helpers::unsortedSegmentIndicesValidate(block.launchContext(), idxSegments, numOfClasses, wrong), 0, "unsorted_segment_min: segment indices should be in range [0, %i), but %i > %i", + REQUIRE_TRUE(helpers::unsortedSegmentIndicesValidate(block.launchContext(), idxSegments, numOfClasses, wrong), 0, "unsorted_segment_min: segment indices should be in range [0, %ld), but %ld > %ld", numOfClasses, wrong, numOfClasses); helpers::unsortedSegmentMinFunctor(block.launchContext(), input, idxSegments, numOfClasses, segmentedOutput); diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/unsorted_segment_prod.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/unsorted_segment_prod.cpp index e430c8f77..d2f491c55 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/unsorted_segment_prod.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/unsorted_segment_prod.cpp @@ -29,11 +29,11 @@ namespace sd { auto segmentedOutput = OUTPUT_NULLIFIED(0); Nd4jLong numOfClasses = block.width() == 3 ? INPUT_VARIABLE(2)->e(0) : INT_ARG(0); REQUIRE_TRUE(idxSegments->isVector(), 0, "unsorted_segment_prod: segment indexes array should be a vector, but it rank is %i.", idxSegments->rankOf()); - REQUIRE_TRUE(idxSegments->lengthOf() == input->sizeAt(0), 0, "unsorted_segment_prod: segment indexes array length should be equal to the input first dimension, but %i != %i.", idxSegments->lengthOf(), input->sizeAt(0)); + REQUIRE_TRUE(idxSegments->lengthOf() == input->sizeAt(0), 0, "unsorted_segment_prod: segment indexes array length should be equal to the input first dimension, but %ld != %ld.", idxSegments->lengthOf(), input->sizeAt(0)); Nd4jLong wrong = 0; - REQUIRE_TRUE(helpers::unsortedSegmentIndicesValidate(block.launchContext(), idxSegments, numOfClasses, wrong), 0, "unsorted_segment_prod: segment indices should be in range [0, %i), but %i > %i", + REQUIRE_TRUE(helpers::unsortedSegmentIndicesValidate(block.launchContext(), idxSegments, numOfClasses, wrong), 0, "unsorted_segment_prod: segment indices should be in range [0, %ld), but %ld != %ld", numOfClasses, wrong, numOfClasses); helpers::unsortedSegmentProdFunctor(block.launchContext(), input, idxSegments, numOfClasses, segmentedOutput); diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/unsorted_segment_sqrt_n.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/unsorted_segment_sqrt_n.cpp index eeaa6e2c2..a8dbf8eaf 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/unsorted_segment_sqrt_n.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/unsorted_segment_sqrt_n.cpp @@ -29,11 +29,11 @@ namespace sd { auto segmentedOutput = OUTPUT_NULLIFIED(0); Nd4jLong numOfClasses = block.width() == 3 ? INPUT_VARIABLE(2)->e(0) : INT_ARG(0); REQUIRE_TRUE(idxSegments->isVector(), 0, "unsorted_segment_sqrt_n: segment indexes array should be a vector, but it rank is %i.", idxSegments->rankOf()); - REQUIRE_TRUE(idxSegments->lengthOf() == input->sizeAt(0), 0, "unsorted_segment_sqrt_n: segment indexes array length should be equal to the input first dimension, but %i != %i.", idxSegments->lengthOf(), input->sizeAt(0)); + REQUIRE_TRUE(idxSegments->lengthOf() == input->sizeAt(0), 0, "unsorted_segment_sqrt_n: segment indexes array length should be equal to the input first dimension, but %ld != %ld.", idxSegments->lengthOf(), input->sizeAt(0)); Nd4jLong wrong; - REQUIRE_TRUE(helpers::unsortedSegmentIndicesValidate(block.launchContext(), idxSegments, numOfClasses, wrong), 0, "unsorted_segment_sqrt_n: segment indices should be in range [0, %i), but %i > %i", + REQUIRE_TRUE(helpers::unsortedSegmentIndicesValidate(block.launchContext(), idxSegments, numOfClasses, wrong), 0, "unsorted_segment_sqrt_n: segment indices should be in range [0, %ld), but %ld != %ld", numOfClasses, wrong, numOfClasses); helpers::unsortedSegmentSqrtNFunctor(block.launchContext(), input, idxSegments, numOfClasses, segmentedOutput); diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/unsorted_segment_sum.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/unsorted_segment_sum.cpp index 941496424..1afcab34f 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/unsorted_segment_sum.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/unsorted_segment_sum.cpp @@ -29,11 +29,11 @@ namespace sd { auto segmentedOutput = OUTPUT_NULLIFIED(0); Nd4jLong numOfClasses = block.width() == 3 ? INPUT_VARIABLE(2)->e(0) : INT_ARG(0); REQUIRE_TRUE(idxSegments->isVector(), 0, "unsorted_segment_sum: segment indexes array should be a vector, but it rank is %i.", idxSegments->rankOf()); - REQUIRE_TRUE(idxSegments->lengthOf() == input->sizeAt(0), 0, "unsorted_segment_sum: segment indexes array length should be equal to the input first dimension, but %i != %i.", idxSegments->lengthOf(), input->sizeAt(0)); + REQUIRE_TRUE(idxSegments->lengthOf() == input->sizeAt(0), 0, "unsorted_segment_sum: segment indexes array length should be equal to the input first dimension, but %ld != %ld", idxSegments->lengthOf(), input->sizeAt(0)); Nd4jLong wrong; - REQUIRE_TRUE(helpers::unsortedSegmentIndicesValidate(block.launchContext(), idxSegments, numOfClasses, wrong), 0, "unsorted_segment_sum: segment indices should be in range [0, %i), but %i > %i", + REQUIRE_TRUE(helpers::unsortedSegmentIndicesValidate(block.launchContext(), idxSegments, numOfClasses, wrong), 0, "unsorted_segment_sum: segment indices should be in range [0, %ld), but %ld > %ld", numOfClasses, wrong, numOfClasses); helpers::unsortedSegmentSumFunctor(block.launchContext(), input, idxSegments, numOfClasses, segmentedOutput); diff --git a/libnd4j/include/ops/declarable/generic/tensor/lin_space.cpp b/libnd4j/include/ops/declarable/generic/tensor/lin_space.cpp index 54fd8fb0e..374456be6 100644 --- a/libnd4j/include/ops/declarable/generic/tensor/lin_space.cpp +++ b/libnd4j/include/ops/declarable/generic/tensor/lin_space.cpp @@ -26,24 +26,38 @@ namespace sd { namespace ops { - CUSTOM_OP_IMPL(lin_space, 3, 1, false, 0, 0) { - auto output = OUTPUT_VARIABLE(0); - auto start = INPUT_VARIABLE(0); - auto finish = INPUT_VARIABLE(1); - auto numOfElements = INPUT_VARIABLE(2); + CUSTOM_OP_IMPL(lin_space, 0, 1, false, 0, 0) { - if (numOfElements->e(0) == 1) { + auto output = OUTPUT_VARIABLE(0); + + const int nInputs = block.width(); + bool bInputs = (3 == nInputs || 3 == block.numI() || (2 == block.numT() && block.numI() > 0)); + + REQUIRE_TRUE(bInputs, 0, "lin_space OP: Have to be supplied correct inputs, input size or T_ARG size have to be equal 3, but got inputs - %i, T_ARGS - %i!", nInputs, block.numT()); + + auto start = (nInputs > 0) ? INPUT_VARIABLE(0)->e(0) : static_cast(T_ARG(0)); + auto finish = (nInputs > 0) ? INPUT_VARIABLE(1)->e(0) : static_cast(T_ARG(1)); + auto numOfElements = (nInputs > 0) ? INPUT_VARIABLE(2)->e(0) : static_cast(I_ARG(0)); + + if (numOfElements == 1) { output->assign(start); return Status::OK(); } - output->linspace(start->e(0), (finish->e(0) - start->e(0)) / (numOfElements->e(0) - 1.)); + output->linspace(start, (finish - start) / ( numOfElements - 1.0 )); return Status::OK(); } DECLARE_SHAPE_FN(lin_space) { - auto dataType = ArrayOptions::dataType(inputShape->at(0)); - Nd4jLong steps = INPUT_VARIABLE(2)->e(0); + + const int nInputs = block.width(); + bool bInputs = (3 == nInputs || 3 == block.numI() || (2 == block.numT() && block.numI() > 0)); + REQUIRE_TRUE(bInputs, 0, "lin_space OP: Have to be supplied correct inputs, input size or T_ARG size have to be equal 3, but got inputs - %i, T_ARGS - %i!", nInputs, block.numT() ); + + + auto dataType = (nInputs > 0) ? ArrayOptions::dataType(inputShape->at(0)) : ( block.numD() > 0 ? static_cast(D_ARG(0)) : DataType::FLOAT32) ; + Nd4jLong steps = (nInputs > 0) ? INPUT_VARIABLE(2)->e(0) : static_cast(I_ARG(0)); + return SHAPELIST(ConstantShapeHelper::getInstance()->vectorShapeInfo(steps, dataType)); } diff --git a/libnd4j/include/ops/declarable/headers/parity_ops.h b/libnd4j/include/ops/declarable/headers/parity_ops.h index f3131c193..8fae1b63c 100644 --- a/libnd4j/include/ops/declarable/headers/parity_ops.h +++ b/libnd4j/include/ops/declarable/headers/parity_ops.h @@ -1433,16 +1433,20 @@ namespace sd { /** * lin_space - op porting from TF (https://www.tensorflow.org/api_docs/python/tf/lin_space) * - * input params: + * optional input params: * 0 - startVal - NDArray scalar (float point) * 1 - finishVal - NDArray scalar (float point) * 2 - numOfElements - NDArray scalar (integer) - * + * Optional: + * T args + * 0 - startVal + * 1 - finishVal] + * 2 - numOfElements * output: * 0 - 1D NDArray with the same type as input and length as given with numOfElements param. */ #if NOT_EXCLUDED(OP_lin_space) - DECLARE_CUSTOM_OP(lin_space, 3, 1, false, 0, 0); + DECLARE_CUSTOM_OP(lin_space, 0, 1, false, 0, 0); #endif /** diff --git a/libnd4j/include/ops/declarable/headers/recurrent.h b/libnd4j/include/ops/declarable/headers/recurrent.h index dd219867f..aeeae24c4 100644 --- a/libnd4j/include/ops/declarable/headers/recurrent.h +++ b/libnd4j/include/ops/declarable/headers/recurrent.h @@ -345,6 +345,10 @@ namespace ops { DECLARE_CUSTOM_OP(gru, 5, 1, false, 0, 0); #endif + #if NOT_EXCLUDED(OP_gru) + DECLARE_CUSTOM_OP(gru_bp, 6, 5, false, 0, 0); + #endif + ////////////////////////////////////////////////////////////////////////// /** * Implementation of operation "static RNN time sequences" with peep hole connections: diff --git a/libnd4j/include/ops/declarable/helpers/cpu/gru.cpp b/libnd4j/include/ops/declarable/helpers/cpu/gru.cpp deleted file mode 100644 index b00036b81..000000000 --- a/libnd4j/include/ops/declarable/helpers/cpu/gru.cpp +++ /dev/null @@ -1,421 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -// -// @author Yurii Shyrma (iuriish@yahoo.com), created on 15.02.2018, Alex Black -// - -// implementation of gated Recurrent Unit cell -// (cf. https://arxiv.org/abs/1406.1078). -// Kyunghyun Cho, Bart van Merrienboer, Caglar Gulcehre, Dzmitry Bahdanau, Fethi Bougares, Holger Schwenk, Yoshua Bengio -// "Learning Phrase Representations using RNN Encoder-Decoder for Statistical Machine Translation" - - -#include -#include -#include -#include - -namespace sd { -namespace ops { -namespace helpers { - - -////////////////////////////////////////////////////////////////////////// -void gruCell(sd::LaunchContext * context, const NDArray* x, const NDArray* hLast, const NDArray* W, const NDArray* Wc, - const NDArray* b, const NDArray* bc, - NDArray* r, NDArray* u, NDArray* c, NDArray* h) { - - //Inputs: - // x input [bS, iS], iS - input size - // hLast previous cell output [bS, nU], that is at previous time step t-1, nU - number of units - // W RU weights - [iS+nU, 2*nU] - reset and update gates - // Wc C weights - [iS+nU, nU] - cell gate - // b r and u biases, [2*nU] - reset and update gates - // bc c biases, [nU] - cell gate - - //Outputs: - // r Reset gate output [bS, nU] - // u Update gate output [bS, nU] - // c Cell gate output [bS, nU] - // h current cell output [bS, nU] - - /***************************************************************************************/ - /************************ THIS IS NOT OPTIMAZED CODE ***********************************/ - /** however it is more math-friendly and convenient for backprop formulas derivation) **/ - - const int bS = x->sizeAt(0); - const int iS = x->sizeAt(1); - const int nU = hLast->sizeAt(1); - - NDArray Wrx = (*W)({0,iS, 0,nU}); // [iS, nU] - NDArray Wux = (*W)({0,iS, nU,2*nU}); // [iS, nU] - NDArray Wrh = (*W)({iS,iS+nU, 0,nU}); // [nU, nU] - NDArray Wuh = (*W)({iS,iS+nU, nU,2*nU}); // [nU, nU] - - NDArray Wcx = (*Wc)({0,iS, 0,0}); // reset cell weights [iS, nU] - NDArray Wch = (*Wc)({iS,iS+nU, 0,0}); // updates cell weights [nU, nU] - - NDArray br = (*b)({0, nU}); // [nU] - NDArray bu = (*b)({nU, 2*nU}); // [nU] - - // × means matrix multipication - // * means element-wise product or so called Hadamard product - - // reset gate - r->assign(mmul(*x, Wrx) + mmul(*hLast, Wrh) + br); // [bS, iS] × [iS, nU] + [bS, nU] × [nU, nU] + [nU] = [bS, nU] - r->applyTransform(transform::Sigmoid, *r); - - // update gate - u->assign(mmul(*x, Wux) + mmul(*hLast, Wuh) + bu); // [bS, iS] × [iS, nU] + [bS, nU] × [nU, nU] + [nU] = [bS, nU] - u->applyTransform(transform::Sigmoid, *u); - - // cell gate c = activation(x × Wcx + (r * hlast) × Wch + bc) - c->assign(mmul(*x, Wcx) + mmul(*r * *hLast, Wch) + *bc); // [bS, iS] × [iS, nU] + [bS, nU] × [nU, nU] + [nU] = [bS, nU] - c->applyTransform(transform::Tanh, *c); - - NDArray temp = 1.f - *c * *c; - - // cell output - h->assign(*u * *hLast + (1.f - *u) * *c); - - - /***************************************************************************************/ - /*************** THIS IS MORE OPTIMAZED CODE (should think about concat) ***************/ - /***************************************************************************************/ -/* - //Concat inputs: x + hLast : [bs, iS + nU] - NDArray xhConcat(x->ordering(), {bS, iS + nU}, x->dataType(), context); // concat([bs, iS], [bs, nU]) -> [bs, iS + nU] - helpers::concat(context, {const_cast(x), const_cast(hLast)}, xhConcat, {1}); - - //mmul for reset and update gates: (x × weight_ux + hLast × weight_xr + b_u) - auto m = mmul(xhConcat, *W) + *b ; // [bs, iS+nU] * [iS+nU, 2*nU] = [bs, 2*nU] - // m += *bru; - - m.applyTransform(transform::Sigmoid); //sigmoid(rz) and sigmoid(uz) - - r->assign(m({0,0, 0, nU})); - u->assign(m({0,0, nU, 2*nU})); - - // hLast = hLast * r - xhConcat({0,0, iS, iS+nU}) *= *r; - - //c = tanh(x × weight_cx + (hLast * r) × weight_cr + b_c) - MmulHelper::mmul(&xhConcat, Wc, c, 1.0, 0.0); //c = 1.0 * xhConcat * Wc + 0.0 * c - *c += *bc; - c->applyTransform(transform::Tanh); - - //Output: h = (1-u).*c + u .* hPrev - //auto hResult = (*u) * (*hLast) + (1.0f - *u) * (*c); const_cast(h)->assign(&hResult); - u->applyPairwiseTransform(pairwise::Multiply, hLast, h, nullptr); //h = u * hLast - auto temp = (1.0f - *u); - temp *= (*c); - (*h) += temp; -*/ -} - -////////////////////////////////////////////////////////////////////////// -void gruTimeLoop(sd::LaunchContext * context, const NDArray* x, const NDArray* hLast, const NDArray* Wx, const NDArray* Wh, const NDArray* b, NDArray* h) { - - // x input [time, bS, iS] - // hLast initial cell output (at time step = 0) [bS, nU] - // Wx input-to-hidden weights, [iS, 3*nU] - // Wh hidden-to-hidden weights, [nU, 3*nU] - // b biases, [3*nU] - - // h is cell outputs at each time step [time, bS, nU] - - const int time = x->sizeAt(0); - - NDArray ht_1(*hLast); - - // loop through time steps - for (int t = 0; t < time; ++t) { - - auto xt = (*x)({t,t+1, 0,0, 0,0}); - auto ht = (*h)({t,t+1, 0,0, 0,0}); - - // helpers::gruCell(&xt, &ht_1, Wx, Wh, b, &ht); - // ht_1.assign(ht); - } -} - -////////////////////////////////////////////////////////////////////////// -void gruCellBP(sd::LaunchContext* context, - const NDArray* x, const NDArray* hLast, - const NDArray* W, const NDArray* Wc, const NDArray* b, const NDArray* bc, - const NDArray* dLdr, const NDArray* dLdu, const NDArray* dLdc, const NDArray* dLdh, - NDArray* dLdx, NDArray* dLdhLast, - NDArray* dLdW, NDArray* dLdWc, - NDArray* dLdb, NDArray* dLdbc) { - - //Inputs: - // x input [bS, iS] - // hLast previous cell output [bS, nU], that is at previous time step t-1 - // W weights - [iS+nU, 2*nU] - reset and update gates - // Wc C weights - [iS+nU, nU] - cell gate - // b r and u biases, [2*nU] - reset and update gates - // bc c biases, [nU] - cell gate - // dLdr gradient wrt reset gate, [bS, nU] - // dLdu gradient wrt update gate, [bS, nU] - // dLdc gradient wrt cell state, [bS, nU] - // dLdh gradient wrt current cell output, [bS, nU] - - //Outputs: - // dLdx gradient wrt x, [bS, iS], - // dLdhLast gradient wrt hLast, [bS, nU] - // dLdW gradient wrt W, [iS+nU, 2*nU] - // dLdWc gradient wrt Wc, [iS+nU, nU] - // dLdb gradient wrt bru [2*nU] - // dLdbc gradient wrt bc [nU] - - // * means element-wise product or so called Hadamard product - // × means matrix multiplication - - /************************************************************************************************/ - /******************************* THIS IS NOT OPTIMAZED CODE *************************************/ - /*** aim is to have math-readable code in order to keep track of backprop formulas derivation ***/ - - const int bS = x->sizeAt(0); - const int iS = x->sizeAt(1); - const int nU = hLast->sizeAt(1); - - NDArray xT = x->transpose(); // [iS, bS] - NDArray hLastT = hLast->transpose(); // [nU, bS] - - NDArray Wrx = (*W)({0,iS, 0,nU}); // [iS, nU] - NDArray Wux = (*W)({0,iS, nU,2*nU}); // [iS, nU] - NDArray Wrh = (*W)({iS,iS+nU, 0,nU}); // [nU, nU] - NDArray Wuh = (*W)({iS,iS+nU, nU,2*nU}); // [nU, nU] - - NDArray Wcx = (*Wc)({0,iS, 0,0}); // reset cell weights [iS, nU] - NDArray Wch = (*Wc)({iS,iS+nU, 0,0}); // updates cell weights [nU, nU] - - NDArray br = (*b)({0, nU}); // [nU] - NDArray bu = (*b)({nU, 2*nU}); // [nU] - - NDArray WrxT = Wrx.transpose(); // [nU, iS] - NDArray WuxT = Wux.transpose(); // [nU, iS] - NDArray WrhT = Wrh.transpose(); // [nU, nU] - NDArray WuhT = Wuh.transpose(); // [nU, nU] - - NDArray WcxT = Wcx.transpose(); // [nU, iS] - NDArray WchT = Wch.transpose(); // [nU, nU] - - NDArray dLdWrx = (*dLdW)({0,iS, 0,nU}); // [iS, nU] - NDArray dLdWux = (*dLdW)({0,iS, nU,2*nU}); // [iS, nU] - NDArray dLdWrh = (*dLdW)({iS,iS+nU, 0,nU}); // [nU, nU] - NDArray dLdWuh = (*dLdW)({iS,iS+nU, nU,2*nU}); // [nU, nU] - - NDArray dLdWcx = (*dLdWc)({0,iS, 0,0}); // [iS, nU] - NDArray dLdWch = (*dLdWc)({iS,iS+nU, 0,0}); // [nU, nU] - - NDArray dLdbr = (*dLdb)({0, nU}); // [nU] - NDArray dLdbu = (*dLdb)({nU, 2*nU}); // [nU] - - - // ***** feed forward step ***** // - - // reset gate - NDArray r = mmul(*x, Wrx) + mmul(*hLast, Wrh) + br; // [bS, iS] × [iS, nU] + [bS, nU] × [nU, nU] + [nU] = [bS, nU] - r.applyTransform(transform::Sigmoid, r); - - // update gate - NDArray u = mmul(*x, Wux) + mmul(*hLast, Wuh) + bu; // [bS, iS] × [iS, nU] + [bS, nU] × [nU, nU] + [nU] = [bS, nU] - u.applyTransform(transform::Sigmoid, u); - - // cell gate c = activation(x×Wcx + (r*hlast)×Wcu + bc) - NDArray c = mmul(*x, Wcx) + mmul(r * *hLast, Wch) + *bc; // [bS, iS] × [iS, nU] + [bS, nU] × [nU, nU] + [nU] = [bS, nU] - c.applyTransform(transform::Tanh, c); - - // h = (1 - u) * c + u * hPrev - - - // ***** back prop step ***** // - - // notations: - // Zr = x × Wrx + hLast × Wrh + br - // Zu = x × Wux + hLast × Wuh + bu - // Sr = sigmoid(Zr) - // Su = sigmoid(Zu) - // Zc = x × Wcx + (r * hlast) × Wch + bc - - - // dLdx = dLdh * dhdx = dLdh * (dhdu * dudx + dhdc * dcdx) = (dLdh * dhdu) * dudx + (dLdh * dhdc) * dcdx = dLdu * dudx + dLdc * dcdx - // = dLdx_u + dLdx_c - // dLdx_u = dLdu * dudx = dLdu * dudZu * dZudx = |dZudx = ... × WuxT| = (dLdu * dudZu) × WuxT - // dLdx_c = dLdc * dcdx = dLdc * dcdZc * (dZcdx + dZcdr * drdx) = dLdc * dcdZc * dZcdx + dLdc * dcdZc * dZcdr * drdx = dLdx_c0 + dLdx_c1 - // dLdx_c0 = dLdc * dcdZc * dZcdx = |dZcdx = ... × WcxT| = (dLdc * dcdZc) × WcxT - // dZcdr = (... * hLast) × WchT - // dLdc * dcdZc * dZcdr = dLdr = (dLdc * dcdZc * hLast) × WchT - // drdx = drdZr * dZrdx - // dZrdx = ... × WrxT - // dLdx_c1 = dLdc * dcdZc * dZcdr * drdx = dLdr * drdx = (dLdr * drdZr) × WrxT - // finally dLdx = dLdx_u + dLdx_c0 + dLdx_c1 = (dLdu * dudZu) × WuxT + (dLdc * dcdZc) × WcxT + (dLdr * drdZr) × WrxT - - - // dLdhLast = dLdh * (dhdhLast + dhdu * dudhLast + dhdc * dcdhLast) = dLdh * dhdhLast + dLdu * dudhLast + dLdc * dcdhLast - // = dLdhLast_h + dLdhLast_u + dLdhLast_c - // dLdhLast_h = dLdh * dhdhLas = dLdh * u - // dLdhLast_u = dLdu * dudhLast = |dudhLast = dudZu * dZudhLast , dZudhLast = ... × WuhT| = (dLdu * dudZu) × WuhT - // dLdhLast_c = dLdc * dcdhLast = dLdc * (dcdZc * dZcdhLast + dcdZc * dZcdr * drdhLast) = - // = dLdc * dcdZc * dZcdhLast + dLdc * dcdZc * dZcdr * drdhLast = - // = dLdc * dcdZc * dZcdhLast + dLdr * drdhLast = dLdhLast_c0 + dLdhLast_c1 - // dLdhLast_c0 = dLdc * dcdZc * dZcdhLast = |dZcdhLast = (... * r) × WchT| = (dLdc * dcdZc * r) × WchT - // dLdhLast_c1 = dLdr * drdhLast = |drdhLast = drdZr * dZrdhLast, dZrdhLast = ... × WrhT| = (dLdr * drdZr) × WrhT - // finally dLdhLast = dLdhLast_h + dLdhLast_u + dLdhLast_c0 + dLdhLast_c1 = - // = dLdh * u + (dLdu * dudZu) × WuhT + (dLdc * dcdZc * r) × WchT + (dLdr * drdZr) × WrhT - - - // dLdWrx = dLdh * dhdWrx = (dLdh * dhdc) * dcdWrx = dLdc * dcdZc * dZcdWrx = dLdc * dcdZc * dZcdr * drdWrx = - // = dLdc * dcdZc * dZcdr * drdZr * dZrdWrx = dLdr * drdZr * dZrdWrx - // dZrdWrx = xT × ... - // finally dLdWrx = xT × (dLdr * drdZr) - - - // dLdWrh = dLdh * dhdWrh = (dLdh * dhdc) * dcdWrh = dLdc * dcdZc * dZcdWrh = dLdc * dcdZc * dZcdr * drdWrh = - // = dLdc * dcdZc * dZcdr * drdZr * dZrdWrh = dLdr * drdZr * dZrdWrh - // dZrdWrh = hLastT × ... - // finally dLdWrh = hLastT × (dLdr * drdZr) - - - // dLdWux = dLdh * dhdWux = (dLdh * dhdu) * dudWux = dLdu * dudZu * dZudWux - // dZudWux = xT × ... - // dLdu * dudZu * dZudWux = xT × (dLdu * dudZu) - - - // dLdWuh = dLdh * dhdWuh = (dLdh * dhdu) * dudWuh = dLdh * dhdu * dudZu * dZudWuh = dLdu * dudZu * dZudWuh - // dZudWuh = hLastT × ... - // finally dLdWuh = hLastT × (dLdu * dudZu) - - - // dLdWcx = dLdh * dhdWcx = dLdh * dhdc * dcdWcx = (dLdh * dhdc) * dcdZc * dZcdWcx = dLdc * dcdZc * dZcdWcx - // dZcdWcx = xT × ... - // finally dLdWcx = xT × (dLdc * dcdZc) - - - // dLdWch = dLdh * dhdWch = dLdh * dhdc * dcdWch = (dLdh * dhdc) * dcdZc * dZcdWch = dLdc * dcdZc * dZcdWch - // dZcdWch = (r*hLast)^T × ... - // finally dLdWch = (r*hLast)^T × (dLdc * dcdZc) - - - // dLdbr = dLdh * dhdbr = (dLdh * dhdc) * dcdbr = dLdc * dcdbr = dLdc * dcdZc * dZcdbr = dLdc * dcdZc * dZcdr * drdbr = - // = dLdr * drdZr * dZrdbr - // dZrdbr = 1 - // finally dLdbr = dLdr * drdZr - - - // dLdbu = dLdh * dhdbu = (dLdh * dhdu) * dudbu = dLdu * dudZu * dZudbu - // dZudbu = 1 - // finally dLdbu = dLdu * dudZu - - - // dLdbc = dLdh * dhdbc = (dLdh * dhdc) * dcdbc = dLdc * dcdZc * dZcdbc - // dZcdbc = 1 - // finally dLdbc = dLdc * dcdZc - - NDArray dhdc = 1.f - u; // [bS, nU] - NDArray dhdu = *hLast - c; // [bS, nU] - NDArray dudZu = u * dhdc; // [bS, nU] - NDArray drdZr = r * (1.f - r); // [bS, nU] - NDArray dcdZc = 1.f - c * c; // [bS, nU] - NDArray dLdZc = *dLdc * dcdZc; // [bS, nU] - NDArray dLdZu = *dLdu * dudZu; // [bS, nU] - NDArray dLdZr = *dLdr * drdZr; // [bS, nU] - - // NDArray dLdc = *dLdh * dhdc; // [bS, nU] - // NDArray dLdu = *dLdh * dhdu; // [bS, nU] - // NDArray dLdr = mmul(dLdc * dcdZc * *hLast, WchT); // [bS, nU] - - dLdx->assign(mmul(dLdZu, WuxT) + mmul(dLdZc, WcxT) + mmul(dLdZr, WrxT)); // [bS, iS] - - dLdhLast->assign(*dLdh * u + mmul(dLdZu, WuhT) + mmul(dLdZc * r, WchT) + mmul(dLdZr, WrhT)); // [bS, nU] - - dLdWrx.assign(mmul(xT, dLdZr)); // [iS, bS] × [bS, nU] = [iS, nU] - dLdWrh.assign(mmul(hLastT, dLdZr)); // [nU, bS] × [bS, nU] = [nU, nU] - dLdWux.assign(mmul(xT, dLdZu)); // [iS, bS] × [bS, nU] = [iS, nU] - dLdWuh.assign(mmul(hLastT, dLdZu)); // [nU, bS] × [bS, nU] = [nU, nU] - - dLdWcx.assign(mmul(xT, dLdZc)); // [iS, bS] × [bS, nU] = [iS, nU] - dLdWch.assign(mmul((r * *hLast).transpose(), dLdZc)); // [nU, bS] × [bS, nU] = [nU, nU] - - dLdbr.assign(dLdZr.reduceAlongDimension(reduce::Sum, {0})); // [nU] - dLdbu.assign(dLdZu.reduceAlongDimension(reduce::Sum, {0})); // [nU] - - dLdbc->assign(dLdZc.reduceAlongDimension(reduce::Sum, {0})); // [nU] -} - -// ////////////////////////////////////////////////////////////////////////// -// FIXME - gruTimeLoopBP is not correct -// template -// void gruTimeLoopBP(const std::vector*>& inArrs, const std::vector*>& outArrs) { - -// NDArray* x = inArrs[0]; // input [time, bS, iS] -// NDArray* hi = inArrs[1]; // previous/initial cell output [bS, nU], that is at previous time step t-1 -// NDArray* Wx = inArrs[2]; // input-to-hidden weights, [iS, 3*nU] -// NDArray* Wh = inArrs[3]; // hidden-to-hidden weights, [nU, 3*nU] -// NDArray* b = inArrs[4]; // biases, [3*nU] -// NDArray* dLdh = inArrs[5]; // gradient wrt output, [time, bS, nU], that is epsilon_next - -// NDArray* dLdx = outArrs[0]; // gradient wrt x, [time, bS, iS], that is epsilon -// NDArray* dLdhi = outArrs[1]; // gradient wrt hi, [bS, nU] -// NDArray* dLdWx = outArrs[2]; // gradient wrt Wx, [iS, 3*nU] -// NDArray* dLdWh = outArrs[3]; // gradient wrt Wh, [nU, 3*nU] -// NDArray* dLdb = outArrs[4]; // gradient wrt b, [3*nU] - -// const Nd4jLong time = x->sizeAt(0); -// const Nd4jLong bS = x->sizeAt(1); -// const Nd4jLong iS = x->sizeAt(2); -// const Nd4jLong nU = hi->sizeAt(1); - -// NDArray h(hi->ordering(), {time, bS, nU}); // feed forward output - -// // first step, time = 0, feed forward -// NDArray x0 = (*x)({{0,1}, {}, {}}); -// NDArray hLast = h({{0,1}, {}, {}}); -// helpers::gruCell({&x0, hi, Wx, Wh, b}, &hLast); - -// // first step, time = 0, back prop -// NDArray dLdx0 = (*dLdx)({{0,1}, {}, {}}); -// NDArray dLdhLast = (*dLdh)({{0,1}, {}, {}}); -// helpers::gruCellBP({&x0, hi, Wx, Wh, b, &dLdhLast, nullptr, nullptr, nullptr}, {&dLdx0, dLdhi, dLdWx, dLdWh, dLdb}); - -// // loop through the rest time steps -// for (Nd4jLong t = time-1; t > 0; --t) { -// for (Nd4jLong t = 1; t < time; ++t) { - -// NDArray xt = (*x)({{t,t+1}, {}, {}}); -// NDArray ht = h({{t,t+1}, {}, {}}); -// NDArray ht_1 = h({{t-1,t}, {}, {}}); -// NDArray dLdxt = (*dLdx)({{t,t+1}, {}, {}}); -// NDArray dLdht = (*dLdh)({{t,t+1}, {}, {}}); - -// NDArray dLdWxt_1 = dLdWx; -// NDArray dLdWht_1 = dLdWh; -// NDArray dLdbt_1 = dLdb; - -// // feed forward, calculation of ht -// helpers::gruCell({&xt, &ht_1, Wx, Wh, b}, &ht); - -// // back prop -// helpers::gruCellBP({&xt, &ht_1, Wx, Wh, b, &dLdht, &dLdWxt_1, &dLdWht_1, &dLdbt_1}, {&dLdxt, nullptr, dLdWx, dLdWh, dLdb}); -// } -// } - - -} -} -} diff --git a/libnd4j/include/ops/declarable/helpers/cuda/gru.cu b/libnd4j/include/ops/declarable/helpers/cuda/gru.cu deleted file mode 100644 index bd4e878e3..000000000 --- a/libnd4j/include/ops/declarable/helpers/cuda/gru.cu +++ /dev/null @@ -1,365 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -// -// @author Yurii Shyrma (iuriish@yahoo.com), created on 15.02.2018 -// - -// implementation of gated Recurrent Unit cell -// (cf. https://arxiv.org/abs/1406.1078). -// Kyunghyun Cho, Bart van Merrienboer, Caglar Gulcehre, Dzmitry Bahdanau, Fethi Bougares, Holger Schwenk, Yoshua Bengio -// "Learning Phrase Representations using RNN Encoder-Decoder for Statistical Machine Translation" - - -#include -#include -#include -#include - -namespace sd { -namespace ops { -namespace helpers { - - -////////////////////////////////////////////////////////////////////////// -void gruCell(sd::LaunchContext * context, const NDArray* x, const NDArray* hLast, const NDArray* W, const NDArray* Wc, - const NDArray* b, const NDArray* bc, - NDArray* r, NDArray* u, NDArray* c, NDArray* h) { - - //Inputs: - // x input [bS, iS], iS - input size - // hLast previous cell output [bS, nU], that is at previous time step t-1, nU - number of units - // W RU weights - [iS+nU, 2*nU] - reset and update gates - // Wc C weights - [iS+nU, nU] - cell gate - // b r and u biases, [2*nU] - reset and update gates - // bc c biases, [nU] - cell gate - - //Outputs: - // r Reset gate output [bS, nU] - // u Update gate output [bS, nU] - // c Cell gate output [bS, nU] - // h current cell output [bS, nU] - - /***************************************************************************************/ - /************************ THIS IS NOT OPTIMAZED CODE ***********************************/ - /** however it is more math-friendly and convenient for backprop formulas derivation) **/ - - const int bS = x->sizeAt(0); - const int iS = x->sizeAt(1); - const int nU = hLast->sizeAt(1); - - NDArray Wrx = (*W)({0,iS, 0,nU}); // [iS, nU] - NDArray Wux = (*W)({0,iS, nU,2*nU}); // [iS, nU] - NDArray Wrh = (*W)({iS,iS+nU, 0,nU}); // [nU, nU] - NDArray Wuh = (*W)({iS,iS+nU, nU,2*nU}); // [nU, nU] - - NDArray Wcx = (*Wc)({0,iS, 0,0}); // reset cell weights [iS, nU] - NDArray Wch = (*Wc)({iS,iS+nU, 0,0}); // updates cell weights [nU, nU] - - NDArray br = (*b)({0, nU}); // [nU] - NDArray bu = (*b)({nU, 2*nU}); // [nU] - - // × means matrix multipication - // * means element-wise product or so called Hadamard product - - // reset gate - r->assign(mmul(*x, Wrx) + mmul(*hLast, Wrh) + br); // [bS, iS] × [iS, nU] + [bS, nU] × [nU, nU] + [nU] = [bS, nU] - r->applyTransform(transform::Sigmoid, *r); - - // update gate - u->assign(mmul(*x, Wux) + mmul(*hLast, Wuh) + bu); // [bS, iS] × [iS, nU] + [bS, nU] × [nU, nU] + [nU] = [bS, nU] - u->applyTransform(transform::Sigmoid, *u); - - // cell gate c = activation(x × Wcx + (r * hlast) × Wch + bc) - c->assign(mmul(*x, Wcx) + mmul(*r * *hLast, Wch) + *bc); // [bS, iS] × [iS, nU] + [bS, nU] × [nU, nU] + [nU] = [bS, nU] - c->applyTransform(transform::Tanh, *c); - - NDArray temp = 1.f - *c * *c; - - // cell output - h->assign(*u * *hLast + (1.f - *u) * *c); - - - /***************************************************************************************/ - /*************** THIS IS MORE OPTIMAZED CODE (should think about concat) ***************/ - /***************************************************************************************/ -/* - //Concat inputs: x + hLast : [bs, iS + nU] - NDArray xhConcat(x->ordering(), {bS, iS + nU}, x->dataType(), context); // concat([bs, iS], [bs, nU]) -> [bs, iS + nU] - helpers::concat(context, {const_cast(x), const_cast(hLast)}, xhConcat, {1}); - - //mmul for reset and update gates: (x × weight_ux + hLast × weight_xr + b_u) - auto m = mmul(xhConcat, *W) + *b ; // [bs, iS+nU] * [iS+nU, 2*nU] = [bs, 2*nU] - // m += *bru; - - m.applyTransform(transform::Sigmoid); //sigmoid(rz) and sigmoid(uz) - - r->assign(m({0,0, 0, nU})); - u->assign(m({0,0, nU, 2*nU})); - - // hLast = hLast * r - xhConcat({0,0, iS, iS+nU}) *= *r; - - //c = tanh(x × weight_cx + (hLast * r) × weight_cr + b_c) - MmulHelper::mmul(&xhConcat, Wc, c, 1.0, 0.0); //c = 1.0 * xhConcat * Wc + 0.0 * c - *c += *bc; - c->applyTransform(transform::Tanh); - - //Output: h = (1-u).*c + u .* hPrev - //auto hResult = (*u) * (*hLast) + (1.0f - *u) * (*c); const_cast(h)->assign(&hResult); - u->applyPairwiseTransform(pairwise::Multiply, hLast, h, nullptr); //h = u * hLast - auto temp = (1.0f - *u); - temp *= (*c); - (*h) += temp; -*/ -} - -////////////////////////////////////////////////////////////////////////// -void gruTimeLoop(sd::LaunchContext * context, const NDArray* x, const NDArray* hLast, const NDArray* Wx, const NDArray* Wh, const NDArray* b, NDArray* h) { - - // x input [time, bS, iS] - // hLast initial cell output (at time step = 0) [bS, nU] - // Wx input-to-hidden weights, [iS, 3*nU] - // Wh hidden-to-hidden weights, [nU, 3*nU] - // b biases, [3*nU] - - // h is cell outputs at each time step [time, bS, nU] - - const int time = x->sizeAt(0); - - NDArray ht_1(*hLast); - - // loop through time steps - for (int t = 0; t < time; ++t) { - - auto xt = (*x)({t,t+1, 0,0, 0,0}); - auto ht = (*h)({t,t+1, 0,0, 0,0}); - - // helpers::gruCell(&xt, &ht_1, Wx, Wh, b, &ht); - // ht_1.assign(ht); - } -} - -////////////////////////////////////////////////////////////////////////// -void gruCellBP(sd::LaunchContext* context, - const NDArray* x, const NDArray* hLast, - const NDArray* W, const NDArray* Wc, const NDArray* b, const NDArray* bc, - const NDArray* dLdr, const NDArray* dLdu, const NDArray* dLdc, const NDArray* dLdh, - NDArray* dLdx, NDArray* dLdhLast, - NDArray* dLdW, NDArray* dLdWc, - NDArray* dLdb, NDArray* dLdbc) { - - //Inputs: - // x input [bS, iS] - // hLast previous cell output [bS, nU], that is at previous time step t-1 - // W weights - [iS+nU, 2*nU] - reset and update gates - // Wc C weights - [iS+nU, nU] - cell gate - // b r and u biases, [2*nU] - reset and update gates - // bc c biases, [nU] - cell gate - // dLdr gradient wrt reset gate, [bS, nU] - // dLdu gradient wrt update gate, [bS, nU] - // dLdc gradient wrt cell state, [bS, nU] - // dLdh gradient wrt current cell output, [bS, nU] - - //Outputs: - // dLdx gradient wrt x, [bS, iS], - // dLdhLast gradient wrt hLast, [bS, nU] - // dLdW gradient wrt W, [iS+nU, 2*nU] - // dLdWc gradient wrt Wc, [iS+nU, nU] - // dLdb gradient wrt bru [2*nU] - // dLdbc gradient wrt bc [nU] - - // * means element-wise product or so called Hadamard product - // × means matrix multiplication - - /************************************************************************************************/ - /******************************* THIS IS NOT OPTIMAZED CODE *************************************/ - /*** aim is to have math-readable code in order to keep track of backprop formulas derivation ***/ - - const int bS = x->sizeAt(0); - const int iS = x->sizeAt(1); - const int nU = hLast->sizeAt(1); - - NDArray xT = x->transpose(); // [iS, bS] - NDArray hLastT = hLast->transpose(); // [nU, bS] - - NDArray Wrx = (*W)({0,iS, 0,nU}); // [iS, nU] - NDArray Wux = (*W)({0,iS, nU,2*nU}); // [iS, nU] - NDArray Wrh = (*W)({iS,iS+nU, 0,nU}); // [nU, nU] - NDArray Wuh = (*W)({iS,iS+nU, nU,2*nU}); // [nU, nU] - - NDArray Wcx = (*Wc)({0,iS, 0,0}); // reset cell weights [iS, nU] - NDArray Wch = (*Wc)({iS,iS+nU, 0,0}); // updates cell weights [nU, nU] - - NDArray br = (*b)({0, nU}); // [nU] - NDArray bu = (*b)({nU, 2*nU}); // [nU] - - NDArray WrxT = Wrx.transpose(); // [nU, iS] - NDArray WuxT = Wux.transpose(); // [nU, iS] - NDArray WrhT = Wrh.transpose(); // [nU, nU] - NDArray WuhT = Wuh.transpose(); // [nU, nU] - - NDArray WcxT = Wcx.transpose(); // [nU, iS] - NDArray WchT = Wch.transpose(); // [nU, nU] - - NDArray dLdWrx = (*dLdW)({0,iS, 0,nU}); // [iS, nU] - NDArray dLdWux = (*dLdW)({0,iS, nU,2*nU}); // [iS, nU] - NDArray dLdWrh = (*dLdW)({iS,iS+nU, 0,nU}); // [nU, nU] - NDArray dLdWuh = (*dLdW)({iS,iS+nU, nU,2*nU}); // [nU, nU] - - NDArray dLdWcx = (*dLdWc)({0,iS, 0,0}); // [iS, nU] - NDArray dLdWch = (*dLdWc)({iS,iS+nU, 0,0}); // [nU, nU] - - NDArray dLdbr = (*dLdb)({0, nU}); // [nU] - NDArray dLdbu = (*dLdb)({nU, 2*nU}); // [nU] - - - // ***** feed forward step ***** // - - // reset gate - NDArray r = mmul(*x, Wrx) + mmul(*hLast, Wrh) + br; // [bS, iS] × [iS, nU] + [bS, nU] × [nU, nU] + [nU] = [bS, nU] - r.applyTransform(transform::Sigmoid, r); - - // update gate - NDArray u = mmul(*x, Wux) + mmul(*hLast, Wuh) + bu; // [bS, iS] × [iS, nU] + [bS, nU] × [nU, nU] + [nU] = [bS, nU] - u.applyTransform(transform::Sigmoid, u); - - // cell gate c = activation(x×Wcx + (r*hlast)×Wcu + bc) - NDArray c = mmul(*x, Wcx) + mmul(r * *hLast, Wch) + *bc; // [bS, iS] × [iS, nU] + [bS, nU] × [nU, nU] + [nU] = [bS, nU] - c.applyTransform(transform::Tanh, c); - - // h = (1 - u) * c + u * hPrev - - - // ***** back prop step ***** // - - // notations: - // Zr = x × Wrx + hLast × Wrh + br - // Zu = x × Wux + hLast × Wuh + bu - // Sr = sigmoid(Zr) - // Su = sigmoid(Zu) - // Zc = x × Wcx + (r * hlast) × Wch + bc - - - // dLdx = dLdh * dhdx = dLdh * (dhdu * dudx + dhdc * dcdx) = (dLdh * dhdu) * dudx + (dLdh * dhdc) * dcdx = dLdu * dudx + dLdc * dcdx - // = dLdx_u + dLdx_c - // dLdx_u = dLdu * dudx = dLdu * dudZu * dZudx = |dZudx = ... × WuxT| = (dLdu * dudZu) × WuxT - // dLdx_c = dLdc * dcdx = dLdc * dcdZc * (dZcdx + dZcdr * drdx) = dLdc * dcdZc * dZcdx + dLdc * dcdZc * dZcdr * drdx = dLdx_c0 + dLdx_c1 - // dLdx_c0 = dLdc * dcdZc * dZcdx = |dZcdx = ... × WcxT| = (dLdc * dcdZc) × WcxT - // dZcdr = (... * hLast) × WchT - // dLdc * dcdZc * dZcdr = dLdr = (dLdc * dcdZc * hLast) × WchT - // drdx = drdZr * dZrdx - // dZrdx = ... × WrxT - // dLdx_c1 = dLdc * dcdZc * dZcdr * drdx = dLdr * drdx = (dLdr * drdZr) × WrxT - // finally dLdx = dLdx_u + dLdx_c0 + dLdx_c1 = (dLdu * dudZu) × WuxT + (dLdc * dcdZc) × WcxT + (dLdr * drdZr) × WrxT - - - // dLdhLast = dLdh * (dhdhLast + dhdu * dudhLast + dhdc * dcdhLast) = dLdh * dhdhLast + dLdu * dudhLast + dLdc * dcdhLast - // = dLdhLast_h + dLdhLast_u + dLdhLast_c - // dLdhLast_h = dLdh * dhdhLas = dLdh * u - // dLdhLast_u = dLdu * dudhLast = |dudhLast = dudZu * dZudhLast , dZudhLast = ... × WuhT| = (dLdu * dudZu) × WuhT - // dLdhLast_c = dLdc * dcdhLast = dLdc * (dcdZc * dZcdhLast + dcdZc * dZcdr * drdhLast) = - // = dLdc * dcdZc * dZcdhLast + dLdc * dcdZc * dZcdr * drdhLast = - // = dLdc * dcdZc * dZcdhLast + dLdr * drdhLast = dLdhLast_c0 + dLdhLast_c1 - // dLdhLast_c0 = dLdc * dcdZc * dZcdhLast = |dZcdhLast = (... * r) × WchT| = (dLdc * dcdZc * r) × WchT - // dLdhLast_c1 = dLdr * drdhLast = |drdhLast = drdZr * dZrdhLast, dZrdhLast = ... × WrhT| = (dLdr * drdZr) × WrhT - // finally dLdhLast = dLdhLast_h + dLdhLast_u + dLdhLast_c0 + dLdhLast_c1 = - // = dLdh * u + (dLdu * dudZu) × WuhT + (dLdc * dcdZc * r) × WchT + (dLdr * drdZr) × WrhT - - - // dLdWrx = dLdh * dhdWrx = (dLdh * dhdc) * dcdWrx = dLdc * dcdZc * dZcdWrx = dLdc * dcdZc * dZcdr * drdWrx = - // = dLdc * dcdZc * dZcdr * drdZr * dZrdWrx = dLdr * drdZr * dZrdWrx - // dZrdWrx = xT × ... - // finally dLdWrx = xT × (dLdr * drdZr) - - - // dLdWrh = dLdh * dhdWrh = (dLdh * dhdc) * dcdWrh = dLdc * dcdZc * dZcdWrh = dLdc * dcdZc * dZcdr * drdWrh = - // = dLdc * dcdZc * dZcdr * drdZr * dZrdWrh = dLdr * drdZr * dZrdWrh - // dZrdWrh = hLastT × ... - // finally dLdWrh = hLastT × (dLdr * drdZr) - - - // dLdWux = dLdh * dhdWux = (dLdh * dhdu) * dudWux = dLdu * dudZu * dZudWux - // dZudWux = xT × ... - // dLdu * dudZu * dZudWux = xT × (dLdu * dudZu) - - - // dLdWuh = dLdh * dhdWuh = (dLdh * dhdu) * dudWuh = dLdh * dhdu * dudZu * dZudWuh = dLdu * dudZu * dZudWuh - // dZudWuh = hLastT × ... - // finally dLdWuh = hLastT × (dLdu * dudZu) - - - // dLdWcx = dLdh * dhdWcx = dLdh * dhdc * dcdWcx = (dLdh * dhdc) * dcdZc * dZcdWcx = dLdc * dcdZc * dZcdWcx - // dZcdWcx = xT × ... - // finally dLdWcx = xT × (dLdc * dcdZc) - - - // dLdWch = dLdh * dhdWch = dLdh * dhdc * dcdWch = (dLdh * dhdc) * dcdZc * dZcdWch = dLdc * dcdZc * dZcdWch - // dZcdWch = (r*hLast)^T × ... - // finally dLdWch = (r*hLast)^T × (dLdc * dcdZc) - - - // dLdbr = dLdh * dhdbr = (dLdh * dhdc) * dcdbr = dLdc * dcdbr = dLdc * dcdZc * dZcdbr = dLdc * dcdZc * dZcdr * drdbr = - // = dLdr * drdZr * dZrdbr - // dZrdbr = 1 - // finally dLdbr = dLdr * drdZr - - - // dLdbu = dLdh * dhdbu = (dLdh * dhdu) * dudbu = dLdu * dudZu * dZudbu - // dZudbu = 1 - // finally dLdbu = dLdu * dudZu - - - // dLdbc = dLdh * dhdbc = (dLdh * dhdc) * dcdbc = dLdc * dcdZc * dZcdbc - // dZcdbc = 1 - // finally dLdbc = dLdc * dcdZc - - NDArray dhdc = 1.f - u; // [bS, nU] - NDArray dhdu = *hLast - c; // [bS, nU] - NDArray dudZu = u * dhdc; // [bS, nU] - NDArray drdZr = r * (1.f - r); // [bS, nU] - NDArray dcdZc = 1.f - c * c; // [bS, nU] - NDArray dLdZc = *dLdc * dcdZc; // [bS, nU] - NDArray dLdZu = *dLdu * dudZu; // [bS, nU] - NDArray dLdZr = *dLdr * drdZr; // [bS, nU] - - // NDArray dLdc = *dLdh * dhdc; // [bS, nU] - // NDArray dLdu = *dLdh * dhdu; // [bS, nU] - // NDArray dLdr = mmul(dLdc * dcdZc * *hLast, WchT); // [bS, nU] - - dLdx->assign(mmul(dLdZu, WuxT) + mmul(dLdZc, WcxT) + mmul(dLdZr, WrxT)); // [bS, iS] - - dLdhLast->assign(*dLdh * u + mmul(dLdZu, WuhT) + mmul(dLdZc * r, WchT) + mmul(dLdZr, WrhT)); // [bS, nU] - - dLdWrx.assign(mmul(xT, dLdZr)); // [iS, bS] × [bS, nU] = [iS, nU] - dLdWrh.assign(mmul(hLastT, dLdZr)); // [nU, bS] × [bS, nU] = [nU, nU] - dLdWux.assign(mmul(xT, dLdZu)); // [iS, bS] × [bS, nU] = [iS, nU] - dLdWuh.assign(mmul(hLastT, dLdZu)); // [nU, bS] × [bS, nU] = [nU, nU] - - dLdWcx.assign(mmul(xT, dLdZc)); // [iS, bS] × [bS, nU] = [iS, nU] - dLdWch.assign(mmul((r * *hLast).transpose(), dLdZc)); // [nU, bS] × [bS, nU] = [nU, nU] - - dLdbr.assign(dLdZr.reduceAlongDimension(reduce::Sum, {0})); // [nU] - dLdbu.assign(dLdZu.reduceAlongDimension(reduce::Sum, {0})); // [nU] - - dLdbc->assign(dLdZc.reduceAlongDimension(reduce::Sum, {0})); // [nU] -} - - -} -} -} - diff --git a/libnd4j/include/ops/declarable/helpers/cuda/lup.cu b/libnd4j/include/ops/declarable/helpers/cuda/lup.cu index 2ca731912..c986260e8 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/lup.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/lup.cu @@ -341,14 +341,16 @@ namespace helpers { static void lup_(LaunchContext *context, NDArray *input, NDArray *compound, NDArray *permutation) { auto stream = context->getCudaStream(); auto n = input->rows(); - cusolverDnHandle_t cusolverH = nullptr; + std::lock_guard lock(*LaunchContext::deviceMutex()); + + cusolverDnHandle_t* cusolverH = (cusolverDnHandle_t*)context->getCusolverHandle(); //nullptr; // create solver handle - cusolverStatus_t status = cusolverDnCreate(&cusolverH); - if (CUSOLVER_STATUS_SUCCESS != status) { - throw cuda_exception::build("Cannot create cuSolver handle", status); - } + cusolverStatus_t status; //cusolverDnCreate(&cusolverH); +// if (CUSOLVER_STATUS_SUCCESS != status) { +// throw cuda_exception::build("Cannot create cuSolver handle", status); +// } // set solver stream - status = cusolverDnSetStream(cusolverH, *stream); + status = cusolverDnSetStream(*cusolverH, *stream); if (CUSOLVER_STATUS_SUCCESS != status) { throw cuda_exception::build("Cannot set up stream for cuda solver", status); } @@ -368,7 +370,7 @@ namespace helpers { // compute internal buffer size double *matrix = reinterpret_cast(input->specialBuffer()); status = cusolverDnDgetrf_bufferSize( - cusolverH, + *cusolverH, n, n, matrix, @@ -386,7 +388,7 @@ namespace helpers { if (permutation == nullptr) { status = cusolverDnDgetrf( - cusolverH, + *cusolverH, n, n, matrix, @@ -404,7 +406,7 @@ namespace helpers { NDArray permutVector('c', {n}, sd::DataType::INT32, context); int* permutationBuf = permutVector.dataBuffer()->specialAsT(); status = cusolverDnDgetrf( - cusolverH, + *cusolverH, n, n, matrix, @@ -440,7 +442,7 @@ namespace helpers { float *d_work = nullptr; status = cusolverDnSgetrf_bufferSize( - cusolverH, + *cusolverH, n, n, matrix, @@ -458,7 +460,7 @@ namespace helpers { if (permutation == nullptr) status = cusolverDnSgetrf( - cusolverH, + *cusolverH, n, n, matrix, @@ -470,7 +472,7 @@ namespace helpers { NDArray permutVector('c', {n}, DataType::INT32, context); int *permutationBuf = reinterpret_cast(permutVector.specialBuffer()); status = cusolverDnSgetrf( - cusolverH, + *cusolverH, n, n, matrix, @@ -504,7 +506,7 @@ namespace helpers { if (err) { throw cuda_exception::build("helpers::lup_: Cannot deallocate memory for solver info buffer", err); } - cusolverDnDestroy(cusolverH); +// cusolverDnDestroy(cusolverH); // NDArray::registerSpecialUse({input}, {input}); input->tickWriteDevice(); } diff --git a/libnd4j/include/ops/declarable/helpers/cuda/svd.cu b/libnd4j/include/ops/declarable/helpers/cuda/svd.cu index 44f924bf0..5c3d2811c 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/svd.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/svd.cu @@ -170,23 +170,25 @@ static void svdQR(sd::LaunchContext* context, const NDArray* A, NDArray* S, NDAr } } + std::lock_guard lock(*LaunchContext::deviceMutex()); + // create cusolverDn handle - cusolverDnHandle_t handle = nullptr; - cusolverStatus_t status = cusolverDnCreate(&handle); - if(status != CUSOLVER_STATUS_SUCCESS) - throw cuda_exception::build("svdQR: cuda failed !", status); + cusolverDnHandle_t* handle = (cusolverDnHandle_t*)context->getCusolverHandle(); //nullptr; + //cusolverStatus_t status = cusolverDnCreate(&handle); + if(handle == nullptr) + throw cuda_exception::build("svdQR: cuda failed !", -1); // stream - status = cusolverDnSetStream(handle, *context->getCudaStream()); + auto status = cusolverDnSetStream(*handle, *context->getCudaStream()); if(status != CUSOLVER_STATUS_SUCCESS) throw cuda_exception::build("svdQR: cuda failed !", status); // query working space of SVD int lwork = 0; if(A->dataType() == DataType::DOUBLE) - status = cusolverDnDgesvd_bufferSize(handle, m, n, &lwork); + status = cusolverDnDgesvd_bufferSize(*handle, m, n, &lwork); else if(A->dataType() == DataType::FLOAT32) - status = cusolverDnSgesvd_bufferSize(handle, m, n, &lwork); + status = cusolverDnSgesvd_bufferSize(*handle, m, n, &lwork); else throw std::invalid_argument("svdQR: given data type is unsupported !"); @@ -227,10 +229,10 @@ static void svdQR(sd::LaunchContext* context, const NDArray* A, NDArray* S, NDAr // choose appropriate cuda gemm api depending on data types if(A->dataType() == DataType::DOUBLE) { - status = cusolverDnDgesvd(handle, jobu, jobvt, m, n, reinterpret_cast(pA->getSpecialBuffer()), lda, reinterpret_cast(pS->getSpecialBuffer()), calcUV ? reinterpret_cast(pU->getSpecialBuffer()) : nullptr, ldu, calcUV ? reinterpret_cast(pVT->getSpecialBuffer()) : nullptr, ldvt, reinterpret_cast(dWork), lwork, reinterpret_cast(rWork), devInfo); + status = cusolverDnDgesvd(*handle, jobu, jobvt, m, n, reinterpret_cast(pA->getSpecialBuffer()), lda, reinterpret_cast(pS->getSpecialBuffer()), calcUV ? reinterpret_cast(pU->getSpecialBuffer()) : nullptr, ldu, calcUV ? reinterpret_cast(pVT->getSpecialBuffer()) : nullptr, ldvt, reinterpret_cast(dWork), lwork, reinterpret_cast(rWork), devInfo); } else if(A->dataType() == DataType::FLOAT32) { - status = cusolverDnSgesvd(handle, jobu, jobvt, m, n, reinterpret_cast(pA->getSpecialBuffer()), lda, reinterpret_cast(pS->getSpecialBuffer()), calcUV ? reinterpret_cast(pU->getSpecialBuffer()) : nullptr, ldu, calcUV ? reinterpret_cast(pVT->getSpecialBuffer()) : nullptr, ldvt, reinterpret_cast(dWork), lwork, reinterpret_cast(rWork), devInfo); + status = cusolverDnSgesvd(*handle, jobu, jobvt, m, n, reinterpret_cast(pA->getSpecialBuffer()), lda, reinterpret_cast(pS->getSpecialBuffer()), calcUV ? reinterpret_cast(pU->getSpecialBuffer()) : nullptr, ldu, calcUV ? reinterpret_cast(pVT->getSpecialBuffer()) : nullptr, ldvt, reinterpret_cast(dWork), lwork, reinterpret_cast(rWork), devInfo); } else throw std::invalid_argument("svdQR: given data type is unsupported !"); @@ -259,8 +261,8 @@ static void svdQR(sd::LaunchContext* context, const NDArray* A, NDArray* S, NDAr if (rWork) cudaFree(rWork); - if(handle) - cusolverDnDestroy(handle); +// if(handle) +// cusolverDnDestroy(handle); // cudaDeviceReset(); } @@ -346,14 +348,16 @@ static void svdJcb(sd::LaunchContext* context, const NDArray* A, NDArray* S, NDA ldv = pV->strideAt(1); } + std::lock_guard lock(*LaunchContext::deviceMutex()); + // create cusolverDn handle - cusolverDnHandle_t handle = nullptr; - cusolverStatus_t status = cusolverDnCreate(&handle); - if(status != CUSOLVER_STATUS_SUCCESS) - throw cuda_exception::build("svdJcb: cuda failed !", status); + cusolverDnHandle_t* handle = (cusolverDnHandle_t*)context->getCusolverHandle(); + //cusolverStatus_t status = cusolverDnCreate(&handle); + if(handle == nullptr) + throw cuda_exception::build("svdJcb: cuda failed !", -1); // stream - status = cusolverDnSetStream(handle, *context->getCudaStream()); + auto status = cusolverDnSetStream(*handle, *context->getCudaStream()); if(status != CUSOLVER_STATUS_SUCCESS) throw cuda_exception::build("svdJcb: cuda failed !", status); @@ -391,9 +395,9 @@ static void svdJcb(sd::LaunchContext* context, const NDArray* A, NDArray* S, NDA // query working space of SVD int lwork = 0; if(A->dataType() == DataType::DOUBLE) - status = cusolverDnDgesvdj_bufferSize(handle, jobz, econ, m, n, reinterpret_cast(pA->getSpecialBuffer()), lda, reinterpret_cast(pS->getSpecialBuffer()), calcUV ? reinterpret_cast(pU->getSpecialBuffer()) : reinterpret_cast(nullPtr), ldu, calcUV ? reinterpret_cast(pV->getSpecialBuffer()) : reinterpret_cast(nullPtr), ldv, &lwork, gesvdjParams); + status = cusolverDnDgesvdj_bufferSize(*handle, jobz, econ, m, n, reinterpret_cast(pA->getSpecialBuffer()), lda, reinterpret_cast(pS->getSpecialBuffer()), calcUV ? reinterpret_cast(pU->getSpecialBuffer()) : reinterpret_cast(nullPtr), ldu, calcUV ? reinterpret_cast(pV->getSpecialBuffer()) : reinterpret_cast(nullPtr), ldv, &lwork, gesvdjParams); else if(A->dataType() == DataType::FLOAT32) - status = cusolverDnSgesvdj_bufferSize(handle, jobz, econ, m, n, reinterpret_cast(pA->getSpecialBuffer()), lda, reinterpret_cast(pS->getSpecialBuffer()), calcUV ? reinterpret_cast(pU->getSpecialBuffer()) : reinterpret_cast(nullPtr), ldu, calcUV ? reinterpret_cast(pV->getSpecialBuffer()) : reinterpret_cast(nullPtr), ldv, &lwork, gesvdjParams); + status = cusolverDnSgesvdj_bufferSize(*handle, jobz, econ, m, n, reinterpret_cast(pA->getSpecialBuffer()), lda, reinterpret_cast(pS->getSpecialBuffer()), calcUV ? reinterpret_cast(pU->getSpecialBuffer()) : reinterpret_cast(nullPtr), ldu, calcUV ? reinterpret_cast(pV->getSpecialBuffer()) : reinterpret_cast(nullPtr), ldv, &lwork, gesvdjParams); else throw std::invalid_argument("svdJcb: given data type is unsupported !"); @@ -410,10 +414,10 @@ static void svdJcb(sd::LaunchContext* context, const NDArray* A, NDArray* S, NDA // choose appropriate cuda gemm api depending on data types if(A->dataType() == DataType::DOUBLE) { - status = cusolverDnDgesvdj(handle, jobz, econ, m, n, reinterpret_cast(pA->getSpecialBuffer()), lda, reinterpret_cast(pS->getSpecialBuffer()), calcUV ? reinterpret_cast(pU->getSpecialBuffer()) : reinterpret_cast(nullPtr), ldu, calcUV ? reinterpret_cast(pV->getSpecialBuffer()) : reinterpret_cast(nullPtr), ldv, reinterpret_cast(dWork), lwork, devInfo, gesvdjParams); + status = cusolverDnDgesvdj(*handle, jobz, econ, m, n, reinterpret_cast(pA->getSpecialBuffer()), lda, reinterpret_cast(pS->getSpecialBuffer()), calcUV ? reinterpret_cast(pU->getSpecialBuffer()) : reinterpret_cast(nullPtr), ldu, calcUV ? reinterpret_cast(pV->getSpecialBuffer()) : reinterpret_cast(nullPtr), ldv, reinterpret_cast(dWork), lwork, devInfo, gesvdjParams); } else if(A->dataType() == DataType::FLOAT32) { - status = cusolverDnSgesvdj(handle, jobz, econ, m, n, reinterpret_cast(pA->getSpecialBuffer()), lda, reinterpret_cast(pS->getSpecialBuffer()), calcUV ? reinterpret_cast(pU->getSpecialBuffer()) : reinterpret_cast(nullPtr), ldu, calcUV ? reinterpret_cast(pV->getSpecialBuffer()) : reinterpret_cast(nullPtr), ldv, reinterpret_cast(dWork), lwork, devInfo, gesvdjParams); + status = cusolverDnSgesvdj(*handle, jobz, econ, m, n, reinterpret_cast(pA->getSpecialBuffer()), lda, reinterpret_cast(pS->getSpecialBuffer()), calcUV ? reinterpret_cast(pU->getSpecialBuffer()) : reinterpret_cast(nullPtr), ldu, calcUV ? reinterpret_cast(pV->getSpecialBuffer()) : reinterpret_cast(nullPtr), ldv, reinterpret_cast(dWork), lwork, devInfo, gesvdjParams); } else throw std::invalid_argument("svdJcb: given data type is unsupported !"); @@ -446,8 +450,8 @@ static void svdJcb(sd::LaunchContext* context, const NDArray* A, NDArray* S, NDA cudaFree(devInfo); if (dWork ) cudaFree(dWork); - if(handle) - cusolverDnDestroy(handle); +// if(handle) +// cusolverDnDestroy(handle); if(gesvdjParams) cusolverDnDestroyGesvdjInfo(gesvdjParams); diff --git a/libnd4j/include/ops/declarable/helpers/gru.h b/libnd4j/include/ops/declarable/helpers/gru.h index 3fecfa71b..9e98e4046 100644 --- a/libnd4j/include/ops/declarable/helpers/gru.h +++ b/libnd4j/include/ops/declarable/helpers/gru.h @@ -31,10 +31,26 @@ namespace helpers { const NDArray* bru, const NDArray* bc, NDArray* r, NDArray* u, NDArray* c, NDArray* h); + void gruCell(sd::LaunchContext * context, const NDArray* x, const NDArray* hLast, const NDArray* Wru, const NDArray* Wc, const NDArray* b, + NDArray* gates, NDArray* h); + void gruTimeLoop(sd::LaunchContext * context, const NDArray* x, const NDArray* h0, const NDArray* Wx, const NDArray* Wh, const NDArray* b, NDArray* h); - void gruCellBP(sd::LaunchContext* context, const NDArray* x, const NDArray* hLast, const NDArray* W, const NDArray* Wc, const NDArray* b, const NDArray* bc, const NDArray* dLdr, const NDArray* dLdu, const NDArray* dLdc, const NDArray* dLdh, NDArray* dLdx, NDArray* dLdhLast, NDArray* dLdW, NDArray* dLdWc, NDArray* dLdb, NDArray* dLdbc); + void gruCellBp(sd::LaunchContext* context, + const NDArray* x, const NDArray* hLast, + const NDArray* W, const NDArray* Wc, const NDArray* b, const NDArray* bc, + const NDArray* dLdr, const NDArray* dLdu, const NDArray* dLdc, const NDArray* dLdh, + NDArray* dLdx, NDArray* dLdhLast, + NDArray* dLdW, NDArray* dLdWc, + NDArray* dLdb, NDArray* dLdbc); + void gruCellBp(sd::LaunchContext* context, + const NDArray* x, const NDArray* hI, const NDArray* Wx, const NDArray* Wh, const NDArray* b, const NDArray* dLdh, const NDArray* gates, + NDArray* dLdx, NDArray* dLdhI, NDArray* dLdWx, NDArray* dLdWh, NDArray* dLdb); + + void gruTimeLoopBp(sd::LaunchContext * context, + const NDArray* x, const NDArray* hI, const NDArray* Wx, const NDArray* Wh, const NDArray* b, const NDArray* dLdh, + NDArray* dLdx, NDArray* dLdhI, NDArray* dLdWx, NDArray* dLdWh, NDArray* dLdb); } } } diff --git a/libnd4j/include/ops/declarable/helpers/impl/gru.cpp b/libnd4j/include/ops/declarable/helpers/impl/gru.cpp new file mode 100644 index 000000000..277188428 --- /dev/null +++ b/libnd4j/include/ops/declarable/helpers/impl/gru.cpp @@ -0,0 +1,546 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2020 Konduit K.K. + * + * ThnIn program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which nIn available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * dnIntributed under the License nIn dnIntributed on an "AS nIn" BASnIn, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permnInsions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author Yurii Shyrma (iuriish@yahoo.com), created on 15.02.2018, Alex Black +// + +// implementation of gated Recurrent Unit cell +// (cf. https://arxiv.org/abs/1406.1078). +// Kyunghyun Cho, Bart van Merrienboer, Caglar Gulcehre, Dzmitry Bahdanau, Fethi Bougares, Holger Schwenk, Yoshua Bengio +// "Learning Phrase Representations using RNN Encoder-Decoder for StatnIntical Machine Translation" + + +#include +#include +#include +#include + +namespace sd { +namespace ops { +namespace helpers { + +////////////////////////////////////////////////////////////////////////// +void gruCell(sd::LaunchContext * context, const NDArray* x, const NDArray* hI, const NDArray* W, const NDArray* Wc, + const NDArray* b, const NDArray* bc, + NDArray* r, NDArray* u, NDArray* c, NDArray* h) { + + //Inputs: + // x input [bS, nIn], nIn - input size + // hI previous cell output [bS, nOut], that is at previous time step t-1, nOut - number of units + // W RU weights - [nIn+nOut, 2*nOut] - reset and update gates + // Wc C weights - [nIn+nOut, nOut] - cell gate + // b r and u biases, [2*nOut] - reset and update gates + // bc c biases, [nOut] - cell gate + + //Outputs: + // r Reset gate output [bS, nOut] + // u Update gate output [bS, nOut] + // c Cell gate output [bS, nOut] + // h current cell output [bS, nOut] + + /***************************************************************************************/ + /************************ THIS IS NOT OPTIMAZED CODE ***********************************/ + /** however it is more math-friendly and convenient for backprop formulas derivation) **/ + + const int bS = x->sizeAt(0); + const int nIn = x->sizeAt(1); + const int nOut = hI->sizeAt(1); + + NDArray Wrx = (*W)({0,nIn, 0,nOut}); // [nIn, nOut] + NDArray Wux = (*W)({0,nIn, nOut,2*nOut}); // [nIn, nOut] + NDArray Wrh = (*W)({nIn,nIn+nOut, 0,nOut}); // [nOut, nOut] + NDArray Wuh = (*W)({nIn,nIn+nOut, nOut,2*nOut}); // [nOut, nOut] + + NDArray Wcx = (*Wc)({0,nIn, 0,0}); // reset cell weights [nIn, nOut] + NDArray Wch = (*Wc)({nIn,nIn+nOut, 0,0}); // updates cell weights [nOut, nOut] + + NDArray br = (*b)({0, nOut}); // [nOut] + NDArray bu = (*b)({nOut, 2*nOut}); // [nOut] + + // × means matrix multipication + // * means element-wise product or so called Hadamard product + + // reset gate + r->assign(mmul(*x, Wrx) + mmul(*hI, Wrh) + br); // [bS, nIn] × [nIn, nOut] + [bS, nOut] × [nOut, nOut] + [nOut] = [bS, nOut] + r->applyTransform(transform::Sigmoid, *r); + + // update gate + u->assign(mmul(*x, Wux) + mmul(*hI, Wuh) + bu); // [bS, nIn] × [nIn, nOut] + [bS, nOut] × [nOut, nOut] + [nOut] = [bS, nOut] + u->applyTransform(transform::Sigmoid, *u); + + // cell gate c = activation(x × Wcx + (r * hlast) × Wch + bc) + c->assign(mmul(*x, Wcx) + mmul(*r * *hI, Wch) + *bc); // [bS, nIn] × [nIn, nOut] + [bS, nOut] × [nOut, nOut] + [nOut] = [bS, nOut] + c->applyTransform(transform::Tanh, *c); + + // cell output + h->assign(*u * *hI + (1.f - *u) * *c); +} + +////////////////////////////////////////////////////////////////////////// +void gruCell(sd::LaunchContext * context, const NDArray* x, const NDArray* hI, const NDArray* Wx, const NDArray* Wh, const NDArray* b, + NDArray* gates, NDArray* h) { + + //Inputs: + // x input [bS, nIn] + // hI previous cell output [bS, nOut], that is at previous time step t-1 + // Wx weights for x - [nIn, 3*nOut] + // Wh weights for h - [nOut, 3*nOut] + // b biases [3*nOut] + + // 3*nOut means following sequence: reset, update, cell + + //Outputs: + // gates [bS, 3*nOut] = reset gate [bS, nOut] + update gate [bS, nOut] + cell gate [bS, nOut] + // h current cell output [bS, nOut] + + // formulas: + // zr = x × Wxr + hI × Whr + br + // zu = x × Wxu + hI × Whu + bu + // r = sigmoid(zr) + // u = sigmoid(zu) + // zc = x × Wxc + (r * hI) × Whc + bc + // c = tanh(zc) + // h = (1-u)*c + u*hI + + const int bS = x->sizeAt(0); + const int nIn = x->sizeAt(1); + const int nOut = hI->sizeAt(1); + + NDArray temp = gates->ulike(); + MmulHelper::mmul(x, Wx, &temp); // [bS, nIn] × [nIn, 3*nOut] = [bS, 3*nOut] + temp += *b; + + MmulHelper::mmul(hI, Wh, gates); // [bS, nOut] × [nOut, 3*nOut] = [bS, 3*nOut] + + NDArray ru = (*gates)({0,0, 0,2*nOut}); // [bS, 2*nOut] + + NDArray r = (*gates)({0,0, 0,nOut}); // [bS, nOut] + NDArray u = (*gates)({0,0, nOut,2*nOut}); // [bS, nOut] + NDArray c = (*gates)({0,0, 2*nOut,3*nOut}); // [bS, nOut] + + // reset and update gates + ru += temp({0,0, 0,2*nOut}); + ru.applyTransform(transform::Sigmoid, ru); + + // cell gate + c.assign(c*r + temp({0,0, 2*nOut, 3*nOut})); + c.applyTransform(transform::Tanh, c); + + // cell output + h->assign(u * *hI + (1.f - u) * c); +} + +////////////////////////////////////////////////////////////////////////// +void gruTimeLoop(sd::LaunchContext * context, const NDArray* x, const NDArray* hI, const NDArray* Wx, const NDArray* Wh, const NDArray* b, NDArray* h) { + + // sL means time steps + + // x input [sL, bS, nIn] + // hI initial cell output (at time step = 0) [bS, nOut] + // Wx input-to-hidden weights, [nIn, 3*nOut] + // Wh hidden-to-hidden weights, [nOut, 3*nOut] + // b biases, [3*nOut] + + // h cell outputs at each time step [sL, bS, nOut] + + const int sL = x->sizeAt(0); + const int bS = x->sizeAt(1); + const int nOut = hI->sizeAt(1); + + NDArray gates(h->ordering(), {bS, 3*nOut}, h->dataType(), context); + + auto xSet = x->allTensorsAlongDimension({1,2}); // sub-arrays with shape [bS, nIn] + auto hSet = h->allTensorsAlongDimension({1,2}); // sub-arrays with shape [bS, nOut] + + // time loop + for (int t = 0; t < sL; ++t) + gruCell(context, xSet.at(t), t == 0 ? hI : hSet.at(t-1), Wx, Wh, b, &gates, hSet.at(t)); +} + +////////////////////////////////////////////////////////////////////////// +void gruCellBp(sd::LaunchContext* context, + const NDArray* x, const NDArray* hLast, + const NDArray* W, const NDArray* Wc, const NDArray* b, const NDArray* bc, + const NDArray* dLdr, const NDArray* dLdu, const NDArray* dLdc, const NDArray* dLdh, + NDArray* dLdx, NDArray* dLdhLast, + NDArray* dLdW, NDArray* dLdWc, + NDArray* dLdb, NDArray* dLdbc) { + + //Inputs: + // x input [bS, iS] + // hLast previous cell output [bS, nU], that is at previous time step t-1 + // W weights - [iS+nU, 2*nU] - reset and update gates + // Wc C weights - [iS+nU, nU] - cell gate + // b r and u biases, [2*nU] - reset and update gates + // bc c biases, [nU] - cell gate + // dLdr gradient wrt reset gate, [bS, nU] + // dLdu gradient wrt update gate, [bS, nU] + // dLdc gradient wrt cell state, [bS, nU] + // dLdh gradient wrt current cell output, [bS, nU] + + //Outputs: + // dLdx gradient wrt x, [bS, iS], + // dLdhLast gradient wrt hLast, [bS, nU] + // dLdW gradient wrt W, [iS+nU, 2*nU] + // dLdWc gradient wrt Wc, [iS+nU, nU] + // dLdb gradient wrt bru [2*nU] + // dLdbc gradient wrt bc [nU] + + // * means element-wise product or so called Hadamard product + // × means matrix multiplication + + /************************************************************************************************/ + /******************************* THIS IS NOT OPTIMAZED CODE *************************************/ + /*** aim is to have math-readable code in order to keep track of backprop formulas derivation ***/ + + const int bS = x->sizeAt(0); + const int iS = x->sizeAt(1); + const int nU = hLast->sizeAt(1); + + NDArray xT = x->transpose(); // [iS, bS] + NDArray hLastT = hLast->transpose(); // [nU, bS] + + NDArray Wrx = (*W)({0,iS, 0,nU}); // [iS, nU] + NDArray Wux = (*W)({0,iS, nU,2*nU}); // [iS, nU] + NDArray Wrh = (*W)({iS,iS+nU, 0,nU}); // [nU, nU] + NDArray Wuh = (*W)({iS,iS+nU, nU,2*nU}); // [nU, nU] + + NDArray Wcx = (*Wc)({0,iS, 0,0}); // reset cell weights [iS, nU] + NDArray Wch = (*Wc)({iS,iS+nU, 0,0}); // updates cell weights [nU, nU] + + NDArray br = (*b)({0, nU}); // [nU] + NDArray bu = (*b)({nU, 2*nU}); // [nU] + + NDArray WrxT = Wrx.transpose(); // [nU, iS] + NDArray WuxT = Wux.transpose(); // [nU, iS] + NDArray WrhT = Wrh.transpose(); // [nU, nU] + NDArray WuhT = Wuh.transpose(); // [nU, nU] + + NDArray WcxT = Wcx.transpose(); // [nU, iS] + NDArray WchT = Wch.transpose(); // [nU, nU] + + NDArray dLdWrx = (*dLdW)({0,iS, 0,nU}); // [iS, nU] + NDArray dLdWux = (*dLdW)({0,iS, nU,2*nU}); // [iS, nU] + NDArray dLdWrh = (*dLdW)({iS,iS+nU, 0,nU}); // [nU, nU] + NDArray dLdWuh = (*dLdW)({iS,iS+nU, nU,2*nU}); // [nU, nU] + + NDArray dLdWcx = (*dLdWc)({0,iS, 0,0}); // [iS, nU] + NDArray dLdWch = (*dLdWc)({iS,iS+nU, 0,0}); // [nU, nU] + + NDArray dLdbr = (*dLdb)({0, nU}); // [nU] + NDArray dLdbu = (*dLdb)({nU, 2*nU}); // [nU] + + + // ***** feed forward step ***** // + + // reset gate + NDArray r = mmul(*x, Wrx) + mmul(*hLast, Wrh) + br; // [bS, iS] × [iS, nU] + [bS, nU] × [nU, nU] + [nU] = [bS, nU] + r.applyTransform(transform::Sigmoid, r); + + // update gate + NDArray u = mmul(*x, Wux) + mmul(*hLast, Wuh) + bu; // [bS, iS] × [iS, nU] + [bS, nU] × [nU, nU] + [nU] = [bS, nU] + u.applyTransform(transform::Sigmoid, u); + + // cell gate c = activation(x×Wcx + (r*hlast)×Wcu + bc) + NDArray c = mmul(*x, Wcx) + mmul(r * *hLast, Wch) + *bc; // [bS, iS] × [iS, nU] + [bS, nU] × [nU, nU] + [nU] = [bS, nU] + c.applyTransform(transform::Tanh, c); + + // h = (1 - u) * c + u * hPrev + + + // ***** back prop step ***** // + + // notations: + // Zr = x × Wrx + hLast × Wrh + br + // Zu = x × Wux + hLast × Wuh + bu + // Sr = sigmoid(Zr) + // Su = sigmoid(Zu) + // Zc = x × Wcx + (r * hlast) × Wch + bc + + + // dLdx = dLdh * dhdx = dLdh * (dhdu * dudx + dhdc * dcdx) = (dLdh * dhdu) * dudx + (dLdh * dhdc) * dcdx = dLdu * dudx + dLdc * dcdx + // = dLdx_u + dLdx_c + // dLdx_u = dLdu * dudx = dLdu * dudZu * dZudx = |dZudx = ... × WuxT| = (dLdu * dudZu) × WuxT + // dLdx_c = dLdc * dcdx = dLdc * dcdZc * (dZcdx + dZcdr * drdx) = dLdc * dcdZc * dZcdx + dLdc * dcdZc * dZcdr * drdx = dLdx_c0 + dLdx_c1 + // dLdx_c0 = dLdc * dcdZc * dZcdx = |dZcdx = ... × WcxT| = (dLdc * dcdZc) × WcxT + // dZcdr = (... * hLast) × WchT + // dLdc * dcdZc * dZcdr = dLdr = (dLdc * dcdZc * hLast) × WchT + // drdx = drdZr * dZrdx + // dZrdx = ... × WrxT + // dLdx_c1 = dLdc * dcdZc * dZcdr * drdx = dLdr * drdx = (dLdr * drdZr) × WrxT + // finally dLdx = dLdx_u + dLdx_c0 + dLdx_c1 = (dLdu * dudZu) × WuxT + (dLdc * dcdZc) × WcxT + (dLdr * drdZr) × WrxT + + + // dLdhLast = dLdh * (dhdhLast + dhdu * dudhLast + dhdc * dcdhLast) = dLdh * dhdhLast + dLdu * dudhLast + dLdc * dcdhLast + // = dLdhLast_h + dLdhLast_u + dLdhLast_c + // dLdhLast_h = dLdh * dhdhLas = dLdh * u + // dLdhLast_u = dLdu * dudhLast = |dudhLast = dudZu * dZudhLast , dZudhLast = ... × WuhT| = (dLdu * dudZu) × WuhT + // dLdhLast_c = dLdc * dcdhLast = dLdc * (dcdZc * dZcdhLast + dcdZc * dZcdr * drdhLast) = + // = dLdc * dcdZc * dZcdhLast + dLdc * dcdZc * dZcdr * drdhLast = + // = dLdc * dcdZc * dZcdhLast + dLdr * drdhLast = dLdhLast_c0 + dLdhLast_c1 + // dLdhLast_c0 = dLdc * dcdZc * dZcdhLast = |dZcdhLast = (... * r) × WchT| = (dLdc * dcdZc * r) × WchT + // dLdhLast_c1 = dLdr * drdhLast = |drdhLast = drdZr * dZrdhLast, dZrdhLast = ... × WrhT| = (dLdr * drdZr) × WrhT + // finally dLdhLast = dLdhLast_h + dLdhLast_u + dLdhLast_c0 + dLdhLast_c1 = + // = dLdh * u + (dLdu * dudZu) × WuhT + (dLdc * dcdZc * r) × WchT + (dLdr * drdZr) × WrhT + + + // dLdWrx = dLdh * dhdWrx = (dLdh * dhdc) * dcdWrx = dLdc * dcdZc * dZcdWrx = dLdc * dcdZc * dZcdr * drdWrx = + // = dLdc * dcdZc * dZcdr * drdZr * dZrdWrx = dLdr * drdZr * dZrdWrx + // dZrdWrx = xT × ... + // finally dLdWrx = xT × (dLdr * drdZr) + + + // dLdWrh = dLdh * dhdWrh = (dLdh * dhdc) * dcdWrh = dLdc * dcdZc * dZcdWrh = dLdc * dcdZc * dZcdr * drdWrh = + // = dLdc * dcdZc * dZcdr * drdZr * dZrdWrh = dLdr * drdZr * dZrdWrh + // dZrdWrh = hLastT × ... + // finally dLdWrh = hLastT × (dLdr * drdZr) + + + // dLdWux = dLdh * dhdWux = (dLdh * dhdu) * dudWux = dLdu * dudZu * dZudWux + // dZudWux = xT × ... + // dLdu * dudZu * dZudWux = xT × (dLdu * dudZu) + + + // dLdWuh = dLdh * dhdWuh = (dLdh * dhdu) * dudWuh = dLdh * dhdu * dudZu * dZudWuh = dLdu * dudZu * dZudWuh + // dZudWuh = hLastT × ... + // finally dLdWuh = hLastT × (dLdu * dudZu) + + + // dLdWcx = dLdh * dhdWcx = dLdh * dhdc * dcdWcx = (dLdh * dhdc) * dcdZc * dZcdWcx = dLdc * dcdZc * dZcdWcx + // dZcdWcx = xT × ... + // finally dLdWcx = xT × (dLdc * dcdZc) + + + // dLdWch = dLdh * dhdWch = dLdh * dhdc * dcdWch = (dLdh * dhdc) * dcdZc * dZcdWch = dLdc * dcdZc * dZcdWch + // dZcdWch = (r*hLast)^T × ... + // finally dLdWch = (r*hLast)^T × (dLdc * dcdZc) + + + // dLdbr = dLdh * dhdbr = (dLdh * dhdc) * dcdbr = dLdc * dcdbr = dLdc * dcdZc * dZcdbr = dLdc * dcdZc * dZcdr * drdbr = + // = dLdr * drdZr * dZrdbr + // dZrdbr = 1 + // finally dLdbr = dLdr * drdZr + + + // dLdbu = dLdh * dhdbu = (dLdh * dhdu) * dudbu = dLdu * dudZu * dZudbu + // dZudbu = 1 + // finally dLdbu = dLdu * dudZu + + + // dLdbc = dLdh * dhdbc = (dLdh * dhdc) * dcdbc = dLdc * dcdZc * dZcdbc + // dZcdbc = 1 + // finally dLdbc = dLdc * dcdZc + + NDArray dhdc = 1.f - u; // [bS, nU] + NDArray dhdu = *hLast - c; // [bS, nU] + NDArray dudZu = u * dhdc; // [bS, nU] + NDArray drdZr = r * (1.f - r); // [bS, nU] + NDArray dcdZc = 1.f - c * c; // [bS, nU] + NDArray dLdZc = *dLdc * dcdZc; // [bS, nU] + NDArray dLdZu = *dLdu * dudZu; // [bS, nU] + NDArray dLdZr = *dLdr * drdZr; // [bS, nU] + + // NDArray dLdc = *dLdh * dhdc; // [bS, nU] + // NDArray dLdu = *dLdh * dhdu; // [bS, nU] + // NDArray dLdr = mmul(dLdc * dcdZc * *hLast, WchT); // [bS, nU] + + dLdx->assign(mmul(dLdZu, WuxT) + mmul(dLdZc, WcxT) + mmul(dLdZr, WrxT)); // [bS, iS] + + dLdhLast->assign(*dLdh * u + mmul(dLdZu, WuhT) + mmul(dLdZc * r, WchT) + mmul(dLdZr, WrhT)); // [bS, nU] + + dLdWrx.assign(mmul(xT, dLdZr)); // [iS, bS] × [bS, nU] = [iS, nU] + dLdWrh.assign(mmul(hLastT, dLdZr)); // [nU, bS] × [bS, nU] = [nU, nU] + dLdWux.assign(mmul(xT, dLdZu)); // [iS, bS] × [bS, nU] = [iS, nU] + dLdWuh.assign(mmul(hLastT, dLdZu)); // [nU, bS] × [bS, nU] = [nU, nU] + + dLdWcx.assign(mmul(xT, dLdZc)); // [iS, bS] × [bS, nU] = [iS, nU] + dLdWch.assign(mmul((r * *hLast).transpose(), dLdZc)); // [nU, bS] × [bS, nU] = [nU, nU] + + dLdbr.assign(dLdZr.reduceAlongDimension(reduce::Sum, {0})); // [nU] + dLdbu.assign(dLdZu.reduceAlongDimension(reduce::Sum, {0})); // [nU] + + dLdbc->assign(dLdZc.reduceAlongDimension(reduce::Sum, {0})); // [nU] +} + + +////////////////////////////////////////////////////////////////////////// +void gruCellBp(sd::LaunchContext* context, + const NDArray* x, const NDArray* hI, const NDArray* Wx, const NDArray* Wh, const NDArray* b, const NDArray* dLdh, const NDArray* gates, + NDArray* dLdx, NDArray* dLdhI, NDArray* dLdWx, NDArray* dLdWh, NDArray* dLdb) { + + //Inputs: + // x input [bS, nIn] + // hI previous cell output [bS, nOut], that nIn at previous time step t-1 + // Wx input-to-hidden weights - [nIn, 3*nOut] + // Wh hidden-to-hidden weights - [nOut, 3*nOut] + // b biases, [3*nOut] - reset and update gates + // dLdh gradient vs. ff output, [bS, nOut] + + //Outputs: + // dLdx gradient vs. x, [bS, nIn], + // dLdhI gradient vs. hI, [bS, nOut] + // dLdWx gradient vs. W, [nIn, 3*nOut] + // dLdWh gradient vs. Wc, [nOut, 3*nOut] + // dLdb gradient vs. b [3*nOut] + + // 3*nOut means following sequence: reset, update, cell + + // * means element-wnIne product or so called Hadamard product + // × means matrix multiplication + + // formulas: + // zr = x × Wxr + hI × Whr + br + // zu = x × Wxu + hI × Whu + bu + // r = sigmoid(zr) + // u = sigmoid(zu) + // zc = x × Wxc + (r * hI) × Whc + bc + // c = tanh(zc) + // h = (1-u)*c + u*hI + + // dLdhI += dLdh; [bS, nOut] + + + // dhdc = 1 - u [bS, nOut] + // dhdu = -c + hI [bS, nOut] + + // dcdzc = 1 - c*c; [bS, nOut] + // dudzu = u*(1-u) [bS, nOut] + // drdzr = r(1-r) [bS, nOut] + + // dzcdr = (...*hI × WhcT) [bS, nOut] + + // dLdzr = dLdh*dhdc*dcdzc*dzcdr*drdzr = (dLdzc*hI*r(1-r) × WhcT); [bS, nOut] + // dLdzu = dLdh*dhdu*dudzu = dLdh*(hI-c)*u*(1-u) [bS, nOut] + // dLdzc = dLdh*dhdc*dcdzc = dLdh*(1-u)*(1-c*c) [bS, nOut] + + // dLdx = dLdzr × WxrT + dLdzu × WxuT + dLdzc × WxcT, [bs, nOut] × [nOut, nIn] + ... = [bS, nIn] + + // dLdhI = dLdzr × WhrT + dLdzu × WhuT + dLdzc × WhcT, [bs, nOut] × [nOut, nOut] + ... = [bS, nOut] + + // dLdWxr = xT × dLdzr [nIn, bS] x [bS, nOut] = [nIn, nOut] + // dLdWxu = xT × dLdzu [nIn, bS] x [bS, nOut] = [nIn, nOut] + // dLdWxc = xT × dLdzc [nIn, bS] x [bS, nOut] = [nIn, nOut] + + // dLdWhr = xT × dLdzr [nOut, bS] x [bS, nOut] = [nOut, nOut] + // dLdWhu = xT × dLdzu [nOut, bS] x [bS, nOut] = [nOut, nOut] + // dLdWhc = (r*hI)T × dLdzc [nOut, bS] x [bS, nOut] = [nOut, nOut] + + // dLdbr = dLdzr.reduce_sum_along_0_axis [bS, nOut] -> reduce -> [nOut] + // dLdbu = dLdzu.reduce_sum_along_0_axis [bS, nOut] -> reduce -> [nOut] + // dLdbc = dLdzc.reduce_sum_along_0_axis [bS, nOut] -> reduce -> [nOut] + + const int nOut = hI->sizeAt(1); + + NDArray dLdz = gates->ulike(); // [bS, 3*nOut] + + NDArray dLdzru = dLdz({0,0, 0,2*nOut}); // [bS, 2*nOut] + + NDArray dLdzr = dLdz({0,0, 0,nOut}); // [bS, nOut] + NDArray dLdzu = dLdz({0,0, nOut,2*nOut}); // [bS, nOut] + NDArray dLdzc = dLdz({0,0, 2*nOut,3*nOut}); // [bS, nOut] + + NDArray r = (*gates)({0,0, 0,nOut}); // [bS, nOut] + NDArray u = (*gates)({0,0, nOut,2*nOut}); // [bS, nOut] + NDArray c = (*gates)({0,0, 2*nOut,3*nOut}); // [bS, nOut] + + NDArray WhcT = (*Wh)({0,0, 2*nOut,3*nOut}).transpose(); + + if(dLdh) + *dLdhI += *dLdh; + + NDArray temp1 = 1 - u; // [bS, nOut] + + // dLdzc + dLdzc.assign(*dLdhI * temp1 * (1-c*c)); // [bS, nOut] + + // dLdzu + dLdzu.assign(*dLdhI * (*hI - c) * u * temp1); // [bS, nOut] + + // dLdzr + NDArray temp2 = dLdzc * (*hI) * r *(1-r); + MmulHelper::mmul(&temp2, &WhcT, &dLdzr); // [bS, nOut] x [nOut, nOut] = [bS, nOut] + + // dLdx + NDArray WxT = Wx->transpose(); + MmulHelper::mmul(&dLdz, &WxT, dLdx); // [bS, 3*nOut] x [3*nOut, nIn] = [bS, nIn] + + // dLdWx + *dLdWx += mmul(x->transpose(), dLdz); // [nIn, bS] x [bS, 3*nOut] = [nIn, 3*nOut] + + // dLdb + *dLdb += dLdz.reduceAlongDimension(reduce::Sum, {0}); // [bS, 3*nOut] -> reduce -> [3*nOut]; + + dLdzc *= r; + + // dLdhI + NDArray WhT = Wh->transpose(); + dLdhI->assign(*dLdhI*u + mmul(dLdz, WhT)); // [bS, 3*nOut] x [3*nOut, nOut] = [bS, nOut] + + // dLdWr + *dLdWh += mmul(hI->transpose(), dLdz); // [nOut, bS] x [bS, 3*nOut] = [nOut, 3*nOut] +} + + +////////////////////////////////////////////////////////////////////////// +void gruTimeLoopBp(sd::LaunchContext * context, + const NDArray* x, const NDArray* hI, const NDArray* Wx, const NDArray* Wh, const NDArray* b, const NDArray* dLdh, + NDArray* dLdx, NDArray* dLdhI, NDArray* dLdWx, NDArray* dLdWh, NDArray* dLdb) { + // sL means time steps + + // x input [sL, bS, nIn] + // hI initial cell output (at time step = 0) [bS, nOut] + // Wx input-to-hidden weights, [nIn, 3*nOut] + // Wh hidden-to-hidden weights, [nOut, 3*nOut] + // b biases, [3*nOut] + // dLdh gradient vs. ff output, [sL, bS, nOut] + + // dLdx gradient vs. x, [sL, bS, nIn], + // dLdhI gradient vs. hI, [bS, nOut] + // dLdWx gradient vs. W, [nIn, 3*nOut] + // dLdWh gradient vs. Wc, [nOut, 3*nOut] + // dLdb gradient vs. b [3*nOut] + + const int sL = x->sizeAt(0); + const int bS = x->sizeAt(1); + const int nOut = hI->sizeAt(1); + + NDArray gates(x->ordering(), {sL, bS, 3*nOut}, dLdh->dataType(), x->getContext()); + NDArray h(x->ordering(), {sL+1, bS, nOut}, dLdh->dataType(), x->getContext()); + + auto xSet = x->allTensorsAlongDimension({1,2}); // sub-arrays with shape [bS, nIn] + auto dLdhSet = dLdh->allTensorsAlongDimension({1,2}); // sub-arrays with shape [bS, nOut] + auto hSet = h.allTensorsAlongDimension({1,2}); // sub-arrays with shape [bS, nOut] + auto gatesSet = gates.allTensorsAlongDimension({1,2}); // sub-arrays with shape [bS, nOut] + auto dLdxSet = dLdx->allTensorsAlongDimension({1,2}); // sub-arrays with shape [bS, nIn] + + hSet.at(0)->assign(hI); + + // forward time loop + for (int t = 0; t < sL; ++t) + gruCell(context, xSet.at(t), hSet.at(t), Wx, Wh, b, gatesSet.at(t), hSet.at(t+1)); + + // backward time loop + for (int t = sL-1; t >= 0; --t) + gruCellBp(context, xSet.at(t), hSet.at(t), Wx, Wh, b, dLdhSet.at(t), gatesSet.at(t), + dLdxSet.at(t), dLdhI, dLdWx, dLdWh, dLdb); +} + + +} +} +} diff --git a/libnd4j/include/ops/declarable/helpers/impl/lstmLayer.cpp b/libnd4j/include/ops/declarable/helpers/impl/lstmLayer.cpp index 9fce17c4b..bffd13128 100644 --- a/libnd4j/include/ops/declarable/helpers/impl/lstmLayer.cpp +++ b/libnd4j/include/ops/declarable/helpers/impl/lstmLayer.cpp @@ -189,54 +189,6 @@ static FORCEINLINE int getBatchTimeTotalIndex(const int dataFormat, const int sL return b * sL + t; // NTS, NST: shape [bS, sL, nIn], [bS, nIn, sL] } -////////////////////////////////////////////////////////////////////////// -// x{M,K} x y{K,N} = z{M,N}, dzdy{K,N,M,N} - Jacobian derivative -> if x.rankOf() == 2 -// x{K} x y{K,N} = z{N}, dzdy{K,N,N} - Jacobian derivative -> if x.rankOf() == 1 -static NDArray mmulJacobianWeightsDeriv(const int nOut, const NDArray& x) { - - std::vector outShape = x.rankOf() == 1 ? std::vector({x.sizeAt(0), nOut, nOut}) : std::vector({x.sizeAt(1), nOut, x.sizeAt(0), nOut}); - - NDArray dzdy(x.ordering(), outShape, x.dataType(), x.getContext()); - - if(x.rankOf() == 1) { - auto func = PRAGMA_THREADS_FOR_3D { - - for (auto i0 = start_x; i0 < stop_x; ++i0) { - for (auto i1 = start_y; i1 < stop_y; ++i1) { - for (auto i2 = start_z; i2 < stop_z; ++i2) { - if(i1 == i2) - dzdy.p(i0,i1,i2, x.e(i0)); - else - dzdy.p(i0,i1,i2, 0); - } - } - } - }; - - samediff::Threads::parallel_for(func, 0,dzdy.sizeAt(0),1, 0,dzdy.sizeAt(1),1, 0,dzdy.sizeAt(2),1); - } - else { - auto func = PRAGMA_THREADS_FOR_3D { - - for (auto i0 = start_x; i0 < stop_x; ++i0) { - for (auto i1 = start_y; i1 < stop_y; ++i1) { - for (auto i2 = start_z; i2 < stop_z; ++i2) { - for (auto i3 = 0; i3 < dzdy.sizeAt(3); ++i3) { - if(i1 == i3) - dzdy.p(i0,i1,i2,i3, x.e(i2,i0)); - else - dzdy.p(i0,i1,i2,i3, 0); - } - } - } - } - }; - - samediff::Threads::parallel_for(func, 0,dzdy.sizeAt(0),1, 0,dzdy.sizeAt(1),1, 0,dzdy.sizeAt(2),1); - } - - return dzdy; -} ////////////////////////////////////////////////////////////////////////// void lstmLayerCell(const NDArray* x, const NDArray* Wx, const NDArray* Wr, @@ -245,25 +197,25 @@ void lstmLayerCell(const NDArray* x, const NDArray* Wx, const NDArray* Wr, NDArray* h, NDArray* c) { // * -> means element-wise multiplication - // ^ -> means matrix multiplication + // × -> means matrix multiplication /************************ THIS IS NOT OPTIMAZED CODE ***********************************/ /** the objective is to provide math-readable code **/ // equations (no peephole connections) - // it = σ(Wxi ^ xt + Wri ^ ht-1 + bi) - // ft = σ(Wxf ^ xt + Wrf ^ ht-1 + bf) - // c't = tanh(Wxc ^ xt + Wrc ^ ht-1 + bc) + // it = σ(Wxi × xt + Wri × ht-1 + bi) + // ft = σ(Wxf × xt + Wrf × ht-1 + bf) + // c't = tanh(Wxc × xt + Wrc × ht-1 + bc) // ct = ft * ct-1 + it * c't - // ot = σ(Wxo ^ xt + Wro ^ ht-1 + bo) + // ot = σ(Wxo × xt + Wro × ht-1 + bo) // ht = ot * tanh(ct) // equations (peephole connections are present) - // it = σ(Wxi ^ xt + Wri ^ ht-1 + Wpi * ct-1 + bi) - // ft = σ(Wxf ^ xt + Wrf ^ ht-1 + Wpf * ct-1 + bf) - // c't = tanh(Wxc ^ xt + Wrc ^ ht-1 + bc) + // it = σ(Wxi × xt + Wri × ht-1 + Wpi * ct-1 + bi) + // ft = σ(Wxf × xt + Wrf × ht-1 + Wpf * ct-1 + bf) + // c't = tanh(Wxc × xt + Wrc × ht-1 + bc) // ct = ft * ct-1 + it * c't - // ot = σ(Wxo ^ xt + Wro ^ ht-1 + Wpo * ct + bo) + // ot = σ(Wxo × xt + Wro × ht-1 + Wpo * ct + bo) // ht = ot * tanh(ct) @@ -399,7 +351,7 @@ void lstmLayerCell(const NDArray* x, const NDArray* Wx, const NDArray* Wr, ////////////////////////////////////////////////////////////////////////// void lstmLayerCellBp(const NDArray* x, const NDArray* Wx, const NDArray* Wr, const NDArray* b, const NDArray* hI, const NDArray* cI, const NDArray* Wp, - const NDArray* dLdh, const NDArray* dLdc, + const NDArray* dLdh, const NDArray* dLdhL, const NDArray* dLdcL, const NDArray* z, const NDArray* a, const NDArray* c, const std::vector& params, NDArray* dLdx, NDArray* dLdWx, NDArray* dLdWr, NDArray* dLdhI, NDArray* dLdcI, NDArray* dLdb, NDArray* dLdWp) { @@ -407,10 +359,10 @@ void lstmLayerCellBp(const NDArray* x, const NDArray* Wx, const NDArray* Wr, con /** the objective is to provide math-readable code **/ // equations (no peephole connections) - // zi = x ^ Wxi + hI ^ Wri + bi - // zf = x ^ Wxf + hI ^ Wrf + bf - // zg = x ^ Wxg + hI ^ Wrg + bg - // zo = x ^ Wxo + hI ^ Wro + bo + // zi = x × Wxi + hI × Wri + bi + // zf = x × Wxf + hI × Wrf + bf + // zg = x × Wxg + hI × Wrg + bg + // zo = x × Wxo + hI × Wro + bo // i = act(zi) // f = act(zf) // g = actC(zg) @@ -419,10 +371,10 @@ void lstmLayerCellBp(const NDArray* x, const NDArray* Wx, const NDArray* Wr, con // h = o * actH(c) // equations (peephole connections are present) - // zi = x ^ Wxi + hI ^ Wri + cI * Wpi + bi - // zf = x ^ Wxf + hI ^ Wrf + cI * Wpf + bf - // zg = x ^ Wxg + hI ^ Wrg + bg - // zo = x ^ Wxo + hI ^ Wro + c * Wpo + bo + // zi = x × Wxi + hI × Wri + cI * Wpi + bi + // zf = x × Wxf + hI × Wrf + cI * Wpf + bf + // zg = x × Wxg + hI × Wrg + bg + // zo = x × Wxo + hI × Wro + c * Wpo + bo // i = act(zi) // f = act(zf) // g = actC(zg) @@ -449,18 +401,19 @@ void lstmLayerCellBp(const NDArray* x, const NDArray* Wx, const NDArray* Wr, con // params[11] - beta value for output activation // INPUTS: - // x - current input at time t, [bS, nIn] or [nIn] if seqLen != nullptr - // Wx - input weights [nIn, 4*nOut] - // Wr - recurrent weights [nOut, 4*nOut] - // b - biases [4*nOut], optional, may be nullptr - // hI - (ht-1) previous (initial) output at time t-1, [bS, nOut] or [nOut] if seqLen != nullptr - // cI - (ct-1) previous (initial) cell state at time t-1, [bS, nOut] or [nOut] if seqLen != nullptr - // Wp - peephole weights [3*nOut], optional, may be nullptr - // dLdh - loss derivative with respect to h, [bS, nOut] or [nOut] if seqLen != nullptr - // dLdc - loss derivative with respect to c, [bS, nOut] or [nOut] if seqLen != nullptr - // z - zi,zf,zg,zo taken from ff outputs to reduce amount of calculations in bp, [bS, 4*nOut] - // a - i,f,g,o taken from ff outputs to reduce amount of calculations in bp, [bS, 4*nOut] - // c - taken from ff outputs to reduce amount of calculations in bp, [bS, nOut] + // x - current input at time t, [bS, nIn] or [nIn] if seqLen != nullptr + // Wx - input weights [nIn, 4*nOut] + // Wr - recurrent weights [nOut, 4*nOut] + // b - biases [4*nOut], optional, may be nullptr + // hI - (ht-1) previous (initial) output at time t-1, [bS, nOut] or [nOut] if seqLen != nullptr + // cI - (ct-1) previous (initial) cell state at time t-1, [bS, nOut] or [nOut] if seqLen != nullptr + // Wp - peephole weights [3*nOut], optional, may be nullptr + // dLdh - loss derivative with respect to h at each time step, [bS, nOut] or [nOut] if seqLen != nullptr + // dLdhL - loss derivative with respect to h at last time step, [bS, nOut] or [nOut] if seqLen != nullptr + // dLdcL - loss derivative with respect to c at last time step, [bS, nOut] or [nOut] if seqLen != nullptr + // z - zi,zf,zg,zo taken from ff outputs to reduce amount of calculations in bp, [bS, 4*nOut] + // a - i,f,g,o taken from ff outputs to reduce amount of calculations in bp, [bS, 4*nOut] + // c - taken from ff outputs to reduce amount of calculations in bp, [bS, nOut] // OUTPUTS: // dLdx - loss derivative with respect to x, [bS, nIn] or [nIn] if seqLen != nullptr @@ -485,19 +438,19 @@ void lstmLayerCellBp(const NDArray* x, const NDArray* Wx, const NDArray* Wr, con // dLdzg = dLdcI*dcdg*dgdzg; [bS, nOut](or[nOut]) // dLdzo = dLdhI*dhdo*dodzo; [bS, nOut](or[nOut]) - // dLdx = dLdzi^WxiT + dLdzf^WxfT + dLdzg^WxgT + dLdzo^WxoT, [bS, nIn] - // dLdhI = dLdzi^WriT + dLdzf^WrfT + dLdzg^WrgT + dLdzo^WroT, [bS, nOut] + // dLdx = dLdzi×WxiT + dLdzf×WxfT + dLdzg×WxgT + dLdzo×WxoT, [bS, nIn] + // dLdhI = dLdzi×WriT + dLdzf×WrfT + dLdzg×WrgT + dLdzo×WroT, [bS, nOut] // dLdcI = dLdcI*dcdcI, [bS, nOut] - // dLdWxi = xT^dLdzi [nIn, bS] x [bS, nOut] = [nIn, nOut] - // dLdWxf = xT^dLdzf [nIn, bS] x [bS, nOut] = [nIn, nOut] - // dLdWxg = xT^dLdzg [nIn, bS] x [bS, nOut] = [nIn, nOut] - // dLdWxo = xT^dLdzo [nIn, bS] x [bS, nOut] = [nIn, nOut] + // dLdWxi = xT×dLdzi [nIn, bS] x [bS, nOut] = [nIn, nOut] + // dLdWxf = xT×dLdzf [nIn, bS] x [bS, nOut] = [nIn, nOut] + // dLdWxg = xT×dLdzg [nIn, bS] x [bS, nOut] = [nIn, nOut] + // dLdWxo = xT×dLdzo [nIn, bS] x [bS, nOut] = [nIn, nOut] - // dLdWri = hIT^dLdzi [nOut, bS] x [bS, nOut] = [nOut, nOut] - // dLdWrf = hIT^dLdzf [nOut, bS] x [bS, nOut] = [nOut, nOut] - // dLdWrg = hIT^dLdzg [nOut, bS] x [bS, nOut] = [nOut, nOut] - // dLdWro = hIT^dLdzo [nOut, bS] x [bS, nOut] = [nOut, nOut] + // dLdWri = hIT×dLdzi [nOut, bS] x [bS, nOut] = [nOut, nOut] + // dLdWrf = hIT×dLdzf [nOut, bS] x [bS, nOut] = [nOut, nOut] + // dLdWrg = hIT×dLdzg [nOut, bS] x [bS, nOut] = [nOut, nOut] + // dLdWro = hIT×dLdzo [nOut, bS] x [bS, nOut] = [nOut, nOut] // dLdbi = dLdzi.reduce_sum_along_0_axis [bS, nOut] -> reduce -> [nOut] // dLdbf = dLdzf.reduce_sum_along_0_axis [bS, nOut] -> reduce -> [nOut] @@ -563,10 +516,12 @@ void lstmLayerCellBp(const NDArray* x, const NDArray* Wx, const NDArray* Wr, con if(dLdh) *dLdhI += *dLdh; - if(dLdc) - *dLdcI += *dLdc; - else - *dLdcI += *dLdhI * dhdc; + if(dLdhL) + *dLdhI += *dLdhL; + if(dLdcL) + *dLdcI += *dLdcL; + + *dLdcI += *dLdhI * dhdc; dLdzi *= *dLdcI; // [bS, nOut](or[nOut]) dLdzf *= *dLdcI; // [bS, nOut](or[nOut]) @@ -662,25 +617,27 @@ void lstmLayerTimeLoop(const NDArray* x, const NDArray* Wx, const NDArray* Wr, const std::vector shapeOut = {bS, nOut}; + const auto type = h ? h->dataType() : (hL ? hL->dataType() : cL->dataType()); + auto h0 = const_cast(hI); if(!hI) { - h0 = new NDArray(x->ordering(), shapeOut, x->dataType(), x->getContext()); + h0 = new NDArray(x->ordering(), shapeOut, type, x->getContext()); h0->nullify(); } auto c0 = const_cast(cI); if(!cI) { - c0 = new NDArray(x->ordering(), shapeOut, x->dataType(), x->getContext()); + c0 = new NDArray(x->ordering(), shapeOut, type, x->getContext()); c0->nullify(); } auto ct = cL; if(!cL) - ct = new NDArray(x->ordering(), shapeOut, x->dataType(), x->getContext()); + ct = new NDArray(x->ordering(), shapeOut, type, x->getContext()); auto ht = hL; if(!h && !hL) - ht = new NDArray(x->ordering(), shapeOut, x->dataType(), x->getContext()); + ht = new NDArray(x->ordering(), shapeOut, type, x->getContext()); // create sets of required (depends on seqLen presence) sub-arrays std::vector dims; @@ -989,17 +946,19 @@ void lstmLayerTimeLoopBp(const NDArray* x, const NDArray* Wx, const NDArray* Wr, const int bS = dataFormat == 1 || dataFormat == 2 ? x->sizeAt(0) : x->sizeAt(1); const int nOut = Wx->sizeAt(-1) / 4; + const auto type = dLdh ? dLdh->dataType() : (dLdhL ? dLdhL->dataType() : dLdcL->dataType()); + auto dLdh0 = dLdhI; if(!hI) - dLdh0 = new NDArray(x->ordering(), {bS, nOut}, x->dataType(), x->getContext()); // this constructor nullifies array automatically + dLdh0 = new NDArray(x->ordering(), {bS, nOut}, type, x->getContext()); // this constructor nullifies array automatically auto dLdc0 = dLdcI; if(!cI) - dLdc0 = new NDArray(x->ordering(), {bS, nOut}, x->dataType(), x->getContext()); // this constructor nullifies array automatically + dLdc0 = new NDArray(x->ordering(), {bS, nOut}, type, x->getContext()); // this constructor nullifies array automatically - NDArray z(x->ordering(), {sL, bS, 4*nOut}, x->dataType(), x->getContext()); + NDArray z(x->ordering(), {sL, bS, 4*nOut}, type, x->getContext()); NDArray a = z.ulike(); - NDArray h(x->ordering(), {sL+1, bS, nOut}, x->dataType(), x->getContext()); + NDArray h(x->ordering(), {sL+1, bS, nOut}, type, x->getContext()); NDArray c = h.ulike(); // create sets of required (depends on seqLen presence) sub-arrays @@ -1041,9 +1000,9 @@ void lstmLayerTimeLoopBp(const NDArray* x, const NDArray* Wx, const NDArray* Wr, if(dLdh) dLdhSet = new ResultSet(dLdh->allTensorsAlongDimension(dims)); // sub-arrays with shape [nOut] - if(!dLdh && dLdhL) + if(dLdhL) dLdhLSet = new ResultSet(dLdhL->allTensorsAlongDimension({1})); // sub-arrays with shape [nOut] - if(!dLdh && !dLdhL) + if(dLdcL) dLdcLSet = new ResultSet(dLdcL->allTensorsAlongDimension({1})); // sub-arrays with shape [nOut] } @@ -1054,13 +1013,13 @@ void lstmLayerTimeLoopBp(const NDArray* x, const NDArray* Wx, const NDArray* Wr, if(!seqLen) { // seqLen is absent if(hI) - h({0,1, 0,0, 0,0}).assign(hI); + hSet->at(0)->assign(hI); else - h({0,1, 0,0, 0,0}).nullify(); + hSet->at(0)->nullify(); if(cI) - c({0,1, 0,0, 0,0}).assign(cI); + cSet->at(0)->assign(cI); else - c({0,1, 0,0, 0,0}).nullify(); + cSet->at(0)->nullify(); // ff for (int t = 0; t < sL; ++t) @@ -1068,9 +1027,10 @@ void lstmLayerTimeLoopBp(const NDArray* x, const NDArray* Wx, const NDArray* Wr, // bp for (int t = sL-1; t >= 0; --t) { - const NDArray* dLdhh = dLdh ? dLdhSet->at(t) : (t == sL-1 ? dLdhL : nullptr); - const NDArray* dLdcc = dLdhh ? nullptr : (t == sL-1 ? dLdcL : nullptr); - lstmLayerCellBp(xSet->at(t), Wx, Wr, b, hSet->at(t), cSet->at(t), Wp, dLdhh, dLdcc, + const NDArray* dLdhh = dLdh ? dLdhSet->at(t) : nullptr; + const NDArray* dLdhhL = (t == sL-1 && dLdhL) ? dLdhL : nullptr; + const NDArray* dLdccL = (t == sL-1 && dLdcL) ? dLdcL : nullptr; + lstmLayerCellBp(xSet->at(t), Wx, Wr, b, hSet->at(t), cSet->at(t), Wp, dLdhh, dLdhhL, dLdccL, zSet->at(t), aSet->at(t), cSet->at(t+1), params, dLdxSet->at(t), dLdWx, dLdWr, dLdh0, dLdc0, dLdb, dLdWp); } } @@ -1086,13 +1046,13 @@ void lstmLayerTimeLoopBp(const NDArray* x, const NDArray* Wx, const NDArray* Wr, } if(hI) - h({0,1, e,e+1, 0,0}).assign(hISet->at(e)); + hSet->at(e)->assign(hISet->at(e)); else - h({0,1, e,e+1, 0,0}).nullify(); + hSet->at(e)->nullify(); if(cI) - c({0,1, e,e+1, 0,0}).assign(cISet->at(e)); + cSet->at(e)->assign(cISet->at(e)); else - c({0,1, e,e+1, 0,0}).nullify(); + cSet->at(e)->nullify(); // ff for (int t = 0; t < limit; ++t) @@ -1102,9 +1062,10 @@ void lstmLayerTimeLoopBp(const NDArray* x, const NDArray* Wx, const NDArray* Wr, // bp for (int t = limit-1; t >= 0; --t) { const auto ind = getBatchTimeTotalIndex(dataFormat, sL, bS, t, e); - const NDArray* dLdhh = dLdh ? dLdhSet->at(ind) : (t == limit-1 && dLdhL ? dLdhLSet->at(e) : nullptr); - const NDArray* dLdcc = dLdhh ? nullptr : (t == limit-1 ? dLdcLSet->at(e) : nullptr); - lstmLayerCellBp(xSet->at(ind), Wx, Wr, b, hSet->at(t*bS + e), cSet->at(t*bS + e), Wp, dLdhh, dLdcc, + const NDArray* dLdhh = dLdh ? dLdhSet->at(ind) : nullptr; + const NDArray* dLdhhL = (t == limit-1 && dLdhL) ? dLdhLSet->at(e) : nullptr; + const NDArray* dLdccL = (t == limit-1 && dLdcL) ? dLdcLSet->at(e) : nullptr; + lstmLayerCellBp(xSet->at(ind), Wx, Wr, b, hSet->at(t*bS + e), cSet->at(t*bS + e), Wp, dLdhh, dLdhhL, dLdccL, zSet->at(t*bS + e), aSet->at(t*bS + e), cSet->at((t+1)*bS + e), params, dLdxSet->at(ind), dLdWx, dLdWr, dLdh0Set->at(e), dLdc0Set->at(e), dLdb, dLdWp); } @@ -1119,13 +1080,13 @@ void lstmLayerTimeLoopBp(const NDArray* x, const NDArray* Wx, const NDArray* Wr, if(!seqLen) { // backward or bidirectional, seqLen is absent if(hI) - h({sL,sL+1, 0,0, 0,0}).assign(hI); + hSet->at(sL)->assign(hI); else - h({sL,sL+1, 0,0, 0,0}).nullify(); + hSet->at(sL)->nullify(); if(cI) - c({sL,sL+1, 0,0, 0,0}).assign(cI); + cSet->at(sL)->assign(cI); else - c({sL,sL+1, 0,0, 0,0}).nullify(); + cSet->at(sL)->nullify(); // ff for (int t = sL-1; t >= 0; --t) @@ -1133,9 +1094,10 @@ void lstmLayerTimeLoopBp(const NDArray* x, const NDArray* Wx, const NDArray* Wr, // bp for (int t = 0; t < sL; ++t) { - const NDArray* dLdhh = dLdh ? dLdhSet->at(t) : (t == 0 ? dLdhL : nullptr); - const NDArray* dLdcc = dLdhh ? nullptr : (t == 0 ? dLdcL : nullptr); - lstmLayerCellBp(xSet->at(t), Wx, Wr, b, hSet->at(t+1), cSet->at(t+1), Wp, dLdhh, dLdcc, + const NDArray* dLdhh = dLdh ? dLdhSet->at(t) : nullptr; + const NDArray* dLdhhL = (t == 0 && dLdhL) ? dLdhL : nullptr; + const NDArray* dLdccL = (t == 0 && dLdcL) ? dLdcL : nullptr; + lstmLayerCellBp(xSet->at(t), Wx, Wr, b, hSet->at(t+1), cSet->at(t+1), Wp, dLdhh, dLdhhL, dLdccL, zSet->at(t), aSet->at(t), cSet->at(t), params, dLdxSet->at(t), dLdWx, dLdWr, dLdh0, dLdc0, dLdb, dLdWp); } } @@ -1151,13 +1113,13 @@ void lstmLayerTimeLoopBp(const NDArray* x, const NDArray* Wx, const NDArray* Wr, } if(hI) - h({sL,sL+1, e,e+1, 0,0}).assign(hISet->at(e)); + hSet->at(sL*bS + e)->assign(hISet->at(e)); else - h({sL,sL+1, e,e+1, 0,0}).nullify(); + hSet->at(sL*bS + e)->nullify(); if(cI) - c({sL,sL+1, e,e+1, 0,0}).assign(cISet->at(e)); + cSet->at(sL*bS + e)->assign(cISet->at(e)); else - c({sL,sL+1, e,e+1, 0,0}).nullify(); + cSet->at(sL*bS + e)->nullify(); // ff for (int t = sL - 1; t >= sL-limit; --t) @@ -1167,9 +1129,10 @@ void lstmLayerTimeLoopBp(const NDArray* x, const NDArray* Wx, const NDArray* Wr, // bp for (int t = sL-limit; t < sL; ++t) { const auto ind = getBatchTimeTotalIndex(dataFormat, sL, bS, t, e); - const NDArray* dLdhh = dLdh ? dLdhSet->at(ind) : (t == sL-limit && dLdhL ? dLdhLSet->at(e) : nullptr); - const NDArray* dLdcc = dLdhh ? nullptr : (t == sL-limit ? dLdcLSet->at(e) : nullptr); - lstmLayerCellBp(xSet->at(ind), Wx, Wr, b, hSet->at((t+1)*bS + e), cSet->at((t+1)*bS + e), Wp, dLdhh, dLdcc, + const NDArray* dLdhh = dLdh ? dLdhSet->at(ind) : nullptr; + const NDArray* dLdhhL = (t == sL-limit && dLdhL) ? dLdhLSet->at(e) : nullptr; + const NDArray* dLdccL = (t == sL-limit && dLdcL) ? dLdcLSet->at(e) : nullptr; + lstmLayerCellBp(xSet->at(ind), Wx, Wr, b, hSet->at((t+1)*bS + e), cSet->at((t+1)*bS + e), Wp, dLdhh, dLdhhL, dLdccL, zSet->at(t*bS + e), aSet->at(t*bS + e), cSet->at(t*bS + e), params, dLdxSet->at(ind), dLdWx, dLdWr, dLdh0Set->at(e), dLdc0Set->at(e), dLdb, dLdWp); } @@ -1206,9 +1169,10 @@ void lstmLayerTimeLoopBp(const NDArray* x, const NDArray* Wx, const NDArray* Wr, // bp for (int t = 0; t < limit; ++t) { const auto ind = getBatchTimeTotalIndex(dataFormat, sL, bS, t, e); - const NDArray* dLdhh = dLdh ? dLdhSet->at(ind) : (t == 0 && dLdhL ? dLdhLSet->at(e) : nullptr); - const NDArray* dLdcc = dLdhh ? nullptr : (t == 0 ? dLdcLSet->at(e) : nullptr); - lstmLayerCellBp(xSet->at(ind), Wx, Wr, b, hSet->at((t+1)*bS + e), cSet->at((t+1)*bS + e), Wp, dLdhh, dLdcc, + const NDArray* dLdhh = dLdh ? dLdhSet->at(ind) : nullptr; + const NDArray* dLdhhL = (t == 0 && dLdhL) ? dLdhLSet->at(e) : nullptr; + const NDArray* dLdccL = (t == 0 && dLdcL) ? dLdcLSet->at(e) : nullptr; + lstmLayerCellBp(xSet->at(ind), Wx, Wr, b, hSet->at((t+1)*bS + e), cSet->at((t+1)*bS + e), Wp, dLdhh, dLdhhL, dLdccL, zSet->at(t*bS + e), aSet->at(t*bS + e), cSet->at(t*bS + e), params, dLdxSet->at(ind), dLdWx, dLdWr, dLdh0Set->at(e), dLdc0Set->at(e), dLdb, dLdWp); } @@ -1248,10 +1212,10 @@ void lstmLayerTimeLoopBp(const NDArray* x, const NDArray* Wx, const NDArray* Wr, // /** the objective is to provide math-readable code **/ // // equations (no peephole connections) -// // zi = x ^ Wxi + hI ^ Wri + bi -// // zf = x ^ Wxf + hI ^ Wrf + bf -// // zg = x ^ Wxg + hI ^ Wrg + bg -// // zo = x ^ Wxo + hI ^ Wro + bo +// // zi = x × Wxi + hI × Wri + bi +// // zf = x × Wxf + hI × Wrf + bf +// // zg = x × Wxg + hI × Wrg + bg +// // zo = x × Wxo + hI × Wro + bo // // i = act(zi) // // f = act(zf) // // g = actC(zg) @@ -1260,10 +1224,10 @@ void lstmLayerTimeLoopBp(const NDArray* x, const NDArray* Wx, const NDArray* Wr, // // h = o * actH(c) // // equations (peephole connections are present) -// // zi = x ^ Wxi + hI ^ Wri + cI * Wpi + bi -// // zf = x ^ Wxf + hI ^ Wrf + cI * Wpf + bf -// // zg = x ^ Wxg + hI ^ Wrg + bg -// // zo = x ^ Wxo + hI ^ Wro + c * Wpo + bo +// // zi = x × Wxi + hI × Wri + cI * Wpi + bi +// // zf = x × Wxf + hI × Wrf + cI * Wpf + bf +// // zg = x × Wxg + hI × Wrg + bg +// // zo = x × Wxo + hI × Wro + c * Wpo + bo // // i = act(zi) // // f = act(zf) // // g = actC(zg) @@ -1333,13 +1297,13 @@ void lstmLayerTimeLoopBp(const NDArray* x, const NDArray* Wx, const NDArray* Wr, // // oFactor = *dLdh*dhdzo [bS, nOut] // // tempC = dcdcI + Wp ? dcdzi*dzidcI + dcdzf*dzfdcI : 0; -// // tempIFE = dcdzi^WriT + dcdzf^WrfT + dcdzg^WrgT -// // tempO = dhdzo^WroT +// // tempIFE = dcdzi×WriT + dcdzf×WrfT + dcdzg×WrgT +// // tempO = dhdzo×WroT // // dhIdcI = dhdc_from_previous_time_step -// // dLdx = iFactor^WxiT + fFactor^WxfT + eFactor^WxgT + oFactor^WxoT, [bS, nIn] -// // dLdhI = iFactor^WriT + fFactor^WrfT + eFactor^WrgT + oFactor^WroT, [bS, nOut] +// // dLdx = iFactor×WxiT + fFactor×WxfT + eFactor×WxgT + oFactor×WxoT, [bS, nIn] +// // dLdhI = iFactor×WriT + fFactor×WrfT + eFactor×WrgT + oFactor×WroT, [bS, nOut] // // dLdcI = factor*tempC + dLdhI * dhIdcI, dhIdcI=0 if firstIter, [bS, nOut] // // dcdWxi(dcIdWxi) = dcdzi*dzidWxi + tempIFE*dhIdWxi + tempC*dcIdWxi, dcIdWxi=dhIdWxi= 0 if firstIter, [nIn, nOut, bS, nOut] diff --git a/libnd4j/include/ops/declarable/helpers/lstmLayer.h b/libnd4j/include/ops/declarable/helpers/lstmLayer.h index 3a2d173b5..29c434865 100644 --- a/libnd4j/include/ops/declarable/helpers/lstmLayer.h +++ b/libnd4j/include/ops/declarable/helpers/lstmLayer.h @@ -42,7 +42,7 @@ void ND4J_EXPORT lstmLayerCell(const NDArray* x, const NDArray* Wx, const NDArra ////////////////////////////////////////////////////////////////////////// void ND4J_EXPORT lstmLayerCellBp(const NDArray* x, const NDArray* Wx, const NDArray* Wr, const NDArray* b, const NDArray* hI, const NDArray* cI, const NDArray* Wp, - const NDArray* dLdh, const NDArray* dLdc, + const NDArray* dLdh, const NDArray* dLdhL, const NDArray* dLdcL, const NDArray* z, const NDArray* a, const NDArray* c, const std::vector& params, NDArray* dLdx, NDArray* dLdWx, NDArray* dLdWr, NDArray* dLdhI, NDArray* dLdcI, NDArray* dLdb, NDArray* dLdWp); diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/lstmLayer.cpp b/libnd4j/include/ops/declarable/platform/mkldnn/lstmLayer.cpp index d09a40120..6763d1403 100644 --- a/libnd4j/include/ops/declarable/platform/mkldnn/lstmLayer.cpp +++ b/libnd4j/include/ops/declarable/platform/mkldnn/lstmLayer.cpp @@ -369,6 +369,7 @@ PLATFORM_IMPL(lstmLayer, ENGINE_CPU) { REQUIRE_TRUE(dataFormat < 2, 0, "LSTM_LAYER_MKLDNN operation: wrong data format, only two formats are allowed for input/output tensors in mkl dnn library: TNC and NTC!"); REQUIRE_TRUE(directionMode < 4, 0, "LSTM_LAYER_MKLDNN operation: option for bidirectional extra output dimension is not valid in mkl dnn library !"); REQUIRE_TRUE(retLastH == retLastC, 0, "LSTM_LAYER_MKLDNN operation: only two options are present: 1) calculate both output at last time and cell state at last time; 2) do not calculate both !"); + REQUIRE_TRUE(hasInitH == hasInitC, 0, "LSTM_LAYER_MKLDNN operation: either both of or neither of initial C and initial H must be provided"); count = 0; auto h = retFullSeq ? OUTPUT_VARIABLE(count++) : nullptr; // output @@ -498,7 +499,7 @@ PLATFORM_CHECK(lstmLayer, ENGINE_CPU) { DataType WrType = Wr->dataType(); DataType bType = b != nullptr ? b->dataType() : (xType == DataType::HALF ? xType : DataType::FLOAT32); DataType hIType = hI != nullptr ? hI->dataType() : xType; - DataType cIType = cI != nullptr ? hI->dataType() : xType; + DataType cIType = cI != nullptr ? cI->dataType() : xType; DataType hType = h != nullptr ? h->dataType() : xType; DataType hLType = hL != nullptr ? hL->dataType() : xType; DataType cLType = cL != nullptr ? cL->dataType() : xType; @@ -509,7 +510,8 @@ PLATFORM_CHECK(lstmLayer, ENGINE_CPU) { && !hasSeqLen //Sequence length array not supported in MKL DNN && dataFormat < 2 //Data format - only 0 and 1 supported in MKL DNN- 0 = [sL, bS, nIn], 1 = [bS, sL ,nIn] && directionMode < 4 //Direction mode - only 0-3 supported in MKL DNN (no extra dim option) - 0 = fwd, 1 = bwd, 2 = bidirectional sum, 3 = bidirectional concat - && retLastH == retLastC; //Return both lastH and lastC, or return neither (not just 1 or other) + && retLastH == retLastC //Return both lastH and lastC, or return neither (not just 1 or other) + && hasInitH == hasInitC; //Need both or neither initial H and C return block.isUseMKLDNN() && featuresSupported && ( (xType==DataType::FLOAT32 && WxType==DataType::FLOAT32 && WrType==DataType::FLOAT32 && bType==DataType::FLOAT32 && hIType==DataType::FLOAT32 && cIType==DataType::FLOAT32 && hType==DataType::FLOAT32 && hLType==DataType::FLOAT32 && cLType==DataType::FLOAT32) || diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests10.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests10.cpp index 03e5ae53f..6d89bd182 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests10.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests10.cpp @@ -2010,6 +2010,34 @@ TEST_F(DeclarableOpsTests10, LinSpace_Test1) { ASSERT_TRUE(expect.equalsTo(res)); +} +//////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests10, LinSpace_Test2) { + + NDArray expect = NDArrayFactory::create({1., 1.5, 2., 2.5, 3., 3.5, 4., 4.5, 5., 5.5, 6., 6.5, 7., 7.5, + 8., 8.5, 9., 9.5, 10., 10.5, 11., 11.5, 12.}); + + sd::ops::lin_space op; + auto result = op.evaluate({}, {1, 12}, {23}); + ASSERT_EQ(result.status(), ND4J_STATUS_OK); + auto res = result.at(0); + ASSERT_EQ( res->dataType(), sd::DataType::FLOAT32 ); + ASSERT_TRUE(expect.equalsTo(res)); + +} +//////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests10, LinSpace_Test3) { + + NDArray expect('c', { 23 }, {1., 1.5, 2., 2.5, 3., 3.5, 4., 4.5, 5., 5.5, 6., 6.5, 7., 7.5, 8., 8.5, 9., 9.5, 10., 10.5, 11., 11.5, 12.}, sd::DataType::DOUBLE ); + + sd::ops::lin_space op; + auto result = op.evaluate({}, {1, 12}, {23}, {}, { sd::DOUBLE }); + ASSERT_EQ(result.status(), ND4J_STATUS_OK); + auto res = result.at(0); + + ASSERT_EQ( res->dataType(), expect.dataType()); + ASSERT_TRUE(expect.equalsTo(res)); + } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests10, ImageResizeNeighbor_Test1) { diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests13.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests13.cpp index cee574dec..4052e260d 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests13.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests13.cpp @@ -2084,11 +2084,11 @@ TEST_F(DeclarableOpsTests13, lstmLayer_bp_1) { const auto hasInitH = true; // initial output is provided const auto hasInitC = true; // initial cell state is provided const auto hasPH = true; // peephole connections are absent - const auto retFullSeq = false; // dLdh per each time step + const auto retFullSeq = true; // dLdh per each time step const auto retLastH = true; // output at last time step - const auto retLastC = true; // cells state at last time step + const auto retLastC = true; // cells state at last time step - const double cellClip = 0.5; // do not apply clipping + const double cellClip = 0.5; // clipping NDArray x('c', {sL, bS, nIn}, sd::DataType::DOUBLE); NDArray Wx('c', {nIn, 4*nOut}, sd::DataType::DOUBLE); @@ -2097,6 +2097,7 @@ TEST_F(DeclarableOpsTests13, lstmLayer_bp_1) { NDArray hI('c', {bS, nOut}, sd::DataType::DOUBLE); NDArray cI('c', {bS, nOut}, sd::DataType::DOUBLE); NDArray Wp('c', {3*nOut}, sd::DataType::DOUBLE); + NDArray dLdh('c', {sL, bS, nOut}, sd::DataType::DOUBLE); NDArray dLdhL('c', {bS, nOut}, sd::DataType::DOUBLE); NDArray dLdcL('c', {bS, nOut}, sd::DataType::DOUBLE); @@ -2113,12 +2114,12 @@ TEST_F(DeclarableOpsTests13, lstmLayer_bp_1) { std::vector bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC}; const OpArgsHolder argsHolderFF({&x, &Wx, &Wr, &b, &hI, &cI, &Wp}, tArgs, iArgs, bArgs); - const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &hI, &cI, &Wp, &dLdhL, &dLdcL}, tArgs, iArgs, bArgs); + const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &hI, &cI, &Wp, &dLdh, &dLdhL, &dLdcL}, tArgs, iArgs, bArgs); sd::ops::lstmLayer opFF; sd::ops::lstmLayer_bp opBP; - const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP, std::vector(), {0., 1.}, GradCheck::LossFunc::SUM, {0}); + const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP); ASSERT_TRUE(isGradCorrect); } @@ -2131,63 +2132,6 @@ TEST_F(DeclarableOpsTests13, lstmLayer_bp_2) { const int nIn = 2; const int nOut = 3; - const int dataFormat = 0; // [sL,bS,nIn] - const int directionMode = 0; // forward - const int gateAct = 2; // sigmoid activation for input (i), forget (f) and output (o) gates - const int cellAct = 0; // tanh activation for cell state - const int outAct = 0; // tanh activation for output - - const bool hasBiases = true; // biases array is provided - const bool hasSeqLen = false; // seqLen array is not provided - const auto hasInitH = true; // initial output is provided - const auto hasInitC = true; // initial cell state is provided - const auto hasPH = true; // peephole connections are absent - const auto retFullSeq = false; // return whole h {h_0, h_1, ... , h_sL-1}, [sL,bS,nOut] - const auto retLastH = false; // output at last time step - const auto retLastC = true; // cells state at last time step - - const double cellClip = 0.5; // do not apply clipping - - NDArray x('c', {sL, bS, nIn}, sd::DataType::DOUBLE); - NDArray Wx('c', {nIn, 4*nOut}, sd::DataType::DOUBLE); - NDArray Wr('c', {nOut, 4*nOut}, sd::DataType::DOUBLE); - NDArray b('c', {4*nOut}, sd::DataType::DOUBLE); - NDArray hI('c', {bS, nOut}, sd::DataType::DOUBLE); - NDArray cI('c', {bS, nOut}, sd::DataType::DOUBLE); - NDArray Wp('c', {3*nOut}, sd::DataType::DOUBLE); - NDArray dLdcL('c', {bS, nOut}, sd::DataType::DOUBLE); - - x.linspace(-2,0.1); - hI.linspace(-1.5,0.1); - cI.linspace(0.7,-0.1); - Wx.linspace(1,-0.1); - Wr.linspace(-1,0.1); - Wp.linspace(0.2,0.2); - b.linspace(1,-0.15); - - std::vector tArgs = {cellClip}; - std::vector iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct}; - std::vector bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC}; - - const OpArgsHolder argsHolderFF({&x, &Wx, &Wr, &b, &hI, &cI, &Wp}, tArgs, iArgs, bArgs); - const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &hI, &cI, &Wp, &dLdcL}, tArgs, iArgs, bArgs); - - sd::ops::lstmLayer opFF; - sd::ops::lstmLayer_bp opBP; - - const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP, std::vector(), {0., 1.}, GradCheck::LossFunc::MEAN); - - ASSERT_TRUE(isGradCorrect); -} - -/////////////////////////////////////////////////////////////////// -TEST_F(DeclarableOpsTests13, lstmLayer_bp_3) { - - const int sL = 3; - const int bS = 2; - const int nIn = 2; - const int nOut = 3; - const int dataFormat = 1; // [bS,sL,nIn] const int directionMode = 0; // forward const int gateAct = 2; // sigmoid activation for input (i), forget (f) and output (o) gates @@ -2199,11 +2143,11 @@ TEST_F(DeclarableOpsTests13, lstmLayer_bp_3) { const auto hasInitH = true; // initial output is provided const auto hasInitC = true; // initial cell state is provided const auto hasPH = true; // peephole connections are absent - const auto retFullSeq = true; // return whole h {h_0, h_1, ... , h_sL-1}, [sL,bS,nOut] - const auto retLastH = false; // output at last time step + const auto retFullSeq = true; // return whole h {h_0, h_1, ... , h_sL-1}, [sL,bS,nOut] + const auto retLastH = false; // output at last time step const auto retLastC = true; // cells state at last time step - const double cellClip = 0.5; // do not apply clipping + const double cellClip = 0.5; // clipping NDArray x('c', {bS, sL, nIn}, sd::DataType::DOUBLE); NDArray Wx('c', {nIn, 4*nOut}, sd::DataType::DOUBLE); @@ -2233,13 +2177,13 @@ TEST_F(DeclarableOpsTests13, lstmLayer_bp_3) { sd::ops::lstmLayer opFF; sd::ops::lstmLayer_bp opBP; - const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP, std::vector(), {0., 1.}, GradCheck::LossFunc::MEAN, {0}); + const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP, std::vector(), {0., 1.}, GradCheck::LossFunc::MEAN); ASSERT_TRUE(isGradCorrect); } /////////////////////////////////////////////////////////////////// -TEST_F(DeclarableOpsTests13, lstmLayer_bp_4) { +TEST_F(DeclarableOpsTests13, lstmLayer_bp_3) { const int sL = 4; const int bS = 3; @@ -2258,10 +2202,10 @@ TEST_F(DeclarableOpsTests13, lstmLayer_bp_4) { const auto hasInitC = true; // initial cell state is provided const auto hasPH = true; // peephole connections are absent const auto retFullSeq = true; // dLdh per each time step - const auto retLastH = false; // output at last time step - const auto retLastC = false; // cells state at last time step + const auto retLastH = true; // output at last time step + const auto retLastC = true; // cells state at last time step - const double cellClip = 0.5; // do not apply clipping + const double cellClip = 0.5; // clipping NDArray x('c', {bS, nIn, sL}, sd::DataType::DOUBLE); NDArray Wx('c', {nIn, 4*nOut}, sd::DataType::DOUBLE); @@ -2272,6 +2216,8 @@ TEST_F(DeclarableOpsTests13, lstmLayer_bp_4) { NDArray cI('c', {bS, nOut}, sd::DataType::DOUBLE); NDArray Wp('c', {3*nOut}, sd::DataType::DOUBLE); NDArray dLdh('c', {bS, nOut, sL}, sd::DataType::DOUBLE); + NDArray dLdhL('c', {bS, nOut}, sd::DataType::DOUBLE); + NDArray dLdcL('c', {bS, nOut}, sd::DataType::DOUBLE); x.linspace(-2,0.1); hI.linspace(-1.5,0.1); @@ -2286,18 +2232,18 @@ TEST_F(DeclarableOpsTests13, lstmLayer_bp_4) { std::vector bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC}; const OpArgsHolder argsHolderFF({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp}, tArgs, iArgs, bArgs); - const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp, &dLdh}, tArgs, iArgs, bArgs); + const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp, &dLdh, &dLdhL, &dLdcL}, tArgs, iArgs, bArgs); sd::ops::lstmLayer opFF; sd::ops::lstmLayer_bp opBP; - const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP, {true, true, true, true, false, true, true, true}, {0., 1.}, GradCheck::LossFunc::MEAN, {0}); + const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP, {true, true, true, true, false, true, true, true}); ASSERT_TRUE(isGradCorrect); } /////////////////////////////////////////////////////////////////// -TEST_F(DeclarableOpsTests13, lstmLayer_bp_5) { +TEST_F(DeclarableOpsTests13, lstmLayer_bp_4) { const int sL = 3; const int bS = 2; @@ -2315,11 +2261,11 @@ TEST_F(DeclarableOpsTests13, lstmLayer_bp_5) { const auto hasInitH = true; // initial output is provided const auto hasInitC = true; // initial cell state is provided const auto hasPH = true; // peephole connections are absent - const auto retFullSeq = false; // dLdh per each time step + const auto retFullSeq = true; // dLdh per each time step const auto retLastH = true; // output at last time step - const auto retLastC = false; // cells state at last time step + const auto retLastC = true; // cells state at last time step - const double cellClip = 0.5; // do not apply clipping + const double cellClip = 0.5; // clipping NDArray x('c', {bS, sL, nIn}, sd::DataType::DOUBLE); NDArray Wx('c', {nIn, 4*nOut}, sd::DataType::DOUBLE); @@ -2328,7 +2274,9 @@ TEST_F(DeclarableOpsTests13, lstmLayer_bp_5) { NDArray hI('c', {bS, nOut}, sd::DataType::DOUBLE); NDArray cI('c', {bS, nOut}, sd::DataType::DOUBLE); NDArray Wp('c', {3*nOut}, sd::DataType::DOUBLE); + NDArray dLdh('c', {bS, sL, nOut}, sd::DataType::DOUBLE); NDArray dLdhL('c', {bS, nOut}, sd::DataType::DOUBLE); + NDArray dLdcL('c', {bS, nOut}, sd::DataType::DOUBLE); x.linspace(-2,0.1); hI.linspace(-1.5,0.1); @@ -2343,18 +2291,18 @@ TEST_F(DeclarableOpsTests13, lstmLayer_bp_5) { std::vector bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC}; const OpArgsHolder argsHolderFF({&x, &Wx, &Wr, &b, &hI, &cI, &Wp}, tArgs, iArgs, bArgs); - const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &hI, &cI, &Wp, &dLdhL}, tArgs, iArgs, bArgs); + const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &hI, &cI, &Wp, &dLdh, &dLdhL, &dLdcL}, tArgs, iArgs, bArgs); sd::ops::lstmLayer opFF; sd::ops::lstmLayer_bp opBP; - const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP, std::vector(), {0., 1.}, GradCheck::LossFunc::MEAN, {0}); + const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP); ASSERT_TRUE(isGradCorrect); } /////////////////////////////////////////////////////////////////// -TEST_F(DeclarableOpsTests13, lstmLayer_bp_6) { +TEST_F(DeclarableOpsTests13, lstmLayer_bp_5) { const int sL = 3; const int bS = 2; @@ -2373,10 +2321,10 @@ TEST_F(DeclarableOpsTests13, lstmLayer_bp_6) { const auto hasInitC = true; // initial cell state is provided const auto hasPH = true; // peephole connections are absent const auto retFullSeq = true; // dLdh per each time step - const auto retLastH = false; // output at last time step - const auto retLastC = false; // cells state at last time step + const auto retLastH = true; // output at last time step + const auto retLastC = true; // cells state at last time step - const double cellClip = 0.5; // do not apply clipping + const double cellClip = 0.5; // clipping NDArray x('c', {bS, nIn, sL}, sd::DataType::DOUBLE); NDArray Wx('c', {nIn, 4*nOut}, sd::DataType::DOUBLE); @@ -2387,6 +2335,8 @@ TEST_F(DeclarableOpsTests13, lstmLayer_bp_6) { NDArray cI('c', {bS, nOut}, sd::DataType::DOUBLE); NDArray Wp('c', {3*nOut}, sd::DataType::DOUBLE); NDArray dLdh('c', {bS, nOut, sL}, sd::DataType::DOUBLE); + NDArray dLdhL('c', {bS, nOut}, sd::DataType::DOUBLE); + NDArray dLdcL('c', {bS, nOut}, sd::DataType::DOUBLE); x.linspace(-2,0.1); hI.linspace(-1.5,0.1); @@ -2401,18 +2351,18 @@ TEST_F(DeclarableOpsTests13, lstmLayer_bp_6) { std::vector bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC}; const OpArgsHolder argsHolderFF({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp}, tArgs, iArgs, bArgs); - const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp, &dLdh}, tArgs, iArgs, bArgs); + const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp, &dLdh, &dLdhL, &dLdcL}, tArgs, iArgs, bArgs); sd::ops::lstmLayer opFF; sd::ops::lstmLayer_bp opBP; - const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP, {true, true, true, true, false, true, true, true}, {0., 1.}, GradCheck::LossFunc::MEAN, {0}); + const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP, {true, true, true, true, false, true, true, true}); ASSERT_TRUE(isGradCorrect); } /////////////////////////////////////////////////////////////////// -TEST_F(DeclarableOpsTests13, lstmLayer_bp_7) { +TEST_F(DeclarableOpsTests13, lstmLayer_bp_6) { const int sL = 3; const int bS = 2; @@ -2430,11 +2380,11 @@ TEST_F(DeclarableOpsTests13, lstmLayer_bp_7) { const auto hasInitH = true; // initial output is provided const auto hasInitC = true; // initial cell state is provided const auto hasPH = true; // peephole connections are absent - const auto retFullSeq = false; // dLdh per each time step + const auto retFullSeq = true; // dLdh per each time step const auto retLastH = true; // output at last time step - const auto retLastC = false; // cells state at last time step + const auto retLastC = true; // cells state at last time step - const double cellClip = 0.5; // do not apply clipping + const double cellClip = 0.5; // clipping NDArray x('c', {bS, nIn, sL}, sd::DataType::DOUBLE); NDArray Wx('c', {2, nIn, 4*nOut}, sd::DataType::DOUBLE); @@ -2444,7 +2394,9 @@ TEST_F(DeclarableOpsTests13, lstmLayer_bp_7) { NDArray hI('c', {2, bS, nOut}, sd::DataType::DOUBLE); NDArray cI('c', {2, bS, nOut}, sd::DataType::DOUBLE); NDArray Wp('c', {2, 3*nOut}, sd::DataType::DOUBLE); + NDArray dLdh('c', {bS, nOut, sL}, sd::DataType::DOUBLE); NDArray dLdhL('c', {2, bS, nOut}, sd::DataType::DOUBLE); + NDArray dLdcL('c', {2, bS, nOut}, sd::DataType::DOUBLE); x.linspace(-2,0.1); hI.linspace(-1.5,0.1); @@ -2459,18 +2411,24 @@ TEST_F(DeclarableOpsTests13, lstmLayer_bp_7) { std::vector bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC}; const OpArgsHolder argsHolderFF({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp}, tArgs, iArgs, bArgs); - const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp, &dLdhL}, tArgs, iArgs, bArgs); + const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp, &dLdh, &dLdhL, &dLdcL}, tArgs, iArgs, bArgs); + // const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp, &dLdh, &dLdhL}, tArgs, iArgs, bArgs); + // const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp, &dLdh, &dLdcL}, tArgs, iArgs, bArgs); + // const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp, &dLdhL, &dLdcL}, tArgs, iArgs, bArgs); + // const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp, &dLdh}, tArgs, iArgs, bArgs); + // const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp, &dLdhL}, tArgs, iArgs, bArgs); + // const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp, &dLdcL}, tArgs, iArgs, bArgs); sd::ops::lstmLayer opFF; sd::ops::lstmLayer_bp opBP; - const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP, {true, true, true, true, false, true, true, true}, {0., 1.}, GradCheck::LossFunc::MEAN, {0}); + const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP, {true, true, true, true, false, true, true, true}); ASSERT_TRUE(isGradCorrect); } /////////////////////////////////////////////////////////////////// -TEST_F(DeclarableOpsTests13, lstmLayer_bp_8) { +TEST_F(DeclarableOpsTests13, lstmLayer_bp_7) { const int sL = 3; const int bS = 2; @@ -2489,10 +2447,10 @@ TEST_F(DeclarableOpsTests13, lstmLayer_bp_8) { const auto hasInitC = true; // initial cell state is provided const auto hasPH = true; // peephole connections are absent const auto retFullSeq = true; // dLdh per each time step - const auto retLastH = false; // output at last time step - const auto retLastC = false; // cells state at last time step + const auto retLastH = true; // output at last time step + const auto retLastC = true; // cells state at last time step - const double cellClip = 0.5; // do not apply clipping + const double cellClip = 0.5; // clipping NDArray x('c', {bS,sL,nIn}, sd::DataType::DOUBLE); NDArray Wx('c', {2, nIn, 4*nOut}, sd::DataType::DOUBLE); @@ -2503,6 +2461,8 @@ TEST_F(DeclarableOpsTests13, lstmLayer_bp_8) { NDArray cI('c', {2, bS, nOut}, sd::DataType::DOUBLE); NDArray Wp('c', {2, 3*nOut}, sd::DataType::DOUBLE); NDArray dLdh('c', {bS,sL,2*nOut}, sd::DataType::DOUBLE); + NDArray dLdhL('c', {2, bS, nOut}, sd::DataType::DOUBLE); + NDArray dLdcL('c', {2, bS, nOut}, sd::DataType::DOUBLE); x.linspace(-2,0.1); hI.linspace(-1.5,0.1); @@ -2517,18 +2477,24 @@ TEST_F(DeclarableOpsTests13, lstmLayer_bp_8) { std::vector bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC}; const OpArgsHolder argsHolderFF({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp}, tArgs, iArgs, bArgs); - const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp, &dLdh}, tArgs, iArgs, bArgs); + const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp, &dLdh, &dLdhL, &dLdcL}, tArgs, iArgs, bArgs); + // const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp, &dLdh, &dLdhL}, tArgs, iArgs, bArgs); + // const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp, &dLdh, &dLdcL}, tArgs, iArgs, bArgs); + // const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp, &dLdhL, &dLdcL}, tArgs, iArgs, bArgs); + // const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp, &dLdh}, tArgs, iArgs, bArgs); + // const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp, &dLdhL}, tArgs, iArgs, bArgs); + // const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp, &dLdcL}, tArgs, iArgs, bArgs); sd::ops::lstmLayer opFF; sd::ops::lstmLayer_bp opBP; - const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP, {true, true, true, true, false, true, true, true}, {0., 1.}, GradCheck::LossFunc::MEAN, {0}); + const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP, {true, true, true, true, false, true, true, true}); ASSERT_TRUE(isGradCorrect); } /////////////////////////////////////////////////////////////////// -TEST_F(DeclarableOpsTests13, lstmLayer_bp_9) { +TEST_F(DeclarableOpsTests13, lstmLayer_bp_8) { const int sL = 3; const int bS = 2; @@ -2547,10 +2513,10 @@ TEST_F(DeclarableOpsTests13, lstmLayer_bp_9) { const auto hasInitC = true; // initial cell state is provided const auto hasPH = true; // peephole connections are absent const auto retFullSeq = true; // dLdh per each time step - const auto retLastH = false; // output at last time step - const auto retLastC = false; // cells state at last time step + const auto retLastH = true; // output at last time step + const auto retLastC = true; // cells state at last time step - const double cellClip = 0.5; // do not apply clipping + const double cellClip = 0.5; // clipping NDArray x('c', {sL, bS, nIn}, sd::DataType::DOUBLE); NDArray Wx('c', {2, nIn, 4*nOut}, sd::DataType::DOUBLE); @@ -2561,6 +2527,8 @@ TEST_F(DeclarableOpsTests13, lstmLayer_bp_9) { NDArray cI('c', {2, bS, nOut}, sd::DataType::DOUBLE); NDArray Wp('c', {2, 3*nOut}, sd::DataType::DOUBLE); NDArray dLdh('c', {sL, 2, bS, nOut}, sd::DataType::DOUBLE); + NDArray dLdhL('c', {2, bS, nOut}, sd::DataType::DOUBLE); + NDArray dLdcL('c', {2, bS, nOut}, sd::DataType::DOUBLE); x.linspace(-2,0.1); hI.linspace(-1.5,0.1); @@ -2575,12 +2543,18 @@ TEST_F(DeclarableOpsTests13, lstmLayer_bp_9) { std::vector bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC}; const OpArgsHolder argsHolderFF({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp}, tArgs, iArgs, bArgs); - const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp, &dLdh}, tArgs, iArgs, bArgs); + const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp, &dLdh, &dLdhL, &dLdcL}, tArgs, iArgs, bArgs); + // const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp, &dLdh, &dLdhL}, tArgs, iArgs, bArgs); + // const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp, &dLdh, &dLdcL}, tArgs, iArgs, bArgs); + // const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp, &dLdhL, &dLdcL}, tArgs, iArgs, bArgs); + // const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp, &dLdh}, tArgs, iArgs, bArgs); + // const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp, &dLdhL}, tArgs, iArgs, bArgs); + // const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp, &dLdcL}, tArgs, iArgs, bArgs); sd::ops::lstmLayer opFF; sd::ops::lstmLayer_bp opBP; - const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP, {true, true, true, true, false, true, true, true}, {0., 1.}, GradCheck::LossFunc::MEAN, {0}); + const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP, {true, true, true, true, false, true, true, true}); ASSERT_TRUE(isGradCorrect); } diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests15.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests15.cpp index 5fffa73c5..3d86cd92b 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests15.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests15.cpp @@ -1904,6 +1904,7 @@ TEST_F(DeclarableOpsTests15, TestTensorMmul_BP16) { const bool isGradCorrect = GradCheck::checkGrad(op, op_bp, argsHolderFF, argsHolderBP, {1,0}); ASSERT_TRUE(isGradCorrect); } + ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests15, TestTensorMmul_BP17) { @@ -1922,3 +1923,68 @@ TEST_F(DeclarableOpsTests15, TestTensorMmul_BP17) { ASSERT_TRUE(isGradCorrect); } + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests15, gru_1) { + + const int sL = 3; + const int bS = 2; + const int nIn = 5; + const int nOut = 4; + + + NDArray x('c', {sL, bS, nIn}, {0.5, 1. , 1.5, 2. , 2.5, 3. , 3.5, 4. , 4.5, 5. , 5.5, 6. , 6.5, 7. , 7.5, 8. , 8.5, 9. , 9.5, 10. , 10.5, 11. , 11.5, 12. , 12.5, 13. , 13.5, 14. , 14.5, 15.}, sd::DataType::FLOAT32); + NDArray hI('c', {bS, nOut}, {-3,-2,-1,0,1,2,3,4}, sd::DataType::FLOAT32); + NDArray Wx('c', {nIn, 3*nOut}, sd::DataType::FLOAT32); + NDArray Wh('c', {nOut, 3*nOut}, sd::DataType::FLOAT32); + NDArray b('c', {3*nOut}, sd::DataType::FLOAT32); + + NDArray expH('c', {sL, bS, nOut}, {-1.681847, -1.062565, -0.443283, 0.175998,0.837823, 1.488041, 2.13826 , 2.788478, -0.888747, -0.491826, -0.094907, 0.302014, + 0.751355, 1.182715, 1.614075, 2.045434, -0.388876, -0.126716, 0.135444, 0.397604,0.710558, 1.002922, 1.295287, 1.587651}, sd::DataType::FLOAT32); + + Wx = 0.003; + Wh = 0.006; + b = 0.5; + + NDArray dLdC('c', { 2, 2 }, sd::DataType::DOUBLE); + + sd::ops::gru op; + auto results = op.evaluate({&x, &hI, &Wx, &Wh, &b}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto* h = results.at(0); + + ASSERT_TRUE(expH.isSameShape(h)); + ASSERT_TRUE(expH.equalsTo(h)); +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests15, gru_bp_1) { + + const int sL = 3; + const int bS = 2; + const int nIn = 5; + const int nOut = 4; + + + NDArray x('c', {sL, bS, nIn}, {0.5, 1. , 1.5, 2. , 2.5, 3. , 3.5, 4. , 4.5, 5. , 5.5, 6. , 6.5, 7. , 7.5, 8. , 8.5, 9. , 9.5, 10. , 10.5, 11. , 11.5, 12. , 12.5, 13. , 13.5, 14. , 14.5, 15.}, sd::DataType::DOUBLE); + NDArray hI('c', {bS, nOut}, {-3,-2,-1,0,1,2,3,4}, sd::DataType::DOUBLE); + NDArray Wx('c', {nIn, 3*nOut}, sd::DataType::DOUBLE); + NDArray Wh('c', {nOut, 3*nOut}, sd::DataType::DOUBLE); + NDArray b('c', {3*nOut}, sd::DataType::DOUBLE); + + NDArray dLdh('c', {sL, bS, nOut}, sd::DataType::DOUBLE); + + Wx.linspace(1,-0.1); + Wh.linspace(0.2,0.2); + b.linspace(1,-0.15); + + const OpArgsHolder argsHolderFF({&x, &hI, &Wx, &Wh, &b}, {}, {}); + const OpArgsHolder argsHolderBP({&x, &hI, &Wx, &Wh, &b, &dLdh}, {}, {}); + + sd::ops::gru opFF; + sd::ops::gru_bp opBP; + + const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP); +} diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests7.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests7.cpp index 7f39c3d76..d8478e471 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests7.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests7.cpp @@ -1868,7 +1868,26 @@ TEST_F(DeclarableOpsTests7, TestUnsortedSegmentSqrtN_5) { // exp.printIndexedBuffer("Expect"); ASSERT_TRUE(exp.equalsTo(result.at(0))); - +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, TestUnsortedSegmentSqrtN_6) { + auto x = NDArrayFactory::create({5,1,7,2,3,4,1,3}); + auto idx = NDArrayFactory::create({0,0,0,1,2,2,3,3}); + //NDArray exp({1.7320508075688772, 1., 1.4142135623730951, 1.4142135623730951}); +// auto exp = NDArrayFactory::create({7.5055537, 2., 4.9497476, 2.828427}); + sd::ops::unsorted_segment_sqrt_n op; + +try { + auto result = op.evaluate({&x, &idx}, {}, {1}); + ASSERT_NE(result.status(), Status::OK()); +} +catch (std::exception& err) { + +} + // result.at(0)->printIndexedBuffer("Output"); + // exp.printIndexedBuffer("Expect"); + //ASSERT_TRUE(exp.equalsTo(result.at(0))); } //////////////////////////////////////////////////////////////////////////////// diff --git a/libnd4j/tests_cpu/layers_tests/JavaInteropTests.cpp b/libnd4j/tests_cpu/layers_tests/JavaInteropTests.cpp index 6f559230b..29c681544 100644 --- a/libnd4j/tests_cpu/layers_tests/JavaInteropTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/JavaInteropTests.cpp @@ -1334,6 +1334,20 @@ TEST_F(JavaInteropTests, test_workspace_backed_arrays_1) { ASSERT_EQ(Status::OK(), status); } +TEST_F(JavaInteropTests, test_linspace_shape_1) { + if (!Environment::getInstance()->isCPU()) + return; + + sd::ops::lin_space op; + double tArgs[2] = {1.0, 10.0}; + Nd4jLong iArgs = 10L; + int dArg = (int) sd::DataType::FLOAT32; + auto result = ::calculateOutputShapes2(nullptr, op.getOpHash(), nullptr, nullptr, 0, tArgs, 2, &iArgs, 1, nullptr, 0, &dArg, 1); + + ASSERT_EQ(1, result->size()); + delete result; +} + /* TEST_F(JavaInteropTests, Test_Results_Conversion_1) { auto pl = sd::graph::readFlatBuffers("./resources/gru_dynamic_mnist.fb"); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunction.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunction.java index 94bda0b78..e21b2d270 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunction.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunction.java @@ -153,6 +153,7 @@ public abstract class DifferentialFunction { public Map propertiesForFunction() { Map fields = DifferentialFunctionClassHolder.getInstance().getFieldsForFunction(this); Map ret = new LinkedHashMap<>(); + Preconditions.checkNotNull(fields, "DifferentialFunctionClassHolder returned null fields for %s - op has not been added to ImportClassMapping?", getClass()); for(val entry : fields.entrySet()) { try { @@ -474,6 +475,11 @@ public abstract class DifferentialFunction { return outputVariables()[0]; } + public List outputs(){ + SDVariable[] out = outputVariables(); + return out == null ? null : Arrays.asList(out); + } + public String[] outputVariablesNames(){ SDVariable[] outputVars = outputVariables(); @@ -501,14 +507,6 @@ public abstract class DifferentialFunction { */ public abstract List doDiff(List f1); - /** - * Shortcut for the {@link DifferentialFunctionFactory} - * @return - */ - public DifferentialFunctionFactory f() { - return sameDiff.f(); - } - /** * Return the arguments for a given function @@ -575,7 +573,7 @@ public abstract class DifferentialFunction { copied = true; } - SDVariable gradVar = f().add(grad, vals.get(i)); + SDVariable gradVar = var.getSameDiff().math.add(grad, vals.get(i)); vals.set(i, gradVar); sameDiff.setGradientForVariableName(var.name(), gradVar); } else { diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunctionFactory.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunctionFactory.java deleted file mode 100644 index 093e3099b..000000000 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunctionFactory.java +++ /dev/null @@ -1,2659 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.nd4j.autodiff.functions; - -import java.lang.reflect.Method; -import java.util.Arrays; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import lombok.Data; -import lombok.NonNull; -import lombok.val; -import org.apache.commons.lang3.ArrayUtils; -import org.nd4j.autodiff.loss.LossReduce; -import org.nd4j.autodiff.samediff.SDVariable; -import org.nd4j.autodiff.samediff.SameDiff; -import org.nd4j.enums.DataFormat; -import org.nd4j.base.Preconditions; -import org.nd4j.linalg.api.blas.params.MMulTranspose; -import org.nd4j.linalg.api.buffer.DataType; -import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.api.ops.NoOp; -import org.nd4j.linalg.api.ops.custom.*; -import org.nd4j.linalg.api.ops.impl.broadcast.BiasAdd; -import org.nd4j.linalg.api.ops.impl.broadcast.BiasAddGrad; -import org.nd4j.linalg.api.ops.impl.controlflow.compat.Enter; -import org.nd4j.linalg.api.ops.impl.controlflow.compat.Exit; -import org.nd4j.linalg.api.ops.impl.controlflow.compat.Merge; -import org.nd4j.linalg.api.ops.impl.controlflow.compat.NextIteration; -import org.nd4j.linalg.api.ops.impl.controlflow.compat.Switch; -import org.nd4j.linalg.api.ops.impl.image.ExtractImagePatches; -import org.nd4j.linalg.api.ops.impl.indexaccum.FirstIndex; -import org.nd4j.linalg.api.ops.impl.indexaccum.IAMax; -import org.nd4j.linalg.api.ops.impl.indexaccum.IAMin; -import org.nd4j.linalg.api.ops.impl.indexaccum.IMax; -import org.nd4j.linalg.api.ops.impl.indexaccum.IMin; -import org.nd4j.linalg.api.ops.impl.indexaccum.LastIndex; -import org.nd4j.linalg.api.ops.impl.layers.ExternalErrorsFunction; -import org.nd4j.linalg.api.ops.impl.layers.convolution.*; -import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv1DConfig; -import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv2DConfig; -import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv3DConfig; -import org.nd4j.linalg.api.ops.impl.layers.convolution.config.DeConv2DConfig; -import org.nd4j.linalg.api.ops.impl.layers.convolution.config.DeConv3DConfig; -import org.nd4j.linalg.api.ops.impl.layers.convolution.config.LocalResponseNormalizationConfig; -import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Pooling2DConfig; -import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Pooling3DConfig; -import org.nd4j.linalg.api.ops.impl.loss.AbsoluteDifferenceLoss; -import org.nd4j.linalg.api.ops.impl.loss.CosineDistanceLoss; -import org.nd4j.linalg.api.ops.impl.loss.HingeLoss; -import org.nd4j.linalg.api.ops.impl.loss.HuberLoss; -import org.nd4j.linalg.api.ops.impl.loss.L2Loss; -import org.nd4j.linalg.api.ops.impl.loss.LogLoss; -import org.nd4j.linalg.api.ops.impl.loss.LogPoissonLoss; -import org.nd4j.linalg.api.ops.impl.loss.MeanPairwiseSquaredErrorLoss; -import org.nd4j.linalg.api.ops.impl.loss.MeanSquaredErrorLoss; -import org.nd4j.linalg.api.ops.impl.loss.SigmoidCrossEntropyLoss; -import org.nd4j.linalg.api.ops.impl.loss.SoftmaxCrossEntropyLoss; -import org.nd4j.linalg.api.ops.impl.loss.SoftmaxCrossEntropyWithLogitsLoss; -import org.nd4j.linalg.api.ops.impl.loss.SparseSoftmaxCrossEntropyLossWithLogits; -import org.nd4j.linalg.api.ops.impl.loss.WeightedCrossEntropyLoss; -import org.nd4j.linalg.api.ops.impl.loss.bp.AbsoluteDifferenceLossBp; -import org.nd4j.linalg.api.ops.impl.loss.bp.CosineDistanceLossBp; -import org.nd4j.linalg.api.ops.impl.loss.bp.HingeLossBp; -import org.nd4j.linalg.api.ops.impl.loss.bp.HuberLossBp; -import org.nd4j.linalg.api.ops.impl.loss.bp.LogLossBp; -import org.nd4j.linalg.api.ops.impl.loss.bp.LogPoissonLossBp; -import org.nd4j.linalg.api.ops.impl.loss.bp.MeanPairwiseSquaredErrorLossBp; -import org.nd4j.linalg.api.ops.impl.loss.bp.MeanSquaredErrorLossBp; -import org.nd4j.linalg.api.ops.impl.loss.bp.SigmoidCrossEntropyLossBp; -import org.nd4j.linalg.api.ops.impl.loss.bp.SoftmaxCrossEntropyLossBp; -import org.nd4j.linalg.api.ops.impl.loss.bp.SoftmaxCrossEntropyWithLogitsLossBp; -import org.nd4j.linalg.api.ops.impl.loss.bp.SparseSoftmaxCrossEntropyLossWithLogitsBp; -import org.nd4j.linalg.api.ops.impl.reduce.Mmul; -import org.nd4j.linalg.api.ops.impl.reduce.MmulBp; -import org.nd4j.linalg.api.ops.impl.reduce.Moments; -import org.nd4j.linalg.api.ops.impl.reduce.NormalizeMoments; -import org.nd4j.linalg.api.ops.impl.reduce.TensorMmul; -import org.nd4j.linalg.api.ops.impl.reduce.ZeroFraction; -import org.nd4j.linalg.api.ops.impl.reduce.bool.All; -import org.nd4j.linalg.api.ops.impl.reduce.bool.Any; -import org.nd4j.linalg.api.ops.impl.reduce.bp.*; -import org.nd4j.linalg.api.ops.impl.reduce.custom.BatchMmul; -import org.nd4j.linalg.api.ops.impl.reduce.custom.LogSumExp; -import org.nd4j.linalg.api.ops.impl.reduce.floating.AMean; -import org.nd4j.linalg.api.ops.impl.reduce.floating.Entropy; -import org.nd4j.linalg.api.ops.impl.reduce.floating.LogEntropy; -import org.nd4j.linalg.api.ops.impl.reduce.floating.Mean; -import org.nd4j.linalg.api.ops.impl.reduce.floating.Norm1; -import org.nd4j.linalg.api.ops.impl.reduce.floating.Norm2; -import org.nd4j.linalg.api.ops.impl.reduce.floating.NormMax; -import org.nd4j.linalg.api.ops.impl.reduce.floating.ShannonEntropy; -import org.nd4j.linalg.api.ops.impl.reduce.floating.SquaredNorm; -import org.nd4j.linalg.api.ops.impl.reduce.longer.CountNonZero; -import org.nd4j.linalg.api.ops.impl.reduce.longer.CountZero; -import org.nd4j.linalg.api.ops.impl.reduce.longer.MatchCondition; -import org.nd4j.linalg.api.ops.impl.reduce.same.AMax; -import org.nd4j.linalg.api.ops.impl.reduce.same.AMin; -import org.nd4j.linalg.api.ops.impl.reduce.same.ASum; -import org.nd4j.linalg.api.ops.impl.reduce.same.Max; -import org.nd4j.linalg.api.ops.impl.reduce.same.Min; -import org.nd4j.linalg.api.ops.impl.reduce.same.Prod; -import org.nd4j.linalg.api.ops.impl.reduce.same.Sum; -import org.nd4j.linalg.api.ops.impl.reduce3.CosineDistance; -import org.nd4j.linalg.api.ops.impl.reduce3.CosineSimilarity; -import org.nd4j.linalg.api.ops.impl.reduce3.Dot; -import org.nd4j.linalg.api.ops.impl.reduce3.EuclideanDistance; -import org.nd4j.linalg.api.ops.impl.reduce3.HammingDistance; -import org.nd4j.linalg.api.ops.impl.reduce3.JaccardDistance; -import org.nd4j.linalg.api.ops.impl.reduce3.ManhattanDistance; -import org.nd4j.linalg.api.ops.impl.scalar.*; -import org.nd4j.linalg.api.ops.impl.scalar.Pow; -import org.nd4j.linalg.api.ops.impl.scalar.comparison.ScalarEquals; -import org.nd4j.linalg.api.ops.impl.scalar.comparison.ScalarGreaterThan; -import org.nd4j.linalg.api.ops.impl.scalar.comparison.ScalarGreaterThanOrEqual; -import org.nd4j.linalg.api.ops.impl.scalar.comparison.ScalarLessThan; -import org.nd4j.linalg.api.ops.impl.scalar.comparison.ScalarLessThanOrEqual; -import org.nd4j.linalg.api.ops.impl.scalar.comparison.ScalarNotEquals; -import org.nd4j.linalg.api.ops.impl.scatter.ScatterAdd; -import org.nd4j.linalg.api.ops.impl.scatter.ScatterDiv; -import org.nd4j.linalg.api.ops.impl.scatter.ScatterMax; -import org.nd4j.linalg.api.ops.impl.scatter.ScatterMin; -import org.nd4j.linalg.api.ops.impl.scatter.ScatterMul; -import org.nd4j.linalg.api.ops.impl.scatter.ScatterSub; -import org.nd4j.linalg.api.ops.impl.scatter.ScatterUpdate; -import org.nd4j.linalg.api.ops.impl.shape.*; -import org.nd4j.linalg.api.ops.impl.shape.bp.SliceBp; -import org.nd4j.linalg.api.ops.impl.shape.bp.StridedSliceBp; -import org.nd4j.linalg.api.ops.impl.shape.bp.TileBp; -import org.nd4j.linalg.api.ops.impl.summarystats.StandardDeviation; -import org.nd4j.linalg.api.ops.impl.summarystats.Variance; -import org.nd4j.linalg.api.ops.impl.transforms.Pad; -import org.nd4j.linalg.api.ops.impl.transforms.ReluLayer; -import org.nd4j.linalg.api.ops.impl.transforms.any.IsMax; -import org.nd4j.linalg.api.ops.impl.transforms.bool.IsFinite; -import org.nd4j.linalg.api.ops.impl.transforms.bool.IsInf; -import org.nd4j.linalg.api.ops.impl.transforms.bool.IsNaN; -import org.nd4j.linalg.api.ops.impl.transforms.bool.MatchConditionTransform; -import org.nd4j.linalg.api.ops.impl.transforms.clip.ClipByNorm; -import org.nd4j.linalg.api.ops.impl.transforms.clip.ClipByValue; -import org.nd4j.linalg.api.ops.impl.transforms.comparison.CompareAndReplace; -import org.nd4j.linalg.api.ops.impl.transforms.comparison.CompareAndSet; -import org.nd4j.linalg.api.ops.impl.transforms.custom.*; -import org.nd4j.linalg.api.ops.impl.transforms.custom.segment.SegmentMax; -import org.nd4j.linalg.api.ops.impl.transforms.custom.segment.SegmentMean; -import org.nd4j.linalg.api.ops.impl.transforms.custom.segment.SegmentMin; -import org.nd4j.linalg.api.ops.impl.transforms.custom.segment.SegmentProd; -import org.nd4j.linalg.api.ops.impl.transforms.custom.segment.SegmentSum; -import org.nd4j.linalg.api.ops.impl.transforms.dtype.Cast; -import org.nd4j.linalg.api.ops.impl.transforms.floating.RSqrt; -import org.nd4j.linalg.api.ops.impl.transforms.floating.Sqrt; -import org.nd4j.linalg.api.ops.impl.transforms.gradient.CubeBp; -import org.nd4j.linalg.api.ops.impl.transforms.gradient.CubeDerivative; -import org.nd4j.linalg.api.ops.impl.transforms.gradient.DynamicPartitionBp; -import org.nd4j.linalg.api.ops.impl.transforms.gradient.EluBp; -import org.nd4j.linalg.api.ops.impl.transforms.gradient.GradientBackwardsMarker; -import org.nd4j.linalg.api.ops.impl.transforms.gradient.HardSigmoidBp; -import org.nd4j.linalg.api.ops.impl.transforms.gradient.HardTanhBp; -import org.nd4j.linalg.api.ops.impl.transforms.gradient.HardTanhDerivative; -import org.nd4j.linalg.api.ops.impl.transforms.gradient.LeakyReLUBp; -import org.nd4j.linalg.api.ops.impl.transforms.gradient.LeakyReLUDerivative; -import org.nd4j.linalg.api.ops.impl.transforms.gradient.LogSoftMaxDerivative; -import org.nd4j.linalg.api.ops.impl.transforms.gradient.PReluBp; -import org.nd4j.linalg.api.ops.impl.transforms.gradient.RationalTanhBp; -import org.nd4j.linalg.api.ops.impl.transforms.gradient.RationalTanhDerivative; -import org.nd4j.linalg.api.ops.impl.transforms.gradient.RectifiedTanhBp; -import org.nd4j.linalg.api.ops.impl.transforms.gradient.RectifiedTanhDerivative; -import org.nd4j.linalg.api.ops.impl.transforms.gradient.Relu6Derivative; -import org.nd4j.linalg.api.ops.impl.transforms.gradient.SELUDerivative; -import org.nd4j.linalg.api.ops.impl.transforms.gradient.SeluBp; -import org.nd4j.linalg.api.ops.impl.transforms.gradient.SigmoidDerivative; -import org.nd4j.linalg.api.ops.impl.transforms.gradient.SoftSignBp; -import org.nd4j.linalg.api.ops.impl.transforms.gradient.SoftSignDerivative; -import org.nd4j.linalg.api.ops.impl.transforms.gradient.SoftmaxBp; -import org.nd4j.linalg.api.ops.impl.transforms.gradient.ThresholdReluBp; -import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.*; -import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.bp.*; -import org.nd4j.linalg.api.ops.impl.transforms.pairwise.bool.And; -import org.nd4j.linalg.api.ops.impl.transforms.pairwise.bool.Or; -import org.nd4j.linalg.api.ops.impl.transforms.pairwise.bool.Xor; -import org.nd4j.linalg.api.ops.impl.transforms.same.Abs; -import org.nd4j.linalg.api.ops.impl.transforms.same.Ceil; -import org.nd4j.linalg.api.ops.impl.transforms.same.Cube; -import org.nd4j.linalg.api.ops.impl.transforms.same.Floor; -import org.nd4j.linalg.api.ops.impl.transforms.same.Identity; -import org.nd4j.linalg.api.ops.impl.transforms.same.Negative; -import org.nd4j.linalg.api.ops.impl.transforms.same.Reciprocal; -import org.nd4j.linalg.api.ops.impl.transforms.same.Round; -import org.nd4j.linalg.api.ops.impl.transforms.same.Sign; -import org.nd4j.linalg.api.ops.impl.transforms.same.Square; -import org.nd4j.linalg.api.ops.impl.transforms.segment.UnsortedSegmentMax; -import org.nd4j.linalg.api.ops.impl.transforms.segment.UnsortedSegmentMean; -import org.nd4j.linalg.api.ops.impl.transforms.segment.UnsortedSegmentMin; -import org.nd4j.linalg.api.ops.impl.transforms.segment.UnsortedSegmentProd; -import org.nd4j.linalg.api.ops.impl.transforms.segment.UnsortedSegmentSqrtN; -import org.nd4j.linalg.api.ops.impl.transforms.segment.UnsortedSegmentSum; -import org.nd4j.linalg.api.ops.impl.transforms.segment.bp.SegmentMaxBp; -import org.nd4j.linalg.api.ops.impl.transforms.segment.bp.SegmentMeanBp; -import org.nd4j.linalg.api.ops.impl.transforms.segment.bp.SegmentMinBp; -import org.nd4j.linalg.api.ops.impl.transforms.segment.bp.SegmentProdBp; -import org.nd4j.linalg.api.ops.impl.transforms.segment.bp.SegmentSumBp; -import org.nd4j.linalg.api.ops.impl.transforms.segment.bp.UnsortedSegmentMaxBp; -import org.nd4j.linalg.api.ops.impl.transforms.segment.bp.UnsortedSegmentMeanBp; -import org.nd4j.linalg.api.ops.impl.transforms.segment.bp.UnsortedSegmentMinBp; -import org.nd4j.linalg.api.ops.impl.transforms.segment.bp.UnsortedSegmentProdBp; -import org.nd4j.linalg.api.ops.impl.transforms.segment.bp.UnsortedSegmentSqrtNBp; -import org.nd4j.linalg.api.ops.impl.transforms.segment.bp.UnsortedSegmentSumBp; -import org.nd4j.linalg.api.ops.impl.transforms.strict.*; -import org.nd4j.linalg.api.ops.random.custom.*; -import org.nd4j.linalg.api.ops.random.impl.BernoulliDistribution; -import org.nd4j.linalg.api.ops.random.impl.BinomialDistribution; -import org.nd4j.linalg.api.ops.random.impl.DropOutInverted; -import org.nd4j.linalg.api.ops.random.impl.GaussianDistribution; -import org.nd4j.linalg.api.ops.random.impl.LogNormalDistribution; -import org.nd4j.linalg.api.ops.random.impl.Range; -import org.nd4j.linalg.api.ops.random.impl.TruncatedNormalDistribution; -import org.nd4j.linalg.api.ops.random.impl.UniformDistribution; -import org.nd4j.linalg.api.shape.Shape; -import org.nd4j.linalg.indexing.conditions.Condition; -import org.nd4j.linalg.util.ArrayUtil; - -/** - * - */ -@Data -public class DifferentialFunctionFactory { - - protected SameDiff sameDiff; - private static Map methodNames; - - /** - * @param sameDiff - */ - public DifferentialFunctionFactory(SameDiff sameDiff) { - if (sameDiff != null) { - this.sameDiff = sameDiff; - if (methodNames == null) { - methodNames = new HashMap<>(); - Method[] methods = getClass().getDeclaredMethods(); - for (Method method : methods) - methodNames.put(method.getName().toLowerCase(), method); - } - } else { - throw new IllegalArgumentException("Input not null value."); - } - - - } - - public SameDiff sameDiff() { - return sameDiff; - } - - - public SDVariable invoke(String name, Object[] args) { - try { - return (SDVariable) methodNames.get(name).invoke(this, args); - } catch (Exception e) { - throw new RuntimeException(e); - } - } - - public ExternalErrorsFunction externalErrors(SDVariable... inputs) { - return externalErrors(null, inputs); - } - - public ExternalErrorsFunction externalErrors(Map externalGradients, SDVariable... inputs) { - Preconditions.checkArgument(inputs != null && inputs.length > 0, "Require at least one SDVariable to" + - " be specified when using external errors: got %s", inputs); - ExternalErrorsFunction fn = new ExternalErrorsFunction(sameDiff(), Arrays.asList(inputs), externalGradients); - fn.outputVariable(); - return fn; - } - - public SDVariable zerosLike(SDVariable input) { - return zerosLike(null, input); - } - - public SDVariable zerosLike(String name, SDVariable input) { - validateDifferentialFunctionsameDiff(input); - return new ZerosLike(name, sameDiff(), input).outputVariable(); - } - - public SDVariable zerosLike(String name, SDVariable input, DataType dataType) { - validateDifferentialFunctionsameDiff(input); - return new ZerosLike(name, sameDiff(), input, dataType).outputVariable(); - } - - public SDVariable create(String name, SDVariable shape, boolean initialize, DataType dataType) { - return create(name, shape, 'c', initialize, dataType); - } - - public SDVariable create(String name, SDVariable shape, char order, boolean initialize, DataType dataType) { - validateDifferentialFunctionsameDiff(shape); - return new Create(name, sameDiff(), shape, order, initialize, dataType).outputVariable(); - } - - public SDVariable onesLike(String name, SDVariable input, DataType dataType) { - validateDifferentialFunctionsameDiff(input); - return new OnesLike(name, sameDiff(), input, dataType).outputVariable(); - } - - public SDVariable linspace(SDVariable lower, SDVariable upper, SDVariable count, DataType dt) { - return new org.nd4j.linalg.api.ops.impl.shape.Linspace(sameDiff(), lower, upper, count, dt).outputVariable(); - } - - public SDVariable range(double from, double to, double step, DataType dataType) { - return new Range(sameDiff(), from, to, step, dataType).outputVariable(); - } - - public SDVariable range(SDVariable from, SDVariable to, SDVariable step, DataType dataType) { - return new Range(sameDiff(), from, to, step, dataType).outputVariable(); - } - - public SDVariable[] listdiff(SDVariable x, SDVariable y){ - return new ListDiff(sameDiff(), x, y).outputVariables(); - } - - public SDVariable cast(SDVariable toCast, DataType toType){ - return new Cast(sameDiff(), toCast, toType).outputVariable(); - } - - public SDVariable[] meshgrid(boolean cartesian, SDVariable... inputs) { - return new MeshGrid(sameDiff(), cartesian, inputs).outputVariables(); - } - - public SDVariable randomUniform(double min, double max, SDVariable shape, DataType dataType) { - return new DistributionUniform(sameDiff(), shape, min, max, dataType).outputVariable(); - } - - public SDVariable randomUniform(double min, double max, long... shape) { - return new UniformDistribution(sameDiff(), min, max, shape).outputVariable(); - } - - public SDVariable randomNormal(double mean, double std, SDVariable shape) { - return new RandomNormal(sameDiff(), shape, mean, std).outputVariable(); - } - - public SDVariable randomNormal(double mean, double std, long... shape) { - return new GaussianDistribution(sameDiff(), mean, std, shape).outputVariable(); - } - - public SDVariable randomBernoulli(double p, SDVariable shape) { - return new RandomBernoulli(sameDiff(), shape, p).outputVariable(); - } - - public SDVariable randomBernoulli(double p, long... shape) { - return new BernoulliDistribution(sameDiff(), p, shape).outputVariable(); - } - - public SDVariable randomBinomial(int nTrials, double p, long... shape) { - return new BinomialDistribution(sameDiff(), nTrials, p, shape).outputVariable(); - } - - public SDVariable randomLogNormal(double mean, double stdev, long... shape) { - return new LogNormalDistribution(sameDiff(), mean, stdev, shape).outputVariable(); - } - - public SDVariable randomNormalTruncated(double mean, double stdev, long... shape) { - return new TruncatedNormalDistribution(sameDiff(), mean, stdev, shape).outputVariable(); - } - - public SDVariable randomGamma(SDVariable shape, SDVariable alpha, SDVariable beta, int... seeds) { - return new RandomGamma(sameDiff(), shape, alpha, beta, seeds).outputVariable(); - } - - public SDVariable randomPoisson(SDVariable shape, SDVariable rate, int... seeds) { - return new RandomPoisson(sameDiff(), shape, rate, seeds).outputVariable(); - } - - public SDVariable randomShuffle(SDVariable values, int... seeds) { - return new RandomShuffle(sameDiff(), values, seeds).outputVariable(); - } - - /** - * Exponential distribution: P(x) = lambda * exp(-lambda * x) - * - * @param lambda Must be > 0 - * @param shape Shape of the output - */ - public SDVariable randomExponential(double lambda, SDVariable shape) { - return new RandomExponential(sameDiff(), shape, lambda).outputVariable(); - } - - - public SDVariable pad(SDVariable input, SDVariable padding, Pad.Mode mode, double padValue){ - return new Pad(sameDiff(), input, padding, mode, padValue).outputVariable(); - } - - /** - * Local response normalization operation. - * - * @param input the inputs to lrn - * @param lrnConfig the configuration - * @return - */ - public SDVariable localResponseNormalization(SDVariable input, LocalResponseNormalizationConfig lrnConfig) { - LocalResponseNormalization lrn = LocalResponseNormalization.sameDiffBuilder() - .inputFunctions(new SDVariable[]{input}) - .sameDiff(sameDiff()) - .config(lrnConfig) - .build(); - - return lrn.outputVariable(); - } - - /** - * Conv1d operation. - * - * @param input the inputs to conv1d - * @param weights conv1d weights - * @param conv1DConfig the configuration - * @return - */ - public SDVariable conv1d(SDVariable input, SDVariable weights, Conv1DConfig conv1DConfig) { - Conv1D conv1D = Conv1D.sameDiffBuilder() - .inputFunctions(new SDVariable[]{input, weights}) - .sameDiff(sameDiff()) - .config(conv1DConfig) - .build(); - - return conv1D.outputVariable(); - } - - /** - * Conv1d operation. - * - * @param input the inputs to conv1d - * @param weights conv1d weights - * @param bias conv1d bias - * @param conv1DConfig the configuration - * @return - */ - public SDVariable conv1d(SDVariable input, SDVariable weights, SDVariable bias, Conv1DConfig conv1DConfig) { - - SDVariable[] args; - - if(bias == null){ - args = new SDVariable[]{input, weights}; - } else { - args = new SDVariable[]{input, weights, bias}; - } - - Conv1D conv1D = Conv1D.sameDiffBuilder() - .inputFunctions(args) - .sameDiff(sameDiff()) - .config(conv1DConfig) - .build(); - - return conv1D.outputVariable(); - } - - /** - * Conv2d operation. - * - * @param inputs the inputs to conv2d - * @param conv2DConfig the configuration - * @return - */ - public SDVariable conv2d(SDVariable[] inputs, Conv2DConfig conv2DConfig) { - Conv2D conv2D = Conv2D.sameDiffBuilder() - .inputFunctions(inputs) - .sameDiff(sameDiff()) - .config(conv2DConfig) - .build(); - - return conv2D.outputVariable(); - } - - public SDVariable upsampling2d(SDVariable input, boolean nchw, int scaleH, int scaleW) { - return new Upsampling2d(sameDiff(), input, nchw, scaleH, scaleW).outputVariable(); - } - - public SDVariable upsampling2dBp(SDVariable input, SDVariable gradient, boolean nchw, int scaleH, int scaleW) { - return new Upsampling2dDerivative(sameDiff(), input, gradient, nchw, scaleH, scaleW).outputVariable(); - } - - - /** - * Average pooling 2d operation. - * - * @param input the inputs to pooling - * @param pooling2DConfig the configuration - * @return - */ - public SDVariable avgPooling2d(SDVariable input, Pooling2DConfig pooling2DConfig) { - AvgPooling2D avgPooling2D = AvgPooling2D.sameDiffBuilder() - .input(input) - .sameDiff(sameDiff()) - .config(pooling2DConfig) - .build(); - - return avgPooling2D.outputVariable(); - } - - /** - * Max pooling 2d operation. - * - * @param input the inputs to pooling - * @param pooling2DConfig the configuration - * @return - */ - public SDVariable maxPooling2d(SDVariable input, Pooling2DConfig pooling2DConfig) { - MaxPooling2D maxPooling2D = MaxPooling2D.sameDiffBuilder() - .input(input) - .sameDiff(sameDiff()) - .config(pooling2DConfig) - .build(); - - return maxPooling2D.outputVariable(); - } - - /** - * Avg pooling 3d operation. - * - * @param input the inputs to pooling - * @param pooling3DConfig the configuration - * @return - */ - public SDVariable avgPooling3d(SDVariable input, Pooling3DConfig pooling3DConfig) { - pooling3DConfig.setType(Pooling3D.Pooling3DType.AVG); - return new AvgPooling3D(sameDiff(), input, pooling3DConfig).outputVariable(); - } - - - /** - * Max pooling 3d operation. - * - * @param input the inputs to pooling - * @param pooling3DConfig the configuration - * @return - */ - public SDVariable maxPooling3d(SDVariable input, Pooling3DConfig pooling3DConfig) { - pooling3DConfig.setType(Pooling3D.Pooling3DType.MAX); - return new MaxPooling3D(sameDiff(), input, pooling3DConfig).outputVariable(); - } - - - /** - * Separable Conv2d operation. - * - * @param inputs the inputs to conv2d - * @param conv2DConfig the configuration - * @return - */ - public SDVariable sconv2d(SDVariable[] inputs, Conv2DConfig conv2DConfig) { - SConv2D sconv2D = SConv2D.sameDiffSBuilder() - .inputFunctions(inputs) - .sameDiff(sameDiff()) - .conv2DConfig(conv2DConfig) - .build(); - - return sconv2D.outputVariable(); - } - - - /** - * Depth-wise Conv2d operation. This is just separable convolution with - * only the depth-wise weights specified. - * - * @param inputs the inputs to conv2d - * @param depthConv2DConfig the configuration - * @return - */ - public SDVariable depthWiseConv2d(SDVariable[] inputs, Conv2DConfig depthConv2DConfig) { - SConv2D depthWiseConv2D = SConv2D.sameDiffSBuilder() - .inputFunctions(inputs) - .sameDiff(sameDiff()) - .conv2DConfig(depthConv2DConfig) - .build(); - - return depthWiseConv2D.outputVariable(); - } - - - /** - * Deconv2d operation. - * - * @param inputs the inputs to conv2d - * @param deconv2DConfig the configuration - * @return - */ - public SDVariable deconv2d(SDVariable[] inputs, DeConv2DConfig deconv2DConfig) { - DeConv2D deconv2D = DeConv2D.sameDiffBuilder() - .inputs(inputs) - .sameDiff(sameDiff()) - .config(deconv2DConfig) - .build(); - - return deconv2D.outputVariable(); - } - - public SDVariable deconv3d(SDVariable input, SDVariable weights, SDVariable bias, DeConv3DConfig config) { - DeConv3D d = new DeConv3D(sameDiff(), input, weights, bias, config); - return d.outputVariable(); - } - - public SDVariable[] deconv3dDerivative(SDVariable input, SDVariable weights, SDVariable bias, SDVariable grad, DeConv3DConfig config) { - DeConv3DDerivative d = new DeConv3DDerivative(sameDiff(), input, weights, bias, grad, config); - return d.outputVariables(); - } - - /** - * Conv3d operation. - * - * @param inputs the inputs to conv3d - * @param conv3DConfig the configuration - * @return - */ - public SDVariable conv3d(SDVariable[] inputs, Conv3DConfig conv3DConfig) { - Conv3D conv3D = Conv3D.sameDiffBuilder() - .inputFunctions(inputs) - .config(conv3DConfig) - .sameDiff(sameDiff()) - .build(); - - val outputVars = conv3D.outputVariables(); - return outputVars[0]; - } - - - /** - * Batch norm operation. - */ - public SDVariable batchNorm(SDVariable input, SDVariable mean, - SDVariable variance, SDVariable gamma, - SDVariable beta, - boolean applyGamma, boolean applyBeta, - double epsilon, int... axis) { - BatchNorm batchNorm = BatchNorm.builder() - .inputFunctions(new SDVariable[]{input, mean, variance, gamma, beta}) - .applyGamma(applyGamma) - .applyBeta(applyBeta) - .epsilon(epsilon) - .sameDiff(sameDiff()) - .axis(axis) - .build(); - - val outputVars = batchNorm.outputVariables(); - return outputVars[0]; - } - - public SDVariable im2Col(SDVariable input, Conv2DConfig config) { - return new Im2col(sameDiff(), input, config).outputVariable(); - } - - public SDVariable im2ColBp(SDVariable im2colInput, SDVariable gradientAtOutput, Conv2DConfig config) { - return new Im2colBp(sameDiff(), im2colInput, gradientAtOutput, config).outputVariable(); - } - - public SDVariable col2Im(SDVariable input, Conv2DConfig config) { - return new Col2Im(sameDiff(), input, config).outputVariable(); - } - - public SDVariable extractImagePatches(SDVariable input, int kH, int kW, int sH, int sW, int rH, int rW, boolean sameMode){ - return new ExtractImagePatches(sameDiff(), input, new int[]{kH, kW}, new int[]{sH, sW}, new int[]{rH, rW}, sameMode).outputVariable(); - } - - public SDVariable[] moments(SDVariable input, int... axes) { - return new Moments(sameDiff(), input, axes).outputVariables(); - } - - public SDVariable[] normalizeMoments(SDVariable counts, SDVariable means, SDVariable variances, double shift) { - return new NormalizeMoments(sameDiff(), counts, means, variances, shift).outputVariables(); - } - - - public SDVariable tile(@NonNull SDVariable iX, @NonNull int[] repeat) { - return new Tile(sameDiff(), iX, repeat).outputVariable(); - } - - public SDVariable tileBp(@NonNull SDVariable in, @NonNull SDVariable grad, @NonNull int[] repeat){ - return new TileBp(sameDiff, in, grad, repeat).outputVariable(); - } - - public SDVariable tile(@NonNull SDVariable iX, @NonNull SDVariable repeat) { - return new Tile(sameDiff(), iX, repeat).outputVariable(); - } - - public SDVariable tileBp(@NonNull SDVariable in, @NonNull SDVariable repeat, @NonNull SDVariable grad){ - return new TileBp(sameDiff, in, repeat, grad).outputVariable(); - } - - public SDVariable dropout(SDVariable input, double p) { - return new DropOutInverted(sameDiff(), input, p).outputVariable(); - } - - - public SDVariable sum(SDVariable i_x, boolean keepDims, int... dimensions) { - return new Sum(sameDiff(), i_x, keepDims, dimensions).outputVariable(); - } - - public SDVariable sumBp(SDVariable i_x, SDVariable grad, boolean keepDims, int... dimensions) { - return new SumBp(sameDiff(), i_x, grad, keepDims, dimensions).outputVariable(); - } - - - public SDVariable prod(SDVariable i_x, boolean keepDims, int... dimensions) { - return new Prod(sameDiff(), i_x, keepDims, dimensions).outputVariable(); - } - - public SDVariable prodBp(SDVariable preReduceInput, SDVariable grad, boolean keepDims, int... dimensions) { - return new ProdBp(sameDiff(), preReduceInput, grad, keepDims, dimensions).outputVariable(); - } - - public SDVariable mean(SDVariable in, boolean keepDims, int... dimensions) { - return new Mean(sameDiff(), in, keepDims, dimensions).outputVariable(); - } - - public SDVariable meanBp(SDVariable in, SDVariable grad, boolean keepDims, int... dimensions) { - return new MeanBp(sameDiff(), in, grad, keepDims, dimensions).outputVariable(); - } - - - public SDVariable std(SDVariable i_x, boolean biasCorrected, boolean keepDims, int... dimensions) { - return new StandardDeviation(sameDiff(), i_x, biasCorrected, keepDims, dimensions).outputVariable(); - } - - public SDVariable stdBp(SDVariable stdInput, SDVariable gradient, boolean biasCorrected, boolean keepDims, int... dimensions) { - return new StandardDeviationBp(sameDiff(), stdInput, gradient, biasCorrected, keepDims, dimensions).outputVariable(); - } - - - public SDVariable variance(SDVariable i_x, boolean biasCorrected, boolean keepDims, int... dimensions) { - return new Variance(sameDiff(), i_x, biasCorrected, keepDims, dimensions).outputVariable(); - } - - public SDVariable varianceBp(SDVariable stdInput, SDVariable gradient, boolean biasCorrected, boolean keepDims, int... dimensions) { - return new VarianceBp(sameDiff(), stdInput, gradient, biasCorrected, keepDims, dimensions).outputVariable(); - } - - public SDVariable standardize(SDVariable i_x, int... dimensions) { - return new Standardize(sameDiff(), i_x, dimensions).outputVariable(); - } - - public SDVariable standardizeBp(SDVariable stdInput, SDVariable gradient, int... dimensions) { - return new StandardizeBp(sameDiff(), stdInput, gradient, dimensions).outputVariable(); - } - - public SDVariable layerNorm(SDVariable input, SDVariable gain, SDVariable bias, boolean channelsFirst, int... dimensions) { - return new LayerNorm(sameDiff(), input, gain, bias, channelsFirst, dimensions).outputVariable(); - } - - public SDVariable[] layerNormBp(SDVariable input, SDVariable gain, SDVariable bias, SDVariable gradient, boolean channelsFirst, int... dimensions) { - return new LayerNormBp(sameDiff(), input, gain, bias, gradient, channelsFirst, dimensions).outputVariables(); - } - - public SDVariable layerNorm(SDVariable input, SDVariable gain, boolean channelsFirst, int... dimensions) { - return new LayerNorm(sameDiff(), input, gain, channelsFirst, dimensions).outputVariable(); - } - - public SDVariable[] layerNormBp(SDVariable input, SDVariable gain, SDVariable gradient, boolean channelsFirst, int... dimensions) { - return new LayerNormBp(sameDiff(), input, gain, gradient, channelsFirst, dimensions).outputVariables(); - } - - public SDVariable squaredNorm(SDVariable input, boolean keepDims, int... dimensions) { - return new SquaredNorm(sameDiff(), input, keepDims, dimensions).outputVariable(); - } - - public SDVariable squaredNormBp(SDVariable preReduceInput, SDVariable gradient, boolean keepDims, int... dimensions) { - return new SquaredNormBp(sameDiff(), preReduceInput, gradient, keepDims, dimensions).outputVariable(); - } - - public SDVariable entropy(SDVariable in, int... dimensions) { - return new Entropy(sameDiff(), in, dimensions).outputVariable(); - } - - public SDVariable logEntropy(SDVariable in, int... dimensions) { - return new LogEntropy(sameDiff(), in, dimensions).outputVariable(); - } - - public SDVariable shannonEntropy(SDVariable in, int... dimensions){ - return new ShannonEntropy(sameDiff(), in, dimensions).outputVariable(); - } - - public SDVariable countNonZero(SDVariable input, int... dimensions) { - return new CountNonZero(sameDiff(), input, dimensions).outputVariable(); - } - - public SDVariable countZero(SDVariable input, int... dimensions) { - return new CountZero(sameDiff(), input, dimensions).outputVariable(); - } - - public SDVariable zeroFraction(SDVariable input) { - return new ZeroFraction(sameDiff(), input).outputVariable(); - } - - public SDVariable scalarMax(SDVariable in, Number num) { - return new ScalarMax(sameDiff(), in, num).outputVariable(); - } - - public SDVariable scalarMin(SDVariable in, Number num) { - return new ScalarMin(sameDiff(), in, num).outputVariable(); - } - - public SDVariable scalarSet(SDVariable in, Number num) { - return new ScalarSet(sameDiff(), in, num).outputVariable(); - } - - public SDVariable scalarFloorMod(SDVariable in, Number num) { - return new ScalarFMod(sameDiff(), in, num).outputVariable(); - } - - public SDVariable max(SDVariable i_x, boolean keepDims, int... dimensions) { - return new Max(sameDiff(), i_x, keepDims, dimensions).outputVariable(); - } - - public SDVariable max(SDVariable first, SDVariable second) { - return new org.nd4j.linalg.api.ops.impl.transforms.custom.Max(sameDiff(), first, second) - .outputVariable(); - } - - public SDVariable maxBp(SDVariable i_x, SDVariable grad, boolean keepDims, int... dimensions) { - return new MaxBp(sameDiff(), i_x, grad, keepDims, dimensions).outputVariable(); - } - - - public SDVariable min(SDVariable i_x, boolean keepDims, int... dimensions) { - return new Min(sameDiff(), i_x, keepDims, dimensions).outputVariable(); - } - - public SDVariable minBp(SDVariable i_x, SDVariable grad, boolean keepDims, int... dimensions) { - return new MinBp(sameDiff(), i_x, grad, keepDims, dimensions).outputVariable(); - } - - public SDVariable min(SDVariable first, SDVariable second) { - return new org.nd4j.linalg.api.ops.impl.transforms.custom.Min(sameDiff(), first, second) - .outputVariable(); - } - - public SDVariable amax(SDVariable in, int... dimensions) { - return new AMax(sameDiff(), in, dimensions).outputVariable(); - } - - public SDVariable amin(SDVariable in, int... dimensions) { - return new AMin(sameDiff(), in, dimensions).outputVariable(); - } - - public SDVariable amean(SDVariable in, int... dimensions) { - return new AMean(sameDiff(), in, dimensions).outputVariable(); - } - - public SDVariable asum(SDVariable in, int... dimensions) { - return new ASum(sameDiff(), in, dimensions).outputVariable(); - } - - public SDVariable argmax(SDVariable in, boolean keepDims, int... dimensions) { - return new IMax(sameDiff(), in, keepDims, dimensions).outputVariable(); - } - - public SDVariable argmin(SDVariable in, boolean keepDims, int... dimensions) { - return new IMin(sameDiff(), in, keepDims, dimensions).outputVariable(); - } - - public SDVariable iamax(SDVariable in, boolean keepDims, int... dimensions) { - return new IAMax(sameDiff(), in, keepDims, dimensions).outputVariable(); - } - - public SDVariable iamin(SDVariable in, boolean keepDims, int... dimensions) { - return new IAMin(sameDiff(), in, keepDims, dimensions).outputVariable(); - } - - public SDVariable firstIndex(SDVariable in, Condition condition, boolean keepDims, int... dimensions) { - return new FirstIndex(sameDiff(), in, condition, keepDims, dimensions).outputVariable(); - } - - public SDVariable lastIndex(SDVariable in, Condition condition, boolean keepDims, int... dimensions) { - return new LastIndex(sameDiff(), in, condition, keepDims, dimensions).outputVariable(); - } - - /** - * Returns a count of the number of elements that satisfy the condition - * - * @param in Input - * @param condition Condition - * @return Number of elements that the condition is satisfied for - */ - public SDVariable matchConditionCount(SDVariable in, Condition condition, boolean keepDims, int... dimensions) { - return new MatchCondition(sameDiff(), in, condition, keepDims, dimensions).outputVariable(); - } - - /** - * Returns a boolean mask of equal shape to the input, where the condition is satisfied - * - * @param in Input - * @param condition Condition - * @return Boolean mask - */ - public SDVariable matchCondition(SDVariable in, Condition condition) { - return new MatchConditionTransform(sameDiff(), in, condition).outputVariable(); - } - - public SDVariable cumsum(SDVariable in, boolean exclusive, boolean reverse, int... axis) { - return new CumSum(sameDiff(), in, exclusive, reverse, axis).outputVariable(); - } - - public SDVariable cumsumBp(SDVariable in, SDVariable grad, boolean exclusive, boolean reverse, int... axis) { - return new CumSumBp(sameDiff(), in, grad, exclusive, reverse, axis).outputVariable(); - } - - public SDVariable cumprod(SDVariable in, boolean exclusive, boolean reverse, int... axis) { - return new CumProd(sameDiff(), in, exclusive, reverse, axis).outputVariable(); - } - - public SDVariable cumprodBp(SDVariable in, SDVariable grad, boolean exclusive, boolean reverse, int... axis) { - return new CumProdBp(sameDiff(), in, grad, exclusive, reverse, axis).outputVariable(); - } - - public SDVariable biasAdd(SDVariable input, SDVariable bias, boolean nchw) { - return new BiasAdd(sameDiff(), input, bias, nchw).outputVariable(); - } - - public SDVariable[] biasAddBp(SDVariable input, SDVariable bias, SDVariable grad, boolean nchw) { - return new BiasAddGrad(sameDiff(), input, bias, grad, nchw).outputVariables(); - } - - public SDVariable norm1(SDVariable i_x, boolean keepDims, int... dimensions) { - return new Norm1(sameDiff(), i_x, keepDims, dimensions).outputVariable(); - } - - public SDVariable norm1Bp(SDVariable preReduceIn, SDVariable grad, boolean keepDims, int... dimensions) { - return new Norm1Bp(sameDiff(), preReduceIn, grad, keepDims, dimensions).outputVariable(); - } - - public SDVariable norm2(SDVariable i_x, boolean keepDims, int... dimensions) { - return new Norm2(sameDiff(), i_x, keepDims, dimensions).outputVariable(); - } - - public SDVariable norm2Bp(SDVariable preReduceIn, SDVariable grad, boolean keepDims, int... dimensions) { - return new Norm2Bp(sameDiff(), preReduceIn, grad, keepDims, dimensions).outputVariable(); - } - - public SDVariable normmax(SDVariable i_x, boolean keepDims, int... dimensions) { - return new NormMax(sameDiff(), i_x, keepDims, dimensions).outputVariable(); - } - - public SDVariable normmaxBp(SDVariable preReduceIn, SDVariable grad, boolean keepDims, int... dimensions) { - return new NormMaxBp(sameDiff(), preReduceIn, grad, keepDims, dimensions).outputVariable(); - } - - public SDVariable reductionShape(SDVariable shape, SDVariable axis, boolean keepDim){ - return new ReductionShape(sameDiff(), shape, axis, keepDim).outputVariable(); - } - - /** - * Add 1s as required to the array make an array possible to be broadcast with the original (pre-reduce) array. - *

- * Example: if doing [a,b,c].sum(1), result is [a,c]. To 'undo' this in a way that can be auto-broadcast, - * we want to expand as required - i.e., [a,c] -> [a,1,c] which can be auto-broadcast with the original [a,b,c]. - * This is typically only used with reduction operations backprop. - * - * @param origRank Rank of the original array, before the reduction was executed - * @param reduceDims Dimensions that the original array was reduced from - * @param toExpand Array to add 1s to the shape to (such that it can be - * @return Reshaped array. - */ - public SDVariable reductionBroadcastableWithOrigShape(int origRank, int[] reduceDims, SDVariable toExpand) { - if (Shape.isWholeArray(origRank, reduceDims)) { - //Output is [1,1] which is already broadcastable - return toExpand; - } else if (origRank == 2 && reduceDims.length == 1) { - //In this case: [a,b] -> [1,b] or [a,b] -> [a,1] - //both are already broadcastable - return toExpand; - } else { - //Example: [a,b,c].sum(1) -> [a,c]... want [a,1,c] - for (int d : reduceDims) { - toExpand = sameDiff().expandDims(toExpand, d); - } - return toExpand; - } - } - - public SDVariable reductionBroadcastableWithOrigShape(SDVariable origInput, SDVariable axis, SDVariable toExpand) { - SDVariable shape = origInput.shape(); - SDVariable reduceShape = reductionShape(shape, axis, true); - SDVariable reshaped = toExpand.reshape(reduceShape); - return reshaped; - } - - - public SDVariable gradientBackwardsMarker(SDVariable iX) { - return new GradientBackwardsMarker(sameDiff(), iX, sameDiff.scalar(iX.name() + "-pairgrad", 1.0)).outputVariable(); - } - - public SDVariable abs(SDVariable iX) { - return new Abs(sameDiff(), iX, false).outputVariable(); - } - - - public SDVariable neg(SDVariable iX) { - return new Negative(sameDiff(), iX, false).outputVariable(); - } - - - public SDVariable cos(SDVariable iX) { - return new Cos(sameDiff(), iX, false).outputVariable(); - } - - - public SDVariable sin(SDVariable iX) { - return new Sin(sameDiff(), iX, false).outputVariable(); - } - - - public SDVariable tan(SDVariable iX) { - return new Tan(sameDiff(), iX, false).outputVariable(); - - } - - - public SDVariable permute(SDVariable iX, int... dimensions) { - return new Permute(sameDiff(), iX, dimensions).outputVariable(); - } - - public SDVariable permute(SDVariable in, SDVariable dimensions) { - return new Permute(sameDiff(), in, dimensions).outputVariable(); - } - - public SDVariable noop(SDVariable input) { - return new NoOp(sameDiff(), input).outputVariable(); - } - - public SDVariable identity(SDVariable input) { - return new Identity(sameDiff(), input).outputVariable(); - } - - public SDVariable all(SDVariable input, int... dimensions) { - return new All(sameDiff(), input, dimensions).outputVariable(); - } - - public SDVariable any(SDVariable input, int... dimensions) { - return new Any(sameDiff(), input, dimensions).outputVariable(); - } - - public SDVariable invertPermutation(SDVariable input, boolean inPlace) { - return new InvertPermutation(sameDiff(), input, inPlace).outputVariable(); - } - - public SDVariable transpose(SDVariable iX) { - return new Transpose(sameDiff(), iX).outputVariable(); - } - - - public SDVariable acos(SDVariable iX) { - return new ACos(sameDiff(), iX, false).outputVariable(); - } - - - public SDVariable asin(SDVariable iX) { - return new ASin(sameDiff(), iX, false).outputVariable(); - } - - - public SDVariable atan(SDVariable iX) { - return new ATan(sameDiff(), iX, false).outputVariable(); - - } - - public SDVariable atan2(SDVariable y, SDVariable x) { - return new ATan2(sameDiff(), y, x).outputVariable(); - } - - - public SDVariable cosh(SDVariable iX) { - return new Cosh(sameDiff(), iX, false).outputVariable(); - - } - - - public SDVariable sinh(SDVariable iX) { - return new Sinh(sameDiff(), iX, false).outputVariable(); - } - - - public SDVariable tanh(SDVariable iX) { - return new Tanh(sameDiff(), iX, false).outputVariable(); - } - - public SDVariable tanhRational(SDVariable in) { - return new RationalTanh(sameDiff(), in, false).outputVariable(); - } - - public SDVariable tanhRectified(SDVariable in) { - return new RectifiedTanh(sameDiff(), in, false).outputVariable(); - } - - public SDVariable tanhDerivative(SDVariable iX, SDVariable wrt) { - return new org.nd4j.linalg.api.ops.impl.transforms.gradient.TanhDerivative(sameDiff(), iX, wrt).outputVariable(); - } - - public SDVariable tanhRationalBp(SDVariable in, SDVariable epsilon) { - return new RationalTanhBp(sameDiff(), in, epsilon).outputVariable(); - } - - public SDVariable tanhRectifiedBp(SDVariable in, SDVariable epsilon) { - return new RectifiedTanhBp(sameDiff(), in, epsilon).outputVariable(); - } - - /** - * Use {@link #tanhRationalBp(SDVariable, SDVariable)} - */ - @Deprecated - public SDVariable tanhRationalDerivative(SDVariable in) { - return new RationalTanhDerivative(sameDiff(), in, false).outputVariable(); - } - - /** - * Use {@link #tanhRectifiedBp(SDVariable, SDVariable)} - */ - @Deprecated - public SDVariable tanhRectifiedDerivative(SDVariable in) { - return new RectifiedTanhDerivative(sameDiff(), in, false).outputVariable(); - } - - public SDVariable step(SDVariable in, double cutoff) { - return new Step(sameDiff(), in, false, cutoff).outputVariable(); - } - - - public SDVariable acosh(SDVariable iX) { - return new ACosh(sameDiff(), iX).outputVariable(); - } - - - public SDVariable asinh(SDVariable iX) { - return new ASinh(sameDiff(), iX).outputVariable(); - } - - - public SDVariable atanh(SDVariable iX) { - return new ATanh(sameDiff(), iX).outputVariable(); - } - - - public SDVariable exp(SDVariable iX) { - return new Exp(sameDiff(), iX, false).outputVariable(); - } - - public SDVariable expm1(SDVariable iX) { - return new Expm1(sameDiff(), iX, false).outputVariable(); - } - - public SDVariable rsqrt(SDVariable iX) { - return new RSqrt(sameDiff(), iX, false).outputVariable(); - } - - public SDVariable log(SDVariable iX) { - return new Log(sameDiff(), iX, false).outputVariable(); - } - - public SDVariable log(SDVariable in, double base) { - return new LogX(sameDiff(), in, base).outputVariable(); - } - - public SDVariable log1p(SDVariable iX) { - return new Log1p(sameDiff(), iX, false).outputVariable(); - } - - - public SDVariable isFinite(SDVariable ix) { - return new IsFinite(sameDiff(), ix, false).outputVariable(); - } - - public SDVariable isInfinite(SDVariable ix) { - return new IsInf(sameDiff(), ix, false).outputVariable(); - } - - public SDVariable isNaN(SDVariable ix) { - return new IsNaN(sameDiff(), ix, false).outputVariable(); - } - - public SDVariable isMax(SDVariable ix) { - return new IsMax(sameDiff(), ix).outputVariable(); - } - - public SDVariable replaceWhere(SDVariable to, SDVariable from, Condition condition) { - return new CompareAndReplace(sameDiff(), to, from, condition).outputVariable(); - } - - public SDVariable replaceWhere(SDVariable to, Number set, Condition condition) { - return new CompareAndSet(sameDiff(), to, set, condition).outputVariable(); - } - - public SDVariable round(SDVariable ix) { - return new Round(sameDiff(), ix, false).outputVariable(); - } - - public SDVariable or(SDVariable iX, SDVariable i_y) { - return new Or(sameDiff(), iX, i_y).outputVariable(); - } - - public SDVariable and(SDVariable ix, SDVariable iy) { - return new And(sameDiff(), ix, iy).outputVariable(); - } - - public SDVariable xor(SDVariable ix, SDVariable iy) { - return new Xor(sameDiff(), ix, iy).outputVariable(); - } - - public SDVariable shift(SDVariable ix, SDVariable shift) { - return new ShiftBits(sameDiff(), ix, shift).outputVariable(); - } - - public SDVariable rshift(SDVariable ix, SDVariable shift) { - return new RShiftBits(sameDiff(), ix, shift).outputVariable(); - } - - public SDVariable rotl(SDVariable ix, SDVariable shift) { - return new CyclicShiftBits(sameDiff(), ix, shift).outputVariable(); - } - - public SDVariable rotr(SDVariable ix, SDVariable shift) { - return new CyclicRShiftBits(sameDiff(), ix, shift).outputVariable(); - } - - public SDVariable bitwiseHammingDist(SDVariable x, SDVariable y) { - return new BitsHammingDistance(sameDiff(), x, y).outputVariable(); - } - - public SDVariable bitwiseAnd(SDVariable x, SDVariable y){ - return new BitwiseAnd(sameDiff(), x, y).outputVariable(); - } - - public SDVariable bitwiseOr(SDVariable x, SDVariable y){ - return new BitwiseOr(sameDiff(), x, y).outputVariable(); - } - - public SDVariable bitwiseXor(SDVariable x, SDVariable y){ - return new BitwiseXor(sameDiff(), x, y).outputVariable(); - } - - public SDVariable eq(SDVariable iX, SDVariable i_y) { - return new EqualTo(sameDiff(), new SDVariable[]{iX, i_y}, false).outputVariable(); - } - - - public SDVariable neq(SDVariable iX, double i_y) { - return new ScalarNotEquals(sameDiff(), iX, i_y).outputVariable(); - } - - - public SDVariable neqi(SDVariable iX, double i_y) { - return new ScalarNotEquals(sameDiff(), iX, i_y, true).outputVariable(); - } - - - public SDVariable neqi(SDVariable iX, SDVariable i_y) { - return new NotEqualTo(sameDiff(), new SDVariable[]{iX, i_y}, true).outputVariable(); - } - - public SDVariable neq(SDVariable iX, SDVariable i_y) { - return new NotEqualTo(sameDiff(), new SDVariable[]{iX, i_y}, false).outputVariable(); - } - - public SDVariable pow(SDVariable iX, double i_y) { - return new Pow(sameDiff(), iX, false, i_y).outputVariable(); - } - - public SDVariable pow(SDVariable x, SDVariable y){ - return new org.nd4j.linalg.api.ops.impl.transforms.custom.Pow(sameDiff(), x, y).outputVariable(); - } - - public SDVariable sqrt(SDVariable iX) { - return new Sqrt(sameDiff(), iX, false).outputVariable(); - } - - - public SDVariable square(SDVariable iX) { - return new Square(sameDiff(), iX, false).outputVariable(); - } - - - public SDVariable cube(SDVariable iX) { - return new Cube(sameDiff(), iX, false).outputVariable(); - } - - public SDVariable cubeBp(SDVariable in, SDVariable epsilon) { - return new CubeBp(sameDiff(), in, epsilon).outputVariable(); - } - - /** - * @deprecated Use {@link #cubeBp(SDVariable, SDVariable)} - */ - @Deprecated - public SDVariable cubeDerivative(SDVariable iX) { - return new CubeDerivative(sameDiff(), iX, false).outputVariable(); - } - - - public SDVariable floor(SDVariable iX) { - return new Floor(sameDiff(), iX, false).outputVariable(); - } - - public SDVariable floorDiv(SDVariable x, SDVariable y) { - return new FloorDivOp(sameDiff(), x, y).outputVariable(); - } - - public List floorDivBp(SDVariable x, SDVariable y, SDVariable grad) { - return Arrays.asList(new FloorDivBpOp(sameDiff(), x, y, grad).outputVariables()); - } - - public SDVariable floorMod(SDVariable x, SDVariable y) { - return new FloorModOp(sameDiff(), x, y).outputVariable(); - } - - public List floorModBp(SDVariable x, SDVariable y, SDVariable grad) { - return Arrays.asList(new FloorModBpOp(sameDiff(), x, y, grad).outputVariables()); - } - - public SDVariable ceil(SDVariable x) { - return new Ceil(sameDiff(), x).outputVariable(); - } - - public SDVariable clipByValue(SDVariable x, double clipValueMin, double clipValueMax) { - return new ClipByValue(sameDiff(), x, clipValueMin, clipValueMax).outputVariable(); - } - - public SDVariable clipByNorm(SDVariable x, double clipValue) { - return new ClipByNorm(sameDiff(), x, clipValue).outputVariable(); - } - - public SDVariable clipByNorm(SDVariable x, double clipValue, int... dimensions) { - return new ClipByNorm(sameDiff(), x, clipValue, dimensions).outputVariable(); - } - - public SDVariable relu(SDVariable iX, double cutoff) { - return new RectifiedLinear(sameDiff(), iX, false, cutoff).outputVariable(); - } - - public SDVariable reluDerivative(SDVariable input, SDVariable grad){ - return new RectifiedLinearDerivative(sameDiff(), input, grad).outputVariable(); - } - - public SDVariable thresholdRelu(SDVariable in, SDVariable epsilon, double cutoff){ - return new ThresholdRelu(sameDiff(), in, cutoff).outputVariable(); - } - - public SDVariable thresholdReluBp(SDVariable in, SDVariable epsilon, double cutoff){ - return new ThresholdReluBp(sameDiff(), in, epsilon, cutoff).outputVariable(); - } - - public SDVariable relu6(SDVariable iX, double cutoff) { - return new Relu6(sameDiff(), iX, false, cutoff).outputVariable(); - } - - public SDVariable relu6Derivative(SDVariable iX, SDVariable wrt, double cutoff) { - return new Relu6Derivative(sameDiff(), iX, wrt, cutoff).outputVariable(); - } - - public SDVariable softmax(SDVariable iX) { - return new SoftMax(sameDiff(), new SDVariable[]{iX}).outputVariable(); - } - - public SDVariable softmax(SDVariable iX, int dimension) { - return new SoftMax(sameDiff(), new SDVariable[]{iX}, dimension).outputVariable(); - } - - - public SDVariable hardTanh(SDVariable iX) { - return new HardTanh(sameDiff(), iX, false).outputVariable(); - } - - public SDVariable hardTanhBp(SDVariable in, SDVariable epsilon) { - return new HardTanhBp(sameDiff(), in, epsilon).outputVariable(); - } - - /** - * @deprecated Use {@link #hardTanhBp(SDVariable, SDVariable)} - */ - @Deprecated - public SDVariable hardTanhDerivative(SDVariable iX) { - return new HardTanhDerivative(sameDiff(), iX, false).outputVariable(); - } - - public SDVariable hardSigmoid(SDVariable in) { - return new HardSigmoid(sameDiff(), in, false).outputVariable(); - } - - public SDVariable hardSigmoidBp(SDVariable in, SDVariable epsilon){ - return new HardSigmoidBp(sameDiff(), in, epsilon).outputVariable(); - } - - public SDVariable sigmoid(SDVariable iX) { - return new Sigmoid(sameDiff(), iX, false).outputVariable(); - } - - - public SDVariable sigmoidDerivative(SDVariable iX, SDVariable wrt) { - return new SigmoidDerivative(sameDiff(), iX, wrt).outputVariable(); - } - - - public SDVariable logSigmoid(SDVariable iX) { - return new LogSigmoid(sameDiff(), iX, false).outputVariable(); - } - - public SDVariable powDerivative(SDVariable iX, double pow) { - return new PowDerivative(sameDiff(), iX, false, pow).outputVariable(); - } - - public SDVariable[] powBp(SDVariable x, SDVariable pow, SDVariable gradient) { - return new PowBp(sameDiff(), x, pow, gradient).outputVariables(); - } - - public SDVariable mishDerivative(SDVariable iX) { - return new MishDerivative(sameDiff(), iX, false).outputVariable(); - } - - public SDVariable swish(SDVariable iX) { - return new Swish(sameDiff(), iX, false).outputVariable(); - } - - - public SDVariable swishDerivative(SDVariable iX) { - return new SwishDerivative(sameDiff(), iX, false).outputVariable(); - } - - public SDVariable gelu(SDVariable iX, boolean precise) { - if (precise) - return new PreciseGELU(sameDiff(), iX, false, precise).outputVariable(); - else - return new GELU(sameDiff(), iX, false, precise).outputVariable(); - } - - public SDVariable geluDerivative(SDVariable iX, boolean precise) { - if (precise) - return new PreciseGELUDerivative(sameDiff(), iX, false, precise).outputVariable(); - else - return new GELUDerivative(sameDiff(), iX, false).outputVariable(); - } - - public SDVariable sign(SDVariable iX) { - return new Sign(sameDiff(), iX, false).outputVariable(); - } - - - public SDVariable expandDims(SDVariable iX, int axis) { - return new ExpandDims(sameDiff(), new SDVariable[]{iX}, axis).outputVariable(); - } - - public SDVariable squeeze(SDVariable iX, int... axis) { - return new Squeeze(sameDiff(), iX, axis).outputVariable(); - } - - public SDVariable confusionMatrix(SDVariable labels, SDVariable pred, DataType dataType) { - return new ConfusionMatrix(sameDiff(), labels, pred, dataType).outputVariable(); - } - - public SDVariable confusionMatrix(SDVariable labels, SDVariable pred, Integer numClasses) { - return new ConfusionMatrix(sameDiff(), labels, pred, numClasses).outputVariable(); - } - - public SDVariable confusionMatrix(SDVariable labels, SDVariable pred, SDVariable weights) { - return new ConfusionMatrix(sameDiff(), labels, pred, weights).outputVariable(); - } - - public SDVariable confusionMatrix(SDVariable labels, SDVariable pred, Integer numClasses, SDVariable weights) { - return new ConfusionMatrix(sameDiff(), labels, pred, numClasses, weights).outputVariable(); - } - - public SDVariable matrixDeterminant(SDVariable in){ - return new MatrixDeterminant(sameDiff(), in, false).outputVariable(); - } - - public SDVariable matrixInverse(SDVariable in){ - return new MatrixInverse(sameDiff(), in, false).outputVariable(); - } - - public SDVariable onehot(SDVariable indices, int depth, int axis, double on, double off, DataType dataType) { - return new OneHot(sameDiff(), indices, depth, axis, on, off, dataType).outputVariable(); - } - - public SDVariable onehot(SDVariable indices, int depth) { - return new OneHot(sameDiff(), indices, depth).outputVariable(); - } - - public SDVariable reciprocal(SDVariable a) { - return new Reciprocal(sameDiff(), a).outputVariable(); - } - - - public SDVariable repeat(SDVariable iX, int axis) { - return new Repeat(sameDiff(), new SDVariable[]{iX}, axis).outputVariable(); - - } - - public SDVariable stack(SDVariable[] values, int axis) { - return new Stack(sameDiff(), values, axis).outputVariable(); - } - - public SDVariable parallel_stack(SDVariable[] values) { - return new ParallelStack(sameDiff(), values).outputVariable(); - } - - public SDVariable[] unstack(SDVariable value, int axis) { - return new Unstack(sameDiff(), value, axis).outputVariables(); - } - - public SDVariable[] unstack(SDVariable value, int axis, int num) { - return new Unstack(sameDiff(), value, axis, num).outputVariables(); - } - - public SDVariable assign(SDVariable x, SDVariable y) { - return new Assign(sameDiff(), x, y).outputVariable(); - } - - public SDVariable assign(SDVariable x, Number num) { - return new ScalarSet(sameDiff(), x, num).outputVariable(); - } - - - public SDVariable softsign(SDVariable iX) { - return new SoftSign(sameDiff(), iX, false).outputVariable(); - - } - - public SDVariable softsignBp(SDVariable in, SDVariable epsilon) { - return new SoftSignBp(sameDiff(), in, epsilon).outputVariable(); - } - - /** - * @deprecated Use {@link #softsignBp(SDVariable, SDVariable)} - */ - @Deprecated - public SDVariable softsignDerivative(SDVariable iX) { - return new SoftSignDerivative(sameDiff(), iX, false).outputVariable(); - } - - - public SDVariable softplus(SDVariable iX) { - return new SoftPlus(sameDiff(), iX, false).outputVariable(); - - } - - - public SDVariable elu(SDVariable iX) { - return new ELU(sameDiff(), iX).outputVariable(); - - } - - public SDVariable eluBp(SDVariable in, SDVariable epsilon, double alpha) { - return new EluBp(sameDiff(), in, epsilon, alpha).outputVariable(); - } - - - public SDVariable leakyRelu(SDVariable iX, double alpha) { - return new LeakyReLU(sameDiff(), iX, false, alpha).outputVariable(); - - } - - public SDVariable leakyReluBp(SDVariable in, SDVariable epsilon, double cutoff) { - return new LeakyReLUBp(sameDiff(), in, epsilon, cutoff).outputVariable(); - } - - /** - * @deprecated Use {@link #leakyReluBp(SDVariable, SDVariable, double)} - */ - @Deprecated - public SDVariable leakyReluDerivative(SDVariable iX, double cutoff) { - return new LeakyReLUDerivative(sameDiff(), iX, false, cutoff).outputVariable(); - } - - public SDVariable prelu(SDVariable x, SDVariable alpha, int... sharedAxes){ - return new PRelu(sameDiff(), x, alpha, sharedAxes).outputVariable(); - } - - public SDVariable[] preluBp(SDVariable in, SDVariable alpha, SDVariable epsilon, int... sharedAxes){ - return new PReluBp(sameDiff(), in, alpha, epsilon, sharedAxes).outputVariables(); - } - - public SDVariable reshape(SDVariable iX, int[] shape) { - return new Reshape(sameDiff(), iX, ArrayUtil.toLongArray(shape)).outputVariable(); - } - - public SDVariable reshape(SDVariable iX, long[] shape) { - return new Reshape(sameDiff(), iX, shape).outputVariable(); - } - - public SDVariable reshape(SDVariable iX, SDVariable shape) { - return new Reshape(sameDiff(), iX, shape).outputVariable(); - } - - public SDVariable reverse(SDVariable x, int... dimensions) { - return new Reverse(sameDiff(), x, dimensions).outputVariable(); - } - - public SDVariable reverseSequence(SDVariable x, SDVariable seq_lengths, int seq_dim, int batch_dim) { - return new ReverseSequence(sameDiff(), x, seq_lengths, seq_dim, batch_dim).outputVariable(); - } - - public SDVariable reverseSequence(SDVariable x, SDVariable seq_lengths) { - return new ReverseSequence(sameDiff(), x, seq_lengths).outputVariable(); - } - - public SDVariable sequenceMask(SDVariable lengths, SDVariable maxLen, DataType dataType) { - return new SequenceMask(sameDiff(), lengths, maxLen, dataType).outputVariable(); - } - - public SDVariable sequenceMask(SDVariable lengths, int maxLen, DataType dataType) { - return new SequenceMask(sameDiff(), lengths, maxLen, dataType).outputVariable(); - } - - public SDVariable sequenceMask(SDVariable lengths, DataType dataType) { - return new SequenceMask(sameDiff(), lengths, dataType).outputVariable(); - } - - public SDVariable concat(int dimension, SDVariable... inputs) { - return new Concat(sameDiff(), dimension, inputs).outputVariable(); - } - - public SDVariable fill(SDVariable shape, DataType dataType, double value) { - return new Fill(sameDiff(), shape, dataType, value).outputVariable(); - } - - public SDVariable dot(SDVariable x, SDVariable y, int... dimensions) { - return new Dot(sameDiff(), x, y, dimensions).outputVariable(); - } - - public SDVariable[] dotBp(SDVariable in1, SDVariable in2, SDVariable grad, boolean keepDims, int... dimensions) { - return new DotBp(sameDiff(), in1, in2, grad, keepDims, dimensions).outputVariables(); - } - - public SDVariable cosineSimilarity(SDVariable iX, SDVariable i_y, int... dimensions) { - return new CosineSimilarity(sameDiff(), iX, i_y, dimensions).outputVariable(); - } - - public SDVariable cosineDistance(SDVariable ix, SDVariable iy, int... dimensions) { - return new CosineDistance(sameDiff(), ix, iy, dimensions).outputVariable(); - } - - - public SDVariable euclideanDistance(SDVariable iX, SDVariable i_y, int... dimensions) { - return new EuclideanDistance(sameDiff(), iX, i_y, dimensions).outputVariable(); - } - - - public SDVariable manhattanDistance(SDVariable iX, SDVariable i_y, int... dimensions) { - return new ManhattanDistance(sameDiff(), iX, i_y, dimensions).outputVariable(); - } - - public SDVariable hammingDistance(SDVariable ix, SDVariable iy, int... dimensions) { - return new HammingDistance(sameDiff(), ix, iy, dimensions).outputVariable(); - } - - public SDVariable jaccardDistance(SDVariable ix, SDVariable iy, int... dimensions) { - return new JaccardDistance(sameDiff(), ix, iy, dimensions).outputVariable(); - } - - public SDVariable weightedCrossEntropyWithLogits(SDVariable targets, SDVariable inputs, SDVariable weights) { - return new WeightedCrossEntropyLoss(sameDiff(), targets, inputs, weights).outputVariable(); - } - - public SDVariable lossL2(SDVariable var){ - return new L2Loss(sameDiff(), var).outputVariable(); - } - - public SDVariable lossAbsoluteDifference(SDVariable label, SDVariable predictions, SDVariable weights, LossReduce lossReduce){ - return new AbsoluteDifferenceLoss(sameDiff(), lossReduce, predictions, weights, label).outputVariable(); - } - - public SDVariable[] lossAbsoluteDifferenceBP(SDVariable label, SDVariable predictions, SDVariable weights, LossReduce lossReduce){ - return new AbsoluteDifferenceLossBp(sameDiff(), lossReduce, predictions, weights, label).outputVariables(); - } - - public SDVariable lossCosineDistance(SDVariable label, SDVariable predictions, SDVariable weights, LossReduce lossReduce, int dimension){ - return new CosineDistanceLoss(sameDiff(), lossReduce, predictions, weights, label, dimension).outputVariable(); - } - - public SDVariable[] lossCosineDistanceBp(SDVariable label, SDVariable predictions, SDVariable weights, LossReduce lossReduce, int dimension){ - return new CosineDistanceLossBp(sameDiff(), lossReduce, predictions, weights, label, dimension).outputVariables(); - } - - public SDVariable lossHinge(SDVariable label, SDVariable predictions, SDVariable weights, LossReduce lossReduce){ - return new HingeLoss(sameDiff(), lossReduce, predictions, weights, label).outputVariable(); - } - - public SDVariable[] lossHingeBp(SDVariable label, SDVariable predictions, SDVariable weights, LossReduce lossReduce){ - return new HingeLossBp(sameDiff(), lossReduce, predictions, weights, label).outputVariables(); - } - - public SDVariable lossHuber(SDVariable label, SDVariable predictions, SDVariable weights, LossReduce lossReduce, double delta){ - return new HuberLoss(sameDiff(), lossReduce, predictions, weights, label, delta).outputVariable(); - } - - public SDVariable[] lossHuberBp(SDVariable label, SDVariable predictions, SDVariable weights, LossReduce lossReduce, double delta){ - return new HuberLossBp(sameDiff(), lossReduce, predictions, weights, label, delta).outputVariables(); - } - - public SDVariable lossLog(SDVariable label, SDVariable predictions, SDVariable weights, LossReduce lossReduce, double epsilon){ - return new LogLoss(sameDiff(), lossReduce, predictions, weights, label, epsilon).outputVariable(); - } - - public SDVariable[] lossLogBp(SDVariable label, SDVariable predictions, SDVariable weights, LossReduce lossReduce, double epsilon){ - return new LogLossBp(sameDiff(), lossReduce, predictions, weights, label, epsilon).outputVariables(); - } - - public SDVariable lossLogPoisson(SDVariable label, SDVariable predictions, SDVariable weights, LossReduce lossReduce){ - return new LogPoissonLoss(sameDiff(), lossReduce, predictions, weights, label).outputVariable(); - } - - public SDVariable[] lossLogPoissonBp(SDVariable label, SDVariable predictions, SDVariable weights, LossReduce lossReduce){ - return new LogPoissonLossBp(sameDiff(), lossReduce, predictions, weights, label).outputVariables(); - } - - public SDVariable lossLogPoissonFull(SDVariable label, SDVariable predictions, SDVariable weights, LossReduce lossReduce){ - return new LogPoissonLoss(sameDiff(), lossReduce, predictions, weights, label, true).outputVariable(); - } - - public SDVariable[] lossLogPoissonFullBp(SDVariable label, SDVariable predictions, SDVariable weights, LossReduce lossReduce){ - return new LogPoissonLossBp(sameDiff(), lossReduce, predictions, weights, label, true).outputVariables(); - } - - public SDVariable lossMeanPairwiseSquaredError(SDVariable label, SDVariable predictions, SDVariable weights, LossReduce lossReduce){ - return new MeanPairwiseSquaredErrorLoss(sameDiff(), lossReduce, predictions, weights, label).outputVariable(); - } - - public SDVariable[] lossMeanPairwiseSquaredErrorBp(SDVariable label, SDVariable predictions, SDVariable weights, LossReduce lossReduce){ - return new MeanPairwiseSquaredErrorLossBp(sameDiff(), lossReduce, predictions, weights, label).outputVariables(); - } - - public SDVariable lossMeanSquaredError(SDVariable label, SDVariable predictions, SDVariable weights, LossReduce lossReduce){ - return new MeanSquaredErrorLoss(sameDiff(), lossReduce, predictions, weights, label).outputVariable(); - } - - public SDVariable[] lossMeanSquaredErrorBp(SDVariable label, SDVariable predictions, SDVariable weights, LossReduce lossReduce){ - return new MeanSquaredErrorLossBp(sameDiff(), lossReduce, predictions, weights, label).outputVariables(); - } - - public SDVariable lossSigmoidCrossEntropy(SDVariable labels, SDVariable logits, SDVariable weights, LossReduce lossReduce, double labelSmoothing) { - return new SigmoidCrossEntropyLoss(sameDiff(), lossReduce, logits, weights, labels, labelSmoothing).outputVariable(); - } - - public SDVariable[] lossSigmoidCrossEntropyBp(SDVariable labels, SDVariable logits, SDVariable weights, LossReduce lossReduce, double labelSmoothing) { - return new SigmoidCrossEntropyLossBp(sameDiff(), lossReduce, logits, weights, labels, labelSmoothing).outputVariables(); - } - - public SDVariable lossSoftmaxCrossEntropy(SDVariable labels, SDVariable logits, SDVariable weights, LossReduce lossReduce, double labelSmoothing) { - return new SoftmaxCrossEntropyLoss(sameDiff(), lossReduce, logits, weights, labels, labelSmoothing).outputVariable(); - } - - public SDVariable[] lossSoftmaxCrossEntropyBp(SDVariable labels, SDVariable logits, SDVariable weights, LossReduce lossReduce, double labelSmoothing) { - return new SoftmaxCrossEntropyLossBp(sameDiff(), lossReduce, logits, weights, labels, labelSmoothing).outputVariables(); - } - - public SDVariable lossSoftmaxCrossEntropyWithLogits(SDVariable labels, SDVariable logits, int classDim) { - return new SoftmaxCrossEntropyWithLogitsLoss(sameDiff(), logits, labels, classDim).outputVariable(); - } - - public SDVariable[] lossSoftmaxCrossEntropyWithLogitsBp(SDVariable labels, SDVariable logits, int classDim) { - return new SoftmaxCrossEntropyWithLogitsLossBp(sameDiff(), logits, labels, classDim).outputVariables(); - } - - public SDVariable lossSparseSoftmaxCrossEntropy(SDVariable logits, SDVariable labels){ - return new SparseSoftmaxCrossEntropyLossWithLogits(sameDiff(), logits, labels).outputVariable(); - } - - public SDVariable[] lossSparseSoftmaxCrossEntropyBp(SDVariable logits, SDVariable labels){ - return new SparseSoftmaxCrossEntropyLossWithLogitsBp(sameDiff(), logits, labels).outputVariables(); - } - - - public SDVariable xwPlusB(SDVariable input, SDVariable weights, SDVariable bias) { - return new XwPlusB(sameDiff(), input, weights, bias).outputVariable(); - } - - public SDVariable reluLayer(SDVariable input, SDVariable weights, SDVariable bias) { - return new ReluLayer(sameDiff(), input, weights, bias).outputVariable(); - } - - public SDVariable mmul(SDVariable x, - SDVariable y, - MMulTranspose mMulTranspose) { - validateDifferentialFunctionsameDiff(x); - validateDifferentialFunctionsameDiff(y); - return new Mmul(sameDiff(), x, y, mMulTranspose).outputVariable(); - } - - - public SDVariable mmul(SDVariable x, - SDVariable y) { - return mmul(x, y, MMulTranspose.allFalse()); - } - - public List mmulBp(SDVariable x, SDVariable y, SDVariable eps, MMulTranspose mt) { - return Arrays.asList(new MmulBp(sameDiff(), x, y, eps, mt).outputVariables()); - } - - public SDVariable[] batchMmul(SDVariable[] matricesA, - SDVariable[] matricesB) { - return batchMmul(matricesA, matricesB, false, false); - } - - - public SDVariable[] batchMmul(SDVariable[] matricesA, - SDVariable[] matricesB, - boolean transposeA, - boolean transposeB) { - return batchMmul(ArrayUtils.addAll(matricesA, matricesB), transposeA, transposeB); - } - - - public SDVariable[] batchMmul(SDVariable[] matrices, - boolean transposeA, - boolean transposeB) { - return new BatchMmul(sameDiff(), matrices, transposeA, transposeB).outputVariables(); - } - - - public SDVariable tensorMmul(SDVariable x, - SDVariable y, - int[][] dimensions) { - validateDifferentialFunctionsameDiff(x); - validateDifferentialFunctionsameDiff(y); - return new TensorMmul(sameDiff(), x, y, dimensions).outputVariable(); - } - - public SDVariable dotProductAttention(SDVariable queries, SDVariable keys, SDVariable values, SDVariable mask, boolean scaled) { - return new DotProductAttention(sameDiff(), queries, keys, values, mask, scaled, false).outputVariable(); - } - - public List dotProductAttention(SDVariable queries, SDVariable keys, SDVariable values, SDVariable mask, boolean scaled, boolean withWeights) { - return Arrays.asList(new DotProductAttention(sameDiff(), queries, keys, values, mask, scaled, withWeights).outputVariables()); - } - - public List dotProductAttentionBp(SDVariable queries, SDVariable keys, SDVariable values, SDVariable gradient, SDVariable mask, boolean scaled) { - return Arrays.asList(new DotProductAttentionBp(sameDiff(), queries, keys, values, gradient, mask, scaled).outputVariables()); - } - - public SDVariable multiHeadDotProductAttention(SDVariable queries, SDVariable keys, SDVariable values, SDVariable Wq, SDVariable Wk, SDVariable Wv, SDVariable Wo, SDVariable mask, boolean scaled) { - return new MultiHeadDotProductAttention(sameDiff(), queries, keys, values, Wq, Wk, Wv, Wo, mask, scaled, false).outputVariable(); - } - - public List multiHeadDotProductAttention(SDVariable queries, SDVariable keys, SDVariable values,SDVariable Wq, SDVariable Wk, SDVariable Wv, SDVariable Wo, SDVariable mask, boolean scaled, boolean withWeights) { - return Arrays.asList(new MultiHeadDotProductAttention(sameDiff(), queries, keys, values, Wq, Wk, Wv, Wo, mask, scaled, withWeights).outputVariables()); - } - - public List multiHeadDotProductAttentionBp(SDVariable queries, SDVariable keys, SDVariable values,SDVariable Wq, SDVariable Wk, SDVariable Wv, SDVariable Wo, SDVariable gradient, SDVariable mask, boolean scaled) { - return Arrays.asList(new MultiHeadDotProductAttentionBp(sameDiff(), queries, keys, values, Wq, Wk, Wv, Wo, gradient, mask, scaled).outputVariables()); - } - - public SDVariable softmaxDerivative(SDVariable functionInput, SDVariable wrt, Integer dimension) { - validateDifferentialFunctionsameDiff(functionInput); - return new SoftmaxBp(sameDiff(), functionInput, wrt, dimension).outputVariable(); - } - - - public SDVariable logSoftmax(SDVariable i_v) { - validateDifferentialFunctionsameDiff(i_v); - return new LogSoftMax(sameDiff(), i_v).outputVariable(); - - } - - - public SDVariable logSoftmax(SDVariable i_v, int dimension) { - validateDifferentialFunctionsameDiff(i_v); - return new LogSoftMax(sameDiff(), i_v, dimension).outputVariable(); - - } - - - public SDVariable logSoftmaxDerivative(SDVariable arg, SDVariable wrt) { - validateDifferentialFunctionsameDiff(arg); - return new LogSoftMaxDerivative(sameDiff(), arg, wrt).outputVariable(); - } - - - public SDVariable logSoftmaxDerivative(SDVariable arg, SDVariable wrt, int dimension) { - validateDifferentialFunctionsameDiff(arg); - return new LogSoftMaxDerivative(sameDiff(), arg, wrt, dimension).outputVariable(); - } - - public SDVariable logSumExp(SDVariable arg, boolean keepDims, int... dimension) { - return new LogSumExp(sameDiff(), arg, keepDims, dimension).outputVariable(); - } - - - public SDVariable selu(SDVariable arg) { - validateDifferentialFunctionsameDiff(arg); - return new SELU(sameDiff(), arg, false).outputVariable(); - } - - public SDVariable seluBp(SDVariable in, SDVariable epsilon) { - validateDifferentialFunctionsameDiff(in); - return new SeluBp(sameDiff(), in, epsilon).outputVariable(); - } - - /** - * @deprecated Use {@link #seluBp(SDVariable, SDVariable)} - */ - @Deprecated - public SDVariable seluDerivative(SDVariable arg) { - validateDifferentialFunctionsameDiff(arg); - return new SELUDerivative(sameDiff(), arg, false).outputVariable(); - } - - - public SDVariable rsub(SDVariable differentialFunction, SDVariable i_v) { - validateDifferentialFunctionsameDiff(differentialFunction); - return new RSubOp(sameDiff(), differentialFunction, i_v).outputVariable(); - } - - public List rsubBp(SDVariable x, SDVariable y, SDVariable grad) { - return Arrays.asList(new RSubBpOp(sameDiff(), x, y, grad).outputVariables()); - } - - - public SDVariable rdiv(SDVariable differentialFunction, SDVariable i_v) { - validateDifferentialFunctionsameDiff(differentialFunction); - return new RDivOp(sameDiff(), new SDVariable[]{differentialFunction, i_v}, false).outputVariable(); - } - - public List rdivBp(SDVariable x, SDVariable y, SDVariable grad) { - return Arrays.asList(new RDivBpOp(sameDiff(), x, y, grad).outputVariables()); - } - - - public SDVariable rdivi(SDVariable differentialFunction, SDVariable i_v) { - validateDifferentialFunctionsameDiff(differentialFunction); - return new RDivOp(sameDiff(), new SDVariable[]{differentialFunction, i_v}, true).outputVariable(); - } - - - public SDVariable rsubi(SDVariable differentialFunction, SDVariable i_v) { - validateDifferentialFunctionsameDiff(differentialFunction); - return new RSubOp(sameDiff(), differentialFunction, i_v).outputVariable(); - } - - public SDVariable add(SDVariable differentialFunction, SDVariable i_v) { - validateDifferentialFunctionsameDiff(differentialFunction); - return new AddOp(sameDiff(), new SDVariable[]{differentialFunction, i_v}, false).outputVariable(); - - } - - public SDVariable mergeAdd(SDVariable... differentialFunctions) { - for (SDVariable df : differentialFunctions) - validateDifferentialFunctionsameDiff(df); - return new MergeAddOp(sameDiff(), differentialFunctions, false).outputVariable(); - } - - public SDVariable mergeMax(SDVariable... differentialFunctions) { - for (SDVariable df : differentialFunctions) - validateDifferentialFunctionsameDiff(df); - return new MergeMax(sameDiff(), differentialFunctions).outputVariable(); - } - - public SDVariable mergeAvg(SDVariable... differentialFunctions) { - for (SDVariable df : differentialFunctions) - validateDifferentialFunctionsameDiff(df); - return new MergeAvg(sameDiff(), differentialFunctions).outputVariable(); - } - - public SDVariable diag(SDVariable sdVariable) { - validateDifferentialFunctionsameDiff(sdVariable); - return new Diag(sameDiff(), new SDVariable[]{sdVariable}, false).outputVariable(); - } - - public SDVariable diagPart(SDVariable sdVariable) { - validateDifferentialFunctionsameDiff(sdVariable); - return new DiagPart(sameDiff(), new SDVariable[]{sdVariable}, false).outputVariable(); - } - - public SDVariable setDiag(SDVariable in, SDVariable diag) { - return new MatrixSetDiag(sameDiff(), in, diag, false).outputVariable(); - } - - - public SDVariable batchToSpace(SDVariable differentialFunction, int[] blocks, int[][] crops) { - validateDifferentialFunctionsameDiff(differentialFunction); - return new BatchToSpace(sameDiff(), new SDVariable[]{differentialFunction}, blocks, crops, false) - .outputVariable(); - } - - public SDVariable spaceToBatch(SDVariable differentialFunction, int[] blocks, int[][] padding) { - validateDifferentialFunctionsameDiff(differentialFunction); - return new SpaceToBatch(sameDiff(), new SDVariable[]{differentialFunction}, blocks, padding, false) - .outputVariable(); - } - - public SDVariable depthToSpace(SDVariable differentialFunction, int blocksSize, DataFormat dataFormat) { - validateDifferentialFunctionsameDiff(differentialFunction); - return new DepthToSpace(sameDiff(), new SDVariable[]{differentialFunction}, blocksSize, dataFormat) - .outputVariable(); - } - - public SDVariable spaceToDepth(SDVariable differentialFunction, int blocksSize, DataFormat dataFormat) { - validateDifferentialFunctionsameDiff(differentialFunction); - return new SpaceToDepth(sameDiff(), new SDVariable[]{differentialFunction}, blocksSize, dataFormat) - .outputVariable(); - } - - public SDVariable[] dynamicPartition(SDVariable differentialFunction, SDVariable partitions, int numPartitions) { - validateDifferentialFunctionsameDiff(differentialFunction); - return new DynamicPartition(sameDiff(), differentialFunction, partitions, numPartitions) - .outputVariables(); - } - - public SDVariable[] dynamicPartitionBp(SDVariable input, SDVariable partitions, SDVariable[] grads, int numPartitions){ - return new DynamicPartitionBp(sameDiff(), input, partitions, grads, numPartitions).outputVariables(); - } - - public SDVariable dynamicStitch(SDVariable[] indices, SDVariable[] differentialFunctions) { - for (SDVariable df : differentialFunctions) - validateDifferentialFunctionsameDiff(df); - - return new DynamicStitch(sameDiff(), indices, differentialFunctions).outputVariable(); - } - - public SDVariable segmentMax(SDVariable data, SDVariable segmentIds){ - return new SegmentMax(sameDiff(), data, segmentIds).outputVariable(); - } - - public SDVariable[] segmentMaxBp(SDVariable data, SDVariable segmentIds, SDVariable gradient){ - return new SegmentMaxBp(sameDiff(), data, segmentIds, gradient).outputVariables(); - } - - public SDVariable segmentMin(SDVariable data, SDVariable segmentIds){ - return new SegmentMin(sameDiff(), data, segmentIds).outputVariable(); - } - - public SDVariable[] segmentMinBp(SDVariable data, SDVariable segmentIds, SDVariable gradient){ - return new SegmentMinBp(sameDiff(), data, segmentIds, gradient).outputVariables(); - } - - public SDVariable segmentMean(SDVariable data, SDVariable segmentIds){ - return new SegmentMean(sameDiff(), data, segmentIds).outputVariable(); - } - - public SDVariable[] segmentMeanBp(SDVariable data, SDVariable segmentIds, SDVariable gradient){ - return new SegmentMeanBp(sameDiff(), data, segmentIds, gradient).outputVariables(); - } - - public SDVariable segmentProd(SDVariable data, SDVariable segmentIds){ - return new SegmentProd(sameDiff(), data, segmentIds).outputVariable(); - } - - public SDVariable[] segmentProdBp(SDVariable data, SDVariable segmentIds, SDVariable gradient){ - return new SegmentProdBp(sameDiff(), data, segmentIds, gradient).outputVariables(); - } - - public SDVariable segmentSum(SDVariable data, SDVariable segmentIds){ - return new SegmentSum(sameDiff(), data, segmentIds).outputVariable(); - } - - public SDVariable[] segmentSumBp(SDVariable data, SDVariable segmentIds, SDVariable gradient){ - return new SegmentSumBp(sameDiff(), data, segmentIds, gradient).outputVariables(); - } - - - public SDVariable unsortedSegmentMax(SDVariable data, SDVariable segmentIds, int numSegments){ - return new UnsortedSegmentMax(sameDiff(), data, segmentIds, numSegments).outputVariable(); - } - - public SDVariable[] unsortedSegmentMaxBp(SDVariable data, SDVariable segmentIds, SDVariable gradient, int numSegments){ - return new UnsortedSegmentMaxBp(sameDiff(), data, segmentIds, gradient, numSegments).outputVariables(); - } - - public SDVariable unsortedSegmentMin(SDVariable data, SDVariable segmentIds, int numSegments){ - return new UnsortedSegmentMin(sameDiff(), data, segmentIds, numSegments).outputVariable(); - } - - public SDVariable[] unsortedSegmentMinBp(SDVariable data, SDVariable segmentIds, SDVariable gradient, int numSegments){ - return new UnsortedSegmentMinBp(sameDiff(), data, segmentIds, gradient, numSegments).outputVariables(); - } - - public SDVariable unsortedSegmentMean(SDVariable data, SDVariable segmentIds, int numSegments){ - return new UnsortedSegmentMean(sameDiff(), data, segmentIds, numSegments).outputVariable(); - } - - public SDVariable[] unsortedSegmentMeanBp(SDVariable data, SDVariable segmentIds, SDVariable gradient, int numSegments){ - return new UnsortedSegmentMeanBp(sameDiff(), data, segmentIds, gradient, numSegments).outputVariables(); - } - - public SDVariable unsortedSegmentProd(SDVariable data, SDVariable segmentIds, int numSegments){ - return new UnsortedSegmentProd(sameDiff(), data, segmentIds, numSegments).outputVariable(); - } - - public SDVariable[] unsortedSegmentProdBp(SDVariable data, SDVariable segmentIds, SDVariable gradient, int numSegments){ - return new UnsortedSegmentProdBp(sameDiff(), data, segmentIds, gradient, numSegments).outputVariables(); - } - - public SDVariable unsortedSegmentSum(SDVariable data, SDVariable segmentIds, int numSegments){ - return new UnsortedSegmentSum(sameDiff(), data, segmentIds, numSegments).outputVariable(); - } - - public SDVariable[] unsortedSegmentSumBp(SDVariable data, SDVariable segmentIds, SDVariable gradient, int numSegments){ - return new UnsortedSegmentSumBp(sameDiff(), data, segmentIds, gradient, numSegments).outputVariables(); - } - - public SDVariable unsortedSegmentSqrtN(SDVariable data, SDVariable segmentIds, int numSegments){ - return new UnsortedSegmentSqrtN(sameDiff(), data, segmentIds, numSegments).outputVariable(); - } - - public SDVariable[] unsortedSegmentSqrtNBp(SDVariable data, SDVariable segmentIds, SDVariable gradient, int numSegments){ - return new UnsortedSegmentSqrtNBp(sameDiff(), data, segmentIds, gradient, numSegments).outputVariables(); - } - - - - - public SDVariable dilation2D(SDVariable df, SDVariable weights, int[] strides, - int[] rates, boolean isSameMode) { - validateDifferentialFunctionsameDiff(df); - return new Dilation2D(sameDiff(), new SDVariable[]{df, weights}, strides, rates, isSameMode, false) - .outputVariable(); - } - - public SDVariable shape(SDVariable df) { - validateDifferentialFunctionsameDiff(df); - return new org.nd4j.linalg.api.ops.impl.shape.Shape(sameDiff(), df, false).outputVariable(); - } - - public SDVariable size(SDVariable in) { - return new Size(sameDiff(), in).outputVariable(); - } - - public SDVariable sizeAt(SDVariable in, int dimension){ - return new SizeAt(sameDiff(), in, dimension).outputVariable(); - } - - public SDVariable rank(SDVariable df) { - return new Rank(sameDiff(), df, false).outputVariable(); - } - - public SDVariable gather(SDVariable df, int[] indices, int axis) { - validateDifferentialFunctionsameDiff(df); - return new Gather(sameDiff(), df, indices, axis, false).outputVariable(); - } - - public SDVariable gather(SDVariable df, SDVariable indices, int axis) { - validateDifferentialFunctionsameDiff(df); - return new Gather(sameDiff(), df, indices, axis, false).outputVariable(); - } - - public SDVariable gatherNd(SDVariable df, SDVariable indices) { - validateDifferentialFunctionsameDiff(df); - return new GatherNd(sameDiff(), df, indices).outputVariable(); - } - - public SDVariable trace(SDVariable in){ - return new Trace(sameDiff(), in).outputVariable(); - } - - public SDVariable cross(SDVariable a, SDVariable b) { - validateDifferentialFunctionsameDiff(a); - return new Cross(sameDiff(), new SDVariable[]{a, b}).outputVariable(); - } - - public SDVariable erf(SDVariable differentialFunction) { - validateDifferentialFunctionsameDiff(differentialFunction); - return new Erf(sameDiff(), differentialFunction, false).outputVariable(); - } - - public SDVariable erfc(SDVariable differentialFunction) { - validateDifferentialFunctionsameDiff(differentialFunction); - return new Erfc(sameDiff(), differentialFunction, false).outputVariable(); - } - - public SDVariable addi(SDVariable differentialFunction, SDVariable i_v) { - validateDifferentialFunctionsameDiff(differentialFunction); - return new AddOp(sameDiff(), new SDVariable[]{differentialFunction, i_v}, true).outputVariable(); - } - - public List addBp(SDVariable x, SDVariable y, SDVariable grad) { - SDVariable[] ret = new AddBpOp(sameDiff(), x, y, grad).outputVariables(); - return Arrays.asList(ret); - } - - - public SDVariable sub(SDVariable differentialFunction, SDVariable i_v) { - validateDifferentialFunctionsameDiff(differentialFunction); - return new SubOp(sameDiff(), new SDVariable[]{differentialFunction, i_v}, false).outputVariable(); - } - - public SDVariable squaredDifference(SDVariable differentialFunction, SDVariable i_v) { - validateDifferentialFunctionsameDiff(differentialFunction); - return new SquaredDifferenceOp(sameDiff(), new SDVariable[]{differentialFunction, i_v}, false) - .outputVariable(); - } - - - public List subBp(SDVariable x, SDVariable y, SDVariable grad) { - return Arrays.asList(new SubBpOp(sameDiff(), x, y, grad).outputVariables()); - } - - - public SDVariable subi(SDVariable differentialFunction, SDVariable i_v) { - validateDifferentialFunctionsameDiff(differentialFunction); - return new SubOp(sameDiff(), new SDVariable[]{differentialFunction, i_v}, true).outputVariable(); - - } - - - public SDVariable mul(SDVariable differentialFunction, SDVariable i_v) { - validateDifferentialFunctionsameDiff(differentialFunction); - return new MulOp(sameDiff(), new SDVariable[]{differentialFunction, i_v}, false).outputVariable(); - } - - public List mulBp(SDVariable x, SDVariable y, SDVariable grad) { - return Arrays.asList(new MulBpOp(sameDiff(), x, y, grad).outputVariables()); - } - - public List modBp(SDVariable x, SDVariable y, SDVariable grad) { - return Arrays.asList(new ModBpOp(sameDiff(), x, y, grad).outputVariables()); - } - - - public SDVariable muli(SDVariable differentialFunction, SDVariable i_v) { - validateDifferentialFunctionsameDiff(differentialFunction); - return new MulOp(sameDiff(), new SDVariable[]{differentialFunction, i_v}, true).outputVariable(); - } - - public SDVariable mod(SDVariable differentialFunction, SDVariable i_v) { - validateDifferentialFunctionsameDiff(differentialFunction); - return new ModOp(sameDiff(), new SDVariable[]{differentialFunction, i_v}, false).outputVariable(); - } - - public SDVariable div(SDVariable differentialFunction, SDVariable i_v) { - validateDifferentialFunctionsameDiff(differentialFunction); - return new DivOp(sameDiff(), new SDVariable[]{differentialFunction, i_v}, false).outputVariable(); - } - - public SDVariable truncatedDiv(SDVariable differentialFunction, SDVariable i_v) { - validateDifferentialFunctionsameDiff(differentialFunction); - return new TruncateDivOp(sameDiff(), differentialFunction, i_v, false).outputVariable(); - } - - public List divBp(SDVariable x, SDVariable y, SDVariable grad) { - return Arrays.asList(new DivBpOp(sameDiff(), x, y, grad).outputVariables()); - } - - - public SDVariable divi(SDVariable differentialFunction, SDVariable i_v) { - validateDifferentialFunctionsameDiff(differentialFunction); - return new DivOp(sameDiff(), new SDVariable[]{differentialFunction, i_v}, true).outputVariable(); - } - - - public SDVariable rsub(SDVariable differentialFunction, double i_v) { - validateDifferentialFunctionsameDiff(differentialFunction); - return new ScalarReverseSubtraction(sameDiff(), differentialFunction, i_v).outputVariable(); - - } - - - public SDVariable rdiv(SDVariable differentialFunction, double i_v) { - validateDifferentialFunctionsameDiff(differentialFunction); - return new ScalarReverseDivision(sameDiff(), differentialFunction, i_v).outputVariable(); - - } - - - public SDVariable rdivi(SDVariable differentialFunction, double i_v) { - validateDifferentialFunctionsameDiff(differentialFunction); - return new ScalarReverseDivision(sameDiff(), differentialFunction, i_v, true).outputVariable(); - } - - - public SDVariable rsubi(SDVariable differentialFunction, double i_v) { - validateDifferentialFunctionsameDiff(differentialFunction); - return new ScalarReverseSubtraction(sameDiff(), differentialFunction, i_v, true).outputVariable(); - - } - - - public SDVariable add(SDVariable differentialFunction, double i_v) { - validateDifferentialFunctionsameDiff(differentialFunction); - return new ScalarAdd(sameDiff(), differentialFunction, i_v, false).outputVariable(); - } - - - public SDVariable addi(SDVariable differentialFunction, double i_v) { - validateDifferentialFunctionsameDiff(differentialFunction); - return new ScalarAdd(sameDiff(), differentialFunction, i_v, true).outputVariable(); - } - - - public SDVariable sub(SDVariable differentialFunction, double i_v) { - validateDifferentialFunctionsameDiff(differentialFunction); - return new ScalarSubtraction(sameDiff(), differentialFunction, i_v).outputVariable(); - } - - - public SDVariable subi(SDVariable differentialFunction, double i_v) { - validateDifferentialFunctionsameDiff(differentialFunction); - return new ScalarSubtraction(sameDiff(), differentialFunction, i_v, true).outputVariable(); - - } - - - public SDVariable mul(SDVariable differentialFunction, double i_v) { - validateDifferentialFunctionsameDiff(differentialFunction); - return new ScalarMultiplication(sameDiff(), differentialFunction, i_v).outputVariable(); - - } - - - public SDVariable muli(SDVariable differentialFunction, double i_v) { - validateDifferentialFunctionsameDiff(differentialFunction); - return new ScalarMultiplication(sameDiff(), differentialFunction, i_v, true).outputVariable(); - - } - - - public SDVariable div(SDVariable differentialFunction, double i_v) { - validateDifferentialFunctionsameDiff(differentialFunction); - return new ScalarDivision(sameDiff(), differentialFunction, i_v).outputVariable(); - } - - - public SDVariable divi(SDVariable differentialFunction, double i_v) { - validateDifferentialFunctionsameDiff(differentialFunction); - return new ScalarDivision(sameDiff(), differentialFunction, i_v, true).outputVariable(); - } - - - public SDVariable gt(SDVariable functionInput, SDVariable functionInput1) { - validateDifferentialFunctionsameDiff(functionInput); - validateDifferentialFunctionsameDiff(functionInput1); - return new GreaterThan(sameDiff(), new SDVariable[]{functionInput, functionInput1}, false).outputVariable(); - } - - - public SDVariable lt(SDVariable functionInput, SDVariable functionInput1) { - validateDifferentialFunctionsameDiff(functionInput); - validateDifferentialFunctionsameDiff(functionInput1); - return new LessThan(sameDiff(), new SDVariable[]{functionInput, functionInput1}, false).outputVariable(); - } - - - public SDVariable gti(SDVariable functionInput, SDVariable functionInput1) { - validateDifferentialFunctionsameDiff(functionInput); - validateDifferentialFunctionsameDiff(functionInput1); - return new GreaterThan(sameDiff(), new SDVariable[]{functionInput, functionInput1}, true).outputVariable(); - } - - - public SDVariable lti(SDVariable functionInput, SDVariable functionInput1) { - validateDifferentialFunctionsameDiff(functionInput); - validateDifferentialFunctionsameDiff(functionInput1); - return new LessThan(sameDiff(), new SDVariable[]{functionInput, functionInput1}, true).outputVariable(); - } - - - public SDVariable gte(SDVariable functionInput, SDVariable functionInput1) { - validateDifferentialFunctionsameDiff(functionInput); - validateDifferentialFunctionsameDiff(functionInput1); - return new GreaterThanOrEqual(sameDiff(), new SDVariable[]{functionInput, functionInput1}, false).outputVariable(); - } - - - public SDVariable lte(SDVariable functionInput, SDVariable functionInput1) { - validateDifferentialFunctionsameDiff(functionInput); - validateDifferentialFunctionsameDiff(functionInput1); - return new LessThanOrEqual(sameDiff(), new SDVariable[]{functionInput, functionInput1}, false).outputVariable(); - } - - - public SDVariable gtei(SDVariable functionInput, SDVariable functionInput1) { - validateDifferentialFunctionsameDiff(functionInput); - validateDifferentialFunctionsameDiff(functionInput1); - return new GreaterThanOrEqual(sameDiff(), new SDVariable[]{functionInput, functionInput1}, true).outputVariable(); - } - - - public SDVariable ltOrEqi(SDVariable functionInput, SDVariable functionInput1) { - validateDifferentialFunctionsameDiff(functionInput); - validateDifferentialFunctionsameDiff(functionInput1); - return new LessThanOrEqual(sameDiff(), new SDVariable[]{functionInput, functionInput1}, true).outputVariable(); - } - - - public SDVariable gt(SDVariable functionInput, double functionInput1) { - validateDifferentialFunctionsameDiff(functionInput); - return new ScalarGreaterThan(sameDiff(), functionInput, functionInput1, false).outputVariable(); - } - - - public SDVariable lt(SDVariable functionInput, double functionInput1) { - validateDifferentialFunctionsameDiff(functionInput); - return new ScalarLessThan(sameDiff(), functionInput, functionInput1, false).outputVariable(); - } - - - public SDVariable gti(SDVariable functionInput, double functionInput1) { - validateDifferentialFunctionsameDiff(functionInput); - return new ScalarGreaterThan(sameDiff(), functionInput, functionInput1, true).outputVariable(); - } - - - public SDVariable lti(SDVariable functionInput, double functionInput1) { - validateDifferentialFunctionsameDiff(functionInput); - return new ScalarLessThan(sameDiff(), functionInput, functionInput1, true).outputVariable(); - } - - - public SDVariable gte(SDVariable functionInput, double functionInput1) { - validateDifferentialFunctionsameDiff(functionInput); - return new ScalarGreaterThanOrEqual(sameDiff(), functionInput, functionInput1, false).outputVariable(); - } - - - public SDVariable lte(SDVariable functionInput, double functionInput1) { - validateDifferentialFunctionsameDiff(functionInput); - return new ScalarLessThanOrEqual(sameDiff(), functionInput, functionInput1, false).outputVariable(); - } - - - public SDVariable gtei(SDVariable functionInput, double functionInput1) { - validateDifferentialFunctionsameDiff(functionInput); - return new ScalarGreaterThanOrEqual(sameDiff(), functionInput, functionInput1, true).outputVariable(); - } - - - public SDVariable ltei(SDVariable functionInput, double functionInput1) { - validateDifferentialFunctionsameDiff(functionInput); - return new ScalarLessThanOrEqual(sameDiff(), functionInput, functionInput1, true).outputVariable(); - } - - - public SDVariable eq(SDVariable iX, double i_y) { - return new ScalarEquals(sameDiff(), iX, i_y).outputVariable(); - } - - public SDVariable eqi(SDVariable iX, double i_y) { - return new ScalarEquals(sameDiff(), iX, i_y, true).outputVariable(); - } - - public SDVariable isNonDecreasing(SDVariable iX) { - validateDifferentialFunctionsameDiff(iX); - return new IsNonDecreasing(sameDiff(), new SDVariable[]{iX}, false).outputVariable(); - } - - public SDVariable isStrictlyIncreasing(SDVariable iX) { - validateDifferentialFunctionsameDiff(iX); - return new IsStrictlyIncreasing(sameDiff(), new SDVariable[]{iX}, false).outputVariable(); - } - - public SDVariable isNumericTensor(SDVariable iX) { - validateDifferentialFunctionsameDiff(iX); - return new IsNumericTensor(sameDiff(), new SDVariable[]{iX}, false).outputVariable(); - } - - public SDVariable slice(SDVariable input, int[] begin, int[] size) { - return new Slice(sameDiff(), input, begin, size).outputVariable(); - } - - public SDVariable slice(SDVariable input, SDVariable begin, SDVariable size) { - return new Slice(sameDiff(), input, begin, size).outputVariable(); - } - - public SDVariable sliceBp(SDVariable input, SDVariable gradient, int[] begin, int[] size) { - return new SliceBp(sameDiff(), input, gradient, begin, size).outputVariable(); - } - - public SDVariable sliceBp(SDVariable input, SDVariable gradient, SDVariable begin, SDVariable size) { - return new SliceBp(sameDiff(), input, gradient, begin, size).outputVariable(); - } - - - public SDVariable stridedSlice(SDVariable input, int[] begin, int[] end, int[] strides) { - return new StridedSlice(sameDiff(), input, begin, end, strides).outputVariable(); - } - - public SDVariable stridedSlice(SDVariable input, long[] begin, long[] end, long[] strides) { - return new StridedSlice(sameDiff(), input, begin, end, strides).outputVariable(); - } - - - public SDVariable stridedSlice(SDVariable in, int[] begin, int[] end, int[] strides, int beginMask, - int endMask, int ellipsisMask, int newAxisMask, int shrinkAxisMask) { - return new StridedSlice(sameDiff(), in, begin, end, strides, beginMask, endMask, ellipsisMask, - newAxisMask, shrinkAxisMask).outputVariable(); - } - - public SDVariable stridedSlice(SDVariable in, long[] begin, long[] end, long[] strides, int beginMask, - int endMask, int ellipsisMask, int newAxisMask, int shrinkAxisMask) { - return new StridedSlice(sameDiff(), in, begin, end, strides, beginMask, endMask, ellipsisMask, - newAxisMask, shrinkAxisMask).outputVariable(); - } - - public SDVariable stridedSliceBp(SDVariable in, SDVariable grad, long[] begin, long[] end, long[] strides, int beginMask, - int endMask, int ellipsisMask, int newAxisMask, int shrinkAxisMask) { - return new StridedSliceBp(sameDiff(), in, grad, begin, end, strides, beginMask, endMask, ellipsisMask, - newAxisMask, shrinkAxisMask).outputVariable(); - } - - public SDVariable stridedSliceBp(SDVariable in, SDVariable grad, SDVariable begin, SDVariable end, SDVariable strides, int beginMask, - int endMask, int ellipsisMask, int newAxisMask, int shrinkAxisMask) { - return new StridedSliceBp(sameDiff(), in, grad, begin, end, strides, beginMask, endMask, ellipsisMask, - newAxisMask, shrinkAxisMask).outputVariable(); - } - - public SDVariable scatterAdd(SDVariable ref, SDVariable indices, SDVariable updates) { - return new ScatterAdd(sameDiff(), ref, indices, updates).outputVariable(); - } - - public SDVariable scatterSub(SDVariable ref, SDVariable indices, SDVariable updates) { - return new ScatterSub(sameDiff(), ref, indices, updates).outputVariable(); - } - - public SDVariable scatterMul(SDVariable ref, SDVariable indices, SDVariable updates) { - return new ScatterMul(sameDiff(), ref, indices, updates).outputVariable(); - } - - public SDVariable scatterDiv(SDVariable ref, SDVariable indices, SDVariable updates) { - return new ScatterDiv(sameDiff(), ref, indices, updates).outputVariable(); - } - - public SDVariable scatterMax(SDVariable ref, SDVariable indices, SDVariable updates) { - return new ScatterMax(sameDiff(), ref, indices, updates).outputVariable(); - } - - public SDVariable scatterMin(SDVariable ref, SDVariable indices, SDVariable updates) { - return new ScatterMin(sameDiff(), ref, indices, updates).outputVariable(); - } - - public SDVariable scatterUpdate(SDVariable ref, SDVariable indices, SDVariable updates) { - return new ScatterUpdate(sameDiff(), ref, indices, updates).outputVariable(); - } - - - public SDVariable merge(SDVariable... inputs){ - return new Merge(sameDiff(), inputs).outputVariable(); - } - - public SDVariable[] switchOp(SDVariable input, SDVariable predicate){ - return new Switch(sameDiff(), input, predicate).outputVariables(); - } - - - public void validateDifferentialFunctionsameDiff( - SDVariable function) { - - Preconditions.checkState(function != null, "Passed in function was null."); - Preconditions.checkState(function.getSameDiff() == sameDiff); - - Preconditions.checkState(function.getSameDiff() == this.getSameDiff(), - "Function applications must be contained " + - "in same sameDiff. The left %s must match this function %s", function, this); - Preconditions.checkState(sameDiff == this.getSameDiff(), "Function applications must be " + - "contained in same sameDiff. The left %s must match this function ", function, this); - } - - - public void validateDifferentialFunctionGraph(SDVariable function) { - Preconditions.checkState(function.getSameDiff() == this.getSameDiff(), - "Function applications must be contained in same graph. The left %s must match this function %s", - function, this); - - } - - - /** - * @param func - * @param input - * @return - */ - public SDVariable doRepeat(SDVariable func, - SDVariable input) { - validateDifferentialFunctionsameDiff(func); - validateDifferentialFunctionsameDiff(input); - - return tile(func, ArrayUtil.toInts(input.getShape())); - } - - public SDVariable enter(SDVariable x, String frameName){ - return new Enter(sameDiff, frameName, x).outputVariable(); - } - - public SDVariable enter(SDVariable x, String frameName, boolean isConstant){ - return new Enter(sameDiff, frameName, x, isConstant).outputVariable(); - } - - public SDVariable exit(SDVariable x){ - return new Exit(sameDiff, x).outputVariable(); - } - - public SDVariable nextIteration(SDVariable x){ - return new NextIteration(sameDiff, x).outputVariable(); - } - - public SDVariable adjustContrast(SDVariable in, SDVariable factor) { - return new AdjustContrast(sameDiff, in, factor).outputVariable(); - } - - public SDVariable adjustContrastV2(SDVariable in, SDVariable factor) { - return new AdjustContrastV2(sameDiff, in, factor).outputVariable(); - } - - public SDVariable bitCast(SDVariable in, SDVariable dataType) { - return new BitCast(sameDiff, in, dataType).outputVariable(); - } - - public SDVariable compareAndBitpack(SDVariable threshold) { - return new CompareAndBitpack(sameDiff, threshold).outputVariable(); - } - - public SDVariable divideNoNan(SDVariable in1, SDVariable in2) { - return new DivideNoNan(sameDiff, in1, in2).outputVariable(); - } - - public SDVariable drawBoundingBoxes(SDVariable boxes, SDVariable colors) { - return new DrawBoundingBoxes(sameDiff, boxes, colors).outputVariable(); - } - - public SDVariable fakeQuantWithMinMaxVarsPerChannel(SDVariable x, SDVariable min, SDVariable max, - int num_bits, boolean narrow) { - return new FakeQuantWithMinMaxVarsPerChannel(sameDiff,x,min,max,num_bits,narrow).outputVariable(); - } - - public SDVariable betainc( SDVariable a, SDVariable b, SDVariable x) { - return new BetaInc(sameDiff, a, b, x).outputVariable(); - } - - public SDVariable[] fusedBatchNorm(SDVariable x, SDVariable scale, SDVariable offset, - SDVariable dataFormat, SDVariable isTraining) { - return new FusedBatchNorm(sameDiff,x,scale,offset,dataFormat,isTraining).outputVariables(); - } - - public SDVariable matrixBandPart(SDVariable input, SDVariable minLower, SDVariable maxUpper) { - return new MatrixBandPart(sameDiff,input,minLower,maxUpper).outputVariable(); - } - - public SDVariable[] maxPoolWithArgmax(SDVariable x, Pooling2DConfig pooling2DConfig) { - return new MaxPoolWithArgmax(sameDiff, x, pooling2DConfig).outputVariables(); - } - - public SDVariable polygamma(SDVariable n, SDVariable x) { - return new Polygamma(sameDiff, n,x).outputVariable(); - } - - public SDVariable roll(SDVariable input, int shift) { - return new Roll(sameDiff, input, shift).outputVariable(); - } - - public SDVariable toggleBits(SDVariable x) { - return new ToggleBits(sameDiff, x).outputVariable(); - } - - - public String toString() { - return "DifferentialFunctionFactory{methodNames=" + methodNames + "}"; - } -} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SDVariable.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SDVariable.java index 3b29e6ccb..5ee0801d0 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SDVariable.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SDVariable.java @@ -253,7 +253,7 @@ public class SDVariable implements Serializable { * @return Negated variable */ public SDVariable neg(){ - return sameDiff.f().neg(this); + return sameDiff.math.neg(this); } /** @@ -579,7 +579,7 @@ public class SDVariable implements Serializable { * @return Output variable */ public SDVariable add(String varName, double scalar) { - val function = sameDiff.f().add(this,scalar); + val function = sameDiff.math.add(this,scalar); return sameDiff.updateVariableNameAndReference(function,varName); } @@ -600,7 +600,7 @@ public class SDVariable implements Serializable { * @return Output (result) SDVariable */ public SDVariable add(String name, SDVariable x) { - val result = sameDiff.f().add(this, x); + val result = sameDiff.math.add(this, x); return sameDiff.updateVariableNameAndReference(result, name); } @@ -636,7 +636,7 @@ public class SDVariable implements Serializable { * @return Output variable */ public SDVariable sub(String varName, double scalar) { - val result = sameDiff.f().sub(this, scalar); + val result = sameDiff.math.sub(this, scalar); return sameDiff.updateVariableNameAndReference(result, varName); } @@ -657,7 +657,7 @@ public class SDVariable implements Serializable { * @return Output (result) SDVariable */ public SDVariable sub(String name, SDVariable x) { - val result = sameDiff.f().sub(this,x); + val result = sameDiff.math.sub(this,x); return sameDiff.updateVariableNameAndReference(result,name); } @@ -693,7 +693,7 @@ public class SDVariable implements Serializable { * @return Output variable */ public SDVariable div(String varName, double scalar) { - val function = sameDiff.f().div(this,scalar); + val function = sameDiff.math.div(this,scalar); return sameDiff.updateVariableNameAndReference(function,varName); } @@ -714,7 +714,7 @@ public class SDVariable implements Serializable { * @return Output (result) SDVariable */ public SDVariable div(String name, SDVariable x) { - val result = sameDiff.f().div(this, x); + val result = sameDiff.math.div(this, x); return sameDiff.updateVariableNameAndReference(result, name); } @@ -728,7 +728,7 @@ public class SDVariable implements Serializable { * @return Output (result) SDVariable */ public SDVariable fdiv(String name, SDVariable x) { - val result = sameDiff.f().floorDiv(this, x); + val result = sameDiff.math.floorDiv(this, x); return sameDiff.updateVariableNameAndReference(result, name); } @@ -742,7 +742,7 @@ public class SDVariable implements Serializable { * @return Output (result) SDVariable */ public SDVariable mod(String name, SDVariable x) { - val result = sameDiff.f().mod(this, x); + val result = sameDiff.math.mod(this, x); return sameDiff.updateVariableNameAndReference(result, name); } @@ -762,7 +762,7 @@ public class SDVariable implements Serializable { * @return Output variable */ public SDVariable mul(String varName, double scalar) { - val function = sameDiff.f().mul(this, scalar); + val function = sameDiff.math.mul(this, scalar); return sameDiff.updateVariableNameAndReference(function,varName); } @@ -784,7 +784,7 @@ public class SDVariable implements Serializable { * @return Output (result) SDVariable */ public SDVariable mul(String name, SDVariable x) { - val result = sameDiff.f().mul(this, x); + val result = sameDiff.math.mul(this, x); return sameDiff.updateVariableNameAndReference(result,name); } @@ -820,7 +820,7 @@ public class SDVariable implements Serializable { * @return Output variable */ public SDVariable pow(String varName, double scalar) { - SDVariable ret = sameDiff.f().pow(this, scalar); + SDVariable ret = sameDiff.math.pow(this, scalar); return sameDiff.updateVariableNameAndReference(ret, varName); } @@ -840,7 +840,7 @@ public class SDVariable implements Serializable { * @return Output variable */ public SDVariable rsub(String varName, double scalar) { - val function = sameDiff.f().rsub(this,scalar); + val function = sameDiff.math.rsub(this,scalar); return sameDiff.updateVariableNameAndReference(function,varName); } @@ -861,7 +861,7 @@ public class SDVariable implements Serializable { * @return Output (result) SDVariable */ public SDVariable rsub(String name, SDVariable x) { - val result = sameDiff.f().rsub(this,x); + val result = sameDiff.math.rsub(this,x); return sameDiff.updateVariableNameAndReference(result,name); } @@ -881,7 +881,7 @@ public class SDVariable implements Serializable { * @return Output variable */ public SDVariable rdiv(String varName, double scalar) { - val function = sameDiff.f().rdiv(this, scalar); + val function = sameDiff.math.rdiv(this, scalar); return sameDiff.updateVariableNameAndReference(function, varName); } @@ -902,34 +902,11 @@ public class SDVariable implements Serializable { * @return Output (result) SDVariable */ public SDVariable rdiv(String name, SDVariable x) { - val result = sameDiff.f().rdiv(this,x); + val result = sameDiff.math.rdiv(this,x); return sameDiff.updateVariableNameAndReference(result,name); } - /** - * - * @param sameDiffVariable - * @return - */ - public SDVariable truncatedDiv(SDVariable sameDiffVariable) { - return truncatedDiv(null,sameDiffVariable); - - } - - - - /** - * - * @param sameDiffVariable - * @return - */ - public SDVariable truncatedDiv(String varName, SDVariable sameDiffVariable) { - val function = sameDiff.f().truncatedDiv(this, sameDiffVariable); - return sameDiff.updateVariableNameAndReference(function,varName); - - } - /** * See {@link #squaredDifference(String, SDVariable)} */ @@ -943,7 +920,7 @@ public class SDVariable implements Serializable { * @return squared difference between variables */ public SDVariable squaredDifference(String name, SDVariable x) { - val result = sameDiff.f().squaredDifference(this, x); + val result = sameDiff.math().squaredDifference(this, x); return sameDiff.updateVariableNameAndReference(result, name); } @@ -1431,7 +1408,7 @@ public class SDVariable implements Serializable { } public SDVariable permute(SDVariable dimensions){ - return sameDiff.permute(null, this, dimensions); + return sameDiff.permute( this, dimensions); } /** diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SameDiff.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SameDiff.java index c51ac28a1..77d46b889 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SameDiff.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SameDiff.java @@ -24,7 +24,6 @@ import org.apache.commons.lang3.ArrayUtils; import org.nd4j.autodiff.execution.conf.ExecutorConfiguration; import org.nd4j.autodiff.execution.conf.OutputMode; import org.nd4j.autodiff.functions.DifferentialFunction; -import org.nd4j.autodiff.functions.DifferentialFunctionFactory; import org.nd4j.autodiff.listeners.*; import org.nd4j.autodiff.listeners.impl.HistoryListener; import org.nd4j.autodiff.listeners.records.History; @@ -53,8 +52,7 @@ import org.nd4j.linalg.api.ops.CustomOp; import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.api.ops.Op; import org.nd4j.linalg.api.ops.executioner.OpExecutioner; -import org.nd4j.linalg.api.ops.impl.controlflow.compat.Merge; -import org.nd4j.linalg.api.ops.impl.controlflow.compat.Switch; +import org.nd4j.linalg.api.ops.impl.controlflow.compat.*; import org.nd4j.linalg.api.ops.impl.layers.ExternalErrorsFunction; import org.nd4j.linalg.api.ops.impl.shape.tensorops.TensorArray; import org.nd4j.linalg.api.ops.impl.transforms.Assert; @@ -95,7 +93,7 @@ import java.util.concurrent.atomic.AtomicInteger; import java.util.regex.Matcher; import java.util.regex.Pattern; -import static org.nd4j.autodiff.util.TrainingUtils.stackOutputs; +import static org.nd4j.autodiff.util.SameDiffUtils.stackOutputs; /** * SameDiff is the entrypoint for ND4J's automatic differentiation functionality. @@ -141,7 +139,7 @@ public class SameDiff extends SDBaseOps { //////////////////////////////////////// - private DifferentialFunctionFactory functionFactory; +// private DifferentialFunctionFactory functionFactory; // counter for auto-naming variables private int variableId = 0; @@ -296,15 +294,6 @@ public class SameDiff extends SDBaseOps { return this; } - /** - * Returns this samediff instance's {@link DifferentialFunctionFactory} - * - * @return DifferentialFunctionFactory - */ - public DifferentialFunctionFactory f() { - return functionFactory; - } - /** * Set the current SameDiff-wide {@link Listener} instances. * @@ -917,7 +906,6 @@ public class SameDiff extends SDBaseOps { private SameDiff() { super(null); super.sd = this; - functionFactory = new DifferentialFunctionFactory(this); sameDiffFunctionInstances = new LinkedHashMap<>(); fieldVariableResolutionMapping = HashBasedTable.create(); } @@ -5945,7 +5933,7 @@ public class SameDiff extends SDBaseOps { if(switches.containsKey(argument.name())) return switches.get(argument.name())[1]; - SDVariable[] s = f().switchOp(argument, pred); + SDVariable[] s = switchOp(argument, pred); switches.put(argument.name(), s); return s[1]; } @@ -5955,7 +5943,7 @@ public class SameDiff extends SDBaseOps { this.removeArgumentInterceptor(); if(declared.contains(trueOut.name())) { - SDVariable[] s = f().switchOp(trueOut, pred); + SDVariable[] s = switchOp(trueOut, pred); switches.put(trueOut.name(), s); trueOut = s[1]; } @@ -5975,7 +5963,7 @@ public class SameDiff extends SDBaseOps { if(switches.containsKey(argument.name())) return switches.get(argument.name())[0]; - SDVariable[] s = f().switchOp(argument, pred); + SDVariable[] s = switchOp(argument, pred); switches.put(argument.name(), s); return s[0]; } @@ -5985,13 +5973,13 @@ public class SameDiff extends SDBaseOps { this.removeArgumentInterceptor(); if(declared2.contains(falseOut.name())) { - SDVariable[] s = f().switchOp(falseOut, pred); + SDVariable[] s = switchOp(falseOut, pred); switches.put(falseOut.name(), s); falseOut = s[0]; } falseScope.close(); - SDVariable output = f().merge(trueOut, falseOut); + SDVariable output = merge(trueOut, falseOut); ifScope.close(); @@ -6042,11 +6030,9 @@ public class SameDiff extends SDBaseOps { SDVariable[] entered = new SDVariable[loopVars.length]; for(int i = 0 ; i < loopVars.length ; i++){ - entered[i] = f().enter(loopVars[i], frameName); + entered[i] = new Enter(this, frameName, loopVars[i]).outputVariable(); } - //counter = SD.f().enter(counter, frameName); - SDVariable[] merged = new SDVariable[loopVars.length]; Merge[] mergeOps = new Merge[loopVars.length]; for(int i = 0 ; i < loopVars.length ; i++){ @@ -6072,19 +6058,16 @@ public class SameDiff extends SDBaseOps { SDVariable[] trueSwitches = new SDVariable[loopVars.length]; SDVariable[] exits = new SDVariable[loopVars.length]; for(int i = 0 ; i < loopVars.length ; i++){ - SDVariable[] s = f().switchOp(merged[i], cond_result); + SDVariable[] s = switchOp(merged[i], cond_result); trueSwitches[i] = s[1]; alreadyEntered.add(s[1].name()); - exits[i] = f().exit(s[0]); + exits[i] = new Exit(this, s[0]).outputVariable(); } - //SDVariable[] cs = SD.f().switchOp(counter, cond_result); - //SDVariable counterExit = SD.f().exit(cs[0]); - //counter = cs[1]; - final Set declared = Sets.newHashSet(this.variableMap().keySet()); final Map done = new HashMap<>(); + final SameDiff sd = this; this.addArgumentInterceptor(new ArgumentInterceptor() { @Override public SDVariable intercept(SDVariable argument) { @@ -6098,7 +6081,7 @@ public class SameDiff extends SDBaseOps { if(done.containsKey(argument.name())) return done.get(argument.name()); - SDVariable e = f().enter(argument, frameName, true); + SDVariable e = new Enter(sd, frameName, argument, true).outputVariable(); done.put(argument.name(), e); return e; } @@ -6112,7 +6095,7 @@ public class SameDiff extends SDBaseOps { //counter.add(1); for(int i = 0 ; i < loopVars.length ; i++){ - SDVariable n = f().nextIteration(outs[i]); + SDVariable n = new NextIteration(this, outs[i]).outputVariable(); mergeOps[i].replaceArg(1,n); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/config/OutputConfig.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/config/OutputConfig.java index 5cafa09aa..193229ff9 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/config/OutputConfig.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/config/OutputConfig.java @@ -27,7 +27,7 @@ import lombok.Setter; import org.nd4j.autodiff.listeners.Listener; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; -import org.nd4j.autodiff.util.TrainingUtils; +import org.nd4j.autodiff.util.SameDiffUtils; import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.dataset.adapter.MultiDataSetIteratorAdapter; @@ -165,7 +165,7 @@ public class OutputConfig { Preconditions.checkState(outputs.size() == 1, "Can only use execSingleBatches() when exactly one output is specified, there were %s", outputs.size()); - return TrainingUtils + return SameDiffUtils .getSingleOutput(sd.outputBatches(data, listeners, outputs.toArray(new String[0])), outputs.get(0)); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDBaseOps.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDBaseOps.java index 3b53e5b65..157ec1fbb 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDBaseOps.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDBaseOps.java @@ -37,12 +37,11 @@ public class SDBaseOps { /** * Boolean and array reduction operation, optionally along specified dimensions
* - * @param x Input variable (BOOL type) + * @param x Input variable (NDARRAY type) * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) * @return output reduced array of rank (input rank - num dimensions) (BOOL type) */ public SDVariable all(SDVariable x, int... dimensions) { - SDValidation.validateBool("all", "x", x); Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); return new org.nd4j.linalg.api.ops.impl.reduce.bool.All(sd,x, dimensions).outputVariable(); } @@ -51,12 +50,11 @@ public class SDBaseOps { * Boolean and array reduction operation, optionally along specified dimensions
* * @param name name May be null. Name for the output variable - * @param x Input variable (BOOL type) + * @param x Input variable (NDARRAY type) * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) * @return output reduced array of rank (input rank - num dimensions) (BOOL type) */ public SDVariable all(String name, SDVariable x, int... dimensions) { - SDValidation.validateBool("all", "x", x); Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); SDVariable out = new org.nd4j.linalg.api.ops.impl.reduce.bool.All(sd,x, dimensions).outputVariable(); return sd.updateVariableNameAndReference(out, name); @@ -65,12 +63,11 @@ public class SDBaseOps { /** * Boolean or array reduction operation, optionally along specified dimensions
* - * @param x Input variable (BOOL type) + * @param x Input variable (NDARRAY type) * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) * @return output reduced array of rank (input rank - num dimensions) (BOOL type) */ public SDVariable any(SDVariable x, int... dimensions) { - SDValidation.validateBool("any", "x", x); Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); return new org.nd4j.linalg.api.ops.impl.reduce.bool.Any(sd,x, dimensions).outputVariable(); } @@ -79,12 +76,11 @@ public class SDBaseOps { * Boolean or array reduction operation, optionally along specified dimensions
* * @param name name May be null. Name for the output variable - * @param x Input variable (BOOL type) + * @param x Input variable (NDARRAY type) * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) * @return output reduced array of rank (input rank - num dimensions) (BOOL type) */ public SDVariable any(String name, SDVariable x, int... dimensions) { - SDValidation.validateBool("any", "x", x); Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); SDVariable out = new org.nd4j.linalg.api.ops.impl.reduce.bool.Any(sd,x, dimensions).outputVariable(); return sd.updateVariableNameAndReference(out, name); @@ -196,6 +192,8 @@ public class SDBaseOps { * keepDims = false: [a,c]
* * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
+ * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
* * @param in Input variable (NUMERIC type) * @param keepDims If true: keep the dimensions that are reduced on (as size 1). False: remove the reduction dimensions @@ -220,6 +218,8 @@ public class SDBaseOps { * keepDims = false: [a,c]
* * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
+ * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
* * @param name name May be null. Name for the output variable * @param in Input variable (NUMERIC type) @@ -246,6 +246,8 @@ public class SDBaseOps { * keepDims = false: [a,c]
* * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
+ * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
* * @param in Input variable (NUMERIC type) * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) @@ -269,6 +271,8 @@ public class SDBaseOps { * keepDims = false: [a,c]
* * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
+ * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
* * @param name name May be null. Name for the output variable * @param in Input variable (NUMERIC type) @@ -744,6 +748,8 @@ public class SDBaseOps { * If x and y arrays have equal shape, the output shape is the same as these inputs.
* * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
+ * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
* * Return boolean array with values true where satisfied, or false otherwise.
* @@ -762,6 +768,8 @@ public class SDBaseOps { * If x and y arrays have equal shape, the output shape is the same as these inputs.
* * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
+ * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
* * Return boolean array with values true where satisfied, or false otherwise.
* @@ -964,6 +972,8 @@ public class SDBaseOps { * If x and y arrays have equal shape, the output shape is the same as these inputs.
* * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
+ * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
* * Return boolean array with values true where satisfied, or false otherwise.
* @@ -982,6 +992,8 @@ public class SDBaseOps { * If x and y arrays have equal shape, the output shape is the same as these inputs.
* * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
+ * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
* * Return boolean array with values true where satisfied, or false otherwise.
* @@ -1032,6 +1044,8 @@ public class SDBaseOps { * If x and y arrays have equal shape, the output shape is the same as these inputs.
* * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
+ * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
* * Return boolean array with values true where satisfied, or false otherwise.
* @@ -1050,6 +1064,8 @@ public class SDBaseOps { * If x and y arrays have equal shape, the output shape is the same as these inputs.
* * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
+ * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
* * Return boolean array with values true where satisfied, or false otherwise.
* @@ -1245,6 +1261,8 @@ public class SDBaseOps { * If x and y arrays have equal shape, the output shape is the same as these inputs.
* * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
+ * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
* * Return boolean array with values true where satisfied, or false otherwise.
* @@ -1263,6 +1281,8 @@ public class SDBaseOps { * If x and y arrays have equal shape, the output shape is the same as these inputs.
* * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
+ * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
* * Return boolean array with values true where satisfied, or false otherwise.
* @@ -1313,6 +1333,8 @@ public class SDBaseOps { * If x and y arrays have equal shape, the output shape is the same as these inputs.
* * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
+ * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
* * Return boolean array with values true where satisfied, or false otherwise.
* @@ -1331,6 +1353,8 @@ public class SDBaseOps { * If x and y arrays have equal shape, the output shape is the same as these inputs.
* * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
+ * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
* * Return boolean array with values true where satisfied, or false otherwise.
* @@ -1581,6 +1605,8 @@ public class SDBaseOps { * Element-wise maximum operation: out[i] = max(first[i], second[i])
* * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
+ * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
* * @param first First input array (NUMERIC type) * @param second Second input array (NUMERIC type) @@ -1596,6 +1622,8 @@ public class SDBaseOps { * Element-wise maximum operation: out[i] = max(first[i], second[i])
* * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
+ * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
* * @param name name May be null. Name for the output variable * @param first First input array (NUMERIC type) @@ -1695,6 +1723,38 @@ public class SDBaseOps { return sd.updateVariableNameAndReference(out, name); } + /** + * The merge operation is a control operation that forwards the either of the inputs to the output, when
+ * the first of them becomes available. If both are available, the output is undefined (either input could
+ * be forwarded to the output)
+ * + * @param x Input variable (NUMERIC type) + * @param y Input variable (NUMERIC type) + * @return output Output (NUMERIC type) + */ + public SDVariable merge(SDVariable x, SDVariable y) { + SDValidation.validateNumerical("merge", "x", x); + SDValidation.validateNumerical("merge", "y", y); + return new org.nd4j.linalg.api.ops.impl.controlflow.compat.Merge(sd,x, y).outputVariable(); + } + + /** + * The merge operation is a control operation that forwards the either of the inputs to the output, when
+ * the first of them becomes available. If both are available, the output is undefined (either input could
+ * be forwarded to the output)
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @param y Input variable (NUMERIC type) + * @return output Output (NUMERIC type) + */ + public SDVariable merge(String name, SDVariable x, SDVariable y) { + SDValidation.validateNumerical("merge", "x", x); + SDValidation.validateNumerical("merge", "y", y); + SDVariable out = new org.nd4j.linalg.api.ops.impl.controlflow.compat.Merge(sd,x, y).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + /** * Minimum array reduction operation, optionally along specified dimensions. out = min(in)
* @@ -1785,6 +1845,8 @@ public class SDBaseOps { * Element-wise minimum operation: out[i] = min(first[i], second[i])
* * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
+ * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
* * @param first First input array (NUMERIC type) * @param second Second input array (NUMERIC type) @@ -1800,6 +1862,8 @@ public class SDBaseOps { * Element-wise minimum operation: out[i] = min(first[i], second[i])
* * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
+ * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
* * @param name name May be null. Name for the output variable * @param first First input array (NUMERIC type) @@ -1916,6 +1980,8 @@ public class SDBaseOps { * If x and y arrays have equal shape, the output shape is the same as these inputs.
* * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
+ * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
* * Return boolean array with values true where satisfied, or false otherwise.
* @@ -1934,6 +2000,8 @@ public class SDBaseOps { * If x and y arrays have equal shape, the output shape is the same as these inputs.
* * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
+ * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
* * Return boolean array with values true where satisfied, or false otherwise.
* @@ -4176,6 +4244,32 @@ public class SDBaseOps { return sd.updateVariableNameAndReference(out, name); } + /** + * Switch operation
+ * Predictate - if false, values are output to left (first) branch/output; if true, to right (second) branch/output
+ * + * @param x Input variable (NDARRAY type) + * @param predicate Predictate - if false, values are output to left (first) branch/output; if true, to right (second) branch/output (BOOL type) + */ + public SDVariable[] switchOp(SDVariable x, SDVariable predicate) { + SDValidation.validateBool("switchOp", "predicate", predicate); + return new org.nd4j.linalg.api.ops.impl.controlflow.compat.Switch(sd,x, predicate).outputVariables(); + } + + /** + * Switch operation
+ * Predictate - if false, values are output to left (first) branch/output; if true, to right (second) branch/output
+ * + * @param names names May be null. Arrays of names for the output variables. + * @param x Input variable (NDARRAY type) + * @param predicate Predictate - if false, values are output to left (first) branch/output; if true, to right (second) branch/output (BOOL type) + */ + public SDVariable[] switchOp(String[] names, SDVariable x, SDVariable predicate) { + SDValidation.validateBool("switchOp", "predicate", predicate); + SDVariable[] out = new org.nd4j.linalg.api.ops.impl.controlflow.compat.Switch(sd,x, predicate).outputVariables(); + return sd.updateVariableNamesAndReferences(out, names); + } + /** * //TODO: Ops must be documented.
* diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDImage.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDImage.java index 70940863a..ef030e952 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDImage.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDImage.java @@ -24,6 +24,7 @@ import java.lang.String; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.base.Preconditions; +import org.nd4j.enums.ImageResizeMethod; public class SDImage extends SDOps { public SDImage(SameDiff sameDiff) { @@ -254,6 +255,98 @@ public class SDImage extends SDOps { return sd.updateVariableNameAndReference(out, name); } + /** + * Resize images to size using the specified method.
+ * + * @param input 4D image [NCHW] (NUMERIC type) + * @param size new height and width (INT type) + * @param preserveAspectRatio Whether to preserve the aspect ratio. If this is set, then images will be resized to a size that fits in size while preserving the aspect ratio of the original image. Scales up the image if size is bigger than the current size of the image. Defaults to False. + * @param antialis Whether to use an anti-aliasing filter when downsampling an image + * @param ImageResizeMethod ResizeBilinear: Bilinear interpolation. If 'antialias' is true, becomes a hat/tent filter function with radius 1 when downsampling. + * ResizeLanczos5: Lanczos kernel with radius 5. Very-high-quality filter but may have stronger ringing. + * ResizeBicubic: Cubic interpolant of Keys. Equivalent to Catmull-Rom kernel. Reasonably good quality and faster than Lanczos3Kernel, particularly when upsampling. + * ResizeGaussian: Gaussian kernel with radius 3, sigma = 1.5 / 3.0. + * ResizeNearest: Nearest neighbor interpolation. 'antialias' has no effect when used with nearest neighbor interpolation. + * ResizeArea: Anti-aliased resampling with area interpolation. 'antialias' has no effect when used with area interpolation; it always anti-aliases. + * ResizeMitchelcubic: Mitchell-Netravali Cubic non-interpolating filter. For synthetic images (especially those lacking proper prefiltering), less ringing than Keys cubic kernel but less sharp. + * @return output Output image (NUMERIC type) + */ + public SDVariable imageResize(SDVariable input, SDVariable size, boolean preserveAspectRatio, + boolean antialis, ImageResizeMethod ImageResizeMethod) { + SDValidation.validateNumerical("imageResize", "input", input); + SDValidation.validateInteger("imageResize", "size", size); + return new org.nd4j.linalg.api.ops.impl.image.ImageResize(sd,input, size, preserveAspectRatio, antialis, ImageResizeMethod).outputVariable(); + } + + /** + * Resize images to size using the specified method.
+ * + * @param name name May be null. Name for the output variable + * @param input 4D image [NCHW] (NUMERIC type) + * @param size new height and width (INT type) + * @param preserveAspectRatio Whether to preserve the aspect ratio. If this is set, then images will be resized to a size that fits in size while preserving the aspect ratio of the original image. Scales up the image if size is bigger than the current size of the image. Defaults to False. + * @param antialis Whether to use an anti-aliasing filter when downsampling an image + * @param ImageResizeMethod ResizeBilinear: Bilinear interpolation. If 'antialias' is true, becomes a hat/tent filter function with radius 1 when downsampling. + * ResizeLanczos5: Lanczos kernel with radius 5. Very-high-quality filter but may have stronger ringing. + * ResizeBicubic: Cubic interpolant of Keys. Equivalent to Catmull-Rom kernel. Reasonably good quality and faster than Lanczos3Kernel, particularly when upsampling. + * ResizeGaussian: Gaussian kernel with radius 3, sigma = 1.5 / 3.0. + * ResizeNearest: Nearest neighbor interpolation. 'antialias' has no effect when used with nearest neighbor interpolation. + * ResizeArea: Anti-aliased resampling with area interpolation. 'antialias' has no effect when used with area interpolation; it always anti-aliases. + * ResizeMitchelcubic: Mitchell-Netravali Cubic non-interpolating filter. For synthetic images (especially those lacking proper prefiltering), less ringing than Keys cubic kernel but less sharp. + * @return output Output image (NUMERIC type) + */ + public SDVariable imageResize(String name, SDVariable input, SDVariable size, + boolean preserveAspectRatio, boolean antialis, ImageResizeMethod ImageResizeMethod) { + SDValidation.validateNumerical("imageResize", "input", input); + SDValidation.validateInteger("imageResize", "size", size); + SDVariable out = new org.nd4j.linalg.api.ops.impl.image.ImageResize(sd,input, size, preserveAspectRatio, antialis, ImageResizeMethod).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Resize images to size using the specified method.
+ * + * @param input 4D image [NCHW] (NUMERIC type) + * @param size new height and width (INT type) + * @param ImageResizeMethod ResizeBilinear: Bilinear interpolation. If 'antialias' is true, becomes a hat/tent filter function with radius 1 when downsampling. + * ResizeLanczos5: Lanczos kernel with radius 5. Very-high-quality filter but may have stronger ringing. + * ResizeBicubic: Cubic interpolant of Keys. Equivalent to Catmull-Rom kernel. Reasonably good quality and faster than Lanczos3Kernel, particularly when upsampling. + * ResizeGaussian: Gaussian kernel with radius 3, sigma = 1.5 / 3.0. + * ResizeNearest: Nearest neighbor interpolation. 'antialias' has no effect when used with nearest neighbor interpolation. + * ResizeArea: Anti-aliased resampling with area interpolation. 'antialias' has no effect when used with area interpolation; it always anti-aliases. + * ResizeMitchelcubic: Mitchell-Netravali Cubic non-interpolating filter. For synthetic images (especially those lacking proper prefiltering), less ringing than Keys cubic kernel but less sharp. + * @return output Output image (NUMERIC type) + */ + public SDVariable imageResize(SDVariable input, SDVariable size, + ImageResizeMethod ImageResizeMethod) { + SDValidation.validateNumerical("imageResize", "input", input); + SDValidation.validateInteger("imageResize", "size", size); + return new org.nd4j.linalg.api.ops.impl.image.ImageResize(sd,input, size, false, false, ImageResizeMethod).outputVariable(); + } + + /** + * Resize images to size using the specified method.
+ * + * @param name name May be null. Name for the output variable + * @param input 4D image [NCHW] (NUMERIC type) + * @param size new height and width (INT type) + * @param ImageResizeMethod ResizeBilinear: Bilinear interpolation. If 'antialias' is true, becomes a hat/tent filter function with radius 1 when downsampling. + * ResizeLanczos5: Lanczos kernel with radius 5. Very-high-quality filter but may have stronger ringing. + * ResizeBicubic: Cubic interpolant of Keys. Equivalent to Catmull-Rom kernel. Reasonably good quality and faster than Lanczos3Kernel, particularly when upsampling. + * ResizeGaussian: Gaussian kernel with radius 3, sigma = 1.5 / 3.0. + * ResizeNearest: Nearest neighbor interpolation. 'antialias' has no effect when used with nearest neighbor interpolation. + * ResizeArea: Anti-aliased resampling with area interpolation. 'antialias' has no effect when used with area interpolation; it always anti-aliases. + * ResizeMitchelcubic: Mitchell-Netravali Cubic non-interpolating filter. For synthetic images (especially those lacking proper prefiltering), less ringing than Keys cubic kernel but less sharp. + * @return output Output image (NUMERIC type) + */ + public SDVariable imageResize(String name, SDVariable input, SDVariable size, + ImageResizeMethod ImageResizeMethod) { + SDValidation.validateNumerical("imageResize", "input", input); + SDValidation.validateInteger("imageResize", "size", size); + SDVariable out = new org.nd4j.linalg.api.ops.impl.image.ImageResize(sd,input, size, false, false, ImageResizeMethod).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + /** * Greedily selects a subset of bounding boxes in descending order of score
* diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDMath.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDMath.java index ead137a57..66d47f905 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDMath.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDMath.java @@ -24,6 +24,7 @@ import java.lang.String; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.base.Preconditions; +import org.nd4j.enums.PartitionMode; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.indexing.conditions.Condition; @@ -32,6 +33,67 @@ public class SDMath extends SDOps { super(sameDiff); } + /** + * Clips tensor values to a maximum average L2-norm.
+ * + * @param x Input variable (NUMERIC type) + * @param clipValue Value for clipping + * @param dimensions Dimensions to reduce over (Size: AtLeast(min=0)) + * @return output Output variable (NUMERIC type) + */ + public SDVariable clipByAvgNorm(SDVariable x, double clipValue, int... dimensions) { + SDValidation.validateNumerical("ClipByAvgNorm", "x", x); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + return new org.nd4j.linalg.api.ops.impl.transforms.clip.ClipByAvgNorm(sd,x, clipValue, dimensions).outputVariable(); + } + + /** + * Clips tensor values to a maximum average L2-norm.
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @param clipValue Value for clipping + * @param dimensions Dimensions to reduce over (Size: AtLeast(min=0)) + * @return output Output variable (NUMERIC type) + */ + public SDVariable clipByAvgNorm(String name, SDVariable x, double clipValue, int... dimensions) { + SDValidation.validateNumerical("ClipByAvgNorm", "x", x); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.clip.ClipByAvgNorm(sd,x, clipValue, dimensions).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Looks up ids in a list of embedding tensors.
+ * + * @param x Input tensor (NUMERIC type) + * @param indices A Tensor containing the ids to be looked up. (NUMERIC type) + * @param PartitionMode partition_mode == 0 - i.e. 'mod' , 1 - 'div' + * @return output Shifted output (NUMERIC type) + */ + public SDVariable embeddingLookup(SDVariable x, SDVariable indices, PartitionMode PartitionMode) { + SDValidation.validateNumerical("EmbeddingLookup", "x", x); + SDValidation.validateNumerical("EmbeddingLookup", "indices", indices); + return new org.nd4j.linalg.api.ops.impl.shape.tensorops.EmbeddingLookup(sd,x, indices, PartitionMode).outputVariable(); + } + + /** + * Looks up ids in a list of embedding tensors.
+ * + * @param name name May be null. Name for the output variable + * @param x Input tensor (NUMERIC type) + * @param indices A Tensor containing the ids to be looked up. (NUMERIC type) + * @param PartitionMode partition_mode == 0 - i.e. 'mod' , 1 - 'div' + * @return output Shifted output (NUMERIC type) + */ + public SDVariable embeddingLookup(String name, SDVariable x, SDVariable indices, + PartitionMode PartitionMode) { + SDValidation.validateNumerical("EmbeddingLookup", "x", x); + SDValidation.validateNumerical("EmbeddingLookup", "indices", indices); + SDVariable out = new org.nd4j.linalg.api.ops.impl.shape.tensorops.EmbeddingLookup(sd,x, indices, PartitionMode).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + /** * Elementwise absolute value operation: out = abs(x)
* @@ -104,6 +166,68 @@ public class SDMath extends SDOps { return sd.updateVariableNameAndReference(out, name); } + /** + * Pairwise addition operation, out = x + y
+ * + * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
+ * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
+ * + * @param x Input variable (NUMERIC type) + * @param y Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable add(SDVariable x, SDVariable y) { + SDValidation.validateNumerical("add", "x", x); + SDValidation.validateNumerical("add", "y", y); + return new org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.AddOp(sd,x, y).outputVariable(); + } + + /** + * Pairwise addition operation, out = x + y
+ * + * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
+ * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @param y Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable add(String name, SDVariable x, SDVariable y) { + SDValidation.validateNumerical("add", "x", x); + SDValidation.validateNumerical("add", "y", y); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.AddOp(sd,x, y).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Scalar add operation, out = in + scalar
+ * + * @param x Input variable (NUMERIC type) + * @param value Scalar value for op + * @return output Output variable (NUMERIC type) + */ + public SDVariable add(SDVariable x, double value) { + SDValidation.validateNumerical("add", "x", x); + return new org.nd4j.linalg.api.ops.impl.scalar.ScalarAdd(sd,x, value).outputVariable(); + } + + /** + * Scalar add operation, out = in + scalar
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @param value Scalar value for op + * @return output Output variable (NUMERIC type) + */ + public SDVariable add(String name, SDVariable x, double value) { + SDValidation.validateNumerical("add", "x", x); + SDVariable out = new org.nd4j.linalg.api.ops.impl.scalar.ScalarAdd(sd,x, value).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + /** * Absolute max array reduction operation, optionally along specified dimensions: out = max(abs(x))
* @@ -1064,6 +1188,68 @@ public class SDMath extends SDOps { return sd.updateVariableNameAndReference(out, name); } + /** + * Pairwise division operation, out = x / y
+ * + * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
+ * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
+ * + * @param x Input variable (NUMERIC type) + * @param y Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable div(SDVariable x, SDVariable y) { + SDValidation.validateNumerical("div", "x", x); + SDValidation.validateNumerical("div", "y", y); + return new org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.DivOp(sd,x, y).outputVariable(); + } + + /** + * Pairwise division operation, out = x / y
+ * + * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
+ * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @param y Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable div(String name, SDVariable x, SDVariable y) { + SDValidation.validateNumerical("div", "x", x); + SDValidation.validateNumerical("div", "y", y); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.DivOp(sd,x, y).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Scalar division operation, out = in / scalar
+ * + * @param x Input variable (NUMERIC type) + * @param value Scalar value for op + * @return output Output variable (NUMERIC type) + */ + public SDVariable div(SDVariable x, double value) { + SDValidation.validateNumerical("div", "x", x); + return new org.nd4j.linalg.api.ops.impl.scalar.ScalarDivision(sd,x, value).outputVariable(); + } + + /** + * Scalar division operation, out = in / scalar
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @param value Scalar value for op + * @return output Output variable (NUMERIC type) + */ + public SDVariable div(String name, SDVariable x, double value) { + SDValidation.validateNumerical("div", "x", x); + SDVariable out = new org.nd4j.linalg.api.ops.impl.scalar.ScalarDivision(sd,x, value).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + /** * Entropy reduction: -sum(x * log(x))
* @@ -1490,6 +1676,104 @@ public class SDMath extends SDOps { return sd.updateVariableNameAndReference(out, name); } + /** + * Pairwise floor division operation, out = floor(x / y)
+ * + * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
+ * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
+ * + * @param x Input variable (NUMERIC type) + * @param y Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable floorDiv(SDVariable x, SDVariable y) { + SDValidation.validateNumerical("floorDiv", "x", x); + SDValidation.validateNumerical("floorDiv", "y", y); + return new org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.FloorDivOp(sd,x, y).outputVariable(); + } + + /** + * Pairwise floor division operation, out = floor(x / y)
+ * + * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
+ * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @param y Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable floorDiv(String name, SDVariable x, SDVariable y) { + SDValidation.validateNumerical("floorDiv", "x", x); + SDValidation.validateNumerical("floorDiv", "y", y); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.FloorDivOp(sd,x, y).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Pairwise Modulus division operation
+ * + * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
+ * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
+ * + * @param x Input variable (NUMERIC type) + * @param y Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable floorMod(SDVariable x, SDVariable y) { + SDValidation.validateNumerical("floorMod", "x", x); + SDValidation.validateNumerical("floorMod", "y", y); + return new org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.FloorModOp(sd,x, y).outputVariable(); + } + + /** + * Pairwise Modulus division operation
+ * + * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
+ * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @param y Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable floorMod(String name, SDVariable x, SDVariable y) { + SDValidation.validateNumerical("floorMod", "x", x); + SDValidation.validateNumerical("floorMod", "y", y); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.FloorModOp(sd,x, y).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Scalar floor modulus operation
+ * + * @param x Input variable (NUMERIC type) + * @param value Scalar value for op + * @return output Output variable (NUMERIC type) + */ + public SDVariable floorMod(SDVariable x, double value) { + SDValidation.validateNumerical("floorMod", "x", x); + return new org.nd4j.linalg.api.ops.impl.scalar.ScalarFMod(sd,x, value).outputVariable(); + } + + /** + * Scalar floor modulus operation
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @param value Scalar value for op + * @return output Output variable (NUMERIC type) + */ + public SDVariable floorMod(String name, SDVariable x, double value) { + SDValidation.validateNumerical("floorMod", "x", x); + SDVariable out = new org.nd4j.linalg.api.ops.impl.scalar.ScalarFMod(sd,x, value).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + /** * Hamming distance reduction operation. The output contains the cosine distance for each
* tensor/subset along the specified dimensions:
@@ -2198,6 +2482,42 @@ public class SDMath extends SDOps { return sd.updateVariableNameAndReference(out, name); } + /** + * Pairwise max operation, out = max(x, y)
+ * + * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
+ * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
+ * + * @param x First input variable, x (NUMERIC type) + * @param y Second input variable, y (NUMERIC type) + * @return out Output (NUMERIC type) + */ + public SDVariable max(SDVariable x, SDVariable y) { + SDValidation.validateNumerical("max", "x", x); + SDValidation.validateNumerical("max", "y", y); + return new org.nd4j.linalg.api.ops.impl.transforms.custom.Max(sd,x, y).outputVariable(); + } + + /** + * Pairwise max operation, out = max(x, y)
+ * + * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
+ * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
+ * + * @param name name May be null. Name for the output variable + * @param x First input variable, x (NUMERIC type) + * @param y Second input variable, y (NUMERIC type) + * @return out Output (NUMERIC type) + */ + public SDVariable max(String name, SDVariable x, SDVariable y) { + SDValidation.validateNumerical("max", "x", x); + SDValidation.validateNumerical("max", "y", y); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.Max(sd,x, y).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + /** * Merge add function: merges an arbitrary number of equal shaped arrays using element-wise addition:
* out = sum_i in[i]
@@ -2308,6 +2628,78 @@ public class SDMath extends SDOps { return sd.updateVariableNamesAndReferences(out, names); } + /** + * Pairwise max operation, out = min(x, y)
+ * + * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
+ * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
+ * + * @param x First input variable, x (NUMERIC type) + * @param y Second input variable, y (NUMERIC type) + * @return out Output (NUMERIC type) + */ + public SDVariable min(SDVariable x, SDVariable y) { + SDValidation.validateNumerical("min", "x", x); + SDValidation.validateNumerical("min", "y", y); + return new org.nd4j.linalg.api.ops.impl.transforms.custom.Min(sd,x, y).outputVariable(); + } + + /** + * Pairwise max operation, out = min(x, y)
+ * + * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
+ * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
+ * + * @param name name May be null. Name for the output variable + * @param x First input variable, x (NUMERIC type) + * @param y Second input variable, y (NUMERIC type) + * @return out Output (NUMERIC type) + */ + public SDVariable min(String name, SDVariable x, SDVariable y) { + SDValidation.validateNumerical("min", "x", x); + SDValidation.validateNumerical("min", "y", y); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.Min(sd,x, y).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Pairwise modulus (remainder) operation, out = x % y
+ * + * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
+ * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
+ * + * @param x Input variable (NUMERIC type) + * @param y Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable mod(SDVariable x, SDVariable y) { + SDValidation.validateNumerical("mod", "x", x); + SDValidation.validateNumerical("mod", "y", y); + return new org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.ModOp(sd,x, y).outputVariable(); + } + + /** + * Pairwise modulus (remainder) operation, out = x % y
+ * + * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
+ * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @param y Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable mod(String name, SDVariable x, SDVariable y) { + SDValidation.validateNumerical("mod", "x", x); + SDValidation.validateNumerical("mod", "y", y); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.ModOp(sd,x, y).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + /** * Calculate the mean and (population) variance for the input variable, for the specified axis
* @@ -2334,6 +2726,68 @@ public class SDMath extends SDOps { return sd.updateVariableNamesAndReferences(out, names); } + /** + * Pairwise multiplication operation, out = x * y
+ * + * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
+ * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
+ * + * @param x Input variable (NUMERIC type) + * @param y Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable mul(SDVariable x, SDVariable y) { + SDValidation.validateNumerical("mul", "x", x); + SDValidation.validateNumerical("mul", "y", y); + return new org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.MulOp(sd,x, y).outputVariable(); + } + + /** + * Pairwise multiplication operation, out = x * y
+ * + * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
+ * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @param y Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable mul(String name, SDVariable x, SDVariable y) { + SDValidation.validateNumerical("mul", "x", x); + SDValidation.validateNumerical("mul", "y", y); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.MulOp(sd,x, y).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Scalar multiplication operation, out = in * scalar
+ * + * @param x Input variable (NUMERIC type) + * @param value Scalar value for op + * @return output Output variable (NUMERIC type) + */ + public SDVariable mul(SDVariable x, double value) { + SDValidation.validateNumerical("mul", "x", x); + return new org.nd4j.linalg.api.ops.impl.scalar.ScalarMultiplication(sd,x, value).outputVariable(); + } + + /** + * Scalar multiplication operation, out = in * scalar
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @param value Scalar value for op + * @return output Output variable (NUMERIC type) + */ + public SDVariable mul(String name, SDVariable x, double value) { + SDValidation.validateNumerical("mul", "x", x); + SDVariable out = new org.nd4j.linalg.api.ops.impl.scalar.ScalarMultiplication(sd,x, value).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + /** * Elementwise negative operation: out = -x
* @@ -2480,6 +2934,96 @@ public class SDMath extends SDOps { return sd.updateVariableNameAndReference(out, name); } + /** + * Rational Tanh Approximation elementwise function, as described in the paper:
+ * Compact Convolutional Neural Network Cascade for Face Detection
+ * This is a faster Tanh approximation
+ * + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable rationalTanh(SDVariable x) { + SDValidation.validateNumerical("rationalTanh", "x", x); + return new org.nd4j.linalg.api.ops.impl.transforms.strict.RationalTanh(sd,x).outputVariable(); + } + + /** + * Rational Tanh Approximation elementwise function, as described in the paper:
+ * Compact Convolutional Neural Network Cascade for Face Detection
+ * This is a faster Tanh approximation
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable rationalTanh(String name, SDVariable x) { + SDValidation.validateNumerical("rationalTanh", "x", x); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.strict.RationalTanh(sd,x).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Pairwise reverse division operation, out = y / x
+ * + * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
+ * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
+ * + * @param x Input variable (NUMERIC type) + * @param y Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable rdiv(SDVariable x, SDVariable y) { + SDValidation.validateNumerical("rdiv", "x", x); + SDValidation.validateNumerical("rdiv", "y", y); + return new org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.RDivOp(sd,x, y).outputVariable(); + } + + /** + * Pairwise reverse division operation, out = y / x
+ * + * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
+ * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @param y Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable rdiv(String name, SDVariable x, SDVariable y) { + SDValidation.validateNumerical("rdiv", "x", x); + SDValidation.validateNumerical("rdiv", "y", y); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.RDivOp(sd,x, y).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Scalar reverse division operation, out = scalar / in
+ * + * @param x Input variable (NUMERIC type) + * @param value Scalar value for op + * @return output Output variable (NUMERIC type) + */ + public SDVariable rdiv(SDVariable x, double value) { + SDValidation.validateNumerical("rdiv", "x", x); + return new org.nd4j.linalg.api.ops.impl.scalar.ScalarReverseDivision(sd,x, value).outputVariable(); + } + + /** + * Scalar reverse division operation, out = scalar / in
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @param value Scalar value for op + * @return output Output variable (NUMERIC type) + */ + public SDVariable rdiv(String name, SDVariable x, double value) { + SDValidation.validateNumerical("rdiv", "x", x); + SDVariable out = new org.nd4j.linalg.api.ops.impl.scalar.ScalarReverseDivision(sd,x, value).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + /** * Element-wise reciprocal (inverse) function: out[i] = 1 / in[i]
* @@ -2504,6 +3048,30 @@ public class SDMath extends SDOps { return sd.updateVariableNameAndReference(out, name); } + /** + * Rectified tanh operation: max(0, tanh(in))
+ * + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable rectifiedTanh(SDVariable x) { + SDValidation.validateNumerical("rectifiedTanh", "x", x); + return new org.nd4j.linalg.api.ops.impl.transforms.strict.RectifiedTanh(sd,x).outputVariable(); + } + + /** + * Rectified tanh operation: max(0, tanh(in))
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable rectifiedTanh(String name, SDVariable x) { + SDValidation.validateNumerical("rectifiedTanh", "x", x); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.strict.RectifiedTanh(sd,x).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + /** * Element-wise round function: out = round(x).
* Rounds (up or down depending on value) to the nearest integer value.
@@ -2554,6 +3122,68 @@ public class SDMath extends SDOps { return sd.updateVariableNameAndReference(out, name); } + /** + * Pairwise reverse subtraction operation, out = y - x
+ * + * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
+ * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
+ * + * @param x Input variable (NUMERIC type) + * @param y Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable rsub(SDVariable x, SDVariable y) { + SDValidation.validateNumerical("rsub", "x", x); + SDValidation.validateNumerical("rsub", "y", y); + return new org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.RSubOp(sd,x, y).outputVariable(); + } + + /** + * Pairwise reverse subtraction operation, out = y - x
+ * + * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
+ * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @param y Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable rsub(String name, SDVariable x, SDVariable y) { + SDValidation.validateNumerical("rsub", "x", x); + SDValidation.validateNumerical("rsub", "y", y); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.RSubOp(sd,x, y).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Scalar reverse subtraction operation, out = scalar - in
+ * + * @param x Input variable (NUMERIC type) + * @param value Scalar value for op + * @return output Output variable (NUMERIC type) + */ + public SDVariable rsub(SDVariable x, double value) { + SDValidation.validateNumerical("rsub", "x", x); + return new org.nd4j.linalg.api.ops.impl.scalar.ScalarReverseSubtraction(sd,x, value).outputVariable(); + } + + /** + * Scalar reverse subtraction operation, out = scalar - in
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @param value Scalar value for op + * @return output Output variable (NUMERIC type) + */ + public SDVariable rsub(String name, SDVariable x, double value) { + SDValidation.validateNumerical("rsub", "x", x); + SDVariable out = new org.nd4j.linalg.api.ops.impl.scalar.ScalarReverseSubtraction(sd,x, value).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + /** * Set the diagonal value to the specified values
* If input is
@@ -2752,6 +3382,42 @@ public class SDMath extends SDOps { return sd.updateVariableNameAndReference(out, name); } + /** + * Pairwise squared difference operation.
+ * + * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
+ * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
+ * + * @param x Input variable (NUMERIC type) + * @param y Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable squaredDifference(SDVariable x, SDVariable y) { + SDValidation.validateNumerical("squaredDifference", "x", x); + SDValidation.validateNumerical("squaredDifference", "y", y); + return new org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.SquaredDifferenceOp(sd,x, y).outputVariable(); + } + + /** + * Pairwise squared difference operation.
+ * + * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
+ * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @param y Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable squaredDifference(String name, SDVariable x, SDVariable y) { + SDValidation.validateNumerical("squaredDifference", "x", x); + SDValidation.validateNumerical("squaredDifference", "y", y); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.SquaredDifferenceOp(sd,x, y).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + /** * Standardize input variable along given axis
*


@@ -2832,6 +3498,68 @@ public class SDMath extends SDOps { return sd.updateVariableNameAndReference(out, name); } + /** + * Pairwise subtraction operation, out = x - y
+ * + * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
+ * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
+ * + * @param x Input variable (NUMERIC type) + * @param y Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable sub(SDVariable x, SDVariable y) { + SDValidation.validateNumerical("sub", "x", x); + SDValidation.validateNumerical("sub", "y", y); + return new org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.SubOp(sd,x, y).outputVariable(); + } + + /** + * Pairwise subtraction operation, out = x - y
+ * + * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
+ * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @param y Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable sub(String name, SDVariable x, SDVariable y) { + SDValidation.validateNumerical("sub", "x", x); + SDValidation.validateNumerical("sub", "y", y); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.SubOp(sd,x, y).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Scalar subtraction operation, out = in - scalar
+ * + * @param x Input variable (NUMERIC type) + * @param value Scalar value for op + * @return output Output variable (NUMERIC type) + */ + public SDVariable sub(SDVariable x, double value) { + SDValidation.validateNumerical("sub", "x", x); + return new org.nd4j.linalg.api.ops.impl.scalar.ScalarSubtraction(sd,x, value).outputVariable(); + } + + /** + * Scalar subtraction operation, out = in - scalar
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @param value Scalar value for op + * @return output Output variable (NUMERIC type) + */ + public SDVariable sub(String name, SDVariable x, double value) { + SDValidation.validateNumerical("sub", "x", x); + SDVariable out = new org.nd4j.linalg.api.ops.impl.scalar.ScalarSubtraction(sd,x, value).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + /** * Elementwise tangent operation: out = tan(x)
* diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDNN.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDNN.java index 7b18c3614..15d70aac5 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDNN.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDNN.java @@ -24,12 +24,37 @@ import java.lang.String; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.base.Preconditions; +import org.nd4j.enums.PadMode; public class SDNN extends SDOps { public SDNN(SameDiff sameDiff) { super(sameDiff); } + /** + * Concatenates a ReLU which selects only the positive part of the activation with a ReLU which selects only the negative part of the activation. Note that as a result this non-linearity doubles the depth of the activations.
+ * + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable cReLU(SDVariable x) { + SDValidation.validateNumerical("CReLU", "x", x); + return new org.nd4j.linalg.api.ops.impl.transforms.custom.CReLU(sd,x).outputVariable(); + } + + /** + * Concatenates a ReLU which selects only the positive part of the activation with a ReLU which selects only the negative part of the activation. Note that as a result this non-linearity doubles the depth of the activations.
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable cReLU(String name, SDVariable x) { + SDValidation.validateNumerical("CReLU", "x", x); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.CReLU(sd,x).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + /** * Neural network batch normalization operation.
* For details, see https://arxiv.org/abs/1502.03167
@@ -698,6 +723,39 @@ public class SDNN extends SDOps { return sd.updateVariableNameAndReference(out, name); } + /** + * Padding operation
+ * + * @param input Input tensor (NUMERIC type) + * @param padding Padding value (NUMERIC type) + * @param PadMode Padding format + * @param constant Padding constant + * @return output Padded input (NUMERIC type) + */ + public SDVariable pad(SDVariable input, SDVariable padding, PadMode PadMode, double constant) { + SDValidation.validateNumerical("pad", "input", input); + SDValidation.validateNumerical("pad", "padding", padding); + return new org.nd4j.linalg.api.ops.impl.transforms.Pad(sd,input, padding, PadMode, constant).outputVariable(); + } + + /** + * Padding operation
+ * + * @param name name May be null. Name for the output variable + * @param input Input tensor (NUMERIC type) + * @param padding Padding value (NUMERIC type) + * @param PadMode Padding format + * @param constant Padding constant + * @return output Padded input (NUMERIC type) + */ + public SDVariable pad(String name, SDVariable input, SDVariable padding, PadMode PadMode, + double constant) { + SDValidation.validateNumerical("pad", "input", input); + SDValidation.validateNumerical("pad", "padding", padding); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.Pad(sd,input, padding, PadMode, constant).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + /** * Padding operation
* @@ -709,7 +767,7 @@ public class SDNN extends SDOps { public SDVariable pad(SDVariable input, SDVariable padding, double constant) { SDValidation.validateNumerical("pad", "input", input); SDValidation.validateNumerical("pad", "padding", padding); - return new org.nd4j.linalg.api.ops.impl.transforms.Pad(sd,input, padding, constant).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.transforms.Pad(sd,input, padding, PadMode.CONSTANT, constant).outputVariable(); } /** @@ -724,7 +782,35 @@ public class SDNN extends SDOps { public SDVariable pad(String name, SDVariable input, SDVariable padding, double constant) { SDValidation.validateNumerical("pad", "input", input); SDValidation.validateNumerical("pad", "padding", padding); - SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.Pad(sd,input, padding, constant).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.Pad(sd,input, padding, PadMode.CONSTANT, constant).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * GELU activation function - Gaussian Error Linear Units
+ * For more details, see Gaussian Error Linear Units (GELUs) - https://arxiv.org/abs/1606.08415
+ * This method uses the precise method
+ * + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable preciseGelu(SDVariable x) { + SDValidation.validateNumerical("preciseGelu", "x", x); + return new org.nd4j.linalg.api.ops.impl.transforms.strict.PreciseGELU(sd,x).outputVariable(); + } + + /** + * GELU activation function - Gaussian Error Linear Units
+ * For more details, see Gaussian Error Linear Units (GELUs) - https://arxiv.org/abs/1606.08415
+ * This method uses the precise method
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable preciseGelu(String name, SDVariable x) { + SDValidation.validateNumerical("preciseGelu", "x", x); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.strict.PreciseGELU(sd,x).outputVariable(); return sd.updateVariableNameAndReference(out, name); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDOps.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDOps.java index 88792bddb..fc406caa6 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDOps.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDOps.java @@ -1,5 +1,5 @@ /******************************************************************************* - * Copyright (c) 2015-2019 Skymind, Inc. + * Copyright (c) 2019-2020 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -14,20 +14,21 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ +//================== GENERATED CODE - DO NOT MODIFY THIS FILE ================== + package org.nd4j.autodiff.samediff.ops; -import org.nd4j.autodiff.functions.DifferentialFunctionFactory; +import static org.nd4j.autodiff.samediff.ops.SDValidation.isSameType; + +import java.lang.String; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.base.Preconditions; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.indexing.conditions.Condition; -/** - * Abstract class for defining categories of operations - such as {@link SDMath} that is available via {@code SameDiff.math()} - * - * @author Alex Black - */ -public abstract class SDOps { - - protected final SameDiff sd; +public class SDOps { + protected SameDiff sd; public SDOps() { sd = null; @@ -37,11 +38,5 @@ public abstract class SDOps { this.sd = sameDiff; } - protected DifferentialFunctionFactory f() { - return sd.f(); - } - protected SDVariable updateVariableNameAndReference(SDVariable varToUpdate, String newVarName) { - return sd.updateVariableNameAndReference(varToUpdate, newVarName); - } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/transform/OpPredicate.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/transform/OpPredicate.java index 97a47d257..2b91300eb 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/transform/OpPredicate.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/transform/OpPredicate.java @@ -21,7 +21,6 @@ import org.nd4j.autodiff.samediff.SameDiff; /** * An OpPredicate defines whether an operation ({@link DifferentialFunction}) matches or not.
- * Used mainly in {@link org.nd4j.autodiff.functions.DifferentialFunctionFactory} * * @author Alex Black */ diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/util/SameDiffUtils.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/util/SameDiffUtils.java new file mode 100644 index 000000000..a3f9ddea2 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/util/SameDiffUtils.java @@ -0,0 +1,139 @@ +/* + * Copyright (c) 2015-2019 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.nd4j.autodiff.util; + +import java.util.*; + +import lombok.AccessLevel; +import lombok.NoArgsConstructor; +import org.nd4j.autodiff.functions.DifferentialFunction; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.base.Preconditions; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.impl.layers.ExternalErrorsFunction; +import org.nd4j.linalg.api.ops.impl.shape.ReductionShape; +import org.nd4j.linalg.api.shape.Shape; +import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator; +import org.nd4j.linalg.exception.ND4JException; +import org.nd4j.linalg.factory.Nd4j; + +/** + * Utilities for SameDiff training and inference + */ +@NoArgsConstructor(access = AccessLevel.PRIVATE) +public class SameDiffUtils { + + /** + * Stack batch outputs, like an output from {@link org.nd4j.autodiff.samediff.SameDiff#output(MultiDataSetIterator, String...)} + */ + public static Map stackOutputs(List> outputs){ + Map> outs = new HashMap<>(); + for(Map batch : outputs){ + for(String k : batch.keySet()){ + if(!outs.containsKey(k)) + outs.put(k, new ArrayList()); + outs.get(k).add(batch.get(k)); + } + } + + Map ret = new HashMap<>(); + for(String k : outs.keySet()){ + try { + ret.put(k, Nd4j.concat(0, outs.get(k).toArray(new INDArray[0]))); + } catch(Exception e){ + throw new ND4JException("Error concatenating batch outputs", e); + } + } + return ret; + } + + /** + * Get a list of batch outputs for a single variable from a list of batch outputs for all variables + */ + public static List getSingleOutput(List> outputs, String output){ + List batches = new ArrayList<>(); + for(Map batch : outputs) + batches.add(batch.get(output)); + + return batches; + } + + public static ExternalErrorsFunction externalErrors(SameDiff sameDiff, Map externalGradients, SDVariable... inputs) { + Preconditions.checkArgument(inputs != null && inputs.length > 0, "Require at least one SDVariable to" + + " be specified when using external errors: got %s", inputs); + ExternalErrorsFunction fn = new ExternalErrorsFunction(sameDiff, Arrays.asList(inputs), externalGradients); + fn.outputVariable(); + return fn; + } + + public static ExternalErrorsFunction externalErrors(SameDiff sameDiff, SDVariable[] inputs) { + return externalErrors(sameDiff, null, inputs); + } + + + + /** + * Add 1s as required to the array make an array possible to be broadcast with the original (pre-reduce) array. + *

+ * Example: if doing [a,b,c].sum(1), result is [a,c]. To 'undo' this in a way that can be auto-broadcast, + * we want to expand as required - i.e., [a,c] -> [a,1,c] which can be auto-broadcast with the original [a,b,c]. + * This is typically only used with reduction operations backprop. + * + * @param origRank Rank of the original array, before the reduction was executed + * @param reduceDims Dimensions that the original array was reduced from + * @param toExpand Array to add 1s to the shape to (such that it can be + * @return Reshaped array. + */ + public static SDVariable reductionBroadcastableWithOrigShape(int origRank, int[] reduceDims, SDVariable toExpand) { + if (Shape.isWholeArray(origRank, reduceDims)) { + //Output is [1,1] which is already broadcastable + return toExpand; + } else if (origRank == 2 && reduceDims.length == 1) { + //In this case: [a,b] -> [1,b] or [a,b] -> [a,1] + //both are already broadcastable + return toExpand; + } else { + //Example: [a,b,c].sum(1) -> [a,c]... want [a,1,c] + for (int d : reduceDims) { + toExpand = toExpand.getSameDiff().expandDims(toExpand, d); + } + return toExpand; + } + } + + public static SDVariable reductionBroadcastableWithOrigShape(SDVariable origInput, SDVariable axis, SDVariable toExpand) { + SDVariable shape = origInput.shape(); + SDVariable reduceShape = reductionShape(shape, axis, true); + SDVariable reshaped = toExpand.reshape(reduceShape); + return reshaped; + } + + public static SDVariable reductionShape(SDVariable shape, SDVariable axis, boolean keepDim){ + return new ReductionShape(shape.getSameDiff(), shape, axis, keepDim).outputVariable(); + } + + public static void validateDifferentialFunctionSameDiff(SameDiff sameDiff, SDVariable function, DifferentialFunction op) { + + Preconditions.checkState(function != null, "Passed in function was null."); + Preconditions.checkState(function.getSameDiff() == sameDiff); + + Preconditions.checkState(function.getSameDiff() == sameDiff, + "Function applications must be contained " + + "in same sameDiff. The left %s must match this function %s", function, op); + } +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/util/TrainingUtils.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/util/TrainingUtils.java deleted file mode 100644 index 289bd15be..000000000 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/util/TrainingUtils.java +++ /dev/null @@ -1,70 +0,0 @@ -/* - * Copyright (c) 2015-2019 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - */ - -package org.nd4j.autodiff.util; - -import java.util.ArrayList; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import lombok.AccessLevel; -import lombok.NoArgsConstructor; -import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator; -import org.nd4j.linalg.exception.ND4JException; -import org.nd4j.linalg.factory.Nd4j; - -/** - * Utilities for SameDiff training and inference - */ -@NoArgsConstructor(access = AccessLevel.PRIVATE) -public class TrainingUtils { - - /** - * Stack batch outputs, like an output from {@link org.nd4j.autodiff.samediff.SameDiff#output(MultiDataSetIterator, String...)} - */ - public static Map stackOutputs(List> outputs){ - Map> outs = new HashMap<>(); - for(Map batch : outputs){ - for(String k : batch.keySet()){ - if(!outs.containsKey(k)) - outs.put(k, new ArrayList()); - outs.get(k).add(batch.get(k)); - } - } - - Map ret = new HashMap<>(); - for(String k : outs.keySet()){ - try { - ret.put(k, Nd4j.concat(0, outs.get(k).toArray(new INDArray[0]))); - } catch(Exception e){ - throw new ND4JException("Error concatenating batch outputs", e); - } - } - return ret; - } - - /** - * Get a list of batch outputs for a single variable from a list of batch outputs for all variables - */ - public static List getSingleOutput(List> outputs, String output){ - List batches = new ArrayList<>(); - for(Map batch : outputs) - batches.add(batch.get(output)); - - return batches; - } -} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/enums/ImageResizeMethod.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/enums/ImageResizeMethod.java new file mode 100644 index 000000000..42043dad7 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/enums/ImageResizeMethod.java @@ -0,0 +1,43 @@ +/******************************************************************************* + * Copyright (c) 2019-2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +//================== GENERATED CODE - DO NOT MODIFY THIS FILE ================== + +package org.nd4j.enums; + +/** + * ResizeBilinear: Bilinear interpolation. If 'antialias' is true, becomes a hat/tent filter function with radius 1 when downsampling. + * ResizeLanczos5: Lanczos kernel with radius 5. Very-high-quality filter but may have stronger ringing. + * ResizeBicubic: Cubic interpolant of Keys. Equivalent to Catmull-Rom kernel. Reasonably good quality and faster than Lanczos3Kernel, particularly when upsampling. + * ResizeGaussian: Gaussian kernel with radius 3, sigma = 1.5 / 3.0. + * ResizeNearest: Nearest neighbor interpolation. 'antialias' has no effect when used with nearest neighbor interpolation. + * ResizeArea: Anti-aliased resampling with area interpolation. 'antialias' has no effect when used with area interpolation; it always anti-aliases. + * ResizeMitchelcubic: Mitchell-Netravali Cubic non-interpolating filter. For synthetic images (especially those lacking proper prefiltering), less ringing than Keys cubic kernel but less sharp. */ +public enum ImageResizeMethod { + ResizeBilinear, + + ResizeBicubic, + + ResizeNearest, + + ResizeGaussian, + + ResizeLanczos5, + + ResizeMitchelcubic, + + ResizeArea +} diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/EpochStepCounter.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/enums/PadMode.java similarity index 75% rename from rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/EpochStepCounter.java rename to nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/enums/PadMode.java index 533209ed7..4802ebdaf 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/EpochStepCounter.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/enums/PadMode.java @@ -1,5 +1,5 @@ /******************************************************************************* - * Copyright (c) 2020 Konduit K.K. + * Copyright (c) 2019-2020 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -14,8 +14,16 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ -package org.deeplearning4j.rl4j.learning; +//================== GENERATED CODE - DO NOT MODIFY THIS FILE ================== -public interface EpochStepCounter { - int getCurrentEpochStep(); +package org.nd4j.enums; + +/** + * Padding format */ +public enum PadMode { + CONSTANT, + + REFLECT, + + SYMMETRIC } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/enums/PartitionMode.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/enums/PartitionMode.java new file mode 100644 index 000000000..565ffd792 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/enums/PartitionMode.java @@ -0,0 +1,27 @@ +/******************************************************************************* + * Copyright (c) 2019-2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +//================== GENERATED CODE - DO NOT MODIFY THIS FILE ================== + +package org.nd4j.enums; + +/** + * partition_mode == 0 - i.e. 'mod' , 1 - 'div' */ +public enum PartitionMode { + MOD, + + DIV +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/enums/WeightsFormat.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/enums/WeightsFormat.java new file mode 100644 index 000000000..865d23282 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/enums/WeightsFormat.java @@ -0,0 +1,29 @@ +/******************************************************************************* + * Copyright (c) 2019-2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +//================== GENERATED CODE - DO NOT MODIFY THIS FILE ================== + +package org.nd4j.enums; + +/** + * Weights format: [kH, kW, iC, oC] or [oC, iC, kH, kW], or [oC, kH, kW, iC] */ +public enum WeightsFormat { + YXIO, + + OIYX, + + OYXI +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/ImportClassMapping.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/ImportClassMapping.java index 043a16e87..6af2d462a 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/ImportClassMapping.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/ImportClassMapping.java @@ -93,6 +93,7 @@ public class ImportClassMapping { org.nd4j.linalg.api.ops.impl.grid.FreeGridOp.class, org.nd4j.linalg.api.ops.impl.image.CropAndResize.class, org.nd4j.linalg.api.ops.impl.image.ExtractImagePatches.class, + org.nd4j.linalg.api.ops.impl.image.ImageResize.class, org.nd4j.linalg.api.ops.impl.image.NonMaxSuppression.class, org.nd4j.linalg.api.ops.impl.image.NonMaxSuppressionV3.class, org.nd4j.linalg.api.ops.impl.image.ResizeBilinear.class, @@ -127,6 +128,7 @@ public class ImportClassMapping { org.nd4j.linalg.api.ops.impl.layers.convolution.DeConv3DDerivative.class, org.nd4j.linalg.api.ops.impl.layers.convolution.DepthToSpace.class, org.nd4j.linalg.api.ops.impl.layers.convolution.DepthwiseConv2D.class, + org.nd4j.linalg.api.ops.impl.layers.convolution.DepthwiseConv2DBp.class, org.nd4j.linalg.api.ops.impl.layers.convolution.Im2col.class, org.nd4j.linalg.api.ops.impl.layers.convolution.Im2colBp.class, org.nd4j.linalg.api.ops.impl.layers.convolution.LocalResponseNormalization.class, @@ -146,6 +148,7 @@ public class ImportClassMapping { org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMBlockCell.class, org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMCell.class, org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMLayer.class, + org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMLayerBp.class, org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMBlock.class, org.nd4j.linalg.api.ops.impl.layers.recurrent.SRU.class, org.nd4j.linalg.api.ops.impl.layers.recurrent.SRUCell.class, @@ -322,9 +325,12 @@ public class ImportClassMapping { org.nd4j.linalg.api.ops.impl.shape.Unstack.class, org.nd4j.linalg.api.ops.impl.shape.ZerosLike.class, org.nd4j.linalg.api.ops.impl.shape.bp.ConcatBp.class, + org.nd4j.linalg.api.ops.impl.shape.bp.MergeMaxBp.class, + org.nd4j.linalg.api.ops.impl.shape.bp.MergeAvgBp.class, org.nd4j.linalg.api.ops.impl.shape.bp.SliceBp.class, org.nd4j.linalg.api.ops.impl.shape.bp.StridedSliceBp.class, org.nd4j.linalg.api.ops.impl.shape.bp.TileBp.class, + org.nd4j.linalg.api.ops.impl.shape.tensorops.EmbeddingLookup.class, org.nd4j.linalg.api.ops.impl.shape.tensorops.TensorArray.class, org.nd4j.linalg.api.ops.impl.shape.tensorops.TensorArrayConcat.class, org.nd4j.linalg.api.ops.impl.shape.tensorops.TensorArrayGather.class, @@ -354,6 +360,7 @@ public class ImportClassMapping { org.nd4j.linalg.api.ops.impl.transforms.bool.IsInf.class, org.nd4j.linalg.api.ops.impl.transforms.bool.IsNaN.class, org.nd4j.linalg.api.ops.impl.transforms.bool.MatchConditionTransform.class, + org.nd4j.linalg.api.ops.impl.transforms.clip.ClipByAvgNorm.class, org.nd4j.linalg.api.ops.impl.transforms.clip.ClipByNorm.class, org.nd4j.linalg.api.ops.impl.transforms.clip.ClipByNormBp.class, org.nd4j.linalg.api.ops.impl.transforms.clip.ClipByValue.class, @@ -365,6 +372,8 @@ public class ImportClassMapping { org.nd4j.linalg.api.ops.impl.transforms.custom.BatchToSpace.class, org.nd4j.linalg.api.ops.impl.transforms.custom.BatchToSpaceND.class, org.nd4j.linalg.api.ops.impl.transforms.custom.Choose.class, + org.nd4j.linalg.api.ops.impl.transforms.custom.CReLU.class, + org.nd4j.linalg.api.ops.impl.transforms.custom.CReluBp.class, org.nd4j.linalg.api.ops.impl.transforms.custom.CumProd.class, org.nd4j.linalg.api.ops.impl.transforms.custom.CumSum.class, org.nd4j.linalg.api.ops.impl.transforms.custom.BitsHammingDistance.class, @@ -406,6 +415,7 @@ public class ImportClassMapping { org.nd4j.linalg.api.ops.impl.transforms.custom.MatrixInverse.class, org.nd4j.linalg.api.ops.impl.transforms.custom.MatrixSetDiag.class, org.nd4j.linalg.api.ops.impl.transforms.custom.Max.class, + org.nd4j.linalg.api.ops.impl.transforms.custom.MaximumBp.class, org.nd4j.linalg.api.ops.impl.transforms.custom.Min.class, org.nd4j.linalg.api.ops.impl.transforms.custom.MirrorPad.class, org.nd4j.linalg.api.ops.impl.transforms.custom.MultiHeadDotProductAttention.class, @@ -492,11 +502,13 @@ public class ImportClassMapping { org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.bp.FloorDivBpOp.class, org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.bp.FloorModBpOp.class, org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.bp.ModBpOp.class, + org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.bp.MergeAddBp.class, org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.bp.MulBpOp.class, org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.bp.RDivBpOp.class, org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.bp.RSubBpOp.class, org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.bp.SquaredDifferenceBpOp.class, org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.bp.SubBpOp.class, + org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.bp.SubBpOp.class, org.nd4j.linalg.api.ops.impl.transforms.pairwise.bool.And.class, org.nd4j.linalg.api.ops.impl.transforms.pairwise.bool.Not.class, org.nd4j.linalg.api.ops.impl.transforms.pairwise.bool.Or.class, diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseNDArray.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseNDArray.java index 07a2bf9b8..46daa869b 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseNDArray.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseNDArray.java @@ -2045,8 +2045,18 @@ public abstract class BaseNDArray implements INDArray, Iterable { throw new ND4JIllegalArgumentException("Indices must be a vector or matrix."); } - if(indices.rows() == rank()) { - INDArray ret = Nd4j.create(indices.dataType(), indices.columns()); + if (rank() == 1) { + Preconditions.checkArgument(indices.rank() <= 1, "For 1D vector indices must be either scalar or vector as well"); + val ret = Nd4j.createUninitialized(this.dataType(), indices.length()); + for (int e = 0; e < indices.length(); e++) { + val idx = indices.getLong(e); + val value = getDouble(idx); + ret.putScalar(e, value); + } + + return ret; + } else if(indices.rows() == rank()) { + INDArray ret = Nd4j.create(this.dataType(), indices.columns()); for(int i = 0; i < indices.columns(); i++) { int[] specifiedIndex = indices.getColumn(i).dup().data().asInt(); @@ -5391,6 +5401,10 @@ public abstract class BaseNDArray implements INDArray, Iterable { return sorted.getDouble(sorted.length() - 1); double pos = (quantile.doubleValue() / 100.0) * (double) (sorted.length() + 1); + if (pos < 1) + return sorted.getDouble(0); + else if (pos >= sorted.length()) + return sorted.getDouble(sorted.length() - 1); double fposition = FastMath.floor(pos); int position = (int)fposition; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseBroadcastBoolOp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseBroadcastBoolOp.java index afdc11aa4..e2ca9329e 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseBroadcastBoolOp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseBroadcastBoolOp.java @@ -22,6 +22,7 @@ import lombok.val; import onnx.Onnx; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.autodiff.util.SameDiffUtils; import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; @@ -56,8 +57,6 @@ public abstract class BaseBroadcastBoolOp extends BaseOp implements BroadcastOp int[] dimension) { super(sameDiff, inPlace, new Object[]{i_v2}); if (i_v1 != null && i_v2 != null) { - f().validateDifferentialFunctionsameDiff(i_v1); - f().validateDifferentialFunctionsameDiff(i_v2); this.sameDiff = sameDiff; this.inPlace = inPlace; this.dimension = dimension; @@ -80,9 +79,6 @@ public abstract class BaseBroadcastBoolOp extends BaseOp implements BroadcastOp super(sameDiff, extraArgs); this.dimension = dimension; if (i_v1 != null && i_v2 != null) { - f().validateDifferentialFunctionsameDiff(i_v1); - f().validateDifferentialFunctionsameDiff(i_v2); - this.sameDiff = sameDiff; sameDiff.addArgsFor(new SDVariable[]{i_v1,i_v2},this); @@ -107,7 +103,7 @@ public abstract class BaseBroadcastBoolOp extends BaseOp implements BroadcastOp super(sameDiff, inPlace, extraArgs); this.dimension = dimension; if (i_v != null) { - f().validateDifferentialFunctionsameDiff(i_v); + SameDiffUtils.validateDifferentialFunctionSameDiff(sameDiff, i_v, this); sameDiff.addArgsFor(new SDVariable[]{i_v},this); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseBroadcastOp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseBroadcastOp.java index 291b66d2b..56201560a 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseBroadcastOp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseBroadcastOp.java @@ -22,6 +22,7 @@ import lombok.val; import onnx.Onnx; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.autodiff.util.SameDiffUtils; import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; @@ -57,8 +58,6 @@ public abstract class BaseBroadcastOp extends BaseOp implements BroadcastOp { int[] dimension) { super(sameDiff, inPlace, new Object[]{i_v2}); if (i_v1 != null && i_v2 != null) { - f().validateDifferentialFunctionsameDiff(i_v1); - f().validateDifferentialFunctionsameDiff(i_v2); this.sameDiff = sameDiff; this.inPlace = inPlace; this.dimension = dimension; @@ -80,8 +79,8 @@ public abstract class BaseBroadcastOp extends BaseOp implements BroadcastOp { super(sameDiff, extraArgs); this.dimension = dimension; if (i_v1 != null && i_v2 != null) { - f().validateDifferentialFunctionsameDiff(i_v1); - f().validateDifferentialFunctionsameDiff(i_v2); + SameDiffUtils.validateDifferentialFunctionSameDiff(sameDiff, i_v1, this); + SameDiffUtils.validateDifferentialFunctionSameDiff(sameDiff, i_v2, this); this.sameDiff = sameDiff; sameDiff.addArgsFor(new SDVariable[]{i_v1,i_v2},this); @@ -107,7 +106,7 @@ public abstract class BaseBroadcastOp extends BaseOp implements BroadcastOp { super(sameDiff, inPlace, extraArgs); this.dimension = dimension; if (i_v != null) { - f().validateDifferentialFunctionsameDiff(i_v); + SameDiffUtils.validateDifferentialFunctionSameDiff(sameDiff, i_v, this); sameDiff.addArgsFor(new SDVariable[]{i_v},this); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseIndexAccumulation.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseIndexAccumulation.java index 8b598242c..502874cc1 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseIndexAccumulation.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseIndexAccumulation.java @@ -20,6 +20,7 @@ import lombok.Data; import lombok.extern.slf4j.Slf4j; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.autodiff.util.SameDiffUtils; import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; @@ -46,7 +47,6 @@ public abstract class BaseIndexAccumulation extends BaseOp implements IndexAccum super(sameDiff,null); if (i_v != null) { this.dimensions = dimensions; - f().validateDifferentialFunctionsameDiff(i_v); sameDiff.addArgsFor(new SDVariable[]{i_v},this); this.xVertexId = i_v.name(); @@ -65,8 +65,8 @@ public abstract class BaseIndexAccumulation extends BaseOp implements IndexAccum super(sameDiff,null); if (i_v != null) { this.dimensions = dimensions; - f().validateDifferentialFunctionsameDiff(i_v); - f().validateDifferentialFunctionsameDiff(i_v2); + SameDiffUtils.validateDifferentialFunctionSameDiff(sameDiff, i_v, this); + SameDiffUtils.validateDifferentialFunctionSameDiff(sameDiff, i_v2, this); this.xVertexId = i_v.name(); this.yVertexId = i_v2.name(); sameDiff.addArgsFor(new SDVariable[]{i_v,i_v2},this); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseOp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseOp.java index 30f3d0bf5..5c45ecf50 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseOp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseOp.java @@ -28,6 +28,7 @@ import org.nd4j.base.Preconditions; import org.nd4j.imports.NoOpNameFoundException; import org.nd4j.linalg.api.buffer.DataBuffer; import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.memory.MemoryWorkspace; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.shape.Shape; import org.nd4j.linalg.exception.ND4JIllegalStateException; @@ -348,7 +349,9 @@ public abstract class BaseOp extends DifferentialFunction implements Op { if (dimensions == null || dimensions.length == 0) dimensions = new int[]{Integer.MAX_VALUE}; - this.dimensionz = Shape.ndArrayDimFromInt(dimensions); + try(MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces()) { + this.dimensionz = Shape.ndArrayDimFromInt(dimensions); + } } public INDArray dimensions() { diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseReduceOp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseReduceOp.java index 9e5b8f67b..66c3e95d3 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseReduceOp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseReduceOp.java @@ -23,6 +23,7 @@ import lombok.val; import onnx.Onnx; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.autodiff.util.SameDiffUtils; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.shape.LongShapeDescriptor; import org.nd4j.linalg.api.shape.Shape; @@ -59,7 +60,7 @@ public abstract class BaseReduceOp extends BaseOp implements ReduceOp { dimensions = new int[] {Integer.MAX_VALUE}; this.dimensions = dimensions; - f().validateDifferentialFunctionsameDiff(i_v); + SameDiffUtils.validateDifferentialFunctionSameDiff(sameDiff, i_v, this); this.keepDims = keepDims; this.xVertexId = i_v.name(); sameDiff.addArgsFor(new String[]{xVertexId},this); @@ -83,8 +84,8 @@ public abstract class BaseReduceOp extends BaseOp implements ReduceOp { this.xVertexId = i_v.name(); this.yVertexId = i_v2.name(); - f().validateDifferentialFunctionsameDiff(i_v); - f().validateDifferentialFunctionsameDiff(i_v2); + SameDiffUtils.validateDifferentialFunctionSameDiff(sameDiff, i_v, this); + SameDiffUtils.validateDifferentialFunctionSameDiff(sameDiff, i_v2, this); this.keepDims = keepDims; sameDiff.addArgsFor(new String[]{xVertexId,yVertexId},this); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseScalarBoolOp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseScalarBoolOp.java index 8cb7e50b4..858b6a81c 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseScalarBoolOp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseScalarBoolOp.java @@ -19,6 +19,7 @@ package org.nd4j.linalg.api.ops; import lombok.extern.slf4j.Slf4j; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.autodiff.util.SameDiffUtils; import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; @@ -73,7 +74,7 @@ public abstract class BaseScalarBoolOp extends BaseOp implements ScalarOp { if (i_v != null) { this.xVertexId = i_v.name(); sameDiff.addArgsFor(new String[]{xVertexId},this); - f().validateDifferentialFunctionsameDiff(i_v); + SameDiffUtils.validateDifferentialFunctionSameDiff(sameDiff, i_v, this); } else { throw new IllegalArgumentException("Input not null variable."); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseScalarOp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseScalarOp.java index 254069929..66a204602 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseScalarOp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseScalarOp.java @@ -21,6 +21,7 @@ import lombok.extern.slf4j.Slf4j; import lombok.val; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.autodiff.util.SameDiffUtils; import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.memory.MemoryWorkspace; @@ -94,7 +95,7 @@ public abstract class BaseScalarOp extends BaseOp implements ScalarOp { this.scalarValue = Nd4j.scalar(i_v.dataType(), scalar); this.xVertexId = i_v.name(); sameDiff.addArgsFor(new String[]{xVertexId},this); - f().validateDifferentialFunctionsameDiff(i_v); + SameDiffUtils.validateDifferentialFunctionSameDiff(sameDiff, i_v, this); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseTransformOp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseTransformOp.java index 4e498edeb..7f8e0487e 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseTransformOp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseTransformOp.java @@ -20,6 +20,7 @@ import lombok.extern.slf4j.Slf4j; import lombok.val; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.autodiff.util.SameDiffUtils; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.shape.LongShapeDescriptor; import org.nd4j.linalg.api.shape.Shape; @@ -52,8 +53,8 @@ public abstract class BaseTransformOp extends BaseOp implements TransformOp { boolean inPlace) { super(sameDiff,inPlace,new Object[] {i_v2}); if (i_v1 != null && i_v2 != null) { - f().validateDifferentialFunctionsameDiff(i_v1); - f().validateDifferentialFunctionsameDiff(i_v2); + SameDiffUtils.validateDifferentialFunctionSameDiff(sameDiff, i_v1, this); + SameDiffUtils.validateDifferentialFunctionSameDiff(sameDiff, i_v2, this); this.sameDiff = sameDiff; this.inPlace = inPlace; this.xVertexId = i_v1.name(); @@ -77,8 +78,8 @@ public abstract class BaseTransformOp extends BaseOp implements TransformOp { super(sameDiff,extraArgs); if (i_v1 != null && i_v2 != null) { - f().validateDifferentialFunctionsameDiff(i_v1); - f().validateDifferentialFunctionsameDiff(i_v2); + SameDiffUtils.validateDifferentialFunctionSameDiff(sameDiff, i_v1, this); + SameDiffUtils.validateDifferentialFunctionSameDiff(sameDiff, i_v2, this); this.sameDiff = sameDiff; this.xVertexId = i_v1.name(); this.yVertexId = i_v2.name(); @@ -104,7 +105,7 @@ public abstract class BaseTransformOp extends BaseOp implements TransformOp { super(sameDiff,inPlace,extraArgs); if (i_v != null) { - f().validateDifferentialFunctionsameDiff(i_v); + SameDiffUtils.validateDifferentialFunctionSameDiff(sameDiff, i_v, this); this.xVertexId = i_v.name(); sameDiff.addArgsFor(new SDVariable[]{i_v},this); } else { diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/NoOp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/NoOp.java index b4cf2d05a..692571df9 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/NoOp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/NoOp.java @@ -20,6 +20,7 @@ import onnx.Onnx; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.shape.LongShapeDescriptor; import org.nd4j.linalg.factory.Nd4j; import org.tensorflow.framework.AttrValue; @@ -38,6 +39,10 @@ public class NoOp extends DynamicCustomOp { super("noop", sd, new SDVariable[]{in}); } + public NoOp(INDArray in) { + addInputArgument(in); + } + @Override public List doDiff(List f1) { return Collections.singletonList(f1.get(0)); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/broadcast/BiasAdd.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/broadcast/BiasAdd.java index f1b7f7398..cea3b388a 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/broadcast/BiasAdd.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/broadcast/BiasAdd.java @@ -85,7 +85,7 @@ public class BiasAdd extends DynamicCustomOp { @Override public List doDiff(List gradient){ - return Arrays.asList(f().biasAddBp(arg(0), arg(1), gradient.get(0), nchw)); + return new BiasAddGrad(sameDiff, arg(0), arg(1), gradient.get(0), nchw).outputs(); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/controlflow/compat/BaseCompatOp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/controlflow/compat/BaseCompatOp.java index 1bb451bf1..c1aff757d 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/controlflow/compat/BaseCompatOp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/controlflow/compat/BaseCompatOp.java @@ -25,6 +25,7 @@ import org.nd4j.imports.NoOpNameFoundException; import org.nd4j.imports.descriptors.properties.AttributeAdapter; import org.nd4j.imports.descriptors.properties.PropertyMapping; import org.nd4j.imports.graphmapper.tf.TFGraphMapper; +import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.api.shape.LongShapeDescriptor; import org.tensorflow.framework.AttrValue; @@ -41,6 +42,10 @@ public abstract class BaseCompatOp extends DynamicCustomOp { super(null, sameDiff, inputs); } + public BaseCompatOp(INDArray... inputs) { + addInputArgument(inputs); + } + public BaseCompatOp(){ } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/controlflow/compat/Merge.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/controlflow/compat/Merge.java index 993a5b11e..9adbd78df 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/controlflow/compat/Merge.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/controlflow/compat/Merge.java @@ -20,6 +20,7 @@ import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.Op; import org.nd4j.linalg.api.ops.Op.Type; import org.tensorflow.framework.AttrValue; @@ -36,6 +37,10 @@ public class Merge extends BaseCompatOp { super(sd, inputs); } + public Merge(INDArray... inputs) { + super(inputs); + } + public Merge(){ } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/controlflow/compat/StopGradient.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/controlflow/compat/StopGradient.java index c7d90e4c8..f302c752a 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/controlflow/compat/StopGradient.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/controlflow/compat/StopGradient.java @@ -50,6 +50,6 @@ public class StopGradient extends BaseDynamicTransformOp { @Override public List doDiff(List gradients){ - return Collections.singletonList(f().zerosLike(arg())); + return Collections.singletonList(sameDiff.zerosLike(arg())); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/controlflow/compat/Switch.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/controlflow/compat/Switch.java index 1b6c2f5e2..a7804f39f 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/controlflow/compat/Switch.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/controlflow/compat/Switch.java @@ -21,6 +21,7 @@ import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.Op; import org.nd4j.linalg.api.ops.Op.Type; import org.tensorflow.framework.AttrValue; @@ -44,6 +45,10 @@ public class Switch extends BaseCompatOp { this.predicate = predicate; } + public Switch(INDArray input, INDArray predicate) { + addInputArgument(input, predicate); + } + public Switch(){ } /** diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/image/ImageResize.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/image/ImageResize.java new file mode 100644 index 000000000..4bdca62a6 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/image/ImageResize.java @@ -0,0 +1,67 @@ +/* ****************************************************************************** + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.nd4j.linalg.api.ops.impl.image; + +import lombok.NoArgsConstructor; +import lombok.NonNull; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.base.Preconditions; +import org.nd4j.enums.ImageResizeMethod; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.DynamicCustomOp; + +import java.util.Collections; +import java.util.List; + +@NoArgsConstructor +public class ImageResize extends DynamicCustomOp { + + + + @Override + public String opName() { + return "image_resize"; + } + + + public ImageResize(@NonNull SameDiff sameDiff, @NonNull SDVariable in, @NonNull SDVariable size, boolean preserveAspectRatio, boolean antialias, ImageResizeMethod method) { + super("image_resize", sameDiff, new SDVariable[]{in, size}); + addBArgument(preserveAspectRatio, antialias); + addIArgument(method.ordinal()); + } + + public ImageResize(@NonNull INDArray in, @NonNull INDArray size, boolean preserveAspectRatio, boolean antialias, ImageResizeMethod method) { + super("image_resize", new INDArray[]{in, size}, null); + Preconditions.checkArgument(in.rank()==4,"expected input message in NHWC format i.e [batchSize, height, width, channels]"); + addBArgument(preserveAspectRatio, antialias); + addIArgument(method.ordinal()); + } + + + + @Override + public List calculateOutputDataTypes(List dataTypes) { + Preconditions + .checkArgument(dataTypes != null && dataTypes.size() == 2, "Expected exactly 2 input datatypes, got %s", dataTypes); + Preconditions.checkArgument(dataTypes.get(0).isFPType(), "Input datatype must be floating point, got %s", dataTypes); + + return Collections.singletonList(dataTypes.get(0)); + } + + +} \ No newline at end of file diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/indexaccum/IAMax.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/indexaccum/IAMax.java index b9c3962aa..181321d4f 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/indexaccum/IAMax.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/indexaccum/IAMax.java @@ -74,6 +74,6 @@ public class IAMax extends BaseIndexAccumulation { @Override public List doDiff(List grad){ - return Collections.singletonList(f().zerosLike(arg())); + return Collections.singletonList(sameDiff.zerosLike(arg())); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/indexaccum/IAMin.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/indexaccum/IAMin.java index 63f40ee6c..760fca314 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/indexaccum/IAMin.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/indexaccum/IAMin.java @@ -76,6 +76,6 @@ public class IAMin extends BaseIndexAccumulation { @Override public List doDiff(List grad){ - return Collections.singletonList(f().zerosLike(arg())); + return Collections.singletonList(sameDiff.zerosLike(arg())); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/indexaccum/IMax.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/indexaccum/IMax.java index 8b7872b49..127239bc7 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/indexaccum/IMax.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/indexaccum/IMax.java @@ -45,15 +45,14 @@ public class IMax extends BaseIndexAccumulation { super(x, z, dimensions); } - public IMax(INDArray x, boolean keepDims, int... dimensions) { - super(x, keepDims, dimensions); - - } - public IMax(INDArray x, int... dimensions) { super(x, null, dimensions); } + public IMax(INDArray x, boolean keepDims, int... dimensions) { + super(x, null, dimensions); + this.keepDims = keepDims; + } @Override public int opNum() { @@ -83,6 +82,6 @@ public class IMax extends BaseIndexAccumulation { @Override public List doDiff(List f1) { //Not differentiable, but (assuming no ties) output does not change for a given infinitesimal change in the input - return Collections.singletonList(f().zerosLike(arg())); + return Collections.singletonList(sameDiff.zerosLike(arg())); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/indexaccum/IMin.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/indexaccum/IMin.java index 06b3deb1c..a459e8c9c 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/indexaccum/IMin.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/indexaccum/IMin.java @@ -53,6 +53,7 @@ public class IMin extends BaseIndexAccumulation { } + @Override public int opNum() { return 1; @@ -77,6 +78,6 @@ public class IMin extends BaseIndexAccumulation { @Override public List doDiff(List f1) { //Not differentiable, but (assuming no ties) output does not change for a given infinitesimal change in the input - return Collections.singletonList(f().zerosLike(arg())); + return Collections.singletonList(sameDiff.zerosLike(arg())); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/Conv2D.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/Conv2D.java index 9635c6f36..5417b14cd 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/Conv2D.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/Conv2D.java @@ -60,7 +60,6 @@ public class Conv2D extends DynamicCustomOp { SDVariable bias, @NonNull Conv2DConfig conv2DConfig) { this(sameDiff, wrapFilterNull(input, weights, bias), conv2DConfig); } - @Builder(builderMethodName = "sameDiffBuilder") public Conv2D(SameDiff sameDiff, SDVariable[] inputFunctions, @@ -71,7 +70,7 @@ public class Conv2D extends DynamicCustomOp { } public Conv2D(INDArray[] inputs, INDArray[] outputs, Conv2DConfig config){ - super(inputs, outputs); + super(inputs, outputs); initConfig(config); } @@ -103,7 +102,8 @@ public class Conv2D extends DynamicCustomOp { config.getDH(), config.getDW(), ArrayUtil.fromBoolean(config.isSameMode()), - config.getDataFormat().equalsIgnoreCase(Conv2DConfig.NCHW) ? 0 : 1); + config.getDataFormat().equalsIgnoreCase("NCHW") ? 0 : 1, + config.getWeightsFormat().ordinal()); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/DeConv3D.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/DeConv3D.java index 436659443..d0b04b36a 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/DeConv3D.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/DeConv3D.java @@ -161,8 +161,7 @@ public class DeConv3D extends DynamicCustomOp { @Override public List doDiff(List f1) { SDVariable bias = args().length > 2 ? arg(2) : null; - SDVariable[] outVars = f().deconv3dDerivative(arg(0), arg(1), bias, f1.get(0), config); - return Arrays.asList(outVars); + return new DeConv3DDerivative(sameDiff, arg(0), arg(1), bias, f1.get(0), config).outputs(); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/DepthwiseConv2D.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/DepthwiseConv2D.java index afb51af58..798b544b8 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/DepthwiseConv2D.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/DepthwiseConv2D.java @@ -56,9 +56,11 @@ public class DepthwiseConv2D extends DynamicCustomOp { protected Conv2DConfig config; + public DepthwiseConv2D(@NonNull SameDiff sameDiff, @NonNull SDVariable input, @NonNull SDVariable weights, SDVariable bias, @NonNull Conv2DConfig conv2DConfig) { this(sameDiff, wrapFilterNull(input, weights, bias), conv2DConfig); + } @Builder(builderMethodName = "sameDiffBuilder") @@ -71,14 +73,14 @@ public class DepthwiseConv2D extends DynamicCustomOp { addArgs(); } - public DepthwiseConv2D(INDArray[] inputs, INDArray[] outputs, Conv2DConfig config){ + public DepthwiseConv2D(INDArray[] inputs, INDArray[] outputs, Conv2DConfig config) { super(inputs, outputs); this.config = config; addArgs(); } - public DepthwiseConv2D(@NonNull INDArray input, @NonNull INDArray weights, INDArray bias, INDArray output, @NonNull Conv2DConfig config){ + public DepthwiseConv2D(@NonNull INDArray input, @NonNull INDArray weights, INDArray bias, INDArray output, @NonNull Conv2DConfig config) { this(wrapFilterNull(input, weights, bias), wrapOrNull(output), config); } @@ -127,7 +129,7 @@ public class DepthwiseConv2D extends DynamicCustomOp { @Override public Map propertiesForFunction() { - if(config == null && !iArguments.isEmpty()){ + if (config == null && !iArguments.isEmpty()) { config = Conv2DConfig.builder() .kH(iArguments.get(0)) .kW(iArguments.get(1)) @@ -308,7 +310,9 @@ public class DepthwiseConv2D extends DynamicCustomOp { @Override public List doDiff(List f1) { - throw new UnsupportedOperationException("Not implemented yet"); + SDVariable bias = args().length==2 ? null : arg(2); + return Arrays.asList(new DepthwiseConv2DBp(sameDiff, arg(0), arg(1), bias, f1.get(0), this.config).outputVariables()); + } @@ -323,7 +327,7 @@ public class DepthwiseConv2D extends DynamicCustomOp { } @Override - public List calculateOutputDataTypes(List inputDataTypes){ + public List calculateOutputDataTypes(List inputDataTypes) { int n = args().length; Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == n, "Expected %s input data types for %s, got %s", n, getClass(), inputDataTypes); return Collections.singletonList(inputDataTypes.get(0)); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/DepthwiseConv2DBp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/DepthwiseConv2DBp.java new file mode 100644 index 000000000..482944fe2 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/DepthwiseConv2DBp.java @@ -0,0 +1,150 @@ +/* ****************************************************************************** + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.nd4j.linalg.api.ops.impl.layers.convolution; + +import lombok.*; +import lombok.extern.slf4j.Slf4j; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.base.Preconditions; +import org.nd4j.imports.converters.DifferentialFunctionClassHolder; +import org.nd4j.imports.descriptors.properties.AttributeAdapter; +import org.nd4j.imports.descriptors.properties.PropertyMapping; +import org.nd4j.imports.descriptors.properties.adapters.ConditionalFieldValueIntIndexArrayAdapter; +import org.nd4j.imports.descriptors.properties.adapters.NDArrayShapeAdapter; +import org.nd4j.imports.descriptors.properties.adapters.SizeThresholdIntArrayIntIndexAdpater; +import org.nd4j.imports.descriptors.properties.adapters.StringEqualsAdapter; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.DynamicCustomOp; +import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv2DConfig; +import org.nd4j.linalg.util.ArrayUtil; + +import java.lang.reflect.Field; +import java.util.*; + + +/** + * Backpropagation for Depthwise Conv2D operation + */ +@Slf4j +@Getter +@NoArgsConstructor +public class DepthwiseConv2DBp extends DynamicCustomOp { + + protected Conv2DConfig config; + + + public DepthwiseConv2DBp(@NonNull SameDiff sameDiff, @NonNull SDVariable input, @NonNull SDVariable weights, SDVariable bias, @NonNull SDVariable gradO, @NonNull Conv2DConfig config){ + super(sameDiff, wrapFilterNull(input, weights, bias, gradO)); + this.config = config; + addArgs(); + + } + + public DepthwiseConv2DBp(@NonNull SameDiff sameDiff, @NonNull SDVariable input, @NonNull SDVariable weights, @NonNull SDVariable gradO, @NonNull Conv2DConfig config){ + super(sameDiff, wrapFilterNull(input, weights, gradO)); + this.config = config; + addArgs(); + + } + + + @Override + public long[] iArgs() { + if (iArguments.size() == 0) + addArgs(); + + return super.iArgs(); + } + + protected void addArgs() { + addIArgument(config.getKH(), + config.getKW(), + config.getSH(), + config.getSW(), + config.getPH(), + config.getPW(), + config.getDH(), + config.getDW(), + ArrayUtil.fromBoolean(config.isSameMode()), + config.getDataFormat().equalsIgnoreCase(Conv2DConfig.NCHW) ? 0 : 1); + + } + + @Override + public Object getValue(Field property) { + if (config == null) { + config = Conv2DConfig.builder().build(); + } + + try { + val t = config.getValue(property); + return t; + } catch (Exception e) { + throw new RuntimeException(e); + } + } + + @Override + public Map propertiesForFunction() { + if (config == null && !iArguments.isEmpty()) { + config = Conv2DConfig.builder() + .kH(iArguments.get(0)) + .kW(iArguments.get(1)) + .sH(iArguments.get(2)) + .sW(iArguments.get(3)) + .pH(iArguments.get(4)) + .pW(iArguments.get(5)) + .dH(iArguments.get(6)) + .dW(iArguments.get(7)) + .isSameMode(iArguments.get(8) == 1) + .dataFormat(iArguments.get(9) == 1 ? Conv2DConfig.NHWC : Conv2DConfig.NCHW) + .build(); + } + return config.toProperties(); + } + + + @Override + public boolean isConfigProperties() { + return true; + } + + @Override + public String configFieldName() { + return "config"; + } + + @Override + public String opName() { + return "depthwise_conv2d_bp"; + } + + + @Override + public List calculateOutputDataTypes(List inputDataTypes) { + int n = args().length; + List list = new ArrayList(); + for(int i=0;i doDiff(List grad) { - return Collections.singletonList(f().im2ColBp(arg(), grad.get(0), conv2DConfig)); + return new Im2colBp(sameDiff, arg(), grad.get(0), conv2DConfig).outputs(); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/Upsampling2d.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/Upsampling2d.java index df345a2f3..3370b6f30 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/Upsampling2d.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/Upsampling2d.java @@ -99,7 +99,7 @@ public class Upsampling2d extends DynamicCustomOp { @Override public List doDiff(List f1) { - return Collections.singletonList(f().upsampling2dBp(arg(), f1.get(0), nchw, scaleH, scaleW)); + return new Upsampling2dDerivative(sameDiff, arg(), f1.get(0), nchw, scaleH, scaleW).outputs(); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/config/Conv2DConfig.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/config/Conv2DConfig.java index 40a2a3908..92701a696 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/config/Conv2DConfig.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/config/Conv2DConfig.java @@ -22,6 +22,7 @@ import lombok.Builder; import lombok.Data; import lombok.NoArgsConstructor; import org.nd4j.base.Preconditions; +import org.nd4j.enums.WeightsFormat; import org.nd4j.linalg.util.ConvConfigUtil; @Data @@ -50,9 +51,11 @@ public class Conv2DConfig extends BaseConvolutionConfig { private boolean isSameMode; @Builder.Default private String dataFormat = NCHW; + @Builder.Default + private WeightsFormat weightsFormat = WeightsFormat.YXIO; public Conv2DConfig(long kH, long kW, long sH, long sW, long pH, long pW, long dH, long dW, boolean isSameMode, - String dataFormat) { + String dataFormat, WeightsFormat weightsFormat) { this.kH = kH; this.kW = kW; @@ -64,6 +67,7 @@ public class Conv2DConfig extends BaseConvolutionConfig { this.dW = dW; this.isSameMode = isSameMode; this.dataFormat = dataFormat; + this.weightsFormat = weightsFormat; validate(); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/LSTMLayer.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/LSTMLayer.java index a433b23d6..bebbd5f8f 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/LSTMLayer.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/LSTMLayer.java @@ -16,6 +16,7 @@ package org.nd4j.linalg.api.ops.impl.layers.recurrent; import lombok.Getter; +import lombok.NoArgsConstructor; import lombok.NonNull; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; @@ -23,6 +24,7 @@ import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; +import org.nd4j.linalg.api.ops.impl.layers.convolution.DepthwiseConv2DBp; import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.LSTMLayerConfig; import org.nd4j.linalg.api.ops.impl.layers.recurrent.weights.LSTMLayerWeights; import org.nd4j.shade.guava.primitives.Booleans; @@ -60,6 +62,7 @@ import java.util.Map; * 1: output at last step hL - rank 3 or 4, depends on DirectionMode and dataFormat<
* 2: cell state at last step cL - same shape as in hL
*/ +@NoArgsConstructor public class LSTMLayer extends DynamicCustomOp { @Getter @@ -68,14 +71,18 @@ public class LSTMLayer extends DynamicCustomOp { @Getter private LSTMLayerWeights weights; + private SDVariable cLast; + private SDVariable yLast; + private SDVariable maxTSLength; - public LSTMLayer() { - } public LSTMLayer(@NonNull SameDiff sameDiff, SDVariable x, SDVariable cLast, SDVariable yLast, SDVariable maxTSLength, LSTMLayerWeights weights, LSTMLayerConfig configuration) { super(null, sameDiff, weights.argsWithInputs(x, maxTSLength, cLast, yLast)); this.configuration = configuration; this.weights = weights; + this.cLast = cLast; + this.yLast = yLast; + this.maxTSLength = maxTSLength; addIArgument(iArgs()); addTArgument(tArgs()); addBArgument(bArgs(weights, maxTSLength, yLast, cLast)); @@ -124,7 +131,13 @@ public class LSTMLayer extends DynamicCustomOp { @Override public List doDiff(List grads) { - throw new UnsupportedOperationException("Not yet implemented"); + int i=0; + SDVariable grad0 = this.configuration.isRetFullSequence() ? grads.get(i++): null; + SDVariable grad1 = this.configuration.isRetLastH() ? grads.get(i++): null; + SDVariable grad2 = this.configuration.isRetLastC() ? grads.get(i++): null; + + return Arrays.asList(new LSTMLayerBp(sameDiff, arg(0), this.cLast, this.yLast, this.maxTSLength, + this.weights, this.configuration, grad0, grad1,grad2).outputVariables()); } @@ -155,7 +168,7 @@ public class LSTMLayer extends DynamicCustomOp { } - public boolean[] bArgs(LSTMLayerWeights weights, T maxTSLength, T yLast, T cLast) { + protected boolean[] bArgs(LSTMLayerWeights weights, T maxTSLength, T yLast, T cLast) { return new boolean[]{ weights.hasBias(), // hasBiases: B_ARG(0) maxTSLength != null, // hasSeqLen: B_ARG(1) @@ -169,6 +182,16 @@ public class LSTMLayer extends DynamicCustomOp { } + @Override + public boolean isConfigProperties() { + return true; + } + + @Override + public String configFieldName() { + return "configuration"; + } + @Override public int getNumOutputs(){ diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/LSTMLayerBp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/LSTMLayerBp.java new file mode 100644 index 000000000..d6ffcd6e5 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/LSTMLayerBp.java @@ -0,0 +1,176 @@ +/* ****************************************************************************** + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.nd4j.linalg.api.ops.impl.layers.recurrent; + +import lombok.Getter; +import lombok.NoArgsConstructor; +import lombok.NonNull; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.base.Preconditions; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.DynamicCustomOp; +import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.LSTMLayerConfig; +import org.nd4j.linalg.api.ops.impl.layers.recurrent.weights.LSTMLayerWeights; +import org.nd4j.shade.guava.primitives.Booleans; + +import javax.xml.crypto.Data; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.Map; + + +/** + * LSTM layer backpropagation + */ +@NoArgsConstructor +public class LSTMLayerBp extends DynamicCustomOp { + + @Getter + private LSTMLayerConfig configuration; + + @Getter + private LSTMLayerWeights weights; + + private SDVariable cLast; + private SDVariable yLast; + private SDVariable maxTSLength; + + + public LSTMLayerBp(@NonNull SameDiff sameDiff, @NonNull SDVariable x, SDVariable cLast, SDVariable yLast, SDVariable maxTSLength, @NonNull LSTMLayerWeights weights, @NonNull LSTMLayerConfig configuration, + SDVariable dLdh, SDVariable dLdhL, SDVariable dLdcL) { + super("lstmLayer_bp", sameDiff, wrapFilterNull(x, weights.getWeights(), weights.getRWeights(), weights.getBias(), + maxTSLength, yLast, cLast, weights.getPeepholeWeights(), dLdh, dLdhL, dLdcL)); + this.configuration = configuration; + this.weights = weights; + this.cLast = cLast; + this.yLast = yLast; + this.maxTSLength = maxTSLength; + addIArgument(iArgs()); + addTArgument(tArgs()); + addBArgument(bArgs(weights, maxTSLength, yLast, cLast)); + + + Preconditions.checkState(this.configuration.isRetLastH() || this.configuration.isRetLastC() || this.configuration.isRetFullSequence(), + "You have to specify at least one output you want to return. Use isRetLastC, isRetLast and isRetFullSequence methods in LSTMLayerConfig builder to specify them"); + + + } + + @Override + public List calculateOutputDataTypes(List inputDataTypes) { + + DataType dt = inputDataTypes.get(1); + Preconditions.checkState(dt.isFPType(), "Input type 1 must be a floating point type, got %s", dt); + ArrayList list = new ArrayList<>(); + list.add(dt); // dLdx + list.add(dt); // dLdWx + list.add(dt); // dLdWr + + if (this.weights.hasBias()) { + list.add(dt); + } // dLdb + + if (this.maxTSLength != null) { + list.add(dt); + } // dLdSl + if (this.yLast != null) { + list.add(dt); + } //dLdhI + if (this.cLast != null) { + list.add(dt); + } // dLdcI + if (this.weights.hasPH()) { + list.add(dt); + } // dLdWp + + return list; + } + + + @Override + public String opName() { + return "lstmLayer_bp"; + } + + @Override + public Map propertiesForFunction() { + return configuration.toProperties(true, true); + } + + + public long[] iArgs() { + return new long[]{ + configuration.getLstmdataformat().ordinal(),// INT_ARG(0) + configuration.getDirectionMode().ordinal(), // INT_ARG(1) + configuration.getGateAct().ordinal(), // INT_ARG(2) + configuration.getOutAct().ordinal(), // INT_ARG(3) + configuration.getCellAct().ordinal() // INT_ARG(4) + + }; + } + + public double[] tArgs() { + return new double[]{this.configuration.getCellClip()}; // T_ARG(0) + } + + + protected boolean[] bArgs(LSTMLayerWeights weights, T maxTSLength, T yLast, T cLast) { + return new boolean[]{ + weights.hasBias(), // hasBiases: B_ARG(0) + maxTSLength != null, // hasSeqLen: B_ARG(1) + yLast != null, // hasInitH: B_ARG(2) + cLast != null, // hasInitC: B_ARG(3) + weights.hasPH(), // hasPH: B_ARG(4) + configuration.isRetFullSequence(), //retFullSequence: B_ARG(5) + configuration.isRetLastH(), // retLastH: B_ARG(6) + configuration.isRetLastC() // retLastC: B_ARG(7) + }; + + } + + @Override + public boolean isConfigProperties() { + return true; + } + + @Override + public String configFieldName() { + return "configuration"; + } + + + @Override + public int getNumOutputs() { + + return Booleans.countTrue( + true, + true, + true, + weights.hasBias(), + this.maxTSLength != null, + this.yLast != null, + this.cLast != null, + weights.hasPH() + ); + } + + +} + + diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/config/LSTMLayerConfig.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/config/LSTMLayerConfig.java index 9901213da..226150e8b 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/config/LSTMLayerConfig.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/config/LSTMLayerConfig.java @@ -15,8 +15,10 @@ ******************************************************************************/ package org.nd4j.linalg.api.ops.impl.layers.recurrent.config; +import lombok.AllArgsConstructor; import lombok.Builder; import lombok.Data; +import lombok.NoArgsConstructor; import org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMBlockCell; import org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMLayer; @@ -26,9 +28,10 @@ import java.util.Map; @Builder @Data +@AllArgsConstructor +@NoArgsConstructor public class LSTMLayerConfig { - /** * notations
* for unidirectional: @@ -90,23 +93,23 @@ public class LSTMLayerConfig { * Cell clipping value, if it = 0 then do not apply clipping */ @Builder.Default - private double cellClip; //T_ARG(0) + private double cellClip = 0; //T_ARG(0) public Map toProperties(boolean includeLSTMDataFormat, boolean includeLSTMDirectionMode) { Map ret = new LinkedHashMap<>(); - ret.put("gateAct", gateAct.ordinal()); - ret.put("outAct", outAct.ordinal()); - ret.put("cellAct", cellAct.ordinal()); + ret.put("gateAct", gateAct.toString()); + ret.put("outAct", outAct.toString()); + ret.put("cellAct", cellAct.toString()); ret.put("retFullSequence", retFullSequence); ret.put("retLastH", retLastH); ret.put("retLastC", retLastC); ret.put("cellClip", cellClip); if (includeLSTMDataFormat) - ret.put("LSTMDataFormat", lstmdataformat.ordinal()); + ret.put("lstmdataformat", lstmdataformat.toString()); if (includeLSTMDirectionMode) - ret.put("LSTMDirectionMode", directionMode.ordinal()); + ret.put("directionMode", directionMode.toString()); return ret; } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/AbsoluteDifferenceLoss.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/AbsoluteDifferenceLoss.java index adc59e4e0..4f6539eee 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/AbsoluteDifferenceLoss.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/AbsoluteDifferenceLoss.java @@ -22,6 +22,7 @@ import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.impl.loss.bp.AbsoluteDifferenceLossBp; import java.util.Arrays; import java.util.List; @@ -58,7 +59,6 @@ public class AbsoluteDifferenceLoss extends BaseLoss { public List doDiff(List grad){ //No external gradient //Args are: predictions, weights, label - SDVariable[] grads = f().lossAbsoluteDifferenceBP(arg(2), arg(0), arg(1), lossReduce); - return Arrays.asList(grads); + return new AbsoluteDifferenceLossBp(sameDiff, lossReduce, arg(0), arg(1), arg(2)).outputs(); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/CosineDistanceLoss.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/CosineDistanceLoss.java index 7faa5f6b0..432910391 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/CosineDistanceLoss.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/CosineDistanceLoss.java @@ -20,6 +20,7 @@ import org.nd4j.autodiff.loss.LossReduce; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.impl.loss.bp.CosineDistanceLossBp; import java.util.Arrays; import java.util.List; @@ -61,8 +62,7 @@ public class CosineDistanceLoss extends BaseLoss { public List doDiff(List grad){ //No external gradient. //Args are: predictions, weights, label - SDVariable[] grads = f().lossCosineDistanceBp(arg(2), arg(0), arg(1), lossReduce, dimension); - return Arrays.asList(grads); + return new CosineDistanceLossBp(sameDiff, lossReduce, arg(0), arg(1), arg(2), dimension).outputs(); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/HingeLoss.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/HingeLoss.java index 5d85e4933..d021623d5 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/HingeLoss.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/HingeLoss.java @@ -20,6 +20,7 @@ import org.nd4j.autodiff.loss.LossReduce; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.impl.loss.bp.HingeLossBp; import java.util.Arrays; import java.util.List; @@ -56,8 +57,7 @@ public class HingeLoss extends BaseLoss { public List doDiff(List grad){ //No external gradient //Args are: predictions, weights, label - SDVariable[] grads = f().lossHingeBp(arg(2), arg(0), arg(1), lossReduce); - return Arrays.asList(grads); + return new HingeLossBp(sameDiff, lossReduce, arg(0), arg(1), arg(2)).outputs(); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/HuberLoss.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/HuberLoss.java index f08d90566..acb74c04c 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/HuberLoss.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/HuberLoss.java @@ -21,6 +21,7 @@ import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.impl.loss.bp.HuberLossBp; import java.util.Arrays; import java.util.List; @@ -63,8 +64,7 @@ public class HuberLoss extends BaseLoss { public List doDiff(List grad){ //No external gradient //Args are: predictions, weights, label - SDVariable[] grads = f().lossHuberBp(arg(2), arg(0), arg(1), lossReduce, delta); - return Arrays.asList(grads); + return new HuberLossBp(sameDiff, lossReduce, arg(0), arg(1), arg(2), delta).outputs(); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/L2Loss.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/L2Loss.java index e1fe56e5f..d36d36c2f 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/L2Loss.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/L2Loss.java @@ -62,6 +62,6 @@ public class L2Loss extends DynamicCustomOp { public List doDiff(List grad){ //L2 loss: L = 1/2 * sum(x_i^2) //dL/dxi = xi - return Collections.singletonList(f().identity(arg())); + return Collections.singletonList(sameDiff.identity(arg())); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/LogLoss.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/LogLoss.java index a7a15f1b5..c13634ee1 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/LogLoss.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/LogLoss.java @@ -20,6 +20,7 @@ import org.nd4j.autodiff.loss.LossReduce; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.impl.loss.bp.LogLossBp; import java.util.Arrays; import java.util.List; @@ -64,8 +65,7 @@ public class LogLoss extends BaseLoss { public List doDiff(List grad){ //No external gradient //Args are: predictions, weights, label - SDVariable[] grads = f().lossLogBp(arg(2), arg(0), arg(1), lossReduce, epsilon); - return Arrays.asList(grads); + return new LogLossBp(sameDiff, lossReduce, arg(0), arg(1), arg(2), epsilon).outputs(); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/LogPoissonLoss.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/LogPoissonLoss.java index a893e3f4a..2ec6e54b5 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/LogPoissonLoss.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/LogPoissonLoss.java @@ -20,6 +20,7 @@ import org.nd4j.autodiff.loss.LossReduce; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.impl.loss.bp.LogPoissonLossBp; import java.util.Arrays; import java.util.List; @@ -73,14 +74,7 @@ public class LogPoissonLoss extends BaseLoss { public List doDiff(List grad){ //No external gradient //Args are: predictions, weights, label - - SDVariable[] grads; - if(full) { - grads = f().lossLogPoissonFullBp(arg(2), arg(0), arg(1), lossReduce); - }else{ - grads = f().lossLogPoissonBp(arg(2), arg(0), arg(1), lossReduce); - } - return Arrays.asList(grads); + return new LogPoissonLossBp(sameDiff, lossReduce, arg(0), arg(1), arg(2), full).outputs(); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/MeanPairwiseSquaredErrorLoss.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/MeanPairwiseSquaredErrorLoss.java index 6c3c5d01b..676eec5e0 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/MeanPairwiseSquaredErrorLoss.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/MeanPairwiseSquaredErrorLoss.java @@ -20,6 +20,7 @@ import org.nd4j.autodiff.loss.LossReduce; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.impl.loss.bp.MeanPairwiseSquaredErrorLossBp; import java.util.Arrays; import java.util.List; @@ -54,7 +55,6 @@ public class MeanPairwiseSquaredErrorLoss extends BaseLoss { public List doDiff(List grad){ //No external gradient //Args are: predictions, weights, label - SDVariable[] grads = f().lossMeanPairwiseSquaredErrorBp(arg(2), arg(0), arg(1), lossReduce); - return Arrays.asList(grads); + return new MeanPairwiseSquaredErrorLossBp(sameDiff, lossReduce, arg(0), arg(1), arg(2)).outputs(); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/MeanSquaredErrorLoss.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/MeanSquaredErrorLoss.java index a9cf27584..c40d9e432 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/MeanSquaredErrorLoss.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/MeanSquaredErrorLoss.java @@ -20,6 +20,7 @@ import org.nd4j.autodiff.loss.LossReduce; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.impl.loss.bp.MeanSquaredErrorLossBp; import java.util.Arrays; import java.util.List; @@ -56,8 +57,7 @@ public class MeanSquaredErrorLoss extends BaseLoss { public List doDiff(List grad){ //No external gradient //Args are: predictions, weights, label - SDVariable[] grads = f().lossMeanSquaredErrorBp(arg(2), arg(0), arg(1), lossReduce); - return Arrays.asList(grads); + return new MeanSquaredErrorLossBp(sameDiff, lossReduce, arg(0), arg(1), arg(2)).outputs(); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/SigmoidCrossEntropyLoss.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/SigmoidCrossEntropyLoss.java index 214380a8c..862b405d4 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/SigmoidCrossEntropyLoss.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/SigmoidCrossEntropyLoss.java @@ -27,6 +27,7 @@ import org.nd4j.imports.graphmapper.tf.TFGraphMapper; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.api.ops.Op; +import org.nd4j.linalg.api.ops.impl.loss.bp.SigmoidCrossEntropyLossBp; import org.tensorflow.framework.AttrValue; import org.tensorflow.framework.GraphDef; import org.tensorflow.framework.NodeDef; @@ -80,7 +81,6 @@ public class SigmoidCrossEntropyLoss extends BaseLoss { public List doDiff(List grad){ //No external gradient //Args are: predictions, weights, label - SDVariable[] grads = f().lossSigmoidCrossEntropyBp(arg(2), arg(0), arg(1), lossReduce, labelSmoothing); - return Arrays.asList(grads); + return new SigmoidCrossEntropyLossBp(sameDiff, lossReduce, arg(0), arg(1), arg(2), labelSmoothing).outputs(); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/SoftmaxCrossEntropyLoss.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/SoftmaxCrossEntropyLoss.java index 57576b78f..e97427e92 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/SoftmaxCrossEntropyLoss.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/SoftmaxCrossEntropyLoss.java @@ -25,6 +25,7 @@ import org.nd4j.imports.graphmapper.tf.TFGraphMapper; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.api.ops.Op; +import org.nd4j.linalg.api.ops.impl.loss.bp.SoftmaxCrossEntropyLossBp; import org.tensorflow.framework.AttrValue; import org.tensorflow.framework.GraphDef; import org.tensorflow.framework.NodeDef; @@ -99,7 +100,6 @@ public class SoftmaxCrossEntropyLoss extends BaseLoss { public List doDiff(List grad){ //No external gradient //Args are: predictions, weights, label - SDVariable[] grads = f().lossSoftmaxCrossEntropyBp(arg(2), arg(0), arg(1), lossReduce, labelSmoothing); - return Arrays.asList(grads); + return new SoftmaxCrossEntropyLossBp(sameDiff, lossReduce, arg(0), arg(1), arg(2), labelSmoothing).outputs(); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/SoftmaxCrossEntropyWithLogitsLoss.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/SoftmaxCrossEntropyWithLogitsLoss.java index 3ef7de264..defb8292b 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/SoftmaxCrossEntropyWithLogitsLoss.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/SoftmaxCrossEntropyWithLogitsLoss.java @@ -22,6 +22,7 @@ import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ops.DynamicCustomOp; +import org.nd4j.linalg.api.ops.impl.loss.bp.SoftmaxCrossEntropyWithLogitsLossBp; import org.nd4j.linalg.factory.Nd4j; import java.util.Arrays; @@ -73,8 +74,6 @@ public class SoftmaxCrossEntropyWithLogitsLoss extends DynamicCustomOp { public List doDiff(List grad){ //No external gradient //Args: logits, weigths, label - SDVariable[] args = args(); - SDVariable[] grads = f().lossSoftmaxCrossEntropyWithLogitsBp(arg(0), arg(1), classesDim); - return Arrays.asList(grads); + return new SoftmaxCrossEntropyWithLogitsLossBp(sameDiff, arg(0), arg(1), classesDim).outputs(); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/SparseSoftmaxCrossEntropyLossWithLogits.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/SparseSoftmaxCrossEntropyLossWithLogits.java index a0f3288a9..c58933134 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/SparseSoftmaxCrossEntropyLossWithLogits.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/SparseSoftmaxCrossEntropyLossWithLogits.java @@ -28,6 +28,7 @@ import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.api.ops.Op; +import org.nd4j.linalg.api.ops.impl.loss.bp.SparseSoftmaxCrossEntropyLossWithLogitsBp; import org.tensorflow.framework.AttrValue; import org.tensorflow.framework.GraphDef; import org.tensorflow.framework.NodeDef; @@ -96,7 +97,8 @@ public class SparseSoftmaxCrossEntropyLossWithLogits extends DynamicCustomOp { @Override public List doDiff(List grad){ //args: label, logits - SDVariable[] ret = f().lossSparseSoftmaxCrossEntropyBp(arg(1), arg(0)); - return Arrays.asList(f().zerosLike(arg(0)), ret[0]); + SDVariable labelsGrad = sameDiff.zerosLike(arg(0)); + SDVariable logitsGrad = new SparseSoftmaxCrossEntropyLossWithLogitsBp(sameDiff, arg(1), arg(0)).outputVariable(); + return Arrays.asList(labelsGrad, logitsGrad); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/Mmul.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/Mmul.java index 46310893d..d4bec6ac2 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/Mmul.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/Mmul.java @@ -141,8 +141,8 @@ public class Mmul extends DynamicCustomOp { boolean transposeZ) { super(null,sameDiff,new SDVariable[]{x,y}); addIArgument(ArrayUtil.fromBoolean(transposeX), - ArrayUtil.fromBoolean(transposeY), - ArrayUtil.fromBoolean(transposeZ)); + ArrayUtil.fromBoolean(transposeY), + ArrayUtil.fromBoolean(transposeZ)); addTArgument(alpha, beta); mt = MMulTranspose.builder().transposeA(transposeX).transposeB(transposeY).transposeResult(transposeZ).build(); @@ -266,7 +266,7 @@ public class Mmul extends DynamicCustomOp { @Override public List doDiff(List gradients) { - return sameDiff.f().mmulBp(larg(),rarg(), gradients.get(0), mt); + return Arrays.asList(new MmulBp(sameDiff, larg(), rarg(), gradients.get(0), mt).outputVariables()); } @@ -306,4 +306,3 @@ public class Mmul extends DynamicCustomOp { return Collections.singletonList(dataTypes.get(0)); } } - diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/Moments.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/Moments.java index bf3cb4af1..89ba1549b 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/Moments.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/Moments.java @@ -25,6 +25,8 @@ import org.nd4j.imports.graphmapper.tf.TFGraphMapper; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; +import org.nd4j.linalg.api.ops.impl.reduce.bp.MeanBp; +import org.nd4j.linalg.api.ops.impl.reduce.bp.VarianceBp; import org.nd4j.linalg.factory.Nd4j; import org.tensorflow.framework.AttrValue; import org.tensorflow.framework.GraphDef; @@ -83,8 +85,8 @@ public class Moments extends DynamicCustomOp { public List doDiff(List grad){ SDVariable dLdMean = grad.get(0); SDVariable dLdVar = grad.get(1); //Note: non-bias-corrected variance - SDVariable meanBp = f().meanBp(arg(), dLdMean, false, axes); - SDVariable varBp = f().varianceBp(arg(), dLdVar, false, false, axes); + SDVariable meanBp = new MeanBp(sameDiff, arg(), dLdMean, false, axes).outputVariable(); + SDVariable varBp = new VarianceBp(sameDiff, arg(), dLdVar, false, false, axes).outputVariable(); return Collections.singletonList(meanBp.add(varBp)); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/TensorMmul.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/TensorMmul.java index c613f107f..820df18ab 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/TensorMmul.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/TensorMmul.java @@ -51,6 +51,35 @@ public class TensorMmul extends DynamicCustomOp { protected boolean addedEdges; protected MMulTranspose mMulTranspose; + + public TensorMmul(INDArray x, INDArray y, int[][] axes) { + this(x,y,axes[0], axes[1], false, false, false); + } + + /** + * Initialize with the given + * input, pairwise transform, result, and number + * of elements + * + * @param x the input + * @param y the pairwise transform + * @param z the result + */ + public TensorMmul(INDArray x, INDArray y, INDArray z, int[][] axes) { + this(x, y, axes[0], axes[1], false, false, false); + } + + public TensorMmul(INDArray x, INDArray y, int[] dimensionsX, int[] dimensionsY, + boolean transposeX, boolean transposeY, boolean transposeZ) { + super(null,new INDArray[]{x, y},null); + this.axes = new int[][]{dimensionsX, dimensionsY}; + addIArgument(dimensionsX.length); + addIArgument(dimensionsX); + addIArgument(dimensionsY.length); + addIArgument(dimensionsY); + addBArgument(transposeX, transposeY, transposeZ); + } + public TensorMmul(SameDiff sameDiff, SDVariable i_v1, SDVariable i_v2, @@ -138,13 +167,13 @@ public class TensorMmul extends DynamicCustomOp { //tensor matrix multiply gradient wrt second variable int[] firstPerm = argsort(combine(deletedAxes[0],keep(argsort(sumAxes[1]),sumAxes[0]))); SDVariable firstResult = doTensorMmul(i_v1.get(0), rarg(), firstAxes); - SDVariable permuted = f().permute(firstResult,firstPerm); + SDVariable permuted = sameDiff.permute(firstResult,firstPerm); ret.add(permuted); //tensor matrix multiply gradient wrt first variable int[] secondPerm = argsort(combine(keep(argsort(sumAxes[0]),sumAxes[1]),deletedAxes[1])); SDVariable secondResult = doTensorMmul(i_v1.get(0), larg(), secondAxes); - SDVariable secondPermuted = f().permute(secondResult,secondPerm); + SDVariable secondPermuted = sameDiff.permute(secondResult,secondPerm); ret.add(secondPermuted); return ret; } @@ -210,7 +239,7 @@ public class TensorMmul extends DynamicCustomOp { } - int[] newShapeB = {n3, -1}; + long[] newShapeB = {n3, -1}; long[] oldShapeB; if (listB.size() == 0) { oldShapeB = new long[] {1}; @@ -221,44 +250,12 @@ public class TensorMmul extends DynamicCustomOp { } - SDVariable at = f() - .reshape(f().permute - (a,newAxesA),newShapeA); - SDVariable bt = f() - .reshape(f() - .permute(b,newAxesB),newShapeB); + SDVariable at = sameDiff.reshape(sameDiff.permute(a,newAxesA),newShapeA); + SDVariable bt = sameDiff.reshape(sameDiff.permute(b,newAxesB),newShapeB); - SDVariable ret = f().mmul(at,bt); + SDVariable ret = sameDiff.mmul(at,bt); long[] aPlusB = Longs.concat(oldShapeA, oldShapeB); - return f().reshape(ret, aPlusB); - } - - - public TensorMmul(INDArray x, INDArray y, int[][] axes) { - super(null,new INDArray[]{x, y},null); - this.axes = axes; - this.extraArgs = new Object[] {axes}; - } - - /** - * Initialize with the given - * input, pairwise transform, result, and number - * of elements - * - * @param x the input - * @param y the pairwise transform - * @param z the result - */ - public TensorMmul(INDArray x, INDArray y, INDArray z, int[][] axes) { - super(null,new INDArray[]{x, y, z},null); - this.axes = axes; - } - - public TensorMmul(INDArray x, INDArray y, int[] dimensionsX, int[] dimensionsY, - boolean transposeX, boolean transposeY, boolean transposeZ) { - super(null,new INDArray[]{x, y},null); - this.axes = new int[][]{dimensionsX, dimensionsY}; - addBArgument(transposeX, transposeY, transposeZ); + return sameDiff.reshape(ret, aPlusB); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/bool/All.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/bool/All.java index a465728d1..8aa12d4d3 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/bool/All.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/bool/All.java @@ -57,7 +57,7 @@ public class All extends BaseReduceBoolOp { @Override public List doDiff(List f1) { - return Collections.singletonList(f().zerosLike(arg())); + return Collections.singletonList(sameDiff.zerosLike(arg())); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/bool/Any.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/bool/Any.java index d4522ca69..4d26e5b70 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/bool/Any.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/bool/Any.java @@ -57,7 +57,7 @@ public class Any extends BaseReduceBoolOp { @Override public List doDiff(List f1) { - return Collections.singletonList(f().zerosLike(arg())); + return Collections.singletonList(sameDiff.zerosLike(arg())); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/bool/IsInf.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/bool/IsInf.java index cb93a832e..5dfc23f8e 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/bool/IsInf.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/bool/IsInf.java @@ -68,7 +68,7 @@ public class IsInf extends BaseReduceBoolOp { @Override public List doDiff(List i_v) { - return Collections.singletonList(f().zerosLike(arg())); + return Collections.singletonList(sameDiff.zerosLike(arg())); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/bool/IsNaN.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/bool/IsNaN.java index c8cd72f2c..a78ae8bd5 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/bool/IsNaN.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/bool/IsNaN.java @@ -68,7 +68,7 @@ public class IsNaN extends BaseReduceBoolOp { @Override public List doDiff(List i_v) { - return Collections.singletonList(f().zerosLike(arg())); + return Collections.singletonList(sameDiff.zerosLike(arg())); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/custom/LogSumExp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/custom/LogSumExp.java index 26eabf0ff..edc3298b5 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/custom/LogSumExp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/custom/LogSumExp.java @@ -22,6 +22,7 @@ import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; +import org.nd4j.linalg.api.ops.impl.reduce.bp.SumBp; import java.util.Collections; import java.util.List; @@ -82,10 +83,10 @@ public class LogSumExp extends DynamicCustomOp { //z = log(sum_i exp(x_i)) = log(s) //dL/dx = dL/dz * dz/ds * ds/dx //dz/ds = 1/s - SDVariable exp = f().exp(arg()); + SDVariable exp = sameDiff.math.exp(arg()); SDVariable sumExp = exp.sum(dimensions); SDVariable gradProd = f1.get(0).div(sumExp); - SDVariable dSumExpdx = f().sumBp(arg(), gradProd, keepDims, dimensions).mul(exp); + SDVariable dSumExpdx = new SumBp(sameDiff, arg(), gradProd, keepDims, dimensions).outputVariable().mul(exp); return Collections.singletonList(dSumExpdx); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/floating/AMean.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/floating/AMean.java index 47cc728ab..e9481fa81 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/floating/AMean.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/floating/AMean.java @@ -21,6 +21,7 @@ import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.imports.NoOpNameFoundException; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.BaseReduceFloatOp; +import org.nd4j.linalg.api.ops.impl.reduce.bp.MeanBp; import java.util.Collections; import java.util.List; @@ -73,7 +74,7 @@ public class AMean extends BaseReduceFloatOp { @Override public List doDiff(List f1) { SDVariable sgn = sameDiff.math().sign(arg()); - SDVariable meanBp = f().meanBp(sameDiff.math().abs(arg()), f1.get(0), false, dimensions); + SDVariable meanBp = new MeanBp(sameDiff, sameDiff.math().abs(arg()), f1.get(0), false, dimensions).outputVariable(); return Collections.singletonList(sgn.mul(meanBp)); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/floating/Entropy.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/floating/Entropy.java index bb0dd4997..913a573db 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/floating/Entropy.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/floating/Entropy.java @@ -16,12 +16,11 @@ package org.nd4j.linalg.api.ops.impl.reduce.floating; -import org.nd4j.autodiff.functions.DifferentialFunctionFactory; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; -import org.nd4j.imports.NoOpNameFoundException; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.BaseReduceFloatOp; +import org.nd4j.linalg.api.ops.impl.reduce.bp.SumBp; import java.util.Collections; import java.util.List; @@ -70,13 +69,13 @@ public class Entropy extends BaseReduceFloatOp { //Then we can do sumBp(z, -dL/dOut) //Note d/dx(x*log(x)) = log(x)+1 - return grad(f(), arg(), f1.get(0), dimensions); + return grad(sameDiff, arg(), f1.get(0), dimensions); } - public static List grad(DifferentialFunctionFactory f, SDVariable arg, SDVariable grad, int[] dimensions){ - SDVariable logx = f.log(arg); + public static List grad(SameDiff sd, SDVariable arg, SDVariable grad, int[] dimensions){ + SDVariable logx = sd.math.log(arg); SDVariable xLogX = arg.mul(logx); - SDVariable sumBp = f.sumBp(xLogX, grad.neg(), false, dimensions); + SDVariable sumBp = new SumBp(sd, xLogX, grad.neg(), false, dimensions).outputVariable(); return Collections.singletonList(sumBp.mul(logx.add(1.0))); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/floating/LogEntropy.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/floating/LogEntropy.java index 52970cc33..837d89c3a 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/floating/LogEntropy.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/floating/LogEntropy.java @@ -70,7 +70,7 @@ public class LogEntropy extends BaseReduceFloatOp { @Override public List doDiff(List f1) { //If y=log(x), and x=entropy(in) then dL/dx = dL/dy * dy/dx; d(log(x))/dx = 1/x - List entropyGrad = Entropy.grad(f(), arg(), f1.get(0), dimensions); - return Collections.singletonList(entropyGrad.get(0).div(f().exp(outputVariable()))); + List entropyGrad = Entropy.grad(sameDiff, arg(), f1.get(0), dimensions); + return Collections.singletonList(entropyGrad.get(0).div(sameDiff.math.exp(outputVariable()))); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/floating/Mean.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/floating/Mean.java index bf15f94d4..6309ccf28 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/floating/Mean.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/floating/Mean.java @@ -20,6 +20,7 @@ import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.BaseReduceFloatOp; +import org.nd4j.linalg.api.ops.impl.reduce.bp.MeanBp; import java.util.Collections; import java.util.List; @@ -67,7 +68,7 @@ public class Mean extends BaseReduceFloatOp { public List doDiff(List i_v1) { //If out = mean(in), then dL/dIn = 1/N * dL/dOut (broadcast to appropriate shape) //Note that N differs for "along dimension" vs. "whole array" reduce cases - return Collections.singletonList(f().meanBp(arg(), i_v1.get(0), keepDims, dimensions)); + return new MeanBp(sameDiff, arg(), i_v1.get(0), keepDims, dimensions).outputs(); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/floating/Norm1.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/floating/Norm1.java index a2ba88927..96222d7c8 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/floating/Norm1.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/floating/Norm1.java @@ -21,6 +21,7 @@ import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.imports.NoOpNameFoundException; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.BaseReduceFloatOp; +import org.nd4j.linalg.api.ops.impl.reduce.bp.Norm1Bp; import org.nd4j.linalg.ops.transforms.Transforms; import java.util.Collections; @@ -80,6 +81,6 @@ public class Norm1 extends BaseReduceFloatOp { @Override public List doDiff(List grad) { - return Collections.singletonList(f().norm1Bp(arg(), grad.get(0), keepDims, dimensions)); + return new Norm1Bp(sameDiff, arg(), grad.get(0), keepDims, dimensions).outputs(); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/floating/Norm2.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/floating/Norm2.java index f61c0dc43..be517f5e8 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/floating/Norm2.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/floating/Norm2.java @@ -21,6 +21,7 @@ import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.imports.NoOpNameFoundException; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.BaseReduceFloatOp; +import org.nd4j.linalg.api.ops.impl.reduce.bp.Norm2Bp; import org.nd4j.linalg.ops.transforms.Transforms; import java.util.Collections; @@ -72,7 +73,7 @@ public class Norm2 extends BaseReduceFloatOp { @Override public List doDiff(List grad) { //d norm2(in)/dx = x / norm2(in) - return Collections.singletonList(f().norm2Bp(arg(), grad.get(0), keepDims, dimensions)); + return new Norm2Bp(sameDiff, arg(), grad.get(0), keepDims, dimensions).outputs(); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/floating/NormMax.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/floating/NormMax.java index ece542857..ea3fd140d 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/floating/NormMax.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/floating/NormMax.java @@ -21,6 +21,7 @@ import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.imports.NoOpNameFoundException; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.BaseReduceFloatOp; +import org.nd4j.linalg.api.ops.impl.reduce.bp.NormMaxBp; import org.nd4j.linalg.ops.transforms.Transforms; import java.util.Collections; @@ -47,7 +48,6 @@ public class NormMax extends BaseReduceFloatOp { super(x, null, z, dimensions); } - public NormMax(INDArray x, int... dimensions) { super(x, dimensions); } @@ -77,7 +77,7 @@ public class NormMax extends BaseReduceFloatOp { public List doDiff(List grad) { //maxnorm(in) = max_i |x_i| //d maxnorm(in)/dx = 0 if x_i is not the max, or d|x|/dx otherwise - return Collections.singletonList(f().normmaxBp(arg(), grad.get(0), keepDims, dimensions)); + return new NormMaxBp(sameDiff, arg(), grad.get(0), keepDims, dimensions).outputs(); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/floating/ShannonEntropy.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/floating/ShannonEntropy.java index 44504f855..963224ed8 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/floating/ShannonEntropy.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/floating/ShannonEntropy.java @@ -21,6 +21,7 @@ import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.imports.NoOpNameFoundException; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.BaseReduceFloatOp; +import org.nd4j.linalg.api.ops.impl.reduce.bp.SumBp; import java.util.Collections; import java.util.List; @@ -68,10 +69,10 @@ public class ShannonEntropy extends BaseReduceFloatOp { //Then we can do sumBp(z, -dL/dOut) //Note d/dx(x*log2(x)) = (log(x)+1)/log(2) - SDVariable log2x = f().log(arg(),2); - SDVariable logx = f().log(arg()); + SDVariable log2x = sameDiff.math.log(arg(),2); + SDVariable logx = sameDiff.math.log(arg()); SDVariable xLog2X = arg().mul(log2x); - SDVariable sumBp = f().sumBp(xLog2X, f1.get(0).neg(), false, dimensions); + SDVariable sumBp = new SumBp(sameDiff, xLog2X, f1.get(0).neg(), false, dimensions).outputVariable(); return Collections.singletonList(sumBp.mul(logx.add(1.0)).div(Math.log(2.0))); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/floating/SquaredNorm.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/floating/SquaredNorm.java index b11fe5b1f..f80a712d3 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/floating/SquaredNorm.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/floating/SquaredNorm.java @@ -21,6 +21,7 @@ import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.imports.NoOpNameFoundException; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.BaseReduceFloatOp; +import org.nd4j.linalg.api.ops.impl.reduce.bp.SquaredNormBp; import java.util.Collections; import java.util.List; @@ -47,6 +48,10 @@ public class SquaredNorm extends BaseReduceFloatOp { public SquaredNorm(){} + public SquaredNorm(INDArray x, int... dimensions){ + super(x, dimensions); + } + @Override public int opNum() { return 7; @@ -69,6 +74,6 @@ public class SquaredNorm extends BaseReduceFloatOp { @Override public List doDiff(List grad){ - return Collections.singletonList(f().squaredNormBp(arg(), grad.get(0), keepDims, dimensions)); + return new SquaredNormBp(sameDiff, arg(), grad.get(0), keepDims, dimensions).outputs(); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/longer/CountNonZero.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/longer/CountNonZero.java index d27215a80..7376b0708 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/longer/CountNonZero.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/longer/CountNonZero.java @@ -56,7 +56,7 @@ public class CountNonZero extends BaseReduceLongOp { @Override public List doDiff(List f1) { - return Collections.singletonList(f().zerosLike(arg())); + return Collections.singletonList(sameDiff.zerosLike(arg())); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/longer/CountZero.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/longer/CountZero.java index db13dfc85..27476cabc 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/longer/CountZero.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/longer/CountZero.java @@ -67,7 +67,7 @@ public class CountZero extends BaseReduceLongOp { @Override public List doDiff(List f1) { - return Collections.singletonList(f().zerosLike(arg())); + return Collections.singletonList(sameDiff.zerosLike(arg())); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/longer/MatchCondition.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/longer/MatchCondition.java index 0fb4db830..f2f097aa9 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/longer/MatchCondition.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/longer/MatchCondition.java @@ -52,11 +52,15 @@ public class MatchCondition extends BaseReduceLongOp { public MatchCondition() {} - public MatchCondition(INDArray x, Condition condition, int... dimensions) { this(x, Nd4j.EPS_THRESHOLD, condition, dimensions); } + public MatchCondition(INDArray x, Condition condition, boolean keepDims, int... dimensions) { + this(x, Nd4j.EPS_THRESHOLD, condition, dimensions); + this.keepDims = keepDims; + } + public MatchCondition(INDArray x, double eps, Condition condition, int... dimensions) { super(x); this.compare = condition.getValue(); @@ -68,10 +72,6 @@ public class MatchCondition extends BaseReduceLongOp { defineDimensions(dimensions); } - public MatchCondition(INDArray in, Condition condition, boolean keepDim, int... dimensions) { - this(in, condition, dimensions); - } - @Override public int opNum() { return 2; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/same/AMax.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/same/AMax.java index a6533441a..cadc77d4f 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/same/AMax.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/same/AMax.java @@ -21,6 +21,7 @@ import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.imports.NoOpNameFoundException; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.BaseReduceSameOp; +import org.nd4j.linalg.api.ops.impl.reduce.bp.MaxBp; import java.util.Collections; import java.util.List; @@ -65,7 +66,7 @@ public class AMax extends BaseReduceSameOp { @Override public List doDiff(List f1) { SDVariable sgn = sameDiff.math().sign(arg()); - SDVariable maxBp = f().maxBp(sameDiff.math().abs(arg()), f1.get(0), false, dimensions); + SDVariable maxBp = new MaxBp(sameDiff, sameDiff.math().abs(arg()), f1.get(0), false, dimensions).outputVariable(); return Collections.singletonList(sgn.mul(maxBp)); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/same/AMin.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/same/AMin.java index 20a8be906..a01c9c1f5 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/same/AMin.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/same/AMin.java @@ -21,6 +21,7 @@ import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.imports.NoOpNameFoundException; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.BaseReduceSameOp; +import org.nd4j.linalg.api.ops.impl.reduce.bp.MinBp; import java.util.Collections; import java.util.List; @@ -69,7 +70,7 @@ public class AMin extends BaseReduceSameOp { @Override public List doDiff(List f1) { SDVariable sgn = sameDiff.math().sign(arg()); - SDVariable minBp = f().minBp(sameDiff.math().abs(arg()), f1.get(0), false, dimensions); + SDVariable minBp = new MinBp(sameDiff, sameDiff.math().abs(arg()), f1.get(0), false, dimensions).outputVariable(); return Collections.singletonList(sgn.mul(minBp)); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/same/ASum.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/same/ASum.java index 17d8a0bde..1a15c32ac 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/same/ASum.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/same/ASum.java @@ -21,6 +21,7 @@ import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.imports.NoOpNameFoundException; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.BaseReduceSameOp; +import org.nd4j.linalg.api.ops.impl.reduce.bp.SumBp; import java.util.Collections; import java.util.List; @@ -72,7 +73,7 @@ public class ASum extends BaseReduceSameOp { @Override public List doDiff(List f1) { SDVariable sgn = sameDiff.math().sign(arg()); - SDVariable meanBp = f().sumBp(sameDiff.math().abs(arg()), f1.get(0), false, dimensions); + SDVariable meanBp = new SumBp(sameDiff, sameDiff.math().abs(arg()), f1.get(0), false, dimensions).outputVariable(); return Collections.singletonList(sgn.mul(meanBp)); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/same/Max.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/same/Max.java index a29384a42..8c4563c95 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/same/Max.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/same/Max.java @@ -20,6 +20,7 @@ import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.BaseReduceSameOp; +import org.nd4j.linalg.api.ops.impl.reduce.bp.MaxBp; import java.util.Collections; import java.util.List; @@ -79,7 +80,7 @@ public class Max extends BaseReduceSameOp { @Override public List doDiff(List grad) { - return Collections.singletonList(f().maxBp(arg(), grad.get(0), keepDims, dimensions)); + return new MaxBp(sameDiff, arg(), grad.get(0), keepDims, dimensions).outputs(); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/same/Min.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/same/Min.java index 99c1e038b..1d644b671 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/same/Min.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/same/Min.java @@ -20,6 +20,7 @@ import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.BaseReduceSameOp; +import org.nd4j.linalg.api.ops.impl.reduce.bp.MinBp; import java.util.Collections; import java.util.List; @@ -77,6 +78,6 @@ public class Min extends BaseReduceSameOp { @Override public List doDiff(List grad) { - return Collections.singletonList(f().minBp(arg(), grad.get(0), keepDims, dimensions)); + return new MinBp(sameDiff, arg(), grad.get(0), keepDims, dimensions).outputs(); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/same/Prod.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/same/Prod.java index b5073d0f9..0247e3169 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/same/Prod.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/same/Prod.java @@ -21,6 +21,7 @@ import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.linalg.api.ndarray.BaseNDArray; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.BaseReduceSameOp; +import org.nd4j.linalg.api.ops.impl.reduce.bp.ProdBp; import java.util.Collections; import java.util.List; @@ -82,6 +83,6 @@ public class Prod extends BaseReduceSameOp { @Override public List doDiff(List grad) { - return Collections.singletonList(f().prodBp(arg(), grad.get(0), keepDims, dimensions)); + return new ProdBp(sameDiff, arg(), grad.get(0), keepDims, dimensions).outputs(); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/same/Sum.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/same/Sum.java index 859b89dac..f2c0b1d40 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/same/Sum.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/same/Sum.java @@ -21,6 +21,7 @@ import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.BaseReduceSameOp; +import org.nd4j.linalg.api.ops.impl.reduce.bp.SumBp; import java.util.Collections; import java.util.List; @@ -40,7 +41,6 @@ public class Sum extends BaseReduceSameOp { super(sameDiff, i_v, i_v2, dimensions); } - public Sum() { } @@ -76,7 +76,7 @@ public class Sum extends BaseReduceSameOp { // dL/dIn = dL/dOut * dOut/dIn // = dL/dOut * 1 // But broadcast to shape of the input - return Collections.singletonList(f().sumBp(arg(), i_v1.get(0), keepDims, dimensions)); + return new SumBp(sameDiff, arg(), i_v1.get(0), keepDims, dimensions).outputs(); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce3/CosineDistance.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce3/CosineDistance.java index 44f1c49fc..a5aab468b 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce3/CosineDistance.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce3/CosineDistance.java @@ -84,7 +84,7 @@ public class CosineDistance extends BaseReduce3Op { //Cosine distance = 1 - cosine similarity //Therefore: just need to negate gradients from cosine similarity... - List diff = CosineSimilarity.doDiff(sameDiff, f(), larg(), rarg(), i_v1.get(0), keepDims, dimensions); - return Arrays.asList(f().neg(diff.get(0)), f().neg(diff.get(1))); + List diff = CosineSimilarity.doDiff(sameDiff, larg(), rarg(), i_v1.get(0), keepDims, dimensions); + return Arrays.asList(sameDiff.math.neg(diff.get(0)), sameDiff.math.neg(diff.get(1))); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce3/CosineSimilarity.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce3/CosineSimilarity.java index 27c14473d..b6edbe6fa 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce3/CosineSimilarity.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce3/CosineSimilarity.java @@ -16,9 +16,9 @@ package org.nd4j.linalg.api.ops.impl.reduce3; -import org.nd4j.autodiff.functions.DifferentialFunctionFactory; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.autodiff.util.SameDiffUtils; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; @@ -93,14 +93,14 @@ public class CosineSimilarity extends BaseReduce3Op { //Then: // dc(x,y)/dx_i = 1/b * (y - x * a / (l2(x))^2) - return doDiff(sameDiff, f(), larg(), rarg(), i_v1.get(0), keepDims, dimensions); + return doDiff(sameDiff, larg(), rarg(), i_v1.get(0), keepDims, dimensions); } - public static List doDiff(SameDiff sameDiff, DifferentialFunctionFactory f, SDVariable x, SDVariable y, + public static List doDiff(SameDiff sameDiff, SDVariable x, SDVariable y, SDVariable gradOut, boolean keepDims, int... dimensions){ SDVariable a = sameDiff.sum(x.mul(y),true, dimensions); - SDVariable l2x = f.norm2(x, true, dimensions); - SDVariable l2y = f.norm2(y, true, dimensions); + SDVariable l2x = sameDiff.norm2(x, true, dimensions); + SDVariable l2y = sameDiff.norm2(y, true, dimensions); SDVariable b = l2x.mul(l2y); SDVariable l2xSq = sameDiff.math().square(l2x); @@ -110,7 +110,7 @@ public class CosineSimilarity extends BaseReduce3Op { //keepDims or full array reduction broadcastableGrad = gradOut; } else { - broadcastableGrad = sameDiff.f().reductionBroadcastableWithOrigShape(x, sameDiff.constant(Nd4j.createFromArray(dimensions)), gradOut); + broadcastableGrad = SameDiffUtils.reductionBroadcastableWithOrigShape(x, sameDiff.constant(Nd4j.createFromArray(dimensions)), gradOut); } SDVariable dcdx = y.sub(x.mul(a).div(l2xSq)).div(b); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce3/Dot.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce3/Dot.java index 85f0b3e15..bdb172924 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce3/Dot.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce3/Dot.java @@ -19,6 +19,7 @@ package org.nd4j.linalg.api.ops.impl.reduce3; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.impl.reduce.bp.DotBp; import java.util.Arrays; import java.util.List; @@ -86,6 +87,6 @@ public class Dot extends BaseReduce3Op { @Override public List doDiff(List f1) { //TODO KEEP DIMS - return Arrays.asList(f().dotBp(arg(0), arg(1), f1.get(0), false, dimensions)); + return new DotBp(sameDiff, arg(0), arg(1), f1.get(0), false, dimensions).outputs(); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce3/EuclideanDistance.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce3/EuclideanDistance.java index a25ba6d52..97ccd81e6 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce3/EuclideanDistance.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce3/EuclideanDistance.java @@ -18,6 +18,7 @@ package org.nd4j.linalg.api.ops.impl.reduce3; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.autodiff.util.SameDiffUtils; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; @@ -89,11 +90,11 @@ public class EuclideanDistance extends BaseReduce3Op { SDVariable divBroadcastable = i_v1.get(0).div(euc); if(!keepDims && !(dimensions == null || dimensions.length == 0 || (dimensions.length == 1 && dimensions[0] == Integer.MAX_VALUE))){ //Not keep dims, and not full array reduction -> need to make broadcastable - divBroadcastable = f().reductionBroadcastableWithOrigShape(arg(), sameDiff.constant(Nd4j.createFromArray(dimensions)), divBroadcastable); + divBroadcastable = SameDiffUtils.reductionBroadcastableWithOrigShape(arg(), sameDiff.constant(Nd4j.createFromArray(dimensions)), divBroadcastable); } SDVariable gradX = difference.mul(divBroadcastable); - SDVariable gradY = f().neg(gradX); + SDVariable gradY = sameDiff.math.neg(gradX); return Arrays.asList(gradX, gradY); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce3/JaccardDistance.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce3/JaccardDistance.java index 994003e78..c520a7c18 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce3/JaccardDistance.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce3/JaccardDistance.java @@ -18,6 +18,7 @@ package org.nd4j.linalg.api.ops.impl.reduce3; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.autodiff.util.SameDiffUtils; import org.nd4j.imports.NoOpNameFoundException; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; @@ -90,18 +91,18 @@ public class JaccardDistance extends BaseReduce3Op { //Jaccard distance: https://en.wikipedia.org/wiki/Jaccard_index#Generalized_Jaccard_similarity_and_distance //J(x,y) = 1 - sum_i min(x_i, y_i) / sum_i max(x_i, y_i) - SDVariable min = f().min(larg(), rarg()); - SDVariable max = f().max(larg(), rarg()); + SDVariable min = sameDiff.math.min(larg(), rarg()); + SDVariable max = sameDiff.math.max(larg(), rarg()); SDVariable sumMax = max.sum(true, dimensions); SDVariable sumMin = min.sum(true, dimensions); DataType d = arg().dataType(); - SDVariable xIsMin = f().eq(min, larg()).castTo(d); - SDVariable xIsMax = f().eq(max, larg()).castTo(d); - SDVariable yIsMin = f().eq(min, rarg()).castTo(d); - SDVariable yIsMax = f().eq(max, rarg()).castTo(d); + SDVariable xIsMin = sameDiff.eq(min, larg()).castTo(d); + SDVariable xIsMax = sameDiff.eq(max, larg()).castTo(d); + SDVariable yIsMin = sameDiff.eq(min, rarg()).castTo(d); + SDVariable yIsMax = sameDiff.eq(max, rarg()).castTo(d); - SDVariable sqSumMax = f().square(sumMax); + SDVariable sqSumMax = sameDiff.math.square(sumMax); SDVariable dldx = xIsMax.mul(sumMin).sub(xIsMin.mul(sumMax)).div(sqSumMax); SDVariable dldy = yIsMax.mul(sumMin).sub(yIsMin.mul(sumMax)).div(sqSumMax); @@ -110,7 +111,7 @@ public class JaccardDistance extends BaseReduce3Op { //KeepDims or full array reduction - already broadcastable bcGradOut = f1.get(0); } else { - bcGradOut = sameDiff.f().reductionBroadcastableWithOrigShape(arg(), sameDiff.constant(Nd4j.createFromArray(dimensions)), f1.get(0)); + bcGradOut = SameDiffUtils.reductionBroadcastableWithOrigShape(arg(), sameDiff.constant(Nd4j.createFromArray(dimensions)), f1.get(0)); } return Arrays.asList(dldx.mul(bcGradOut), dldy.mul(bcGradOut)); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce3/ManhattanDistance.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce3/ManhattanDistance.java index 0c007a261..9fdea3afb 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce3/ManhattanDistance.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce3/ManhattanDistance.java @@ -18,6 +18,7 @@ package org.nd4j.linalg.api.ops.impl.reduce3; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.autodiff.util.SameDiffUtils; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; @@ -86,11 +87,11 @@ public class ManhattanDistance extends BaseReduce3Op { //keepDims or full array reduction gradBroadcastable = i_v1.get(0); } else { - gradBroadcastable = sameDiff.f().reductionBroadcastableWithOrigShape(arg(), sameDiff.constant(Nd4j.createFromArray(dimensions)), i_v1.get(0)); + gradBroadcastable = SameDiffUtils.reductionBroadcastableWithOrigShape(arg(), sameDiff.constant(Nd4j.createFromArray(dimensions)), i_v1.get(0)); } SDVariable gradX = sameDiff.math().sign(difference).mul(gradBroadcastable); - SDVariable gradY = f().neg(gradX); + SDVariable gradY = sameDiff.math().neg(gradX); return Arrays.asList(gradX, gradY); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/LeakyReLU.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/LeakyReLU.java index 000b0414c..44514ee1a 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/LeakyReLU.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/LeakyReLU.java @@ -23,6 +23,7 @@ import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.BaseScalarOp; +import org.nd4j.linalg.api.ops.impl.transforms.gradient.LeakyReLUBp; import org.nd4j.linalg.factory.Nd4j; import org.tensorflow.framework.AttrValue; import org.tensorflow.framework.GraphDef; @@ -108,7 +109,7 @@ public class LeakyReLU extends BaseScalarOp { @Override public List doDiff(List i_v) { - return Collections.singletonList(f().leakyReluBp(arg(), i_v.get(0), alpha)); + return new LeakyReLUBp(sameDiff, arg(), i_v.get(0), alpha).outputs(); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/PRelu.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/PRelu.java index f9e30be9c..4b9b37026 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/PRelu.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/PRelu.java @@ -29,6 +29,7 @@ import org.nd4j.imports.NoOpNameFoundException; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; +import org.nd4j.linalg.api.ops.impl.transforms.gradient.PReluBp; /** * Parameterized ReLU op @@ -80,6 +81,6 @@ public class PRelu extends DynamicCustomOp { @Override public List doDiff(List i_v) { - return Arrays.asList(f().preluBp(arg(0), arg(1), i_v.get(0), sharedAxes)); + return new PReluBp(sameDiff, arg(0), arg(1), i_v.get(0), sharedAxes).outputs(); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/Pow.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/Pow.java index 5cfab3768..ec15ea537 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/Pow.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/Pow.java @@ -23,6 +23,7 @@ import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.BaseScalarOp; import java.util.Arrays; +import java.util.Collections; import java.util.List; /** @@ -87,7 +88,7 @@ public class Pow extends BaseScalarOp { @Override public List doDiff(List i_v1) { - SDVariable g = f().powDerivative(arg(), this.pow).mul(i_v1.get(0)); - return Arrays.asList(g); + SDVariable g = new PowDerivative(sameDiff, arg(), false, this.pow).outputVariable().mul(i_v1.get(0)); + return Collections.singletonList(g); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/RectifiedLinear.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/RectifiedLinear.java index 944d4d095..98df920bc 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/RectifiedLinear.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/RectifiedLinear.java @@ -20,6 +20,7 @@ import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.BaseScalarOp; +import org.nd4j.linalg.api.ops.impl.transforms.gradient.ThresholdReluBp; import java.util.Arrays; import java.util.Collections; @@ -81,6 +82,6 @@ public class RectifiedLinear extends BaseScalarOp { @Override public List doDiff(List i_v) { - return Collections.singletonList(f().thresholdReluBp(arg(), i_v.get(0), scalarValue.getDouble(0))); + return new ThresholdReluBp(sameDiff, arg(), i_v.get(0), scalarValue.getDouble(0)).outputs(); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/Relu6.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/Relu6.java index c80d3c8f9..9b11925c3 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/Relu6.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/Relu6.java @@ -23,6 +23,7 @@ import org.nd4j.imports.graphmapper.tf.TFGraphMapper; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.BaseScalarOp; +import org.nd4j.linalg.api.ops.impl.transforms.gradient.Relu6Derivative; import org.nd4j.linalg.factory.Nd4j; import org.tensorflow.framework.AttrValue; import org.tensorflow.framework.GraphDef; @@ -99,6 +100,6 @@ public class Relu6 extends BaseScalarOp { @Override public List doDiff(List i_v) { SDVariable dLdOut = i_v.get(0); - return Collections.singletonList(f().relu6Derivative(arg(), dLdOut, scalarValue.getDouble(0))); + return new Relu6Derivative(sameDiff, arg(), dLdOut, scalarValue.getDouble(0)).outputs(); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/ScalarAdd.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/ScalarAdd.java index 3aa31771a..f8831c68e 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/ScalarAdd.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/ScalarAdd.java @@ -16,6 +16,7 @@ package org.nd4j.linalg.api.ops.impl.scalar; +import lombok.NonNull; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.imports.NoOpNameFoundException; @@ -49,9 +50,12 @@ public class ScalarAdd extends BaseScalarOp { this(arr, 0); } + public ScalarAdd(@NonNull SameDiff sameDiff, @NonNull SDVariable i_v, Number scalar) { + this(sameDiff, i_v, scalar, false); + } + public ScalarAdd(SameDiff sameDiff, SDVariable i_v, Number scalar, boolean inPlace) { super(sameDiff, i_v, scalar, inPlace); - } public ScalarAdd(SameDiff sameDiff, SDVariable i_v, Number scalar, boolean inPlace, Object[] extraArgs) { diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/ScalarReverseDivision.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/ScalarReverseDivision.java index 1fec8f808..463012875 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/ScalarReverseDivision.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/ScalarReverseDivision.java @@ -22,6 +22,7 @@ import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.BaseScalarOp; import java.util.Arrays; +import java.util.Collections; import java.util.List; /** @@ -73,8 +74,8 @@ public class ScalarReverseDivision extends BaseScalarOp { @Override public List doDiff(List i_v1) { - SDVariable g = f().rdiv(f().pow(arg(), 2), -scalarValue.getDouble(0)).mul(i_v1.get(0)); - return Arrays.asList(g); + SDVariable g = sameDiff.math.rdiv(sameDiff.math.pow(arg(), 2), -scalarValue.getDouble(0)).mul(i_v1.get(0)); + return Collections.singletonList(g); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/ScalarReverseSubtraction.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/ScalarReverseSubtraction.java index d362620e4..972f4ec10 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/ScalarReverseSubtraction.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/ScalarReverseSubtraction.java @@ -23,6 +23,7 @@ import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.BaseScalarOp; import java.util.Arrays; +import java.util.Collections; import java.util.List; /** @@ -79,8 +80,8 @@ public class ScalarReverseSubtraction extends BaseScalarOp { @Override public List doDiff(List i_v1) { - SDVariable g = f().neg(i_v1.get(0)); - return Arrays.asList(g); + SDVariable g = sameDiff.math.neg(i_v1.get(0)); + return Collections.singletonList(g); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/ScalarSet.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/ScalarSet.java index a8ec8c7f3..d3a8c7f67 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/ScalarSet.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/ScalarSet.java @@ -76,7 +76,7 @@ public class ScalarSet extends BaseScalarOp { @Override public List doDiff(List i_v1) { - return Collections.singletonList(f().zerosLike(arg())); + return Collections.singletonList(sameDiff.zerosLike(arg())); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/Step.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/Step.java index 65f653d64..04bd39622 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/Step.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/Step.java @@ -96,6 +96,6 @@ public class Step extends BaseScalarOp { @Override public List doDiff(List f1) { - return Collections.singletonList(f().zerosLike(arg())); + return Collections.singletonList(sameDiff.zerosLike(arg())); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/comparison/ScalarEquals.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/comparison/ScalarEquals.java index ad2aa9b50..bafe8db88 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/comparison/ScalarEquals.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/comparison/ScalarEquals.java @@ -40,7 +40,7 @@ public class ScalarEquals extends BaseScalarBoolOp { } public ScalarEquals(INDArray x, Number num) { - super(x, num); + this(x, null, num); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/comparison/ScalarGreaterThan.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/comparison/ScalarGreaterThan.java index eafbcbc1a..524e66baa 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/comparison/ScalarGreaterThan.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/comparison/ScalarGreaterThan.java @@ -19,9 +19,11 @@ package org.nd4j.linalg.api.ops.impl.scalar.comparison; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.imports.NoOpNameFoundException; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.BaseScalarBoolOp; import org.nd4j.linalg.api.ops.BaseScalarOp; +import org.nd4j.linalg.factory.Nd4j; import java.util.Arrays; import java.util.List; @@ -56,7 +58,7 @@ public class ScalarGreaterThan extends BaseScalarBoolOp { } public ScalarGreaterThan(INDArray x, Number num) { - super(x, num); + this(x, null, num); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/comparison/ScalarGreaterThanOrEqual.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/comparison/ScalarGreaterThanOrEqual.java index 0948c01ab..09c001dda 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/comparison/ScalarGreaterThanOrEqual.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/comparison/ScalarGreaterThanOrEqual.java @@ -19,9 +19,11 @@ package org.nd4j.linalg.api.ops.impl.scalar.comparison; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.imports.NoOpNameFoundException; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.BaseScalarBoolOp; import org.nd4j.linalg.api.ops.BaseScalarOp; +import org.nd4j.linalg.factory.Nd4j; import java.util.Arrays; import java.util.List; @@ -40,7 +42,7 @@ public class ScalarGreaterThanOrEqual extends BaseScalarBoolOp { } public ScalarGreaterThanOrEqual(INDArray x, Number num) { - super(x, num); + this(x, null, num); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/comparison/ScalarLessThan.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/comparison/ScalarLessThan.java index 6f72490a1..740d05a79 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/comparison/ScalarLessThan.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/comparison/ScalarLessThan.java @@ -18,9 +18,11 @@ package org.nd4j.linalg.api.ops.impl.scalar.comparison; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.BaseScalarBoolOp; import org.nd4j.linalg.api.ops.BaseScalarOp; +import org.nd4j.linalg.factory.Nd4j; import java.util.Arrays; import java.util.List; @@ -39,7 +41,7 @@ public class ScalarLessThan extends BaseScalarBoolOp { } public ScalarLessThan(INDArray x, Number num) { - super(x, num); + this(x, null, num); } public ScalarLessThan(SameDiff sameDiff, SDVariable i_v, Number scalar, boolean inPlace) { diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/comparison/ScalarLessThanOrEqual.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/comparison/ScalarLessThanOrEqual.java index 6c9a3a893..343051ec6 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/comparison/ScalarLessThanOrEqual.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/comparison/ScalarLessThanOrEqual.java @@ -19,9 +19,11 @@ package org.nd4j.linalg.api.ops.impl.scalar.comparison; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.imports.NoOpNameFoundException; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.BaseScalarBoolOp; import org.nd4j.linalg.api.ops.BaseScalarOp; +import org.nd4j.linalg.factory.Nd4j; import java.util.Arrays; import java.util.List; @@ -49,7 +51,7 @@ public class ScalarLessThanOrEqual extends BaseScalarBoolOp { } public ScalarLessThanOrEqual(INDArray x, Number num) { - super(x, num); + this(x, null, num); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/comparison/ScalarNotEquals.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/comparison/ScalarNotEquals.java index f050b686e..52f4b7a99 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/comparison/ScalarNotEquals.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/comparison/ScalarNotEquals.java @@ -41,10 +41,9 @@ public class ScalarNotEquals extends BaseScalarBoolOp { } public ScalarNotEquals(INDArray x, Number num) { - super(x, num); + this(x, null, num); } - public ScalarNotEquals(SameDiff sameDiff, SDVariable i_v, Number scalar) { super(sameDiff, i_v, scalar); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterAdd.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterAdd.java index 160556867..1c524ccf1 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterAdd.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterAdd.java @@ -16,6 +16,7 @@ package org.nd4j.linalg.api.ops.impl.scatter; +import lombok.NonNull; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.base.Preconditions; @@ -44,12 +45,12 @@ public class ScatterAdd extends DynamicCustomOp { super(null, sameDiff, new SDVariable[]{ref, indices, updates}, false); } - public ScatterAdd(INDArray ref, INDArray indices, INDArray updates) { - addInputArgument(ref, indices, updates); - } - public ScatterAdd(){} + public ScatterAdd(@NonNull INDArray ref, @NonNull INDArray indices, @NonNull INDArray update){ + super(new INDArray[]{ref, indices, update}, null); + } + @Override public String opName() { return "scatter_add"; @@ -88,9 +89,9 @@ public class ScatterAdd extends DynamicCustomOp { List ret = new ArrayList<>(3); ret.add(gradOut.get(0)); //Reference array - ret.add(f().zerosLike(arg(1))); //Indices + ret.add(sameDiff.zerosLike(arg(1))); //Indices - SDVariable gather = f().gather(gradOut.get(0), arg(1), 0); //Updates + SDVariable gather = sameDiff.gather(gradOut.get(0), arg(1), 0); //Updates ret.add(gather); return ret; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterDiv.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterDiv.java index 5d6b60c88..c2993eb23 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterDiv.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterDiv.java @@ -16,6 +16,7 @@ package org.nd4j.linalg.api.ops.impl.scatter; +import lombok.NonNull; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.base.Preconditions; @@ -44,11 +45,12 @@ public class ScatterDiv extends DynamicCustomOp { super(null, sameDiff, new SDVariable[]{ref, indices, updates}, false); } - public ScatterDiv(INDArray ref, INDArray indices, INDArray updates) { - addInputArgument(ref, indices, updates); + public ScatterDiv() {} + + public ScatterDiv(@NonNull INDArray ref, @NonNull INDArray indices, @NonNull INDArray update){ + super(new INDArray[]{ref, indices, update}, null); } - public ScatterDiv() {} @Override public String opName() { @@ -77,13 +79,13 @@ public class ScatterDiv extends DynamicCustomOp { SDVariable updates = arg(2); List ret = new ArrayList<>(3); - SDVariable gradRef = f().scatterDiv(gradOut.get(0), indices, updates); + SDVariable gradRef = sameDiff.scatterDiv(gradOut.get(0), indices, updates); ret.add(gradRef); //Reference array - ret.add(f().zerosLike(arg(1))); //Indices + ret.add(sameDiff.zerosLike(arg(1))); //Indices - SDVariable gatherOutGrad = f().gather(gradOut.get(0), indices, 0); //Updates - SDVariable gatherRef = f().gather(ref, indices, 0); - SDVariable updateGrad = gatherOutGrad.mul(gatherRef).div(f().square(updates)).neg(); + SDVariable gatherOutGrad = sameDiff.gather(gradOut.get(0), indices, 0); //Updates + SDVariable gatherRef = sameDiff.gather(ref, indices, 0); + SDVariable updateGrad = gatherOutGrad.mul(gatherRef).div(sameDiff.math.square(updates)).neg(); ret.add(updateGrad); return ret; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterMax.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterMax.java index 7f814d928..12fe3380a 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterMax.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterMax.java @@ -16,6 +16,7 @@ package org.nd4j.linalg.api.ops.impl.scatter; +import lombok.NonNull; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.base.Preconditions; @@ -42,12 +43,12 @@ public class ScatterMax extends DynamicCustomOp { super(null, sameDiff, new SDVariable[]{ref, indices, updates}, false); } - public ScatterMax(INDArray ref, INDArray indices, INDArray updates) { - addInputArgument(ref, indices, updates); - } - public ScatterMax() {} + public ScatterMax(@NonNull INDArray ref, @NonNull INDArray indices, @NonNull INDArray update){ + super(new INDArray[]{ref, indices, update}, null); + } + @Override public String opName() { return "scatter_max"; @@ -87,12 +88,12 @@ public class ScatterMax extends DynamicCustomOp { SDVariable notModified = arg(0).eq(outputVariable()).castTo(arg(0).dataType()); //0 if modified, 1 otherwise SDVariable refGrad = gradOut.get(0).mul(notModified); - SDVariable gatherOut = f().gather(outputVariable(), arg(1), 0); - SDVariable gatherGrad = f().gather(gradOut.get(0), arg(1), 0); + SDVariable gatherOut = sameDiff.gather(outputVariable(), arg(1), 0); + SDVariable gatherGrad = sameDiff.gather(gradOut.get(0), arg(1), 0); SDVariable outIsUpdate = gatherOut.eq(arg(2)).castTo(arg(2).dataType()); SDVariable updateGrad = gatherGrad.mul(outIsUpdate); - return Arrays.asList(refGrad, f().zerosLike(arg(1)), updateGrad); + return Arrays.asList(refGrad, sameDiff.zerosLike(arg(1)), updateGrad); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterMin.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterMin.java index 2539a3d56..91a49487e 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterMin.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterMin.java @@ -16,6 +16,7 @@ package org.nd4j.linalg.api.ops.impl.scatter; +import lombok.NonNull; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.base.Preconditions; @@ -42,12 +43,12 @@ public class ScatterMin extends DynamicCustomOp { super(null, sameDiff, new SDVariable[]{ref, indices, updates}, false); } - public ScatterMin(INDArray ref, INDArray indices, INDArray updates) { - addInputArgument(ref, indices, updates); - } - public ScatterMin() {} + public ScatterMin(@NonNull INDArray ref, @NonNull INDArray indices, @NonNull INDArray update){ + super(new INDArray[]{ref, indices, update}, null); + } + @Override public String opName() { return "scatter_min"; @@ -88,12 +89,12 @@ public class ScatterMin extends DynamicCustomOp { SDVariable notModified = arg(0).eq(outputVariable()).castTo(arg(0).dataType()); //0 if modified, 1 otherwise SDVariable refGrad = gradOut.get(0).mul(notModified); - SDVariable gatherOut = f().gather(outputVariable(), arg(1), 0); - SDVariable gatherGrad = f().gather(gradOut.get(0), arg(1), 0); + SDVariable gatherOut = sameDiff.gather(outputVariable(), arg(1), 0); + SDVariable gatherGrad = sameDiff.gather(gradOut.get(0), arg(1), 0); SDVariable outIsUpdate = gatherOut.eq(arg(2)).castTo(arg(2).dataType()); SDVariable updateGrad = gatherGrad.mul(outIsUpdate); - return Arrays.asList(refGrad, f().zerosLike(arg(1)), updateGrad); + return Arrays.asList(refGrad, sameDiff.zerosLike(arg(1)), updateGrad); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterMul.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterMul.java index 411c59188..705b85d3d 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterMul.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterMul.java @@ -16,6 +16,7 @@ package org.nd4j.linalg.api.ops.impl.scatter; +import lombok.NonNull; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.base.Preconditions; @@ -44,12 +45,12 @@ public class ScatterMul extends DynamicCustomOp { super(null, sameDiff, new SDVariable[]{ref, indices, updates}, false); } - public ScatterMul(INDArray ref, INDArray indices, INDArray updates) { - addInputArgument(ref, indices, updates); - } - public ScatterMul() {} + public ScatterMul(@NonNull INDArray ref, @NonNull INDArray indices, @NonNull INDArray update){ + super(new INDArray[]{ref, indices, update}, null); + } + @Override public String opName() { return "scatter_mul"; @@ -91,12 +92,12 @@ public class ScatterMul extends DynamicCustomOp { SDVariable updates = arg(2); List ret = new ArrayList<>(3); - SDVariable gradRef = f().scatterMul(gradOut.get(0), indices, updates); + SDVariable gradRef = sameDiff.scatterMul(gradOut.get(0), indices, updates); ret.add(gradRef); //Reference array - ret.add(f().zerosLike(arg(1))); //Indices + ret.add(sameDiff.zerosLike(arg(1))); //Indices - SDVariable gatherOutGrad = f().gather(gradOut.get(0), indices, 0); //Updates - SDVariable gatherRef = f().gather(ref, indices, 0); + SDVariable gatherOutGrad = sameDiff.gather(gradOut.get(0), indices, 0); //Updates + SDVariable gatherRef = sameDiff.gather(ref, indices, 0); SDVariable updateGrad = gatherOutGrad.mul(gatherRef); ret.add(updateGrad); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterSub.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterSub.java index 83c4cc222..15e6d5ac2 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterSub.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterSub.java @@ -44,12 +44,12 @@ public class ScatterSub extends DynamicCustomOp { super(null, sameDiff, new SDVariable[]{ref, indices, updates}, false); } - public ScatterSub(INDArray ref, INDArray indices, INDArray updates) { - addInputArgument(ref, indices, updates); - } - public ScatterSub() {} + public ScatterSub(INDArray ref, INDArray indices, INDArray update){ + super(new INDArray[]{ref, indices, update}, null); + } + @Override public String opName() { return "scatter_sub"; @@ -74,9 +74,9 @@ public class ScatterSub extends DynamicCustomOp { List ret = new ArrayList<>(3); ret.add(gradOut.get(0)); //Reference array - ret.add(f().zerosLike(arg(1))); //Indices + ret.add(sameDiff.zerosLike(arg(1))); //Indices - SDVariable gather = f().gather(gradOut.get(0), arg(1), 0); //Updates + SDVariable gather = sameDiff.gather(gradOut.get(0), arg(1), 0); //Updates ret.add(gather.neg()); return ret; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterUpdate.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterUpdate.java index 93e1e5995..2e87af624 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterUpdate.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterUpdate.java @@ -54,12 +54,12 @@ public class ScatterUpdate extends DynamicCustomOp { super(null, sameDiff, new SDVariable[]{ref, indices, updates}, false); } - public ScatterUpdate(INDArray ref, INDArray indices, INDArray updates) { - addInputArgument(ref, indices, updates); - } - public ScatterUpdate(){} + public ScatterUpdate(INDArray ref, INDArray indices, INDArray update){ + super(new INDArray[]{ref, indices, update}, null); + } + @Override public String opName() { return "scatter_upd"; @@ -98,12 +98,12 @@ public class ScatterUpdate extends DynamicCustomOp { SDVariable updates = arg(2); List ret = new ArrayList<>(3); - SDVariable zerosUpdate = f().zerosLike(updates); - SDVariable gradRef = f().scatterMul(gradOut.get(0), indices, zerosUpdate); //TODO optimize + SDVariable zerosUpdate = sameDiff.zerosLike(updates); + SDVariable gradRef = sameDiff.scatterMul(gradOut.get(0), indices, zerosUpdate); //TODO optimize ret.add(gradRef); //Reference array gradient - ret.add(f().zerosLike(arg(1))); //Indices + ret.add(sameDiff.zerosLike(arg(1))); //Indices - SDVariable gather = f().gather(gradOut.get(0), indices, 0); //Updates + SDVariable gather = sameDiff.gather(gradOut.get(0), indices, 0); //Updates ret.add(gather); return ret; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Concat.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Concat.java index bddcef970..85cf62247 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Concat.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Concat.java @@ -44,7 +44,7 @@ public class Concat extends DynamicCustomOp { } public Concat(int concatDimension, INDArray... arrays) { - super(null, arrays, new INDArray[0]); + super(null, arrays, null); this.concatDimension = concatDimension; addIArgument(concatDimension); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/ExpandDims.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/ExpandDims.java index f0a6f436a..c51207d64 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/ExpandDims.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/ExpandDims.java @@ -67,15 +67,16 @@ public class ExpandDims extends DynamicCustomOp { super(null, inputs, outputs); } - public ExpandDims(INDArray input, int axis) { - addInputArgument(input); - addIArgument(axis); - } - public ExpandDims(SameDiff sameDiff, SDVariable[] args, boolean inPlace) { super(null, sameDiff, args, inPlace); } + public ExpandDims(INDArray x, int axis){ + super(new INDArray[]{x}, null); + this.jaxis = axis; + addIArgument(axis); + } + @Override public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map attributesForNode, GraphDef graph) { val targetNode = TFGraphMapper.getNodeWithNameFromGraph(graph, nodeDef.getInput(1)); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/GatherNd.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/GatherNd.java index 593531098..7ac8429c8 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/GatherNd.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/GatherNd.java @@ -42,6 +42,10 @@ public class GatherNd extends DynamicCustomOp { super(new INDArray[]{df, indices}, null); } + public GatherNd(INDArray[] inputs, INDArray[] outputs){ + super(inputs, outputs); + } + @Override public String opName() { return "gather_nd"; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Linspace.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Linspace.java index 4bc3b3f63..f83d61c0d 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Linspace.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Linspace.java @@ -1,5 +1,6 @@ /******************************************************************************* * Copyright (c) 2015-2019 Skymind, Inc. + * Copyright (c) 2020 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -42,6 +43,9 @@ import java.util.Map; public class Linspace extends DynamicCustomOp { private DataType dataType; + private double start; + private double stop; + private long elements; public Linspace(SameDiff sameDiff, DataType dataType, double start, double stop, long number) { this(sameDiff, sameDiff.constant(start), sameDiff.constant(stop), sameDiff.constant(number), dataType); @@ -54,7 +58,7 @@ public class Linspace extends DynamicCustomOp { } public Linspace(DataType dataType, double start, double stop, long number) { - this(dataType, Nd4j.scalar(start), Nd4j.scalar(stop), Nd4j.scalar(number)); + this(start, stop, number, dataType); } public Linspace(DataType dataType, INDArray start, INDArray stop, INDArray number) { @@ -67,6 +71,19 @@ public class Linspace extends DynamicCustomOp { addDArgument(dataType); } + public Linspace(double start, double stop, long number, @NonNull DataType dataType) { + super(new INDArray[]{}, null); + this.dataType = dataType; + addDArgument(dataType); + + this.start = start; + this.stop = stop; + this.elements = number; + + addTArgument(this.start, this.stop); + addIArgument(elements); + } + public Linspace(){ } @Override @@ -101,6 +118,6 @@ public class Linspace extends DynamicCustomOp { @Override public List doDiff(List gradients){ - return Arrays.asList(f().zerosLike(arg(0)), f().zerosLike(arg(1)), f().zerosLike(arg(2))); + return Arrays.asList(sameDiff.zerosLike(arg(0)), sameDiff.zerosLike(arg(1)), sameDiff.zerosLike(arg(2))); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/MergeAvg.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/MergeAvg.java index b63052eb5..3fc734b6b 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/MergeAvg.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/MergeAvg.java @@ -24,15 +24,13 @@ import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; +import org.nd4j.linalg.api.ops.impl.shape.bp.MergeAvgBp; import org.nd4j.linalg.factory.Nd4j; import org.tensorflow.framework.AttrValue; import org.tensorflow.framework.GraphDef; import org.tensorflow.framework.NodeDef; -import java.util.ArrayList; -import java.util.Collections; -import java.util.List; -import java.util.Map; +import java.util.*; @Slf4j public class MergeAvg extends DynamicCustomOp { @@ -74,12 +72,8 @@ public class MergeAvg extends DynamicCustomOp { @Override public List doDiff(List i_v) { - int nArgs = args().length; - SDVariable gradient = sameDiff.setupFunction(i_v.get(0)).div(nArgs); - List ret = new ArrayList<>(); - for (int i = 0; i < args().length; i++) - ret.add(gradient); - return ret; + return Arrays.asList(new MergeAvgBp(sameDiff, args(), i_v.get(0)).outputVariables()); + } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/MergeMax.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/MergeMax.java index 2b954e8b7..4e41344fb 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/MergeMax.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/MergeMax.java @@ -24,14 +24,12 @@ import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; +import org.nd4j.linalg.api.ops.impl.shape.bp.MergeMaxBp; import org.tensorflow.framework.AttrValue; import org.tensorflow.framework.GraphDef; import org.tensorflow.framework.NodeDef; -import java.util.ArrayList; -import java.util.Collections; -import java.util.List; -import java.util.Map; +import java.util.*; @Slf4j public class MergeMax extends DynamicCustomOp { @@ -71,14 +69,8 @@ public class MergeMax extends DynamicCustomOp { @Override public List doDiff(List i_v) { - SDVariable gradient = sameDiff.setupFunction(i_v.get(0)); - List ret = new ArrayList<>(); - SDVariable out = outputVariable(); - for (int i = 0; i < args().length; i++){ - SDVariable isMax = out.eq(arg(i)).castTo(arg(i).dataType()); - ret.add(isMax.mul(gradient)); - } - return ret; + return Arrays.asList(new MergeMaxBp(sameDiff, args(), i_v.get(0)).outputVariables()); + } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/OneHot.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/OneHot.java index affc603e9..c08dcb1d6 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/OneHot.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/OneHot.java @@ -67,8 +67,7 @@ public class OneHot extends DynamicCustomOp { } public OneHot(INDArray indices, int depth) { - addInputArgument(indices); - addIArgument(depth); + this(indices, null, depth, 0, 1.0, 0.0); } public OneHot(INDArray indices, INDArray output, int depth, int axis, double on, double off) { @@ -80,14 +79,16 @@ public class OneHot extends DynamicCustomOp { addArgs(); } - public OneHot(INDArray indices, int depth, int axis, double on, double off, DataType dataType) { - addInputArgument(indices); - addIArgument(depth, axis); - addTArgument(on, off); - addDArgument(dataType); + public OneHot(INDArray indices, int depth, int axis, double on, double off) { + this(indices, null, depth, axis, on, off); } - + public OneHot(INDArray indices, int depth, int axis, double on, double off, DataType dataType) { + this(indices, null, depth, axis, on, off); + this.outputType = dataType; + if (outputType != null) + addDArgument(outputType); + } protected void addArgs() { addIArgument(jaxis); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/ParallelStack.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/ParallelStack.java index 1827a6589..b4d71b40f 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/ParallelStack.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/ParallelStack.java @@ -23,6 +23,7 @@ import org.nd4j.base.Preconditions; import org.nd4j.imports.descriptors.properties.PropertyMapping; import org.nd4j.imports.graphmapper.tf.TFGraphMapper; import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.tensorflow.framework.AttrValue; import org.tensorflow.framework.GraphDef; @@ -44,6 +45,10 @@ public class ParallelStack extends DynamicCustomOp { super(null, sameDiff, values, false); } + public ParallelStack(INDArray[] inputs){ + super(inputs, null); + } + @Override public String opName() { return "parallel_stack"; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Permute.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Permute.java index cfd0bd7ed..85b9e2bb4 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Permute.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Permute.java @@ -55,15 +55,16 @@ public class Permute extends Transpose { addIArgument(permuteDims); } - public Permute(INDArray input, int... permuteDims){ - addInputArgument(input); - addIArgument(permuteDims); - } - public Permute(SameDiff sd, SDVariable input, SDVariable permuteDims){ super(sd, input, permuteDims); } + public Permute(INDArray input, int... permuteDims){ + super(input, null); + this.permuteDims = permuteDims; + addIArgument(permuteDims); + } + public Permute() { } @@ -77,10 +78,10 @@ public class Permute extends Transpose { SDVariable ret; if(args().length == 1) { //Static dimensions - ret = f().permute(i_v.get(0), reverseDims); + ret = sameDiff.permute(i_v.get(0), reverseDims); } else { //Dynamic dimensions - ret = f().permute(i_v.get(0), sameDiff.invertPermutation(arg(1))); + ret = sameDiff.permute(i_v.get(0), sameDiff.invertPermutation(arg(1))); } return Collections.singletonList(ret); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Reshape.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Reshape.java index 2126dfe27..47960fad3 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Reshape.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Reshape.java @@ -16,6 +16,7 @@ package org.nd4j.linalg.api.ops.impl.shape; +import lombok.NoArgsConstructor; import lombok.NonNull; import lombok.extern.slf4j.Slf4j; import lombok.val; @@ -41,6 +42,7 @@ import java.util.Map; * @author Adam Gibson */ @Slf4j +@NoArgsConstructor public class Reshape extends DynamicCustomOp { private long[] shape; @@ -61,15 +63,13 @@ public class Reshape extends DynamicCustomOp { addIArgument(shape); } - public Reshape(INDArray in, INDArray shape){ - this(in, shape, null); - } public Reshape(@NonNull INDArray in, @NonNull INDArray shape, INDArray out){ super(null, new INDArray[]{in, shape}, wrapOrNull(out), null, (List)null); } - public Reshape() { + public Reshape(INDArray in, INDArray shape){ + this(in, shape, null); } @@ -152,8 +152,8 @@ public class Reshape extends DynamicCustomOp { @Override public List doDiff(List i_v) { - SDVariable origShape = f().shape(arg()); - SDVariable ret = f().reshape(i_v.get(0), origShape); + SDVariable origShape = sameDiff.shape(arg()); + SDVariable ret = sameDiff.reshape(i_v.get(0), origShape); return Collections.singletonList(ret); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/SequenceMask.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/SequenceMask.java index 3c3baf1f6..94a6e6c2b 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/SequenceMask.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/SequenceMask.java @@ -120,7 +120,7 @@ public class SequenceMask extends DynamicCustomOp { @Override public List doDiff(List grad){ //Input is integer indices - return Collections.singletonList(f().zerosLike(arg())); + return Collections.singletonList(sameDiff.zerosLike(arg())); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/ShapeN.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/ShapeN.java index 593830327..d9c4c4578 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/ShapeN.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/ShapeN.java @@ -65,7 +65,7 @@ public class ShapeN extends DynamicCustomOp { public List doDiff(List i_v) { List out = new ArrayList<>(); for(SDVariable in : args()){ - out.add(f().zerosLike(in)); + out.add(sameDiff.zerosLike(in)); } return out; } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Size.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Size.java index ce3ce9cae..acec28f68 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Size.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Size.java @@ -48,11 +48,10 @@ public class Size extends DynamicCustomOp { super(null, sameDiff, new SDVariable[] {input}, false); } - public Size(INDArray in) { - addInputArgument(in); + public Size(INDArray in){ + super(new INDArray[] {in}, null); } - @Override public String onnxName() { throw new NoOpNameFoundException("No onnx name found for shape " + opName()); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Slice.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Slice.java index 46b8f6286..effe95b2b 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Slice.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Slice.java @@ -25,6 +25,7 @@ import org.nd4j.imports.NoOpNameFoundException; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; +import org.nd4j.linalg.api.ops.impl.shape.bp.SliceBp; import java.util.*; @@ -53,8 +54,10 @@ public class Slice extends DynamicCustomOp { super(null, sameDiff, new SDVariable[]{input, begin, end}); } - public Slice(INDArray in, int[] begin, int... size) { - addInputArgument(in); + public Slice(INDArray input, int[] begin, int... size){ + super(new INDArray[] {input}, null); + this.begin = begin; + this.size = size; addIArgument(begin); addIArgument(size); } @@ -82,10 +85,10 @@ public class Slice extends DynamicCustomOp { @Override public List doDiff(List grad) { if(args().length == 1) { - return Collections.singletonList(f().sliceBp(arg(), grad.get(0), begin, size)); + return new SliceBp(sameDiff, arg(), grad.get(0), begin, size).outputs(); } else { //Dynamic begin/size - return Collections.singletonList(f().sliceBp(arg(0), grad.get(0), arg(1), arg(2))); + return new SliceBp(sameDiff, arg(0), grad.get(0), arg(1), arg(2)).outputs(); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Stack.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Stack.java index 17a8beb3c..28e92930c 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Stack.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Stack.java @@ -129,7 +129,7 @@ public class Stack extends DynamicCustomOp { @Override public List doDiff(List f1) { - return Arrays.asList(f().unstack(f1.get(0), jaxis, args().length)); + return Arrays.asList(sameDiff.unstack(f1.get(0), jaxis, args().length)); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/StridedSlice.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/StridedSlice.java index 456edfe1c..53deb43ca 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/StridedSlice.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/StridedSlice.java @@ -16,6 +16,7 @@ package org.nd4j.linalg.api.ops.impl.shape; +import lombok.NoArgsConstructor; import lombok.NonNull; import lombok.extern.slf4j.Slf4j; import lombok.val; @@ -27,6 +28,7 @@ import org.nd4j.imports.descriptors.properties.PropertyMapping; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; +import org.nd4j.linalg.api.ops.impl.shape.bp.StridedSliceBp; import org.nd4j.linalg.exception.ND4JIllegalStateException; import org.nd4j.linalg.util.ArrayUtil; import org.tensorflow.framework.AttrValue; @@ -259,12 +261,12 @@ public class StridedSlice extends DynamicCustomOp { public List doDiff(List i_v) { if(args().length == 1) { //Array inputs for begin/end/strides - return Collections.singletonList(f().stridedSliceBp(arg(), i_v.get(0), begin, end, strides, beginMask, endMask, - ellipsisMask, newAxisMask, shrinkAxisMask)); + return new StridedSliceBp(sameDiff, arg(), i_v.get(0), begin, end, strides, beginMask, endMask, + ellipsisMask, newAxisMask, shrinkAxisMask).outputs(); } else { //SDVariable inputs for begin/end/strides - return Collections.singletonList(f().stridedSliceBp(arg(), i_v.get(0), arg(1), arg(2), arg(3), beginMask, endMask, - ellipsisMask, newAxisMask, shrinkAxisMask)); + return new StridedSliceBp(sameDiff, arg(), i_v.get(0), arg(1), arg(2), arg(3), beginMask, endMask, + ellipsisMask, newAxisMask, shrinkAxisMask).outputs(); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Tile.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Tile.java index c2e476f60..687342ed8 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Tile.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Tile.java @@ -24,6 +24,7 @@ import org.nd4j.imports.descriptors.properties.PropertyMapping; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; +import org.nd4j.linalg.api.ops.impl.shape.bp.TileBp; import org.tensorflow.framework.AttrValue; import org.tensorflow.framework.GraphDef; import org.tensorflow.framework.NodeDef; @@ -66,13 +67,16 @@ public class Tile extends DynamicCustomOp { this(inputs,outputs,axis,false); } - public Tile(INDArray x, INDArray repeat) { - addInputArgument(x, repeat); + public Tile(INDArray x, INDArray repeat){ + super(null, new INDArray[] {x, repeat}, null); + this.jaxis = null; } - public Tile(INDArray x, int... repeat) { - addInputArgument(x); - addIArgument(repeat); + public Tile(INDArray inputs, int... axis){ + super(null, new INDArray[] {inputs}, null); + this.jaxis = axis; + this.is_static_reps = true; + addArguments(); } public Tile() {} @@ -126,9 +130,9 @@ public class Tile extends DynamicCustomOp { @Override public List doDiff(List i_v) { if(jaxis != null){ - return Collections.singletonList(f().tileBp(arg(), i_v.get(0), jaxis)); + return new TileBp(sameDiff, arg(), i_v.get(0), jaxis).outputs(); }else{ - return Collections.singletonList(f().tileBp(arg(0), arg(1), i_v.get(0))); + return new TileBp(sameDiff, arg(0), arg(1), i_v.get(0)).outputs(); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Transpose.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Transpose.java index 95215b686..ea4096f63 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Transpose.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Transpose.java @@ -60,8 +60,8 @@ public class Transpose extends DynamicCustomOp { super(null, new INDArray[]{input}, result == null ? null : new INDArray[]{result}, null, (List) null); } - public Transpose(INDArray input) { - addInputArgument(input); + public Transpose(INDArray input){ + this(input, null); } public Transpose() { diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/bp/MergeAvgBp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/bp/MergeAvgBp.java new file mode 100644 index 000000000..54d39ce89 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/bp/MergeAvgBp.java @@ -0,0 +1,57 @@ +/* ****************************************************************************** + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.nd4j.linalg.api.ops.impl.shape.bp; + +import lombok.NoArgsConstructor; +import lombok.NonNull; +import org.apache.commons.lang3.ArrayUtils; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ops.DynamicCustomOp; + +import java.util.ArrayList; +import java.util.List; + + +@NoArgsConstructor +public class MergeAvgBp extends DynamicCustomOp { + + public MergeAvgBp(SameDiff sameDiff, @NonNull SDVariable[] inputs, @NonNull SDVariable gradO) { + super("mergeavg_bp", sameDiff, ArrayUtils.add(inputs, gradO)); + } + + @Override + public String opName() { + return "mergeavg_bp"; + } + + @Override + public List calculateOutputDataTypes(List inputDataTypes) { + ArrayList list = new ArrayList(); + for (int i = 0; i < args().length - 1; i++) { + list.add(inputDataTypes.get(0)); + } + return list; + + } + + @Override + public int getNumOutputs() { + return args().length - 1; + } + +} \ No newline at end of file diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/bp/MergeMaxBp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/bp/MergeMaxBp.java new file mode 100644 index 000000000..792036b76 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/bp/MergeMaxBp.java @@ -0,0 +1,56 @@ +/* ****************************************************************************** + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.nd4j.linalg.api.ops.impl.shape.bp; + +import lombok.NoArgsConstructor; +import lombok.NonNull; +import org.apache.commons.lang3.ArrayUtils; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ops.DynamicCustomOp; + +import java.util.ArrayList; +import java.util.List; + + +@NoArgsConstructor +public class MergeMaxBp extends DynamicCustomOp { + + public MergeMaxBp(SameDiff sameDiff, @NonNull SDVariable[] inputs, @NonNull SDVariable gradO) { + super("mergemax_bp", sameDiff, ArrayUtils.add(inputs, gradO)); + } + + @Override + public String opName() { + return "mergemax_bp"; + } + + @Override + public List calculateOutputDataTypes(List inputDataTypes) { + List list = new ArrayList(); + for (int i=0; i< args().length-1;i++){ + list.add(inputDataTypes.get(0)); + } + return list; + + } + + @Override + public int getNumOutputs(){ + return args().length-1; + } +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/tensorops/EmbeddingLookup.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/tensorops/EmbeddingLookup.java new file mode 100644 index 000000000..e59abc268 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/tensorops/EmbeddingLookup.java @@ -0,0 +1,71 @@ +/* ****************************************************************************** + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.nd4j.linalg.api.ops.impl.shape.tensorops; + +import lombok.NoArgsConstructor; +import lombok.NonNull; +import lombok.val; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.base.Preconditions; +import org.nd4j.enums.PartitionMode; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.DynamicCustomOp; +import org.nd4j.linalg.api.shape.LongShapeDescriptor; +import org.nd4j.linalg.factory.Nd4j; + +import java.util.Collections; +import java.util.List; + +@NoArgsConstructor +public class EmbeddingLookup extends DynamicCustomOp { + + public EmbeddingLookup(@NonNull SameDiff sameDiff, @NonNull SDVariable in, @NonNull SDVariable indices, PartitionMode partitionMode) { + super("embedding_lookup", sameDiff, new SDVariable[]{in, indices}); + addIArgument(partitionMode.ordinal()); + } + + public EmbeddingLookup(@NonNull INDArray in, @NonNull INDArray indices, PartitionMode partitionMode, INDArray output) { + super("embedding_lookup", new INDArray[]{in, indices}, wrapOrNull(output)); + addIArgument(partitionMode.ordinal()); + + } + + public EmbeddingLookup(@NonNull INDArray in, INDArray output, PartitionMode partitionMode, @NonNull int... indices) { + super("embedding_lookup", new INDArray[]{in, Nd4j.createFromArray(indices)}, wrapOrNull(output)); + addIArgument(partitionMode.ordinal()); + + + } + + @Override + public String opName() { + return "embedding_lookup"; + } + + @Override + public List calculateOutputDataTypes(List dataTypes) { + Preconditions + .checkArgument(dataTypes != null && dataTypes.size() == 2, "Expected exactly 2 input datatypes, got %s", dataTypes); + Preconditions.checkArgument(dataTypes.get(0).isFPType(), "Input datatype must be floating point, got %s", dataTypes); + Preconditions.checkArgument(dataTypes.get(1).isIntType(), "Input datatype must be integer point, got %s", dataTypes); + + return Collections.singletonList(dataTypes.get(0)); + } + + +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/summarystats/StandardDeviation.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/summarystats/StandardDeviation.java index 211fec834..2b6359985 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/summarystats/StandardDeviation.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/summarystats/StandardDeviation.java @@ -21,6 +21,7 @@ import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.imports.NoOpNameFoundException; import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.impl.reduce.bp.StandardDeviationBp; import org.nd4j.linalg.api.shape.LongShapeDescriptor; import org.nd4j.linalg.api.shape.Shape; import org.nd4j.linalg.exception.ND4JIllegalStateException; @@ -101,7 +102,7 @@ public class StandardDeviation extends Variance { //If out = stdev(in) then: //dL/dIn = dL/dOut * dOut/dIn //dOut/dIn_i = (in_i-mean)/(stdev * (n-1)) - return Collections.singletonList(f().stdBp(arg(), grad.get(0), biasCorrected, keepDims, dimensions)); + return new StandardDeviationBp(sameDiff, arg(), grad.get(0), biasCorrected, keepDims, dimensions).outputs(); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/summarystats/Variance.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/summarystats/Variance.java index adc92549b..64948880c 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/summarystats/Variance.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/summarystats/Variance.java @@ -25,6 +25,7 @@ import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.BaseReduceOp; import org.nd4j.linalg.api.ops.OpContext; +import org.nd4j.linalg.api.ops.impl.reduce.bp.VarianceBp; import org.nd4j.linalg.api.shape.LongShapeDescriptor; import org.nd4j.linalg.api.shape.Shape; import org.nd4j.linalg.exception.ND4JIllegalStateException; @@ -115,7 +116,7 @@ public class Variance extends BaseReduceOp { //If out = var(in) then: //dL/dIn = dL/dOut * dOut/dIn // with dOut/dIn = (in-mean) * 2/(n-1) - return Collections.singletonList(f().varianceBp(arg(), grad.get(0), biasCorrected, keepDims, dimensions)); + return new VarianceBp(sameDiff, arg(), grad.get(0), biasCorrected, keepDims, dimensions).outputs(); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/Angle.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/Angle.java index 66eeb9b99..7fed19e02 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/Angle.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/Angle.java @@ -44,6 +44,6 @@ public class Angle extends DynamicCustomOp { @Override public List doDiff(List i_v) { - return Collections.singletonList(f().zerosLike(arg())); + return Collections.singletonList(sameDiff.zerosLike(arg())); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/Pad.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/Pad.java index b7bd0e0f6..ef8283be8 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/Pad.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/Pad.java @@ -21,6 +21,7 @@ import org.nd4j.autodiff.samediff.SDIndex; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.base.Preconditions; +import org.nd4j.enums.PadMode; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; @@ -46,6 +47,24 @@ public class Pad extends DynamicCustomOp { public Pad(){ } + private static Mode adaptMode(PadMode mode) { + Mode legacyMode = Mode.CONSTANT; + + if (mode == PadMode.CONSTANT) { + legacyMode = Mode.CONSTANT; + } + else if (mode == PadMode.REFLECT) { + legacyMode = Mode.REFLECT; + } + else if (mode == PadMode.SYMMETRIC) { + legacyMode = Mode.SYMMETRIC; + } + return legacyMode; + } + + public Pad(SameDiff sd, SDVariable in, SDVariable padding, PadMode mode, double padValue) { + this(sd, in, padding, adaptMode(mode), padValue); + } public Pad(SameDiff sd, SDVariable in, SDVariable padding, Mode mode, double padValue) { super(sd, new SDVariable[]{in, padding}, false); Preconditions.checkState(padding.dataType().isIntType(), "Padding array must be an integer datatype, got %s", padding.dataType()); @@ -62,6 +81,10 @@ public class Pad extends DynamicCustomOp { this(in, padding, null, Mode.CONSTANT, padValue); } + public Pad(@NonNull INDArray in, @NonNull INDArray padding, @NonNull PadMode mode, double padValue) { + this(in, padding, null, adaptMode(mode), padValue); + } + public Pad(@NonNull INDArray in, @NonNull INDArray padding, INDArray out, @NonNull Mode mode, double padValue){ super(null, new INDArray[]{in, padding}, out == null ? null : new INDArray[]{out}); Preconditions.checkState(padding.dataType().isIntType(), "Padding array must be an integer datatype, got %s", padding.dataType()); @@ -70,6 +93,10 @@ public class Pad extends DynamicCustomOp { addTArgument(padValue); } + public Pad(@NonNull INDArray in, @NonNull INDArray padding, INDArray out, @NonNull PadMode mode, double padValue) { + this(in, padding, out, adaptMode(mode), padValue); + } + @Override public String opName(){ return "pad"; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/any/IsMax.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/any/IsMax.java index 218ec66db..54f91f6d1 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/any/IsMax.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/any/IsMax.java @@ -73,7 +73,7 @@ public class IsMax extends DynamicCustomOp { @Override public List doDiff(List f1) { - return Collections.singletonList(f().zerosLike(arg())); + return Collections.singletonList(sameDiff.zerosLike(arg())); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/bool/BooleanNot.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/bool/BooleanNot.java index 4dd948b4f..053199731 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/bool/BooleanNot.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/bool/BooleanNot.java @@ -75,6 +75,6 @@ public class BooleanNot extends BaseTransformBoolOp { @Override public List doDiff(List f1) { - return Collections.singletonList(f().zerosLike(arg())); + return Collections.singletonList(sameDiff.zerosLike(arg())); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/bool/IsFinite.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/bool/IsFinite.java index 8df844943..8cf81febf 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/bool/IsFinite.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/bool/IsFinite.java @@ -73,7 +73,7 @@ public class IsFinite extends BaseTransformBoolOp { @Override public List doDiff(List i_v) { - return Collections.singletonList(f().zerosLike(arg())); + return Collections.singletonList(sameDiff.zerosLike(arg())); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/bool/IsInf.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/bool/IsInf.java index 44cb362a4..95d75e9be 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/bool/IsInf.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/bool/IsInf.java @@ -73,7 +73,7 @@ public class IsInf extends BaseTransformBoolOp { @Override public List doDiff(List i_v) { - return Collections.singletonList(f().zerosLike(arg())); + return Collections.singletonList(sameDiff.zerosLike(arg())); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/bool/IsNaN.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/bool/IsNaN.java index daf9b0ea3..9f8e9ea74 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/bool/IsNaN.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/bool/IsNaN.java @@ -74,7 +74,7 @@ public class IsNaN extends BaseTransformBoolOp { @Override public List doDiff(List i_v) { - return Collections.singletonList(f().zerosLike(arg())); + return Collections.singletonList(sameDiff.zerosLike(arg())); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/bool/MatchConditionTransform.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/bool/MatchConditionTransform.java index dea1c9c3b..b3810ba15 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/bool/MatchConditionTransform.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/bool/MatchConditionTransform.java @@ -62,12 +62,10 @@ public class MatchConditionTransform extends BaseTransformBoolOp { this(x, z, Nd4j.EPS_THRESHOLD, condition); } - public MatchConditionTransform(INDArray x, @NonNull Condition condition) { this(x, null, Nd4j.EPS_THRESHOLD, condition); } - public MatchConditionTransform(INDArray x, INDArray z, double eps, @NonNull Condition condition) { super(x, null, z); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/clip/ClipByAvgNorm.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/clip/ClipByAvgNorm.java new file mode 100644 index 000000000..a5f53622b --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/clip/ClipByAvgNorm.java @@ -0,0 +1,71 @@ +/* ****************************************************************************** + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.nd4j.linalg.api.ops.impl.transforms.clip; + +import lombok.NoArgsConstructor; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.base.Preconditions; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.DynamicCustomOp; + +import java.util.Collections; +import java.util.List; + + +@NoArgsConstructor +public class ClipByAvgNorm extends DynamicCustomOp { + + private double clipValue; + + + public ClipByAvgNorm(SameDiff sameDiff, SDVariable x, double clipValue, int... dimensions) { + super("clipbyavgnorm", sameDiff, new SDVariable[]{x}); + this.clipValue = clipValue; + this.dimensions = dimensions; + addIArgument(dimensions); + addTArgument(clipValue); + } + + public ClipByAvgNorm(INDArray in, double clipValue, int... dimensions){ + this(in, null, clipValue, dimensions); + } + + public ClipByAvgNorm(INDArray in, INDArray out, double clipValue, int... dimensions){ + super("clipbyavgnorm", new INDArray[]{in}, wrapOrNull(out), Collections.singletonList(clipValue), dimensions); + } + + @Override + public String opName() { + return "clipbyavgnorm"; + } + + + + @Override + public List doDiff(List grad) { + throw new UnsupportedOperationException("Not yet implemented"); } + + @Override + public List calculateOutputDataTypes(List inputDataTypes){ + Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == 1, "Expected exactly 1 input datatype for %s, got %s", getClass(), inputDataTypes); + return inputDataTypes; + } + +} + + diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/clip/ClipByNorm.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/clip/ClipByNorm.java index fadd8720e..ef1ebb38a 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/clip/ClipByNorm.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/clip/ClipByNorm.java @@ -73,7 +73,7 @@ public class ClipByNorm extends DynamicCustomOp { @Override public List doDiff(List grad) { - return Collections.singletonList(new ClipByNormBp(f().sameDiff(), arg(), grad.get(0), clipValue, dimensions).outputVariable()); + return new ClipByNormBp(sameDiff, arg(), grad.get(0), clipValue, dimensions).outputs(); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/clip/ClipByValue.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/clip/ClipByValue.java index fa465b251..44cde0abb 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/clip/ClipByValue.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/clip/ClipByValue.java @@ -83,8 +83,8 @@ public class ClipByValue extends DynamicCustomOp { @Override public List doDiff(List grad) { //dOut/dIn is 0 if clipped, 1 otherwise - SDVariable notClippedLower = f().gt(arg(), clipValueMin).castTo(arg().dataType()); - SDVariable notClippedUpper = f().lt(arg(), clipValueMax).castTo(arg().dataType()); + SDVariable notClippedLower = sameDiff.gt(arg(), clipValueMin).castTo(arg().dataType()); + SDVariable notClippedUpper = sameDiff.lt(arg(), clipValueMax).castTo(arg().dataType()); SDVariable ret = notClippedLower.mul(notClippedUpper).mul(grad.get(0)); return Collections.singletonList(ret); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/comparison/CompareAndReplace.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/comparison/CompareAndReplace.java index 80ab7fd35..9c5b54c72 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/comparison/CompareAndReplace.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/comparison/CompareAndReplace.java @@ -69,7 +69,7 @@ public class CompareAndReplace extends BaseTransformSameOp { * @param condition */ public CompareAndReplace(INDArray x, INDArray y, Condition condition) { - this(x, y, x, condition); + this(x, y, null, condition); } /** diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/ATan2.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/ATan2.java index d6230e153..e847142a6 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/ATan2.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/ATan2.java @@ -85,17 +85,7 @@ public class ATan2 extends BaseDynamicTransformOp { SDVariable y = larg(); SDVariable x = rarg(); -/* SDVariable r = y.div(x); - - SDVariable dOutdr = f().square(r).add(1.0).rdiv(1.0); - SDVariable drdy = x.rdiv(1.0); - SDVariable drdx = f().neg(y).div(f().square(x)); - - SDVariable xGrad = dOutdr.mul(drdx).mul(i_v.get(0)); - SDVariable yGrad = dOutdr.mul(drdy).mul(i_v.get(0)); -*/ - - val xGrad = f().neg(y.div(x.pow(2).add(y.pow(2)))).mul(i_v.get(0)); + val xGrad = sameDiff.math.neg(y.div(x.pow(2).add(y.pow(2)))).mul(i_v.get(0)); val yGrad = x.div(x.pow(2).add(y.pow(2))).mul(i_v.get(0)); return Arrays.asList(yGrad, xGrad); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Assign.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Assign.java index 35c209870..e9ce44c57 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Assign.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Assign.java @@ -24,6 +24,7 @@ import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.api.ops.Op; +import org.nd4j.linalg.factory.Nd4j; import org.tensorflow.framework.AttrValue; import org.tensorflow.framework.GraphDef; import org.tensorflow.framework.NodeDef; @@ -46,6 +47,10 @@ public class Assign extends DynamicCustomOp { super(null,inputs, outputs); } + public Assign(INDArray x, INDArray y ) { + this( new INDArray[]{y ,x},new INDArray[]{y}); // TODO: Still check. y cannot be null, must be same shape as x. + } + @Override public void addIArgument(int... arg) { super.addIArgument(arg); @@ -89,7 +94,7 @@ public class Assign extends DynamicCustomOp { @Override public List doDiff(List f1){ //TODO replace with assign backprop op from libnd4j (that handles the broadcast case properly) - return Arrays.asList(f().zerosLike(larg()), f1.get(0)); + return Arrays.asList(sameDiff.zerosLike(larg()), f1.get(0)); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/CReLU.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/CReLU.java new file mode 100644 index 000000000..d442bc141 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/CReLU.java @@ -0,0 +1,65 @@ +/* ****************************************************************************** + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.nd4j.linalg.api.ops.impl.transforms.custom; + +import lombok.NoArgsConstructor; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.DynamicCustomOp; +import org.nd4j.base.Preconditions; + +import java.util.Collections; +import java.util.List; +import lombok.Getter; +import lombok.NonNull; + +@NoArgsConstructor +public class CReLU extends DynamicCustomOp { + + + public CReLU(SameDiff sd, SDVariable input) { + super(sd, new SDVariable[]{input}); + } + + public CReLU(@NonNull INDArray input) { + super(new INDArray[]{input}, null); + + } + + + @Override + public String opName() { + return "crelu"; + } + + @Override + public List calculateOutputDataTypes(List dataTypes) { + Preconditions + .checkArgument(dataTypes != null && dataTypes.size() == 1, "Expected exactly 1 input datatypes, got %s", dataTypes); + Preconditions.checkArgument(dataTypes.get(0).isFPType(), "Input datatype must be floating point, got %s", dataTypes); + + return Collections.singletonList(dataTypes.get(0)); + } + + @Override + public List doDiff(List i_v) { + + return Collections.singletonList(new CReluBp(sameDiff, arg(), i_v.get(0)).outputVariable()); + } + +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/CReluBp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/CReluBp.java new file mode 100644 index 000000000..7b96afffd --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/CReluBp.java @@ -0,0 +1,59 @@ +/* ****************************************************************************** + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.nd4j.linalg.api.ops.impl.transforms.custom; + +import lombok.NoArgsConstructor; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.DynamicCustomOp; +import org.nd4j.base.Preconditions; + +import java.util.Collections; +import java.util.List; +import lombok.Getter; +import lombok.NonNull; + + +@NoArgsConstructor +public class CReluBp extends DynamicCustomOp { + + public CReluBp(SameDiff sd, SDVariable input, SDVariable epsilonNext) { + super(sd, new SDVariable[]{input, epsilonNext}); + } + + public CReluBp(@NonNull INDArray input, @NonNull INDArray epsilonNext, INDArray output) { + super(new INDArray[]{input, epsilonNext}, wrapOrNull(output)); + } + + + @Override + public String opName() { + return "crelu_bp"; + } + + @Override + public List calculateOutputDataTypes(List dataTypes) { + Preconditions + .checkArgument(dataTypes != null && dataTypes.size() == 2, "Expected exactly 2 input datatypes, got %s", dataTypes); + Preconditions.checkArgument(dataTypes.get(0).isFPType(), "Input datatype must be floating point, got %s", dataTypes); + + return Collections.singletonList(dataTypes.get(0)); + } + + +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/CumProd.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/CumProd.java index 0be0b08ad..34e3e5f1d 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/CumProd.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/CumProd.java @@ -28,6 +28,7 @@ import org.nd4j.imports.graphmapper.tf.TFGraphMapper; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; +import org.nd4j.linalg.api.ops.impl.reduce.bp.CumProdBp; import org.tensorflow.framework.AttrValue; import org.tensorflow.framework.GraphDef; import org.tensorflow.framework.NodeDef; @@ -142,7 +143,7 @@ public class CumProd extends DynamicCustomOp { @Override public List doDiff(List grad) { - return Collections.singletonList(f().cumprodBp(arg(0), grad.get(0), exclusive, reverse, jaxis)); + return new CumProdBp(sameDiff, arg(0), grad.get(0), exclusive, reverse, jaxis).outputs(); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/CumSum.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/CumSum.java index c24693b01..97c53f4e2 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/CumSum.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/CumSum.java @@ -29,6 +29,7 @@ import org.nd4j.imports.graphmapper.tf.TFGraphMapper; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; +import org.nd4j.linalg.api.ops.impl.reduce.bp.CumSumBp; import org.tensorflow.framework.AttrValue; import org.tensorflow.framework.GraphDef; import org.tensorflow.framework.NodeDef; @@ -142,7 +143,7 @@ public class CumSum extends DynamicCustomOp { @Override public List doDiff(List grad) { - return Collections.singletonList(f().cumsumBp(arg(0), grad.get(0), exclusive, reverse, jaxis)); + return new CumSumBp(sameDiff, arg(0), grad.get(0), exclusive, reverse, jaxis).outputs(); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/DotProductAttention.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/DotProductAttention.java index d3a5c9676..300f8277a 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/DotProductAttention.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/DotProductAttention.java @@ -70,7 +70,8 @@ public class DotProductAttention extends DynamicCustomOp { @Override public List doDiff(List gradient) { - return sameDiff.f().dotProductAttentionBp(arg(0), arg(1), arg(2), gradient.get(0), args().length > 3 ? arg(3) : null, scaled); + SDVariable mask = args().length == 4 ? arg(3) : null; + return Arrays.asList(new DotProductAttentionBp(sameDiff, arg(0), arg(1), arg(2), gradient.get(0), mask, scaled).outputVariables()); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/DynamicPartition.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/DynamicPartition.java index 3efc13af0..db11a6a6c 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/DynamicPartition.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/DynamicPartition.java @@ -16,6 +16,7 @@ package org.nd4j.linalg.api.ops.impl.transforms.custom; +import lombok.NonNull; import lombok.val; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; @@ -24,6 +25,7 @@ import org.nd4j.imports.graphmapper.tf.TFGraphMapper; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; +import org.nd4j.linalg.api.ops.impl.transforms.gradient.DynamicPartitionBp; import org.tensorflow.framework.AttrValue; import org.tensorflow.framework.GraphDef; import org.tensorflow.framework.NodeDef; @@ -66,15 +68,23 @@ public class DynamicPartition extends DynamicCustomOp { addArgs(); } - public DynamicPartition(INDArray input, INDArray partitions, int numPartitions) { - addInputArgument(input); - addIArgument(numPartitions); + public DynamicPartition(@NonNull INDArray input, @NonNull INDArray partitions, int numPartitions) { + super(new INDArray[]{input, partitions}, null); + this.numPartitions = numPartitions; + addArgs(); } + public DynamicPartition(INDArray x, INDArray [] partitions, int numPartitions){ + //TODO; This needs fixing. + super(new INDArray[]{x}, null); + // this.partitions = partitions; + this.numPartitions = numPartitions; + addArgs(); + } @Override public List doDiff(List i_v) { - return Arrays.asList(f().dynamicPartitionBp(arg(0), arg(1), i_v.toArray(new SDVariable[i_v.size()]), numPartitions)); + return new DynamicPartitionBp(sameDiff, arg(0), arg(1), i_v.toArray(new SDVariable[i_v.size()]), numPartitions).outputs(); } protected void addArgs() { diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/DynamicStitch.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/DynamicStitch.java index 94c34d108..8c94c3e54 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/DynamicStitch.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/DynamicStitch.java @@ -16,6 +16,7 @@ package org.nd4j.linalg.api.ops.impl.transforms.custom; +import lombok.NonNull; import org.apache.commons.lang3.ArrayUtils; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; @@ -61,14 +62,8 @@ public class DynamicStitch extends DynamicCustomOp { this.numPartitions = inputs.length; } - public DynamicStitch(INDArray[] inputs, INDArray[] indices) { - for (INDArray input : inputs) { - addInputArgument(input); - } - - for (INDArray index : indices) { - addInputArgument(index); - } + public DynamicStitch(@NonNull INDArray[] indices, @NonNull INDArray[] inputs) { + super(ArrayUtils.addAll(indices, inputs), null); } @Override @@ -83,7 +78,7 @@ public class DynamicStitch extends DynamicCustomOp { SDVariable[] partition = sameDiff.dynamicPartition(gradient, partitions, numPartitions); List ret = new ArrayList<>(); for (SDVariable i : indices) - ret.add(f().zerosLike(i)); + ret.add(sameDiff.zerosLike(i)); Collections.addAll(ret, partition); return ret; } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/EqualTo.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/EqualTo.java index 0d1214c9a..e58609f2a 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/EqualTo.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/EqualTo.java @@ -48,14 +48,14 @@ public class EqualTo extends BaseDynamicTransformOp { super(inputs, outputs); } - public EqualTo( INDArray x, INDArray y) { - addInputArgument(x, y); - } - public EqualTo(INDArray x, INDArray y, INDArray z){ this(new INDArray[]{x, y}, new INDArray[]{z}); } + public EqualTo(INDArray x, INDArray y){ + this(new INDArray[]{x, y}, null); + } + @Override public String opName() { return "equals"; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Fill.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Fill.java index 73f221f35..0e39eb77b 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Fill.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Fill.java @@ -27,6 +27,7 @@ import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.api.ops.Op; import org.nd4j.linalg.exception.ND4JIllegalStateException; +import org.nd4j.linalg.factory.Nd4j; import org.tensorflow.framework.AttrValue; import org.tensorflow.framework.GraphDef; import org.tensorflow.framework.NodeDef; @@ -55,19 +56,21 @@ public class Fill extends DynamicCustomOp { super(null,sameDiff, new SDVariable[] {shape}, false); this.value = value; this.outputDataType = outputDataType; + this.outputDataType = outputDataType; addArgs(); } + public Fill(INDArray shape, DataType outputDataType, double value) { + super(new INDArray[]{shape, Nd4j.scalar(outputDataType, value)}, null); + this.value = value; + this.outputDataType = outputDataType; + } + public Fill(INDArray shape, INDArray result, double value) { super(null, shape, result, Collections.singletonList(value), null); this.value = value; } - public Fill(INDArray shape, DataType dataType, double value) { - super(null, shape, null, Collections.singletonList(value), null); - this.value = value; - } - public Fill(INDArray shape, INDArray value, INDArray result) { super(null, new INDArray[]{shape, value}, new INDArray[]{result}); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/GreaterThan.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/GreaterThan.java index 6a1ecc2cf..e4b28e56a 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/GreaterThan.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/GreaterThan.java @@ -49,14 +49,14 @@ public class GreaterThan extends BaseDynamicTransformOp { super(inputs, outputs); } - public GreaterThan( INDArray x, INDArray y) { - addInputArgument(x,y); - } - public GreaterThan(INDArray x, INDArray y, INDArray z){ this(new INDArray[]{x, y}, new INDArray[]{z}); } + public GreaterThan(INDArray x, INDArray y){ + this(new INDArray[]{x, y}, null); + } + @Override public String opName() { return "greater"; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/GreaterThanOrEqual.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/GreaterThanOrEqual.java index dfb7fe8dd..0ebea5e9c 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/GreaterThanOrEqual.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/GreaterThanOrEqual.java @@ -52,9 +52,9 @@ public class GreaterThanOrEqual extends BaseDynamicTransformOp { this(new INDArray[]{x, y}, new INDArray[]{z}); } - public GreaterThanOrEqual(INDArray x, INDArray y) { - this(new INDArray[]{x,y}, null); + public GreaterThanOrEqual(INDArray x, INDArray y){ + this(new INDArray[]{x, y}, null); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/InvertPermutation.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/InvertPermutation.java index 6048c9dff..9612c4dea 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/InvertPermutation.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/InvertPermutation.java @@ -26,6 +26,7 @@ import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.impl.transforms.BaseDynamicTransformOp; import java.util.Arrays; +import java.util.Collections; import java.util.List; /** @@ -67,8 +68,8 @@ public class InvertPermutation extends BaseDynamicTransformOp { @Override public List doDiff(List grad) { SDVariable gradient = grad.get(0); - SDVariable invertedGradient = f().invertPermutation(gradient, false); - return Arrays.asList(invertedGradient); + SDVariable invertedGradient = sameDiff.invertPermutation(gradient); + return Collections.singletonList(invertedGradient); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/IsNumericTensor.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/IsNumericTensor.java index 88c0a84ba..4f1804480 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/IsNumericTensor.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/IsNumericTensor.java @@ -45,8 +45,8 @@ public class IsNumericTensor extends DynamicCustomOp { super(null, inputs, outputs); } - public IsNumericTensor(INDArray input) { - addInputArgument(input); + public IsNumericTensor(INDArray inputs) { + super( new INDArray[] {inputs}, null); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/LayerNorm.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/LayerNorm.java index 0c4990bb2..f16a92318 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/LayerNorm.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/LayerNorm.java @@ -100,13 +100,11 @@ public class LayerNorm extends DynamicCustomOp { @Override public List doDiff(List gradient) { - SDVariable[] ret; - if(noBias){ - ret = f().layerNormBp(arg(0), arg(1), gradient.get(0), channelsFirst, dimensions); - }else{ - ret = f().layerNormBp(arg(0), arg(1), arg(2), gradient.get(0), channelsFirst, dimensions); + if (noBias) { + return new LayerNormBp(sameDiff, arg(0), arg(1), gradient.get(0), channelsFirst, dimensions).outputs(); + } else { + return new LayerNormBp(sameDiff, arg(0), arg(1), arg(2), gradient.get(0), channelsFirst, dimensions).outputs(); } - return Arrays.asList(ret); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/LessThan.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/LessThan.java index b1a38e0ff..0445b58c4 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/LessThan.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/LessThan.java @@ -49,14 +49,14 @@ public class LessThan extends BaseDynamicTransformOp { super(inputs, outputs); } - public LessThan( INDArray x, INDArray y) { - addInputArgument(x,y); - } - public LessThan(INDArray x, INDArray y, INDArray z){ this(new INDArray[]{x, y}, new INDArray[]{z}); } + public LessThan(INDArray x, INDArray y){ + this(new INDArray[]{x, y}, null); + } + @Override public String opName() { return "less"; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/LessThanOrEqual.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/LessThanOrEqual.java index 0ca6bf7e6..06e03335c 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/LessThanOrEqual.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/LessThanOrEqual.java @@ -48,14 +48,14 @@ public class LessThanOrEqual extends BaseDynamicTransformOp { super(inputs, outputs); } - public LessThanOrEqual( INDArray x, INDArray y) { - addInputArgument(x,y); - } - public LessThanOrEqual(INDArray x, INDArray y, INDArray z){ this(new INDArray[]{x, y}, new INDArray[]{z}); } + public LessThanOrEqual(INDArray x, INDArray y){ + this(new INDArray[]{x, y}, null); + } + @Override public String opName() { return "less_equal"; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/LogSoftMax.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/LogSoftMax.java index 86c9d9c0a..2de57451f 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/LogSoftMax.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/LogSoftMax.java @@ -22,6 +22,7 @@ import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; +import org.nd4j.linalg.api.ops.impl.transforms.gradient.LogSoftMaxDerivative; import org.nd4j.linalg.factory.Nd4j; import java.util.Collections; @@ -76,11 +77,9 @@ public class LogSoftMax extends DynamicCustomOp { @Override public List doDiff(List i_v) { if(dimension == null) { - SDVariable ret = f().logSoftmaxDerivative(arg(), i_v.get(0)); - return Collections.singletonList(ret); + return new LogSoftMaxDerivative(sameDiff, arg(), i_v.get(0)).outputs(); } else { - SDVariable ret = f().logSoftmaxDerivative(arg(), i_v.get(0), dimension); - return Collections.singletonList(ret); + return new LogSoftMaxDerivative(sameDiff, arg(), i_v.get(0), dimension).outputs(); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/MatrixSetDiag.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/MatrixSetDiag.java index 19d139cbb..9f4b97576 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/MatrixSetDiag.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/MatrixSetDiag.java @@ -57,8 +57,8 @@ public class MatrixSetDiag extends DynamicCustomOp { @Override public List doDiff(List i_v) { SDVariable grad = i_v.get(0); - SDVariable in1Grad = f().setDiag(grad, sameDiff.zerosLike(arg(1))); - SDVariable in2Grad = f().diagPart(grad); + SDVariable in1Grad = sameDiff.math.setDiag(grad, sameDiff.zerosLike(arg(1))); + SDVariable in2Grad = sameDiff.math.diagPart(grad); return Arrays.asList(in1Grad, in2Grad); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Max.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Max.java index e8653d4c0..0197f0c79 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Max.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Max.java @@ -48,12 +48,12 @@ public class Max extends BaseDynamicTransformOp { super(new INDArray[]{first, second}, out == null ? null : new INDArray[]{out}); } - public Max( INDArray[] inputs, INDArray[] outputs) { - super(inputs, outputs); + public Max( INDArray first, INDArray second){ + this(first, second, null); } - public Max( INDArray x, INDArray y) { - addInputArgument(x,y); + public Max( INDArray[] inputs, INDArray[] outputs) { + super(inputs, outputs); } @Override @@ -73,12 +73,7 @@ public class Max extends BaseDynamicTransformOp { @Override public List doDiff(List f1) { - //TODO Switch to maximum_bp op - https://github.com/deeplearning4j/deeplearning4j/blob/master/libnd4j/include/ops/declarable/generic/broadcastable/maximum.cpp - SDVariable max = outputVariables()[0]; - SDVariable eq1 = sameDiff.eq(larg(), max).castTo(arg(0).dataType()); - SDVariable eq2 = sameDiff.eq(rarg(), max).castTo(arg(1).dataType()); - - return Arrays.asList(eq1.mul(f1.get(0)), eq2.mul(f1.get(0))); + return Arrays.asList(new MaximumBp(sameDiff, arg(0), arg(1), f1.get(0)).outputVariables()); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/MaximumBp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/MaximumBp.java new file mode 100644 index 000000000..92fb3b0eb --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/MaximumBp.java @@ -0,0 +1,48 @@ +/* ****************************************************************************** + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.nd4j.linalg.api.ops.impl.transforms.custom; + +import lombok.NoArgsConstructor; +import lombok.NonNull; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.base.Preconditions; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ops.DynamicCustomOp; + +import java.util.ArrayList; +import java.util.List; + +@NoArgsConstructor +public class MaximumBp extends DynamicCustomOp { + + public MaximumBp(@NonNull SameDiff sameDiff, @NonNull SDVariable x, @NonNull SDVariable y, @NonNull SDVariable gradO) { + super("maximum_bp",sameDiff, new SDVariable[]{x,y, gradO}); + } + + @Override + public String opName() { + return "maximum_bp"; + } + + @Override + public List calculateOutputDataTypes(List inputDataTypes) { + List list = new ArrayList(); + list.add(inputDataTypes.get(0)); + list.add(inputDataTypes.get(0)); + return list; + } +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Min.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Min.java index c195178c2..cf0cf9c58 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Min.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Min.java @@ -48,12 +48,12 @@ public class Min extends BaseDynamicTransformOp { super(new INDArray[]{first, second}, out == null ? null : new INDArray[]{out}); } - public Min( INDArray[] inputs, INDArray[] outputs) { - super(inputs, outputs); + public Min( INDArray first, INDArray second){ + this(first, second, null); } - public Min( INDArray x, INDArray y) { - addInputArgument(x,y); + public Min( INDArray[] inputs, INDArray[] outputs) { + super(inputs, outputs); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/MultiHeadDotProductAttention.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/MultiHeadDotProductAttention.java index 54167bd8b..98765ed96 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/MultiHeadDotProductAttention.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/MultiHeadDotProductAttention.java @@ -79,7 +79,7 @@ public class MultiHeadDotProductAttention extends DynamicCustomOp { @Override public List doDiff(List gradient) { - return sameDiff.f().multiHeadDotProductAttentionBp(arg(0), arg(1), arg(2), arg(3), arg(4), arg(5), arg(6), gradient.get(0), args().length > 7 ? arg(7) : null, scaled); + return Arrays.asList(new MultiHeadDotProductAttentionBp(sameDiff, arg(0), arg(1), arg(2), arg(3), arg(4), arg(5), arg(6), gradient.get(0), args().length > 7 ? arg(7) : null, scaled).outputVariables()); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/NotEqualTo.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/NotEqualTo.java index 69d724a7e..ba2e36ea2 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/NotEqualTo.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/NotEqualTo.java @@ -48,14 +48,14 @@ public class NotEqualTo extends BaseDynamicTransformOp { super(inputs, outputs); } - public NotEqualTo( INDArray x, INDArray y) { - addInputArgument(x,y); - } - public NotEqualTo(INDArray x, INDArray y, INDArray z){ this(new INDArray[]{x, y}, new INDArray[]{z}); } + public NotEqualTo(INDArray x, INDArray y){ + this(new INDArray[]{x, y}, null); + } + @Override public String opName() { return "not_equals"; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Pow.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Pow.java index e155a4f2a..0f8286769 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Pow.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Pow.java @@ -23,6 +23,7 @@ import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; +import org.nd4j.linalg.api.ops.impl.reduce.bp.PowBp; import java.util.Arrays; import java.util.Collections; @@ -68,8 +69,7 @@ public class Pow extends DynamicCustomOp { SDVariable dldb = outputVariable().mul(sameDiff.math().log(a)).mul(f1.get(0)); return Arrays.asList(dlda, dldb);*/ - SDVariable[] g = f().powBp(arg(0), arg(1), f1.get(0)); - return Arrays.asList(g); + return new PowBp(sameDiff, arg(0), arg(1), f1.get(0)).outputs(); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Reverse.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Reverse.java index d1648abab..372f96c18 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Reverse.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Reverse.java @@ -100,8 +100,8 @@ public class Reverse extends DynamicCustomOp { @Override public List doDiff(List f1) { - SDVariable ret = f().reverse(f1.get(0), dimensions); - return Arrays.asList(ret); + SDVariable ret = sameDiff.reverse(f1.get(0), dimensions); + return Collections.singletonList(ret); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/ReverseSequence.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/ReverseSequence.java index 11897fef8..f7494c618 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/ReverseSequence.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/ReverseSequence.java @@ -59,6 +59,17 @@ public class ReverseSequence extends DynamicCustomOp { addArguments(); } + public ReverseSequence(INDArray x, INDArray seq_lengths, int seqDim, int batchDim){ + super(new INDArray[]{x, seq_lengths}, null); + this.seqDim = seqDim; + this.batchDim = batchDim; + addArguments(); + } + + public ReverseSequence(INDArray x, INDArray seq_lengths){ + this(x, seq_lengths, 1, 0); + } + private void addArguments(){ addIArgument(seqDim); addIArgument(batchDim); @@ -67,11 +78,6 @@ public class ReverseSequence extends DynamicCustomOp { public ReverseSequence() { } - public ReverseSequence(INDArray x, INDArray seq_lengths, int seqDim, int batchDim) { - addInputArgument(x, seq_lengths); - addIArgument(seqDim, batchDim); - } - @Override public String opName() { return "reverse_sequence"; @@ -115,8 +121,8 @@ public class ReverseSequence extends DynamicCustomOp { @Override public List doDiff(List f1) { - SDVariable ret = f().reverseSequence(f1.get(0), arg(1), seqDim, batchDim); - return Arrays.asList(ret, f().zerosLike(arg(1))); + SDVariable ret = sameDiff.reverseSequence(f1.get(0), arg(1), seqDim, batchDim); + return Arrays.asList(ret, sameDiff.zerosLike(arg(1))); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/SoftMax.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/SoftMax.java index 24c2353c1..737e76f3b 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/SoftMax.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/SoftMax.java @@ -23,6 +23,7 @@ import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.impl.transforms.BaseDynamicTransformOp; +import org.nd4j.linalg.api.ops.impl.transforms.gradient.SoftmaxBp; import java.util.Collections; import java.util.List; @@ -106,8 +107,7 @@ public class SoftMax extends BaseDynamicTransformOp { @Override public List doDiff(List i_v) { - SDVariable ret = f().softmaxDerivative(arg(), i_v.get(0), this.dimension); - return Collections.singletonList(ret); + return new SoftmaxBp(sameDiff, arg(), i_v.get(0), this.dimension).outputs(); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Standardize.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Standardize.java index 467b36a4e..8acef4029 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Standardize.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Standardize.java @@ -63,8 +63,7 @@ public class Standardize extends DynamicCustomOp { @Override public List doDiff(List grad) { - SDVariable ret = f().standardizeBp(arg(0), grad.get(0), dimensions); - return Arrays.asList(ret); + return new StandardizeBp(sameDiff, arg(0), grad.get(0), dimensions).outputs(); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/ThresholdRelu.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/ThresholdRelu.java index 82e2ae6e3..1688c03c4 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/ThresholdRelu.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/ThresholdRelu.java @@ -27,6 +27,7 @@ import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.api.ops.impl.scalar.RectifiedLinear; +import org.nd4j.linalg.api.ops.impl.transforms.gradient.ThresholdReluBp; /** * Threshold ReLU op. The genral case of {@link RectifiedLinear}. @@ -72,6 +73,6 @@ public class ThresholdRelu extends DynamicCustomOp { @Override public List doDiff(List f1) { - return Collections.singletonList(f().thresholdReluBp(arg(), f1.get(0), cutoff)); + return new ThresholdReluBp(sameDiff, arg(), f1.get(0), cutoff).outputs(); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Trace.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Trace.java index 24d79f234..9faca3403 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Trace.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Trace.java @@ -51,12 +51,12 @@ public class Trace extends DynamicCustomOp { @Override public List doDiff(List gradAtOutput){ - SDVariable rows = f().reshape(f().sizeAt(arg(), -2), new long[]{1}); - SDVariable cols = f().reshape(f().sizeAt(arg(), -1), new long[]{1}); - SDVariable eye = sameDiff.math().eye(/*f().shape(gradAtOutput.get(0)),*/ rows, cols); + SDVariable rows = sameDiff.reshape(sameDiff.sizeAt(arg(), -2), 1); + SDVariable cols = sameDiff.reshape(sameDiff.sizeAt(arg(), -1), 1); + SDVariable eye = sameDiff.math().eye(/*sameDiff.shape(gradAtOutput.get(0)),*/ rows, cols); //Reshape gradient from [x,y,z] to [x,y,z,1,1] - SDVariable reshapedGrad = f().expandDims(gradAtOutput.get(0), -1); - reshapedGrad = f().expandDims(reshapedGrad, -1); + SDVariable reshapedGrad = sameDiff.expandDims(gradAtOutput.get(0), -1); + reshapedGrad = sameDiff.expandDims(reshapedGrad, -1); return Collections.singletonList(reshapedGrad.mul(eye)); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/segment/SegmentMax.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/segment/SegmentMax.java index 5b6cd2517..217062e09 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/segment/SegmentMax.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/segment/SegmentMax.java @@ -22,6 +22,7 @@ import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; +import org.nd4j.linalg.api.ops.impl.transforms.segment.bp.SegmentMaxBp; import java.util.Arrays; import java.util.Collections; @@ -38,8 +39,8 @@ public class SegmentMax extends DynamicCustomOp { super(null, sameDiff, new SDVariable[] {data, segmentIds}, false); } - public SegmentMax(INDArray data, INDArray segmentIds) { - addInputArgument(data, segmentIds); + public SegmentMax(INDArray data, INDArray segmentIds){ + super(new INDArray[]{data, segmentIds}, null); } public SegmentMax(){ } @@ -56,7 +57,7 @@ public class SegmentMax extends DynamicCustomOp { @Override public List doDiff(List gradients){ - return Arrays.asList(f().segmentMaxBp(arg(0), arg(1), gradients.get(0))); + return new SegmentMaxBp(sameDiff, arg(0), arg(1), gradients.get(0)).outputs(); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/segment/SegmentMean.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/segment/SegmentMean.java index d0a9a6784..79e887001 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/segment/SegmentMean.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/segment/SegmentMean.java @@ -22,6 +22,7 @@ import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; +import org.nd4j.linalg.api.ops.impl.transforms.segment.bp.SegmentMeanBp; import java.util.Arrays; import java.util.Collections; @@ -38,12 +39,12 @@ public class SegmentMean extends DynamicCustomOp { super(null, sameDiff, new SDVariable[] {data, segmentIds}, false); } - public SegmentMean(INDArray data, INDArray segmentIds) { - addInputArgument(data, segmentIds); - } - public SegmentMean(){ } + public SegmentMean(INDArray data, INDArray segmentIds){ + super(new INDArray[]{data, segmentIds}, null); + } + @Override public String opName(){ return "segment_mean"; @@ -56,7 +57,7 @@ public class SegmentMean extends DynamicCustomOp { @Override public List doDiff(List gradients){ - return Arrays.asList(f().segmentMeanBp(arg(0), arg(1), gradients.get(0))); + return new SegmentMeanBp(sameDiff, arg(0), arg(1), gradients.get(0)).outputs(); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/segment/SegmentMin.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/segment/SegmentMin.java index 2bc369f2a..367d6ee1a 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/segment/SegmentMin.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/segment/SegmentMin.java @@ -22,6 +22,7 @@ import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; +import org.nd4j.linalg.api.ops.impl.transforms.segment.bp.SegmentMinBp; import java.util.Arrays; import java.util.Collections; @@ -38,12 +39,12 @@ public class SegmentMin extends DynamicCustomOp { super(null, sameDiff, new SDVariable[] {data, segmentIds}, false); } - public SegmentMin(INDArray data, INDArray segmentIds) { - addInputArgument(data, segmentIds); - } - public SegmentMin(){ } + public SegmentMin(INDArray data, INDArray segmentIds){ + super(new INDArray[]{data, segmentIds}, null); + } + @Override public String opName(){ return "segment_min"; @@ -56,7 +57,7 @@ public class SegmentMin extends DynamicCustomOp { @Override public List doDiff(List gradients){ - return Arrays.asList(f().segmentMinBp(arg(0), arg(1), gradients.get(0))); + return new SegmentMinBp(sameDiff, arg(0), arg(1), gradients.get(0)).outputs(); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/segment/SegmentProd.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/segment/SegmentProd.java index 3be3625e7..ddd719045 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/segment/SegmentProd.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/segment/SegmentProd.java @@ -22,6 +22,7 @@ import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; +import org.nd4j.linalg.api.ops.impl.transforms.segment.bp.SegmentProdBp; import java.util.Arrays; import java.util.Collections; @@ -38,12 +39,12 @@ public class SegmentProd extends DynamicCustomOp { super(null, sameDiff, new SDVariable[] {data, segmentIds}, false); } - public SegmentProd(INDArray data, INDArray segmentIds) { - addInputArgument(data, segmentIds); - } - public SegmentProd(){ } + public SegmentProd(INDArray data, INDArray segmentIds){ + super(new INDArray[]{data, segmentIds}, null); + } + @Override public String opName(){ return "segment_prod"; @@ -56,7 +57,7 @@ public class SegmentProd extends DynamicCustomOp { @Override public List doDiff(List gradients){ - return Arrays.asList(f().segmentProdBp(arg(0), arg(1), gradients.get(0))); + return new SegmentProdBp(sameDiff, arg(0), arg(1), gradients.get(0)).outputs(); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/segment/SegmentSum.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/segment/SegmentSum.java index 5de847162..9f6269848 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/segment/SegmentSum.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/segment/SegmentSum.java @@ -22,6 +22,7 @@ import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; +import org.nd4j.linalg.api.ops.impl.transforms.segment.bp.SegmentSumBp; import java.util.Arrays; import java.util.Collections; @@ -38,12 +39,12 @@ public class SegmentSum extends DynamicCustomOp { super(null, sameDiff, new SDVariable[] {data, segmentIds}, false); } - public SegmentSum(INDArray data, INDArray segmentIds) { - addInputArgument(data, segmentIds); - } - public SegmentSum(){ } + public SegmentSum(INDArray data, INDArray segmentIds){ + super(new INDArray[]{data, segmentIds}, null); + } + @Override public String opName(){ return "segment_sum"; @@ -56,7 +57,7 @@ public class SegmentSum extends DynamicCustomOp { @Override public List doDiff(List gradients){ - return Arrays.asList(f().segmentSumBp(arg(0), arg(1), gradients.get(0))); + return new SegmentSumBp(sameDiff, arg(0), arg(1), gradients.get(0)).outputs(); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/dtype/Cast.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/dtype/Cast.java index d588ef4a8..b168adb43 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/dtype/Cast.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/dtype/Cast.java @@ -129,7 +129,7 @@ public class Cast extends BaseDynamicTransformOp { if(arg().dataType().isFPType()){ return Collections.singletonList(i_v.get(0).castTo(arg().dataType())); } else { - return Collections.singletonList(f().zerosLike(arg())); + return Collections.singletonList(sameDiff.zerosLike(arg())); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/floating/RSqrt.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/floating/RSqrt.java index df5cdbcc7..9471c8bca 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/floating/RSqrt.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/floating/RSqrt.java @@ -75,7 +75,7 @@ public class RSqrt extends BaseTransformFloatOp { @Override public List doDiff(List i_v) { - SDVariable xPowNeg32 = f().pow(arg(), -1.5).mul(-0.5); + SDVariable xPowNeg32 = sameDiff.math.pow(arg(), -1.5).mul(-0.5); return Collections.singletonList(i_v.get(0).mul(xPowNeg32)); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/SELUDerivative.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/SELUDerivative.java index b00b29b75..fdbbafa99 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/SELUDerivative.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/SELUDerivative.java @@ -24,6 +24,7 @@ import org.nd4j.linalg.api.ops.BaseTransformOp; import org.nd4j.linalg.api.ops.BaseTransformStrictOp; import java.util.Arrays; +import java.util.Collections; import java.util.List; /** @@ -41,6 +42,10 @@ public class SELUDerivative extends BaseTransformStrictOp { private static final double SELU_ALPHA = 1.6732632423543772848170429916717; private static final double SELU_LAMBDA = 1.0507009873554804934193349852946; + public SELUDerivative(SameDiff sameDiff, SDVariable i_v) { + this(sameDiff, i_v, false); + } + public SELUDerivative(SameDiff sameDiff, SDVariable i_v, boolean inPlace) { super(sameDiff, i_v, inPlace); } @@ -79,9 +84,8 @@ public class SELUDerivative extends BaseTransformStrictOp { @Override public List doDiff(List i_v) { - SDVariable ret = f().div(arg(),f().seluDerivative(arg())); - - return Arrays.asList(ret); + SDVariable ret = sameDiff.math.div(arg(), new SELUDerivative(sameDiff, arg()).outputVariable()); + return Collections.singletonList(ret); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/TanhDerivative.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/TanhDerivative.java index 16afb4316..2a0d6021a 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/TanhDerivative.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/TanhDerivative.java @@ -84,7 +84,7 @@ public class TanhDerivative extends DynamicCustomOp { @Override public List doDiff(List i_v) { - SDVariable ret = f().div(sameDiff.onesLike(outputVariables()[0]), f().pow(f().cosh(arg()), 2)); + SDVariable ret = sameDiff.math.div(sameDiff.onesLike(outputVariables()[0]), sameDiff.math.pow(sameDiff.math.cosh(arg()), 2)); return Collections.singletonList(ret); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/AddOp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/AddOp.java index 672159a3e..4069967c6 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/AddOp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/AddOp.java @@ -16,10 +16,12 @@ package org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic; +import lombok.NonNull; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.impl.transforms.BaseDynamicTransformOp; +import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.bp.AddBpOp; import java.util.List; @@ -34,14 +36,18 @@ public class AddOp extends BaseDynamicTransformOp { public AddOp() { } - public AddOp(SameDiff sameDiff, SDVariable[] args, boolean inPlace) { - super(sameDiff, args, inPlace); + public AddOp(@NonNull SameDiff sameDiff, @NonNull SDVariable x, @NonNull SDVariable y) { + super(sameDiff, new SDVariable[]{x, y}, false); } public AddOp(INDArray first, INDArray second, INDArray result){ this(new INDArray[]{first, second}, result == null ? null : new INDArray[]{result}); } + public AddOp(@NonNull INDArray x, @NonNull INDArray y) { + this(new INDArray[]{x,y}, null); + } + public AddOp(INDArray[] inputs, INDArray[] outputs) { super(inputs, outputs); } @@ -63,7 +69,7 @@ public class AddOp extends BaseDynamicTransformOp { @Override public List doDiff(List i_v) { - return f().addBp(larg(), rarg(), i_v.get(0)); + return new AddBpOp(sameDiff, larg(), rarg(), i_v.get(0)).outputs(); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/DivOp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/DivOp.java index b76942e95..2ce0101cf 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/DivOp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/DivOp.java @@ -16,11 +16,14 @@ package org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic; +import lombok.NonNull; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.impl.transforms.BaseDynamicTransformOp; +import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.bp.DivBpOp; +import java.util.Arrays; import java.util.List; /** @@ -33,14 +36,18 @@ public class DivOp extends BaseDynamicTransformOp { public DivOp() {} - public DivOp( SameDiff sameDiff, SDVariable[] args, boolean inPlace) { - super(sameDiff, args, inPlace); + public DivOp( @NonNull SameDiff sameDiff, @NonNull SDVariable x, @NonNull SDVariable y) { + super(sameDiff, new SDVariable[]{x, y}, false); } public DivOp(INDArray first, INDArray second, INDArray result){ this(new INDArray[]{first, second}, result == null ? null : new INDArray[]{result}); } + public DivOp( @NonNull INDArray x, INDArray y) { + this(new INDArray[]{x,y}, null); + } + public DivOp( INDArray[] inputs, INDArray[] outputs) { super(inputs, outputs); } @@ -65,7 +72,7 @@ public class DivOp extends BaseDynamicTransformOp { @Override public List doDiff(List i_v) { - return f().divBp(larg(), rarg(), i_v.get(0)); + return Arrays.asList(new DivBpOp(sameDiff, larg(), rarg(), i_v.get(0)).outputVariables()); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/FModOp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/FModOp.java index c2314cdc4..408d86a75 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/FModOp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/FModOp.java @@ -22,6 +22,7 @@ import org.nd4j.imports.NoOpNameFoundException; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.BaseTransformOp; import org.nd4j.linalg.api.ops.BaseTransformSameOp; +import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.bp.FloorModBpOp; import java.util.List; @@ -83,6 +84,6 @@ public class FModOp extends BaseTransformSameOp { @Override public List doDiff(List f1) { - return f().floorModBp(larg(), rarg(), f1.get(0)); + return new FloorModBpOp(sameDiff, larg(), rarg(), f1.get(0)).outputs(); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/FloorDivOp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/FloorDivOp.java index debfc5a5d..7ed2c6c1c 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/FloorDivOp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/FloorDivOp.java @@ -16,10 +16,12 @@ package org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic; +import lombok.NonNull; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.impl.transforms.BaseDynamicTransformOp; +import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.bp.FloorDivBpOp; import java.util.List; @@ -39,6 +41,10 @@ public class FloorDivOp extends BaseDynamicTransformOp { super(sameDiff, args, inPlace); } + public FloorDivOp(@NonNull INDArray x, @NonNull INDArray y) { + this(new INDArray[]{x, y}, null); + } + public FloorDivOp( INDArray[] inputs, INDArray[] outputs) { super(inputs, outputs); } @@ -63,6 +69,6 @@ public class FloorDivOp extends BaseDynamicTransformOp { @Override public List doDiff(List i_v) { - return f().floorDivBp(larg(), rarg(), i_v.get(0)); + return new FloorDivBpOp(sameDiff, larg(), rarg(), i_v.get(0)).outputs(); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/FloorModOp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/FloorModOp.java index e7286816f..29799a221 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/FloorModOp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/FloorModOp.java @@ -16,12 +16,14 @@ package org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic; +import lombok.NonNull; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.impl.transforms.BaseDynamicTransformOp; +import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.bp.FloorModBpOp; import org.nd4j.linalg.api.shape.Shape; import java.util.Collections; @@ -39,6 +41,10 @@ public class FloorModOp extends BaseDynamicTransformOp { super(sameDiff, new SDVariable[]{x, y}, false); } + public FloorModOp(@NonNull INDArray x, @NonNull INDArray y) { + this(new INDArray[]{x, y}, null); + } + public FloorModOp(INDArray[] inputs, INDArray[] outputs) { super(inputs, outputs); } @@ -60,7 +66,7 @@ public class FloorModOp extends BaseDynamicTransformOp { @Override public List doDiff(List f1) { - return f().floorModBp(larg(), rarg(), f1.get(0)); + return new FloorModBpOp(sameDiff, larg(), rarg(), f1.get(0)).outputs(); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/MergeAddOp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/MergeAddOp.java index fc89333f4..0d634766e 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/MergeAddOp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/MergeAddOp.java @@ -18,14 +18,19 @@ package org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic; import lombok.NoArgsConstructor; import lombok.NonNull; +import org.apache.commons.lang3.ArrayUtils; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.impl.transforms.BaseDynamicTransformOp; +import org.nd4j.linalg.api.ops.impl.transforms.custom.MaximumBp; +import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.bp.MergeAddBp; +import org.nd4j.linalg.util.ArrayUtil; import java.util.ArrayList; +import java.util.Arrays; import java.util.Collections; import java.util.List; @@ -70,11 +75,8 @@ public class MergeAddOp extends BaseDynamicTransformOp { @Override public List doDiff(List i_v) { - SDVariable gradient = sameDiff.setupFunction(i_v.get(0)); - List ret = new ArrayList<>(); - for (int i = 0; i < args().length; i++) - ret.add(gradient); - return ret; + return Arrays.asList(new MergeAddBp(sameDiff, args(), i_v.get(0)).outputVariables()); + } @@ -82,7 +84,7 @@ public class MergeAddOp extends BaseDynamicTransformOp { public List calculateOutputDataTypes(List dataTypes){ DataType first = dataTypes.get(0); for( int i=1; i doDiff(List i_v) { - return f().modBp(larg(), rarg(), i_v.get(0)); + return new ModBpOp(sameDiff, larg(), rarg(), i_v.get(0)).outputs(); } - - } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/MulOp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/MulOp.java index 4636f9bc8..307a46557 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/MulOp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/MulOp.java @@ -16,10 +16,12 @@ package org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic; +import lombok.NonNull; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.impl.transforms.BaseDynamicTransformOp; +import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.bp.MulBpOp; import java.util.List; @@ -33,12 +35,16 @@ public class MulOp extends BaseDynamicTransformOp { public MulOp() {} - public MulOp( SameDiff sameDiff, SDVariable[] args, boolean inPlace) { - super(sameDiff, args, inPlace); + public MulOp( @NonNull SameDiff sameDiff, @NonNull SDVariable x, @NonNull SDVariable y) { + super(sameDiff, new SDVariable[]{x, y}, false); + } + + public MulOp(INDArray first, INDArray second){ + this(first, second, null); } public MulOp(INDArray first, INDArray second, INDArray result){ - this(new INDArray[]{first, second}, result == null ? null : new INDArray[]{result}); + this(new INDArray[]{first, second}, wrapOrNull(result)); } public MulOp( INDArray[] inputs, INDArray[] outputs) { @@ -66,7 +72,7 @@ public class MulOp extends BaseDynamicTransformOp { @Override public List doDiff(List i_v) { - return f().mulBp(larg(), rarg(), i_v.get(0)); + return new MulBpOp(sameDiff, larg(), rarg(), i_v.get(0)).outputs(); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/RDivOp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/RDivOp.java index d54d91dbc..9891464fd 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/RDivOp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/RDivOp.java @@ -16,12 +16,15 @@ package org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic; +import lombok.NonNull; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.imports.NoOpNameFoundException; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.impl.transforms.BaseDynamicTransformOp; +import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.bp.RDivBpOp; +import java.util.Arrays; import java.util.List; /** @@ -34,14 +37,18 @@ public class RDivOp extends BaseDynamicTransformOp { public RDivOp() {} - public RDivOp( SameDiff sameDiff, SDVariable[] args, boolean inPlace) { - super(sameDiff, args, inPlace); + public RDivOp(@NonNull SameDiff sameDiff, @NonNull SDVariable x, @NonNull SDVariable y) { + super(sameDiff, new SDVariable[]{x, y}, false); } public RDivOp(INDArray first, INDArray second, INDArray result){ this(new INDArray[]{first, second}, result == null ? null : new INDArray[]{result}); } + public RDivOp(@NonNull INDArray x, @NonNull INDArray y){ + this(new INDArray[]{x, y}, null); + } + public RDivOp( INDArray[] inputs, INDArray[] outputs) { super(inputs, outputs); } @@ -64,6 +71,6 @@ public class RDivOp extends BaseDynamicTransformOp { @Override public List doDiff(List i_v) { - return f().rdivBp(larg(), rarg(), i_v.get(0)); + return Arrays.asList(new RDivBpOp(sameDiff, larg(), rarg(), i_v.get(0)).outputVariables()); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/RSubOp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/RSubOp.java index 12c852949..5b233eb17 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/RSubOp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/RSubOp.java @@ -21,6 +21,7 @@ import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.imports.NoOpNameFoundException; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.impl.transforms.BaseDynamicTransformOp; +import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.bp.RSubBpOp; import java.util.List; @@ -45,8 +46,16 @@ public class RSubOp extends BaseDynamicTransformOp { this(sameDiff, new SDVariable[]{i_v1, i_v2}, inPlace); } + public RSubOp(INDArray first, INDArray second){ + this(first, second, null); + } + public RSubOp(INDArray first, INDArray second, INDArray result){ - this(new INDArray[]{first, second}, result == null ? null : new INDArray[]{result}); + this(new INDArray[]{first, second}, wrapOrNull(result)); + } + + public RSubOp( INDArray[] inputs, INDArray[] outputs) { + super(inputs, outputs); } public RSubOp() {} @@ -61,13 +70,9 @@ public class RSubOp extends BaseDynamicTransformOp { throw new NoOpNameFoundException("No ONNX op name found for: " + getClass().getName()); } - public RSubOp( INDArray[] inputs, INDArray[] outputs) { - super(inputs, outputs); - } - @Override public List doDiff(List i_v) { - return f().rsubBp(larg(), rarg(), i_v.get(0)); + return new RSubBpOp(sameDiff, larg(), rarg(), i_v.get(0)).outputs(); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/RealDivOp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/RealDivOp.java index 04e72b5db..2fe1f150f 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/RealDivOp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/RealDivOp.java @@ -20,6 +20,7 @@ import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.impl.transforms.BaseDynamicTransformOp; +import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.bp.DivBpOp; import java.util.List; @@ -60,7 +61,7 @@ public class RealDivOp extends BaseDynamicTransformOp { @Override public List doDiff(List i_v) { - return f().divBp(larg(), rarg(), i_v.get(0)); + return new DivBpOp(sameDiff, larg(), rarg(), i_v.get(0)).outputs(); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/SquaredDifferenceOp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/SquaredDifferenceOp.java index acbf840f1..e1ad183e7 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/SquaredDifferenceOp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/SquaredDifferenceOp.java @@ -36,14 +36,21 @@ public class SquaredDifferenceOp extends BaseDynamicTransformOp { public SquaredDifferenceOp() {} - public SquaredDifferenceOp(SameDiff sameDiff, SDVariable[] args, boolean inPlace) { - super(sameDiff, args, inPlace); + public SquaredDifferenceOp(SameDiff sameDiff, SDVariable x, SDVariable y, boolean inPlace) { + super(sameDiff, new SDVariable[]{x,y}, inPlace); } - public SquaredDifferenceOp(INDArray[] inputs, INDArray[] outputs) { - super(inputs, outputs); + public SquaredDifferenceOp(SameDiff sameDiff, SDVariable x, SDVariable y) { + this(sameDiff, x, y, false); } + public SquaredDifferenceOp(INDArray x, INDArray y, INDArray output) { + super(new INDArray[]{x,y}, new INDArray[]{output}); + } + + public SquaredDifferenceOp(INDArray x, INDArray y) { + addInputArgument(new INDArray[]{x,y}); + } @Override public String opName() { @@ -63,8 +70,7 @@ public class SquaredDifferenceOp extends BaseDynamicTransformOp { @Override public List doDiff(List i_v1) { - SDVariable[] outputs = new SquaredDifferenceBpOp(f().sameDiff(), new SDVariable[]{larg(), rarg(), i_v1.get(0)}).outputVariables(); - return Arrays.asList(outputs); + return new SquaredDifferenceBpOp(sameDiff, new SDVariable[]{larg(), rarg(), i_v1.get(0)}).outputs(); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/SubOp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/SubOp.java index 0d222329e..da6e77a42 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/SubOp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/SubOp.java @@ -16,10 +16,12 @@ package org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic; +import lombok.NonNull; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.impl.transforms.BaseDynamicTransformOp; +import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.bp.SubBpOp; import java.util.List; @@ -33,12 +35,16 @@ public class SubOp extends BaseDynamicTransformOp { public SubOp() {} - public SubOp( SameDiff sameDiff, SDVariable[] args, boolean inPlace) { - super(sameDiff, args, inPlace); + public SubOp( @NonNull SameDiff sameDiff, @NonNull SDVariable x, @NonNull SDVariable y) { + super(sameDiff, new SDVariable[]{x, y}, false); + } + + public SubOp(INDArray first, INDArray second){ + this(first, second, null); } public SubOp(INDArray first, INDArray second, INDArray result){ - this(new INDArray[]{first, second}, result == null ? null : new INDArray[]{result}); + this(new INDArray[]{first, second}, wrapOrNull(result)); } public SubOp( INDArray[] inputs, INDArray[] outputs) { @@ -65,7 +71,7 @@ public class SubOp extends BaseDynamicTransformOp { @Override public List doDiff(List i_v) { - return f().subBp(larg(), rarg(), i_v.get(0)); + return new SubBpOp(sameDiff, larg(), rarg(), i_v.get(0)).outputs(); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/TruncateDivOp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/TruncateDivOp.java index 973ecd7ba..4c99b479b 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/TruncateDivOp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/TruncateDivOp.java @@ -64,8 +64,8 @@ public class TruncateDivOp extends BaseDynamicTransformOp { @Override public List doDiff(List i_v) { - SDVariable gradWrtX = f().div(i_v.get(0),rarg()); - SDVariable gradWrtY = f().mul(f().neg(gradWrtX),f().div(larg(),rarg())); + SDVariable gradWrtX = sameDiff.math.div(i_v.get(0),rarg()); + SDVariable gradWrtY = sameDiff.math.mul(sameDiff.math.neg(gradWrtX),sameDiff.math.div(larg(),rarg())); List ret = new ArrayList<>(2); ret.add(gradWrtX); ret.add(gradWrtY); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/bp/MergeAddBp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/bp/MergeAddBp.java new file mode 100644 index 000000000..b0403ecff --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/bp/MergeAddBp.java @@ -0,0 +1,54 @@ +/* ****************************************************************************** + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.bp; + +import lombok.NoArgsConstructor; +import lombok.NonNull; +import org.apache.commons.lang3.ArrayUtils; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ops.DynamicCustomOp; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + +@NoArgsConstructor +public class MergeAddBp extends DynamicCustomOp { + + public MergeAddBp(SameDiff sameDiff, @NonNull SDVariable[] inputs, @NonNull SDVariable gradO) { + super("mergeadd_bp", sameDiff, ArrayUtils.add(inputs, gradO)); + } + + @Override + public String opName() { + return "mergeadd_bp"; + } + + @Override + public List calculateOutputDataTypes(List inputDataTypes) { + ArrayList list = new ArrayList(); + for (int i=0; i< args().length-1;i++){list.add(inputDataTypes.get(0));} + return list; + + } + + @Override + public int getNumOutputs(){ + return args().length-1; + } +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/bool/Not.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/bool/Not.java index 95bd0bf41..d70de3b3e 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/bool/Not.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/bool/Not.java @@ -69,6 +69,6 @@ public class Not extends BaseTransformBoolOp { @Override public List doDiff(List f1) { - return Collections.singletonList(f().zerosLike(arg())); + return Collections.singletonList(sameDiff.zerosLike(arg())); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/AMax.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/AMax.java index e5ea01a60..d1e5917c8 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/AMax.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/AMax.java @@ -21,6 +21,7 @@ import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.imports.NoOpNameFoundException; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.BaseTransformSameOp; +import org.nd4j.linalg.api.ops.impl.reduce.bp.MinBp; import java.util.Collections; import java.util.List; @@ -57,7 +58,7 @@ public class AMax extends BaseTransformSameOp { @Override public List doDiff(List f1) { SDVariable sgn = sameDiff.math().sign(arg()); - SDVariable minBp = f().minBp(sameDiff.math().abs(arg()), f1.get(0), false, dimensions); + SDVariable minBp = new MinBp(sameDiff, sameDiff.math().abs(arg()), f1.get(0), false, dimensions).outputVariable(); return Collections.singletonList(sgn.mul(minBp)); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/AMin.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/AMin.java index bf6a37a55..b1ded6b55 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/AMin.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/AMin.java @@ -22,6 +22,7 @@ import org.nd4j.imports.NoOpNameFoundException; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.BaseReduceSameOp; import org.nd4j.linalg.api.ops.BaseTransformSameOp; +import org.nd4j.linalg.api.ops.impl.reduce.bp.MinBp; import java.util.Collections; import java.util.List; @@ -57,7 +58,7 @@ public class AMin extends BaseTransformSameOp { @Override public List doDiff(List f1) { SDVariable sgn = sameDiff.math().sign(arg()); - SDVariable minBp = f().minBp(sameDiff.math().abs(arg()), f1.get(0), false, dimensions); + SDVariable minBp = new MinBp(sameDiff, sameDiff.math().abs(arg()), f1.get(0), false, dimensions).outputVariable(); return Collections.singletonList(sgn.mul(minBp)); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Abs.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Abs.java index 4c6bf0ad9..ef21623c2 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Abs.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Abs.java @@ -77,7 +77,7 @@ public class Abs extends BaseTransformSameOp { @Override public List doDiff(List i_v) { - SDVariable ret = f().sign(arg()).mul(i_v.get(0)); + SDVariable ret = sameDiff.math.sign(arg()).mul(i_v.get(0)); return Arrays.asList(ret); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Ceil.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Ceil.java index f2d163f3f..bc86ae999 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Ceil.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Ceil.java @@ -75,6 +75,6 @@ public class Ceil extends BaseTransformSameOp { public List doDiff(List f1) { //not continuously differentiable, but dOut/dIn = 0 in most places - return Arrays.asList(f().zerosLike(arg())); + return Arrays.asList(sameDiff.zerosLike(arg())); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Cube.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Cube.java index 6422e8df8..b4550cb4e 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Cube.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Cube.java @@ -25,6 +25,7 @@ import org.nd4j.imports.NoOpNameFoundException; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.BaseTransformOp; import org.nd4j.linalg.api.ops.BaseTransformSameOp; +import org.nd4j.linalg.api.ops.impl.transforms.gradient.CubeBp; import java.util.Arrays; import java.util.List; @@ -77,6 +78,6 @@ public class Cube extends BaseTransformSameOp { @Override public List doDiff(List f1) { - return Collections.singletonList(f().cubeBp(arg(), f1.get(0))); + return new CubeBp(sameDiff, arg(), f1.get(0)).outputs(); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Identity.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Identity.java index dcee02131..6918b8d2c 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Identity.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Identity.java @@ -42,7 +42,7 @@ public class Identity extends BaseDynamicTransformOp { } public Identity(INDArray x){ - addInputArgument(x); + super(new INDArray[]{x}, null); } public Identity(){ } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Max.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Max.java index db682174c..f5aec6b48 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Max.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Max.java @@ -21,6 +21,8 @@ import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.imports.NoOpNameFoundException; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.BaseTransformSameOp; +import org.nd4j.linalg.api.ops.impl.reduce.bp.MaxBp; +import org.nd4j.linalg.api.ops.impl.transforms.custom.MaximumBp; import java.util.Collections; import java.util.List; @@ -56,9 +58,7 @@ public class Max extends BaseTransformSameOp { @Override public List doDiff(List f1) { - SDVariable sgn = sameDiff.math().sign(arg()); - SDVariable minBp = f().minBp(sameDiff.math().abs(arg()), f1.get(0), false, dimensions); - return Collections.singletonList(sgn.mul(minBp)); + return new MaximumBp(sameDiff, larg(), rarg(), f1.get(0)).outputs(); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Min.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Min.java index 6585ace19..1560a0e80 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Min.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Min.java @@ -21,7 +21,9 @@ import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.imports.NoOpNameFoundException; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.BaseTransformSameOp; +import org.nd4j.linalg.api.ops.impl.transforms.custom.MaximumBp; +import java.util.Arrays; import java.util.Collections; import java.util.List; @@ -56,9 +58,10 @@ public class Min extends BaseTransformSameOp { @Override public List doDiff(List f1) { - SDVariable sgn = sameDiff.math().sign(arg()); - SDVariable minBp = f().minBp(sameDiff.math().abs(arg()), f1.get(0), false, dimensions); - return Collections.singletonList(sgn.mul(minBp)); + //TODO optimize + SDVariable gt = arg(0).gt(arg(1)).castTo(arg(0).dataType()); + SDVariable lt = arg(0).lt(arg(1)).castTo(arg(1).dataType()); + return Arrays.asList(lt.mul(f1.get(0)), gt.mul(f1.get(0))); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Negative.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Negative.java index 37b370fe9..f03805eb6 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Negative.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Negative.java @@ -73,7 +73,7 @@ public class Negative extends BaseTransformSameOp { @Override public List doDiff(List i_v) { - return Arrays.asList(f().neg(i_v.get(0))); + return Arrays.asList(sameDiff.math.neg(i_v.get(0))); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Reciprocal.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Reciprocal.java index 1e11fa34d..8d2049f25 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Reciprocal.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Reciprocal.java @@ -74,7 +74,7 @@ public class Reciprocal extends BaseTransformSameOp { @Override public List doDiff(List i_v1) { // -1/(x^2) - SDVariable g = f().pow(arg(), 2).rdiv(-1).mul(i_v1.get(0)); + SDVariable g = sameDiff.math.pow(arg(), 2).rdiv(-1).mul(i_v1.get(0)); return Collections.singletonList(g); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Round.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Round.java index 375a8acb5..de9e0b685 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Round.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Round.java @@ -75,6 +75,6 @@ public class Round extends BaseTransformSameOp { @Override public List doDiff(List f1) { - return Arrays.asList(f().zerosLike(arg())); + return Arrays.asList(sameDiff.zerosLike(arg())); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Square.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Square.java index c63e00114..010955baf 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Square.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Square.java @@ -21,8 +21,10 @@ import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.BaseTransformOp; import org.nd4j.linalg.api.ops.BaseTransformSameOp; +import org.nd4j.linalg.api.ops.impl.scalar.PowDerivative; import java.util.Arrays; +import java.util.Collections; import java.util.List; /** @@ -72,7 +74,7 @@ public class Square extends BaseTransformSameOp { @Override public List doDiff(List i_v) { - SDVariable g = f().powDerivative(arg(), 2).mul(i_v.get(0)); - return Arrays.asList(g); + SDVariable g = new PowDerivative(sameDiff, arg(), false, 2).outputVariable().mul(i_v.get(0)); + return Collections.singletonList(g); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentMax.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentMax.java index 1506ac5f3..6f0432ba5 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentMax.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentMax.java @@ -22,6 +22,7 @@ import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; +import org.nd4j.linalg.api.ops.impl.transforms.segment.bp.UnsortedSegmentMaxBp; import java.util.*; @@ -40,13 +41,14 @@ public class UnsortedSegmentMax extends DynamicCustomOp { addIArgument(numSegments); } - public UnsortedSegmentMax(INDArray data, INDArray segmentIds, int numSegments) { - addInputArgument(data, segmentIds); + public UnsortedSegmentMax(){ } + + public UnsortedSegmentMax(INDArray data, INDArray segmentIds, int numSegments){ + super(new INDArray[]{data, segmentIds}, null); + this.numSegments = numSegments; addIArgument(numSegments); } - public UnsortedSegmentMax(){ } - @Override public String opName(){ return "unsorted_segment_max"; @@ -59,7 +61,7 @@ public class UnsortedSegmentMax extends DynamicCustomOp { @Override public List doDiff(List gradients){ - return Arrays.asList(f().unsortedSegmentMaxBp(arg(0), arg(1), gradients.get(0), numSegments)); + return new UnsortedSegmentMaxBp(sameDiff, arg(0), arg(1), gradients.get(0), numSegments).outputs(); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentMean.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentMean.java index 4338cf33d..ef39f1c91 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentMean.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentMean.java @@ -23,6 +23,7 @@ import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; +import org.nd4j.linalg.api.ops.impl.transforms.segment.bp.UnsortedSegmentMeanBp; import java.util.Arrays; import java.util.Collections; @@ -44,8 +45,9 @@ public class UnsortedSegmentMean extends DynamicCustomOp { addIArgument(numSegments); } - public UnsortedSegmentMean(INDArray data, INDArray segmentIds, int numSegments) { - addInputArgument(data, segmentIds); + public UnsortedSegmentMean(INDArray data, INDArray segmentIds, int numSegments){ + super(new INDArray[]{data, segmentIds}, null); + this.numSegments = numSegments; addIArgument(numSegments); } @@ -61,7 +63,7 @@ public class UnsortedSegmentMean extends DynamicCustomOp { @Override public List doDiff(List gradients){ - return Arrays.asList(f().unsortedSegmentMeanBp(arg(0), arg(1), gradients.get(0), numSegments)); + return new UnsortedSegmentMeanBp(sameDiff, arg(0), arg(1), gradients.get(0), numSegments).outputs(); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentMin.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentMin.java index 2f8aab0b1..6dc6e7737 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentMin.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentMin.java @@ -23,6 +23,7 @@ import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; +import org.nd4j.linalg.api.ops.impl.transforms.segment.bp.UnsortedSegmentMinBp; import java.util.Arrays; import java.util.Collections; @@ -44,8 +45,9 @@ public class UnsortedSegmentMin extends DynamicCustomOp { addIArgument(numSegments); } - public UnsortedSegmentMin(INDArray data, INDArray segmentIds, int numSegments) { - addInputArgument(data, segmentIds); + public UnsortedSegmentMin(INDArray data, INDArray segmentIds, int numSegments){ + super(new INDArray[]{data, segmentIds}, null); + this.numSegments = numSegments; addIArgument(numSegments); } @@ -61,7 +63,7 @@ public class UnsortedSegmentMin extends DynamicCustomOp { @Override public List doDiff(List gradients){ - return Arrays.asList(f().unsortedSegmentMinBp(arg(0), arg(1), gradients.get(0), numSegments)); + return new UnsortedSegmentMinBp(sameDiff, arg(0), arg(1), gradients.get(0), numSegments).outputs(); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentProd.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentProd.java index 7afd75fac..f753fe6dc 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentProd.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentProd.java @@ -23,6 +23,7 @@ import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; +import org.nd4j.linalg.api.ops.impl.transforms.segment.bp.UnsortedSegmentProdBp; import java.util.Arrays; import java.util.Collections; @@ -44,8 +45,9 @@ public class UnsortedSegmentProd extends DynamicCustomOp { addIArgument(numSegments); } - public UnsortedSegmentProd(INDArray data, INDArray segmentIds, int numSegments) { - addInputArgument(data, segmentIds); + public UnsortedSegmentProd(INDArray data, INDArray segmentIds, int numSegments){ + super(new INDArray[]{data, segmentIds}, null); + this.numSegments = numSegments; addIArgument(numSegments); } @@ -61,7 +63,7 @@ public class UnsortedSegmentProd extends DynamicCustomOp { @Override public List doDiff(List gradients){ - return Arrays.asList(f().unsortedSegmentProdBp(arg(0), arg(1), gradients.get(0), numSegments)); + return new UnsortedSegmentProdBp(sameDiff, arg(0), arg(1), gradients.get(0), numSegments).outputs(); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentSqrtN.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentSqrtN.java index 77474855c..ea5285f12 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentSqrtN.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentSqrtN.java @@ -23,6 +23,7 @@ import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; +import org.nd4j.linalg.api.ops.impl.transforms.segment.bp.UnsortedSegmentSqrtNBp; import java.util.ArrayList; import java.util.Arrays; @@ -38,18 +39,18 @@ public class UnsortedSegmentSqrtN extends DynamicCustomOp { private int numSegments; - public UnsortedSegmentSqrtN(INDArray data, INDArray segmentIds, int numSegments) { - addInputArgument(data, segmentIds); - addIArgument(numSegments); - this.numSegments = numSegments; - } - public UnsortedSegmentSqrtN(SameDiff sameDiff, SDVariable data, SDVariable segmentIds, int numSegments) { super(null, sameDiff, new SDVariable[] {data, segmentIds}, false); this.numSegments = numSegments; addIArgument(numSegments); } + public UnsortedSegmentSqrtN(INDArray data, INDArray segmentIds, int numSegments){ + super(new INDArray[]{data, segmentIds}, null); + this.numSegments = numSegments; + addIArgument(numSegments); + } + @Override public String opName(){ return "unsorted_segment_sqrt_n"; @@ -62,7 +63,7 @@ public class UnsortedSegmentSqrtN extends DynamicCustomOp { @Override public List doDiff(List gradients){ - return Arrays.asList(f().unsortedSegmentSqrtNBp(arg(0), arg(1), gradients.get(0), numSegments)); + return new UnsortedSegmentSqrtNBp(sameDiff, arg(0), arg(1), gradients.get(0), numSegments).outputs(); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentSum.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentSum.java index 336c756ac..d0d5b095b 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentSum.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentSum.java @@ -23,6 +23,7 @@ import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; +import org.nd4j.linalg.api.ops.impl.transforms.segment.bp.UnsortedSegmentSumBp; import org.nd4j.linalg.factory.Nd4j; import java.util.Arrays; @@ -45,8 +46,9 @@ public class UnsortedSegmentSum extends DynamicCustomOp { addIArgument(numSegments); } - public UnsortedSegmentSum(INDArray data, INDArray segmentIds, int numSegments) { - addInputArgument(data, segmentIds); + public UnsortedSegmentSum(INDArray data, INDArray segmentIds, int numSegments){ + super(new INDArray[]{data, segmentIds}, null); + this.numSegments = numSegments; addIArgument(numSegments); } @@ -62,7 +64,7 @@ public class UnsortedSegmentSum extends DynamicCustomOp { @Override public List doDiff(List gradients){ - return Arrays.asList(f().unsortedSegmentSumBp(arg(0), arg(1), gradients.get(0), numSegments)); + return new UnsortedSegmentSumBp(sameDiff, arg(0), arg(1), gradients.get(0), numSegments).outputs(); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/ACos.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/ACos.java index 3e0c60bb0..44dca1ed1 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/ACos.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/ACos.java @@ -26,6 +26,7 @@ import org.nd4j.linalg.api.ops.BaseTransformOp; import org.nd4j.linalg.api.ops.BaseTransformStrictOp; import java.util.Arrays; +import java.util.Collections; import java.util.List; /** @@ -75,9 +76,9 @@ public class ACos extends BaseTransformStrictOp { @Override public List doDiff(List i_v) { //dacos(x)/dx = -1 / sqrt(1-x^2) - SDVariable oneSubSq = f().square(arg()).rsub(1.0); - SDVariable sqrt = f().sqrt(oneSubSq); + SDVariable oneSubSq = sameDiff.math.square(arg()).rsub(1.0); + SDVariable sqrt = sameDiff.math.sqrt(oneSubSq); SDVariable ret = sqrt.rdiv(-1.0).mul(i_v.get(0)); - return Arrays.asList(ret); + return Collections.singletonList(ret); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/ASinh.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/ASinh.java index 49ef2fb09..25f20e011 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/ASinh.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/ASinh.java @@ -75,8 +75,8 @@ public class ASinh extends BaseTransformStrictOp { @Override public List doDiff(List i_v) { //dasinh(x)/dx = 1 / sqrt(x^2+1) - SDVariable xSqPlus1 = f().square(arg()).add(1.0); - SDVariable ret = i_v.get(0).div(f().sqrt(xSqPlus1)); + SDVariable xSqPlus1 = sameDiff.math.square(arg()).add(1.0); + SDVariable ret = i_v.get(0).div(sameDiff.math.sqrt(xSqPlus1)); return Arrays.asList(ret); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/ATan.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/ATan.java index 483896dfd..a7a741759 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/ATan.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/ATan.java @@ -26,6 +26,7 @@ import org.nd4j.linalg.api.ops.BaseTransformOp; import org.nd4j.linalg.api.ops.BaseTransformStrictOp; import java.util.Arrays; +import java.util.Collections; import java.util.List; /** @@ -76,8 +77,8 @@ public class ATan extends BaseTransformStrictOp { @Override public List doDiff(List i_v) { //d(atan(x))/dx = 1/(x^2+1) - SDVariable xSqPlus1 = f().square(arg()).add(1.0); + SDVariable xSqPlus1 = sameDiff.math.square(arg()).add(1.0); SDVariable ret = xSqPlus1.rdiv(1.0).mul(i_v.get(0)); - return Arrays.asList(ret); + return Collections.singletonList(ret); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Cos.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Cos.java index 21076ad6e..35ed040c2 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Cos.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Cos.java @@ -64,7 +64,7 @@ public class Cos extends BaseTransformStrictOp { @Override public List doDiff(List i_v) { - SDVariable ret = f().neg(f().sin(arg())).mul(i_v.get(0)); + SDVariable ret = sameDiff.math.neg(sameDiff.math.sin(arg())).mul(i_v.get(0)); return Arrays.asList(ret); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Cosh.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Cosh.java index dc08ead5f..5144315da 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Cosh.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Cosh.java @@ -74,7 +74,7 @@ public class Cosh extends BaseTransformStrictOp { @Override public List doDiff(List i_v) { - SDVariable ret = f().sinh(arg()).mul(i_v.get(0)); + SDVariable ret = sameDiff.math.sinh(arg()).mul(i_v.get(0)); return Arrays.asList(ret); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/ELU.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/ELU.java index c4fc245b7..cc3a5e116 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/ELU.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/ELU.java @@ -23,6 +23,7 @@ import org.nd4j.imports.NoOpNameFoundException; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; +import org.nd4j.linalg.api.ops.impl.transforms.gradient.EluBp; import java.util.Collections; import java.util.List; @@ -83,7 +84,7 @@ public class ELU extends DynamicCustomOp { public List doDiff(List i_v) { //ELU: e^x-1 if x<0, x otherwise //dL/dIn = dL/Out * dOut/dIn - return Collections.singletonList(f().eluBp(arg(), i_v.get(0), alpha)); + return new EluBp(sameDiff, arg(), i_v.get(0), alpha).outputs(); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Exp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Exp.java index 21aa49522..f9288d1d4 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Exp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Exp.java @@ -73,7 +73,7 @@ public class Exp extends BaseTransformStrictOp { @Override public List doDiff(List i_v) { - SDVariable ret = f().mul(f().exp(arg()), i_v.get(0)); + SDVariable ret = sameDiff.math.mul(sameDiff.math.exp(arg()), i_v.get(0)); return Arrays.asList(ret); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Expm1.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Expm1.java index 538f6a003..5b093a3ec 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Expm1.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Expm1.java @@ -75,7 +75,7 @@ public class Expm1 extends BaseTransformStrictOp { @Override public List doDiff(List i_v) { - SDVariable ret = f().mul(f().exp(arg()), i_v.get(0)); + SDVariable ret = sameDiff.math.mul(sameDiff.math.exp(arg()), i_v.get(0)); return Arrays.asList(ret); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/GELU.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/GELU.java index b784ddde0..009492924 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/GELU.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/GELU.java @@ -68,7 +68,7 @@ public class GELU extends BaseTransformStrictOp { @Override public List doDiff(List i_v) { - SDVariable ret = f().geluDerivative(arg(), false).mul(i_v.get(0)); + SDVariable ret = new GELUDerivative(sameDiff, arg(), false).outputVariable().mul(i_v.get(0)); return Collections.singletonList(ret); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/HardSigmoid.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/HardSigmoid.java index ddaa8631f..91c4eb8ae 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/HardSigmoid.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/HardSigmoid.java @@ -23,6 +23,7 @@ import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.BaseTransformFloatOp; import org.nd4j.linalg.api.ops.BaseTransformOp; import org.nd4j.linalg.api.ops.BaseTransformStrictOp; +import org.nd4j.linalg.api.ops.impl.transforms.gradient.HardSigmoidBp; import org.nd4j.linalg.api.ops.impl.transforms.gradient.HardSigmoidDerivative; import java.util.Collections; @@ -74,7 +75,7 @@ public class HardSigmoid extends BaseTransformStrictOp { @Override public List doDiff(List f1) { - return Collections.singletonList(f().hardSigmoidBp(arg(), f1.get(0))); + return new HardSigmoidBp(sameDiff, arg(), f1.get(0)).outputs(); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/HardTanh.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/HardTanh.java index fa80bf880..fc44f3c22 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/HardTanh.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/HardTanh.java @@ -25,6 +25,7 @@ import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.BaseTransformFloatOp; import org.nd4j.linalg.api.ops.BaseTransformOp; import org.nd4j.linalg.api.ops.BaseTransformStrictOp; +import org.nd4j.linalg.api.ops.impl.transforms.gradient.HardTanhBp; import java.util.Arrays; import java.util.List; @@ -75,6 +76,6 @@ public class HardTanh extends BaseTransformStrictOp { @Override public List doDiff(List i_v) { - return Collections.singletonList(f().hardTanhBp(arg(), i_v.get(0))); + return new HardTanhBp(sameDiff, arg(), i_v.get(0)).outputs(); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Log.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Log.java index a937e1d63..1fd8ac430 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Log.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Log.java @@ -24,6 +24,7 @@ import org.nd4j.linalg.api.ops.BaseTransformOp; import org.nd4j.linalg.api.ops.BaseTransformStrictOp; import java.util.Arrays; +import java.util.Collections; import java.util.List; /** @@ -75,8 +76,7 @@ public class Log extends BaseTransformStrictOp { @Override public List doDiff(List i_v) { - f().validateDifferentialFunctionsameDiff(arg()); - SDVariable toInverse = sameDiff.setupFunction(f().div(i_v.get(0), arg())); - return Arrays.asList(toInverse); + SDVariable toInverse = sameDiff.math.div(i_v.get(0), arg()); + return Collections.singletonList(toInverse); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Log1p.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Log1p.java index 131986d15..d61504e39 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Log1p.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Log1p.java @@ -19,6 +19,7 @@ package org.nd4j.linalg.api.ops.impl.transforms.strict; import lombok.NoArgsConstructor; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.autodiff.util.SameDiffUtils; import org.nd4j.imports.NoOpNameFoundException; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.BaseTransformFloatOp; @@ -73,7 +74,7 @@ public class Log1p extends BaseTransformStrictOp { @Override public List doDiff(List i_v) { - f().validateDifferentialFunctionsameDiff(arg()); + SameDiffUtils.validateDifferentialFunctionSameDiff(sameDiff, arg(), this); return Collections.singletonList(i_v.get(0).div(arg().add(1.0))); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/LogSigmoid.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/LogSigmoid.java index 353ced004..6a118a062 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/LogSigmoid.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/LogSigmoid.java @@ -23,6 +23,7 @@ import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.BaseTransformFloatOp; import org.nd4j.linalg.api.ops.BaseTransformOp; import org.nd4j.linalg.api.ops.BaseTransformStrictOp; +import org.nd4j.linalg.api.ops.impl.transforms.gradient.SigmoidDerivative; import java.util.Arrays; import java.util.Collections; @@ -74,10 +75,8 @@ public class LogSigmoid extends BaseTransformStrictOp { @Override public List doDiff(List i_v) { -// SDVariable ret = f().logSigmoidDerivative(arg(), i_v.get(0)); -// return Arrays.asList(ret); - SDVariable sigmDeriv = f().sigmoidDerivative(arg(), i_v.get(0)).div(f().sigmoid(arg())); - return Collections.singletonList(sigmDeriv); + SDVariable v = new SigmoidDerivative(sameDiff, arg(), i_v.get(0)).outputVariable().div(sameDiff.nn.sigmoid(arg())); + return Collections.singletonList(v); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Mish.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Mish.java index 416f74133..05ca68aa0 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Mish.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Mish.java @@ -23,6 +23,7 @@ import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.BaseTransformStrictOp; import java.util.Arrays; +import java.util.Collections; import java.util.List; /** @@ -69,8 +70,8 @@ public class Mish extends BaseTransformStrictOp { @Override public List doDiff(List i_v) { - SDVariable ret = f().mishDerivative(arg()).mul(i_v.get(0)); - return Arrays.asList(ret); + SDVariable ret = new MishDerivative(sameDiff, arg(), false).outputVariable().mul(i_v.get(0)); + return Collections.singletonList(ret); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/PreciseGELU.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/PreciseGELU.java index ab565b30f..e2d8c8b5a 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/PreciseGELU.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/PreciseGELU.java @@ -37,6 +37,10 @@ public class PreciseGELU extends BaseTransformStrictOp { super(sameDiff, i_v, inPlace); } + public PreciseGELU(SameDiff sameDiff, SDVariable i_v) { + this(sameDiff, i_v, false, true); + } + public PreciseGELU() { } @@ -72,7 +76,7 @@ public class PreciseGELU extends BaseTransformStrictOp { @Override public List doDiff(List i_v) { - SDVariable ret = f().geluDerivative(arg(), true).mul(i_v.get(0)); + SDVariable ret = new PreciseGELUDerivative(sameDiff, arg(), false, true).outputVariable().mul(i_v.get(0)); return Collections.singletonList(ret); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/RationalTanh.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/RationalTanh.java index a05e34637..ecf85c9a9 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/RationalTanh.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/RationalTanh.java @@ -21,6 +21,7 @@ import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.imports.NoOpNameFoundException; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.BaseTransformStrictOp; +import org.nd4j.linalg.api.ops.impl.transforms.gradient.RationalTanhBp; import java.util.Collections; import java.util.List; @@ -35,6 +36,10 @@ public class RationalTanh extends BaseTransformStrictOp { super(sameDiff, i_v, inPlace); } + public RationalTanh(SameDiff sameDiff, SDVariable i_v) { + this(sameDiff, i_v, false); + } + public RationalTanh() {} public RationalTanh(INDArray x, INDArray z) { @@ -68,6 +73,6 @@ public class RationalTanh extends BaseTransformStrictOp { @Override public List doDiff(List f1) { - return Collections.singletonList(f().tanhRationalBp(arg(), f1.get(0))); + return new RationalTanhBp(sameDiff, arg(), f1.get(0)).outputs(); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/RectifiedTanh.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/RectifiedTanh.java index d5fbf1294..8956f1b66 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/RectifiedTanh.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/RectifiedTanh.java @@ -22,6 +22,7 @@ import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.imports.NoOpNameFoundException; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.BaseTransformStrictOp; +import org.nd4j.linalg.api.ops.impl.transforms.gradient.RectifiedTanhBp; import org.tensorflow.framework.AttrValue; import org.tensorflow.framework.GraphDef; import org.tensorflow.framework.NodeDef; @@ -42,6 +43,10 @@ public class RectifiedTanh extends BaseTransformStrictOp { super(sameDiff, i_v, inPlace); } + public RectifiedTanh(SameDiff sameDiff, SDVariable i_v) { + this(sameDiff, i_v, false); + } + public RectifiedTanh() {} public RectifiedTanh(INDArray x, INDArray z) { @@ -85,6 +90,6 @@ public class RectifiedTanh extends BaseTransformStrictOp { @Override public List doDiff(List f1) { - return Collections.singletonList(f().tanhRectifiedBp(arg(), f1.get(0))); + return new RectifiedTanhBp(sameDiff, arg(), f1.get(0)).outputs(); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/SELU.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/SELU.java index 00592f0e2..472c4eece 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/SELU.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/SELU.java @@ -23,6 +23,7 @@ import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.BaseTransformFloatOp; import org.nd4j.linalg.api.ops.BaseTransformOp; import org.nd4j.linalg.api.ops.BaseTransformStrictOp; +import org.nd4j.linalg.api.ops.impl.transforms.gradient.SeluBp; import java.util.Arrays; import java.util.List; @@ -81,7 +82,7 @@ public class SELU extends BaseTransformStrictOp { @Override public List doDiff(List i_v) { - return Collections.singletonList(f().seluBp(arg(), i_v.get(0))); + return new SeluBp(sameDiff, arg(), i_v.get(0)).outputs(); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Sigmoid.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Sigmoid.java index 37ef4b743..2452d5906 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Sigmoid.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Sigmoid.java @@ -22,6 +22,7 @@ import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.BaseTransformFloatOp; import org.nd4j.linalg.api.ops.BaseTransformOp; import org.nd4j.linalg.api.ops.BaseTransformStrictOp; +import org.nd4j.linalg.api.ops.impl.transforms.gradient.SigmoidDerivative; import java.util.Arrays; import java.util.List; @@ -74,8 +75,7 @@ public class Sigmoid extends BaseTransformStrictOp { @Override public List doDiff(List i_v) { - SDVariable ret = f().sigmoidDerivative(arg(), i_v.get(0)); - return Arrays.asList(ret); + return new SigmoidDerivative(sameDiff, arg(), i_v.get(0)).outputs(); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Sin.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Sin.java index 0fa918c11..bfdde52d8 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Sin.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Sin.java @@ -25,6 +25,7 @@ import org.nd4j.linalg.api.ops.BaseTransformOp; import org.nd4j.linalg.api.ops.BaseTransformStrictOp; import java.util.Arrays; +import java.util.Collections; import java.util.List; /** @@ -75,8 +76,8 @@ public class Sin extends BaseTransformStrictOp { @Override public List doDiff(List i_v) { - SDVariable ret = f().cos(arg()).mul(i_v.get(0)); - return Arrays.asList(ret); + SDVariable ret = sameDiff.math.cos(arg()).mul(i_v.get(0)); + return Collections.singletonList(ret); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Sinh.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Sinh.java index d5e3be988..2f5b981bd 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Sinh.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Sinh.java @@ -25,6 +25,7 @@ import org.nd4j.linalg.api.ops.BaseTransformOp; import org.nd4j.linalg.api.ops.BaseTransformStrictOp; import java.util.Arrays; +import java.util.Collections; import java.util.List; /** @@ -75,8 +76,8 @@ public class Sinh extends BaseTransformStrictOp { @Override public List doDiff(List i_v) { - SDVariable ret = f().cosh(arg()).mul(i_v.get(0)); - return Arrays.asList(ret); + SDVariable ret = sameDiff.math.cosh(arg()).mul(i_v.get(0)); + return Collections.singletonList(ret); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/SoftPlus.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/SoftPlus.java index 11ffb2ef8..f3eeda670 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/SoftPlus.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/SoftPlus.java @@ -24,6 +24,7 @@ import org.nd4j.linalg.api.ops.BaseTransformOp; import org.nd4j.linalg.api.ops.BaseTransformStrictOp; import java.util.Arrays; +import java.util.Collections; import java.util.List; /** @@ -73,8 +74,8 @@ public class SoftPlus extends BaseTransformStrictOp { @Override public List doDiff(List i_v) { //dL/dIn = dL/Out * dOut/dIn - SDVariable ret = f().sigmoid(arg()).mul(i_v.get(0)); - return Arrays.asList(ret); + SDVariable ret = sameDiff.nn.sigmoid(arg()).mul(i_v.get(0)); + return Collections.singletonList(ret); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/SoftSign.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/SoftSign.java index 8be5ea2d4..057fda972 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/SoftSign.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/SoftSign.java @@ -16,15 +16,12 @@ package org.nd4j.linalg.api.ops.impl.transforms.strict; -import java.util.Collections; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.api.ops.BaseTransformFloatOp; -import org.nd4j.linalg.api.ops.BaseTransformOp; import org.nd4j.linalg.api.ops.BaseTransformStrictOp; +import org.nd4j.linalg.api.ops.impl.transforms.gradient.SoftSignBp; -import java.util.Arrays; import java.util.List; /** @@ -78,7 +75,7 @@ public class SoftSign extends BaseTransformStrictOp { @Override public List doDiff(List i_v) { - return Collections.singletonList(f().softsignBp(arg(), i_v.get(0))); + return new SoftSignBp(sameDiff, arg(), i_v.get(0)).outputs(); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Swish.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Swish.java index 0794e0b57..7f694f481 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Swish.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Swish.java @@ -74,7 +74,7 @@ public class Swish extends BaseTransformStrictOp { @Override public List doDiff(List i_v) { - SDVariable ret = f().swishDerivative(arg()).mul(i_v.get(0)); + SDVariable ret = new SwishDerivative(sameDiff, arg()).outputVariable().mul(i_v.get(0)); return Arrays.asList(ret); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/SwishDerivative.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/SwishDerivative.java index 350fb194e..552308859 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/SwishDerivative.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/SwishDerivative.java @@ -39,8 +39,8 @@ public class SwishDerivative extends BaseTransformStrictOp { super(sameDiff, i_v1, i_v2, inPlace); } - public SwishDerivative(SameDiff sameDiff, SDVariable i_v, boolean inPlace) { - super(sameDiff, i_v, inPlace); + public SwishDerivative(SameDiff sameDiff, SDVariable i_v) { + super(sameDiff, i_v, false); } public SwishDerivative() {} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Tan.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Tan.java index 3244925b1..2c9a603ee 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Tan.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Tan.java @@ -76,7 +76,7 @@ public class Tan extends BaseTransformStrictOp { @Override public List doDiff(List i_v) { //d(tan(x))/dx = (sec(x))^2 = 1 / (cos(x))^2 - SDVariable cosx = f().cos(arg()); + SDVariable cosx = sameDiff.math.cos(arg()); SDVariable cosSqx = sameDiff.math().square(cosx); return Collections.singletonList(i_v.get(0).div(cosSqx)); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Tanh.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Tanh.java index 136d0bbea..ca45a549a 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Tanh.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Tanh.java @@ -22,6 +22,7 @@ import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.BaseTransformFloatOp; import org.nd4j.linalg.api.ops.BaseTransformOp; import org.nd4j.linalg.api.ops.BaseTransformStrictOp; +import org.nd4j.linalg.api.ops.impl.transforms.gradient.TanhDerivative; import java.util.Arrays; import java.util.List; @@ -74,7 +75,6 @@ public class Tanh extends BaseTransformStrictOp { @Override public List doDiff(List i_v) { - SDVariable ret = f().tanhDerivative(arg(), i_v.get(0)); - return Arrays.asList(ret); + return new TanhDerivative(sameDiff, arg(), i_v.get(0)).outputs(); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/updaters/AmsGradUpdater.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/updaters/AmsGradUpdater.java index 35af113ad..5e8db1cfd 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/updaters/AmsGradUpdater.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/updaters/AmsGradUpdater.java @@ -30,11 +30,14 @@ public class AmsGradUpdater extends DynamicCustomOp { // } - public AmsGradUpdater(@NonNull INDArray gradients, @NonNull INDArray stateV, @NonNull INDArray stateM, @NonNull INDArray stateH, double lr, double beta1, double beta2, double epsilon, int iteration) { + public AmsGradUpdater(@NonNull INDArray gradients, @NonNull INDArray stateV, @NonNull INDArray stateM, @NonNull INDArray stateH, + double lr, double beta1, double beta2, double epsilon, int iteration) { this(gradients, stateV, stateM, stateH, gradients, stateV, stateM, stateH, lr, beta1, beta2, epsilon, iteration); } - public AmsGradUpdater(@NonNull INDArray gradients, @NonNull INDArray stateV, @NonNull INDArray stateM, @NonNull INDArray stateH, @NonNull INDArray updates, @NonNull INDArray updatedStateV, @NonNull INDArray updatedStateM, @NonNull INDArray updatedStateH, double lr, double beta1, double beta2, double epsilon, int iteration) { + public AmsGradUpdater(@NonNull INDArray gradients, @NonNull INDArray stateV, @NonNull INDArray stateM, @NonNull INDArray stateH, + @NonNull INDArray updates, @NonNull INDArray updatedStateV, @NonNull INDArray updatedStateM, + @NonNull INDArray updatedStateH, double lr, double beta1, double beta2, double epsilon, int iteration) { addInputArgument(gradients, stateV, stateM, stateH); addOutputArgument(updates, updatedStateV, updatedStateM, updatedStateH); addTArgument(lr, beta1, beta2, epsilon); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/updaters/NadamUpdater.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/updaters/NadamUpdater.java index ad4f374b7..325c85af5 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/updaters/NadamUpdater.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/updaters/NadamUpdater.java @@ -30,11 +30,14 @@ public class NadamUpdater extends DynamicCustomOp { // } - public NadamUpdater(@NonNull INDArray gradients, @NonNull INDArray stateV, @NonNull INDArray stateM, double lr, double beta1, double beta2, double epsilon, int iteration) { + public NadamUpdater(@NonNull INDArray gradients, @NonNull INDArray stateV, @NonNull INDArray stateM, double lr, + double beta1, double beta2, double epsilon, int iteration) { this(gradients, stateV, stateM, gradients, stateV, stateM, lr, beta1, beta2, epsilon, iteration); } - public NadamUpdater(@NonNull INDArray gradients, @NonNull INDArray stateV, @NonNull INDArray stateM, @NonNull INDArray updates, @NonNull INDArray updatedStateV, @NonNull INDArray updatedStateM, double lr, double beta1, double beta2, double epsilon, int iteration) { + public NadamUpdater(@NonNull INDArray gradients, @NonNull INDArray stateV, @NonNull INDArray stateM, @NonNull INDArray updates, + @NonNull INDArray updatedStateV, @NonNull INDArray updatedStateM, double lr, double beta1, double beta2, + double epsilon, int iteration) { addInputArgument(gradients, stateV, stateM); addOutputArgument(updates, updatedStateV, updatedStateM); addTArgument(lr, beta1, beta2, epsilon); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/NDValidation.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/NDValidation.java index 4986b8277..e9dd0f840 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/NDValidation.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/NDValidation.java @@ -130,13 +130,20 @@ public class NDValidation { " type; got array with non-integer data type " + v.dataType()); } - public static void validateInteger(String opName, String inputName, INDArray[] vars) { - for (INDArray v : vars) { - if (v == null) - return; - if (!v.dataType().isIntType()) + /** + * Validate that the operation is being applied on an integer type INDArray [] + * + * @param opName Operation name to print in the exception + * @param inputName Name of the input to the op to validate + * @param v Variable to validate datatype for (input to operation) + */ + public static void validateInteger(String opName, String inputName, INDArray [] v) { + if (v == null) + return; + for (int i = 0; i < v.length; i++) { + if (!v[i].dataType().isIntType()) throw new IllegalStateException("Input \"" + inputName + "\" for operation \"" + opName + "\" must be an integer" + - " type; got array with non-integer data type " + v.dataType()); + " type; got array with non-integer data type member" + v[i].dataType()); } } @@ -246,11 +253,12 @@ public class NDValidation { } public static boolean isSameType(INDArray[] x) { - DataType firstDataType = x[0].dataType(); - if (x.length > 1) { - for (int i = 1; i < x.length; ++i) { - if (firstDataType != x[i].dataType()) - return false; + if(x.length == 0) + return true; + DataType first = x[0].dataType(); + for( int i=1; i * - * @param x Input variable (BOOL type) + * @param x Input variable (NDARRAY type) * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) * @return output reduced array of rank (input rank - num dimensions) (BOOL type) */ public INDArray all(INDArray x, int... dimensions) { - NDValidation.validateBool("all", "x", x); Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce.bool.All(x, dimensions)); } @@ -47,12 +46,11 @@ public class NDBase { /** * Boolean or array reduction operation, optionally along specified dimensions
* - * @param x Input variable (BOOL type) + * @param x Input variable (NDARRAY type) * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) * @return output reduced array of rank (input rank - num dimensions) (BOOL type) */ public INDArray any(INDArray x, int... dimensions) { - NDValidation.validateBool("any", "x", x); Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce.bool.Any(x, dimensions)); } @@ -114,6 +112,8 @@ public class NDBase { * keepDims = false: [a,c]
* * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
+ * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
* * @param in Input variable (NUMERIC type) * @param keepDims If true: keep the dimensions that are reduced on (as size 1). False: remove the reduction dimensions @@ -138,6 +138,8 @@ public class NDBase { * keepDims = false: [a,c]
* * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
+ * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
* * @param in Input variable (NUMERIC type) * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) @@ -369,6 +371,8 @@ public class NDBase { * If x and y arrays have equal shape, the output shape is the same as these inputs.
* * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
+ * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
* * Return boolean array with values true where satisfied, or false otherwise.
* @@ -472,6 +476,8 @@ public class NDBase { * If x and y arrays have equal shape, the output shape is the same as these inputs.
* * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
+ * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
* * Return boolean array with values true where satisfied, or false otherwise.
* @@ -504,6 +510,8 @@ public class NDBase { * If x and y arrays have equal shape, the output shape is the same as these inputs.
* * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
+ * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
* * Return boolean array with values true where satisfied, or false otherwise.
* @@ -602,6 +610,8 @@ public class NDBase { * If x and y arrays have equal shape, the output shape is the same as these inputs.
* * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
+ * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
* * Return boolean array with values true where satisfied, or false otherwise.
* @@ -634,6 +644,8 @@ public class NDBase { * If x and y arrays have equal shape, the output shape is the same as these inputs.
* * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
+ * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
* * Return boolean array with values true where satisfied, or false otherwise.
* @@ -760,6 +772,8 @@ public class NDBase { * Element-wise maximum operation: out[i] = max(first[i], second[i])
* * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
+ * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
* * @param first First input array (NUMERIC type) * @param second Second input array (NUMERIC type) @@ -812,6 +826,21 @@ public class NDBase { return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce.floating.Mean(x, false, dimensions)); } + /** + * The merge operation is a control operation that forwards the either of the inputs to the output, when
+ * the first of them becomes available. If both are available, the output is undefined (either input could
+ * be forwarded to the output)
+ * + * @param x Input variable (NUMERIC type) + * @param y Input variable (NUMERIC type) + * @return output Output (NUMERIC type) + */ + public INDArray merge(INDArray x, INDArray y) { + NDValidation.validateNumerical("merge", "x", x); + NDValidation.validateNumerical("merge", "y", y); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.controlflow.compat.Merge(x, y))[0]; + } + /** * Minimum array reduction operation, optionally along specified dimensions. out = min(in)
* @@ -857,6 +886,8 @@ public class NDBase { * Element-wise minimum operation: out[i] = min(first[i], second[i])
* * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
+ * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
* * @param first First input array (NUMERIC type) * @param second Second input array (NUMERIC type) @@ -919,6 +950,8 @@ public class NDBase { * If x and y arrays have equal shape, the output shape is the same as these inputs.
* * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
+ * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
* * Return boolean array with values true where satisfied, or false otherwise.
* @@ -1978,6 +2011,18 @@ public class NDBase { return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce.same.Sum(x, false, dimensions)); } + /** + * Switch operation
+ * Predictate - if false, values are output to left (first) branch/output; if true, to right (second) branch/output
+ * + * @param x Input variable (NDARRAY type) + * @param predicate Predictate - if false, values are output to left (first) branch/output; if true, to right (second) branch/output (BOOL type) + */ + public INDArray[] switchOp(INDArray x, INDArray predicate) { + NDValidation.validateBool("switchOp", "predicate", predicate); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.controlflow.compat.Switch(x, predicate)); + } + /** * //TODO: Ops must be documented.
* diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDImage.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDImage.java index 859ad43c3..536633cd2 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDImage.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDImage.java @@ -21,6 +21,7 @@ package org.nd4j.linalg.factory.ops; import static org.nd4j.linalg.factory.NDValidation.isSameType; import org.nd4j.base.Preconditions; +import org.nd4j.enums.ImageResizeMethod; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.NDValidation; import org.nd4j.linalg.factory.Nd4j; @@ -134,6 +135,49 @@ public class NDImage { return Nd4j.exec(new org.nd4j.linalg.api.ops.custom.HsvToRgb(input))[0]; } + /** + * Resize images to size using the specified method.
+ * + * @param input 4D image [NCHW] (NUMERIC type) + * @param size new height and width (INT type) + * @param preserveAspectRatio Whether to preserve the aspect ratio. If this is set, then images will be resized to a size that fits in size while preserving the aspect ratio of the original image. Scales up the image if size is bigger than the current size of the image. Defaults to False. + * @param antialis Whether to use an anti-aliasing filter when downsampling an image + * @param ImageResizeMethod ResizeBilinear: Bilinear interpolation. If 'antialias' is true, becomes a hat/tent filter function with radius 1 when downsampling. + * ResizeLanczos5: Lanczos kernel with radius 5. Very-high-quality filter but may have stronger ringing. + * ResizeBicubic: Cubic interpolant of Keys. Equivalent to Catmull-Rom kernel. Reasonably good quality and faster than Lanczos3Kernel, particularly when upsampling. + * ResizeGaussian: Gaussian kernel with radius 3, sigma = 1.5 / 3.0. + * ResizeNearest: Nearest neighbor interpolation. 'antialias' has no effect when used with nearest neighbor interpolation. + * ResizeArea: Anti-aliased resampling with area interpolation. 'antialias' has no effect when used with area interpolation; it always anti-aliases. + * ResizeMitchelcubic: Mitchell-Netravali Cubic non-interpolating filter. For synthetic images (especially those lacking proper prefiltering), less ringing than Keys cubic kernel but less sharp. + * @return output Output image (NUMERIC type) + */ + public INDArray imageResize(INDArray input, INDArray size, boolean preserveAspectRatio, + boolean antialis, ImageResizeMethod ImageResizeMethod) { + NDValidation.validateNumerical("imageResize", "input", input); + NDValidation.validateInteger("imageResize", "size", size); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.image.ImageResize(input, size, preserveAspectRatio, antialis, ImageResizeMethod))[0]; + } + + /** + * Resize images to size using the specified method.
+ * + * @param input 4D image [NCHW] (NUMERIC type) + * @param size new height and width (INT type) + * @param ImageResizeMethod ResizeBilinear: Bilinear interpolation. If 'antialias' is true, becomes a hat/tent filter function with radius 1 when downsampling. + * ResizeLanczos5: Lanczos kernel with radius 5. Very-high-quality filter but may have stronger ringing. + * ResizeBicubic: Cubic interpolant of Keys. Equivalent to Catmull-Rom kernel. Reasonably good quality and faster than Lanczos3Kernel, particularly when upsampling. + * ResizeGaussian: Gaussian kernel with radius 3, sigma = 1.5 / 3.0. + * ResizeNearest: Nearest neighbor interpolation. 'antialias' has no effect when used with nearest neighbor interpolation. + * ResizeArea: Anti-aliased resampling with area interpolation. 'antialias' has no effect when used with area interpolation; it always anti-aliases. + * ResizeMitchelcubic: Mitchell-Netravali Cubic non-interpolating filter. For synthetic images (especially those lacking proper prefiltering), less ringing than Keys cubic kernel but less sharp. + * @return output Output image (NUMERIC type) + */ + public INDArray imageResize(INDArray input, INDArray size, ImageResizeMethod ImageResizeMethod) { + NDValidation.validateNumerical("imageResize", "input", input); + NDValidation.validateInteger("imageResize", "size", size); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.image.ImageResize(input, size, false, false, ImageResizeMethod))[0]; + } + /** * Greedily selects a subset of bounding boxes in descending order of score
* diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDMath.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDMath.java index bee0da889..1deddfd0a 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDMath.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDMath.java @@ -21,6 +21,7 @@ package org.nd4j.linalg.factory.ops; import static org.nd4j.linalg.factory.NDValidation.isSameType; import org.nd4j.base.Preconditions; +import org.nd4j.enums.PartitionMode; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.NDValidation; @@ -31,6 +32,34 @@ public class NDMath { public NDMath() { } + /** + * Clips tensor values to a maximum average L2-norm.
+ * + * @param x Input variable (NUMERIC type) + * @param clipValue Value for clipping + * @param dimensions Dimensions to reduce over (Size: AtLeast(min=0)) + * @return output Output variable (NUMERIC type) + */ + public INDArray clipByAvgNorm(INDArray x, double clipValue, int... dimensions) { + NDValidation.validateNumerical("ClipByAvgNorm", "x", x); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.clip.ClipByAvgNorm(x, clipValue, dimensions))[0]; + } + + /** + * Looks up ids in a list of embedding tensors.
+ * + * @param x Input tensor (NUMERIC type) + * @param indices A Tensor containing the ids to be looked up. (NUMERIC type) + * @param PartitionMode partition_mode == 0 - i.e. 'mod' , 1 - 'div' + * @return output Shifted output (NUMERIC type) + */ + public INDArray embeddingLookup(INDArray x, INDArray indices, PartitionMode PartitionMode) { + NDValidation.validateNumerical("EmbeddingLookup", "x", x); + NDValidation.validateNumerical("EmbeddingLookup", "indices", indices); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.tensorops.EmbeddingLookup(x, indices, PartitionMode))[0]; + } + /** * Elementwise absolute value operation: out = abs(x)
* @@ -64,6 +93,35 @@ public class NDMath { return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.strict.ACosh(x)); } + /** + * Pairwise addition operation, out = x + y
+ * + * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
+ * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
+ * + * @param x Input variable (NUMERIC type) + * @param y Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public INDArray add(INDArray x, INDArray y) { + NDValidation.validateNumerical("add", "x", x); + NDValidation.validateNumerical("add", "y", y); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.AddOp(x, y))[0]; + } + + /** + * Scalar add operation, out = in + scalar
+ * + * @param x Input variable (NUMERIC type) + * @param value Scalar value for op + * @return output Output variable (NUMERIC type) + */ + public INDArray add(INDArray x, double value) { + NDValidation.validateNumerical("add", "x", x); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.scalar.ScalarAdd(x, value)); + } + /** * Absolute max array reduction operation, optionally along specified dimensions: out = max(abs(x))
* @@ -511,6 +569,35 @@ public class NDMath { return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.DiagPart(x))[0]; } + /** + * Pairwise division operation, out = x / y
+ * + * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
+ * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
+ * + * @param x Input variable (NUMERIC type) + * @param y Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public INDArray div(INDArray x, INDArray y) { + NDValidation.validateNumerical("div", "x", x); + NDValidation.validateNumerical("div", "y", y); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.DivOp(x, y))[0]; + } + + /** + * Scalar division operation, out = in / scalar
+ * + * @param x Input variable (NUMERIC type) + * @param value Scalar value for op + * @return output Output variable (NUMERIC type) + */ + public INDArray div(INDArray x, double value) { + NDValidation.validateNumerical("div", "x", x); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.scalar.ScalarDivision(x, value)); + } + /** * Entropy reduction: -sum(x * log(x))
* @@ -710,6 +797,52 @@ public class NDMath { return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.same.Floor(x)); } + /** + * Pairwise floor division operation, out = floor(x / y)
+ * + * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
+ * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
+ * + * @param x Input variable (NUMERIC type) + * @param y Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public INDArray floorDiv(INDArray x, INDArray y) { + NDValidation.validateNumerical("floorDiv", "x", x); + NDValidation.validateNumerical("floorDiv", "y", y); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.FloorDivOp(x, y))[0]; + } + + /** + * Pairwise Modulus division operation
+ * + * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
+ * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
+ * + * @param x Input variable (NUMERIC type) + * @param y Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public INDArray floorMod(INDArray x, INDArray y) { + NDValidation.validateNumerical("floorMod", "x", x); + NDValidation.validateNumerical("floorMod", "y", y); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.FloorModOp(x, y))[0]; + } + + /** + * Scalar floor modulus operation
+ * + * @param x Input variable (NUMERIC type) + * @param value Scalar value for op + * @return output Output variable (NUMERIC type) + */ + public INDArray floorMod(INDArray x, double value) { + NDValidation.validateNumerical("floorMod", "x", x); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.scalar.ScalarFMod(x, value)); + } + /** * Hamming distance reduction operation. The output contains the cosine distance for each
* tensor/subset along the specified dimensions:
@@ -1040,6 +1173,23 @@ public class NDMath { return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.MatrixInverse(in))[0]; } + /** + * Pairwise max operation, out = max(x, y)
+ * + * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
+ * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
+ * + * @param x First input variable, x (NUMERIC type) + * @param y Second input variable, y (NUMERIC type) + * @return out Output (NUMERIC type) + */ + public INDArray max(INDArray x, INDArray y) { + NDValidation.validateNumerical("max", "x", x); + NDValidation.validateNumerical("max", "y", y); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.Max(x, y))[0]; + } + /** * Merge add function: merges an arbitrary number of equal shaped arrays using element-wise addition:
* out = sum_i in[i]
@@ -1091,6 +1241,40 @@ public class NDMath { return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.MeshGrid(inputs, cartesian)); } + /** + * Pairwise max operation, out = min(x, y)
+ * + * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
+ * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
+ * + * @param x First input variable, x (NUMERIC type) + * @param y Second input variable, y (NUMERIC type) + * @return out Output (NUMERIC type) + */ + public INDArray min(INDArray x, INDArray y) { + NDValidation.validateNumerical("min", "x", x); + NDValidation.validateNumerical("min", "y", y); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.Min(x, y))[0]; + } + + /** + * Pairwise modulus (remainder) operation, out = x % y
+ * + * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
+ * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
+ * + * @param x Input variable (NUMERIC type) + * @param y Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public INDArray mod(INDArray x, INDArray y) { + NDValidation.validateNumerical("mod", "x", x); + NDValidation.validateNumerical("mod", "y", y); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.ModOp(x, y))[0]; + } + /** * Calculate the mean and (population) variance for the input variable, for the specified axis
* @@ -1103,6 +1287,35 @@ public class NDMath { return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce.Moments(input, axes)); } + /** + * Pairwise multiplication operation, out = x * y
+ * + * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
+ * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
+ * + * @param x Input variable (NUMERIC type) + * @param y Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public INDArray mul(INDArray x, INDArray y) { + NDValidation.validateNumerical("mul", "x", x); + NDValidation.validateNumerical("mul", "y", y); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.MulOp(x, y))[0]; + } + + /** + * Scalar multiplication operation, out = in * scalar
+ * + * @param x Input variable (NUMERIC type) + * @param value Scalar value for op + * @return output Output variable (NUMERIC type) + */ + public INDArray mul(INDArray x, double value) { + NDValidation.validateNumerical("mul", "x", x); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.scalar.ScalarMultiplication(x, value)); + } + /** * Elementwise negative operation: out = -x
* @@ -1171,6 +1384,48 @@ public class NDMath { return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.Pow(x, y))[0]; } + /** + * Rational Tanh Approximation elementwise function, as described in the paper:
+ * Compact Convolutional Neural Network Cascade for Face Detection
+ * This is a faster Tanh approximation
+ * + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public INDArray rationalTanh(INDArray x) { + NDValidation.validateNumerical("rationalTanh", "x", x); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.strict.RationalTanh(x)); + } + + /** + * Pairwise reverse division operation, out = y / x
+ * + * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
+ * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
+ * + * @param x Input variable (NUMERIC type) + * @param y Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public INDArray rdiv(INDArray x, INDArray y) { + NDValidation.validateNumerical("rdiv", "x", x); + NDValidation.validateNumerical("rdiv", "y", y); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.RDivOp(x, y))[0]; + } + + /** + * Scalar reverse division operation, out = scalar / in
+ * + * @param x Input variable (NUMERIC type) + * @param value Scalar value for op + * @return output Output variable (NUMERIC type) + */ + public INDArray rdiv(INDArray x, double value) { + NDValidation.validateNumerical("rdiv", "x", x); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.scalar.ScalarReverseDivision(x, value)); + } + /** * Element-wise reciprocal (inverse) function: out[i] = 1 / in[i]
* @@ -1182,6 +1437,17 @@ public class NDMath { return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.same.Reciprocal(x)); } + /** + * Rectified tanh operation: max(0, tanh(in))
+ * + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public INDArray rectifiedTanh(INDArray x) { + NDValidation.validateNumerical("rectifiedTanh", "x", x); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.strict.RectifiedTanh(x)); + } + /** * Element-wise round function: out = round(x).
* Rounds (up or down depending on value) to the nearest integer value.
@@ -1205,6 +1471,35 @@ public class NDMath { return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.floating.RSqrt(x)); } + /** + * Pairwise reverse subtraction operation, out = y - x
+ * + * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
+ * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
+ * + * @param x Input variable (NUMERIC type) + * @param y Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public INDArray rsub(INDArray x, INDArray y) { + NDValidation.validateNumerical("rsub", "x", x); + NDValidation.validateNumerical("rsub", "y", y); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.RSubOp(x, y))[0]; + } + + /** + * Scalar reverse subtraction operation, out = scalar - in
+ * + * @param x Input variable (NUMERIC type) + * @param value Scalar value for op + * @return output Output variable (NUMERIC type) + */ + public INDArray rsub(INDArray x, double value) { + NDValidation.validateNumerical("rsub", "x", x); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.scalar.ScalarReverseSubtraction(x, value)); + } + /** * Set the diagonal value to the specified values
* If input is
@@ -1297,6 +1592,23 @@ public class NDMath { return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.same.Square(x)); } + /** + * Pairwise squared difference operation.
+ * + * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
+ * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
+ * + * @param x Input variable (NUMERIC type) + * @param y Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public INDArray squaredDifference(INDArray x, INDArray y) { + NDValidation.validateNumerical("squaredDifference", "x", x); + NDValidation.validateNumerical("squaredDifference", "y", y); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.SquaredDifferenceOp(x, y))[0]; + } + /** * Standardize input variable along given axis
*


@@ -1335,6 +1647,35 @@ public class NDMath { return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.scalar.Step(x, value)); } + /** + * Pairwise subtraction operation, out = x - y
+ * + * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
+ * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
+ * + * @param x Input variable (NUMERIC type) + * @param y Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public INDArray sub(INDArray x, INDArray y) { + NDValidation.validateNumerical("sub", "x", x); + NDValidation.validateNumerical("sub", "y", y); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.SubOp(x, y))[0]; + } + + /** + * Scalar subtraction operation, out = in - scalar
+ * + * @param x Input variable (NUMERIC type) + * @param value Scalar value for op + * @return output Output variable (NUMERIC type) + */ + public INDArray sub(INDArray x, double value) { + NDValidation.validateNumerical("sub", "x", x); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.scalar.ScalarSubtraction(x, value)); + } + /** * Elementwise tangent operation: out = tan(x)
* diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDNN.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDNN.java index 3f9e1431a..e2a8af245 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDNN.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDNN.java @@ -21,6 +21,7 @@ package org.nd4j.linalg.factory.ops; import static org.nd4j.linalg.factory.NDValidation.isSameType; import org.nd4j.base.Preconditions; +import org.nd4j.enums.PadMode; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.NDValidation; import org.nd4j.linalg.factory.Nd4j; @@ -29,6 +30,17 @@ public class NDNN { public NDNN() { } + /** + * Concatenates a ReLU which selects only the positive part of the activation with a ReLU which selects only the negative part of the activation. Note that as a result this non-linearity doubles the depth of the activations.
+ * + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public INDArray cReLU(INDArray x) { + NDValidation.validateNumerical("CReLU", "x", x); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.CReLU(x))[0]; + } + /** * Neural network batch normalization operation.
* For details, see https://arxiv.org/abs/1502.03167
@@ -344,6 +356,21 @@ public class NDNN { return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.MultiHeadDotProductAttention(queries, keys, values, Wq, Wk, Wv, Wo, mask, scaled, false))[0]; } + /** + * Padding operation
+ * + * @param input Input tensor (NUMERIC type) + * @param padding Padding value (NUMERIC type) + * @param PadMode Padding format + * @param constant Padding constant + * @return output Padded input (NUMERIC type) + */ + public INDArray pad(INDArray input, INDArray padding, PadMode PadMode, double constant) { + NDValidation.validateNumerical("pad", "input", input); + NDValidation.validateNumerical("pad", "padding", padding); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.Pad(input, padding, PadMode, constant))[0]; + } + /** * Padding operation
* @@ -355,7 +382,20 @@ public class NDNN { public INDArray pad(INDArray input, INDArray padding, double constant) { NDValidation.validateNumerical("pad", "input", input); NDValidation.validateNumerical("pad", "padding", padding); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.Pad(input, padding, constant))[0]; + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.Pad(input, padding, PadMode.CONSTANT, constant))[0]; + } + + /** + * GELU activation function - Gaussian Error Linear Units
+ * For more details, see Gaussian Error Linear Units (GELUs) - https://arxiv.org/abs/1606.08415
+ * This method uses the precise method
+ * + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public INDArray preciseGelu(INDArray x) { + NDValidation.validateNumerical("preciseGelu", "x", x); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.strict.PreciseGELU(x)); } /** diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/indexing/BooleanIndexing.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/indexing/BooleanIndexing.java index 77f56c613..49dfd1458 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/indexing/BooleanIndexing.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/indexing/BooleanIndexing.java @@ -159,7 +159,7 @@ public class BooleanIndexing { if (to.length() != from.length()) throw new IllegalStateException("Mis matched length for to and from"); - Nd4j.getExecutioner().exec(new CompareAndSet(to, from, condition)); + Nd4j.getExecutioner().exec(new CompareAndSet(to, from, to, condition)); } @@ -177,7 +177,7 @@ public class BooleanIndexing { if (to.length() != from.length()) throw new IllegalStateException("Mis matched length for to and from"); - Nd4j.getExecutioner().exec(new CompareAndReplace(to, from, condition)); + Nd4j.getExecutioner().exec(new CompareAndReplace(to, from, to, condition)); } /** diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/learning/AMSGradUpdater.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/learning/AMSGradUpdater.java index 79907a237..37d1cb01d 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/learning/AMSGradUpdater.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/learning/AMSGradUpdater.java @@ -1,5 +1,6 @@ /******************************************************************************* * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2020 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -19,14 +20,12 @@ package org.nd4j.linalg.learning; import lombok.Data; import lombok.NonNull; import lombok.val; -import org.apache.commons.math3.util.FastMath; import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.api.ops.impl.transforms.floating.Sqrt; +import org.nd4j.linalg.api.ops.impl.updaters.AmsGradUpdater; import org.nd4j.linalg.api.shape.Shape; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.indexing.NDArrayIndex; import org.nd4j.linalg.learning.config.AMSGrad; -import org.nd4j.linalg.ops.transforms.Transforms; import java.util.HashMap; import java.util.Map; @@ -103,27 +102,11 @@ public class AMSGradUpdater implements GradientUpdater { double epsilon = config.getEpsilon(); //m_t = b_1 * m_{t-1} + (1-b_1) * g_t eq 1 pg 3 - INDArray oneMinusBeta1Grad = gradient.mul(1.0 - beta1); - m.muli(beta1).addi(oneMinusBeta1Grad); - //v_t = b_2 * v_{t-1} + (1-b_2) * (g_t)^2 eq 1 pg 3 - INDArray oneMinusBeta2GradSquared = gradient.mul(gradient).muli(1 - beta2); - v.muli(beta2).addi(oneMinusBeta2GradSquared); - - double beta1t = FastMath.pow(beta1, iteration + 1); - double beta2t = FastMath.pow(beta2, iteration + 1); - //vHat_t = max(vHat_{t-1}, v_t) - Transforms.max(vHat, v, false); - - double alphat = learningRate * FastMath.sqrt(1 - beta2t) / (1 - beta1t); - if (Double.isNaN(alphat) || alphat == 0.0) - alphat = epsilon; - //gradient array contains: sqrt(vHat) + eps - Nd4j.getExecutioner().exec(new Sqrt(vHat, gradient)).addi(epsilon); - //gradient = alphat * m_t / (sqrt(vHat) + eps) - gradient.rdivi(m).muli(alphat); + + Nd4j.exec(new AmsGradUpdater(gradient, v, m, vHat, learningRate, beta1, beta2, epsilon, iteration)); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/learning/AdaDeltaUpdater.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/learning/AdaDeltaUpdater.java index ced2a8c84..6aa7d7ab4 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/learning/AdaDeltaUpdater.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/learning/AdaDeltaUpdater.java @@ -1,5 +1,6 @@ /******************************************************************************* * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2020 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -19,9 +20,9 @@ package org.nd4j.linalg.learning; import lombok.Data; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.shape.Shape; +import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.indexing.NDArrayIndex; import org.nd4j.linalg.learning.config.AdaDelta; -import org.nd4j.linalg.ops.transforms.Transforms; import java.util.HashMap; import java.util.Map; @@ -104,16 +105,11 @@ public class AdaDeltaUpdater implements GradientUpdater { //Line 4 of Algorithm 1: https://arxiv.org/pdf/1212.5701v1.pdf //E[g^2]_t = rho * E[g^2]_{t-1} + (1-rho)*g^2_t - msg.muli(rho).addi(gradient.mul(gradient).muli(1 - rho)); - //Calculate update: //dX = - g * RMS[delta x]_{t-1} / RMS[g]_t //Note: negative is applied in the DL4J step function: params -= update rather than params += update - INDArray rmsdx_t1 = Transforms.sqrt(msdx.add(epsilon), false); - INDArray rmsg_t = Transforms.sqrt(msg.add(epsilon), false); - INDArray update = gradient.muli(rmsdx_t1.divi(rmsg_t)); - //Accumulate gradients: E[delta x^2]_t = rho * E[delta x^2]_{t-1} + (1-rho)* (delta x_t)^2 - msdx.muli(rho).addi(update.mul(update).muli(1 - rho)); + + Nd4j.exec(new org.nd4j.linalg.api.ops.impl.updaters.AdaDeltaUpdater(gradient, msg, msdx, rho, epsilon)); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/learning/AdaGradUpdater.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/learning/AdaGradUpdater.java index 09a530a51..211022366 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/learning/AdaGradUpdater.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/learning/AdaGradUpdater.java @@ -1,5 +1,6 @@ /******************************************************************************* * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2020 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -18,16 +19,14 @@ package org.nd4j.linalg.learning; import lombok.Data; -import lombok.EqualsAndHashCode; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.shape.Shape; +import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.learning.config.AdaGrad; import java.util.Collections; import java.util.Map; -import static org.nd4j.linalg.ops.transforms.Transforms.sqrt; - /** * Vectorized Learning Rate used per Connection Weight @@ -98,10 +97,6 @@ public class AdaGradUpdater implements GradientUpdater { double learningRate = config.getLearningRate(iteration, epoch); double epsilon = config.getEpsilon(); - historicalGradient.addi(gradient.mul(gradient)); - - INDArray sqrtHistory = sqrt(historicalGradient.dup(gradientReshapeOrder), false).addi(epsilon); - // lr * gradient / (sqrt(sumSquaredGradients) + epsilon) - gradient.muli(sqrtHistory.rdivi(learningRate)); + Nd4j.exec(new org.nd4j.linalg.api.ops.impl.updaters.AdaGradUpdater(gradient, historicalGradient, learningRate, epsilon)); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/learning/AdaMaxUpdater.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/learning/AdaMaxUpdater.java index 20a908f1e..06fbde54d 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/learning/AdaMaxUpdater.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/learning/AdaMaxUpdater.java @@ -1,5 +1,6 @@ /******************************************************************************* * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2020 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -18,14 +19,11 @@ package org.nd4j.linalg.learning; import lombok.Data; import lombok.NonNull; -import org.apache.commons.math3.util.FastMath; import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.api.ops.impl.transforms.custom.Max; import org.nd4j.linalg.api.shape.Shape; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.indexing.NDArrayIndex; import org.nd4j.linalg.learning.config.AdaMax; -import org.nd4j.linalg.ops.transforms.Transforms; import java.util.HashMap; import java.util.Map; @@ -99,22 +97,13 @@ public class AdaMaxUpdater implements GradientUpdater { throw new IllegalStateException("Updater has not been initialized with view state"); //m = B_1 * m + (1-B_1)*grad - m.muli(config.getBeta1()).addi(gradient.mul(1 - config.getBeta1())); - //u = max(B_2 * u, |grad|) - u.muli(config.getBeta2()); - Transforms.abs(gradient, false); //In-place should be OK here, original gradient values aren't used again later - Nd4j.getExecutioner().exec(new Max(u, gradient, u)); - double beta1t = FastMath.pow(config.getBeta1(), iteration + 1); + double lr = config.getLearningRate(iteration, epoch); + double b1 = config.getBeta1(); + double b2 = config.getBeta2(); + double eps = config.getEpsilon(); - double learningRate = config.getLearningRate(iteration, epoch); - double alphat = learningRate / (1.0 - beta1t); - if (Double.isNaN(alphat) || Double.isInfinite(alphat) || alphat == 0.0) { - alphat = config.getEpsilon(); - } - - u.addi(1e-32); // prevent NaNs in params - gradient.assign(m).muli(alphat).divi(u); + Nd4j.exec(new org.nd4j.linalg.api.ops.impl.updaters.AdaMaxUpdater(gradient, u, m, lr, b1, b2, eps, iteration)); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/learning/AdamUpdater.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/learning/AdamUpdater.java index e68af09f7..e72bfe5a4 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/learning/AdamUpdater.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/learning/AdamUpdater.java @@ -1,5 +1,6 @@ /******************************************************************************* * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2020 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -18,12 +19,11 @@ package org.nd4j.linalg.learning; import lombok.Data; import lombok.NonNull; -import org.apache.commons.math3.util.FastMath; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.shape.Shape; +import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.indexing.NDArrayIndex; import org.nd4j.linalg.learning.config.Adam; -import org.nd4j.linalg.ops.transforms.Transforms; import java.util.HashMap; import java.util.Map; @@ -102,20 +102,6 @@ public class AdamUpdater implements GradientUpdater { double learningRate = config.getLearningRate(iteration, epoch); double epsilon = config.getEpsilon(); - INDArray oneMinusBeta1Grad = gradient.mul(1.0 - beta1); - m.muli(beta1).addi(oneMinusBeta1Grad); - - INDArray oneMinusBeta2GradSquared = gradient.mul(gradient).muli(1 - beta2); - v.muli(beta2).addi(oneMinusBeta2GradSquared); - - double beta1t = FastMath.pow(beta1, iteration + 1); - double beta2t = FastMath.pow(beta2, iteration + 1); - - double alphat = learningRate * FastMath.sqrt(1 - beta2t) / (1 - beta1t); - if (Double.isNaN(alphat) || alphat == 0.0) - alphat = epsilon; - INDArray sqrtV = Transforms.sqrt(v.dup(gradientReshapeOrder), false).addi(epsilon); - - gradient.assign(m).muli(alphat).divi(sqrtV); + Nd4j.exec(new org.nd4j.linalg.api.ops.impl.updaters.AdamUpdater(gradient, v, m, learningRate, beta1, beta2, epsilon, iteration)); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/learning/NadamUpdater.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/learning/NadamUpdater.java index 18a29cc25..1cea01cb7 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/learning/NadamUpdater.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/learning/NadamUpdater.java @@ -1,5 +1,6 @@ /******************************************************************************* * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2020 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -18,12 +19,11 @@ package org.nd4j.linalg.learning; import lombok.Data; import lombok.NonNull; -import org.apache.commons.math3.util.FastMath; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.shape.Shape; +import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.indexing.NDArrayIndex; import org.nd4j.linalg.learning.config.Nadam; -import org.nd4j.linalg.ops.transforms.Transforms; import java.util.HashMap; import java.util.Map; @@ -101,21 +101,6 @@ public class NadamUpdater implements GradientUpdater { double learningRate = config.getLearningRate(iteration, epoch); double epsilon = config.getEpsilon(); - INDArray oneMinusBeta1Grad = gradient.mul(1.0 - beta1); - m.muli(beta1).addi(oneMinusBeta1Grad); - - INDArray oneMinusBeta2GradSquared = gradient.mul(gradient).muli(1.0 - beta2); - v.muli(beta2).addi(oneMinusBeta2GradSquared); - - double beta1t = FastMath.pow(beta1, iteration + 1); - - INDArray biasCorrectedEstimateOfMomentum = m.mul(beta1).divi(1.0 - beta1t); - INDArray secondTerm = oneMinusBeta1Grad.divi(1 - beta1t); - - INDArray alphat = biasCorrectedEstimateOfMomentum.add(secondTerm).muli(learningRate); - - INDArray sqrtV = Transforms.sqrt(v.dup(gradientReshapeOrder), false).addi(epsilon); - - gradient.assign(alphat).divi(sqrtV); + Nd4j.exec(new org.nd4j.linalg.api.ops.impl.updaters.NadamUpdater(gradient, v, m, learningRate, beta1, beta2, epsilon, iteration)); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/learning/NesterovsUpdater.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/learning/NesterovsUpdater.java index 64a9a6f87..2a18b78d8 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/learning/NesterovsUpdater.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/learning/NesterovsUpdater.java @@ -1,5 +1,6 @@ /******************************************************************************* * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2020 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -19,7 +20,6 @@ package org.nd4j.linalg.learning; import lombok.Data; import lombok.NonNull; import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.AddOp; import org.nd4j.linalg.api.shape.Shape; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.learning.config.Nesterovs; @@ -95,16 +95,8 @@ public class NesterovsUpdater implements GradientUpdater { //DL4J default is negative step function thus we flipped the signs: // x += mu * v_prev + (-1 - mu) * v //i.e., we do params -= updatedGradient, not params += updatedGradient - //v = mu * v - lr * gradient - INDArray vPrev = v.dup(gradientReshapeOrder); - v.muli(momentum).subi(gradient.dup(gradientReshapeOrder).muli(learningRate)); //Modify state array in-place - /* - Next line is equivalent to: - INDArray ret = vPrev.muli(momentum).addi(v.mul(-momentum - 1)); - gradient.assign(ret); - */ - Nd4j.getExecutioner().exec(new AddOp(vPrev.muli(momentum), v.mul(-momentum - 1), gradient)); + Nd4j.exec(new org.nd4j.linalg.api.ops.impl.updaters.NesterovsUpdater(gradient, v, learningRate, momentum)); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/learning/RmsPropUpdater.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/learning/RmsPropUpdater.java index e2d68c4bf..866f9ce0d 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/learning/RmsPropUpdater.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/learning/RmsPropUpdater.java @@ -1,5 +1,6 @@ /******************************************************************************* * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2020 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -20,8 +21,8 @@ import lombok.Data; import lombok.NonNull; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.shape.Shape; +import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.learning.config.RmsProp; -import org.nd4j.linalg.ops.transforms.Transforms; import java.util.Collections; import java.util.Map; @@ -85,8 +86,7 @@ public class RmsPropUpdater implements GradientUpdater { double rmsDecay = config.getRmsDecay(); double epsilon = config.getEpsilon(); - lastGradient.muli(rmsDecay).addi(gradient.mul(gradient).muli(1 - rmsDecay)); // lr * gradient / (sqrt(cache) + 1e-8) - gradient.muli(learningRate).divi(Transforms.sqrt(lastGradient.dup(gradientReshapeOrder), false).addi(epsilon)); + Nd4j.exec(new org.nd4j.linalg.api.ops.impl.updaters.RmsPropUpdater(gradient, lastGradient, learningRate, rmsDecay, epsilon)); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/learning/SgdUpdater.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/learning/SgdUpdater.java index 1eca487c1..a2d0b0214 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/learning/SgdUpdater.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/learning/SgdUpdater.java @@ -1,5 +1,6 @@ /******************************************************************************* * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2020 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -19,6 +20,7 @@ package org.nd4j.linalg.learning; import lombok.Data; import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.learning.config.Sgd; import java.util.Collections; @@ -56,6 +58,6 @@ public class SgdUpdater implements GradientUpdater { @Override public void applyUpdater(INDArray gradient, int iteration, int epoch) { double lr = config.getLearningRate(iteration, epoch); - gradient.muli(lr); + Nd4j.exec(new org.nd4j.linalg.api.ops.impl.updaters.SgdUpdater(gradient, lr)); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/SameDiffLoss.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/SameDiffLoss.java new file mode 100644 index 000000000..2a3a05663 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/SameDiffLoss.java @@ -0,0 +1,186 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.nd4j.linalg.lossfunctions; + +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.base.Preconditions; +import org.nd4j.linalg.activations.IActivation; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.primitives.Pair; +import java.util.HashMap; +import java.util.Map; + +/** + * SameDiff loss function. + * + * This class can be extended to create Deeplearning4j loss functions by defining one single method only: + * {@link #defineLoss(SameDiff, SDVariable, SDVariable)}. This method is used to define the loss function on a + * per example basis - i.e., the output should be an array with shape [minibatch].
+ *
+ * For example, the mean squared error (MSE) loss function can be defined using:
+ * {@code return labels.squaredDifference(layerInput).mean(1);} + * + */ +public abstract class SameDiffLoss implements ILossFunction { + protected transient SameDiff sd; + protected transient SDVariable scoreVariable; + + protected SameDiffLoss() { + + } + + /** + * Define the loss function.
+ * NOTE: The score on a *per example* basis - should return a SDVariable with shape [minibatch], where out[i] + * is the score for the ith minibatch + * + * @param sd SameDiff instance to define the loss on + * @param layerInput Input to the SameDiff loss function + * @param labels Labels placeholder + * @return The score on a per example basis (SDVariable with shape [minibatch]) + */ + public abstract SDVariable defineLoss(SameDiff sd, SDVariable layerInput, SDVariable labels); + + protected void createSameDiffInstance(DataType dataType){ + sd = SameDiff.create(); + SDVariable layerInput = sd.placeHolder("layerInput", dataType, -1); + SDVariable labels = sd.placeHolder("labels", dataType, -1); + scoreVariable = this.defineLoss(sd, layerInput, labels); + sd.createGradFunction("layerInput"); + } + + /** + * Compute the score (loss function value) for the given inputs. + * + * @param labels Label/expected preOutput + * @param preOutput Output of the model (neural network) + * @param activationFn Activation function that should be applied to preOutput + * @param mask Mask array; may be null + * @param average Whether the score should be averaged (divided by number of rows in labels/preOutput) or not @return Loss function value + */ + @Override + public double computeScore(INDArray labels, INDArray preOutput, IActivation activationFn, INDArray mask, boolean average) { + if(sd == null){ + createSameDiffInstance(preOutput.dataType()); + } + + INDArray scoreArr = computeScoreArray(labels, preOutput, activationFn, mask); + + double score = scoreArr.sumNumber().doubleValue(); + if (average) { + score /= scoreArr.size(0); + } + return score; + } + + + /** + * Compute the score (loss function value) for each example individually. + * For input [numExamples,nOut] returns scores as a column vector: [numExamples,1] + * + * @param labels Labels/expected output + * @param preOutput Output of the model (neural network) + * @param activationFn Activation function that should be applied to preOutput + * @param mask @return Loss function value for each example; column vector + */ + @Override + public INDArray computeScoreArray(INDArray labels, INDArray preOutput, IActivation activationFn, INDArray mask) { + if(sd == null){ + createSameDiffInstance(preOutput.dataType()); + } + + Preconditions.checkArgument((labels.size(1) == preOutput.size(1)), "Labels array numColumns (size(1) = %s) does not match output layer number of outputs (nOut = %s)", labels.size(1), preOutput.size(1)); + + INDArray output = activationFn.getActivation(preOutput.dup(), true); + + Map m = new HashMap<>(); + m.put("labels", labels); + m.put("layerInput", output); + + INDArray scoreArr = sd.outputSingle(m,scoreVariable.name()); + + if (mask != null) { + LossUtil.applyMask(scoreArr, mask); + } + return scoreArr; + } + + + /** + * Compute the gradient of the loss function with respect to the inputs: dL/dOutput + * + * @param labels Label/expected output + * @param preOutput Output of the model (neural network), before the activation function is applied + * @param activationFn Activation function that should be applied to preOutput + * @param mask Mask array; may be null + * @return Gradient dL/dPreOut + */ + @Override + public INDArray computeGradient(INDArray labels, INDArray preOutput, IActivation activationFn, INDArray mask) { + if(sd == null){ + createSameDiffInstance(preOutput.dataType()); + } + + + Map m = new HashMap<>(); + INDArray output = activationFn.getActivation(preOutput.dup(), true); + m.put("labels", labels); + m.put("layerInput", output); + + Map grads = sd.calculateGradients(m, "layerInput"); + + INDArray gradAtActivationOutput = grads.get("layerInput"); + INDArray gradAtInput = activationFn.backprop(preOutput.dup(), gradAtActivationOutput).getFirst(); + + if (mask != null) { + LossUtil.applyMask(gradAtInput, mask); + } + return gradAtInput; + } + + /** + * Compute both the score (loss function value) and gradient. This is equivalent to calling {@link #computeScore(INDArray, INDArray, IActivation, INDArray, boolean)} + * and {@link #computeGradient(INDArray, INDArray, IActivation, INDArray)} individually + * + * @param labels Label/expected output + * @param preOutput Output of the model (neural network) + * @param activationFn Activation function that should be applied to preOutput + * @param mask Mask array; may be null + * @param average Whether the score should be averaged (divided by number of rows in labels/output) or not + * @return The score (loss function value) and gradient + */ + @Override + public Pair computeGradientAndScore(INDArray labels, INDArray preOutput, IActivation activationFn, + INDArray mask, boolean average) { + + Pair GradientAndScore = new Pair<>(); + GradientAndScore.setFirst(this.computeScore(labels, preOutput, activationFn, mask, average)); + GradientAndScore.setSecond(this.computeGradient(labels, preOutput, activationFn, mask)); + + return GradientAndScore; + } + + @Override + public String name() { + return getClass().getSimpleName(); + } +} + + + + diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/serde/jackson/shaded/NDArrayTextDeSerializer.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/serde/jackson/shaded/NDArrayTextDeSerializer.java index d50143e77..11761a00d 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/serde/jackson/shaded/NDArrayTextDeSerializer.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/serde/jackson/shaded/NDArrayTextDeSerializer.java @@ -36,6 +36,10 @@ public class NDArrayTextDeSerializer extends JsonDeserializer { @Override public INDArray deserialize(JsonParser jp, DeserializationContext deserializationContext) throws IOException { JsonNode n = jp.getCodec().readTree(jp); + return deserialize(n); + } + + public INDArray deserialize(JsonNode n){ //First: check for backward compatilibity (RowVectorSerializer/Deserializer) if(!n.has("dataType")){ diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/pom.xml b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/pom.xml index 371386898..c181d4328 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/pom.xml +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/pom.xml @@ -274,12 +274,6 @@ junit test - - org.nd4j - nd4j-jackson - ${project.version} - test - org.nd4j nd4j-api diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/concurrency/CudaAffinityManager.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/concurrency/CudaAffinityManager.java index 356df88e0..415fa487f 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/concurrency/CudaAffinityManager.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/concurrency/CudaAffinityManager.java @@ -294,7 +294,11 @@ public class CudaAffinityManager extends BasicAffinityManager { @Override public void unsafeSetDevice(Integer deviceId) { + // actually set device NativeOpsHolder.getInstance().getDeviceNativeOps().setDevice(deviceId); + + // reset saved context, so it will be recreated on first call + AtomicAllocator.getInstance().getMemoryHandler().resetCachedContext(); } @Override diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/handler/MemoryHandler.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/handler/MemoryHandler.java index abb919e8c..44d8e2042 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/handler/MemoryHandler.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/handler/MemoryHandler.java @@ -304,4 +304,6 @@ public interface MemoryHandler { boolean promoteObject(DataBuffer buffer); void relocateObject(DataBuffer buffer); + + void resetCachedContext(); } diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/handler/impl/CudaZeroHandler.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/handler/impl/CudaZeroHandler.java index a8f3a0a3b..4c6e56bc9 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/handler/impl/CudaZeroHandler.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/handler/impl/CudaZeroHandler.java @@ -17,6 +17,8 @@ package org.nd4j.jita.handler.impl; import lombok.var; +import org.nd4j.base.Preconditions; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.nativeblas.OpaqueLaunchContext; import org.nd4j.shade.guava.collect.HashBasedTable; import org.nd4j.shade.guava.collect.Table; @@ -325,6 +327,11 @@ public class CudaZeroHandler implements MemoryHandler { */ @Override public void memcpyAsync(DataBuffer dstBuffer, Pointer srcPointer, long length, long dstOffset) { + if (length < 1) + return; + + Preconditions.checkArgument(length <= (dstBuffer.length() * Nd4j.sizeOfDataType(dstBuffer.dataType())), "Length requested is bigger than target DataBuffer length"); + val point = ((BaseCudaDataBuffer) dstBuffer).getAllocationPoint(); CudaContext tContext = null; @@ -1041,6 +1048,11 @@ public class CudaZeroHandler implements MemoryHandler { return ctx; } + @Override + public void resetCachedContext() { + tlContext.remove(); + } + /** * This method returns if this MemoryHandler instance is device-dependant (i.e. CUDA) * diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/ops/executioner/CudaExecutioner.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/ops/executioner/CudaExecutioner.java index bda208ce7..4b8209027 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/ops/executioner/CudaExecutioner.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/ops/executioner/CudaExecutioner.java @@ -792,6 +792,19 @@ public class CudaExecutioner extends DefaultOpExecutioner { } } + boolean keepDims = op.isKeepDims(); + long[] retShape = Shape.reductionShape(x, dimension, true, keepDims); + + if(z == null || x == z) { + val ret = Nd4j.createUninitialized(DataType.LONG, retShape); + + setZ(ret, op, oc); + z = ret; + } else if(!Arrays.equals(retShape, z.shape())){ + throw new IllegalStateException("Z array shape does not match expected return type for op " + op + + ": expected shape " + Arrays.toString(retShape) + ", z.shape()=" + Arrays.toString(z.shape())); + } + long st = profilingConfigurableHookIn(op); checkForCompression(op); @@ -1947,7 +1960,7 @@ public class CudaExecutioner extends DefaultOpExecutioner { val result = new ArrayList(); int nIn = opContext != null ? opContext.numInputArguments() : op.numInputArguments(); - if(nIn == 0 && op.getDescriptor().getNumInputs() != -2) { + if(nIn == 0 && op.getDescriptor().getNumInputs() >= 1) { if(log.isTraceEnabled()){ log.trace("Could not calculate output shape for op {}: number of input args was 0", op.getClass().getName()); @@ -2060,7 +2073,7 @@ public class CudaExecutioner extends DefaultOpExecutioner { throw new ND4JIllegalStateException("Op name " + op.opName() + " failed to execute. You can't execute non-inplace CustomOp without outputs being specified"); for (val shape: list) - op.addOutputArgument(Nd4j.create(shape)); + op.addOutputArgument(Nd4j.create(shape, false)); shapeOverride = true; } catch (Exception e) { diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/nativeblas/Nd4jCuda.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/nativeblas/Nd4jCuda.java index 8f30cdd82..6753b5ea1 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/nativeblas/Nd4jCuda.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/nativeblas/Nd4jCuda.java @@ -1,4 +1,4 @@ -// Targeted by JavaCPP version 1.5.3-SNAPSHOT: DO NOT EDIT THIS FILE +// Targeted by JavaCPP version 1.5.3: DO NOT EDIT THIS FILE package org.nd4j.nativeblas; diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/ops/NativeOpExecutioner.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/ops/NativeOpExecutioner.java index f0488636f..e42495e5e 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/ops/NativeOpExecutioner.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/ops/NativeOpExecutioner.java @@ -772,8 +772,10 @@ public class NativeOpExecutioner extends DefaultOpExecutioner { if (y != null) { - if (z == null) + if (z == null) { setZ(Nd4j.create(op.resultType(), x.shape()), op, oc); + z = getZ(op, oc); + } op.validateDataTypes(oc, experimentalMode.get()); @@ -1754,7 +1756,7 @@ public class NativeOpExecutioner extends DefaultOpExecutioner { val result = new ArrayList(); int nIn = opContext != null ? opContext.numInputArguments() : op.numInputArguments(); - if(nIn == 0 && op.getDescriptor().getNumInputs() != -2) { + if(nIn == 0 && op.getDescriptor().getNumInputs() >= 1) { if(log.isTraceEnabled()){ log.trace("Could not calculate output shape for op {}: number of input args was 0", op.getClass().getName()); diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java index e96325460..cd18e0f18 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java @@ -1,4 +1,4 @@ -// Targeted by JavaCPP version 1.5.3-SNAPSHOT: DO NOT EDIT THIS FILE +// Targeted by JavaCPP version 1.5.3: DO NOT EDIT THIS FILE package org.nd4j.nativeblas; @@ -16040,6 +16040,41 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD(); } // #endif +// #if NOT_EXCLUDED(OP_lstmLayerCell) + @Namespace("sd::ops") public static class lstmLayerCell extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public lstmLayerCell(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public lstmLayerCell(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public lstmLayerCell position(long position) { + return (lstmLayerCell)super.position(position); + } + + public lstmLayerCell() { super((Pointer)null); allocate(); } + private native void allocate(); + public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); + } +// #endif +// #if NOT_EXCLUDED(OP_lstmLayerCell) + @Namespace("sd::ops") public static class lstmLayerCellBp extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public lstmLayerCellBp(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public lstmLayerCellBp(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public lstmLayerCellBp position(long position) { + return (lstmLayerCellBp)super.position(position); + } + + public lstmLayerCellBp() { super((Pointer)null); allocate(); } + private native void allocate(); + public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); + } +// #endif + ////////////////////////////////////////////////////////////////////////// /** @@ -16169,6 +16204,25 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD(); } // #endif + ////////////////////////////////////////////////////////////////////////// +// #if NOT_EXCLUDED(OP_lstmLayer) + @Namespace("sd::ops") public static class lstmLayer_bp extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public lstmLayer_bp(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public lstmLayer_bp(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public lstmLayer_bp position(long position) { + return (lstmLayer_bp)super.position(position); + } + + public lstmLayer_bp() { super((Pointer)null); allocate(); } + private native void allocate(); + public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); + } +// #endif + ////////////////////////////////////////////////////////////////////////// /** @@ -16336,6 +16390,24 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD(); } // #endif +// #if NOT_EXCLUDED(OP_gru) + @Namespace("sd::ops") public static class gru_bp extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public gru_bp(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public gru_bp(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public gru_bp position(long position) { + return (gru_bp)super.position(position); + } + + public gru_bp() { super((Pointer)null); allocate(); } + private native void allocate(); + public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); + } +// #endif + ////////////////////////////////////////////////////////////////////////// /** * Implementation of operation "static RNN time sequences" with peep hole connections: @@ -20403,11 +20475,15 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD(); /** * lin_space - op porting from TF (https://www.tensorflow.org/api_docs/python/tf/lin_space) * - * input params: + * optional input params: * 0 - startVal - NDArray scalar (float point) * 1 - finishVal - NDArray scalar (float point) * 2 - numOfElements - NDArray scalar (integer) - * + * Optional: + * T args + * 0 - startVal + * 1 - finishVal] + * 2 - numOfElements * output: * 0 - 1D NDArray with the same type as input and length as given with numOfElements param. */ diff --git a/nd4j/nd4j-backends/nd4j-tests-tensorflow/pom.xml b/nd4j/nd4j-backends/nd4j-tests-tensorflow/pom.xml index 6a3cc6eda..3e4367992 100644 --- a/nd4j/nd4j-backends/nd4j-tests-tensorflow/pom.xml +++ b/nd4j/nd4j-backends/nd4j-tests-tensorflow/pom.xml @@ -30,7 +30,6 @@ nd4j-tests-tensorflow - 1.8 1.8 @@ -216,8 +215,10 @@ **/*.java - org.nd4j.linalg.jcublas.JCublasBackend - org.nd4j.linalg.jcublas.JCublasBackend + org.nd4j.linalg.jcublas.JCublasBackend + + org.nd4j.linalg.jcublas.JCublasBackend + - + nd4j-backends org.nd4j @@ -29,7 +30,6 @@ nd4j-tests - 1.8 1.8 @@ -76,12 +76,6 @@ ${project.version} - - org.nd4j - nd4j-jackson - ${project.version} - - org.nd4j @@ -179,7 +173,8 @@ maven-surefire-plugin - ${env.LD_LIBRARY_PATH}:${user.dir}:${libnd4jhome}/blasbuild/cpu/blas/ + ${env.LD_LIBRARY_PATH}:${user.dir}:${libnd4jhome}/blasbuild/cpu/blas/ + src/test/java @@ -191,8 +186,10 @@ junit:junit - org.nd4j.linalg.cpu.nativecpu.CpuBackend - org.nd4j.linalg.cpu.nativecpu.CpuBackend + org.nd4j.linalg.cpu.nativecpu.CpuBackend + + org.nd4j.linalg.cpu.nativecpu.CpuBackend + + + + com.google.code.findbugs + * + + + diff --git a/nd4j/nd4j-common-tests/src/main/java/org/nd4j/AbstractAssertTestsClass.java b/nd4j/nd4j-common-tests/src/main/java/org/nd4j/AbstractAssertTestsClass.java new file mode 100644 index 000000000..f229069ae --- /dev/null +++ b/nd4j/nd4j-common-tests/src/main/java/org/nd4j/AbstractAssertTestsClass.java @@ -0,0 +1,82 @@ +/* ****************************************************************************** + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.nd4j; + +import lombok.extern.slf4j.Slf4j; +import org.junit.Test; +import org.reflections.Reflections; +import org.reflections.scanners.MethodAnnotationsScanner; +import org.reflections.util.ClasspathHelper; +import org.reflections.util.ConfigurationBuilder; + +import java.lang.reflect.Method; +import java.util.*; + +import static org.junit.Assert.assertEquals; + +/** + * This class checks that all test classes (i.e., anything with one or more methods annotated with @Test) + * extends BaseDl4jTest - either directly or indirectly. + * Other than a small set of exceptions, all tests must extend this + * + * @author Alex Black + * @author Alexander Stoyakin + */ +@Slf4j +public abstract class AbstractAssertTestsClass extends BaseND4JTest { + + protected abstract Set> getExclusions(); + + protected abstract String getPackageName(); + + protected abstract Class getBaseClass(); + + @Override + public long getTimeoutMilliseconds() { + return 240000L; + } + + @Test + public void checkTestClasses(){ + + Reflections reflections = new Reflections(new ConfigurationBuilder() + .setUrls(ClasspathHelper.forPackage(getPackageName())) + .setScanners(new MethodAnnotationsScanner())); + Set methods = reflections.getMethodsAnnotatedWith(Test.class); + Set> s = new HashSet<>(); + for(Method m : methods){ + s.add(m.getDeclaringClass()); + } + + List> l = new ArrayList<>(s); + Collections.sort(l, new Comparator>() { + @Override + public int compare(Class aClass, Class t1) { + return aClass.getName().compareTo(t1.getName()); + } + }); + + int count = 0; + for(Class c : l){ + if(!getBaseClass().isAssignableFrom(c) && !getExclusions().contains(c)){ + log.error("Test {} does not extend {} (directly or indirectly). All tests must extend this class for proper memory tracking and timeouts", + c, getBaseClass()); + count++; + } + } + assertEquals("Number of tests not extending BaseND4JTest", 0, count); + } +} diff --git a/nd4j/nd4j-jdbc/nd4j-jdbc-mysql/pom.xml b/nd4j/nd4j-jdbc/nd4j-jdbc-mysql/pom.xml index 0ece4c8b0..e640ed219 100644 --- a/nd4j/nd4j-jdbc/nd4j-jdbc-mysql/pom.xml +++ b/nd4j/nd4j-jdbc/nd4j-jdbc-mysql/pom.xml @@ -15,7 +15,8 @@ ~ SPDX-License-Identifier: Apache-2.0 ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~--> - + 4.0.0 @@ -34,9 +35,9 @@ ${project.version} - junit - junit - test + junit + junit + test diff --git a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-rocksdb-storage/pom.xml b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-rocksdb-storage/pom.xml index 619af0d7b..3ba5a156a 100644 --- a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-rocksdb-storage/pom.xml +++ b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-rocksdb-storage/pom.xml @@ -42,9 +42,9 @@ - junit - junit - test + junit + junit + test diff --git a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-status/pom.xml b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-status/pom.xml index 4ed7c8b7b..7c2783904 100644 --- a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-status/pom.xml +++ b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-status/pom.xml @@ -49,9 +49,9 @@ - junit - junit - test + junit + junit + test diff --git a/nd4j/nd4j-remote/nd4j-json-server/pom.xml b/nd4j/nd4j-remote/nd4j-json-server/pom.xml index d7729f179..5537216ca 100644 --- a/nd4j/nd4j-remote/nd4j-json-server/pom.xml +++ b/nd4j/nd4j-remote/nd4j-json-server/pom.xml @@ -1,185 +1,185 @@ - 4.0.0 - jar + xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd"> + 4.0.0 + jar - - org.nd4j - nd4j-remote - 1.0.0-SNAPSHOT - + + org.nd4j + nd4j-remote + 1.0.0-SNAPSHOT + - nd4j-json-server - nd4j-json-server + nd4j-json-server + nd4j-json-server - - UTF-8 - 1.7 - 1.7 - + + UTF-8 + 1.7 + 1.7 + - - - junit - junit - test - - - - org.nd4j - nd4j-json-client - ${project.version} - - - - org.slf4j - slf4j-api - - - - org.nd4j - nd4j-api - ${project.version} - - - - org.glassfish.jersey.core - jersey-client - ${jersey.version} - - - - org.glassfish.jersey.core - jersey-server - ${jersey.version} - - - - org.eclipse.jetty - jetty-server - 9.4.19.v20190610 - - - - org.eclipse.jetty - jetty-servlet - 9.4.19.v20190610 - - - - org.glassfish.jersey.inject - jersey-hk2 - ${jersey.version} - - - - org.glassfish.jersey.media - jersey-media-json-processing - ${jersey.version} - - - - org.glassfish.jersey.containers - jersey-container-servlet-core - ${jersey.version} - - - - ch.qos.logback - logback-core - ${logback.version} - test - - - - ch.qos.logback - logback-classic - ${logback.version} - test - - - - javax.xml.bind - jaxb-api - 2.3.0 - - - - com.sun.xml.bind - jaxb-impl - 2.3.0 - - - - com.sun.xml.bind - jaxb-core - 2.3.0 - - - - javax.activation - activation - 1.1 - - - - com.google.code.gson - gson - ${gson.version} - test - - - - - org.nd4j - nd4j-common-tests - ${project.version} - test - - - - - - - org.apache.maven.plugins - maven-compiler-plugin - - ${maven.compiler.source} - ${maven.compiler.target} - - - - - - - - nd4j-tests-cpu - + - org.nd4j - nd4j-native - ${project.version} - test + junit + junit + test - - - - nd4j-tests-cuda - - org.nd4j - nd4j-cuda-10.2 - ${project.version} - test + org.nd4j + nd4j-json-client + ${project.version} - - - - testresources - - + + org.slf4j + slf4j-api + + + + org.nd4j + nd4j-api + ${project.version} + + + + org.glassfish.jersey.core + jersey-client + ${jersey.version} + + + + org.glassfish.jersey.core + jersey-server + ${jersey.version} + + + + org.eclipse.jetty + jetty-server + 9.4.19.v20190610 + + + + org.eclipse.jetty + jetty-servlet + 9.4.19.v20190610 + + + + org.glassfish.jersey.inject + jersey-hk2 + ${jersey.version} + + + + org.glassfish.jersey.media + jersey-media-json-processing + ${jersey.version} + + + + org.glassfish.jersey.containers + jersey-container-servlet-core + ${jersey.version} + + + + ch.qos.logback + logback-core + ${logback.version} + test + + + + ch.qos.logback + logback-classic + ${logback.version} + test + + + + javax.xml.bind + jaxb-api + 2.3.0 + + + + com.sun.xml.bind + jaxb-impl + 2.3.0 + + + + com.sun.xml.bind + jaxb-core + 2.3.0 + + + + javax.activation + activation + 1.1 + + + + com.google.code.gson + gson + ${gson.version} + test + + + + + org.nd4j + nd4j-common-tests + ${project.version} + test + + + + + + + org.apache.maven.plugins + maven-compiler-plugin + + ${maven.compiler.source} + ${maven.compiler.target} + + + + + + + + nd4j-tests-cpu + + + org.nd4j + nd4j-native + ${project.version} + test + + + + + + nd4j-tests-cuda + + + org.nd4j + nd4j-cuda-10.2 + ${project.version} + test + + + + + + testresources + + diff --git a/nd4j/nd4j-serde/nd4j-aeron/pom.xml b/nd4j/nd4j-serde/nd4j-aeron/pom.xml index 827afb23a..c94bf86af 100644 --- a/nd4j/nd4j-serde/nd4j-aeron/pom.xml +++ b/nd4j/nd4j-serde/nd4j-aeron/pom.xml @@ -16,179 +16,186 @@ - 4.0.0 + 4.0.0 - org.nd4j - nd4j-aeron - jar - - nd4j-aeron - - org.nd4j - nd4j-serde - 1.0.0-SNAPSHOT - - - 1.8 - 1.8 - 1.5.4 - 1.4.0 - UTF-8 - + nd4j-aeron + jar - - - jdk9 - - 1.9 - - - 8 - - - - testresources - + nd4j-aeron - - nd4j-tests-cpu - - false - - - - org.nd4j - nd4j-native - ${project.version} - - - - - - org.apache.maven.plugins - maven-surefire-plugin - - - ${env.LD_LIBRARY_PATH}:${user.dir}:${libnd4jhome}/blasbuild/cpu/blas/ - - src/test/java - - *.java - **/*.java - **/Test*.java - **/*Test.java - **/*TestCase.java - - junit:junit - - org.nd4j.linalg.cpu.nativecpu.CpuBackend - org.nd4j.linalg.cpu.nativecpu.CpuBackend - - - -Ddtype=float -Dfile.encoding=UTF-8 -Xmx8g - - - - - + + + jdk9 + + 1.9 + + + 8 + + + + testresources + - - nd4j-tests-cuda - - false - - - - org.nd4j - nd4j-cuda-10.2 - ${project.version} - - - - - - org.apache.maven.plugins - maven-surefire-plugin + + nd4j-tests-cpu + + false + - - org.apache.maven.surefire - surefire-junit47 - 2.19.1 - + + org.nd4j + nd4j-native + ${project.version} + - - - ${env.LD_LIBRARY_PATH}:${user.dir}:${libnd4jhome}/blasbuild/cuda/blas/ - - src/test/java - - *.java - **/*.java - **/Test*.java - **/*Test.java - **/*TestCase.java - - junit:junit - - org.nd4j.linalg.jcublas.JCublasBackend - org.nd4j.linalg.jcublas.JCublasBackend - - - -Ddtype=float -Xmx6g - - - - - - + + + + org.apache.maven.plugins + maven-surefire-plugin + + + ${env.LD_LIBRARY_PATH}:${user.dir}:${libnd4jhome}/blasbuild/cpu/blas/ + + + src/test/java + + *.java + **/*.java + **/Test*.java + **/*Test.java + **/*TestCase.java + + junit:junit + + org.nd4j.linalg.cpu.nativecpu.CpuBackend + + org.nd4j.linalg.cpu.nativecpu.CpuBackend + + + + -Ddtype=float -Dfile.encoding=UTF-8 -Xmx8g + + + + + + + + nd4j-tests-cuda + + false + + + + org.nd4j + nd4j-cuda-10.2 + ${project.version} + + + + + + org.apache.maven.plugins + maven-surefire-plugin + + + org.apache.maven.surefire + surefire-junit47 + 2.19.1 + + + + + + ${env.LD_LIBRARY_PATH}:${user.dir}:${libnd4jhome}/blasbuild/cuda/blas/ + + + src/test/java + + *.java + **/*.java + **/Test*.java + **/*Test.java + **/*TestCase.java + + junit:junit + + org.nd4j.linalg.jcublas.JCublasBackend + + org.nd4j.linalg.jcublas.JCublasBackend + + + + -Ddtype=float -Xmx6g + + + + + + - - - org.nd4j - nd4j-api - ${project.version} - - - io.aeron - aeron-all - ${aeron.version} - - - junit - junit - test - + + + org.nd4j + nd4j-api + ${project.version} + + + io.aeron + aeron-all + ${aeron.version} + + + junit + junit + test + - - ch.qos.logback - logback-classic - ${logback.version} - test - + + ch.qos.logback + logback-classic + ${logback.version} + test + - - ch.qos.logback - logback-core - ${logback.version} - test - + + ch.qos.logback + logback-core + ${logback.version} + test + - - org.nd4j - nd4j-common-tests - ${project.version} - test - - + + org.nd4j + nd4j-common-tests + ${project.version} + test + + diff --git a/nd4j/nd4j-serde/nd4j-arrow/pom.xml b/nd4j/nd4j-serde/nd4j-arrow/pom.xml index 3a768c1a5..69879e965 100644 --- a/nd4j/nd4j-serde/nd4j-arrow/pom.xml +++ b/nd4j/nd4j-serde/nd4j-arrow/pom.xml @@ -88,7 +88,8 @@ maven-surefire-plugin - ${env.LD_LIBRARY_PATH}:${user.dir}:${libnd4jhome}/blasbuild/cpu/blas/ + ${env.LD_LIBRARY_PATH}:${user.dir}:${libnd4jhome}/blasbuild/cpu/blas/ + src/test/java @@ -100,8 +101,10 @@ junit:junit - org.nd4j.linalg.cpu.nativecpu.CpuBackend - org.nd4j.linalg.cpu.nativecpu.CpuBackend + org.nd4j.linalg.cpu.nativecpu.CpuBackend + + org.nd4j.linalg.cpu.nativecpu.CpuBackend + - + nd4j-camel-routes org.nd4j diff --git a/nd4j/nd4j-serde/nd4j-gson/pom.xml b/nd4j/nd4j-serde/nd4j-gson/pom.xml index f488bfde5..60de01b6e 100644 --- a/nd4j/nd4j-serde/nd4j-gson/pom.xml +++ b/nd4j/nd4j-serde/nd4j-gson/pom.xml @@ -15,7 +15,8 @@ ~ SPDX-License-Identifier: Apache-2.0 ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~--> - + nd4j-serde org.nd4j @@ -41,9 +42,9 @@ - junit - junit - test + junit + junit + test @@ -79,7 +80,8 @@ maven-surefire-plugin - ${env.LD_LIBRARY_PATH}:${user.dir}:${libnd4jhome}/blasbuild/cpu/blas/ + ${env.LD_LIBRARY_PATH}:${user.dir}:${libnd4jhome}/blasbuild/cpu/blas/ + src/test/java @@ -91,8 +93,10 @@ junit:junit - org.nd4j.linalg.cpu.nativecpu.CpuBackend - org.nd4j.linalg.cpu.nativecpu.CpuBackend + org.nd4j.linalg.cpu.nativecpu.CpuBackend + + org.nd4j.linalg.cpu.nativecpu.CpuBackend + - + nd4j-serde org.nd4j @@ -101,17 +102,17 @@ ${spark.version} provided - - com.google.code.findbugs - jsr305 - + + com.google.code.findbugs + jsr305 + - junit - junit - test + junit + junit + test @@ -147,7 +148,8 @@ maven-surefire-plugin - ${env.LD_LIBRARY_PATH}:${user.dir}:${libnd4jhome}/blasbuild/cpu/blas/ + ${env.LD_LIBRARY_PATH}:${user.dir}:${libnd4jhome}/blasbuild/cpu/blas/ + src/test/java @@ -159,8 +161,10 @@ junit:junit - org.nd4j.linalg.cpu.nativecpu.CpuBackend - org.nd4j.linalg.cpu.nativecpu.CpuBackend + org.nd4j.linalg.cpu.nativecpu.CpuBackend + + org.nd4j.linalg.cpu.nativecpu.CpuBackend +