diff --git a/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/scoring/impl/RegressionScoreFunction.java b/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/scoring/impl/RegressionScoreFunction.java index fdcc6934e..51fcd9898 100644 --- a/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/scoring/impl/RegressionScoreFunction.java +++ b/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/scoring/impl/RegressionScoreFunction.java @@ -21,12 +21,13 @@ import org.deeplearning4j.datasets.iterator.MultiDataSetWrapperIterator; import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.nd4j.evaluation.regression.RegressionEvaluation; +import org.nd4j.evaluation.regression.RegressionEvaluation.Metric; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator; /** * Score function for regression (including multi-label regression) for a MultiLayerNetwork or ComputationGraph - * on a test set. Supports all regression metrics: {@link RegressionEvaluation.Metric} + * on a test set. Supports all regression metrics: {@link Metric} * * @author Alex Black */ @@ -35,13 +36,13 @@ import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator; @NoArgsConstructor(access = AccessLevel.PROTECTED) //For JSON public class RegressionScoreFunction extends BaseNetScoreFunction { - protected RegressionEvaluation.Metric metric; + protected Metric metric; public RegressionScoreFunction(@NonNull org.deeplearning4j.eval.RegressionEvaluation.Metric metric) { this(metric.toNd4j()); } - public RegressionScoreFunction(@NonNull RegressionEvaluation.Metric metric) { + public RegressionScoreFunction(@NonNull Metric metric) { this.metric = metric; } diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/earlystopping/TestEarlyStopping.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/earlystopping/TestEarlyStopping.java index d88a3785a..53e505e3b 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/earlystopping/TestEarlyStopping.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/earlystopping/TestEarlyStopping.java @@ -51,7 +51,7 @@ import org.junit.Test; import org.junit.rules.TemporaryFolder; import org.nd4j.evaluation.classification.Evaluation; import org.nd4j.evaluation.classification.ROCBinary; -import org.nd4j.evaluation.regression.RegressionEvaluation; +import org.nd4j.evaluation.regression.RegressionEvaluation.Metric; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; @@ -107,7 +107,7 @@ public class TestEarlyStopping extends BaseDL4JTest { min = false; break; case 3: - sc = new RegressionScoreCalculator(RegressionEvaluation.Metric.MSE, irisIter); + sc = new RegressionScoreCalculator(Metric.MSE, irisIter); min = true; break; case 4: @@ -561,8 +561,8 @@ public class TestEarlyStopping extends BaseDL4JTest { @Test public void testRegressionScoreFunctionSimple() throws Exception { - for(RegressionEvaluation.Metric metric : new RegressionEvaluation.Metric[]{RegressionEvaluation.Metric.MSE, - RegressionEvaluation.Metric.MAE}) { + for(Metric metric : new Metric[]{Metric.MSE, + Metric.MAE}) { log.info("Metric: " + metric); MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() @@ -604,8 +604,8 @@ public class TestEarlyStopping extends BaseDL4JTest { @Test public void testAEScoreFunctionSimple() throws Exception { - for(RegressionEvaluation.Metric metric : new RegressionEvaluation.Metric[]{RegressionEvaluation.Metric.MSE, - RegressionEvaluation.Metric.MAE}) { + for(Metric metric : new Metric[]{Metric.MSE, + Metric.MAE}) { log.info("Metric: " + metric); MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() @@ -647,8 +647,8 @@ public class TestEarlyStopping extends BaseDL4JTest { @Test public void testVAEScoreFunctionSimple() throws Exception { - for(RegressionEvaluation.Metric metric : new RegressionEvaluation.Metric[]{RegressionEvaluation.Metric.MSE, - RegressionEvaluation.Metric.MAE}) { + for(Metric metric : new Metric[]{Metric.MSE, + Metric.MAE}) { log.info("Metric: " + metric); MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/earlystopping/TestEarlyStoppingCompGraph.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/earlystopping/TestEarlyStoppingCompGraph.java index 44767065b..16bc4c31f 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/earlystopping/TestEarlyStoppingCompGraph.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/earlystopping/TestEarlyStoppingCompGraph.java @@ -43,7 +43,7 @@ import org.deeplearning4j.nn.weights.WeightInit; import org.deeplearning4j.optimize.listeners.ScoreIterationListener; import org.junit.Test; import org.nd4j.evaluation.classification.Evaluation; -import org.nd4j.evaluation.regression.RegressionEvaluation; +import org.nd4j.evaluation.regression.RegressionEvaluation.Metric; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.dataset.DataSet; @@ -289,8 +289,8 @@ public class TestEarlyStoppingCompGraph extends BaseDL4JTest { @Test public void testRegressionScoreFunctionSimple() throws Exception { - for(RegressionEvaluation.Metric metric : new RegressionEvaluation.Metric[]{RegressionEvaluation.Metric.MSE, - RegressionEvaluation.Metric.MAE}) { + for(Metric metric : new Metric[]{Metric.MSE, + Metric.MAE}) { log.info("Metric: " + metric); ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder() @@ -335,8 +335,8 @@ public class TestEarlyStoppingCompGraph extends BaseDL4JTest { public void testAEScoreFunctionSimple() throws Exception { DataType dt = Nd4j.defaultFloatingPointType(); - for(RegressionEvaluation.Metric metric : new RegressionEvaluation.Metric[]{RegressionEvaluation.Metric.MSE, - RegressionEvaluation.Metric.MAE}) { + for(Metric metric : new Metric[]{Metric.MSE, + Metric.MAE}) { log.info("Metric: " + metric); ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder() @@ -380,8 +380,8 @@ public class TestEarlyStoppingCompGraph extends BaseDL4JTest { @Test public void testVAEScoreFunctionSimple() throws Exception { - for(RegressionEvaluation.Metric metric : new RegressionEvaluation.Metric[]{RegressionEvaluation.Metric.MSE, - RegressionEvaluation.Metric.MAE}) { + for(Metric metric : new Metric[]{Metric.MSE, + Metric.MAE}) { log.info("Metric: " + metric); ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder() diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/earlystopping/scorecalc/AutoencoderScoreCalculator.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/earlystopping/scorecalc/AutoencoderScoreCalculator.java index 2390d654c..2284efd1b 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/earlystopping/scorecalc/AutoencoderScoreCalculator.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/earlystopping/scorecalc/AutoencoderScoreCalculator.java @@ -23,6 +23,7 @@ import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.layers.feedforward.autoencoder.AutoEncoder; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.nd4j.evaluation.regression.RegressionEvaluation; +import org.nd4j.evaluation.regression.RegressionEvaluation.Metric; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; @@ -30,16 +31,16 @@ import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; /** * Score function for a MultiLayerNetwork or ComputationGraph with a single * {@link org.deeplearning4j.nn.conf.layers.AutoEncoder} layer. - * Calculates the specified {@link RegressionEvaluation.Metric} on the layer's reconstructions. + * Calculates the specified {@link Metric} on the layer's reconstructions. * * @author Alex Black */ public class AutoencoderScoreCalculator extends BaseScoreCalculator { - protected final RegressionEvaluation.Metric metric; + protected final Metric metric; protected RegressionEvaluation evaluation; - public AutoencoderScoreCalculator(RegressionEvaluation.Metric metric, DataSetIterator iterator){ + public AutoencoderScoreCalculator(Metric metric, DataSetIterator iterator){ super(iterator); this.metric = metric; } diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/earlystopping/scorecalc/RegressionScoreCalculator.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/earlystopping/scorecalc/RegressionScoreCalculator.java index d398c9ce1..87c9f2e38 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/earlystopping/scorecalc/RegressionScoreCalculator.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/earlystopping/scorecalc/RegressionScoreCalculator.java @@ -19,19 +19,20 @@ package org.deeplearning4j.earlystopping.scorecalc; import org.deeplearning4j.earlystopping.scorecalc.base.BaseIEvaluationScoreCalculator; import org.deeplearning4j.nn.api.Model; import org.nd4j.evaluation.regression.RegressionEvaluation; +import org.nd4j.evaluation.regression.RegressionEvaluation.Metric; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; /** * Calculate the regression score of the network (MultiLayerNetwork or ComputationGraph) on a test set, using the - * specified regression metric - {@link RegressionEvaluation.Metric} + * specified regression metric - {@link Metric} * * @author Alex Black */ public class RegressionScoreCalculator extends BaseIEvaluationScoreCalculator { - protected final RegressionEvaluation.Metric metric; + protected final Metric metric; - public RegressionScoreCalculator(RegressionEvaluation.Metric metric, DataSetIterator iterator){ + public RegressionScoreCalculator(Metric metric, DataSetIterator iterator){ super(iterator); this.metric = metric; } diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/earlystopping/scorecalc/VAEReconErrorScoreCalculator.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/earlystopping/scorecalc/VAEReconErrorScoreCalculator.java index e2f4c3b4e..46eb7d670 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/earlystopping/scorecalc/VAEReconErrorScoreCalculator.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/earlystopping/scorecalc/VAEReconErrorScoreCalculator.java @@ -23,6 +23,7 @@ import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.layers.variational.VariationalAutoencoder; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.nd4j.evaluation.regression.RegressionEvaluation; +import org.nd4j.evaluation.regression.RegressionEvaluation.Metric; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; @@ -35,7 +36,7 @@ import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; */ public class VAEReconErrorScoreCalculator extends BaseScoreCalculator { - protected final RegressionEvaluation.Metric metric; + protected final Metric metric; protected RegressionEvaluation evaluation; /** @@ -44,7 +45,7 @@ public class VAEReconErrorScoreCalculator extends BaseScoreCalculator { * @param metric * @param iterator */ - public VAEReconErrorScoreCalculator(RegressionEvaluation.Metric metric, DataSetIterator iterator) { + public VAEReconErrorScoreCalculator(Metric metric, DataSetIterator iterator) { super(iterator); this.metric = metric; } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/At.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/At.java index ea43233cd..5427b4cd7 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/At.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/At.java @@ -20,6 +20,22 @@ public class At { private int iteration; private int trainingThreadNum; private long javaThreadNum; + private Operation operation; + + /** + * @return A new instance with everything set to 0, and operation set to INFERENCE + */ + public static At defaultAt(){ + return new At(0, 0, 0, 0, Operation.INFERENCE); + } + + /** + * @param op Operation + * @return A new instance with everything set to 0, except for the specified operation + */ + public static At defaultAt(@NonNull Operation op){ + return new At(0, 0, 0, 0, op); + } /** * @return The current training epoch @@ -48,4 +64,26 @@ public class At { public long javaThreadNum(){ return javaThreadNum; } + + /** + * @return The current operation + */ + public Operation operation(){ + return operation; + } + + /** + * @return A copy of the current At instance + */ + public At copy(){ + return new At(epoch, iteration, trainingThreadNum, javaThreadNum, operation); + } + + /** + * @param operation Operation to set in the new instance + * @return A copy of the current instance, but with the specified operation + */ + public At copy(Operation operation){ + return new At(epoch, iteration, trainingThreadNum, javaThreadNum, operation); + } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/BaseEvaluationListener.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/BaseEvaluationListener.java new file mode 100644 index 000000000..c0af6bb5c --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/BaseEvaluationListener.java @@ -0,0 +1,151 @@ +/* + * Copyright (c) 2015-2019 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.nd4j.autodiff.listeners; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import org.nd4j.autodiff.listeners.records.EvaluationRecord; +import org.nd4j.autodiff.listeners.records.LossCurve; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.autodiff.samediff.internal.SameDiffOp; +import org.nd4j.evaluation.IEvaluation; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.dataset.api.MultiDataSet; + +/** + * A base listener class that will preform the provided evaluations, and provide the results in epochEnd and validationDone + * + * Instead of overriding requiredVariables, epochStart, epochEnd, validationDone, and/or opExecution, + * override otherRequiredVariables, epochStartEvaluations, epochEndEvaluations, validationDoneEvaluations, and/or opExecutionEvaluations + * + * If you want to use Evaluations in your listener, extend this class + */ +public abstract class BaseEvaluationListener extends BaseListener { + + private Map> trainingEvaluations = new HashMap<>(); + private Map> validationEvaluations = new HashMap<>(); + + /** + * Return the requested evaluations. New instances of these evaluations will be made each time they are used + */ + public abstract ListenerEvaluations evaluations(); + + @Override + public final ListenerVariables requiredVariables(SameDiff sd) { + return evaluations().requiredVariables().merge(otherRequiredVariables(sd)); + } + + /** + * Return any requested variables that are not part of the evaluations + */ + public ListenerVariables otherRequiredVariables(SameDiff sd){ + return ListenerVariables.empty(); + } + + + @Override + public final void epochStart(SameDiff sd, At at) { + trainingEvaluations = new HashMap<>(); + for(Map.Entry> entry : evaluations().trainEvaluations().entrySet()){ + + List evals = new ArrayList<>(); + for(IEvaluation ie : entry.getValue()) + evals.add(ie.newInstance()); + + trainingEvaluations.put(entry.getKey(), evals); + } + validationEvaluations = new HashMap<>(); + for(Map.Entry> entry : evaluations().validationEvaluations().entrySet()){ + + List evals = new ArrayList<>(); + for(IEvaluation ie : entry.getValue()) + evals.add(ie.newInstance()); + + validationEvaluations.put(entry.getKey(), evals); + } + + epochStartEvaluations(sd, at); + } + + /** + * See {@link Listener#epochStart(SameDiff, At)} + */ + public void epochStartEvaluations(SameDiff sd, At at){ + //No op + } + + @Override + public final ListenerResponse epochEnd(SameDiff sd, At at, LossCurve lossCurve, long epochTimeMillis) { + return epochEndEvaluations(sd, at, lossCurve, epochTimeMillis, new EvaluationRecord(trainingEvaluations)); + } + + /** + * See {@link Listener#epochEnd(SameDiff, At, LossCurve, long)}, also provided the requested evaluations + */ + public ListenerResponse epochEndEvaluations(SameDiff sd, At at, LossCurve lossCurve, long epochTimeMillis, EvaluationRecord evaluations) { + //No op + return ListenerResponse.CONTINUE; + } + + @Override + public final ListenerResponse validationDone(SameDiff sd, At at, long validationTimeMillis) { + return validationDoneEvaluations(sd, at, validationTimeMillis, new EvaluationRecord(validationEvaluations)); + } + + /** + * See {@link Listener#validationDone(SameDiff, At, long)}, also provided the requested evaluations + */ + public ListenerResponse validationDoneEvaluations(SameDiff sd, At at, long validationTimeMillis, EvaluationRecord evaluations) { + //No op + return ListenerResponse.CONTINUE; + } + + @Override + public final void activationAvailable(SameDiff sd, At at, MultiDataSet batch, SameDiffOp op, String varName, + INDArray activation) { + if(at.operation() == Operation.TRAINING) { + if (trainingEvaluations.containsKey(varName)) { + INDArray labels = batch.getLabels(evaluations().trainEvaluationLabels().get(varName)); + INDArray mask = batch.getLabelsMaskArray(evaluations().trainEvaluationLabels().get(varName)); + + for (IEvaluation e : trainingEvaluations.get(varName)) + e.eval(labels, activation, mask); + } + } else if(at.operation() == Operation.TRAINING_VALIDATION) { + if (validationEvaluations.containsKey(varName)) { + INDArray labels = batch.getLabels(evaluations().validationEvaluationLabels().get(varName)); + INDArray mask = batch.getLabelsMaskArray(evaluations().validationEvaluationLabels().get(varName)); + + for (IEvaluation e : validationEvaluations.get(varName)) + e.eval(labels, activation, mask); + } + } + + activationAvailableEvaluations(sd, at, batch, op, varName, activation); + } + + /** + * See {@link Listener#activationAvailable(SameDiff, At, MultiDataSet, SameDiffOp, String, INDArray)} + */ + public void activationAvailableEvaluations(SameDiff sd, At at, MultiDataSet batch, SameDiffOp op, String varName, + INDArray activation){ + //No op + } + +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/BaseListener.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/BaseListener.java index ea16e31b0..61a5e75a3 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/BaseListener.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/BaseListener.java @@ -1,6 +1,6 @@ package org.nd4j.autodiff.listeners; -import org.nd4j.autodiff.functions.DifferentialFunction; +import org.nd4j.autodiff.listeners.records.LossCurve; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.internal.SameDiffOp; import org.nd4j.autodiff.samediff.internal.Variable; @@ -11,18 +11,32 @@ import org.nd4j.linalg.dataset.api.MultiDataSet; * A base/abstract {@link Listener} with all methods implemented as no-op. * Extend this for custom listeners to selectively override only the required methods * + * If you want to use evaluations in your listener, use {@link BaseEvaluationListener} + * * @author Alex Black */ public abstract class BaseListener implements Listener { + + @Override + public ListenerVariables requiredVariables(SameDiff sd){ + return ListenerVariables.empty(); + } + @Override public void epochStart(SameDiff sd, At at) { //No op } @Override - public void epochEnd(SameDiff sd, At at) { + public ListenerResponse epochEnd(SameDiff sd, At at, LossCurve lossCurve, long epochTimeMillis) { + return ListenerResponse.CONTINUE; + } + + @Override + public ListenerResponse validationDone(SameDiff sd, At at, long validationTimeMillis) { //No op + return ListenerResponse.CONTINUE; } @Override @@ -36,12 +50,28 @@ public abstract class BaseListener implements Listener { } @Override - public void preOpExecution(SameDiff sd, At at, boolean training, SameDiffOp op) { + public void operationStart(SameDiff sd, Operation op) { //No op } @Override - public void opExecution(SameDiff sd, At at, boolean training, SameDiffOp op, INDArray[] outputs) { + public void operationEnd(SameDiff sd, Operation op) { + //No op + } + + @Override + public void preOpExecution(SameDiff sd, At at, SameDiffOp op) { + //No op + } + + @Override + public void opExecution(SameDiff sd, At at, MultiDataSet batch, SameDiffOp op, INDArray[] outputs) { + //No op + } + + @Override + public void activationAvailable(SameDiff sd, At at, MultiDataSet batch, SameDiffOp op, String varName, + INDArray activation) { //No op } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/Listener.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/Listener.java index 08503e8c3..8d6a051df 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/Listener.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/Listener.java @@ -1,6 +1,6 @@ package org.nd4j.autodiff.listeners; -import org.nd4j.autodiff.functions.DifferentialFunction; +import org.nd4j.autodiff.listeners.records.LossCurve; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.internal.SameDiffOp; import org.nd4j.autodiff.samediff.internal.Variable; @@ -11,10 +11,29 @@ import org.nd4j.linalg.dataset.api.MultiDataSet; * A {@link SameDiff} listener interface that is called during every iteration of training or inference * * @author Alex Black - * @see BaseListener BaseListener, for extending + * @see BaseListener BaseListener, for extending only the required methods (all others are no-op) + * @see BaseEvaluationListener BaseEvaluationListener, for extending if you want to use evaluations */ public interface Listener { + + /** + * Required variables for this listener. + *

+ * Used to ensure these variables end up in the minimum required subgraph calculated by {@link org.nd4j.autodiff.samediff.internal.InferenceSession}. + * Otherwise, if the variables weren't required by a loss variable, they would not be calculated. + *

+ * Any variables in here are guaranteed to have {@link Listener#activationAvailable(SameDiff, At, MultiDataSet, SameDiffOp, String, INDArray)} + * called for them, regardless of whether they would normally be calculated or not. + */ + ListenerVariables requiredVariables(SameDiff sd); + + /** + * Returns whether this listener is active during the given operation. If this returns false for the given operation, + * those listener methods will not be called. + */ + boolean isActive(Operation operation); + /** * Called at the start of every epoch, when fitting from an iterator * @@ -26,10 +45,23 @@ public interface Listener { /** * Called at the end of every epoch, when fitting from an iterator * - * @param sd The SameDiff instance - * @param at Current iteration/epoch etc + * @param sd The SameDiff instance + * @param at Current iteration/epoch etc + * @param lossCurve The losses so far + * @param epochTimeMillis How long this epoch took + * @return ListenerResponse.STOP to stop training, CONTINUE or null to continue */ - void epochEnd(SameDiff sd, At at); + ListenerResponse epochEnd(SameDiff sd, At at, LossCurve lossCurve, long epochTimeMillis); + + /** + * Called after the end of every epoch, once validation evaluation is done, when training + * + * @param sd The SameDiff instance + * @param at Current iteration/epoch etc + * @param validationTimeMillis How long validation took for this epoch + * @return ListenerResponse.STOP to stop training, CONTINUE or null to continue + */ + ListenerResponse validationDone(SameDiff sd, At at, long validationTimeMillis); /** * Called at the start of every iteration (minibatch), before any operations have been executed @@ -45,31 +77,70 @@ public interface Listener { * @param sd The SameDiff instance * @param at Current iteration/epoch etc * @param dataSet The current dataset (minibatch) used for training - * @param loss The loss value for the current minibatch + * @param loss The loss value for the current minibatch. Will be null except for during training */ void iterationDone(SameDiff sd, At at, MultiDataSet dataSet, Loss loss); /** - * Called just before each operation is executed (native code called, etc) - after all inputs etc have been set + * Called at the start of an operation, e.g. training or validation * - * @param sd The SameDiff instance - * @param at Current iteration/epoch etc - * @param op Operation that has just been executed + * @param sd The SameDiff instance + * @param op The operation being started */ - void preOpExecution(SameDiff sd, At at, boolean training, SameDiffOp op); + void operationStart(SameDiff sd, Operation op); /** - * Called at the end of each operation execution + * Called at the end of an operation, e.g. training or validation + * + * @param sd The SameDiff instance + * @param op The operation being started + */ + void operationEnd(SameDiff sd, Operation op); + + /** + * Called just before each operation is executed (native code called, etc) - after all inputs etc have been set + * + * @param sd The SameDiff instance + * @param at Current iteration/epoch etc + * @param op Operation that has just been executed + */ + void preOpExecution(SameDiff sd, At at, SameDiffOp op); + + /** + * Called at the end of each operation execution
+ *

+ * Note: Outputs will most likely be freed later, use detach() if you need to save it. * * @param sd The SameDiff instance * @param at Current iteration/epoch etc + * @param batch The batch's input data. May be null if not called with a batch * @param op Operation that has just been executed * @param outputs The output arrays for the just-executed operation */ - void opExecution(SameDiff sd, At at, boolean training, SameDiffOp op, INDArray[] outputs); + void opExecution(SameDiff sd, At at, MultiDataSet batch, SameDiffOp op, INDArray[] outputs); /** - * Called just before each parameter is to be updated - i.e., just before each parameter is modified + * Called when any activation becomes available. + *

+ * The activation will most likely be freed later, use detach() if you need to save it.
+ *
+ * Note that this method will be called when any activation becomes available, not just ones from {@link #requiredVariables(SameDiff)}
+ * It is guaranteed to be called for variables from requiredVariables().
+ *
+ * Note that the activations here overlap with {@link #opExecution(SameDiff, At, MultiDataSet, SameDiffOp, INDArray[])} - + * both contain the same information/arrays + * + * @param sd The SameDiff instance + * @param at Current iteration/epoch etc + * @param batch The batch's input data. May be null if not called with a batch + * @param op Operation that has just been executed + * @param varName The name of the variable + * @param activation The variable's activation + */ + void activationAvailable(SameDiff sd, At at, MultiDataSet batch, SameDiffOp op, String varName, INDArray activation); + + /** + * Called just before each parameter is to be updated - i.e., just before each parameter is modified. * * @param sd SameDiff instance * @param at The current iteration/epoch etc diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/ListenerEvaluations.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/ListenerEvaluations.java new file mode 100644 index 000000000..08722b06e --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/ListenerEvaluations.java @@ -0,0 +1,228 @@ +/* + * Copyright (c) 2015-2019 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.nd4j.autodiff.listeners; + +import java.util.Arrays; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; + +import lombok.Data; +import lombok.Getter; +import lombok.NoArgsConstructor; +import lombok.NonNull; +import lombok.Setter; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.base.Preconditions; +import org.nd4j.evaluation.IEvaluation; + +/** + * A class to allow Listeners to define what evaluations they need to run during training
+ *

+ * Usage example - does classification ({@link org.nd4j.evaluation.classification.Evaluation}) evaluation on + * the training set (as training proceeds) and also Evaluation/ROCMultiClass evaluation on the test/validation set. + * Assumes that the output predictions are called "softmax" and the first DataSet/MultiDataSet labels are those corresponding + * to the "softmax" node + *

{@code
+ * ListenerEvaluations.builder()
+ *     //trainEvaluations: on the training data (in-line, as training proceeds through the epoch)
+ *     .trainEvaluation("softmax", 0, new Evaluation(), new ROCMultiClass())
+ *     //validationEvaluation: on the test/validation data, at the end of each epoch
+ *     .validationEvaluation("softmax", 0, new Evaluation(), new ROCMultiClass())
+ *     .build();
+ * }
+ */ +@Getter +public class ListenerEvaluations { + private Map> trainEvaluations; + private Map trainEvaluationLabels; + + private Map> validationEvaluations; + private Map validationEvaluationLabels; + + public ListenerEvaluations(Map> trainEvaluations, + Map trainEvaluationLabels, Map> validationEvaluations, + Map validationEvaluationLabels) { + this.trainEvaluations = trainEvaluations; + this.trainEvaluationLabels = trainEvaluationLabels; + this.validationEvaluations = validationEvaluations; + this.validationEvaluationLabels = validationEvaluationLabels; + + Preconditions.checkArgument(trainEvaluations.keySet().equals(trainEvaluationLabels.keySet()), + "Must specify a label index for each train evaluation. Expected: %s, got: %s", + trainEvaluations.keySet(), trainEvaluationLabels.keySet()); + + Preconditions.checkArgument(validationEvaluations.keySet().equals(validationEvaluationLabels.keySet()), + "Must specify a label index for each validation evaluation. Expected: %s, got: %s", + validationEvaluations.keySet(), validationEvaluationLabels.keySet()); + } + + private ListenerEvaluations() { + + } + + public static Builder builder() { + return new Builder(); + } + + /** + * Get the requested training evaluations + */ + public Map> trainEvaluations() { + return trainEvaluations; + } + + /** + * Get the label indices for the requested training evaluations + */ + public Map trainEvaluationLabels() { + return trainEvaluationLabels; + } + + /** + * Get the requested validation evaluations + */ + public Map> validationEvaluations() { + return validationEvaluations; + } + + /** + * Get the label indices for the requested validation evaluations + */ + public Map validationEvaluationLabels() { + return validationEvaluationLabels; + } + + /** + * Get the required variables for these evaluations + */ + public ListenerVariables requiredVariables() { + return new ListenerVariables(trainEvaluations.keySet(), validationEvaluations.keySet(), + new HashSet(), new HashSet()); + } + + /** + * @return true if there are no requested evaluations + */ + public boolean isEmpty() { + return trainEvaluations.isEmpty() && validationEvaluations.isEmpty(); + } + + @NoArgsConstructor + @Getter + @Setter + public static class Builder { + private Map> trainEvaluations = new HashMap<>(); + private Map trainEvaluationLabels = new HashMap<>(); + + private Map> validationEvaluations = new HashMap<>(); + private Map validationEvaluationLabels = new HashMap<>(); + + private void addEvaluations(boolean validation, @NonNull Map> evaluationMap, @NonNull Map labelMap, + @NonNull String variableName, int labelIndex, @NonNull IEvaluation... evaluations) { + if (evaluationMap.containsKey(variableName) && labelMap.get(variableName) != labelIndex) { + String s; + + if (validation) { + s = "This ListenerEvaluations.Builder already has validation evaluations for "; + } else { + s = "This ListenerEvaluations.Builder already has train evaluations for "; + } + + throw new IllegalArgumentException(s + "variable " + + variableName + " with label index " + labelIndex + ". You can't add " + + " evaluations with a different label index. Got label index " + labelIndex); + } + + if (evaluationMap.containsKey(variableName)) { + evaluationMap.get(variableName).addAll(Arrays.asList(evaluations)); + } else { + evaluationMap.put(variableName, Arrays.asList(evaluations)); + labelMap.put(variableName, labelIndex); + } + } + + /** + * Add requested training evaluations for a parm/variable + * + * @param variableName The variable to evaluate + * @param labelIndex The index of the label to evaluate against + * @param evaluations The evaluations to run + */ + public Builder trainEvaluation(@NonNull String variableName, int labelIndex, @NonNull IEvaluation... evaluations) { + addEvaluations(false, this.trainEvaluations, this.trainEvaluationLabels, variableName, + labelIndex, evaluations); + return this; + } + + /** + * Add requested training evaluations for a parm/variable + * + * @param variable The variable to evaluate + * @param labelIndex The index of the label to evaluate against + * @param evaluations The evaluations to run + */ + public Builder trainEvaluation(@NonNull SDVariable variable, int labelIndex, @NonNull IEvaluation... evaluations) { + return trainEvaluation(variable.getVarName(), labelIndex, evaluations); + } + + /** + * Add requested validation evaluations for a parm/variable + * + * @param variableName The variable to evaluate + * @param labelIndex The index of the label to evaluate against + * @param evaluations The evaluations to run + */ + public Builder validationEvaluation(@NonNull String variableName, int labelIndex, @NonNull IEvaluation... evaluations) { + addEvaluations(true, this.validationEvaluations, this.validationEvaluationLabels, variableName, + labelIndex, evaluations); + return this; + } + + /** + * Add requested validation evaluations for a parm/variable + * + * @param variable The variable to evaluate + * @param labelIndex The index of the label to evaluate against + * @param evaluations The evaluations to run + */ + public Builder validationEvaluation(@NonNull SDVariable variable, int labelIndex, @NonNull IEvaluation... evaluations) { + return validationEvaluation(variable.getVarName(), labelIndex, evaluations); + } + + /** + * Add requested evaluations for a parm/variable, for either training or validation + * + * @param validation Whether to add these evaluations as validation or training + * @param variableName The variable to evaluate + * @param labelIndex The index of the label to evaluate against + * @param evaluations The evaluations to run + */ + public Builder addEvaluations(boolean validation, @NonNull String variableName, int labelIndex, @NonNull IEvaluation... evaluations) { + if (validation) { + return validationEvaluation(variableName, labelIndex, evaluations); + } else { + return trainEvaluation(variableName, labelIndex, evaluations); + } + } + + public ListenerEvaluations build() { + return new ListenerEvaluations(trainEvaluations, trainEvaluationLabels, validationEvaluations, validationEvaluationLabels); + } + } +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/ListenerResponse.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/ListenerResponse.java new file mode 100644 index 000000000..c6ff02827 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/ListenerResponse.java @@ -0,0 +1,26 @@ +/* + * Copyright (c) 2015-2019 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.nd4j.autodiff.listeners; + +/** + * An enum representing feedback given by listeners during the training loop.
+ * CONTINUE: Continue training for more epochs, unless the specified (maximum) number of training epochs have already been completed.
+ * STOP: Terminate training at the current point, irrespective of how many total epochs were specified when calling fit.
+ */ +public enum ListenerResponse { + CONTINUE, STOP; +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/ListenerVariables.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/ListenerVariables.java new file mode 100644 index 000000000..1156de102 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/ListenerVariables.java @@ -0,0 +1,236 @@ +/* + * Copyright (c) 2015-2019 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.nd4j.autodiff.listeners; + +import com.google.common.collect.Sets; + +import java.util.Arrays; +import java.util.HashSet; +import java.util.Set; + +import lombok.Getter; +import lombok.NoArgsConstructor; +import lombok.NonNull; +import lombok.RequiredArgsConstructor; +import lombok.Setter; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.autodiff.samediff.internal.SameDiffOp; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.dataset.api.MultiDataSet; + +/** + * Specifies a Listener's required variables for each operation. + * Used to ensure those variables end up in the minimum required subgraph calculated by {@link org.nd4j.autodiff.samediff.internal.InferenceSession}. + * Otherwise, if the variables weren't required by a loss variable, they would not be calculated. + *

+ * Any variables in here are guaranteed to have {@link Listener#activationAvailable(SameDiff, At, MultiDataSet, SameDiffOp, String, INDArray)} called for them. + */ +@RequiredArgsConstructor +@Getter +public class ListenerVariables { + + public static ListenerVariables empty() { + return ListenerVariables.builder().build(); + } + + @NonNull + private Set trainingVariables; + @NonNull + private Set validationVariables; + @NonNull + private Set evaluationVariables; + @NonNull + private Set inferenceVariables; + + public static Builder builder() { + return new Builder(); + } + + /** + * Get required training variables + */ + public Set trainingVariables() { + return trainingVariables; + } + + /** + * Get required validation variables + */ + public Set validationVariables() { + return validationVariables; + } + + /** + * Get required evaluation variables + */ + public Set evaluationVariables() { + return evaluationVariables; + } + + /** + * Get required inference variables + */ + public Set inferenceVariables() { + return inferenceVariables; + } + + /** + * Get required variables for specified op + */ + public Set requiredVariables(Operation op) { + switch (op) { + case TRAINING: + return trainingVariables; + case TRAINING_VALIDATION: + return validationVariables; + case INFERENCE: + return inferenceVariables; + case EVALUATION: + return evaluationVariables; + } + throw new IllegalArgumentException("Unknown operation " + op); + } + + private ListenerVariables() { + + } + + /** + * Return a new ListenerVariables that contains the variables of this ListenerVariables and of other + */ + public ListenerVariables merge(ListenerVariables other) { + return new ListenerVariables( + Sets.newHashSet(Sets.union(trainingVariables, other.trainingVariables)), + Sets.newHashSet(Sets.union(validationVariables, other.validationVariables)), + Sets.newHashSet(Sets.union(evaluationVariables, other.evaluationVariables)), + Sets.newHashSet(Sets.union(inferenceVariables, other.inferenceVariables))); + } + + @NoArgsConstructor + @Getter + @Setter + public static class Builder { + @NonNull + private Set trainingVariables = new HashSet<>(); + @NonNull + private Set validationVariables = new HashSet<>(); + @NonNull + private Set evaluationVariables = new HashSet<>(); + @NonNull + private Set inferenceVariables = new HashSet<>(); + + /** + * Add required variables for the specified op + * + * @param op The op to require the variable for + */ + public Builder requireVariables(@NonNull Operation op, @NonNull String... variables) { + switch (op) { + case TRAINING: + trainingVariables.addAll(Arrays.asList(variables)); + break; + case TRAINING_VALIDATION: + validationVariables.addAll(Arrays.asList(variables)); + break; + case INFERENCE: + inferenceVariables.addAll(Arrays.asList(variables)); + break; + case EVALUATION: + evaluationVariables.addAll(Arrays.asList(variables)); + break; + } + + return this; + } + + /** + * Add required variables for the specified op + * + * @param op The op to require the variable for + */ + public Builder requireVariables(@NonNull Operation op, @NonNull SDVariable... variables) { + String[] names = new String[variables.length]; + + for (int i = 0; i < variables.length; i++) + names[i] = variables[i].getVarName(); + + return requireVariables(op, names); + } + + /** + * Add required variables for training + */ + public Builder trainingVariables(@NonNull String... variables) { + return requireVariables(Operation.TRAINING, variables); + } + + /** + * Add required variables for training + */ + public Builder trainingVariables(@NonNull SDVariable... variables) { + return requireVariables(Operation.TRAINING, variables); + } + + /** + * Add required variables for validation + */ + public Builder validationVariables(@NonNull String... variables) { + return requireVariables(Operation.TRAINING_VALIDATION, variables); + } + + /** + * Add required variables for validation + */ + public Builder validationVariables(@NonNull SDVariable... variables) { + return requireVariables(Operation.TRAINING_VALIDATION, variables); + } + + /** + * Add required variables for inference + */ + public Builder inferenceVariables(@NonNull String... variables) { + return requireVariables(Operation.INFERENCE, variables); + } + + /** + * Add required variables for inference + */ + public Builder inferenceVariables(@NonNull SDVariable... variables) { + return requireVariables(Operation.INFERENCE, variables); + } + + /** + * Add required variables for evaluation + */ + public Builder evaluationVariables(@NonNull String... variables) { + return requireVariables(Operation.EVALUATION, variables); + } + + /** + * Add required variables for evaluation + */ + public Builder evaluationVariables(@NonNull SDVariable... variables) { + return requireVariables(Operation.EVALUATION, variables); + } + + public ListenerVariables build() { + return new ListenerVariables(trainingVariables, validationVariables, evaluationVariables, inferenceVariables); + } + } + +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/Loss.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/Loss.java index d9f2dea45..e85a93739 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/Loss.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/Loss.java @@ -1,5 +1,8 @@ package org.nd4j.autodiff.listeners; +import java.util.ArrayList; +import java.util.Collections; + import lombok.Data; import lombok.NonNull; import org.nd4j.base.Preconditions; @@ -7,7 +10,7 @@ import org.nd4j.base.Preconditions; import java.util.List; /** - * Loss class - represents the loss (score) for the network. Provides a breakdown of all the loss components + * Loss class - represents the loss (score) for the network, for one iteration. Provides a breakdown of all the loss components * * @author Alex Black */ @@ -70,4 +73,96 @@ public class Loss { } return sum; } + + public Loss copy() { + return new Loss(lossNames, losses); + } + + public static Loss sum(List losses) { + + if (losses.size() == 0) + return new Loss(Collections.emptyList(), new double[0]); + + double[] lossValues = new double[losses.get(0).losses.length]; + List lossNames = new ArrayList<>(losses.get(0).lossNames); + + for (int i = 0; i < losses.size(); i++) { + Loss l = losses.get(i); + Preconditions.checkState(l.losses.length == lossValues.length, + "Loss %s has %s losses, the others before it had %s.", i, l.losses.length, lossValues.length); + + Preconditions.checkState(l.lossNames.equals(lossNames), + "Loss %s has different loss names from the others before it. Expected %s, got %s.", + i, lossNames, l.lossNames); + + for (int j = 0; j < lossValues.length; j++) + lossValues[j] += l.losses[j]; + + } + + return new Loss(lossNames, lossValues); + } + + public static Loss average(List losses) { + Loss sum = sum(losses); + + for (int i = 0; i < sum.losses.length; i++) { + sum.losses[i] /= losses.size(); + } + + return sum; + } + + public static Loss add(Loss a, Loss b) { + Preconditions.checkState(a.lossNames.equals(b.lossNames), + "Loss names differ. First loss has names %s, second has names %s.", + a.lossNames, b.lossNames); + + double[] lossValues = new double[a.losses.length]; + for (int i = 0; i < lossValues.length; i++) + lossValues[i] = a.losses[i] + b.losses[i]; + + return new Loss(a.lossNames, lossValues); + } + + public static Loss sub(Loss a, Loss b) { + Preconditions.checkState(a.lossNames.equals(b.lossNames), + "Loss names differ. First loss has names %s, second has names %s.", + a.lossNames, b.lossNames); + + double[] lossValues = new double[a.losses.length]; + for (int i = 0; i < lossValues.length; i++) + lossValues[i] = a.losses[i] - b.losses[i]; + + return new Loss(a.lossNames, lossValues); + } + + public static Loss div(Loss a, Number b) { + double[] lossValues = new double[a.losses.length]; + for (int i = 0; i < lossValues.length; i++) + lossValues[i] = a.losses[i] / b.doubleValue(); + + return new Loss(a.lossNames, lossValues); + } + + public Loss add(Loss other) { + return add(this, other); + } + + public Loss sub(Loss other) { + return sub(this, other); + } + + public Loss plus(Loss other) { + return add(this, other); + } + + public Loss minus(Loss other) { + return sub(this, other); + } + + public Loss div(Number other) { + return div(this, other); + } + } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/Operation.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/Operation.java new file mode 100644 index 000000000..8676c4b02 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/Operation.java @@ -0,0 +1,60 @@ +/* + * Copyright (c) 2015-2019 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.nd4j.autodiff.listeners; + +import java.util.Map; + +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; + +/** + * An enum representing the operation being done on a SameDiff graph.
+ *

+ * TRAINING: {@link SameDiff#fit()} methods training step (everything except validation)
+ * TRAINING_VALIDATION: the validation step during {@link SameDiff#fit()} methods - i.e., test/validation set evaluation,
+ * INFERENCE: {@link SameDiff#output()}, {@link SameDiff#batchOutput()} and {@link SameDiff#exec(Map, String...)} ()} methods, + * including the single batch and placeholder ones. Also {@link SDVariable#eval()}
+ * EVALUATION: {@link SameDiff#evaluate()} methods
+ */ +public enum Operation { + /** + * The training operation: {@link SameDiff#fit()} methods training step (everything except validation). + */ + TRAINING, + /** + * The training validation operation: the validation step during {@link SameDiff#fit()} methods. + */ + TRAINING_VALIDATION, + /** + * Inference operations: {@link SameDiff#output()}, {@link SameDiff#batchOutput()} and {@link SameDiff#exec(Map, String...)} ()} methods, + * as well as {@link SameDiff#execBackwards(Map, Operation, String...)} methods. + */ + INFERENCE, + /** + * Evaluation operations: {@link SameDiff#evaluate()} methods. + */ + EVALUATION; + + public boolean isTrainingPhase() { + return this == TRAINING || this == TRAINING_VALIDATION; + } + + public boolean isValidation() { + return this == TRAINING_VALIDATION; + } + +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/checkpoint/CheckpointListener.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/checkpoint/CheckpointListener.java index 432cc7168..1932a6c75 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/checkpoint/CheckpointListener.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/checkpoint/CheckpointListener.java @@ -7,13 +7,15 @@ import lombok.extern.slf4j.Slf4j; import org.apache.commons.io.IOUtils; import org.nd4j.autodiff.listeners.At; import org.nd4j.autodiff.listeners.BaseListener; +import org.nd4j.autodiff.listeners.ListenerResponse; import org.nd4j.autodiff.listeners.Loss; +import org.nd4j.autodiff.listeners.records.LossCurve; +import org.nd4j.autodiff.listeners.Operation; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.base.Preconditions; import org.nd4j.linalg.dataset.api.MultiDataSet; import java.io.*; -import java.nio.charset.Charset; import java.nio.charset.StandardCharsets; import java.text.SimpleDateFormat; import java.util.*; @@ -148,12 +150,18 @@ public class CheckpointListener extends BaseListener implements Serializable { } @Override - public void epochEnd(SameDiff sameDiff, At at) { + public ListenerResponse epochEnd(SameDiff sameDiff, At at, LossCurve lossCurve, long epochTimeMillis) { if(saveEveryNEpochs != null && (at.epoch()+1) % saveEveryNEpochs == 0){ //Save: saveCheckpoint(sameDiff, at); } //General saving conditions: don't need to check here - will check in iterationDone + return ListenerResponse.CONTINUE; + } + + @Override + public boolean isActive(Operation operation) { + return operation == Operation.TRAINING; } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/debugging/ExecDebuggingListener.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/debugging/ExecDebuggingListener.java index 58ba7fffc..a9862d253 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/debugging/ExecDebuggingListener.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/debugging/ExecDebuggingListener.java @@ -3,6 +3,7 @@ package org.nd4j.autodiff.listeners.debugging; import org.nd4j.autodiff.functions.DifferentialFunction; import org.nd4j.autodiff.listeners.At; import org.nd4j.autodiff.listeners.BaseListener; +import org.nd4j.autodiff.listeners.Operation; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.internal.SameDiffOp; import org.nd4j.linalg.api.buffer.DataType; @@ -70,7 +71,12 @@ public class ExecDebuggingListener extends BaseListener { } @Override - public void preOpExecution(SameDiff sd, At at, boolean training, SameDiffOp op) { + public boolean isActive(Operation operation) { + return true; + } + + @Override + public void preOpExecution(SameDiff sd, At at, SameDiffOp op) { if(lastIter != at.iteration()){ lastIter = at.iteration(); stepThisIter = 0; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/impl/HistoryListener.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/impl/HistoryListener.java new file mode 100644 index 000000000..eb0675da5 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/impl/HistoryListener.java @@ -0,0 +1,121 @@ +/* + * Copyright (c) 2015-2019 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.nd4j.autodiff.listeners.impl; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Set; + +import lombok.Getter; +import lombok.Setter; +import org.nd4j.autodiff.listeners.At; +import org.nd4j.autodiff.listeners.BaseEvaluationListener; +import org.nd4j.autodiff.listeners.records.EvaluationRecord; +import org.nd4j.autodiff.listeners.records.History; +import org.nd4j.autodiff.listeners.ListenerEvaluations; +import org.nd4j.autodiff.listeners.ListenerResponse; +import org.nd4j.autodiff.listeners.records.LossCurve; +import org.nd4j.autodiff.listeners.Operation; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.autodiff.samediff.TrainingConfig; +import org.nd4j.autodiff.samediff.internal.SameDiffOp; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.dataset.api.MultiDataSet; + +/** + * HistoryListener is mainly used internally to collect information such as the loss curve and evaluations, + * which will be reported later in a {@link History} instance + */ +public class HistoryListener extends BaseEvaluationListener { + + @Getter + @Setter + private ListenerEvaluations evaluations; + + private List trainingHistory = new ArrayList<>(); + private List validationHistory = new ArrayList<>(); + private LossCurve loss = null; + + private long startTime; + private long endTime; + + private List validationTimes = new ArrayList<>(); + private long validationStartTime; + + + public HistoryListener(TrainingConfig tc) { + this.evaluations = new ListenerEvaluations(tc.getTrainEvaluations(), tc.getTrainEvaluationLabels(), + tc.getValidationEvaluations(), tc.getValidationEvaluationLabels()); + } + + public HistoryListener(ListenerEvaluations evaluations) { + this.evaluations = evaluations; + } + + public HistoryListener newInstance() { + return new HistoryListener(evaluations); + } + + @Override + public ListenerEvaluations evaluations() { + return evaluations; + } + + @Override + public boolean isActive(Operation operation) { + return operation.isTrainingPhase(); + } + + @Override + public ListenerResponse epochEndEvaluations(SameDiff sd, At at, LossCurve lossCurve, long epochTimeMillis, EvaluationRecord evaluations) { + trainingHistory.add(evaluations); + loss = lossCurve; + + return ListenerResponse.CONTINUE; + } + + @Override + public ListenerResponse validationDoneEvaluations(SameDiff sd, At at, long validationTimeMillis, EvaluationRecord evaluations) { + validationHistory.add(evaluations); + return ListenerResponse.CONTINUE; + } + + @Override + public void operationStart(SameDiff sd, Operation op) { + if (op == Operation.TRAINING) { + startTime = System.currentTimeMillis(); + } else if (op == Operation.TRAINING_VALIDATION) { + validationStartTime = System.currentTimeMillis(); + } + } + + @Override + public void operationEnd(SameDiff sd, Operation op) { + if (op == Operation.TRAINING) { + endTime = System.currentTimeMillis(); + } else if (op == Operation.TRAINING_VALIDATION) { + validationTimes.add(System.currentTimeMillis() - validationStartTime); + } + } + + public History getReport() { + return new History(trainingHistory, validationHistory, loss, endTime - startTime, validationTimes); + } + +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/impl/ScoreListener.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/impl/ScoreListener.java index fc8528791..90d8f5739 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/impl/ScoreListener.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/impl/ScoreListener.java @@ -3,7 +3,10 @@ package org.nd4j.autodiff.listeners.impl; import lombok.extern.slf4j.Slf4j; import org.nd4j.autodiff.listeners.At; import org.nd4j.autodiff.listeners.BaseListener; +import org.nd4j.autodiff.listeners.ListenerResponse; import org.nd4j.autodiff.listeners.Loss; +import org.nd4j.autodiff.listeners.records.LossCurve; +import org.nd4j.autodiff.listeners.Operation; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.base.Preconditions; import org.nd4j.linalg.dataset.api.MultiDataSet; @@ -32,7 +35,6 @@ public class ScoreListener extends BaseListener { private final boolean reportEpochs; private final boolean reportIterPerformance; - private long epochStart; private long epochExampleCount; private int epochBatchCount; private long etlTotalTimeEpoch; @@ -72,10 +74,14 @@ public class ScoreListener extends BaseListener { } + @Override + public boolean isActive(Operation operation) { + return operation == Operation.TRAINING; + } + @Override public void epochStart(SameDiff sd, At at) { if (reportEpochs) { - epochStart = System.currentTimeMillis(); epochExampleCount = 0; epochBatchCount = 0; etlTotalTimeEpoch = 0; @@ -85,17 +91,18 @@ public class ScoreListener extends BaseListener { } @Override - public void epochEnd(SameDiff sd, At at) { + public ListenerResponse epochEnd(SameDiff sd, At at, LossCurve lossCurve, long epochTimeMillis) { if (reportEpochs) { - long epochDuration = System.currentTimeMillis() - epochStart; - double batchesPerSec = epochBatchCount / (epochDuration / 1000.0); - double examplesPerSec = epochExampleCount / (epochDuration / 1000.0); - double pcEtl = 100.0 * etlTotalTimeEpoch / (double) epochDuration; + double batchesPerSec = epochBatchCount / (epochTimeMillis / 1000.0); + double examplesPerSec = epochExampleCount / (epochTimeMillis / 1000.0); + double pcEtl = 100.0 * etlTotalTimeEpoch / (double) epochTimeMillis; String etl = formatDurationMs(etlTotalTimeEpoch) + " ETL time" + (etlTotalTimeEpoch > 0 ? "(" + format2dp(pcEtl) + " %)" : ""); log.info("Epoch {} complete on iteration {} - {} batches ({} examples) in {} - {} batches/sec, {} examples/sec, {}", - at.epoch(), at.iteration(), epochBatchCount, epochExampleCount, formatDurationMs(epochDuration), + at.epoch(), at.iteration(), epochBatchCount, epochExampleCount, formatDurationMs(epochTimeMillis), format2dp(batchesPerSec), format2dp(examplesPerSec), etl); } + + return ListenerResponse.CONTINUE; } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/impl/UIListener.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/impl/UIListener.java index 36ac88189..452636d57 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/impl/UIListener.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/impl/UIListener.java @@ -4,7 +4,10 @@ import com.google.flatbuffers.Table; import lombok.NonNull; import org.nd4j.autodiff.listeners.At; import org.nd4j.autodiff.listeners.BaseListener; +import org.nd4j.autodiff.listeners.ListenerResponse; import org.nd4j.autodiff.listeners.Loss; +import org.nd4j.autodiff.listeners.records.LossCurve; +import org.nd4j.autodiff.listeners.Operation; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.internal.SameDiffOp; @@ -279,13 +282,18 @@ public class UIListener extends BaseListener { writer.writeFinishStaticMarker(); } + @Override + public boolean isActive(Operation operation) { + return operation == Operation.TRAINING; + } + @Override public void epochStart(SameDiff sd, At at) { epochTrainEval = null; } @Override - public void epochEnd(SameDiff sd, At at) { + public ListenerResponse epochEnd(SameDiff sd, At at, LossCurve lossCurve, long epochTimeMillis) { //If any training evaluation, report it here: if(epochTrainEval != null){ @@ -315,6 +323,7 @@ public class UIListener extends BaseListener { } epochTrainEval = null; + return ListenerResponse.CONTINUE; } @Override @@ -401,13 +410,13 @@ public class UIListener extends BaseListener { @Override - public void opExecution(SameDiff sd, At at, boolean training, SameDiffOp op, INDArray[] outputs) { + public void opExecution(SameDiff sd, At at, MultiDataSet batch, SameDiffOp op, INDArray[] outputs) { //Do training set evaluation, if required //Note we'll do it in opExecution not iterationDone because we can't be sure arrays will be stil be around in the future //i.e., we'll eventually add workspaces and clear activation arrays once they have been consumed - if(training && trainEvalMetrics != null && trainEvalMetrics.size() > 0){ + if(at.operation() == Operation.TRAINING && trainEvalMetrics != null && trainEvalMetrics.size() > 0){ long time = System.currentTimeMillis(); //First: check if this op is relevant at all to evaluation... diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/records/EvaluationRecord.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/records/EvaluationRecord.java new file mode 100644 index 000000000..2f0dbd5b5 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/records/EvaluationRecord.java @@ -0,0 +1,240 @@ +/* + * Copyright (c) 2015-2019 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.nd4j.autodiff.listeners.records; + +import com.google.common.base.Predicates; +import com.google.common.collect.Collections2; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.Lists; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import lombok.Getter; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.base.Preconditions; +import org.nd4j.evaluation.IEvaluation; +import org.nd4j.evaluation.IMetric; + +/** + * A helper class to hold evaluations and provide methods to easily query them + */ +@Getter +public class EvaluationRecord { + + private ImmutableMap> evaluations; + private Map, IEvaluation> classEvaluations = new HashMap<>(); + private boolean isEmpty = true; + + public EvaluationRecord(Map> evaluations) { + this.evaluations = ImmutableMap.copyOf(evaluations); + + for (List le : evaluations.values()) { + for (IEvaluation e : le) { + isEmpty = false; + if (classEvaluations.containsKey(e.getClass())) + classEvaluations.remove(e.getClass()); + else + classEvaluations.put(e.getClass(), e); + } + } + } + + private EvaluationRecord() { + + } + + public boolean isEmpty() { + return isEmpty; + } + + /** + * Get all evaluations + */ + public ImmutableMap> evaluations() { + return evaluations; + } + + /** + * Get evaluations for a given param/variable + * + * @param param The target param/variable + */ + public List evaluations(String param) { + Preconditions.checkArgument(evaluations.containsKey(param), + "No evaluations for %s.", param); + + return evaluations.get(param); + } + + /** + * Get evaluations for a given param/variable + * + * @param param The target param/variable + */ + public List evaluations(SDVariable param) { + return evaluations(param.getVarName()); + } + + /** + * Get the evaluation for param at the specified index + */ + public IEvaluation evaluation(String param, int index) { + return evaluations(param).get(index); + } + + /** + * Get the evaluation for param at the specified index + */ + public IEvaluation evaluation(SDVariable param, int index) { + return evaluation(param.getVarName(), index); + } + + /** + * Get the evaluation for a given param/variable + *

+ * Will throw an exception if there are more than one or no evaluations for the param + * + * @param param The target param/variable + */ + public T evaluation(String param) { + Preconditions.checkArgument(evaluations.containsKey(param), + "No evaluations for %s.", param); + Preconditions.checkArgument(evaluations.get(param).size() == 1, + "Multiple evaluations for %s. Use evaluations().", param); + + return (T) evaluations.get(param).get(0); + } + + /** + * Get the evaluation for a given param/variable + *

+ * Will throw an exception if there are more than one or no evaluations for the param + * + * @param param The target param/variable + */ + public T evaluation(SDVariable param) { + return evaluation(param.getVarName()); + } + + /** + * Get the evaluation of a given type + *

+ * Will throw an exception if there are more than one or no evaluations of that type + * + * @param evalClass The type of evaluation to look for + */ + public > T evaluation(Class evalClass) { + Preconditions.checkArgument(classEvaluations.containsKey(evalClass), + "Can't get evaluation for %s. Either no evaluations with that class are present, or more than one are.", evalClass); + + return (T) classEvaluations.get(evalClass); + } + + /** + * Get the evaluation of a given type, for a given param/variable + *

+ * Will throw an exception if there are more than one or no evaluations of that type for the given param + * + * @param param The target param/variable + * @param evalClass The type of evaluation to look for + */ + public > T evaluation(String param, Class evalClass) { + Collection evals = Collections2.filter(evaluations(param), Predicates.instanceOf(evalClass)); + + Preconditions.checkArgument(evals.size() == 1, "Multiple or no evaluations of type %s for param %s.", evalClass, param); + + return (T) evals.iterator().next(); + } + + /** + * Get the evaluation of a given type, for a given param/variable + *

+ * Will throw an exception if there are more than one or no evaluations of that type for the given param + * + * @param param The target param/variable + * @param evalClass The type of evaluation to look for + */ + public > T evaluation(SDVariable param, Class evalClass) { + return evaluation(param.getVarName(), evalClass); + } + + /** + * Get the metric's value for the evaluation of the metric's type + *

+ * Will throw an exception if there are more than one or no evaluations of that type + * + * @param metric The metric to calculate + */ + public double getValue(IMetric metric) { + return evaluation(metric.getEvaluationClass()).getValue(metric); + } + + /** + * Get the metric's value for the evaluation of the metric's type, for a given param/variable + *

+ * Will throw an exception if there are more than one or no evaluations of that type for the given param + * + * @param param The target param/variable + * @param metric The metric to calculate + */ + public double getValue(String param, IMetric metric) { + return evaluation(param, metric.getEvaluationClass()).getValue(metric); + } + + /** + * Get the metric's value for the evaluation of the metric's type, for a given param/variable + *

+ * Will throw an exception if there are more than one or no evaluations of that type for the given param + * + * @param param The target param/variable + * @param metric The metric to calculate + */ + public double getValue(SDVariable param, IMetric metric) { + return getValue(param.getVarName(), metric); + } + + /** + * Get the metric's value for the evaluation for a given param/variable at the given index + *

+ * Will throw an exception if the target evaluation doesn't support the given metric + * + * @param param The target param/variable + * @param index The index of the target evaluation on the param + * @param metric The metric to calculate + */ + public double getValue(String param, int index, IMetric metric) { + return evaluation(param, index).getValue(metric); + } + + /** + * Get the metric's value for the evaluation for a given param/variable at the given index + *

+ * Will throw an exception if the target evaluation doesn't support the given metric + * + * @param param The target param/variable + * @param index The index of the target evaluation on the param + * @param metric The metric to calculate + */ + public double getValue(SDVariable param, int index, IMetric metric) { + return getValue(param.getVarName(), index, metric); + } + +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/records/History.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/records/History.java new file mode 100644 index 000000000..f43d41841 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/records/History.java @@ -0,0 +1,356 @@ +/* + * Copyright (c) 2015-2019 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.nd4j.autodiff.listeners.records; + +import com.google.common.collect.ImmutableList; +import java.util.ArrayList; +import java.util.List; + +import lombok.Getter; +import org.nd4j.autodiff.listeners.Listener; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.evaluation.IEvaluation; +import org.nd4j.evaluation.IMetric; +import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; + +/** + * An object containing training history for a SameDiff.fit call, such as {@link SameDiff#fit()}, {@link SameDiff#fit(DataSetIterator, int, Listener...)}, etc.
+ * Contains information including:
+ * - Evaluations performed (training set and test set)
+ * - Loss curve (score values at each iteration)
+ * - Training times, and validation times
+ * - Number of epochs performed
+ */ +@Getter +public class History { + + private List trainingHistory; + private List validationHistory; + + private LossCurve lossCurve; + + private long trainingTimeMillis; + private List validationTimesMillis; + + public History(List training, List validation, LossCurve loss, + long trainingTimeMillis, List validationTimesMillis){ + trainingHistory = ImmutableList.copyOf(training); + validationHistory = ImmutableList.copyOf(validation); + this.lossCurve = loss; + this.trainingTimeMillis = trainingTimeMillis; + this.validationTimesMillis = ImmutableList.copyOf(validationTimesMillis); + } + + /** + * Get the training evaluations + */ + public List trainingEval(){ + return trainingHistory; + } + + /** + * Get the validation evaluations + */ + public List validationEval(){ + return validationHistory; + } + + /** + * Get the loss curve + */ + public LossCurve lossCurve(){ + return lossCurve; + } + + /** + * Get the total training time, in milliseconds + */ + public long trainingTimeMillis(){ + return trainingTimeMillis; + } + + /** + * Get the total validation time, in milliseconds + */ + public List validationTimesMillis(){ + return validationTimesMillis; + } + + /** + * Get the number of epochs trained for + */ + public int trainingEpochs(){ + return trainingHistory.size(); + } + + /** + * Get the number of epochs validation was ran on + */ + public int validationEpochs(){ + return validationHistory.size(); + } + + /** + * Get the results of a training evaluation on a given parameter for a given metric + * + * Only works if there is only one evaluation with the given metric for param + */ + public List trainingEval(String param, IMetric metric){ + List data = new ArrayList<>(); + for(EvaluationRecord er : trainingHistory) + data.add(er.getValue(param, metric)); + + return data; + } + + /** + * Get the results of a training evaluation on a given parameter for a given metric + * + * Only works if there is only one evaluation with the given metric for param + */ + public List trainingEval(SDVariable param, IMetric metric){ + return trainingEval(param.getVarName(), metric); + } + + /** + * Get the results of a training evaluation on a given parameter at a given index, for a given metric + * + * Note that it returns all recorded evaluations. + * Index determines the evaluation used not the epoch's results to return. + */ + public List trainingEval(String param, int index, IMetric metric){ + List data = new ArrayList<>(); + for(EvaluationRecord er : trainingHistory) + data.add(er.getValue(param, index, metric)); + + return data; + } + + /** + * Get the results of a training evaluation on a given parameter at a given index, for a given metric + * + * Note that it returns all recorded evaluations. + * Index determines the evaluation used not the epoch's results to return. + */ + public List trainingEval(SDVariable param, int index, IMetric metric){ + return trainingEval(param.getVarName(), index, metric); + } + + /** + * Get the results of a training evaluation for a given metric + * + * Only works if there is only one evaluation with the given metric + */ + public List trainingEval(IMetric metric){ + List data = new ArrayList<>(); + for(EvaluationRecord er : trainingHistory) + data.add(er.getValue(metric)); + + return data; + } + + /** + * Get the results of a training evaluation on a given parameter + * + * Only works if there is only one evaluation for param. + */ + public List trainingEval(String param){ + List data = new ArrayList<>(); + for(EvaluationRecord er : trainingHistory) + data.add(er.evaluation(param)); + + return data; + } + + /** + * Get the results of a training evaluation on a given parameter + * + * Only works if there is only one evaluation for param. + */ + public List trainingEval(SDVariable param){ + return trainingEval(param.getVarName()); + } + + /** + * Get the results of a training evaluation on a given parameter at a given index + * + * Note that it returns all recorded evaluations. + * Index determines the evaluation used not the epoch's results to return. + */ + public List trainingEval(String param, int index){ + List data = new ArrayList<>(); + for(EvaluationRecord er : trainingHistory) + data.add(er.evaluation(param, index)); + + return data; + } + + /** + * Get the results of a training evaluation on a given parameter at a given index + * + * Note that it returns all recorded evaluations. + * Index determines the evaluation used not the epoch's results to return. + */ + public List trainingEval(SDVariable param, int index){ + return trainingEval(param.getVarName(), index); + } + + /** + * Get the results of a validation evaluation on a given parameter for a given metric + * + * Only works if there is only one evaluation with the given metric for param + */ + public List validationEval(String param, IMetric metric){ + List data = new ArrayList<>(); + for(EvaluationRecord er : validationHistory) + data.add(er.getValue(param, metric)); + + return data; + } + + /** + * Get the results of a validation evaluation on a given parameter for a given metric + * + * Only works if there is only one evaluation with the given metric for param + */ + public List validationEval(SDVariable param, IMetric metric){ + return validationEval(param.getVarName(), metric); + } + + /** + * Get the results of a validation evaluation on a given parameter at a given index, for a given metric + * + * Note that it returns all recorded evaluations. + * Index determines the evaluation used not the epoch's results to return. + */ + public List validationEval(String param, int index, IMetric metric){ + List data = new ArrayList<>(); + for(EvaluationRecord er : validationHistory) + data.add(er.getValue(param, index, metric)); + + return data; + } + + /** + * Get the results of a validation evaluation on a given parameter at a given index, for a given metric + * + * Note that it returns all recorded evaluations. + * Index determines the evaluation used not the epoch's results to return. + */ + public List validationEval(SDVariable param, int index, IMetric metric){ + return validationEval(param.getVarName(), index, metric); + } + + /** + * Get the results of a validation evaluation for a given metric + * + * Only works if there is only one evaluation with the given metric + */ + public List validationEval(IMetric metric){ + List data = new ArrayList<>(); + for(EvaluationRecord er : validationHistory) + data.add(er.getValue(metric)); + + return data; + } + + /** + * Get the results of a validation evaluation on a given parameter + * + * Only works if there is only one evaluation for param. + */ + public List validationEval(String param){ + List data = new ArrayList<>(); + for(EvaluationRecord er : validationHistory) + data.add(er.evaluation(param)); + + return data; + } + + /** + * Get the results of a validation evaluation on a given parameter + * + * Only works if there is only one evaluation for param. + */ + public List validationEval(SDVariable param){ + return validationEval(param.getVarName()); + } + + /** + * Get the results of a validation evaluation on a given parameter at a given index + * + * Note that it returns all recorded evaluations. + * Index determines the evaluation used not the epoch's results to return. + */ + public List validationEval(String param, int index){ + List data = new ArrayList<>(); + for(EvaluationRecord er : validationHistory) + data.add(er.evaluation(param, index)); + + return data; + } + + /** + * Get the results of a validation evaluation on a given parameter at a given index + * + * Note that it returns all recorded evaluations. + * Index determines the evaluation used not the epoch's results to return. + */ + public List validationEval(SDVariable param, int index){ + return validationEval(param.getVarName(), index); + } + + /** + * Gets the training evaluations ran during the last epoch + */ + public EvaluationRecord finalTrainingEvaluations(){ + return trainingHistory.get(trainingHistory.size() - 1); + } + + /** + * Gets the validation evaluations ran during the last epoch + */ + public EvaluationRecord finalValidationEvaluations(){ + return validationHistory.get(validationHistory.size() - 1); + } + + /** + * Gets the evaluation record for a given epoch. + * @param epoch The epoch to get results for. If negative, returns results for the epoch that many epochs from the end. + */ + public EvaluationRecord trainingEvaluations(int epoch){ + if(epoch >= 0){ + return trainingHistory.get(epoch); + } else { + return trainingHistory.get(trainingHistory.size() - epoch); + } + } + + /** + * Gets the evaluation record for a given epoch. + * @param epoch The epoch to get results for. If negative, returns results for the epoch that many epochs from the end. + */ + public EvaluationRecord validationEvaluations(int epoch){ + if(epoch >= 0){ + return trainingHistory.get(epoch); + } else { + return validationHistory.get(validationHistory.size() - epoch); + } + } + +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/records/LossCurve.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/records/LossCurve.java new file mode 100644 index 000000000..a65efe180 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/records/LossCurve.java @@ -0,0 +1,211 @@ +/* + * Copyright (c) 2015-2019 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.nd4j.autodiff.listeners.records; + +import com.google.common.collect.ImmutableList; +import java.util.ArrayList; +import java.util.List; +import lombok.Getter; +import lombok.NonNull; +import org.nd4j.autodiff.listeners.Loss; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.base.Preconditions; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.factory.Nd4j; + +public class LossCurve { + @Getter + private List lossNames; + @Getter + private INDArray lossValues; + + public LossCurve(List losses){ + lossNames = ImmutableList.copyOf(losses.get(0).getLossNames()); + int numLossValues = losses.get(0).lossValues().length; + lossValues = Nd4j.create(DataType.FLOAT, losses.size(), losses.get(0).lossValues().length); + + for(int i = 0 ; i < losses.size() ; i++){ + Loss l = losses.get(i); + Preconditions.checkArgument(l.getLossNames().equals(lossNames), + "Loss names for loss %s differ from others. Expected %s, got %s", + i, lossNames, l.getLossNames()); + + Preconditions.checkArgument(l.getLosses().length == numLossValues, + "Number of loss values for loss %s differ from others. Expected %s, got %s", + i, numLossValues, l.getLosses().length); + + lossValues = lossValues.putRow(i, Nd4j.createFromArray(l.getLosses()).castTo(DataType.FLOAT)); + } + } + + public LossCurve(double[] lossValues, List lossNames){ + this.lossValues = Nd4j.createFromArray(new double[][]{ lossValues}).castTo(DataType.FLOAT); + this.lossNames = lossNames; + } + + protected LossCurve(INDArray lossValues, List lossNames){ + Preconditions.checkArgument(lossValues.rank() == 2, "lossValues must have a rank of 2, got %s", lossValues.rank()); + Preconditions.checkArgument(lossValues.dataType() == DataType.FLOAT, "lossValues must be type FLOAT, got %s", lossValues.dataType()); + this.lossValues = lossValues; + this.lossNames = lossNames; + } + + public List losses(){ + List losses = new ArrayList<>(); + for(int i = 0 ; i < lossValues.size(0) ; i++){ + losses.add(new Loss(lossNames, lossValues.getRow(i).toDoubleVector())); + } + return losses; + } + + /** + * Get the mean loss for a given epoch + * + * If epoch is negative, counts backwards from the end. + * E.g. losses(-1) gets the last epoch. + * + * @param epoch The epoch to get. If negative, returns results for the epoch that many epochs from the end + */ + public Loss meanLoss(int epoch){ + if(epoch >= 0){ + return new Loss(lossNames, lossValues.getRow(epoch).toDoubleVector()); + } else { + return new Loss(lossNames, lossValues.getRow(lossValues.rows() + epoch).toDoubleVector()); + } + } + + /** + * Get the mean loss for the last epoch. + */ + public Loss lastMeanLoss(){ + return meanLoss(-1); + } + + /** + * Return all mean loss values for a given variable + */ + public float[] meanLoss(@NonNull String lossName){ + + int idx = lossNames.indexOf(lossName); + + Preconditions.checkArgument(idx >= 0, "No loss value for %s. Existing losses: %s", lossName, lossNames); + + float[] loss = new float[(int) lossValues.size(0)]; + for(int i = 0 ; i < lossValues.size(0) ; i++){ + loss[i] = lossValues.getFloat(i, idx); + } + return loss; + } + + /** + * Return all mean loss values for a given variable + */ + public float[] meanLoss(@NonNull SDVariable loss){ + return meanLoss(loss.getVarName()); + } + + /** + * Return the mean loss value for a given variable on a given epoch. + * + * See {@link #meanLoss(int)} + */ + public float meanLoss(@NonNull String lossName, int epoch){ + + int idx = lossNames.indexOf(lossName); + + Preconditions.checkArgument(idx >= 0, "No loss value for %s. Existing losses: %s", lossName, lossNames); + + if(epoch >= 0) { + return lossValues.getFloat(epoch, idx); + } else { + return lossValues.getFloat(lossValues.rows() + epoch, idx); + } + } + + /** + * Return the mean loss value for a given variable on a given epoch. + * + * See {@link #meanLoss(int)} + */ + public float meanLoss(@NonNull SDVariable loss, int epoch){ + return meanLoss(loss.getVarName(), epoch); + } + + /** + * Return the mean loss value for a given variable on the last epoch. + */ + public float lastMeanLoss(@NonNull String lossName){ + + int idx = lossNames.indexOf(lossName); + + Preconditions.checkArgument(idx >= 0, "No loss value for %s. Existing losses: %s", lossName, lossNames); + + return lossValues.getFloat(lossValues.rows() - 1, idx); + } + + /** + * Return the mean loss value for a given variable on the last epoch. + */ + public float lastMeanLoss(@NonNull SDVariable loss){ + return lastMeanLoss(loss.getVarName()); + } + + /** + * Return the loss delta between the last epoch and the one before it. + * Equivalent to meanLoss(-1) - meanLoss(-2). + * A positive delta means the loss is increasing, and a negative delta means it is decreasing. + */ + public Loss lastMeanDelta(){ + return lastMeanLoss().sub(meanLoss(-2)); + } + + /** + * Return the loss delta between the last epoch and the one before it, for a given variable. + * Equivalent to meanLoss(-1) - meanLoss(-2). + * A positive delta means the loss is increasing, and a negative delta means it is decreasing. + */ + public double lastMeanDelta(String lossName){ + return lastMeanDelta().getLoss(lossName); + } + + /** + * Return the loss delta between the last epoch and the one before it, for a given variable. + * Equivalent to meanLoss(-1) - meanLoss(-2). + * A positive delta means the loss is increasing, and a negative delta means it is decreasing. + */ + public double lastMeanDelta(SDVariable loss){ + return lastMeanDelta(loss.getVarName()); + } + + /** + * Return a new LossCurve with the given losses added on as the most recent epoch + */ + public LossCurve addLossAndCopy(Loss loss){ + return addLossAndCopy(loss.getLosses(), loss.lossNames()); + } + + /** + * Return a new LossCurve with the given losses added on as the most recent epoch + */ + public LossCurve addLossAndCopy(double[] values, List lossNames){ + return new LossCurve( + Nd4j.concat(0, lossValues, + Nd4j.createFromArray(new double[][]{values}).castTo(DataType.FLOAT)), + lossNames); + } +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SameDiff.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SameDiff.java index b3abf7b00..9cfe87822 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SameDiff.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SameDiff.java @@ -16,7 +16,9 @@ package org.nd4j.autodiff.samediff; +import com.google.common.base.Predicates; import com.google.common.collect.HashBasedTable; +import com.google.common.collect.Maps; import com.google.common.collect.Table; import com.google.common.primitives.Ints; import com.google.flatbuffers.FlatBufferBuilder; @@ -30,9 +32,14 @@ import org.nd4j.autodiff.execution.conf.ExecutorConfiguration; import org.nd4j.autodiff.execution.conf.OutputMode; import org.nd4j.autodiff.functions.DifferentialFunction; import org.nd4j.autodiff.functions.DifferentialFunctionFactory; -import org.nd4j.autodiff.listeners.At; -import org.nd4j.autodiff.listeners.Listener; -import org.nd4j.autodiff.listeners.Loss; +import org.nd4j.autodiff.listeners.*; +import org.nd4j.autodiff.listeners.impl.HistoryListener; +import org.nd4j.autodiff.listeners.records.History; +import org.nd4j.autodiff.listeners.records.LossCurve; +import org.nd4j.autodiff.samediff.config.BatchOutputConfig; +import org.nd4j.autodiff.samediff.config.EvaluationConfig; +import org.nd4j.autodiff.samediff.config.FitConfig; +import org.nd4j.autodiff.samediff.config.OutputConfig; import org.nd4j.autodiff.samediff.internal.*; import org.nd4j.autodiff.samediff.ops.*; import org.nd4j.autodiff.samediff.serde.FlatBuffersMapper; @@ -40,6 +47,8 @@ import org.nd4j.autodiff.util.cloner.DataBufferFastCloner; import org.nd4j.autodiff.util.cloner.INDArrayFastCloner; import org.nd4j.base.Preconditions; import org.nd4j.evaluation.IEvaluation; +import org.nd4j.evaluation.classification.Evaluation; +import org.nd4j.evaluation.classification.ROC; import org.nd4j.graph.*; import org.nd4j.imports.graphmapper.tf.TFGraphMapper; import org.nd4j.linalg.api.buffer.DataType; @@ -94,7 +103,6 @@ import java.util.concurrent.atomic.AtomicInteger; import java.util.regex.Matcher; import java.util.regex.Pattern; -import static org.nd4j.autodiff.util.TrainingUtils.getSingleOutput; import static org.nd4j.autodiff.util.TrainingUtils.stackOutputs; /** @@ -114,19 +122,19 @@ public class SameDiff extends SDBaseOps { //Fields for graph structure and execution @Getter //TODO use package private instead of public getters? - private final Map variables = new LinkedHashMap<>(); //Use linked hash map to guarantee iteration order based on order they were added. Used in inputs() and flatbuffers serde + private final Map variables = new LinkedHashMap<>(); //Use linked hash map to guarantee iteration order based on order they were added. Used in inputs() and flatbuffers serde @Getter - private final Map ops = new LinkedHashMap<>(); + private final Map ops = new LinkedHashMap<>(); @Getter - private final Map sessions = new ConcurrentHashMap<>(); //Key: thread ID + private final Map sessions = new ConcurrentHashMap<>(); //Key: thread ID - private final Map constantArrays = new ConcurrentHashMap<>(); - private final Map variablesArrays = new ConcurrentHashMap<>(); //TODO issues with DeviceLocal + mutable / changed during training? - private final Map> placeholdersPerThread = new ConcurrentHashMap<>(); //Placeholders for each thread - if the user sets them + private final Map constantArrays = new ConcurrentHashMap<>(); + private final Map variablesArrays = new ConcurrentHashMap<>(); //TODO issues with DeviceLocal + mutable / changed during training? + private final Map> placeholdersPerThread = new ConcurrentHashMap<>(); //Placeholders for each thread - if the user sets them private final List lossVariables = new ArrayList<>(); - private List listeners = new ArrayList<>(); + private final List listeners = new ArrayList<>(); private final List nameScopes = new ArrayList<>(); //Used as a stack @@ -137,7 +145,7 @@ public class SameDiff extends SDBaseOps { @Getter private boolean initializedTraining; //True if training setup has been done @Getter - private Map updaterMap; //GradientUpdater instance for each trainable parameter + private Map updaterMap; //GradientUpdater instance for each trainable parameter //////////////////////////////////////// //map a function's instance id to a base name, used for propagating variable names @@ -155,53 +163,81 @@ public class SameDiff extends SDBaseOps { //////////////////////////////////////// - /** Op creator object for math operations */ + /** + * Op creator object for math operations + */ public final SDMath math = new SDMath(this); - /** Op creator object for random number generation operations */ + /** + * Op creator object for random number generation operations + */ public final SDRandom random = new SDRandom(this); - /** Op creator object for general neural network operations */ + /** + * Op creator object for general neural network operations + */ public final SDNN nn = new SDNN(this); - /** Op creator object for convolutional neural network operations */ + /** + * Op creator object for convolutional neural network operations + */ public final SDCNN cnn = new SDCNN(this); - /** Op creator object for recurrent neural network operations */ + /** + * Op creator object for recurrent neural network operations + */ public final SDRNN rnn = new SDRNN(this); - /** Op creator object for loss function operations */ + /** + * Op creator object for loss function operations + */ public final SDLoss loss = new SDLoss(this); - /** Op creator object for image operations */ + /** + * Op creator object for image operations + */ public final SDImage image = new SDImage(this); - /** Op creator object for math operations */ - public SDMath math(){ + /** + * Op creator object for math operations + */ + public SDMath math() { return math; } - /** Op creator object for random number generation operations */ - public SDRandom random(){ + /** + * Op creator object for random number generation operations + */ + public SDRandom random() { return random; } - /** Op creator object for general neural network operations */ - public SDNN nn(){ + /** + * Op creator object for general neural network operations + */ + public SDNN nn() { return nn; } - /** Op creator object for convolutional neural network operations */ - public SDCNN cnn(){ + /** + * Op creator object for convolutional neural network operations + */ + public SDCNN cnn() { return cnn; } - /** Op creator object for recurrent neural network operations */ - public SDRNN rnn(){ + /** + * Op creator object for recurrent neural network operations + */ + public SDRNN rnn() { return rnn; } - /** Op creator object for loss function operations */ - public SDLoss loss(){ + /** + * Op creator object for loss function operations + */ + public SDLoss loss() { return loss; } - /** Op creator object for image operations */ - public SDImage image(){ + /** + * Op creator object for image operations + */ + public SDImage image() { return image; } @@ -252,7 +288,6 @@ public class SameDiff extends SDBaseOps { private boolean resolvedVariables = false; - @Getter private Stack argumentInterceptors = new Stack<>(); @Getter @@ -325,9 +360,9 @@ public class SameDiff extends SDBaseOps { v.setName(withName); variables.put(withName, v); - for(SameDiffOp op : ops.values()){ + for (SameDiffOp op : ops.values()) { List outputsOfOp = op.getOutputsOfOp(); - if(outputsOfOp != null && !outputsOfOp.isEmpty()) { + if (outputsOfOp != null && !outputsOfOp.isEmpty()) { for (int i = 0; i < outputsOfOp.size(); i++) { if (outputsOfOp.get(i).equals(oldVarName)) { outputsOfOp.set(i, withName); @@ -336,7 +371,7 @@ public class SameDiff extends SDBaseOps { } List inputsToOp = op.getInputsToOp(); - if(inputsToOp != null && !inputsToOp.isEmpty()) { + if (inputsToOp != null && !inputsToOp.isEmpty()) { for (int i = 0; i < inputsToOp.size(); i++) { if (inputsToOp.get(i).equals(oldVarName)) { inputsToOp.set(i, withName); @@ -436,40 +471,41 @@ public class SameDiff extends SDBaseOps { * * @param listeners Listeners */ - public void setListeners(Listener... listeners){ + public void setListeners(Listener... listeners) { this.listeners.clear(); addListeners(listeners); } - public void setListeners(Collection listeners){ + public void setListeners(Collection listeners) { this.listeners.clear(); addListeners(listeners); } - public void addListeners(Listener... listeners){ + public void addListeners(Listener... listeners) { addListeners(Arrays.asList(listeners)); } - public void addListeners(Collection listeners){ + public void addListeners(Collection listeners) { this.listeners.addAll(listeners); } - public List getListeners(){ + public List getListeners() { return listeners; } + /** * @return The current name scope, if any (null otherwise). See {@link #withNameScope(String)} for more details. */ - public String currentNameScope(){ - if(nameScopes.isEmpty()) + public String currentNameScope() { + if (nameScopes.isEmpty()) return null; //Would use String.join but that is Java 8+ StringBuilder sb = new StringBuilder(); boolean first = true; - for(NameScope ns : nameScopes){ - if(!first){ + for (NameScope ns : nameScopes) { + if (!first) { sb.append("/"); } sb.append(ns.getName()); @@ -481,30 +517,30 @@ public class SameDiff extends SDBaseOps { /** * @return The name with the current name scope (if any) appended. See {@link #withNameScope(String)} */ - protected String nameWithScope(String name){ + protected String nameWithScope(String name) { String scope = currentNameScope(); - if(scope == null){ + if (scope == null) { return name; } - if(!name.startsWith(scope + "/")) + if (!name.startsWith(scope + "/")) return scope + "/" + name; else return name; } //Intentionally package private - void addNameScope(NameScope nameScope){ + void addNameScope(NameScope nameScope) { nameScopes.add(nameScope); } //Intentionally package private - void closeNameScope(NameScope nameScope){ + void closeNameScope(NameScope nameScope) { //Check that the name scope is closed correctly/in order Preconditions.checkState(!nameScopes.isEmpty(), "Cannot close name scope: no name scopes are currently defined"); - Preconditions.checkState(nameScopes.get(nameScopes.size()-1).equals(nameScope), + Preconditions.checkState(nameScopes.get(nameScopes.size() - 1).equals(nameScope), "Cannot close name scope %s: Name scopes must be closed in order. Current name scopes: \"%s\"", nameScope, currentNameScope()); - nameScopes.remove(nameScopes.size()-1); + nameScopes.remove(nameScopes.size() - 1); } /** @@ -524,7 +560,7 @@ public class SameDiff extends SDBaseOps { * String zName = z.getVarName(); //RESULT: "z" * } * - * + *

* Note that name scopes can also be nested: *

      *  {@code
@@ -539,30 +575,29 @@ public class SameDiff extends SDBaseOps {
      *  }
      * 
* - * * @param nameScope Name of the name scope to open/create * @return The NameScope object */ - public NameScope withNameScope(String nameScope){ + public NameScope withNameScope(String nameScope) { NameScope ns = new NameScope(this, nameScope); addNameScope(ns); return ns; } - public List getOpsInScope(NameScope scope){ + public List getOpsInScope(NameScope scope) { ArrayList ops = new ArrayList<>(); - for(SameDiffOp v : this.ops.values()){ - if(v.getName().startsWith(scope.getName())) + for (SameDiffOp v : this.ops.values()) { + if (v.getName().startsWith(scope.getName())) ops.add(v); } return ops; } - public List getVariablesInScope(NameScope scope){ + public List getVariablesInScope(NameScope scope) { ArrayList vars = new ArrayList<>(); - for(SDVariable v : variables()){ - if(v.getVarName().startsWith(scope.getName())) + for (SDVariable v : variables()) { + if (v.getVarName().startsWith(scope.getName())) vars.add(v); } return vars; @@ -641,11 +676,11 @@ public class SameDiff extends SDBaseOps { return ops.containsKey(id); } - public DifferentialFunction functionOutputFor(String varName){ - if(variables.get(varName).getOutputOfOp() == null) + public DifferentialFunction functionOutputFor(String varName) { + if (variables.get(varName).getOutputOfOp() == null) return null; String outName = variables.get(varName).getOutputOfOp(); - if(outName == null) + if (outName == null) return null; return ops.get(outName).getOp(); } @@ -677,7 +712,7 @@ public class SameDiff extends SDBaseOps { throw new ND4JIllegalStateException("Function must not be a variable!"); } - if(ops.containsKey(id)){ + if (ops.containsKey(id)) { } else { ops.put(id, SameDiffOp.builder().name(id).op(function).build()); @@ -757,17 +792,17 @@ public class SameDiff extends SDBaseOps { } - public void setArrayForVariable(@NonNull String varName, @NonNull INDArray arr){ + public void setArrayForVariable(@NonNull String varName, @NonNull INDArray arr) { Preconditions.checkState(variables.containsKey(varName), "No variable with name \"%s\" exists", varName); SDVariable v = getVariable(varName); - if(v.isConstant()) { + if (v.isConstant()) { constantArrays.put(varName, new DeviceLocalNDArray(arr)); - } else if(v.getVariableType() == VariableType.VARIABLE) { + } else if (v.getVariableType() == VariableType.VARIABLE) { variablesArrays.put(varName, new DeviceLocalNDArray(arr)); - } else if(v.isPlaceHolder()){ + } else if (v.isPlaceHolder()) { long tid = Thread.currentThread().getId(); - if(!placeholdersPerThread.containsKey(tid)){ + if (!placeholdersPerThread.containsKey(tid)) { placeholdersPerThread.put(tid, new HashMap()); } placeholdersPerThread.get(tid).put(varName, arr); @@ -835,6 +870,7 @@ public class SameDiff extends SDBaseOps { /** * Put or update the shape for the given variable name. Optionally supports clearing the specified variable's * INDArray if it's shape does not match the new shape + * * @param varName Variable name * @param shape Shape to put * @param clearArrayOnShapeMismatch If false: no change to arrays. If true: if an INDArray is defined for the specified @@ -842,9 +878,9 @@ public class SameDiff extends SDBaseOps { * its shape does not match the specified shape */ @Deprecated - public void putOrUpdateShapeForVarName(String varName, long[] shape, boolean clearArrayOnShapeMismatch){ + public void putOrUpdateShapeForVarName(String varName, long[] shape, boolean clearArrayOnShapeMismatch) { Preconditions.checkNotNull(shape, "Cannot put null shape for variable: %s", varName); - if(variableNameToShape.containsKey(varName)){ + if (variableNameToShape.containsKey(varName)) { // updateShapeForVarName(varName, shape, clearArrayOnShapeMismatch); //TODO } else { @@ -871,7 +907,7 @@ public class SameDiff extends SDBaseOps { */ public boolean arrayAlreadyExistsForVarName(String varName) { SDVariable var = getVariable(varName); - switch(var.getVariableType()){ + switch (var.getVariableType()) { case VARIABLE: return variablesArrays.containsKey(varName); case ARRAY: @@ -896,27 +932,27 @@ public class SameDiff extends SDBaseOps { public INDArray getArrForVarName(@NonNull String varName) { Preconditions.checkState(variables.containsKey(varName), "No variable found with name \"%s\"", varName); SDVariable v = variables.get(varName).getVariable(); - switch(v.getVariableType()){ + switch (v.getVariableType()) { case VARIABLE: - if(!variablesArrays.containsKey(varName)) { + if (!variablesArrays.containsKey(varName)) { //VARIBALE type arrays should have a parameter initializer... // we should use this to azy init the array if none is present v.storeAndAllocateNewArray(); } return variablesArrays.get(varName).get(); case CONSTANT: - if(!constantArrays.containsKey(varName)) + if (!constantArrays.containsKey(varName)) return null; return constantArrays.get(varName).get(); case ARRAY: //Only stored in inference session... InferenceSession s = sessions.get(Thread.currentThread().getId()); - if(s == null) + if (s == null) return null; return s.get(varName, InferenceSession.OUTER_FRAME, 0, null, false); case PLACEHOLDER: long tid = Thread.currentThread().getId(); - if(placeholdersPerThread.get(tid) == null || !placeholdersPerThread.get(tid).containsKey(varName)) + if (placeholdersPerThread.get(tid) == null || !placeholdersPerThread.get(tid).containsKey(varName)) return null; return placeholdersPerThread.get(tid).get(varName); default: @@ -931,8 +967,8 @@ public class SameDiff extends SDBaseOps { * @param variable the name of the variable to associate the array with */ public void associateArrayWithVariable(INDArray arr, @NonNull String variable) { - Preconditions.checkState(variables.containsKey(variable), "Cannot associate array with variable \"%s\": " + - "variable \"%s\" does not exist in this SameDiff instance", variable, variable); + Preconditions.checkState(variables.containsKey(variable), "Cannot associate array with variable \"%s\": " + + "variable \"%s\" does not exist in this SameDiff instance", variable, variable); associateArrayWithVariable(arr, this.getVariable(variable)); } @@ -961,16 +997,16 @@ public class SameDiff extends SDBaseOps { } boolean duped = false; - if(arr.isAttached()) { + if (arr.isAttached()) { arr = arr.detach(); duped = true; } - if(arr.isView()) { + if (arr.isView()) { arr = arr.dup(); duped = true; } - if(!duped && variable.getVariableType() == VariableType.VARIABLE) { + if (!duped && variable.getVariableType() == VariableType.VARIABLE) { for (DeviceLocalNDArray otherArr : variablesArrays.values()) { if (otherArr.get() == arr) { //Check for exact same object, to avoid array reuse (can result in unexpected behaviour) arr = arr.dup(); @@ -979,7 +1015,7 @@ public class SameDiff extends SDBaseOps { } } - switch(variable.getVariableType()){ + switch (variable.getVariableType()) { case VARIABLE: variablesArrays.put(variable.getVarName(), new DeviceLocalNDArray(arr)); break; @@ -1002,7 +1038,7 @@ public class SameDiff extends SDBaseOps { long tid = Thread.currentThread().getId(); - if(!placeholdersPerThread.containsKey(tid)){ + if (!placeholdersPerThread.containsKey(tid)) { placeholdersPerThread.put(tid, new HashMap()); } placeholdersPerThread.get(tid).put(variable.getVarName(), arr); @@ -1014,11 +1050,11 @@ public class SameDiff extends SDBaseOps { //putOrUpdateShapeForVarName(variable.getVarName(), arr.shape(), true); //Also update nested SameDiff instances (such as gradient function) - if(sameDiffFunctionInstances != null && sameDiffFunctionInstances.size() > 0){ - for(Map.Entry e : sameDiffFunctionInstances.entrySet()){ + if (sameDiffFunctionInstances != null && sameDiffFunctionInstances.size() > 0) { + for (Map.Entry e : sameDiffFunctionInstances.entrySet()) { SameDiff sd = e.getValue(); SDVariable v = sd.getVariable(variable.getVarName()); - if(v != null){ + if (v != null) { sd.associateArrayWithVariable(arr, v); } } @@ -1047,8 +1083,8 @@ public class SameDiff extends SDBaseOps { * @return Map of variables by name */ public Map variableMap() { - Map ret = new LinkedHashMap<>(); - for(Variable v : variables.values()){ + Map ret = new LinkedHashMap<>(); + for (Variable v : variables.values()) { ret.put(v.getName(), v.getVariable()); } return ret; @@ -1275,7 +1311,7 @@ public class SameDiff extends SDBaseOps { * Also checks for input arguments and updates the graph adding an appropriate edge when the full graph is declared. * * @param variables Variables - arguments for the specified differential function - * @param function Differential function + * @param function Differential function */ public void addOutgoingFor(SDVariable[] variables, DifferentialFunction function) { String[] varNames = new String[variables.length]; @@ -1322,31 +1358,31 @@ public class SameDiff extends SDBaseOps { /** * Add a new argument interceptor to the interceptor stack - * + *

* For internal use only. - * + *

* When a op is added with arguments, most recent argument interceptor is called on it. * If ops are added in that interceptor, the next most recent will be called on their args, and so on. * - * @param interceptor the argument interceptor to add + * @param interceptor the argument interceptor to add */ - public void addArgumentInterceptor(@NonNull ArgumentInterceptor interceptor){ + public void addArgumentInterceptor(@NonNull ArgumentInterceptor interceptor) { argumentInterceptors.push(interceptor); } - private boolean isArgumentInterceptorPaused(@NonNull ArgumentInterceptor interceptor){ + private boolean isArgumentInterceptorPaused(@NonNull ArgumentInterceptor interceptor) { return pausedArgumentInterceptors.contains(interceptor); } - private ArgumentInterceptor getArgumentInterceptorToUse(){ + private ArgumentInterceptor getArgumentInterceptorToUse() { - if(argumentInterceptors.isEmpty()) + if (argumentInterceptors.isEmpty()) return null; ArgumentInterceptor use = argumentInterceptors.peek(); int i = 1; - while(isArgumentInterceptorPaused(use)){ - if(argumentInterceptors.size() - i < 0) + while (isArgumentInterceptorPaused(use)) { + if (argumentInterceptors.size() - i < 0) return null; use = argumentInterceptors.elementAt(argumentInterceptors.size() - i); @@ -1358,51 +1394,51 @@ public class SameDiff extends SDBaseOps { /** * Remote the top (most recently added) argument interceptor - * + *

* For internal use only. */ - public void removeArgumentInterceptor(){ - if(!argumentInterceptors.isEmpty()) + public void removeArgumentInterceptor() { + if (!argumentInterceptors.isEmpty()) argumentInterceptors.pop(); } /** * Pause the top (most recently added) argument interceptor - * + *

* For internal use only. */ - public void pauseArgumentInterceptor(){ + public void pauseArgumentInterceptor() { pausedArgumentInterceptors.add(argumentInterceptors.peek()); } /** * Pause the given argument interceptor - * + *

* For internal use only. * - * @param interceptor the argument interceptor to pause + * @param interceptor the argument interceptor to pause */ - public void pauseArgumentInterceptor(@NonNull ArgumentInterceptor interceptor){ + public void pauseArgumentInterceptor(@NonNull ArgumentInterceptor interceptor) { pausedArgumentInterceptors.add(interceptor); } /** * Unpause the top (most recently added) argument interceptor - * + *

* For internal use only. */ - public void unpauseArgumentInterceptor(){ + public void unpauseArgumentInterceptor() { pausedArgumentInterceptors.remove(argumentInterceptors.peek()); } /** * Unpause the top given argument interceptor - * + *

* For internal use only. * - * @param interceptor the argument interceptor to unpause + * @param interceptor the argument interceptor to unpause */ - public void unpauseArgumentInterceptor(@NonNull ArgumentInterceptor interceptor){ + public void unpauseArgumentInterceptor(@NonNull ArgumentInterceptor interceptor) { pausedArgumentInterceptors.remove(interceptor); } @@ -1416,7 +1452,7 @@ public class SameDiff extends SDBaseOps { ArgumentInterceptor interceptor = getArgumentInterceptorToUse(); - if(interceptor != null) { + if (interceptor != null) { pauseArgumentInterceptor(interceptor); for (int i = 0; i < variables.length; i++) { variables[i] = interceptor.intercept(getVariable(variables[i])).getVarName(); @@ -1436,7 +1472,7 @@ public class SameDiff extends SDBaseOps { //Add function if it doesn't exist //TODO could "not existing" be a bug sometimes? - if(!ops.containsKey(function.getOwnName())){ + if (!ops.containsKey(function.getOwnName())) { ops.put(function.getOwnName(), SameDiffOp.builder().name(function.getOwnName()).op(function).build()); } @@ -1449,7 +1485,7 @@ public class SameDiff extends SDBaseOps { funcs = new ArrayList<>(); this.variables.get(variableName).setInputsForOp(funcs); } - if(!funcs.contains(function.getOwnName())) //Avoid duplicates for function names. + if (!funcs.contains(function.getOwnName())) //Avoid duplicates for function names. funcs.add(function.getOwnName()); } } @@ -1475,7 +1511,7 @@ public class SameDiff extends SDBaseOps { * Replaces the argument at i with newArg for function * Does not use (or remove) ArgumentInterceptor stuff */ - public void replaceArgFor(int i, @NonNull SDVariable newArg, @NonNull DifferentialFunction function){ + public void replaceArgFor(int i, @NonNull SDVariable newArg, @NonNull DifferentialFunction function) { Preconditions.checkArgument(i < function.args().length, "Index out of range: function " + function.getOwnName() + " only has " + function.args().length + " args but you are trying" + @@ -1484,20 +1520,20 @@ public class SameDiff extends SDBaseOps { String oldName = function.arg(i).getVarName(); String newName = newArg.getVarName(); - if(function.arg(i).isPlaceHolder() && !newArg.isPlaceHolder()){ + if (function.arg(i).isPlaceHolder() && !newArg.isPlaceHolder()) { boolean otherPlaceholders = false; - for(int j = 0 ; j < function.argNames().length ; j++){ - if(j == i) + for (int j = 0; j < function.argNames().length; j++) { + if (j == i) continue; - if(function.arg(j).isPlaceHolder()) + if (function.arg(j).isPlaceHolder()) otherPlaceholders = true; } - if(!otherPlaceholders) + if (!otherPlaceholders) placeHolderFunctions.remove(function.getOwnName()); - } else if(!function.arg(i).isPlaceHolder() && newArg.isPlaceHolder()){ - if(!placeHolderFunctions.contains(function.getOwnName())) + } else if (!function.arg(i).isPlaceHolder() && newArg.isPlaceHolder()) { + if (!placeHolderFunctions.contains(function.getOwnName())) placeHolderFunctions.add(function.getOwnName()); } @@ -1512,12 +1548,12 @@ public class SameDiff extends SDBaseOps { funcs = new ArrayList<>(); this.variables.get(newName).setInputsForOp(funcs); } - if(!funcs.contains(function.getOwnName())) //Avoid duplicates for function names. + if (!funcs.contains(function.getOwnName())) //Avoid duplicates for function names. funcs.add(function.getOwnName()); List oldFuncs = this.variables.get(oldName).getInputsForOp(); - if(oldFuncs != null) { - if(!ArrayUtils.contains(function.argNames(), oldName)) + if (oldFuncs != null) { + if (!ArrayUtils.contains(function.argNames(), oldName)) oldFuncs.remove(function.getOwnName()); } @@ -1531,7 +1567,7 @@ public class SameDiff extends SDBaseOps { */ public DifferentialFunction getVariableOutputFunction(String variableName) { Preconditions.checkState(variables.containsKey(variableName), "No variable with name \"%s\" found in graph", variableName); - if(variables.get(variableName).getOutputOfOp() == null) + if (variables.get(variableName).getOutputOfOp() == null) return null; return ops.get(variables.get(variableName).getOutputOfOp()).getOp(); } @@ -1551,16 +1587,16 @@ public class SameDiff extends SDBaseOps { /** * Clear the placeholder arrays from the SameDiff instance * - * @param allThreads If true: clear the placeholders for all threads. False: clear only for current thread + * @param allThreads If true: clear the placeholders for all threads. False: clear only for current thread */ - public void clearPlaceholders(boolean allThreads){ - if(allThreads){ + public void clearPlaceholders(boolean allThreads) { + if (allThreads) { this.placeholdersPerThread.clear(); } else { long tid = Thread.currentThread().getId(); this.placeholdersPerThread.remove(tid); } - for(SameDiff sd : this.sameDiffFunctionInstances.values()){ + for (SameDiff sd : this.sameDiffFunctionInstances.values()) { sd.clearPlaceholders(allThreads); } } @@ -1569,31 +1605,32 @@ public class SameDiff extends SDBaseOps { * Clear the input arrays to each op. * This is usually not required, under normal SameDiff use */ - public void clearOpInputs(){ - for(SameDiffOp op : ops.values()){ - if(op.getOp() instanceof Op){ + public void clearOpInputs() { + for (SameDiffOp op : ops.values()) { + if (op.getOp() instanceof Op) { Op o = ((Op) op.getOp()); o.setX(null); - if(o.y() != null) { + if (o.y() != null) { o.setY(null); } - } else if(op.getOp() instanceof DynamicCustomOp ){ - DynamicCustomOp o = (DynamicCustomOp)op.getOp(); - o.setInputArguments((INDArray[])null); + } else if (op.getOp() instanceof DynamicCustomOp) { + DynamicCustomOp o = (DynamicCustomOp) op.getOp(); + o.setInputArguments((INDArray[]) null); } } - for(SameDiff sd : this.sameDiffFunctionInstances.values()){ + for (SameDiff sd : this.sameDiffFunctionInstances.values()) { sd.clearOpInputs(); } } /** * Get an array of differential functions that have been defined for this SameDiff instance + * * @return Array of differential functions */ public DifferentialFunction[] functions() { List out = new ArrayList<>(ops.size()); - for(SameDiffOp op : ops.values()){ + for (SameDiffOp op : ops.values()) { out.add(op.getOp()); } return out.toArray(new DifferentialFunction[out.size()]); @@ -1642,6 +1679,7 @@ public class SameDiff extends SDBaseOps { /** * Create a new (empty) SameDiff instance without any functions or variables + * * @return New SameDiff instance */ public static SameDiff create() { @@ -1651,6 +1689,7 @@ public class SameDiff extends SDBaseOps { /** * Clone/duplicate the SameDiff instance, including arrays etc. The returned SameDiff instance should have no * shared state with the original instance + * * @return The cloned SameDiff instance */ public SameDiff dup() { @@ -1664,13 +1703,14 @@ public class SameDiff extends SDBaseOps { /** * Count the number of elements in all arrays, according to {@link SDVariable#getShape()} + * * @return Number of array elements for all variables */ public long numElements() { long ret = 0; for (SDVariable variable : variables()) { long[] shape = variable.getShape(); - if(shape != null) { + if (shape != null) { ret += ArrayUtil.prod(shape); } } @@ -1679,12 +1719,13 @@ public class SameDiff extends SDBaseOps { /** * Returns the inputs (placeholders) for the SameDiff graph + * * @return the inputs for this graph */ public List inputs() { List out = new ArrayList<>(); - for(String s : variables.keySet()){ - if(isPlaceHolder(s)) + for (String s : variables.keySet()) { + if (isPlaceHolder(s)) out.add(s); } return out; @@ -1694,12 +1735,13 @@ public class SameDiff extends SDBaseOps { * Outputs are those variables (not placeholders, constants, etc) that are the output of a function that aren't the * input to any other ops. * Usually these are the output of the last function(s) in the SameDiff instance. + * * @return The (inferred) outputs of the SameDiff instance, in no particular order */ - public List outputs(){ + public List outputs() { List out = new ArrayList<>(); - for(Variable v : variables.values()){ - if(v.getVariable().isConstant() || v.getVariable().isPlaceHolder() || //Exclude constants and placeholders + for (Variable v : variables.values()) { + if (v.getVariable().isConstant() || v.getVariable().isPlaceHolder() || //Exclude constants and placeholders (v.getInputsForOp() != null && !v.getInputsForOp().isEmpty()) || //Exclude variables that are inputs to ops (v.getControlDepsForOp() != null && !v.getControlDepsForOp().isEmpty()) || //Exclude variables that are control dependency inputs to ops (v.getControlDepsForVar() != null && !v.getControlDepsForVar().isEmpty())) { //Exclude variables that are control dependency inputs to other variables (mainly for import of cond etc ops) @@ -1707,17 +1749,17 @@ public class SameDiff extends SDBaseOps { } //Also exclude assert etc ops - doesn't make sense to return these "outputs" to user - if(v.getOutputOfOp() != null){ + if (v.getOutputOfOp() != null) { String opName = v.getOutputOfOp(); SameDiffOp o = ops.get(opName); - if(o.getOp() instanceof Assert){ + if (o.getOp() instanceof Assert) { continue; } //A bit of a hack for TF import: some TF graphs have Switch ops, where the output of one branch isn't consumed // by any ops. Consequently, during execution this "output" might never be available. So we'll exclude the output of execution here // This applies to SameDiff while loops as well - if(o.getOp() instanceof Switch){ + if (o.getOp() instanceof Switch) { continue; } } @@ -1745,18 +1787,19 @@ public class SameDiff extends SDBaseOps { * (b) Via {@link #setLossVariables(String...)}, @link #addLossVariable(String)} or {@link SDVariable#markAsLoss()}
* (c) Via {@link TrainingConfig#setLossVariables(List)}
*/ - public List getLossVariables(){ + public List getLossVariables() { return Collections.unmodifiableList(this.lossVariables); } /** * Clear/remove any existing loss variables, and set the loss variables to the specified variable names.
* See {@link #addLossVariable(String)} for more details + * * @param lossVariableNames Names of variables to be loss function variables */ - public void setLossVariables(String... lossVariableNames){ + public void setLossVariables(@NonNull String... lossVariableNames) { this.lossVariables.clear(); - for(String s : lossVariableNames){ + for (String s : lossVariableNames) { addLossVariable(s); } //After changing loss function variables, we (probably) need to recreate gradient function - as gradient @@ -1764,6 +1807,17 @@ public class SameDiff extends SDBaseOps { sameDiffFunctionInstances.remove("grad"); } + /** + * See {@link #setLossVariables(String...)} + */ + public void setLossVariables(@NonNull SDVariable... lossVariables) { + String[] varNames = new String[lossVariables.length]; + for (int i = 0; i < lossVariables.length; i++) + varNames[i] = lossVariables[i].getVarName(); + + setLossVariables(varNames); + } + /** * Mark the specified variable as a loss function variable. This means that this variable will be minimized via backprop during training.
* This will add the variable as a loss to any others - i.e., if multiple variables are marked as losses, their values will be summed @@ -1772,24 +1826,32 @@ public class SameDiff extends SDBaseOps { * Note also that only ARRAY type SDVariables can be marked as losses to be minimized. That is, we cannot mark the value * of a constant, variable or placeholder to be minimized as doing so would not make sense.
*/ - public void addLossVariable(@NonNull String variableName){ + public void addLossVariable(@NonNull String variableName) { Preconditions.checkState(hasVariable(variableName), "No variable with name \"%s\" exists", variableName); SDVariable v = getVariable(variableName); Preconditions.checkState(v.dataType().isFPType(), "Only floating point type variables can be marked as losses to be minimized." + " SDVariable \"%s\" has datatype %s", variableName, v.dataType()); Preconditions.checkState(v.getVariableType() == VariableType.ARRAY, "Only ARRAY type SDVariables can be marked as losses to be minimized." + " SDVariable \"%s\" has variable type %s", variableName, v.getVariableType()); - if(!lossVariables.contains(variableName)){ + if (!lossVariables.contains(variableName)) { lossVariables.add(variableName); } } + /** + * See {@link #addLossVariable(String)} + */ + public void addLossVariable(@NonNull SDVariable variable) { + addLossVariable(variable.getVarName()); + } + /** * Set the training configuration ({@link TrainingConfig}) for the SameDiff instance. * A TrainingConfig must be set before the SameDiff instance can be trained via the fit methods + * * @param trainingConfig Training configuration */ - public void setTrainingConfig(TrainingConfig trainingConfig){ + public void setTrainingConfig(TrainingConfig trainingConfig) { this.trainingConfig = trainingConfig; } @@ -1800,10 +1862,14 @@ public class SameDiff extends SDBaseOps { * Note that a {@link TrainingConfig} must be set via {@link #setTrainingConfig(TrainingConfig)} before training can * be performed. * - * @param dataSet The DataSet (single minibatch) to peform training on + * @param dataSet The DataSet (single minibatch) to peform training on + * @param listeners Additional listeners to use during this operation + * @return a {@link History} object containing the history information for this training operation + * (evaluations specified in the {@link TrainingConfig}, loss values, and timing information). */ - public void fit(DataSet dataSet){ - fit(new SingletonMultiDataSetIterator(dataSet.toMultiDataSet()), 1, false); + public History fit(@NonNull DataSet dataSet, @NonNull Listener... listeners) { + return fit(new SingletonMultiDataSetIterator(dataSet.toMultiDataSet()), 1, false, + null, 1, listeners); } /** @@ -1811,10 +1877,14 @@ public class SameDiff extends SDBaseOps { * Note that a {@link TrainingConfig} must be set via {@link #setTrainingConfig(TrainingConfig)} before training can * be performed. * - * @param dataSet The DataSet (single minibatch) to peform training on + * @param dataSet The MultiDataSet (single minibatch) to peform training on + * @param listeners Additional listeners to use during this operation + * @return a {@link History} object containing the history information for this training operation + * (evaluations specified in the {@link TrainingConfig}, loss values, and timing information). */ - public void fit(MultiDataSet dataSet){ - fit(new SingletonMultiDataSetIterator(dataSet), 1, false); + public History fit(@NonNull MultiDataSet dataSet, @NonNull Listener... listeners) { + return fit(new SingletonMultiDataSetIterator(dataSet), 1, false, + null, 1, listeners); } /** @@ -1823,12 +1893,34 @@ public class SameDiff extends SDBaseOps { * single input and a single output.
* Note that a {@link TrainingConfig} must be set via {@link #setTrainingConfig(TrainingConfig)} before training can * be performed. + *

+ * A special case of {@link #fit()}. + * + * @param iter The iterator to train the SameDiff instance with + * @param numEpochs The number of epochs for training. Must be > 0 + * @param validationIter The DataSetIterator to use for validation (null to skip validation) + * @param validationFrequency The frequency with which to run validation. 1 is every epoch, 2 is every other, etc. + * @param listeners Additional listeners to use during this operation + * @return a {@link History} object containing the history information for this training operation + * (evaluations specified in the {@link TrainingConfig}, loss values, and timing information). + */ + public History fit(@NonNull DataSetIterator iter, int numEpochs, DataSetIterator validationIter, int validationFrequency, @NonNull Listener... listeners) { + return fit().train(iter, numEpochs).validate(validationIter, validationFrequency).listeners(listeners).exec(); + } + + /** + * See {@link #fit(DataSetIterator, int, DataSetIterator, int, Listener...)}, does not preform validation. + *

+ * A special case of {@link #fit()}. * * @param iter The iterator to train the SameDiff instance with * @param numEpochs The number of epochs for training. Must be > 0 + * @param listeners Additional listeners to use during this operation + * @return a {@link History} object containing the history information for this training operation + * (evaluations specified in the {@link TrainingConfig}, loss values, and timing information). */ - public void fit(DataSetIterator iter, int numEpochs) { - fit(new MultiDataSetIteratorAdapter(iter), numEpochs, true); + public History fit(@NonNull DataSetIterator iter, int numEpochs, @NonNull Listener... listeners) { + return fit().train(iter, numEpochs).listeners(listeners).exec(); } /** @@ -1836,31 +1928,92 @@ public class SameDiff extends SDBaseOps { * This method can both singe input, single output and multi-input, multi-output SameDiff instances
* Note that a {@link TrainingConfig} must be set via {@link #setTrainingConfig(TrainingConfig)} before training can * be performed. + *

+ * A special case of {@link #fit()}. + * + * @param iter The iterator to train the SameDiff instance with + * @param numEpochs The number of epochs for training. Must be > 0 + * @param validationIter The MultiDataSetIterator to use for validation (null to skip validation) + * @param validationFrequency The frequency with which to run validation. 1 is every epoch, 2 is every other, etc. + * @param listeners Additional listeners to use during this operation + * @return a {@link History} object containing the history information for this training operation + * (evaluations specified in the {@link TrainingConfig}, loss values, and timing information). + */ + public History fit(@NonNull MultiDataSetIterator iter, int numEpochs, MultiDataSetIterator validationIter, int validationFrequency, @NonNull Listener... listeners) { + return fit(iter, numEpochs, true, validationIter, validationFrequency, listeners); + } + + /** + * See {@link #fit(MultiDataSetIterator, int, MultiDataSetIterator, int, Listener...)}, does not preform validation. + *

+ * A special case of {@link #fit()}. * * @param iter The iterator to train the SameDiff instance with * @param numEpochs The number of epochs for training. Must be > 0 + * @param listeners Additional listeners to use during this operation + * @return a {@link History} object containing the history information for this training operation + * (evaluations specified in the {@link TrainingConfig}, loss values, and timing information). */ - public void fit(MultiDataSetIterator iter, int numEpochs){ - fit(iter, numEpochs, true); + public History fit(@NonNull MultiDataSetIterator iter, int numEpochs, @NonNull Listener... listeners) { + return fit().train(iter, numEpochs).listeners(listeners).exec(); + } + + /** + * Set up for a fit operation using {@link FitConfig}. + *

+ * Supports the setting of training data ({@link MultiDataSetIterator} or {@link DataSetIterator}), number of epochs, + * validation data ({@link MultiDataSetIterator} or {@link DataSetIterator}), validation frequency, and additional listeners. + *

+ * Example: train on data for 5 epochs, validating on valData every 2nd epoch + *

+     *     {@code
+     *     SameDiff sd = ...;
+     *     MultiDataSet data = ...;
+     *     MultiDataSet valData = ...;
+     *
+     *     History hist = sd.fit()
+     *         .train(data, 5)
+     *         .validate(valData, 2)
+     *         .exec();
+     *     }
+     * 
+ */ + public FitConfig fit() { + return new FitConfig(this); } //Synchronized for thread safety - protected synchronized void fit(@NonNull MultiDataSetIterator iter, int numEpochs, boolean incrementEpochCount) { + protected synchronized History fit(@NonNull MultiDataSetIterator iter, int numEpochs, boolean incrementEpochCount, + MultiDataSetIterator validationData, int validationFrequency, @NonNull Listener... listeners) { boolean async = iter.asyncSupported(); - if(async){ + + boolean validationAsync = false; + if (validationData != null) + validationAsync = validationData.asyncSupported(); + + if (async) { iter = new AsyncMultiDataSetIterator(iter, 3, true); } - try{ - fitHelper(iter, numEpochs, incrementEpochCount); + + if (validationAsync) { + validationData = new AsyncMultiDataSetIterator(validationData, 3, true); + } + + try { + return fitHelper(iter, numEpochs, incrementEpochCount, validationData, validationFrequency, Arrays.asList(listeners)); } finally { - if(async){ - ((AsyncMultiDataSetIterator)iter).shutdown(); + if (async) { + ((AsyncMultiDataSetIterator) iter).shutdown(); + } + if (validationAsync) { + ((AsyncMultiDataSetIterator) validationData).shutdown(); } } } //fitHelper should only be called from fit method above - protected synchronized void fitHelper(MultiDataSetIterator iter, int numEpochs, boolean incrementEpochCount){ + protected synchronized History fitHelper(@NonNull MultiDataSetIterator iter, int numEpochs, boolean incrementEpochCount, + MultiDataSetIterator validationData, int validationFrequency, @NonNull List listeners) { Preconditions.checkNotNull(iter, "Iterator must not be null"); Preconditions.checkState(numEpochs > 0, "Number of training epochs must be a positive number. Got: %s", numEpochs); Preconditions.checkState(trainingConfig != null, "No training configuration has been set. A training configuration must " + @@ -1868,36 +2021,68 @@ public class SameDiff extends SDBaseOps { Preconditions.checkState(numEpochs == 1 || iter.resetSupported(), "Cannot train for multiple epochs on an iterator that" + " does not support resetting"); - if(!iter.hasNext() && iter.resetSupported()) + HistoryListener history = new HistoryListener(trainingConfig); + + List activeListeners = new ArrayList<>(); + + if (!history.evaluations().isEmpty()) + activeListeners.add(history); + + for (Listener l : this.listeners) + if (l.isActive(Operation.TRAINING)) + activeListeners.add(l); + + for (Listener l : listeners) + if (l.isActive(Operation.TRAINING)) + activeListeners.add(l); + + validateListenerActivations(activeListeners, Operation.TRAINING); + validateListenerActivations(activeListeners, Operation.TRAINING_VALIDATION); + + if (!iter.hasNext() && iter.resetSupported()) iter.reset(); boolean performedValidation = false; int trainThreadNum = 0; long jThreadId = Thread.currentThread().getId(); - boolean hasListeners = !listeners.isEmpty(); + boolean hasListeners = !activeListeners.isEmpty(); At at = At.builder() .epoch(trainingConfig.getEpochCount()) .iteration(trainingConfig.getIterationCount()) .trainingThreadNum(trainThreadNum) .javaThreadNum(jThreadId) + .operation(Operation.TRAINING) .build(); + LossCurve lossCurve = null; - for(int i = 0; i < numEpochs; i++) { - if(incrementEpochCount && hasListeners){ + Set requiredVars = new HashSet<>(); + for (Listener l : activeListeners) { + requiredVars.addAll(l.requiredVariables(this).trainingVariables()); + } + + for (int i = 0; i < numEpochs; i++) { + + if (incrementEpochCount && hasListeners) { at.setEpoch(trainingConfig.getEpochCount()); - for(Listener l : listeners){ + for (Listener l : activeListeners) { l.epochStart(this, at); } } + long epochStartTime = System.currentTimeMillis(); + + double[] lossSums = null; + List lossNames = null; + int lossCount = 0; while (iter.hasNext()) { long dataStart = hasListeners ? System.currentTimeMillis() : 0; org.nd4j.linalg.dataset.api.MultiDataSet ds = iter.next(); + long dataEnd = hasListeners ? System.currentTimeMillis() : 0; - if(!performedValidation){ + if (!performedValidation) { Preconditions.checkState(trainingConfig.getDataSetFeatureMapping().size() == ds.numFeatureArrays(), "The number of dataset feature mapping variables set in the training configuration (%s) must match" + " the number of dataset feature arrays (%s)", trainingConfig.getDataSetFeatureMapping().size(), ds.numFeatureArrays()); @@ -1910,10 +2095,10 @@ public class SameDiff extends SDBaseOps { performedValidation = true; } - if(hasListeners){ + if (hasListeners) { at.setIteration(trainingConfig.getIterationCount()); - for(Listener l : listeners){ - l.iterationStart(this, at, ds, (dataEnd-dataStart)); + for (Listener l : activeListeners) { + l.iterationStart(this, at, ds, (dataEnd - dataStart)); } } @@ -1924,7 +2109,7 @@ public class SameDiff extends SDBaseOps { resolveVariablesWith(placeholders); //Calculate gradients: - execBackwards(placeholders); + execBackwards(placeholders, at.operation(), ds, requiredVars, activeListeners); //Apply updater: @@ -1932,22 +2117,22 @@ public class SameDiff extends SDBaseOps { initializeTraining(); Map, AtomicDouble> regScore = null; //Holds regularization scores for later reporting to listeners - if(hasListeners){ + if (hasListeners) { regScore = new HashMap<>(); } int iteration = trainingConfig.getIterationCount(); int e = trainingConfig.getEpochCount(); - for(Variable v : variables.values()){ + for (Variable v : variables.values()) { //Only update trainable params - float type parameters (variable type vars) SDVariable sdv = v.getVariable(); - if(sdv.getVariableType() != VariableType.VARIABLE || !sdv.dataType().isFPType()) + if (sdv.getVariableType() != VariableType.VARIABLE || !sdv.dataType().isFPType()) continue; INDArray param = sdv.getArr(); SDVariable gradVar = sdv.getGradient(); - if(gradVar == null){ + if (gradVar == null) { //Not all trainable parameters have gradients defined. //Consider graph: in1->loss1; in2->loss2, where we optimize only loss1. //No gradient will be present for in2, because in2 doesn't impact loss1 at all @@ -1962,9 +2147,9 @@ public class SameDiff extends SDBaseOps { int iterCount = trainingConfig.getIterationCount(); int epochCount = trainingConfig.getEpochCount(); double lr = trainingConfig.getUpdater().hasLearningRate() ? trainingConfig.getUpdater().getLearningRate(iteration, epochCount) : 1.0; - if(r != null && r.size() > 0){ - for(Regularization reg : r){ - if(reg.applyStep() == Regularization.ApplyStep.BEFORE_UPDATER){ + if (r != null && r.size() > 0) { + for (Regularization reg : r) { + if (reg.applyStep() == Regularization.ApplyStep.BEFORE_UPDATER) { reg.apply(param, grad, lr, iterCount, epochCount); } } @@ -1982,13 +2167,13 @@ public class SameDiff extends SDBaseOps { } //Post-apply regularization (weight decay) - if(r != null && r.size() > 0){ - for(Regularization reg : r){ - if(reg.applyStep() == Regularization.ApplyStep.POST_UPDATER){ + if (r != null && r.size() > 0) { + for (Regularization reg : r) { + if (reg.applyStep() == Regularization.ApplyStep.POST_UPDATER) { reg.apply(param, grad, lr, iterCount, epochCount); - if(hasListeners){ + if (hasListeners) { double score = reg.score(param, iterCount, epochCount); - if(!regScore.containsKey(reg.getClass())){ + if (!regScore.containsKey(reg.getClass())) { regScore.put(reg.getClass(), new AtomicDouble()); } regScore.get(reg.getClass()).addAndGet(score); @@ -1997,9 +2182,10 @@ public class SameDiff extends SDBaseOps { } } - if(hasListeners){ - for(Listener l : listeners){ - l.preUpdate(this, at, v, reshapedView); + if (hasListeners) { + for (Listener l : activeListeners) { + if (l.isActive(at.operation())) + l.preUpdate(this, at, v, reshapedView); } } @@ -2011,15 +2197,15 @@ public class SameDiff extends SDBaseOps { } } - if(hasListeners){ + if (hasListeners) { double[] d = new double[lossVariables.size() + regScore.size()]; List lossVars; - if(regScore.size() > 0){ + if (regScore.size() > 0) { lossVars = new ArrayList<>(lossVariables.size() + regScore.size()); lossVars.addAll(lossVariables); - int s=regScore.size(); + int s = regScore.size(); //Collect regularization losses - for(Map.Entry,AtomicDouble> entry : regScore.entrySet()){ + for (Map.Entry, AtomicDouble> entry : regScore.entrySet()) { lossVars.add(entry.getKey().getSimpleName()); d[s] = entry.getValue().get(); } @@ -2030,35 +2216,153 @@ public class SameDiff extends SDBaseOps { //Collect the losses... SameDiff gradFn = sameDiffFunctionInstances.get(GRAD_FN_KEY); - int count=0; - for(String s : lossVariables){ + int count = 0; + for (String s : lossVariables) { INDArray arr = gradFn.getArrForVarName(s); double l = arr.isScalar() ? arr.getDouble(0) : arr.sumNumber().doubleValue(); d[count++] = l; } Loss loss = new Loss(lossVars, d); - for(Listener l : listeners){ + + if (lossNames == null) { + lossNames = lossVars; + } else { + Preconditions.checkState(lossNames.equals(lossVars), + "Loss names mismatch, expected: %s, got: %s", lossNames, lossVars); + } + + if (lossSums == null) { + lossSums = d; + } else { + Preconditions.checkState(lossNames.equals(lossVars), + "Loss size mismatch, expected: %s, got: %s", lossSums.length, d.length); + + for (int j = 0; j < lossSums.length; j++) { + lossSums[j] += d[j]; + } + } + lossCount++; + + for (Listener l : activeListeners) { l.iterationDone(this, at, ds, loss); } + } trainingConfig.incrementIterationCount(); } - if(incrementEpochCount) { - if(hasListeners){ - for(Listener l : listeners){ - l.epochEnd(this, at); + long epochTime = System.currentTimeMillis() - epochStartTime; + + if (incrementEpochCount && hasListeners) { + for (int j = 0; j < lossSums.length; j++) + lossSums[j] /= lossCount; + + if (lossCurve != null) + lossCurve = lossCurve.addLossAndCopy(lossSums, lossNames); + else + lossCurve = new LossCurve(lossSums, lossNames); + } + + if (incrementEpochCount) { + if (hasListeners) { + + boolean doStop = false; + Listener stopped = null; + + for (Listener l : activeListeners) { + + ListenerResponse res = l.epochEnd(this, at, lossCurve, epochTime); + + if (res == ListenerResponse.STOP && (i < numEpochs - 1)) { + doStop = true; + stopped = l; + } } + + if (doStop) { + log.info("Stopping training early. Listener " + stopped + " gave a STOP signal at epoch " + at.epoch() + " and iteration " + at.iteration()); + + for (Listener l1 : activeListeners) + l1.operationEnd(this, Operation.TRAINING); + + if (i < numEpochs - 1) { + iter.reset(); + } + + if (incrementEpochCount) + trainingConfig.incrementEpochCount(); + return history.getReport(); + } + + + //validation evaluation + if (validationData != null && (validationFrequency <= 0 || i % validationFrequency == 0)) { + + long validationStart = System.currentTimeMillis(); + outputHelper(validationData, new At(at.epoch(), 0, 0, 0, Operation.TRAINING_VALIDATION), + listeners); + + long validationTime = System.currentTimeMillis() - validationStart; + + boolean doStopV = false; + Listener stoppedV = null; + for (Listener l : activeListeners) { + + ListenerResponse res = l.validationDone(this, at, validationTime); + + if (res == ListenerResponse.STOP && (i < numEpochs - 1)) { + doStopV = true; + stoppedV = l; + } + } + + if (doStopV) { + log.info("Stopping training early from validation. Listener " + stoppedV + " gave a STOP signal at epoch " + at.epoch() + " and iteration " + at.iteration()); + + for (Listener l1 : activeListeners) + l1.operationEnd(this, Operation.TRAINING); + + if (i < numEpochs - 1) { + iter.reset(); + } + + if (incrementEpochCount) + trainingConfig.incrementEpochCount(); + + return history.getReport(); + } + + } + } + trainingConfig.incrementEpochCount(); } - if(i < numEpochs - 1) { + if (i < numEpochs - 1) { iter.reset(); } } + + for (Listener l1 : activeListeners) + l1.operationEnd(this, Operation.TRAINING); + + return history.getReport(); + } + + /** + * Ensure the specified listeners do not request any activations that aren't present for the given operation + */ + private void validateListenerActivations(List listeners, Operation op) { + for (Listener l : listeners) { + for (String s : l.requiredVariables(this).requiredVariables(op)) { + if (!variables.containsKey(s)) { + Preconditions.checkState(false, "Listener %s requested variable %s that is not defined in this SameDiff graph", l, s); + } + } + } } /** @@ -2072,19 +2376,19 @@ public class SameDiff extends SDBaseOps { Preconditions.checkState(trainingConfig != null, "No training configuration has been set. A training configuration must " + "be set before calculating the L2 loss. Use setTrainingConfig(TrainingConfig)"); - if(trainingConfig.getRegularization() == null || trainingConfig.getRegularization().isEmpty()){ + if (trainingConfig.getRegularization() == null || trainingConfig.getRegularization().isEmpty()) { return 0.0; } List l = trainingConfig.getRegularization(); double loss = 0.0; - for(Variable v : variables.values()){ + for (Variable v : variables.values()) { SDVariable sdv = v.getVariable(); - if(sdv.getVariableType() != VariableType.VARIABLE || !sdv.dataType().isFPType()){ + if (sdv.getVariableType() != VariableType.VARIABLE || !sdv.dataType().isFPType()) { //Only trainable parameters (FP and variable type vars) contribute to regularization score continue; } - for(Regularization r : l){ + for (Regularization r : l) { INDArray arr = sdv.getArr(); loss += r.score(arr, trainingConfig.getIterationCount(), trainingConfig.getEpochCount()); } @@ -2097,14 +2401,14 @@ public class SameDiff extends SDBaseOps { * 1. Infer the set of trainable parameters - unless specified manually by the user * 2. Set up the updaters */ - protected void initializeTraining(){ - if(!initializedTraining) { - if(trainingConfig == null) { + protected void initializeTraining() { + if (!initializedTraining) { + if (trainingConfig == null) { throw new ND4JIllegalStateException("Please specify a training config with setTrainingConfig"); } updaterMap = new HashMap<>(); - for(Variable v : variables.values()){ - if(v.getVariable().getVariableType() != VariableType.VARIABLE || !v.getVariable().dataType().isFPType()){ + for (Variable v : variables.values()) { + if (v.getVariable().getVariableType() != VariableType.VARIABLE || !v.getVariable().dataType().isFPType()) { //Skip non-trainable parameters continue; } @@ -2126,24 +2430,24 @@ public class SameDiff extends SDBaseOps { * @param ds MultiDataSet - source of the features/labels * @return MultiDataSet converted to a Map, based on TrainingConfig */ - private Map toPlaceholderMap(org.nd4j.linalg.dataset.api.MultiDataSet ds) { - Map placeholders = new HashMap<>(); + private Map toPlaceholderMap(org.nd4j.linalg.dataset.api.MultiDataSet ds) { + Map placeholders = new HashMap<>(); int count = 0; - for(String s : trainingConfig.getDataSetFeatureMapping()){ + for (String s : trainingConfig.getDataSetFeatureMapping()) { placeholders.put(s, ds.getFeatures(count++)); } count = 0; - if(trainingConfig.getDataSetLabelMapping() != null) { + if (trainingConfig.getDataSetLabelMapping() != null) { //Labels may be null in some models (unsupervised etc) for (String s : trainingConfig.getDataSetLabelMapping()) { placeholders.put(s, ds.getLabels(count++)); } } - if(trainingConfig.getDataSetFeatureMaskMapping() != null && trainingConfig.getDataSetFeatureMaskMapping().size() > 0){ + if (trainingConfig.getDataSetFeatureMaskMapping() != null && trainingConfig.getDataSetFeatureMaskMapping().size() > 0) { count = 0; - for(String s : trainingConfig.getDataSetFeatureMaskMapping()){ - if(s == null) { + for (String s : trainingConfig.getDataSetFeatureMaskMapping()) { + if (s == null) { count++; continue; } @@ -2151,10 +2455,10 @@ public class SameDiff extends SDBaseOps { } } - if(trainingConfig.getDataSetLabelMaskMapping() != null && trainingConfig.getDataSetLabelMaskMapping().size() > 0){ + if (trainingConfig.getDataSetLabelMaskMapping() != null && trainingConfig.getDataSetLabelMaskMapping().size() > 0) { count = 0; - for(String s : trainingConfig.getDataSetLabelMaskMapping()){ - if(s == null) { + for (String s : trainingConfig.getDataSetLabelMaskMapping()) { + if (s == null) { count++; continue; } @@ -2171,41 +2475,57 @@ public class SameDiff extends SDBaseOps { * {@code Evaluation e = new Evaluation(); * sameDiff.evaluate(iterator, "softmax", e);} * + *

+ * A special case of {@link #evaluate()}. * * @param iterator Iterator as source of data to evaluate * @param outputVariable The variable to evaluate + * @param listeners Additional listeners to use during this operation. * @param evaluations The evaluations to perform */ - public void evaluate(DataSetIterator iterator, String outputVariable, IEvaluation... evaluations) { + public void evaluate(@NonNull DataSetIterator iterator, @NonNull String outputVariable, @NonNull List listeners, @NonNull IEvaluation... evaluations) { Preconditions.checkArgument(evaluations != null && evaluations.length > 0, "No evaluations were passed to the evaluate method"); - evaluate(new MultiDataSetIteratorAdapter(iterator), Collections.singletonMap(outputVariable, Arrays.asList(evaluations)), - Collections.singletonMap(outputVariable, 0)); + + evaluate().data(iterator).evaluate(outputVariable, evaluations).listeners(listeners.toArray(new Listener[0])).exec(); + } + + /** + * See {@link #evaluate(DataSetIterator, String, List, IEvaluation[])}. + *

+ * A special case of {@link #evaluate()}. + */ + public void evaluate(@NonNull DataSetIterator iterator, @NonNull String outputVariable, @NonNull IEvaluation... evaluations) { + evaluate().data(iterator).evaluate(outputVariable, evaluations).exec(); } /** * Evaluation for multiple-output networks.
- * See {@link #evaluate(MultiDataSetIterator, Map, Map)} + * See {@link #evaluate(MultiDataSetIterator, Map, Map, Listener[])}. + *

+ * A special case of {@link #evaluate()}. */ - public void evaluate(DataSetIterator iterator, Map variableEvals){ - Map map = new HashMap<>(); - Map> variableEvalsList = new HashMap<>(); - for(String s : variableEvals.keySet()){ + public void evaluate(@NonNull DataSetIterator iterator, @NonNull Map variableEvals, @NonNull Listener... listeners) { + Map map = new HashMap<>(); + Map> variableEvalsList = new HashMap<>(); + for (String s : variableEvals.keySet()) { map.put(s, 0); //Only 1 possible output here with DataSetIterator variableEvalsList.put(s, Collections.singletonList(variableEvals.get(s))); } - evaluate(new MultiDataSetIteratorAdapter(iterator), variableEvalsList, map); + evaluate(new MultiDataSetIteratorAdapter(iterator), variableEvalsList, map, listeners); } /** - * Evaluation for multiple output networks - one ore more - * See {@link #evaluate(MultiDataSetIterator, Map, Map)} + * Evaluation for multiple output networks - one or more. + * See {@link #evaluate(MultiDataSetIterator, Map, Map, Listener[])}. + *

+ * A special case of {@link #evaluate()}. */ - public void evaluateMultiple(DataSetIterator iterator, Map> variableEvals){ - Map map = new HashMap<>(); - for(String s : variableEvals.keySet()){ + public void evaluateMultiple(DataSetIterator iterator, Map> variableEvals, @NonNull Listener... listeners) { + Map map = new HashMap<>(); + for (String s : variableEvals.keySet()) { map.put(s, 0); //Only 1 possible output here with DataSetIterator } - evaluate(new MultiDataSetIteratorAdapter(iterator), variableEvals, map); + evaluate(new MultiDataSetIteratorAdapter(iterator), variableEvals, map, listeners); } /** @@ -2215,25 +2535,38 @@ public class SameDiff extends SDBaseOps { * {@code Evaluation e = new Evaluation(); * sameDiff.evaluate(iterator, "softmax", e);} * + *

+ * A special case of {@link #evaluate()}. * * @param iterator Iterator as source of data to evaluate * @param outputVariable The variable to evaluate * @param labelIndex The index of the target variable's labels in the iterator + * @param listeners Additional listeners to use during this operation. * @param evaluations The evaluations to perform */ - public void evaluate(MultiDataSetIterator iterator, String outputVariable, int labelIndex, IEvaluation... evaluations) { + public void evaluate(@NonNull MultiDataSetIterator iterator, @NonNull String outputVariable, int labelIndex, + @NonNull List listeners, @NonNull IEvaluation... evaluations) { Preconditions.checkArgument(evaluations != null && evaluations.length > 0, "No evaluations were passed to the evaluate method"); - evaluate(iterator, Collections.singletonMap(outputVariable, Arrays.asList(evaluations)), - Collections.singletonMap(outputVariable, labelIndex)); + + evaluate().data(iterator).evaluate(outputVariable, labelIndex, evaluations).listeners(listeners.toArray(new Listener[0])).exec(); } /** - * Perform evaluation using classes such as {@link org.nd4j.evaluation.classification.Evaluation} for classifier outputs + * See {@link #evaluate(MultiDataSetIterator, String, int, List, IEvaluation[])}. + *

+ * A special case of {@link #evaluate()}. + */ + public void evaluate(@NonNull MultiDataSetIterator iterator, @NonNull String outputVariable, int labelIndex, @NonNull IEvaluation... evaluations) { + evaluate().data(iterator).evaluate(outputVariable, labelIndex, evaluations).exec(); + } + + /** + * Perform evaluation using classes such as {@link Evaluation} for classifier outputs * and {@link org.nd4j.evaluation.regression.RegressionEvaluation} for regression outputs.
*
* Example: classifier evaluation
* Predictions variable name: "softmaxOutput"
- * Evaluations to perform: {@link org.nd4j.evaluation.classification.Evaluation}
+ * Evaluations to perform: {@link Evaluation}
* Data: single input, single output MultiDataSets
* Code:
*

@@ -2243,61 +2576,174 @@ public class SameDiff extends SDBaseOps {
      * Map labelMapping = Collections.singletonMap("softmaxOutput",0);  //Compare: "softmaxOutput" vs. MultiDataSet.getLabels(0)
      * }
      * 
+ *

+ * A special case of {@link #evaluate()}. * * @param iterator The iterator - the source of the data for evaluation * @param variableEvals The evaluations to perform. Key: the name of the variable. Value: the evaluations to perform * @param predictionLabelMapping The output/label mapping. Key: the name of the variable. + * @param listeners Additional listeners to use during this operation. */ - public void evaluate(MultiDataSetIterator iterator, Map> variableEvals, Map predictionLabelMapping){ + public void evaluate(MultiDataSetIterator iterator, Map> variableEvals, Map predictionLabelMapping, Listener... listeners) { + evaluateHelper(iterator, variableEvals, predictionLabelMapping, At.defaultAt(Operation.EVALUATION), listeners); + } + + + /** + * Set up for a evaluation operation using EvaluationConfig. + *

+ * Supports the setting of the data ({@link MultiDataSetIterator} or {@link DataSetIterator}), + * adding evaluations for variables (with optional label index setting), setting label indices, + * and setting additional listeners. + * Does not require setting label indices when using a {@link DataSetIterator}. + *

+ * Also supports using {@link SDVariable} instances instead of variable names. + * + *

+ * Example: evaluate "pred" with {@link Evaluation} and {@link ROC}, using label 0. + *

+     *      {@code
+     *     SameDiff sd = ...;
+     *     MultiDataSetIterator data = ...;
+     *
+     *     EvaluationRecord results = sd.evaluate()
+     *         .data(data)
+     *         .evaluate("pred", 0, new Evaluation(), new ROC()),
+     *         .exec();
+     *      }
+     *  
+ * Example: evaluate "pred" with {@link Evaluation}, using the only label from a DataSetIterator. + *
+     *      {@code
+     *     SameDiff sd = ...;
+     *     DataSetIterator singleData = ...;
+     *
+     *     EvaluationRecord results = sd.evaluate()
+     *         .data(singleData)
+     *         .evaluate("pred", new Evaluation()),
+     *         .exec();
+     *      }
+     *  
+ */ + public EvaluationConfig evaluate() { + return new EvaluationConfig(this); + } + + /** + * Helper method for evaluations. Should only be called from the above evaluate method + */ + private void evaluateHelper(MultiDataSetIterator iterator, + Map> variableEvals, Map predictionLabelMapping, At at, @NonNull Listener... listeners) { Preconditions.checkState(trainingConfig != null, "Training config has not been set"); Preconditions.checkState(variableEvals.keySet().equals(predictionLabelMapping.keySet()), "Keysets for variable evaluations" + " and for the prediction label mapping must be equal. Keys for variables to evaluate: %s vs. keys for label mapping: %s", variableEvals.keySet(), predictionLabelMapping.keySet()); - if(!iterator.hasNext() && iterator.resetSupported()) + List activeListeners = new ArrayList<>(); + + for (Listener l : listeners) + if (l.isActive(at.operation())) + activeListeners.add(l); + + for (Listener l : this.listeners) + if (l.isActive(at.operation())) + activeListeners.add(l); + + validateListenerActivations(activeListeners, at.operation()); + + for (Listener l : activeListeners) + l.operationStart(this, at.operation()); + + boolean hasListeners = !activeListeners.isEmpty(); + + if (!iterator.hasNext() && iterator.resetSupported()) iterator.reset(); + Set requiredVars = new HashSet<>(variableEvals.keySet()); - List reqVars = new ArrayList<>(variableEvals.keySet()); - - while(iterator.hasNext()){ - MultiDataSet ds = iterator.next(); - Map placeholderMap = toPlaceholderMap(ds); - - Map m = output(placeholderMap, reqVars); - - for(Map.Entry> e : variableEvals.entrySet()){ - INDArray prediction = m.get(e.getKey()); - for(IEvaluation eval : e.getValue()){ - //TODO masking, time series, etc - - INDArray label = ds.getLabels(predictionLabelMapping.get(e.getKey())); - eval.eval(label, prediction); - } + if (hasListeners) { + for (Listener l : activeListeners) { + requiredVars.addAll(l.requiredVariables(this).evaluationVariables()); } } + + String[] requiredVarsArr = requiredVars.toArray(new String[0]); + + while (iterator.hasNext()) { + long dataStart = hasListeners ? System.currentTimeMillis() : 0; + MultiDataSet ds = iterator.next(); + long dataEnd = hasListeners ? System.currentTimeMillis() : 0; + Map placeholderMap = toPlaceholderMap(ds); + + Map m; + Map outs = null; + if (hasListeners) { + + for (Listener l : activeListeners) { + l.iterationStart(this, at, ds, (dataEnd - dataStart)); + } + + m = directExecHelper(placeholderMap, at, ds, Collections.emptyList(), activeListeners, requiredVarsArr); + } else { + m = directExecHelper(placeholderMap, at, ds, Collections.emptyList(), activeListeners, requiredVarsArr); + } + + + for (Map.Entry> e : variableEvals.entrySet()) { + INDArray prediction = m.get(e.getKey()); + for (IEvaluation eval : e.getValue()) { + //TODO time series, etc + + INDArray label = ds.getLabels(predictionLabelMapping.get(e.getKey())); + INDArray mask = ds.getLabelsMaskArray(predictionLabelMapping.get(e.getKey())); + eval.eval(label, prediction, mask); + } + } + + if (hasListeners) { + for (Listener l : activeListeners) { + Map outVars = Maps.newHashMap( + Maps.filterKeys(outs, + Predicates.in(l.requiredVariables(this).evaluationVariables()))); + l.iterationDone(this, at, ds, null); + } + } + + at.setIteration(at.iteration() + 1); + } + + + for (Listener l : activeListeners) + l.operationEnd(this, at.operation()); } /** - * Do inference on a network with a single input.
+ * Do a single batch inference on a network with a single input.
* For example, if the variable to infer was called "softmax" you would use: *
      * {@code
      * sameDiff.output(iterator, "softmax");}
      * 
* - * @param dataSet The data to evaluate - * @param outputs The variables to evaluate + * @param dataSet The data to evaluate + * @param outputs The variables to evaluate */ - public Map output(DataSet dataSet, String... outputs){ + public Map output(@NonNull DataSet dataSet, @NonNull String... outputs) { return outputBatches(new SingletonMultiDataSetIterator(dataSet.toMultiDataSet()), outputs).get(0); } /** - * Single output inference. - * See {@link #output(DataSet, String...)} + * Do a single batch inference on a network.
+ * For example, if the variable to infer was called "softmax" you would use: + *
+     * {@code
+     * sameDiff.output(iterator, "softmax");}
+     * 
+ * + * @param dataSet The data to evaluate + * @param outputs The variables to evaluate */ - public INDArray outputSingle(DataSet dataSet, String output){ - return output(dataSet, output).get(output); + public Map output(@NonNull MultiDataSet dataSet, @NonNull String... outputs) { + return outputBatches(new SingletonMultiDataSetIterator(dataSet), outputs).get(0); } /** @@ -2307,39 +2753,47 @@ public class SameDiff extends SDBaseOps { * {@code * sameDiff.output(iterator, "softmax");} * - * + *

* Uses concatenation on the outputs of {@link #outputBatches(DataSetIterator, String...)} which may cause issues with some inputs. * RNNs with variable time series length and CNNs with variable image sizes will most likely have issues. + *

+ * Special case of {@link #output()}. * - * @param iterator Iterator as source of data to evaluate - * @param outputs The variables to evaluate + * @param iterator Iterator as source of data to evaluate + * @param listeners Additional listeners to use during this operation. + * @param outputs The variables to evaluate */ - public Map output(DataSetIterator iterator, String... outputs){ - return output(new MultiDataSetIteratorAdapter(iterator), outputs); + public Map output(@NonNull DataSetIterator iterator, @NonNull List listeners, @NonNull String... outputs) { + return output().data(iterator).output(outputs).listeners(listeners.toArray(new Listener[0])).exec(); } + /** + * See {@link #output(DataSetIterator, List, String...)}. No additional listeners. + *

+ * Special case of {@link #output()}. + */ + public Map output(@NonNull DataSetIterator dataSet, @NonNull String... outputs) { + return output().data(dataSet).output(outputs).exec(); + } + + + /** + * See {@link #output(DataSetIterator, List, String...)}, but without the concatenation of batches. + *

+ * Special case of {@link #output()}. + */ + public List> outputBatches(DataSetIterator iterator, List listeners, String... outputs) { + return output().data(iterator).output(outputs).listeners(listeners.toArray(new Listener[0])).execBatches(); + } + + /** * See {@link #output(DataSetIterator, String...)}, but without the concatenation of batches. - * + *

+ * Special case of {@link #output()}. */ - public List> outputBatches(DataSetIterator iterator, String... outputs){ - return outputBatches(new MultiDataSetIteratorAdapter(iterator), outputs); - } - - /** - * Single output inference. - * See {@link #output(DataSetIterator, String...)} - */ - public INDArray outputSingle(DataSetIterator dataSet, String output){ - return output(dataSet, output).get(output); - } - - /** - * Single batched output inference. - * See {@link #output(DataSetIterator, String...)} - */ - public List outputSingleBatches(DataSetIterator dataSet, String output){ - return getSingleOutput(outputBatches(dataSet, output), output); + public List> outputBatches(DataSetIterator iterator, String... outputs) { + return output().data(iterator).output(outputs).execBatches(); } /** @@ -2347,7 +2801,7 @@ public class SameDiff extends SDBaseOps { *
* Example: classifier inference
* Predictions variable name: "softmaxOutput"
- * Evaluations to perform: {@link org.nd4j.evaluation.classification.Evaluation}
+ * Evaluations to perform: {@link Evaluation}
* Data: single output MultiDataSets
* Code:
*

@@ -2356,138 +2810,328 @@ public class SameDiff extends SDBaseOps {
      * sameDiff.output(iterator, "softmaxOutput);
      * }
      * 
- * - * Uses concatenation on the outputs of {@link #outputBatches(MultiDataSetIterator, String...)} which may cause issues with some inputs. - * RNNs with variable time series length and CNNs with variable image sizes will most likely have issues. + *

+ * Special case of {@link #output()}. * * @param iterator The iterator - the source of the data for inference + * @param listeners Additional listeners to use during this operation. * @param outputs The set of outputs to report. If null, defaults to all outputs of this SameDiff. */ - public Map output(MultiDataSetIterator iterator, String... outputs){ - return stackOutputs(outputBatches(iterator, outputs)); + public Map output(@NonNull MultiDataSetIterator iterator, @NonNull List listeners, @NonNull String... outputs) { + return stackOutputs(outputHelper(iterator, At.defaultAt(Operation.INFERENCE), listeners, outputs)); } /** - * See {@link #output(MultiDataSetIterator, String...)}, but without the concatenation of batches. + * See {@link #output(MultiDataSetIterator, List, String...)}. No additional listeners. + *

+ * Special case of {@link #output()}. */ - public List> outputBatches(MultiDataSetIterator iterator, String... outputs){ + public Map output(@NonNull MultiDataSetIterator dataSet, @NonNull String... outputs) { + return output().data(dataSet).output(outputs).exec(); + } + + /** + * Perform inference.
+ *
+ * Example: classifier inference
+ * Predictions variable name: "softmaxOutput"
+ * Evaluations to perform: {@link Evaluation}
+ * Data: single output MultiDataSets
+ * Code:
+ *

+     * {@code
+     * MultiDataSetIterator data = ...
+     * sameDiff.output(iterator, "softmaxOutput);
+     * }
+     * 
+ *

+ * Uses concatenation on the outputs of {@link #outputBatches(MultiDataSetIterator, List, String...)} which may cause issues with some inputs. + * RNNs with variable time series length and CNNs with variable image sizes will most likely have issues. + *

+ * Special case of {@link #output()}. + * + * @param iterator The iterator - the source of the data for inference + * @param listeners Additional listeners to use during this operation. + * @param outputs The set of outputs to report. If null, defaults to all outputs of this SameDiff. + */ + public List> outputBatches(MultiDataSetIterator iterator, List listeners, String... outputs) { + return outputHelper(iterator, At.defaultAt(Operation.INFERENCE), listeners, outputs); + } + + /** + * See {@link #outputBatches(MultiDataSetIterator, List, String...)}. No additional listeners. + *

+ * Special case of {@link #output()}. + */ + public List> outputBatches(MultiDataSetIterator iterator, String... outputs) { + return output().data(iterator).output(outputs).execBatches(); + } + + /** + * Set up for an inference operation using OutputConfig. + * Supports the setting of variables to output, the input data ({@link MultiDataSetIterator} or {@link DataSetIterator}), + * and additional listeners. + * Has exec methods to get results in batches or concatenated, or to get results when there is only + * a single output (again in batches or concatenated). + *

+ * Also supports using {@link SDVariable} instances instead of variable names. + * + *

+ * Example: get the output of pred, with batches concatenated together + *

+     *     {@code
+     *     SameDiff sd = ...;
+     *     MultiDataSet data = ...;
+     *
+     *     INDArray out = sd.output()
+     *         .data(data)
+     *         .output("pred")
+     *         .execSingle();
+     *     }
+     * 
+ */ + public OutputConfig output() { + return new OutputConfig(this); + } + + /** + * Helper method to run inference. Also used for validation + */ + private List> outputHelper(MultiDataSetIterator iterator, At at, @NonNull List listeners, @NonNull String... outputs) { Preconditions.checkState(trainingConfig != null, "Training config has not been set"); - List reqVars; + List activeListeners = new ArrayList<>(); - if(outputs != null){ - reqVars = Arrays.asList(outputs); + for (Listener l : listeners) + if (l.isActive(at.operation())) + activeListeners.add(l); + + for (Listener l : this.listeners) + if (l.isActive(at.operation())) + activeListeners.add(l); + + validateListenerActivations(activeListeners, at.operation()); + + for (Listener l : activeListeners) + l.operationStart(this, at.operation()); + + boolean hasListeners = !activeListeners.isEmpty(); + + List neededOutputs; + + if (outputs != null) { + neededOutputs = Arrays.asList(outputs); } else { - reqVars = outputs(); + neededOutputs = outputs(); } + String[] neededOutputsArr = neededOutputs.toArray(new String[0]); + List> predictions = new ArrayList<>(); - if(!iterator.hasNext() && iterator.resetSupported()) + if (!iterator.hasNext() && iterator.resetSupported()) iterator.reset(); - while(iterator.hasNext()){ - MultiDataSet ds = iterator.next(); - Map placeholderMap = toPlaceholderMap(ds); + Set requiredVars = new HashSet<>(); - predictions.add(output(placeholderMap, reqVars)); + for (Listener l : activeListeners) { + if (at.operation() == Operation.TRAINING_VALIDATION) + requiredVars.addAll(l.requiredVariables(this).validationVariables()); + else + requiredVars.addAll(l.requiredVariables(this).inferenceVariables()); } + while (iterator.hasNext()) { + long dataStart = hasListeners ? System.currentTimeMillis() : 0; + MultiDataSet ds = iterator.next(); + long dataEnd = hasListeners ? System.currentTimeMillis() : 0; + Map placeholderMap = toPlaceholderMap(ds); + + if (hasListeners) { + + for (Listener l : activeListeners) { + l.iterationStart(this, at, ds, (dataEnd - dataStart)); + } + + Map outs = directExecHelper(placeholderMap, at, ds, requiredVars, activeListeners, neededOutputsArr); + + for (Listener l : activeListeners) { + l.iterationDone(this, at, ds, null); + } + + predictions.add(outs); + } else { + predictions.add(directExecHelper(placeholderMap, at, ds, requiredVars, activeListeners, neededOutputsArr)); + } + at.setIteration(at.iteration() + 1); + } + + + for (Listener l : activeListeners) + l.operationEnd(this, at.operation()); + return predictions; } /** - * Single output inference. - * See {@link #output(MultiDataSetIterator, String...)} + * Set up for a single batch inference operation using OutputConfig. + * Supports the setting of placeholder inputs, outputs, and additional listeners. + * Has exec methods to get the single output if only one is requested, or all requested outputs. + *

+ * Also supports using {@link SDVariable} instances instead of variable names. + *

+ * Example: get the value of "out" with placeholders x and y + *

+     *     {@code
+     *     SameDiff sd = ...;
+     *     INDArray xValue = ...;
+     *     INDArray yValue = ...;
+     *     SDVariable y = ...;
+     *
+     *     INDArray outValue = sd.batchOutput()
+     *         .output("out")
+     *         .input("x", xValue)
+     *         .input(y, yValue)
+     *         .execSingle();
+     *     }
+     * 
*/ - public INDArray outputSingle(MultiDataSetIterator dataSet, String output){ - return output(dataSet, output).get(output); + public BatchOutputConfig batchOutput() { + return new BatchOutputConfig(this); } /** - * Single batched output inference. - * See {@link #output(MultiDataSetIterator, String...)} - */ - public List outputSingleBatches(MultiDataSetIterator dataSet, String output){ - return getSingleOutput(outputBatches(dataSet, output), output); - } - - /** - * @deprecated See {@link #outputAll(Map)} + * @deprecated See {@link #outputAll(Map)} and {@link #batchOutput()} */ @Deprecated - public Map execAll(Map placeholders){ + public Map execAll(Map placeholders) { return outputAll(placeholders); } /** - * Do inference for all variables for a single batch + * Do inference for all variables for a single batch. + *

+ * See {@link #output(Map, List, String...)}. + *

+ * Special case of {@link #batchOutput()}. */ - public Map outputAll(Map placeholders){ - List allVars = new ArrayList<>(); - for(Variable v : variables.values()){ - allVars.add(v.getName()); - } - return output(placeholders, allVars.toArray(new String[0])); + public Map outputAll(Map placeholders) { + return batchOutput().outputAll().inputs(placeholders).exec(); } + /** - * @deprecated See {@link #outputSingle(Map, String)} + * @deprecated See {@link #outputSingle(Map, String)} and {@link #batchOutput()} */ @Deprecated - public INDArray execSingle(Map placeholders, String output){ + public INDArray execSingle(Map placeholders, String output) { return outputSingle(placeholders, output); } /** - * Do inference for a single variable for a single batch + * Do inference for a single variable for a single batch. + *

+ * See {@link #output(Map, List, String...)}. + *

+ * Special case of {@link #batchOutput()}. */ - public INDArray outputSingle(Map placeholders, String output){ - return output(placeholders, output).get(output); + public INDArray outputSingle(Map placeholders, String output) { + return batchOutput().output(output).inputs(placeholders).execSingle(); } + /** - * @deprecated See {@link #output(Map, List)} + * @deprecated See {@link #output(Map, List)} and {@link #batchOutput()} */ @Deprecated - public Map exec(Map placeholders, List outputs){ + public Map exec(Map placeholders, List outputs) { return output(placeholders, outputs); } /** - * Do inference for the given variables for a single batch + * Do inference for the given variables for a single batch. + *

+ * See {@link #output(Map, List, String...)}. + *

+ * Special case of {@link #batchOutput()}. */ - public Map output(Map placeholders, List outputs){ - return output(placeholders, outputs.toArray(new String[outputs.size()])); + public Map output(Map placeholders, List outputs) { + return batchOutput().output(outputs.toArray(new String[0])).inputs(placeholders).exec(); } /** - * @deprecated See {@link #output(Map, String...)} + * @deprecated See {@link #output(Map, String...)} and {@link #batchOutput()} */ @Deprecated - public Map exec(Map placeholders, String... outputs) { + public Map exec(Map placeholders, String... outputs) { return output(placeholders, outputs); } + /** + * Do inference for the given variables for a single batch. + *

+ * See {@link #output(Map, List, String...)}. + *

+ * Special case of {@link #batchOutput()}. + */ + public Map output(Map placeholders, String... outputs) { + return batchOutput().output(outputs).inputs(placeholders).exec(); + } + /** - * Do inference for the given variables for a single batch + * Do inference for the given variables for a single batch. + *

+ * Special case of {@link #batchOutput()}. + * + * @param placeholders The values to use for placeholders. + * @param listeners Additional listeners to use during this operation. + * @param outputs The variables to output and return. */ - public Map output(Map placeholders, String... outputs) { - return output(placeholders, false, null, outputs); + public Map output(Map placeholders, @NonNull List listeners, String... outputs) { + return batchOutputHelper(placeholders, listeners, outputs); + } + + protected Map batchOutputHelper(Map placeholders, @NonNull List listeners, String... outputs) { + List activeListeners = new ArrayList<>(); + + for (Listener l : this.listeners) + if (l.isActive(Operation.INFERENCE)) + activeListeners.add(l); + + for (Listener l : listeners) + if (l.isActive(Operation.INFERENCE)) + activeListeners.add(l); + + for (Listener l : activeListeners) { + l.operationStart(this, Operation.INFERENCE); + } + + validateListenerActivations(activeListeners, Operation.INFERENCE); + + Map ret = directExecHelper(placeholders, At.defaultAt(Operation.INFERENCE), null, Collections.emptyList(), activeListeners, outputs); + + for (Listener l : activeListeners) { + l.operationEnd(this, Operation.INFERENCE); + } + return ret; } /** * Do inference for the given variables for a single batch, with training information */ - protected Map output(Map placeholders, boolean training, At at, String... outputs){ + protected Map directExecHelper(Map placeholders, At at, MultiDataSet batch, + Collection requiredActivations, List activeListeners, String... outputs) { + if (at == null) + at = At.defaultAt(); + Preconditions.checkState(outputs != null && outputs.length > 0, "No outputs were specified"); long threadId = Thread.currentThread().getId(); - if(!sessions.containsKey(threadId)){ + if (!sessions.containsKey(threadId)) { log.info("Creating new InferenceSession for thread {}", threadId); sessions.put(threadId, new InferenceSession(this)); } List phNames = inputs(); - if(placeholders == null && phNames != null){ + if (placeholders == null && phNames != null) { //Maybe user set placeholders before calling exec method? placeholders = placeholdersPerThread.get(Thread.currentThread().getId()); } @@ -2495,15 +3139,15 @@ public class SameDiff extends SDBaseOps { //Placeholder validation is performed in InferenceSession InferenceSession is = sessions.get(threadId); - Map ret = is.output(Arrays.asList(outputs), placeholders, listeners, training, at); - return ret; + return is.output(outputs == null ? Collections.emptyList() : Arrays.asList(outputs), + placeholders, batch, requiredActivations, activeListeners, at); } - public SDVariable one(String name, int... shape){ + public SDVariable one(String name, int... shape) { return one(name, Nd4j.defaultFloatingPointType(), shape); } - public SDVariable one(String name, long... shape){ + public SDVariable one(String name, long... shape) { return one(name, Nd4j.defaultFloatingPointType(), shape); } @@ -2531,12 +3175,11 @@ public class SameDiff extends SDBaseOps { } - - public SDVariable zero(String name, long... shape){ + public SDVariable zero(String name, long... shape) { return zero(name, Nd4j.defaultFloatingPointType(), shape); } - public SDVariable zero(String name, int... shape){ + public SDVariable zero(String name, int... shape) { return zero(name, Nd4j.defaultFloatingPointType(), shape); } @@ -2565,21 +3208,23 @@ public class SameDiff extends SDBaseOps { /** * Create an SDVariable with a fixed/constant value, with a generated name
* Constants are not modified by training/backprop. See {@link VariableType} for more details. + * * @param constant Value for the constant SDVariable * @return The created variable */ - public SDVariable constant(@NonNull INDArray constant){ + public SDVariable constant(@NonNull INDArray constant) { return constant(getNewVarName(), constant); } /** * Create an SDVariable with a fixed/constant value
* Constants are not modified by training/backprop. See {@link VariableType} for more details. - * @param name Name of the constant SDVariable + * + * @param name Name of the constant SDVariable * @param constant Value for the constant SDVariable * @return The created variable */ - public SDVariable constant(String name, @NonNull INDArray constant){ + public SDVariable constant(String name, @NonNull INDArray constant) { Preconditions.checkState(!variables.containsKey(name), "Variable with name \"%s\" already exists", name); if (name == null || name.length() < 1) name = getNewVarName(); @@ -2627,7 +3272,7 @@ public class SameDiff extends SDBaseOps { * @param shape the shape of the variable if any * @return SDVariable placeholder */ - public SDVariable placeHolder(@NonNull String name, org.nd4j.linalg.api.buffer.DataType dataType, long...shape) { + public SDVariable placeHolder(@NonNull String name, org.nd4j.linalg.api.buffer.DataType dataType, long... shape) { Preconditions.checkState(!variables.containsKey(name), "Variable already exists with name %s", name); SDVariable ret = new SDVariable(name, VariableType.PLACEHOLDER, this, shape, dataType, null); variables.put(name, Variable.builder().name(name).variable(ret).build()); @@ -2649,7 +3294,7 @@ public class SameDiff extends SDBaseOps { //TODO only allowing null datatype for TF import (it's fixed in a later step) - don't want this in the public API! public SDVariable var(@NonNull String name, @NonNull VariableType variableType, WeightInitScheme weightInitScheme, - org.nd4j.linalg.api.buffer.DataType dataType, long... shape) { + org.nd4j.linalg.api.buffer.DataType dataType, long... shape) { if (name == null || name.length() < 1) @@ -2658,7 +3303,7 @@ public class SameDiff extends SDBaseOps { name = generateNewVarName(name, 0); if (variables.containsKey(name)) { - if(nameScopes.isEmpty()){ + if (nameScopes.isEmpty()) { throw new IllegalArgumentException("Another variable with the name " + name + " already exists (current name scope: \"" + currentNameScope() + "\""); } else { @@ -2670,7 +3315,7 @@ public class SameDiff extends SDBaseOps { SDVariable ret = new SDVariable(name, variableType, this, shape, dataType, weightInitScheme); addVariable(ret); - if(variableType == VariableType.PLACEHOLDER){ + if (variableType == VariableType.PLACEHOLDER) { setOriginalPlaceHolderShape(name, shape); putShapeForVarName(name, shape); } @@ -2682,8 +3327,8 @@ public class SameDiff extends SDBaseOps { * The underlying array will be initialized using the specified weight initilization scheme
* This is a VARIABLE type SDVariable - i.e., must be floating point, and is a trainable parameter. See {@link VariableType} for more details. * - * @param name the name of the variable - * @param shape the shape of the variable + * @param name the name of the variable + * @param shape the shape of the variable * @param weightInitScheme Weight initialization scheme to use to initialize the underlying array * @return the created variable */ @@ -2703,7 +3348,7 @@ public class SameDiff extends SDBaseOps { */ public SDVariable var(String name, org.nd4j.linalg.api.buffer.DataType dataType, long... shape) { Preconditions.checkNotNull(shape != null, "Invalid shape: shape may not be null"); - if(Shape.isPlaceholderShape(shape)){ + if (Shape.isPlaceholderShape(shape)) { return placeHolder(name, dataType, shape); } return var(name, new ZeroInitScheme(), dataType, shape); @@ -2732,7 +3377,7 @@ public class SameDiff extends SDBaseOps { * @param shape the shape of the variable * @return the created variable */ - public SDVariable var(String name, int... shape){ + public SDVariable var(String name, int... shape) { return var(name, Nd4j.defaultFloatingPointType(), shape); } @@ -2745,7 +3390,7 @@ public class SameDiff extends SDBaseOps { * @param shape the shape of the variable * @return the created variable */ - public SDVariable var(String name, long... shape){ + public SDVariable var(String name, long... shape) { return var(name, Nd4j.defaultFloatingPointType(), shape); } @@ -2759,7 +3404,7 @@ public class SameDiff extends SDBaseOps { */ public SDVariable var(String name, org.nd4j.linalg.api.buffer.DataType dataType, int... shape) { Preconditions.checkNotNull(shape, "Invalid shape: shape may not be null"); - if(Shape.isPlaceholderShape(shape)){ + if (Shape.isPlaceholderShape(shape)) { return placeHolder(name, dataType, ArrayUtil.toLongArray(shape)); } return var(name, new ZeroInitScheme(), dataType, ArrayUtil.toLongArray(shape)); @@ -2784,7 +3429,7 @@ public class SameDiff extends SDBaseOps { VariableType vt = v.getVariableType(); NDArraySupplierInitScheme s = null; - switch(vt){ + switch (vt) { case VARIABLE: s = new NDArraySupplierInitScheme(v.getArr()); //Intentional fallthrough @@ -2843,6 +3488,7 @@ public class SameDiff extends SDBaseOps { /** * Create an {@link SDVariable} with a generated name, and assocate the specified array with it.
* This is a VARIABLE type SDVariable - i.e., must be floating point, and is a trainable parameter. See {@link VariableType} for more details. + * * @param arr Array to associate with the new variable * @return New SDVariable * @see #var(String, INDArray) @@ -2854,6 +3500,7 @@ public class SameDiff extends SDBaseOps { /** * Create an {@link SDVariable} with the specified name, and associate the specified array with it
* This is a VARIABLE type SDVariable - i.e., must be floating point, and is a trainable parameter. See {@link VariableType} for more details. + * * @param arr Array to associate with the new variable * @return New SDVariable with the specified name and array */ @@ -2869,16 +3516,16 @@ public class SameDiff extends SDBaseOps { name = getNewVarName(); boolean duped = false; - if(arr.isAttached()) { + if (arr.isAttached()) { arr = arr.detach(); duped = true; } - if(arr.isView()) { + if (arr.isView()) { arr = arr.dup(); duped = true; } - if(!duped) { + if (!duped) { for (DeviceLocalNDArray otherArr : variablesArrays.values()) { if (otherArr.get() == arr) { //Check for exact same object, to avoid array reuse (can result in unexpected behaviour) arr = arr.dup(); @@ -2891,7 +3538,7 @@ public class SameDiff extends SDBaseOps { associateArrayWithVariable(arr, ret); if (ArrayUtil.prod(arr.shape()) == 1) { - try(MemoryWorkspace ws = Nd4j.getMemoryManager().scopeOutOfWorkspaces()) { + try (MemoryWorkspace ws = Nd4j.getMemoryManager().scopeOutOfWorkspaces()) { ret.setScalarValue(Nd4j.scalar(arr.getDouble(0))); } } @@ -2927,17 +3574,17 @@ public class SameDiff extends SDBaseOps { * @param variables Variables to convert to constants * @return The (now constant) SDVariables */ - public void convertToConstants(List variables){ - if(variables.size() == 0) + public void convertToConstants(List variables) { + if (variables.size() == 0) return; boolean allConst = true; - for(SDVariable variable : variables) { + for (SDVariable variable : variables) { if (variable.getVariableType() != VariableType.CONSTANT) { allConst = false; Preconditions.checkState(variable.getVariableType() != VariableType.ARRAY, "Cannot convert variable of type ARRAY to a constant: %s", variable); } } - if(allConst){ + if (allConst) { return; //No op } @@ -2947,15 +3594,15 @@ public class SameDiff extends SDBaseOps { //If gradient function has been defined, remove it (so it will be recreated later) sameDiffFunctionInstances.remove(GRAD_FN_KEY); - for(SDVariable variable : variables ) { + for (SDVariable variable : variables) { String n = variable.getVarName(); INDArray arr = variable.getArr(); Preconditions.checkNotNull(arr, "Could not get array for variable %s: if this is a placeholder, use SDVariable.setArray before converting", variable); constantArrays.put(n, new DeviceLocalNDArray(arr)); variablesArrays.remove(n); - if(!placeholdersPerThread.isEmpty()){ - for(Map m : placeholdersPerThread.values()){ + if (!placeholdersPerThread.isEmpty()) { + for (Map m : placeholdersPerThread.values()) { m.remove(n); } } @@ -2968,34 +3615,34 @@ public class SameDiff extends SDBaseOps { //Remove updater state for now constant variables for (SDVariable v : variables) { GradientUpdater gu = updaterMap.remove(v.getVarName()); - Map m = gu == null ? null : gu.getState(); - if(m != null){ - for(INDArray arr : m.values()){ - if(arr.closeable()) + Map m = gu == null ? null : gu.getState(); + if (m != null) { + for (INDArray arr : m.values()) { + if (arr.closeable()) arr.close(); } } //Also check dataset feature/label mapping - remove any placeholders here... - if(trainingConfig.getDataSetFeatureMapping() != null && trainingConfig.getDataSetFeatureMapping().contains(v.getVarName())){ + if (trainingConfig.getDataSetFeatureMapping() != null && trainingConfig.getDataSetFeatureMapping().contains(v.getVarName())) { List newFM = new ArrayList<>(trainingConfig.getDataSetFeatureMapping()); //New list in case of immutable list newFM.remove(v.getVarName()); trainingConfig.setDataSetFeatureMapping(newFM); } - if(trainingConfig.getDataSetLabelMapping() != null && trainingConfig.getDataSetLabelMapping().contains(v.getVarName())){ + if (trainingConfig.getDataSetLabelMapping() != null && trainingConfig.getDataSetLabelMapping().contains(v.getVarName())) { List newLM = new ArrayList<>(trainingConfig.getDataSetLabelMapping()); newLM.remove(v.getVarName()); trainingConfig.setDataSetLabelMapping(newLM); } - if(trainingConfig.getDataSetFeatureMaskMapping() != null && trainingConfig.getDataSetFeatureMaskMapping().contains(v.getVarName())){ + if (trainingConfig.getDataSetFeatureMaskMapping() != null && trainingConfig.getDataSetFeatureMaskMapping().contains(v.getVarName())) { List newFMM = new ArrayList<>(trainingConfig.getDataSetFeatureMaskMapping()); newFMM.remove(v.getVarName()); trainingConfig.setDataSetFeatureMaskMapping(newFMM); } - if(trainingConfig.getDataSetLabelMaskMapping() != null && trainingConfig.getDataSetLabelMaskMapping().contains(v.getVarName())){ + if (trainingConfig.getDataSetLabelMaskMapping() != null && trainingConfig.getDataSetLabelMaskMapping().contains(v.getVarName())) { List newLMM = new ArrayList<>(trainingConfig.getDataSetLabelMaskMapping()); newLMM.remove(v.getVarName()); trainingConfig.setDataSetLabelMaskMapping(newLMM); @@ -3025,17 +3672,17 @@ public class SameDiff extends SDBaseOps { * As variables, this variable will modified during any subsequent training.
* See also: {@link VariableType} */ - public void convertToVariables(@NonNull List constants){ - if(constants.size() == 0) + public void convertToVariables(@NonNull List constants) { + if (constants.size() == 0) return; boolean allConst = true; - for(SDVariable variable : constants) { + for (SDVariable variable : constants) { if (variable.getVariableType() != VariableType.VARIABLE) { allConst = false; } Preconditions.checkState(variable.getVariableType() != VariableType.ARRAY, "Cannot convert variable of type ARRAY to a variable: %s", variable); } - if(allConst){ + if (allConst) { return; //No op } @@ -3045,15 +3692,15 @@ public class SameDiff extends SDBaseOps { //If gradient function has been defined, remove it (so it will be recreated later) sameDiffFunctionInstances.remove(GRAD_FN_KEY); - for(SDVariable variable : constants) { + for (SDVariable variable : constants) { String n = variable.getVarName(); INDArray arr = variable.getArr(); Preconditions.checkNotNull(arr, "Could not get array for variable %s: if this is a placeholder, use SDVariable.setArray before converting", variable); variablesArrays.put(n, new DeviceLocalNDArray(arr)); constantArrays.remove(n); - if(!placeholdersPerThread.isEmpty()){ - for(Map m : placeholdersPerThread.values()){ + if (!placeholdersPerThread.isEmpty()) { + for (Map m : placeholdersPerThread.values()) { m.remove(n); } } @@ -3075,7 +3722,7 @@ public class SameDiff extends SDBaseOps { GradientUpdater u = trainingConfig.getUpdater().instantiate(stateArr, true); updaterMap.put(v.getVarName(), u); } else { - GradientUpdater u = trainingConfig.getUpdater().instantiate((INDArray)null, true); + GradientUpdater u = trainingConfig.getUpdater().instantiate((INDArray) null, true); updaterMap.put(v.getVarName(), u); } } @@ -3096,18 +3743,18 @@ public class SameDiff extends SDBaseOps { * * @param dataTypeMap Map of SDVariables to change the datatype for. Key = SDVariable name, Value = new datatype */ - public void convertDataTypes(@NonNull Map dataTypeMap){ - if(dataTypeMap.isEmpty()) + public void convertDataTypes(@NonNull Map dataTypeMap) { + if (dataTypeMap.isEmpty()) return; //First: check these are all either constants, variables or placeholders. - for(Map.Entry e : dataTypeMap.entrySet()){ + for (Map.Entry e : dataTypeMap.entrySet()) { String s = e.getKey(); Preconditions.checkState(variables.containsKey(s), "Cannot change datatype of variable \"%s\": No variable with this name exists", s); SDVariable v = variables.get(s).getVariable(); Preconditions.checkState(v.getVariableType() != VariableType.ARRAY, "Cannot change datatype of ARRAY type variable \"%s\": " + "datatype of ARRAY type variables is determined by the datatypes of their inputs plus corresponding "); - if(v.getVariableType() != VariableType.PLACEHOLDER){ + if (v.getVariableType() != VariableType.PLACEHOLDER) { //Can't convert constant or variable between numerical and non-numerical type (not possible to cast) Preconditions.checkState(v.dataType().isNumerical() == e.getValue().isNumerical(), "Cannot convert variables between numerical " + "and non-numerical types: attempting to convert variable \"%s\" from %s to %s", e.getKey(), v.dataType(), e.getValue()); @@ -3115,16 +3762,16 @@ public class SameDiff extends SDBaseOps { } boolean anyChanged = false; - for(Map.Entry e : dataTypeMap.entrySet()){ + for (Map.Entry e : dataTypeMap.entrySet()) { String s = e.getKey(); DataType d = e.getValue(); SDVariable v = variables.get(s).getVariable(); - if(v.dataType() == d) + if (v.dataType() == d) continue; //No-op v.setDataType(d); - switch (v.getVariableType()){ + switch (v.getVariableType()) { case VARIABLE: DeviceLocalNDArray dl = variablesArrays.remove(e.getKey()); INDArray arr = dl.get(); @@ -3138,8 +3785,8 @@ public class SameDiff extends SDBaseOps { constantArrays.put(e.getKey(), new DeviceLocalNDArray(newArr2)); break; case PLACEHOLDER: - Map m = placeholdersPerThread.get(Thread.currentThread().getId()); - if(m != null && m.containsKey(e.getKey())){ + Map m = placeholdersPerThread.get(Thread.currentThread().getId()); + if (m != null && m.containsKey(e.getKey())) { m.put(e.getKey(), m.get(e.getKey()).castTo(d)); } break; @@ -3152,7 +3799,7 @@ public class SameDiff extends SDBaseOps { anyChanged = true; } - if(anyChanged){ + if (anyChanged) { sessions.clear(); //Recalculate datatypes of outputs, and dynamically update them @@ -3166,61 +3813,61 @@ public class SameDiff extends SDBaseOps { * @param from The variable to rename - this variable must exist * @param to The new name for the variable - no variable with this name must already exist */ - public void renameVariable(String from, String to){ + public void renameVariable(String from, String to) { Preconditions.checkState(variables.containsKey(from), "Cannot rename variable \"%s\": no variable with this name exists", from); Preconditions.checkState(!variables.containsKey(to), "Cannot rename variable \"%s\" to name \"%s\": a variable with name \"%s\" already exists", from, to, to); Variable v = variables.get(from); v.setName(to); v.getVariable().setVarName(to); - if(v.getInputsForOp() != null){ - for(String opName : v.getInputsForOp()){ + if (v.getInputsForOp() != null) { + for (String opName : v.getInputsForOp()) { SameDiffOp op = ops.get(opName); List newInputs = new ArrayList<>(op.getInputsToOp()); - while(newInputs.contains(from)){ + while (newInputs.contains(from)) { newInputs.set(newInputs.indexOf(from), to); } op.setInputsToOp(newInputs); } } - if(v.getControlDepsForOp() != null){ - for(String opName : v.getControlDepsForOp()){ + if (v.getControlDepsForOp() != null) { + for (String opName : v.getControlDepsForOp()) { SameDiffOp op = ops.get(opName); List newCDs = new ArrayList<>(op.getControlDeps()); - while(newCDs.contains(from)){ + while (newCDs.contains(from)) { newCDs.set(newCDs.indexOf(from), to); } op.setControlDeps(newCDs); } } - if(v.getControlDepsForVar() != null){ - for(String varName : v.getControlDepsForVar()){ + if (v.getControlDepsForVar() != null) { + for (String varName : v.getControlDepsForVar()) { Variable var = variables.get(varName); List newCDs = new ArrayList<>(var.getControlDeps()); - while(newCDs.contains(from)){ + while (newCDs.contains(from)) { newCDs.set(newCDs.indexOf(from), to); } var.setControlDeps(newCDs); } } - if(v.getControlDeps() != null){ - for(String varName : v.getControlDeps()){ + if (v.getControlDeps() != null) { + for (String varName : v.getControlDeps()) { Variable var = variables.get(varName); List newCDsFor = new ArrayList<>(var.getControlDepsForVar()); - while(newCDsFor.contains(from)){ + while (newCDsFor.contains(from)) { newCDsFor.set(newCDsFor.indexOf(from), to); } var.setControlDepsForVar(newCDsFor); } } - if(v.getOutputOfOp() != null){ + if (v.getOutputOfOp() != null) { SameDiffOp op = ops.get(v.getOutputOfOp()); List newOuts = new ArrayList<>(op.getOutputsOfOp()); - while(newOuts.contains(from)){ + while (newOuts.contains(from)) { newOuts.set(newOuts.indexOf(from), to); } op.setOutputsOfOp(newOuts); @@ -3229,50 +3876,50 @@ public class SameDiff extends SDBaseOps { variables.remove(from); variables.put(to, v); - if(trainingConfig != null){ - if(trainingConfig.getDataSetFeatureMapping() != null && trainingConfig.getDataSetFeatureMapping().contains(from)){ + if (trainingConfig != null) { + if (trainingConfig.getDataSetFeatureMapping() != null && trainingConfig.getDataSetFeatureMapping().contains(from)) { List l = new ArrayList<>(trainingConfig.getDataSetFeatureMapping()); - while(l.contains(from)){ + while (l.contains(from)) { l.set(l.indexOf(from), to); } trainingConfig.setDataSetFeatureMapping(l); } - if(trainingConfig.getDataSetLabelMapping() != null && trainingConfig.getDataSetLabelMapping().contains(from)){ + if (trainingConfig.getDataSetLabelMapping() != null && trainingConfig.getDataSetLabelMapping().contains(from)) { List l = new ArrayList<>(trainingConfig.getDataSetLabelMapping()); - while(l.contains(from)){ + while (l.contains(from)) { l.set(l.indexOf(from), to); } trainingConfig.setDataSetLabelMapping(l); } - if(trainingConfig.getDataSetFeatureMaskMapping() != null && trainingConfig.getDataSetFeatureMaskMapping().contains(from)){ + if (trainingConfig.getDataSetFeatureMaskMapping() != null && trainingConfig.getDataSetFeatureMaskMapping().contains(from)) { List l = new ArrayList<>(trainingConfig.getDataSetFeatureMaskMapping()); - while(l.contains(from)){ + while (l.contains(from)) { l.set(l.indexOf(from), to); } trainingConfig.setDataSetFeatureMaskMapping(l); } - if(trainingConfig.getDataSetLabelMaskMapping() != null && trainingConfig.getDataSetLabelMaskMapping().contains(from)){ + if (trainingConfig.getDataSetLabelMaskMapping() != null && trainingConfig.getDataSetLabelMaskMapping().contains(from)) { List l = new ArrayList<>(trainingConfig.getDataSetLabelMaskMapping()); - while(l.contains(from)){ + while (l.contains(from)) { l.set(l.indexOf(from), to); } trainingConfig.setDataSetLabelMaskMapping(l); } - if(trainingConfig.getLossVariables() != null && trainingConfig.getLossVariables().contains(from)){ + if (trainingConfig.getLossVariables() != null && trainingConfig.getLossVariables().contains(from)) { List l = new ArrayList<>(trainingConfig.getLossVariables()); - while(l.contains(from)){ + while (l.contains(from)) { l.set(l.indexOf(from), to); } trainingConfig.setLossVariables(l); } } - for(SameDiff sd : sameDiffFunctionInstances.values()){ - if(sd.hasVariable(from)){ + for (SameDiff sd : sameDiffFunctionInstances.values()) { + if (sd.hasVariable(from)) { sd.renameVariable(from, to); } } @@ -3320,7 +3967,7 @@ public class SameDiff extends SDBaseOps { return v == null ? null : v.getVariable(); } - public boolean hasVariable(String name){ + public boolean hasVariable(String name) { return variables.containsKey(name); } @@ -3347,7 +3994,7 @@ public class SameDiff extends SDBaseOps { //Gradients are being placed in the inner "grad" function SameDiff instance, but not the outer one if (variables.containsKey(varName) && variables.get(varName).getGradient() != null) { return variables.get(varName).getGradient(); - } else if(sameDiffFunctionInstances.containsKey(GRAD_FN_KEY) && sameDiffFunctionInstances.get(GRAD_FN_KEY).variables.containsKey(varName)){ + } else if (sameDiffFunctionInstances.containsKey(GRAD_FN_KEY) && sameDiffFunctionInstances.get(GRAD_FN_KEY).variables.containsKey(varName)) { return sameDiffFunctionInstances.get(GRAD_FN_KEY).variables.get(varName).getGradient(); } return null; @@ -3364,10 +4011,10 @@ public class SameDiff extends SDBaseOps { * @param varName Name of the variable to check the existence of a gradient variable for * @return True if a gradient variable exists for the specified variable, for the current loss */ - public boolean variableHasGradient(String varName){ + public boolean variableHasGradient(String varName) { Preconditions.checkState(variables.containsKey(varName), "No variable with name \"%s\" exists", varName); SDVariable v = getVariable(varName); - if(!v.dataType().isFPType() || v.isConstant()) + if (!v.dataType().isFPType() || v.isConstant()) return false; return getGradForVariable(varName) != null; @@ -3399,7 +4046,7 @@ public class SameDiff extends SDBaseOps { /** * Get the gradient for the variable with the specified variable name. - * Note that in order to run this function, {@link #execBackwards()} must be executed first. + * Note that in order to run this function, {@link #execBackwards(Map, Operation, MultiDataSet, Collection, List)} must be executed first. * All gradient functions are obtained from the results of the execBackwards call. * * @param varName the variable name to get the gradient variable for. @@ -3418,30 +4065,33 @@ public class SameDiff extends SDBaseOps { /** * Create a new double scalar (rank 0) SDVariable with the specified value + * * @param name Name of the SDVariable * @param value Value to initialize the variable with * @return SDVariable */ public SDVariable scalar(String name, double value) { - try(MemoryWorkspace ws = Nd4j.getMemoryManager().scopeOutOfWorkspaces()) { + try (MemoryWorkspace ws = Nd4j.getMemoryManager().scopeOutOfWorkspaces()) { return var(name, Nd4j.scalar(value)); } } /** * Create a new float scalar (rank 0) SDVariable with the specified value + * * @param name Name of the SDVariable * @param value Value to initialize the variable with * @return SDVariable */ public SDVariable scalar(String name, float value) { - try(MemoryWorkspace ws = Nd4j.getMemoryManager().scopeOutOfWorkspaces()) { + try (MemoryWorkspace ws = Nd4j.getMemoryManager().scopeOutOfWorkspaces()) { return var(name, Nd4j.scalar(value)); } } /** * Create a new integer scalar (rank 0) SDVariable with the specified value + * * @param name Name of the SDVariable * @param value Value to initialize the variable with * @return SDVariable @@ -3454,12 +4104,13 @@ public class SameDiff extends SDBaseOps { /** * Create a new long scalar (rank 0) SDVariable with the specified value + * * @param name Name of the SDVariable * @param value Value to initialize the variable with * @return SDVariable */ public SDVariable scalar(String name, long value) { - try(MemoryWorkspace ws = Nd4j.getMemoryManager().scopeOutOfWorkspaces()) { + try (MemoryWorkspace ws = Nd4j.getMemoryManager().scopeOutOfWorkspaces()) { return var(name, Nd4j.scalar(value)); } } @@ -3473,7 +4124,7 @@ public class SameDiff extends SDBaseOps { * @return SDVariable */ public SDVariable scalar(String name, DataType dataType, Number value) { - try(MemoryWorkspace ws = Nd4j.getMemoryManager().scopeOutOfWorkspaces()) { + try (MemoryWorkspace ws = Nd4j.getMemoryManager().scopeOutOfWorkspaces()) { return var(name, Nd4j.scalar(dataType, value)); } } @@ -3481,21 +4132,23 @@ public class SameDiff extends SDBaseOps { /** * Create a new double scalar constant (rank 0) with the specified value.
* Constants are not modified by training/backprop. See {@link VariableType} for more details. + * * @param value Value to initialize the constant with * @return SDVariable */ - public SDVariable constant(double value){ + public SDVariable constant(double value) { return constant(null, value); } /** * Create a new double scalar constant (rank 0) with the specified value + * * @param name Name of the SDVariable * @param value Value to initialize the constant with * @return SDVariable */ public SDVariable constant(String name, double value) { - try(MemoryWorkspace ws = Nd4j.getMemoryManager().scopeOutOfWorkspaces()) { + try (MemoryWorkspace ws = Nd4j.getMemoryManager().scopeOutOfWorkspaces()) { return constant(name, Nd4j.scalar(value)); } } @@ -3503,6 +4156,7 @@ public class SameDiff extends SDBaseOps { /** * Create a new float scalar constant (rank 0) with the specified value
* Constants are not modified by training/backprop. See {@link VariableType} for more details. + * * @param value Value to initialize the constant with * @return SDVariable */ @@ -3512,18 +4166,20 @@ public class SameDiff extends SDBaseOps { /** * Create a new float scalar constant (rank 0) with the specified value + * * @param name Name of the SDVariable * @param value Value to initialize the constant with * @return SDVariable */ public SDVariable constant(String name, float value) { - try(MemoryWorkspace ws = Nd4j.getMemoryManager().scopeOutOfWorkspaces()) { + try (MemoryWorkspace ws = Nd4j.getMemoryManager().scopeOutOfWorkspaces()) { return constant(name, Nd4j.scalar(value)); } } /** * Create a new integer scalar constant (rank 0) with the specified value + * * @param value Value to initialize the constant with */ public SDVariable constant(int value) { @@ -3532,6 +4188,7 @@ public class SameDiff extends SDBaseOps { /** * Create a new integer scalar constant (rank 0) with the specified value + * * @param name Name of the SDVariable * @param value Value to initialize the constant with * @return SDVariable @@ -3544,6 +4201,7 @@ public class SameDiff extends SDBaseOps { /** * Create a new long scalar constant (rank 0) with the specified value + * * @param value Value to initialize the constant with */ public SDVariable constant(long value) { @@ -3552,11 +4210,12 @@ public class SameDiff extends SDBaseOps { /** * Create a new long scalar constant (rank 0) with the specified value + * * @param name Name of the SDVariable * @param value Value to initialize the constant with */ public SDVariable constant(String name, long value) { - try(MemoryWorkspace ws = Nd4j.getMemoryManager().scopeOutOfWorkspaces()) { + try (MemoryWorkspace ws = Nd4j.getMemoryManager().scopeOutOfWorkspaces()) { return constant(name, Nd4j.scalar(value)); } } @@ -3569,13 +4228,14 @@ public class SameDiff extends SDBaseOps { * @param value Value to initialize the constant with */ public SDVariable constant(String name, DataType dataType, Number value) { - try(MemoryWorkspace ws = Nd4j.getMemoryManager().scopeOutOfWorkspaces()) { + try (MemoryWorkspace ws = Nd4j.getMemoryManager().scopeOutOfWorkspaces()) { return constant(name, Nd4j.scalar(dataType, value)); } } /** * Add the specified variable to this SameDiff instance + * * @param variable Variable to add */ public SDVariable addVariable(SDVariable variable) { @@ -3614,7 +4274,7 @@ public class SameDiff extends SDBaseOps { //are not available - *except for sometimes during import, until all ops/variables have been added* List outputDataTypes = null; - if(!isImport) { + if (!isImport) { List inputDataTypes = new ArrayList<>(); List fnInputs = ops.get(function.getOwnName()).getInputsToOp(); if (fnInputs != null) { @@ -3655,8 +4315,8 @@ public class SameDiff extends SDBaseOps { //Generate new variable name if one with the specified name doesn't exist //Note: output of an op is ARRAY type - activations, not a trainable parameter. Thus has no weight init scheme - org.nd4j.linalg.api.buffer.DataType dataType = isImport ? null : outputDataTypes.get(i); - var = var(generateNewVarName(baseName, i), VariableType.ARRAY, null, dataType, (long[])null); + org.nd4j.linalg.api.buffer.DataType dataType = isImport ? null : outputDataTypes.get(i); + var = var(generateNewVarName(baseName, i), VariableType.ARRAY, null, dataType, (long[]) null); } var.setOutputIndex(i); var.setCreator(function); @@ -3681,14 +4341,14 @@ public class SameDiff extends SDBaseOps { } if (checkGet == null) { //Note: output of an op is ARRAY type - activations, not a trainable parameter. Thus has no weight init scheme - org.nd4j.linalg.api.buffer.DataType dataType = outputDataTypes.get(0); - checkGet = var(baseName, VariableType.ARRAY, null, dataType, (long[])null); + org.nd4j.linalg.api.buffer.DataType dataType = outputDataTypes.get(0); + checkGet = var(baseName, VariableType.ARRAY, null, dataType, (long[]) null); } if (checkGet == null) { //Note: output of an op is ARRAY type - activations, not a trainable parameter. Thus has no weight init scheme - org.nd4j.linalg.api.buffer.DataType dataType = outputDataTypes.get(0); - checkGet = var(baseName, VariableType.ARRAY, null, dataType, (long[])null); + org.nd4j.linalg.api.buffer.DataType dataType = outputDataTypes.get(0); + checkGet = var(baseName, VariableType.ARRAY, null, dataType, (long[]) null); } checkGet.setOutputIndex(0); @@ -3838,11 +4498,11 @@ public class SameDiff extends SDBaseOps { } @Deprecated - public INDArray execAndEndResult(){ + public INDArray execAndEndResult() { List outputs = outputs(); Preconditions.checkState(outputs.size() == 1, "Method can only be used with SameDiff instances with a single output"); long tid = Thread.currentThread().getId(); - Map placeholders = placeholdersPerThread.get(tid); + Map placeholders = placeholdersPerThread.get(tid); return execSingle(placeholders, outputs.get(0)); } @@ -3851,24 +4511,37 @@ public class SameDiff extends SDBaseOps { * After execution, the gradient arrays can be accessed using {@code myVariable.getGradient().getArr()}
* Note: This method by default calculates VARIABLE type SDVariable gradients only (as well as any other * gradients needed to calculate the variable gradients). That is, placeholder, constant, etc gradients are not - * calculated. If these gradients are required, they can be calculated using {@link #execBackwards(Map, List)} instead, + * calculated. If these gradients are required, they can be calculated using {@link #execBackwards(Map, List, Operation, MultiDataSet, Collection, List)} instead, * which allows specifying the set of SDVariables to calculate the gradients for. For example, * {@code execBackwards(placeholders, Arrays.asList(myPlaceholder.gradient().getVarName())}. In some cases, * {@link #createGradFunction()} may need to be called first * * @param placeholders Values for the placeholder variables in the graph. For graphs without placeholders, use null or an empty map */ - public void execBackwards(Map placeholders) { + public void execBackwards(Map placeholders, Operation op) { + execBackwards(placeholders, op, null, Collections.emptyList(), Collections.emptyList()); + } + + /** + * See {@link #execBackwards(Map, Operation)}. + *

+ * Uses {@link Operation#INFERENCE}. + */ + public void execBackwards(Map placeholders) { + execBackwards(placeholders, Operation.INFERENCE); + } + + protected void execBackwards(Map placeholders, Operation op, MultiDataSet batch, Collection requiredActivations, List activeListeners) { if (getFunction(GRAD_FN_KEY) == null) { createGradFunction(); } //Collect (unique) list of gradient names... Set varGradNames = new HashSet<>(); - for(Variable v : variables.values()){ - if(v.getVariable().getVariableType() == VariableType.VARIABLE){ + for (Variable v : variables.values()) { + if (v.getVariable().getVariableType() == VariableType.VARIABLE) { SDVariable g = v.getVariable().gradient(); - if(g != null) { + if (g != null) { //Not all variables can have gradients... for example: suppose graph has 2 independent loss functions, // optimizing only 1 might not require changing all variables varGradNames.add(g.getVarName()); @@ -3877,34 +4550,59 @@ public class SameDiff extends SDBaseOps { } //Also add loss values - we need these so we can report them to listeners... - if(!listeners.isEmpty()){ + if (!listeners.isEmpty()) { varGradNames.addAll(lossVariables); } //Edge case: if no variables, no variable gradients to calculate... - if(varGradNames.isEmpty()){ + if (varGradNames.isEmpty()) { log.warn("Skipping gradient execution (backward pass) - no variables to be calculated (graph does not contain any VARIABLE type SDVariables).\n" + "If gradients for other variables (such as placeholders) are required, use execBackwards(Map, List) instead"); - return; } List vargradNamesList = new ArrayList<>(varGradNames); - execBackwards(placeholders, vargradNamesList); - } - - public void execBackwards(Map placeholders, String... variableGradNamesList){ - execBackwards(placeholders, Arrays.asList(variableGradNamesList)); + execBackwards(placeholders, vargradNamesList, op, batch, requiredActivations, activeListeners); } /** - * As per {@link #execBackwards(Map)}, but the set of gradients to calculate can be specified manually.
+ * See {@link #execBackwards(Map, List, Operation)} + */ + public Map execBackwards(Map placeholders, Operation op, String... variableGradNamesList) { + return execBackwards(placeholders, Arrays.asList(variableGradNamesList), op, null, Collections.emptyList(), Collections.emptyList()); + } + + /** + * See {@link #execBackwards(Map, Operation, String...)}. + *

+ * Uses {@link Operation#INFERENCE}. + */ + public Map execBackwards(Map placeholders, String... variableGradNamesList) { + return execBackwards(placeholders, Operation.INFERENCE, variableGradNamesList); + } + + /** + * As per {@link #execBackwards(Map, Operation, MultiDataSet, Collection, List)}, but the set of gradients to calculate can be specified manually.
* For example, to calculate the gradient for placeholder variable "myPlaceholder", use * {@code execBackwards(placeholders, Arrays.asList(myPlaceholder.gradient().getVarName())}. * - * @param placeholders Values for the placeholder variables in the graph. For graphs without placeholders, use null or an empty map + * @param placeholders Values for the placeholder variables in the graph. For graphs without placeholders, use null or an empty map * @param variableGradNamesList Names of the gradient variables to calculate */ - public void execBackwards(Map placeholders, List variableGradNamesList){ + public Map execBackwards(Map placeholders, List variableGradNamesList, Operation operation) { + return execBackwards(placeholders, variableGradNamesList, operation, null, Collections.emptyList(), Collections.emptyList()); + } + + /** + * See {@link #execBackwards(Map, List, Operation)}. + *

+ * Uses {@link Operation#INFERENCE}. + */ + public Map execBackwards(Map placeholders, List variableGradNamesList) { + return execBackwards(placeholders, variableGradNamesList, Operation.INFERENCE); + } + + protected Map execBackwards(Map placeholders, List variableGradNamesList, Operation operation, + MultiDataSet batch, Collection requiredActivations, List activeListeners) { if (getFunction(GRAD_FN_KEY) == null) { createGradFunction(); } @@ -3912,35 +4610,36 @@ public class SameDiff extends SDBaseOps { log.trace("About to execute backward function"); //Edge case: if no variables, no variable gradients to calculate... - if(variableGradNamesList.isEmpty()){ + if (variableGradNamesList.isEmpty()) { log.warn("Skipping gradient calculation (backward pass) - no variables to be calculated (variableGradNamesList is empty)"); - return; + return Collections.emptyMap(); } SameDiff sd = sameDiffFunctionInstances.get(GRAD_FN_KEY); - sd.listeners = listeners; + sd.listeners.clear(); + sd.listeners.addAll(activeListeners); - At at = new At(0, 0, 0, Thread.currentThread().getId()); - if(trainingConfig != null){ + At at = new At(0, 0, 0, Thread.currentThread().getId(), operation); + if (trainingConfig != null) { at.setIteration(trainingConfig.getIterationCount()); at.setEpoch(trainingConfig.getEpochCount()); } - //TODO is this 'train' flag the best approach? - sd.output(placeholders, trainingConfig != null, at, variableGradNamesList.toArray(new String[variableGradNamesList.size()])); + return sd.directExecHelper(placeholders, at, batch, requiredActivations, activeListeners, variableGradNamesList.toArray(new String[0])); } /** * Returns true if the gradient function has been created - i.e., {@link #createGradFunction()} or {@link #createGradFunction(String...)} * has been called at all + * * @return True if gradient (backprop) function exists */ - public boolean hasGradientFunction(){ + public boolean hasGradientFunction() { return sameDiffFunctionInstances.containsKey(GRAD_FN_KEY); } /** - * Create the gradient function (for calculating gradients via {@link #execBackwards(Map)}) if it is not already defined. + * Create the gradient function (for calculating gradients via {@link #execBackwards(Map, Operation, String[])}) if it is not already defined. * Users do not usually need to call this function manually, as it is called as required in the aforementioned method. *

* If the gradient function already exists, this method is a no-op.
@@ -3949,7 +4648,7 @@ public class SameDiff extends SDBaseOps { * Note that the gradient array (after execBackwards has been called) can be accessed via {@code SDVariable.gradient().getArr()} */ public void createGradFunction() { - createGradFunction((String[])null); + createGradFunction((String[]) null); } /** @@ -3963,15 +4662,15 @@ public class SameDiff extends SDBaseOps { * be calculated and available after backprop has been done */ public void createGradFunction(final String... variablesRequiringGradients) { - if(lossVariables.isEmpty()){ - if(trainingConfig != null && trainingConfig.getLossVariables() != null && !trainingConfig.getLossVariables().isEmpty()){ + if (lossVariables.isEmpty()) { + if (trainingConfig != null && trainingConfig.getLossVariables() != null && !trainingConfig.getLossVariables().isEmpty()) { lossVariables.addAll(trainingConfig.getLossVariables()); } else { List outputs = outputs(); if (outputs.size() == 1) { String outName = outputs.get(0); String opName = variables.get(outName).getOutputOfOp(); - if(opName == null || !(ops.get(opName).getOp() instanceof ExternalErrorsFunction)){ + if (opName == null || !(ops.get(opName).getOp() instanceof ExternalErrorsFunction)) { log.info("Inferring output \"{}\" as loss variable as none were previously set. Use SameDiff.setLossVariables() to override", outputs.get(0)); } lossVariables.add(outputs.get(0)); @@ -3988,9 +4687,9 @@ public class SameDiff extends SDBaseOps { log.trace("Defining function \"grad\""); } - if(variablesRequiringGradients != null && variablesRequiringGradients.length > 0){ + if (variablesRequiringGradients != null && variablesRequiringGradients.length > 0) { //Check that they are FP variables... - for(String s : variablesRequiringGradients){ + for (String s : variablesRequiringGradients) { Preconditions.checkArgument(variables.containsKey(s), "Cannot ensure gradient exists for variable: no variable with name \"%s\" exists", s); DataType dt = variables.get(s).getVariable().dataType(); Preconditions.checkState(dt.isFPType(), "Cannot ensure gradient exists for variable \"%s\": variable is not a floating point SDVariable." + @@ -4032,7 +4731,6 @@ public class SameDiff extends SDBaseOps { */ - final SameDiff outer = this; defineFunction(GRAD_FN_KEY, new SameDiffFunctionDefinition() { @@ -4071,7 +4769,7 @@ public class SameDiff extends SDBaseOps { List finalOutputs = new ArrayList<>(lossVariables.size()); SDVariable initialGrad = sameDiff.var("one-var", Nd4j.scalar(1.0f)); - for(String s : lossVariables){ + for (String s : lossVariables) { Preconditions.checkNotNull(s, "Encountered null value in loss variables. Null loss variables are not allowed." + " Use SameDiff.setLossVariables with non-null array names to fix"); Preconditions.checkState(variables.containsKey(s), "Specified loss function variable \"%s\" does not exist", s); @@ -4079,12 +4777,12 @@ public class SameDiff extends SDBaseOps { Preconditions.checkState(v.dataType().isFPType(), "Specified loss function variable \"%s\" is not a floating" + "point variable (datatype: %s). Only floating point variables may be used as loss function variable", s, v.dataType()); v = v.sum(); //If output is not a scalar: we'll use loss = v.sum(), same as adding loss for multiple outputs. We don't always know for sure if output is scalar at this point - if(v.dataType() == initialGrad.dataType()){ + if (v.dataType() == initialGrad.dataType()) { sameDiff.setGradientForVariableName(v.getVarName(), initialGrad); } else { sameDiff.setGradientForVariableName(v.getVarName(), initialGrad.castTo(v.dataType())); } - if(finalOutputs.contains(v)){ + if (finalOutputs.contains(v)) { log.warn("Loss function variable \"{}\" appears multiple times in list of loss variables - using only first instance", s); } else { finalOutputs.add(v); @@ -4102,26 +4800,26 @@ public class SameDiff extends SDBaseOps { // Find all FP variables that are connected to loss by an floating point (FP16/32/64) path Set allFpVarsConnectedToLoss = new HashSet<>(); Queue toProcess = new LinkedList<>(); - for(String s : lossVariables){ - if(!toProcess.contains(s)){ + for (String s : lossVariables) { + if (!toProcess.contains(s)) { toProcess.add(s); } } - while(!toProcess.isEmpty()){ + while (!toProcess.isEmpty()) { String next = toProcess.remove(); - if(!allFpVarsConnectedToLoss.contains(next)){ + if (!allFpVarsConnectedToLoss.contains(next)) { Variable v = variables.get(next); - if(v.getVariable().dataType().isFPType()){ + if (v.getVariable().dataType().isFPType()) { allFpVarsConnectedToLoss.add(v.getName()); //Work out what op (if any) this is an output of... and add the inputs to that op to be processed - if(v.getOutputOfOp() != null){ + if (v.getOutputOfOp() != null) { String opName = v.getOutputOfOp(); SameDiffOp op = ops.get(opName); List opInputs = op.getInputsToOp(); - if(opInputs != null){ - for(String s : opInputs){ + if (opInputs != null) { + for (String s : opInputs) { Variable inputVar = variables.get(s); - if(inputVar.getVariable().dataType().isFPType()){ + if (inputVar.getVariable().dataType().isFPType()) { //Add this connected floating point type to the list to be processed toProcess.add(s); } @@ -4136,36 +4834,36 @@ public class SameDiff extends SDBaseOps { // Keep removing leaf nodes until only Variable type SDVariables remain Set minimalSubgraphVars = new HashSet<>(allFpVarsConnectedToLoss); Queue leafFPVars = new LinkedList<>(); - for(String s : allFpVarsConnectedToLoss){ + for (String s : allFpVarsConnectedToLoss) { //First: determine if is a FP leaf (Array type SDVariable) Variable v = variables.get(s); - if(v.getVariable().getVariableType() == VariableType.ARRAY){ + if (v.getVariable().getVariableType() == VariableType.ARRAY) { String opName = v.getOutputOfOp(); //Always defined for array type SameDiffOp op = ops.get(opName); List inputsToOp = op.getInputsToOp(); boolean anyInputsInSubgraph = false; - if(inputsToOp != null){ - for(String s2 : inputsToOp){ - if(allFpVarsConnectedToLoss.contains(s2)){ + if (inputsToOp != null) { + for (String s2 : inputsToOp) { + if (allFpVarsConnectedToLoss.contains(s2)) { //Connection s2 -> s exists... therefore s is not a leaf (yet) anyInputsInSubgraph = true; break; } } } - if(!anyInputsInSubgraph){ + if (!anyInputsInSubgraph) { //Mark s as a leaf to be removed leafFPVars.add(s); } } VariableType vt = v.getVariable().getVariableType(); boolean isUserRequested = variablesRequiringGradients != null && ArrayUtils.contains(variablesRequiringGradients, s); - if((vt == VariableType.CONSTANT || vt == VariableType.PLACEHOLDER) && !isUserRequested ){ + if ((vt == VariableType.CONSTANT || vt == VariableType.PLACEHOLDER) && !isUserRequested) { leafFPVars.add(s); } } - while(!leafFPVars.isEmpty()){ + while (!leafFPVars.isEmpty()) { String nextLeaf = leafFPVars.remove(); Variable v = variables.get(nextLeaf); minimalSubgraphVars.remove(nextLeaf); @@ -4176,24 +4874,24 @@ public class SameDiff extends SDBaseOps { //Note that any time we remove a variable, the only possible new leafs are those that this one // is connected to. List inputsTo = v.getInputsForOp(); - if( inputsTo != null && !inputsTo.isEmpty()) { + if (inputsTo != null && !inputsTo.isEmpty()) { for (String opName : inputsTo) { SameDiffOp op = ops.get(opName); List inputsToOp = op.getInputsToOp(); boolean anyPresent = false; - for(String s : inputsToOp){ - if(minimalSubgraphVars.contains(s) || (variablesRequiringGradients != null && ArrayUtils.contains(variablesRequiringGradients, s))){ + for (String s : inputsToOp) { + if (minimalSubgraphVars.contains(s) || (variablesRequiringGradients != null && ArrayUtils.contains(variablesRequiringGradients, s))) { //Note second condition: means user explicitly specified that they want gradients for that input variable... hence we need to diff this op anyPresent = true; break; } } - if(!anyPresent){ + if (!anyPresent) { //All inputs to op X are not in subgraph. Therefore outputs of op must be new leaves List outVars = op.getOutputsOfOp(); - if(outVars != null) { + if (outVars != null) { for (String s : outVars) { - if(!leafFPVars.contains(s)){ + if (!leafFPVars.contains(s)) { //Mark this variable to be processed next leafFPVars.add(s); } @@ -4209,9 +4907,9 @@ public class SameDiff extends SDBaseOps { //At this point: we know the set of variables that are connected to the loss - these all (and only) need gradients Queue availableForDiff = new LinkedList<>(); - for(SDVariable lossVar : finalOutputs){ + for (SDVariable lossVar : finalOutputs) { Variable v = sameDiff.variables.get(lossVar.getVarName()); - if(v.getOutputOfOp() != null){ + if (v.getOutputOfOp() != null) { String opName = v.getOutputOfOp(); availableForDiff.add(opName); } @@ -4222,28 +4920,28 @@ public class SameDiff extends SDBaseOps { //For example, if we have X -> op -> Y, and Y -> (A,B) we need gradient contribution from BOTH // Y->A and Y->B connections before we can do differentiation of op "op" final HashMap> prerequisites = new HashMap<>(); //Key: variable name. Value: list of op names - for(String var : minimalSubgraphVars){ + for (String var : minimalSubgraphVars) { Variable variable = variables.get(var); // Copy the collection, as the original one will be modified during backprop final List inputsForOp = variable.getInputsForOp(); if (inputsForOp != null) { List req = new ArrayList<>(); - for(String opName : inputsForOp){ + for (String opName : inputsForOp) { //Need to filter ops here //For example, if we have: var -> Op1, and var -> Op2 //we might not need to differentiate Op2 if output of Op2 doesn't impact loss function SameDiffOp o = ops.get(opName); List opOutputs = o.getOutputsOfOp(); boolean anyOpOutputsRequired = false; - if(opOutputs != null) { + if (opOutputs != null) { for (String s : opOutputs) { - if(minimalSubgraphVars.contains(s)) { + if (minimalSubgraphVars.contains(s)) { anyOpOutputsRequired = true; break; } } } - if(anyOpOutputsRequired){ + if (anyOpOutputsRequired) { req.add(opName); } } @@ -4252,14 +4950,14 @@ public class SameDiff extends SDBaseOps { } Set differentiatedOps = new HashSet<>(); - while(!availableForDiff.isEmpty()){ + while (!availableForDiff.isEmpty()) { String dfName = availableForDiff.remove(); DifferentialFunction df = sameDiff.ops.get(dfName).getOp(); //Get the inputs and outputs of the op List inputsToOp; List outputsOfOp; - if(df instanceof GradientBackwardsMarker){ + if (df instanceof GradientBackwardsMarker) { SameDiffOp op = sameDiff.ops.get(df.getOwnName()); inputsToOp = op.getInputsToOp(); outputsOfOp = Collections.emptyList(); @@ -4271,18 +4969,18 @@ public class SameDiff extends SDBaseOps { //Get gradients for all output variables: List grads = new ArrayList<>(); - for(String s : outputsOfOp){ + for (String s : outputsOfOp) { SDVariable v = sameDiff.getVariable(s); SDVariable g = v.hasGradient() ? v.gradient() : null; - if(g == null){ + if (g == null) { //If no gradient exists at this point, 3 possibilities: // (a) we have a bug // (b) output of this op isn't used in calculating the loss // (c) output isn't a FP type //In the FP case, we should create a zero variable to backprop, because we can't perform backprop // for this op otherwise... - if(!v.dataType().isFPType()){ + if (!v.dataType().isFPType()) { grads.add(null); } else { //See "Step 3: Differentiate ops in minimal subgraph" above for explanation on why this should be zerosLike here... @@ -4299,10 +4997,10 @@ public class SameDiff extends SDBaseOps { differentiatedOps.add(df.getOwnName()); //Check the inputs to this op, see if we can differentiate those ops now (and if so: add to queue) - for(String s : inputsToOp){ + for (String s : inputsToOp) { Variable v = sameDiff.variables.get(s); String opName = v.getOutputOfOp(); - if(opName == null || differentiatedOps.contains(opName)){ + if (opName == null || differentiatedOps.contains(opName)) { //Skip placeholder/constant etc; also skip if we've previously differentiated this op continue; } @@ -4316,23 +5014,23 @@ public class SameDiff extends SDBaseOps { boolean isRequiredOp = false; SameDiffOp op = ops.get(opName); - if(op.getInputsToOp() != null){ + if (op.getInputsToOp() != null) { List opInputs = op.getInputsToOp(); boolean anyInputsRequired = false; - for(String s2 : opInputs){ - if(minimalSubgraphVars.contains(s2)){ + for (String s2 : opInputs) { + if (minimalSubgraphVars.contains(s2)) { anyInputsRequired = true; break; } } - if(anyInputsRequired){ - if(!differentiatedOps.contains(op.getName())){ + if (anyInputsRequired) { + if (!differentiatedOps.contains(op.getName())) { isRequiredOp = true; } } } - if(!isRequiredOp){ + if (!isRequiredOp) { continue; } @@ -4346,20 +5044,20 @@ public class SameDiff extends SDBaseOps { boolean allAvailable = true; SameDiffOp o = sameDiff.ops.get(opName); - for(String opOutput : o.getOutputsOfOp()){ + for (String opOutput : o.getOutputsOfOp()) { Variable outVar = variables.get(opOutput); - if(outVar.getVariable().dataType().isFPType()){ - if(minimalSubgraphVars.contains(outVar.getName())){ + if (outVar.getVariable().dataType().isFPType()) { + if (minimalSubgraphVars.contains(outVar.getName())) { //Need gradient for this variable to be available before we can differentiate - if(outVar.getVariable().gradient() == null){ + if (outVar.getVariable().gradient() == null) { allAvailable = false; break; } //However, when a variable is used multiple times, we need ALL gradient contributions available: List prereqs = prerequisites.get(outVar.getName()); - if(prereqs != null){ + if (prereqs != null) { allAvailable &= differentiatedOps.containsAll(prereqs); - if(!allAvailable) + if (!allAvailable) break; } } @@ -4367,19 +5065,19 @@ public class SameDiff extends SDBaseOps { } } - if(allAvailable && !availableForDiff.contains(o.getOp().getOwnName())){ + if (allAvailable && !availableForDiff.contains(o.getOp().getOwnName())) { availableForDiff.add(o.getOp().getOwnName()); } } } //Let's validate we actually differentiated everything correctly: - for(String s : minimalSubgraphVars){ - if(lossVariables.contains(s)) + for (String s : minimalSubgraphVars) { + if (lossVariables.contains(s)) continue; SDVariable v = variables.get(s).getVariable(); SDVariable g = v.gradient(); - if(g == null){ + if (g == null) { throw new IllegalStateException("Error encountered during differentiation: no gradient for required variable \"" + s + "\" was calculated"); } } @@ -4430,7 +5128,7 @@ public class SameDiff extends SDBaseOps { /** * Get the original shape for the vertex id if one was set (other wise returns null).
* This is mainly for use in validating passed in arrays as arguments to {@link #resolveVariablesWith(Map)} - * usually when executing using {@link #execWithPlaceHolder(Map)} + * usually when executing using {@link #execAll(Map)} * * @param varName the vertex id to get the original shape for. * @return the set vertex @@ -4460,7 +5158,7 @@ public class SameDiff extends SDBaseOps { * @param arrays the arrays to resolve. */ public void resolveVariablesWith(Map arrays) { - for (Map.Entry e : arrays.entrySet()) { + for (Map.Entry e : arrays.entrySet()) { SDVariable varForName = getVariable(e.getKey()); if (varForName == null) { throw new ND4JIllegalStateException("A placeholder array was provided for variable with name \"" + e.getKey() + @@ -4468,7 +5166,7 @@ public class SameDiff extends SDBaseOps { } Variable v = variables.get(e.getKey()); - if(varForName.getVariableType() == VariableType.PLACEHOLDER){ + if (varForName.getVariableType() == VariableType.PLACEHOLDER) { //Check shape: long[] shape = varForName.placeholderShape(); long[] newShape = e.getValue().shape(); @@ -4513,7 +5211,7 @@ public class SameDiff extends SDBaseOps { throw new NullPointerException("Null input: No variable found for updating!"); } - if(newVarName != null) { + if (newVarName != null) { String nameScope = currentNameScope(); if (nameScope != null) { if (!newVarName.startsWith(nameScope + "/")) { @@ -4522,7 +5220,7 @@ public class SameDiff extends SDBaseOps { } } - if(newVarName != null && variables.containsKey(newVarName) && varToUpdate != variables.get(newVarName).getVariable()){ + if (newVarName != null && variables.containsKey(newVarName) && varToUpdate != variables.get(newVarName).getVariable()) { throw new IllegalStateException("Variable name \"" + newVarName + "\" already exists for a different SDVariable"); } @@ -4577,12 +5275,12 @@ public class SameDiff extends SDBaseOps { * as "grad" - the backward function) we have the correct SameDiff instance set for all ops/SDVariables.
* If this is not done, arrays and shapes could be fetched from the incorrect SameDiff instance for some methods */ - protected void associateSameDiffWithOpsAndVariables(){ - for(SDVariable var : variableMap().values()){ + protected void associateSameDiffWithOpsAndVariables() { + for (SDVariable var : variableMap().values()) { var.setSameDiff(this); } // for(DifferentialFunction df : functionInstancesById.values()){ - for(SameDiffOp op : ops.values()){ + for (SameDiffOp op : ops.values()) { DifferentialFunction df = op.getOp(); df.setSameDiff(this); @@ -4592,15 +5290,15 @@ public class SameDiff extends SDBaseOps { // to another SameDiff instance. At which point, they could fetch shapes and arrays from some other instance // (i.e., not from this one that is currently executing) SDVariable[] args = df.args(); - if(args != null){ - for(SDVariable arg : args){ + if (args != null) { + for (SDVariable arg : args) { arg.setSameDiff(this); } } SDVariable[] outputs = df.outputVariables(); - if(outputs != null){ - for(SDVariable out : outputs){ + if (outputs != null) { + for (SDVariable out : outputs) { out.setSameDiff(this); } } @@ -4625,7 +5323,7 @@ public class SameDiff extends SDBaseOps { 0, 0, -1, - 0, 0, 0, 0,0, 0); + 0, 0, 0, 0, 0, 0); return flatNode; } @@ -4669,8 +5367,8 @@ public class SameDiff extends SDBaseOps { //log.info("Exporting node: [{}:<{}> ; OpType: {}; Hash/opNum: {}]", node.opName(), node.tensorflowName(), node.opType(), hash); double[] extras; - if(node.opType() == Op.Type.CUSTOM){ - CustomOp op = (CustomOp)node; + if (node.opType() == Op.Type.CUSTOM) { + CustomOp op = (CustomOp) node; extras = op.tArgs(); } else { Object[] eArgs = node.getExtraArgs(); @@ -4734,7 +5432,7 @@ public class SameDiff extends SDBaseOps { for (SDVariable input : inputs) { String varName = input.getVarName(); int outIdx; - if(this.variables.get(varName).getOutputOfOp() != null){ + if (this.variables.get(varName).getOutputOfOp() != null) { DifferentialFunction df = ops.get(this.variables.get(varName).getOutputOfOp()).getOp(); outIdx = ops.get(df.getOwnName()).getOutputsOfOp().indexOf(varName); } else { @@ -4759,21 +5457,21 @@ public class SameDiff extends SDBaseOps { log.trace("Own Name: {}", node.getOwnName()); int ownId = id != null ? id : idCounter.incrementAndGet(); //forwardMap.containsKey(node.getOwnName()) ? forwardMap.get(node.getOwnName()) : idCounter.incrementAndGet(); String[] outNames = node.outputVariablesNames(); - for(String s : outNames){ - if(!reverseMap.containsKey(s)){ + for (String s : outNames) { + if (!reverseMap.containsKey(s)) { reverseMap.put(s, ownId); } } int[] dims; - if(node.opType() == Op.Type.REDUCE_FLOAT || node.opType() == Op.Type.REDUCE_SAME || node.opType() == Op.Type.REDUCE_BOOL || node.opType() == Op.Type.REDUCE_LONG || node.opType() == Op.Type.INDEXREDUCE || node.opType() == Op.Type.REDUCE3){ + if (node.opType() == Op.Type.REDUCE_FLOAT || node.opType() == Op.Type.REDUCE_SAME || node.opType() == Op.Type.REDUCE_BOOL || node.opType() == Op.Type.REDUCE_LONG || node.opType() == Op.Type.INDEXREDUCE || node.opType() == Op.Type.REDUCE3) { dims = node.getDimensions(); - if(dims == null) + if (dims == null) dims = new int[0]; } else { dims = new int[0]; } - Map fnProps = node.propertiesForFunction(); + Map fnProps = node.propertiesForFunction(); int[] flatProperties = FlatBuffersMapper.mapFunctionPropertiesToFlatProperties(bufferBuilder, fnProps); int propIdx = FlatNode.createPropertiesVector(bufferBuilder, flatProperties); @@ -4787,10 +5485,10 @@ public class SameDiff extends SDBaseOps { int fname = bufferBuilder.createString(node.getOwnName()); int scopeName = bufferBuilder.createString(""); int scalar = 0; - if(node instanceof ScalarOp){ - ScalarOp sOp = (ScalarOp)node; + if (node instanceof ScalarOp) { + ScalarOp sOp = (ScalarOp) node; INDArray s = sOp.scalar(); - if(s != null){ + if (s != null) { scalar = s.toFlatArray(bufferBuilder); } } @@ -4802,7 +5500,7 @@ public class SameDiff extends SDBaseOps { List outVarNames = node.getSameDiff().ops.get(node.getOwnName()).getOutputsOfOp(); int[] outVarNamesStringsOffsets = new int[outVarNames == null ? 0 : outVarNames.size()]; - for( int i=0; i(); int idx = 0; - val idxForOps = new IdentityHashMap(); + val idxForOps = new IdentityHashMap(); List allVars = variables(); for (SDVariable variable : allVars) { INDArray arr = variable.getArr(); @@ -4890,10 +5588,10 @@ public class SameDiff extends SDBaseOps { String varName = variable.getVarName(); int varIdx; int outputNum; - if(variables.get(varName).getOutputOfOp() != null){ + if (variables.get(varName).getOutputOfOp() != null) { //This variable is the output of a node DifferentialFunction df = ops.get(variables.get(varName).getOutputOfOp()).getOp(); - if(!idxForOps.containsKey(df)){ + if (!idxForOps.containsKey(df)) { varIdx = idCounter.incrementAndGet(); idxForOps.put(df, varIdx); } else { @@ -4916,7 +5614,7 @@ public class SameDiff extends SDBaseOps { int array = 0; int id = IntPair.createIntPair(bufferBuilder, varIdx, outputNum); byte varType = (byte) variable.getVariableType().ordinal(); - if(variable.isConstant() || variable.isPlaceHolder() || variable.getVariableType() == VariableType.VARIABLE) { + if (variable.isConstant() || variable.isPlaceHolder() || variable.getVariableType() == VariableType.VARIABLE) { //Don't export array type (i.e., activations), these are always replaced/re-calculated on each step array = arr == null ? 0 : arr.toFlatArray(bufferBuilder); } @@ -4926,12 +5624,12 @@ public class SameDiff extends SDBaseOps { shape = FlatVariable.createShapeVector(bufferBuilder, shp); } - int flatVariable = FlatVariable.createFlatVariable(bufferBuilder, id, name, FlatBuffersMapper.getDataTypeAsByte(variable.dataType()), shape, array, -1, varType); + int flatVariable = FlatVariable.createFlatVariable(bufferBuilder, id, name, FlatBuffersMapper.getDataTypeAsByte(variable.dataType()), shape, array, -1, varType); flatVariables.add(flatVariable); } //add functions - for(SameDiffOp op : ops.values()){ + for (SameDiffOp op : ops.values()) { DifferentialFunction func = op.getOp(); Integer fnId = idxForOps.get(func); flatNodes.add(asFlatNode(func, bufferBuilder, variableList, reverseMap, forwardMap, framesMap, idCounter, fnId)); @@ -4939,7 +5637,7 @@ public class SameDiff extends SDBaseOps { // we're dumping scopes now for (Map.Entry scope : sameDiffFunctionInstances.entrySet()) { - if(scope.getKey().equalsIgnoreCase(GRAD_FN_KEY)){ + if (scope.getKey().equalsIgnoreCase(GRAD_FN_KEY)) { //Skip the gradient function for export continue; } @@ -4962,13 +5660,13 @@ public class SameDiff extends SDBaseOps { log.trace("Adding [{}] as [{}]", pair.getFirst(), idx); - byte varType = (byte)node.getVariableType().ordinal(); - int flatVariable = FlatVariable.createFlatVariable(bufferBuilder, id, name, FlatBuffersMapper.getDataTypeAsByte(arr.dataType()),0, array, -1, varType); + byte varType = (byte) node.getVariableType().ordinal(); + int flatVariable = FlatVariable.createFlatVariable(bufferBuilder, id, name, FlatBuffersMapper.getDataTypeAsByte(arr.dataType()), 0, array, -1, varType); flatVariables.add(flatVariable); } //add functions - for(SameDiffOp op : scope.getValue().ops.values()){ + for (SameDiffOp op : scope.getValue().ops.values()) { DifferentialFunction func = op.getOp(); flatNodes.add(asFlatNode(func, bufferBuilder, currVarList, reverseMap, forwardMap, framesMap, idCounter, null)); } @@ -4979,17 +5677,17 @@ public class SameDiff extends SDBaseOps { int nodesOffset = FlatGraph.createNodesVector(bufferBuilder, Ints.toArray(flatNodes)); int numPlaceholders = 0; - for(SDVariable v : variables()){ - if(v.isPlaceHolder()){ + for (SDVariable v : variables()) { + if (v.isPlaceHolder()) { numPlaceholders++; } } int[] placeholderOffsets = new int[numPlaceholders]; - if(numPlaceholders > 0){ - int i=0; - for(SDVariable v : variables()){ - if(!v.isPlaceHolder()) + if (numPlaceholders > 0) { + int i = 0; + for (SDVariable v : variables()) { + if (!v.isPlaceHolder()) continue; placeholderOffsets[i++] = bufferBuilder.createString(v.getVarName()); } @@ -4998,30 +5696,30 @@ public class SameDiff extends SDBaseOps { List lossVars = getLossVariables(); int[] lossVarOffsets = new int[lossVars == null ? 0 : lossVars.size()]; - for( int i=0; i g : updaterMap.entrySet()){ + for (Map.Entry g : updaterMap.entrySet()) { int paramNameOffset = bufferBuilder.createString(g.getKey()); int stateKeyOffset = 0; int stateValuesOffset = 0; - Map state = g.getValue().getState(); - if(state != null && !state.isEmpty()){ + Map state = g.getValue().getState(); + if (state != null && !state.isEmpty()) { int[] keysOffsets = new int[state.size()]; int[] valuesOffsets = new int[state.size()]; - int i=0; - for(Map.Entry e : state.entrySet()){ + int i = 0; + for (Map.Entry e : state.entrySet()) { keysOffsets[i] = bufferBuilder.createString(e.getKey()); valuesOffsets[i] = e.getValue().toFlatArray(bufferBuilder); i++; @@ -5041,7 +5739,7 @@ public class SameDiff extends SDBaseOps { bufferBuilder.finish(fg); synchronized (this) { - for(Map.Entry e : reverseMap.entrySet()){ + for (Map.Entry e : reverseMap.entrySet()) { this.variables.get(e.getKey()).setVariableIndex(e.getValue()); } } @@ -5194,7 +5892,7 @@ public class SameDiff extends SDBaseOps { /** * This method converts SameDiff instance to FlatBuffers and saves it to file which can be restored later * - * @param file File to save the FlatBuffers serialized graph (including arrays) to + * @param file File to save the FlatBuffers serialized graph (including arrays) to * @param includeUpdaterState If true: include the updater state (state for updaters such as Adam, Nesterov, AdaGrad etc) */ public void asFlatFile(@NonNull File file, @NonNull ExecutorConfiguration configuration, boolean includeUpdaterState) throws IOException { @@ -5212,6 +5910,7 @@ public class SameDiff extends SDBaseOps { /** * Create a {@link SameDiff} instance from a file, including the updater state * The method to save the file is {@link #save(File, boolean)} + * * @param file the file to load from * @return the loaded same diff instance * @throws IOException @@ -5223,7 +5922,8 @@ public class SameDiff extends SDBaseOps { /** * Create a {@link SameDiff} instance from a file, optionally also loading the updater state * The method to save the file is {@link #save(File, boolean)} - * @param file the file to load from + * + * @param file the file to load from * @param loadUpdaterState If true, load the updater state (Adam etc state). For training, use true. For inference, use false * @return the loaded same diff instance * @throws IOException @@ -5242,6 +5942,7 @@ public class SameDiff extends SDBaseOps { * Create a {@link SameDiff} * instance from a byte buffers * instance. + * * @param bbIn the input byte buffer * @return the created samediff instance * @throws IOException @@ -5257,11 +5958,11 @@ public class SameDiff extends SDBaseOps { int numOps = fg.nodesLength(); int numVars = fg.variablesLength(); List ops = new ArrayList<>(numOps); - for( int i=0; i vars = new ArrayList<>(numVars); - for( int i = 0; i < numVars; i++) { + for (int i = 0; i < numVars; i++) { vars.add(fg.variables(i)); } @@ -5277,18 +5978,18 @@ public class SameDiff extends SDBaseOps { //Reconstruct placeholders int numPlaceholders = fg.placeholdersLength(); Set ph = new LinkedHashSet<>(); - for(int i=0; i varNodeIds = new HashMap<>(); - Map, SDVariable> variablesByNodeAndOutNum = new HashMap<>(); - Map> variablesByName = new HashMap<>(); - for(FlatVariable v : vars){ + Map varNodeIds = new HashMap<>(); + Map, SDVariable> variablesByNodeAndOutNum = new HashMap<>(); + Map> variablesByName = new HashMap<>(); + for (FlatVariable v : vars) { int shapeLength = v.shapeLength(); long[] shape = new long[shapeLength]; - for( int i = 0; i < shapeLength; i++) { + for (int i = 0; i < shapeLength; i++) { shape[i] = v.shape(i); } @@ -5305,9 +6006,9 @@ public class SameDiff extends SDBaseOps { FlatArray fa = v.ndarray(); - if(fa != null && vt != VariableType.ARRAY){ + if (fa != null && vt != VariableType.ARRAY) { INDArray arr; - try(MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces()) { + try (MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces()) { arr = Nd4j.createFromFlatArray(fa); } sd.setArrayForVariable(n, arr); @@ -5316,7 +6017,7 @@ public class SameDiff extends SDBaseOps { IntPair id = v.id(); //First value: node (op) id. Second: output number variablesByNodeAndOutNum.put(new Pair<>(id.first(), id.second()), var); - if(!variablesByName.containsKey(n)){ + if (!variablesByName.containsKey(n)) { variablesByName.put(n, new ArrayList()); } @@ -5325,12 +6026,12 @@ public class SameDiff extends SDBaseOps { } //Reconstruct ops: - for(FlatNode fn : ops){ + for (FlatNode fn : ops) { DifferentialFunction df = FlatBuffersMapper.fromFlatNode(fn); String name = fn.name(); df.setSameDiff(sd); df.setOwnName(name); - if(sd.ops.containsKey(name)){ + if (sd.ops.containsKey(name)) { sd.ops.get(name).setOp(df); } else { sd.ops.put(name, SameDiffOp.builder().name(name).op(df).build()); @@ -5338,7 +6039,7 @@ public class SameDiff extends SDBaseOps { int outLength = fn.outputLength(); int[] outs = new int[outLength]; - for( int i=0; i> intPairList = new ArrayList<>(); + List> intPairList = new ArrayList<>(); for (int i = 0; i < inputPaired.length; i++) { inputPaired[i] = fn.inputPaired(i); intPairList.add(new Pair<>(inputPaired[i].first(), inputPaired[i].second())); } String[] inputNames = new String[inputPaired.length]; - for(int i=0; i(nodeId, nodeOutNum)); - if(varIn == null){ + if (varIn == null) { //The variable corresponding to this op was not } inputNames[i] = varIn.getVarName(); @@ -5373,12 +6074,12 @@ public class SameDiff extends SDBaseOps { sd.ops.get(df.getOwnName()).setInputsToOp(Arrays.asList(inputNames)); //Record that input variables are input to this op - for(String inName : inputNames) { + for (String inName : inputNames) { Variable v = sd.getVariables().get(inName); - if(v.getInputsForOp() == null){ + if (v.getInputsForOp() == null) { v.setInputsForOp(new ArrayList()); } - if(!v.getInputsForOp().contains(df.getOwnName())){ + if (!v.getInputsForOp().contains(df.getOwnName())) { v.getInputsForOp( ).add(df.getOwnName()); @@ -5391,14 +6092,14 @@ public class SameDiff extends SDBaseOps { //In theory, we can reconstruct the output variables (minus names) if we know the number of op outputs //And we can calculate the op outputs - in most cases - after the op has been created and parameters set int numOutputs = df.getNumOutputs(); - if(numOutputs <= 0){ + if (numOutputs <= 0) { numOutputs = fn.outputLength(); } String[] varNames = null; - if(varsForOp != null && varsForOp.size() == numOutputs){ + if (varsForOp != null && varsForOp.size() == numOutputs) { varNames = new String[varsForOp.size()]; - for( int i=0; i can only be VARIABLE type SDVariable var = new SDVariable(n, VariableType.VARIABLE, sd, null, null, null); sd.variables.put(n, Variable.builder().name(n).variable(var).build()); @@ -5423,28 +6124,28 @@ public class SameDiff extends SDBaseOps { //Check the op mapping int he variablesByNodeAndOutputNum //For multi-output ops, variables will have their own index, not related to the op index - for( int i=0; i p = new Pair<>(opId, i); - if(!variablesByNodeAndOutNum.containsKey(p)){ + for (int i = 0; i < varNames.length; i++) { + Pair p = new Pair<>(opId, i); + if (!variablesByNodeAndOutNum.containsKey(p)) { variablesByNodeAndOutNum.put(p, sd.getVariable(varNames[i])); } } } //Reconstruct loss variables - if(fg.lossVariablesLength() > 0){ - for(int i=0; i 0) { + for (int i = 0; i < fg.lossVariablesLength(); i++) { sd.addLossVariable(fg.lossVariables(i)); } } //Reconstruct training config String tc = fg.trainingConfig(); - if(tc != null){ + if (tc != null) { sd.trainingConfig = TrainingConfig.fromJson(tc); } - if(loadUpdaterState) { + if (loadUpdaterState) { //Reconstruct updater state if (fg.updaterStateLength() > 0) { sd.updaterMap = new HashMap<>(); @@ -5489,32 +6190,32 @@ public class SameDiff extends SDBaseOps { for (int e = 0; e < graph.variablesLength(); e++) { FlatVariable var = graph.variables(e); INDArray ndarray = null; - try(MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces()) { + try (MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces()) { FlatArray fa = var.ndarray(); - if(fa != null) { + if (fa != null) { ndarray = Nd4j.createFromFlatArray(fa); } } sb.append(var.id().first()) .append(":<").append(var.name()).append("> "); - if(ndarray == null){ + if (ndarray == null) { sb.append("").append("; Values: ").append("").append(";\n"); } else { sb.append(Arrays.toString(ndarray.shapeInfoDataBuffer().asInt())).append("; Values: "); - if(ndarray.data() == null){ + if (ndarray.data() == null) { //Empty array sb.append(""); - } else if(ndarray.dataType() == DataType.UTF8) { + } else if (ndarray.dataType() == DataType.UTF8) { sb.append(""); } else { - if(ndarray.length() < 50){ - sb.append(Arrays.toString(ndarray.data().asFloat()).replaceAll(" ","")); + if (ndarray.length() < 50) { + sb.append(Arrays.toString(ndarray.data().asFloat()).replaceAll(" ", "")); } else { //Array is too long - only tak. last few values... sb.append("["); - for( int i=0; i<50; i++ ){ - if(i > 0) + for (int i = 0; i < 50; i++) { + if (i > 0) sb.append(","); sb.append(ndarray.data().getFloat(i)); } @@ -5612,7 +6313,7 @@ public class SameDiff extends SDBaseOps { int maxLengthOfName = 8; //Length of "- Name -" for (String s : varMap.keySet()) { String outputOf = null; - for(SameDiffOp op : ops.values()){ + for (SameDiffOp op : ops.values()) { List outputsOfOp = op.getOutputsOfOp(); if (outputsOfOp != null && outputsOfOp.contains(s)) { outputOf = op.getName(); @@ -5641,10 +6342,10 @@ public class SameDiff extends SDBaseOps { String arrayShape = "-"; if (arr != null) { arrayShape = Arrays.toString(arr.shape()); - } else if(varMap.get(s).isPlaceHolder()){ + } else if (varMap.get(s).isPlaceHolder()) { SDVariable v = varMap.get(s); long[] phShape = v.placeholderShape(); - if(phShape != null){ + if (phShape != null) { arrayShape = Arrays.toString(phShape); } } @@ -5722,22 +6423,23 @@ public class SameDiff extends SDBaseOps { } - public Map calculateOutputDataTypes() { + public Map calculateOutputDataTypes() { return calculateOutputDataTypes(false); } - public Map calculateOutputDataTypes(boolean dynamicUpdate){ + public Map calculateOutputDataTypes(boolean dynamicUpdate) { List allVars = new ArrayList<>(variables.keySet()); DataTypesSession session = new DataTypesSession(this, dynamicUpdate); - Map phValues = new HashMap<>(); - for(Variable v : variables.values()){ - if(v.getVariable().isPlaceHolder()){ + Map phValues = new HashMap<>(); + for (Variable v : variables.values()) { + if (v.getVariable().isPlaceHolder()) { org.nd4j.linalg.api.buffer.DataType dt = v.getVariable().dataType(); Preconditions.checkNotNull(dt, "Placeholder variable %s has null datatype", v.getName()); phValues.put(v.getName(), dt); } } - Map out = session.output(allVars, phValues, null, false, null); + Map out = session.output(allVars, phValues, null, + Collections.emptyList(), Collections.emptyList(), At.defaultAt(Operation.INFERENCE)); return out; } @@ -5746,17 +6448,17 @@ public class SameDiff extends SDBaseOps { * Creates a new discinct block name from baseName. * Block names are used by If and While */ - public String newBlockName(String baseName){ + public String newBlockName(String baseName) { - if(baseName == null) + if (baseName == null) return null; - if(!blockNames.contains(baseName)){ + if (!blockNames.contains(baseName)) { blockNames.add(baseName); return baseName; } else { int i = 1; - while(blockNames.contains(baseName + "_" + i)){ + while (blockNames.contains(baseName + "_" + i)) { i++; } blockNames.add(baseName + "_" + i); @@ -5770,68 +6472,68 @@ public class SameDiff extends SDBaseOps { * @param graphFile The text or binary file containing the graph * @return The imported graph */ - public static SameDiff importFrozenTF(File graphFile){ + public static SameDiff importFrozenTF(File graphFile) { return TFGraphMapper.getInstance().importGraph(graphFile); } /** * See {@link #importFrozenTF(File)} */ - public static SameDiff importFrozenTF(GraphDef graphDef){ + public static SameDiff importFrozenTF(GraphDef graphDef) { return TFGraphMapper.getInstance().importGraph(graphDef); } /** * See {@link #importFrozenTF(File)} - * + *

* Again, the input can be text or binary. */ - public static SameDiff importFrozenTF(InputStream graph){ + public static SameDiff importFrozenTF(InputStream graph) { return TFGraphMapper.getInstance().importGraph(graph); } /** * Generate a new, distinct op name of the form <base>_#. - * + *

* Applies name scope if active. * - * @param base The base name to use + * @param base The base name to use * @param force Whether to force the result name to be the same as base. */ - public String getOpName(String base, boolean force){ + public String getOpName(String base, boolean force) { base = nameWithScope(base); - if(force && ops.containsKey(base)) + if (force && ops.containsKey(base)) throw new IllegalArgumentException("Op with name \"" + base + "\" already exists"); - else if(force) + else if (force) return base; int start = 1; // if we already have a name like "op_2", start from trying "op_3" - if(base.contains("_")){ + if (base.contains("_")) { // extract number used to generate base Matcher num = Pattern.compile("(.*)_(\\d+)").matcher(base); // extract argIndex used to generate base - if(num.find()) { + if (num.find()) { start = Integer.parseInt(num.group(2)); base = num.group(1); } } String name = base; - for(int i = start ; true ; i++) { + for (int i = start; true; i++) { // ensure that there are no variables that look like they are outputs of this op boolean varWithName = false; - for(String varName : variables.keySet()) - if(varName.startsWith(name + ":") || varName.equals(name)) + for (String varName : variables.keySet()) + if (varName.startsWith(name + ":") || varName.equals(name)) varWithName = true; - if(!ops.containsKey(name) && !varWithName) + if (!ops.containsKey(name) && !varWithName) break; name = base + "_" + i; @@ -5843,51 +6545,50 @@ public class SameDiff extends SDBaseOps { * See {@link #getOpName(String, boolean)} * force is false */ - public String getOpName(String base){ + public String getOpName(String base) { return getOpName(base, false); } /** * Generate a new, distinct variable name of the form <base>_#[:#]. - * + *

* Applies name scopes if active. * - * @param base The base of the name. - * @param argIndex The argument index, used in the ":#". A value of 0 (or negative) does not include the ":#" part. + * @param base The base of the name. + * @param argIndex The argument index, used in the ":#". A value of 0 (or negative) does not include the ":#" part. * @param existingOp Whether to generate an distinct operation name from base (if false), or just use base (if true). */ - public String generateNewVarName(String base, int argIndex, boolean existingOp){ + public String generateNewVarName(String base, int argIndex, boolean existingOp) { base = nameWithScope(base); - if(argIndex > 0 && base.contains(":")){ + if (argIndex > 0 && base.contains(":")) { Matcher num = Pattern.compile("(.*):(\\d+)").matcher(base); // extract argIndex used to generate base - if(num.find()) { + if (num.find()) { argIndex = Integer.parseInt(num.group(2)) + 1; base = num.group(1); } } - if(!existingOp) + if (!existingOp) base = getOpName(base); - if(argIndex > 0) + if (argIndex > 0) base += ":" + argIndex; - if(variables.containsKey(base)) + if (variables.containsKey(base)) throw new IllegalArgumentException("Variable with name \"" + base + "\" already exists"); return base; } /** - * * See {@link #generateNewVarName(String, int, boolean)} * existingOp is true. */ @Override - public String generateNewVarName(String base, int argIndex){ + public String generateNewVarName(String base, int argIndex) { return generateNewVarName(base, argIndex, true); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/TrainingConfig.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/TrainingConfig.java index cbb71199c..353c2d1e1 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/TrainingConfig.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/TrainingConfig.java @@ -18,7 +18,9 @@ package org.nd4j.autodiff.samediff; import lombok.*; import lombok.extern.slf4j.Slf4j; +import org.nd4j.autodiff.listeners.ListenerEvaluations; import org.nd4j.base.Preconditions; +import org.nd4j.evaluation.IEvaluation; import org.nd4j.linalg.learning.config.IUpdater; import org.nd4j.linalg.learning.regularization.L1Regularization; import org.nd4j.linalg.learning.regularization.L2Regularization; @@ -62,6 +64,12 @@ public class TrainingConfig { private int iterationCount; private int epochCount; + private Map> trainEvaluations = new HashMap<>(); + private Map trainEvaluationLabels = new HashMap<>(); + + private Map> validationEvaluations = new HashMap<>(); + private Map validationEvaluationLabels = new HashMap<>(); + /** * Create a training configuration suitable for training a single input, single output network.
* See also the {@link Builder} for creating a TrainingConfig @@ -106,6 +114,17 @@ public class TrainingConfig { this.lossVariables = lossVariables; } + protected TrainingConfig(IUpdater updater, List regularization, boolean minimize, List dataSetFeatureMapping, List dataSetLabelMapping, + List dataSetFeatureMaskMapping, List dataSetLabelMaskMapping, List lossVariables, + Map> trainEvaluations, Map trainEvaluationLabels, + Map> validationEvaluations, Map validationEvaluationLabels){ + this(updater, regularization, minimize, dataSetFeatureMapping, dataSetLabelMapping, dataSetFeatureMaskMapping, dataSetLabelMaskMapping, lossVariables); + this.trainEvaluations = trainEvaluations; + this.trainEvaluationLabels = trainEvaluationLabels; + this.validationEvaluations = validationEvaluations; + this.validationEvaluationLabels = validationEvaluationLabels; + } + /** * Increment the iteration count by 1 */ @@ -146,6 +165,12 @@ public class TrainingConfig { private boolean skipValidation = false; private boolean markLabelsUnused = false; + private Map> trainEvaluations = new HashMap<>(); + private Map trainEvaluationLabels = new HashMap<>(); + + private Map> validationEvaluations = new HashMap<>(); + private Map validationEvaluationLabels = new HashMap<>(); + /** * Set the updater (such as {@link org.nd4j.linalg.learning.config.Adam}, {@link org.nd4j.linalg.learning.config.Nesterovs} * etc. This is also how the learning rate (or learning rate schedule) is set. @@ -327,7 +352,7 @@ public class TrainingConfig { * Set the name of the placeholders/variables that should be set using the feature mask INDArray(s) from the * DataSet or MultiDataSet. For example, if the network had 2 mask variables called "mask1" and "mask2" * and the MultiDataSet features masks should be mapped with {@code MultiDataSet.getFeatureMaskArray(0)->"mask1"} - * and {@code MultiDataSet.getFeatureMaskArray(1)->"mask2"}, then this should be set to {@code "mask2", "mask2"}. + * and {@code MultiDataSet.getFeatureMaskArray(1)->"mask2"}, then this should be set to {@code "mask1", "mask2"}. * * @param dataSetFeatureMaskMapping Name of the variables/placeholders that the feature arrays should be mapped to */ @@ -347,7 +372,7 @@ public class TrainingConfig { * Set the name of the placeholders/variables that should be set using the label mask INDArray(s) from the * DataSet or MultiDataSet. For example, if the network had 2 mask variables called "mask1" and "mask2" * and the MultiDataSet label masks should be mapped with {@code MultiDataSet.getLabelMaskArray(0)->"mask1"} - * and {@code MultiDataSet.getLabelMaskArray(1)->"mask2"}, then this should be set to {@code "mask2", "mask2"}. + * and {@code MultiDataSet.getLabelMaskArray(1)->"mask2"}, then this should be set to {@code "mask1", "mask2"}. * * @param dataSetLabelMaskMapping Name of the variables/placeholders that the feature arrays should be mapped to */ @@ -366,6 +391,104 @@ public class TrainingConfig { return this; } + private void addEvaluations(boolean validation, @NonNull Map> evaluationMap, @NonNull Map labelMap, + @NonNull String variableName, int labelIndex, @NonNull IEvaluation... evaluations){ + if(evaluationMap.containsKey(variableName) && labelMap.get(variableName) != labelIndex){ + String s; + + if(validation){ + s = "This ListenerEvaluations.Builder already has validation evaluations for "; + } else { + s = "This ListenerEvaluations.Builder already has train evaluations for "; + } + + throw new IllegalArgumentException(s + "variable " + + variableName + " with label index " + labelIndex + ". You can't add " + + " evaluations with a different label index. Got label index " + labelIndex); + } + + if(evaluationMap.containsKey(variableName)){ + evaluationMap.get(variableName).addAll(Arrays.asList(evaluations)); + } else { + evaluationMap.put(variableName, Arrays.asList(evaluations)); + labelMap.put(variableName, labelIndex); + } + } + + /** + * Add requested History training evaluations for a parm/variable. + * + * These evaluations will be reported in the {@link org.nd4j.autodiff.listeners.records.History} object returned by fit. + * + * @param variableName The variable to evaluate + * @param labelIndex The index of the label to evaluate against + * @param evaluations The evaluations to run + */ + public Builder trainEvaluation(@NonNull String variableName, int labelIndex, @NonNull IEvaluation... evaluations){ + addEvaluations(false, this.trainEvaluations, this.trainEvaluationLabels, variableName, + labelIndex, evaluations); + return this; + } + + /** + * Add requested History training evaluations for a parm/variable. + * + * These evaluations will be reported in the {@link org.nd4j.autodiff.listeners.records.History} object returned by fit. + * + * @param variable The variable to evaluate + * @param labelIndex The index of the label to evaluate against + * @param evaluations The evaluations to run + */ + public Builder trainEvaluation(@NonNull SDVariable variable, int labelIndex, @NonNull IEvaluation... evaluations){ + return trainEvaluation(variable.getVarName(), labelIndex, evaluations); + } + + /** + * Add requested History validation evaluations for a parm/variable. + * + * These evaluations will be reported in the {@link org.nd4j.autodiff.listeners.records.History} object returned by fit. + * + * @param variableName The variable to evaluate + * @param labelIndex The index of the label to evaluate against + * @param evaluations The evaluations to run + */ + public Builder validationEvaluation(@NonNull String variableName, int labelIndex, @NonNull IEvaluation... evaluations){ + addEvaluations(true, this.validationEvaluations, this.validationEvaluationLabels, variableName, + labelIndex, evaluations); + return this; + } + + /** + * Add requested History validation evaluations for a parm/variable. + * + * These evaluations will be reported in the {@link org.nd4j.autodiff.listeners.records.History} object returned by fit. + * + * @param variable The variable to evaluate + * @param labelIndex The index of the label to evaluate against + * @param evaluations The evaluations to run + */ + public Builder validationEvaluation(@NonNull SDVariable variable, int labelIndex, @NonNull IEvaluation... evaluations){ + return validationEvaluation(variable.getVarName(), labelIndex, evaluations); + } + + /** + * Add requested evaluations for a parm/variable, for either training or validation. + * + * These evaluations will be reported in the {@link org.nd4j.autodiff.listeners.records.History} object returned by fit. + * + * @param validation Whether to add these evaluations as validation or training + * @param variableName The variable to evaluate + * @param labelIndex The index of the label to evaluate against + * @param evaluations The evaluations to run + */ + public Builder addEvaluations(boolean validation, @NonNull String variableName, int labelIndex, @NonNull IEvaluation... evaluations){ + if(validation){ + return validationEvaluation(variableName, labelIndex, evaluations); + } else{ + return trainEvaluation(variableName, labelIndex, evaluations); + } + } + public TrainingConfig build(){ if(!skipValidation) { Preconditions.checkState(updater != null, "Updater (optimizer) must not be null. Use updater(IUpdater) to set an updater"); @@ -374,10 +497,20 @@ public class TrainingConfig { Preconditions.checkState(markLabelsUnused || dataSetLabelMapping != null, "No DataSet label mapping has been provided. A " + "mapping between DataSet array positions and variables/placeholders must be provided - use dataSetLabelMapping(...) to set this," + " or use markLabelsUnused() to mark labels as unused (for example, for unsupervised learning)"); + + + Preconditions.checkArgument(trainEvaluations.keySet().equals(trainEvaluationLabels.keySet()), + "Must specify a label index for each train evaluation. Expected: %s, got: %s", + trainEvaluations.keySet(), trainEvaluationLabels.keySet()); + + Preconditions.checkArgument(validationEvaluations.keySet().equals(validationEvaluationLabels.keySet()), + "Must specify a label index for each validation evaluation. Expected: %s, got: %s", + validationEvaluations.keySet(), validationEvaluationLabels.keySet()); } return new TrainingConfig(updater, regularization, minimize, dataSetFeatureMapping, dataSetLabelMapping, - dataSetFeatureMaskMapping, dataSetLabelMaskMapping, lossVariables); + dataSetFeatureMaskMapping, dataSetLabelMaskMapping, lossVariables, + trainEvaluations, trainEvaluationLabels, validationEvaluations, validationEvaluationLabels); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/config/BatchOutputConfig.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/config/BatchOutputConfig.java new file mode 100644 index 000000000..ba5aaa234 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/config/BatchOutputConfig.java @@ -0,0 +1,152 @@ +/* + * Copyright (c) 2015-2019 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.nd4j.autodiff.samediff.config; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import lombok.AccessLevel; +import lombok.Getter; +import lombok.NonNull; +import lombok.Setter; +import org.nd4j.autodiff.listeners.Listener; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.autodiff.samediff.VariableType; +import org.nd4j.base.Preconditions; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.dataset.adapter.MultiDataSetIteratorAdapter; +import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; +import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator; + +/** + * Configuration for a single batch {@link SameDiff} inference operation. + * + * Used in {@link SameDiff#batchOutput()}. + */ +@Getter +@Setter +public class BatchOutputConfig { + + @Setter(AccessLevel.NONE) + private SameDiff sd; + + @NonNull + private List outputs = new ArrayList<>(); + + private Map placeholders = new HashMap<>(); + + @NonNull + private List listeners = new ArrayList<>(); + + public BatchOutputConfig(@NonNull SameDiff sd){ + this.sd = sd; + } + + /** + * Add required outputs + */ + public BatchOutputConfig output(@NonNull String... outputs){ + this.outputs.addAll(Arrays.asList(outputs)); + return this; + } + + /** + * Add required outputs + */ + public BatchOutputConfig output(@NonNull SDVariable... outputs){ + String[] outNames = new String[outputs.length]; + for(int i = 0 ; i < outputs.length ; i++){ + outNames[i] = outputs[i].getVarName(); + } + + return output(outNames); + } + + /** + * Add all variables as required outputs + */ + public BatchOutputConfig outputAll(){ + return output(sd.variables().toArray(new SDVariable[0])); + } + + /** + * Add a placeholder value for a specified variable + */ + public BatchOutputConfig input(@NonNull String variable, @NonNull INDArray placeholder){ + Preconditions.checkState(!placeholders.containsKey(variable), + "Placeholder for variable %s already specified", variable); + + Preconditions.checkNotNull(sd.getVariable(variable), + "Variable %s does not exist in this SameDiff graph", variable); + + placeholders.put(variable, placeholder); + return this; + } + + /** + * See {@link #input(String, INDArray)} + */ + public BatchOutputConfig input(@NonNull SDVariable variable, @NonNull INDArray placeholder){ + return input(variable.getVarName(), placeholder); + } + + /** + * Calls {@link #input(String, INDArray)} on each entry in the map. + */ + public BatchOutputConfig inputs(Map placeholders){ + + if(placeholders == null) { + this.placeholders = null; + return this; + } + + for(Map.Entry e : placeholders.entrySet()){ + input(e.getKey(), e.getValue()); + } + + return this; + } + + /** + * Add listeners for this operation + */ + public BatchOutputConfig listeners(@NonNull Listener... listeners){ + this.listeners.addAll(Arrays.asList(listeners)); + return this; + } + + /** + * Do inference and return the results + */ + public Map exec(){ + return sd.output(placeholders, listeners, outputs.toArray(new String[0])); + } + + /** + * Do inference and return the results for the single output + * + * Only works if exactly one output is specified + */ + public INDArray execSingle(){ + Preconditions.checkState(outputs.size() == 1, + "Can only use execSingle() when exactly one output is specified, there were %s", outputs.size()); + return exec().get(outputs.get(0)); + } +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/config/EvaluationConfig.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/config/EvaluationConfig.java new file mode 100644 index 000000000..f4477adad --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/config/EvaluationConfig.java @@ -0,0 +1,202 @@ +/* + * Copyright (c) 2015-2019 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.nd4j.autodiff.samediff.config; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import lombok.AccessLevel; +import lombok.Getter; +import lombok.NonNull; +import lombok.Setter; +import org.nd4j.autodiff.listeners.Listener; +import org.nd4j.autodiff.listeners.records.EvaluationRecord; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.base.Preconditions; +import org.nd4j.evaluation.IEvaluation; +import org.nd4j.linalg.dataset.adapter.MultiDataSetIteratorAdapter; +import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; +import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator; + +/** + * Configuration for a {@link SameDiff} evaluation operation. + * + * Used in {@link SameDiff#evaluate()}. + */ +@Getter +@Setter +public class EvaluationConfig { + + @NonNull + private Map> evaluations = new HashMap<>(); + + @NonNull + private Map labelIndices = new HashMap<>(); + + private MultiDataSetIterator data; + + @NonNull + private List listeners = new ArrayList<>(); + + private boolean singleInput = false; + + @Setter(AccessLevel.NONE) + private SameDiff sd; + + public EvaluationConfig(@NonNull SameDiff sd){ + this.sd = sd; + } + + /** + * Add evaluations to be preformed on a specified variable, and set that variable's label index. + * + * Setting a label index is required if using a MultiDataSetIterator. + * + * @param param The param to evaluate + * @param labelIndex The label index of that parameter + * @param evaluations The evaluations to preform + */ + public EvaluationConfig evaluate(@NonNull String param, int labelIndex, @NonNull IEvaluation... evaluations){ + return evaluate(param, evaluations).labelIndex(param, labelIndex); + } + + /** + * See {@link #evaluate(String, int, IEvaluation[])} + */ + public EvaluationConfig evaluate(@NonNull SDVariable variable, int labelIndex, @NonNull IEvaluation... evaluations){ + return evaluate(variable.getVarName(), labelIndex, evaluations); + } + + /** + * Add evaluations to be preformed on a specified variable, without setting a label index. + * + * Setting a label index (which is not done here) is required if using a MultiDataSetIterator. + * + * @param param The param to evaluate + * @param evaluations The evaluations to preform + */ + public EvaluationConfig evaluate(@NonNull String param, @NonNull IEvaluation... evaluations){ + if(this.evaluations.get(param) == null){ + this.evaluations.put(param, new ArrayList()); + } + + this.evaluations.get(param).addAll(Arrays.asList(evaluations)); + return this; + } + + + /** + * See {@link #evaluate(String, IEvaluation[])} + */ + public EvaluationConfig evaluate(@NonNull SDVariable variable, @NonNull IEvaluation... evaluations){ + return evaluate(variable.getVarName(), evaluations); + } + + /** + * Set the label index for a parameter + */ + public EvaluationConfig labelIndex(@NonNull String param, int labelIndex){ + if(this.labelIndices.get(param) != null){ + int existingIndex = this.labelIndices.get(param); + Preconditions.checkArgument(existingIndex == labelIndex, + "Different label index already specified for param %s. Already specified: %s, given: %s", + param, existingIndex, labelIndex); + } + + labelIndices.put(param, labelIndex); + + return this; + } + + /** + * See {@link #labelIndex(String, int)} + */ + public EvaluationConfig labelIndex(@NonNull SDVariable variable, int labelIndex){ + return labelIndex(variable.getVarName(), labelIndex); + } + + /** + * Add listeners for this operation + */ + public EvaluationConfig listeners(@NonNull Listener... listeners){ + this.listeners.addAll(Arrays.asList(listeners)); + return this; + } + + /** + * Set the data to evaluate on. + * + * Setting a label index for each variable to evaluate is required + */ + public EvaluationConfig data(@NonNull MultiDataSetIterator data){ + this.data = data; + singleInput = false; + return this; + } + + /** + * Set the data to evaluate on. + * + * Setting a label index for each variable to evaluate is NOT required (since there is only one input) + */ + public EvaluationConfig data(@NonNull DataSetIterator data){ + this.data = new MultiDataSetIteratorAdapter(data); + singleInput = true; + return this; + } + + private void validateConfig(){ + Preconditions.checkNotNull(data, "Must specify data. It may not be null."); + + if(!singleInput){ + for(String param : this.evaluations.keySet()){ + Preconditions.checkState(labelIndices.containsKey(param), + "Using multiple input dataset iterator without specifying a label index for %s", param); + } + } + + for(String param : this.evaluations.keySet()){ + Preconditions.checkState(sd.variableMap().containsKey(param), + "Parameter %s not present in this SameDiff graph", param); + } + } + + /** + * Run the evaluation. + * + * Note that the evaluations in the returned {@link EvaluationRecord} are the evaluations set using {@link #evaluate(String, int, IEvaluation[])}, + * it does not matter which you use to access results. + * + * @return The specified listeners, in an {@link EvaluationRecord} for easy access. + */ + public EvaluationRecord exec(){ + validateConfig(); + + if(singleInput){ + for(String param : this.evaluations.keySet()){ + labelIndices.put(param, 0); + } + } + + sd.evaluate(data, evaluations, labelIndices, listeners.toArray(new Listener[0])); + return new EvaluationRecord(evaluations); + } + +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/config/FitConfig.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/config/FitConfig.java new file mode 100644 index 000000000..8386c05e6 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/config/FitConfig.java @@ -0,0 +1,176 @@ +/* + * Copyright (c) 2015-2019 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.nd4j.autodiff.samediff.config; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + +import lombok.AccessLevel; +import lombok.Getter; +import lombok.NonNull; +import lombok.Setter; +import org.nd4j.autodiff.listeners.records.History; +import org.nd4j.autodiff.listeners.Listener; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.autodiff.samediff.TrainingConfig; +import org.nd4j.base.Preconditions; +import org.nd4j.linalg.dataset.adapter.MultiDataSetIteratorAdapter; +import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; +import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator; + +/** + * Configuration for a {@link SameDiff} training operation. + *

+ * Used in {@link SameDiff#fit()}. + */ +@Getter +@Setter +public class FitConfig { + + @Setter(AccessLevel.NONE) + private SameDiff sd; + + private MultiDataSetIterator trainingData; + + private MultiDataSetIterator validationData = null; + + private int epochs = -1; + + private int validationFrequency = 1; + + @NonNull + private List listeners = new ArrayList<>(); + + public FitConfig(@NonNull SameDiff sd) { + this.sd = sd; + } + + /** + * Set the number of epochs to train for + */ + public FitConfig epochs(int epochs) { + this.epochs = epochs; + return this; + } + + /** + * Set the training data + */ + public FitConfig train(@NonNull MultiDataSetIterator trainingData) { + this.trainingData = trainingData; + return this; + } + + /** + * Set the training data + */ + public FitConfig train(@NonNull DataSetIterator trainingData) { + return train(new MultiDataSetIteratorAdapter(trainingData)); + } + + /** + * Set the training data and number of epochs + */ + public FitConfig train(@NonNull MultiDataSetIterator trainingData, int epochs) { + return train(trainingData).epochs(epochs); + } + + /** + * Set the training data and number of epochs + */ + public FitConfig train(@NonNull DataSetIterator trainingData, int epochs) { + return train(trainingData).epochs(epochs); + } + + /** + * Set the validation data + */ + public FitConfig validate(MultiDataSetIterator validationData) { + this.validationData = validationData; + return this; + } + + /** + * Set the validation data + */ + public FitConfig validate(DataSetIterator validationData) { + if (validationData == null) { + return validate((MultiDataSetIterator) null); + } else { + return validate(new MultiDataSetIteratorAdapter(validationData)); + } + } + + /** + * Set the validation frequency. Validation will be preformed once every so many epochs. + *

+ * Specifically, validation will be preformed when i % validationFrequency == 0 + */ + public FitConfig validationFrequency(int validationFrequency) { + this.validationFrequency = validationFrequency; + return this; + } + + /** + * Set the validation data and frequency + *

+ * Specifically, validation will be preformed when i % validationFrequency == 0 + */ + public FitConfig validate(MultiDataSetIterator validationData, int validationFrequency) { + return validate(validationData).validationFrequency(validationFrequency); + } + + /** + * Set the validation data and frequency + *

+ * Specifically, validation will be preformed when i % validationFrequency == 0 + */ + public FitConfig validate(DataSetIterator validationData, int validationFrequency) { + return validate(validationData).validationFrequency(validationFrequency); + } + + /** + * Add listeners for this operation + */ + public FitConfig listeners(@NonNull Listener... listeners) { + this.listeners.addAll(Arrays.asList(listeners)); + return this; + } + + + private void validateConfig() { + Preconditions.checkNotNull(trainingData, "Training data must not be null"); + Preconditions.checkState(epochs > 0, "Epochs must be > 0, got %s", epochs); + + if (validationData != null) + Preconditions.checkState(validationFrequency > 0, "Validation Frequency must be > 0 if validation data is given, got %s", validationFrequency); + } + + /** + * Do the training. + * + * @return a {@link History} object containing the history information for this training operation + * (evaluations specified in the {@link TrainingConfig}, loss values, and timing information). + */ + public History exec() { + validateConfig(); + + return sd.fit(trainingData, epochs, validationData, validationFrequency, listeners.toArray(new Listener[0])); + } + +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/config/OutputConfig.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/config/OutputConfig.java new file mode 100644 index 000000000..e56ab5988 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/config/OutputConfig.java @@ -0,0 +1,154 @@ +/* + * Copyright (c) 2015-2019 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.nd4j.autodiff.samediff.config; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.Map; +import lombok.AccessLevel; +import lombok.Getter; +import lombok.NonNull; +import lombok.Setter; +import org.nd4j.autodiff.listeners.Listener; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.autodiff.util.TrainingUtils; +import org.nd4j.base.Preconditions; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.dataset.adapter.MultiDataSetIteratorAdapter; +import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; +import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator; + +/** + * Configuration for a {@link SameDiff} inference operation. + * + * Used in {@link SameDiff#output()}. + */ +@Getter +@Setter +public class OutputConfig { + + @Setter(AccessLevel.NONE) + private SameDiff sd; + + @NonNull + private List outputs = new ArrayList<>(); + + @NonNull + private List listeners = new ArrayList<>(); + + private MultiDataSetIterator data; + + public OutputConfig(@NonNull SameDiff sd) { + this.sd = sd; + } + + /** + * Add required outputs + */ + public OutputConfig output(@NonNull String... outputs) { + this.outputs.addAll(Arrays.asList(outputs)); + return this; + } + + /** + * Add required outputs + */ + public OutputConfig output(@NonNull SDVariable... outputs) { + String[] outNames = new String[outputs.length]; + for (int i = 0; i < outputs.length; i++) { + outNames[i] = outputs[i].getVarName(); + } + + return output(outNames); + } + + /** + * Set the data to use as input. + */ + public OutputConfig data(@NonNull MultiDataSetIterator data) { + this.data = data; + return this; + } + + /** + * Set the data to use as input. + */ + public OutputConfig data(@NonNull DataSetIterator data) { + this.data = new MultiDataSetIteratorAdapter(data); + return this; + } + + /** + * Add listeners for this operation + */ + public OutputConfig listeners(@NonNull Listener... listeners) { + this.listeners.addAll(Arrays.asList(listeners)); + return this; + } + + private void validateConfig() { + Preconditions.checkNotNull(data, "Must specify data. It may not be null."); + } + + /** + * Do inference and return the results. + * + * Uses concatenation on the outputs of {@link #execBatches()} which may cause issues with some inputs. RNNs with + * variable time series length and CNNs with variable image sizes will most likely have issues. + */ + public Map exec() { + return sd.output(data, listeners, outputs.toArray(new String[0])); + } + + /** + * Do inference and return the results in batches. + */ + public List> execBatches() { + return sd.outputBatches(data, listeners, outputs.toArray(new String[0])); + } + + /** + * Do inference and return the results for the single output variable specified. + * + * Only works if exactly one output is specified. + * + * Uses concatenation on the outputs of {@link #execBatches()} which may cause issues with some inputs. RNNs with + * variable time series length and CNNs with variable image sizes will most likely have issues. + */ + public INDArray execSingle() { + Preconditions.checkState(outputs.size() == 1, + "Can only use execSingle() when exactly one output is specified, there were %s", outputs.size()); + + return sd.output(data, listeners, outputs.toArray(new String[0])).get(outputs.get(0)); + } + + + /** + * Do inference and return the results (in batches) for the single output variable specified. + * + * Only works if exactly one output is specified. + */ + public List execSingleBatches() { + Preconditions.checkState(outputs.size() == 1, + "Can only use execSingleBatches() when exactly one output is specified, there were %s", outputs.size()); + + return TrainingUtils + .getSingleOutput(sd.outputBatches(data, listeners, outputs.toArray(new String[0])), outputs.get(0)); + } +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/AbstractSession.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/AbstractSession.java index 1b520e7aa..8d806249d 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/AbstractSession.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/AbstractSession.java @@ -24,6 +24,7 @@ import lombok.extern.slf4j.Slf4j; import org.nd4j.autodiff.functions.DifferentialFunction; import org.nd4j.autodiff.listeners.At; import org.nd4j.autodiff.listeners.Listener; +import org.nd4j.autodiff.listeners.Operation; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.VariableType; @@ -31,6 +32,7 @@ import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.ops.impl.controlflow.compat.*; import java.util.*; +import org.nd4j.linalg.dataset.api.MultiDataSet; import org.nd4j.linalg.factory.Nd4j; /** @@ -133,16 +135,42 @@ public abstract class AbstractSession { return newVarId(variable, frameIter.getFrame(), frameIter.getIteration(), frameIter.getParentFrame()); } + /** + * @deprecated Use {@link #output(List, Map, MultiDataSet, Collection, List, At)}. + * + * @param training Uses Operation.TRAINING if true, otherwise Operation.INFERENCE + */ + @Deprecated + public Map output(@NonNull List variables, Map placeholderValues, + MultiDataSet batch, Collection requiredActivations, boolean training, At at){ + if(at == null){ + if(training) + at = At.defaultAt(Operation.TRAINING); + else + at = At.defaultAt(Operation.INFERENCE); + } + return output(variables, placeholderValues, batch, requiredActivations, Collections.emptyList(), at); + } + /** * Get the output of the session - i.e., perform inference/forward pass * * @param variables Name of the variables we want the arrays/activations for * @param placeholderValues The placeholder values (if any). + * @param batch The batch data, used to call Listener.opExecution + * @param requiredActivations Additional activations that are required. Won't be outputed, but opExecution will be called. May be null. * @return The specified variable values, optionally in the specified workspace */ - public Map output(@NonNull List variables, Map placeholderValues, List listeners, boolean training, At at) { + public Map output(@NonNull List variables, Map placeholderValues, + MultiDataSet batch, Collection requiredActivations, List listeners, At at) { + Preconditions.checkState(!variables.isEmpty(), "Variables to perform forward pass for must not be empty"); + if(requiredActivations == null) + requiredActivations = Collections.emptyList(); + + if(at == null) + at = At.defaultAt(); //Step 0: validation - that variables exist, placeholders have arrays, etc for (String s : variables) { @@ -164,7 +192,9 @@ public abstract class AbstractSession { //Step 1: determine subgraph structure we actually need to execute //Basic plan: work backwards from the variables we want, based on the graph structure, to work out what // we actually need to execute - initSubgraph(variables); + List allRequired = new ArrayList<>(requiredActivations); + allRequired.addAll(variables); + initSubgraph(allRequired); //Step 1a: Check that we have required placeholders List phNames = sameDiff.inputs(); @@ -198,7 +228,7 @@ public abstract class AbstractSession { // Some Keras layers (like GRU) do different things depending on whether the model is training. // We provide this value directly. if(s.endsWith("keras_learning_phase")){ - placeholderValues.put(s, (T) Nd4j.scalar(training)); + placeholderValues.put(s, (T) Nd4j.scalar(at.operation().isTrainingPhase())); } else { throw new IllegalStateException( "An input placeholder \"" + s + "\" is required to calculate the requested outputs," + @@ -302,7 +332,7 @@ public abstract class AbstractSession { //Execute op FrameIter frameIter = varToExec.toFrameIter(); O parameterizedOp = getAndParameterizeOp(opName, frameIter, inputsToVar, inputsToVarAllIter, constPhForVar, placeholderValues); - T[] opOutputValues = getOutputs(parameterizedOp, frameIter, inputsToVar, inputsToVarAllIter, constPhForVar, listeners, training, at); + T[] opOutputValues = getOutputs(parameterizedOp, frameIter, inputsToVar, inputsToVarAllIter, constPhForVar, listeners, at, batch); //Post execution: work out what is now available for exec @@ -831,7 +861,7 @@ public abstract class AbstractSession { * @return The outputs of the op */ public abstract T[] getOutputs(O op, FrameIter outputFrameIter, Set inputs, Set allIterInputs, Set constAndPhInputs, - List listeners, boolean training, At at); + List listeners, At at, MultiDataSet batch); /** * This method is used to record that the specified input is required for calculating the specified output. diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/DataTypesSession.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/DataTypesSession.java index 2a8303303..6eff336e8 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/DataTypesSession.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/DataTypesSession.java @@ -30,6 +30,7 @@ import java.util.ArrayList; import java.util.List; import java.util.Map; import java.util.Set; +import org.nd4j.linalg.dataset.api.MultiDataSet; /** * Infer datatypes for all variables. @@ -80,7 +81,7 @@ public class DataTypesSession extends AbstractSession inputs, Set allIterInputs, - Set constAndPhInputs, List listeners, boolean training, At at) { + Set constAndPhInputs, List listeners, At at, MultiDataSet batch) { List outTypes = op.getFn().calculateOutputDataTypes(op.getInputTypes()); if(dynamicUpdate) { diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/InferenceSession.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/InferenceSession.java index 86e107031..a98b03566 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/InferenceSession.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/InferenceSession.java @@ -16,6 +16,7 @@ package org.nd4j.autodiff.samediff.internal; +import com.google.common.collect.ImmutableMap; import lombok.NonNull; import lombok.extern.slf4j.Slf4j; import org.nd4j.autodiff.functions.DifferentialFunction; @@ -38,6 +39,7 @@ import org.nd4j.linalg.api.ops.impl.transforms.gradient.GradientBackwardsMarker; import org.nd4j.linalg.api.ops.impl.transforms.same.Identity; import org.nd4j.linalg.api.shape.LongShapeDescriptor; import org.nd4j.linalg.api.shape.Shape; +import org.nd4j.linalg.dataset.api.MultiDataSet; import org.nd4j.linalg.exception.ND4JIllegalStateException; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.indexing.INDArrayIndex; @@ -106,19 +108,34 @@ public class InferenceSession extends AbstractSession opInputs, Set allIterInputs, - Set constAndPhInputs, List listeners, boolean training, At at) { + Set constAndPhInputs, List listeners, At at, MultiDataSet batch) { if(listeners != null && listeners.size() > 0){ SameDiffOp sdOp = sameDiff.getOps().get(op.getOwnName()); for(Listener l : listeners){ - l.preOpExecution(sameDiff, at, training, sdOp); + if(l.isActive(at.operation())) + l.preOpExecution(sameDiff, at, sdOp); } } INDArray[] out = getOutputsHelper(op, outputFrameIter, opInputs, allIterInputs, constAndPhInputs); if(listeners != null && listeners.size() > 0){ SameDiffOp sdOp = sameDiff.getOps().get(op.getOwnName()); + + ImmutableMap.Builder namedOutsBuilder = ImmutableMap.builder(); + + for(int i = 0 ; i < out.length ; i++) + namedOutsBuilder.put(sdOp.outputsOfOp.get(i), out[i]); + + Map namedOuts = namedOutsBuilder.build(); + for(Listener l : listeners){ - l.opExecution(sameDiff, at, training, sdOp, out); + if(l.isActive(at.operation())) { + l.opExecution(sameDiff, at, batch, sdOp, out); + + for(String varName : namedOuts.keySet()){ + l.activationAvailable(sameDiff, at, batch, sdOp, varName, namedOuts.get(varName)); + } + } } } return out; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/validation/ActivationGradientCheckListener.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/validation/ActivationGradientCheckListener.java index fdbde5f77..a8865f972 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/validation/ActivationGradientCheckListener.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/validation/ActivationGradientCheckListener.java @@ -5,12 +5,14 @@ import lombok.NoArgsConstructor; import lombok.Setter; import org.nd4j.autodiff.listeners.At; import org.nd4j.autodiff.listeners.BaseListener; +import org.nd4j.autodiff.listeners.Operation; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.internal.SameDiffOp; import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.ndarray.INDArray; import java.util.List; +import org.nd4j.linalg.dataset.api.MultiDataSet; /** * A listener used for debugging and testing purposes - specifically for gradient checking activations internally in @@ -29,7 +31,12 @@ public class ActivationGradientCheckListener extends BaseListener { private double eps; @Override - public void opExecution(SameDiff sd, At at, boolean training, SameDiffOp op, INDArray[] outputs) { + public boolean isActive(Operation operation) { + return true; + } + + @Override + public void opExecution(SameDiff sd, At at, MultiDataSet batch, SameDiffOp op, INDArray[] outputs) { Preconditions.checkState(variableName != null, "No variable name has been set yet. Variable name must be set before using this listener"); Preconditions.checkState(eps != 0.0, "Epsilon has not been set"); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/validation/listeners/NonInplaceValidationListener.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/validation/listeners/NonInplaceValidationListener.java index 6adc71ec3..f5ae0693d 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/validation/listeners/NonInplaceValidationListener.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/validation/listeners/NonInplaceValidationListener.java @@ -1,9 +1,9 @@ package org.nd4j.autodiff.validation.listeners; import lombok.Getter; -import org.bytedeco.javacpp.Pointer; import org.nd4j.autodiff.listeners.At; import org.nd4j.autodiff.listeners.BaseListener; +import org.nd4j.autodiff.listeners.Operation; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.internal.SameDiffOp; import org.nd4j.base.Preconditions; @@ -14,6 +14,7 @@ import org.nd4j.linalg.api.ops.Op; import java.security.MessageDigest; import java.util.Arrays; import java.util.concurrent.atomic.AtomicInteger; +import org.nd4j.linalg.dataset.api.MultiDataSet; public class NonInplaceValidationListener extends BaseListener { @Getter @@ -30,7 +31,7 @@ public class NonInplaceValidationListener extends BaseListener { } @Override - public void preOpExecution(SameDiff sd, At at, boolean training, SameDiffOp op) { + public void preOpExecution(SameDiff sd, At at, SameDiffOp op) { if(op.getOp().isInPlace()){ //Don't check inplace op return; @@ -57,7 +58,7 @@ public class NonInplaceValidationListener extends BaseListener { } @Override - public void opExecution(SameDiff sd, At at, boolean training, SameDiffOp op, INDArray[] outputs) { + public void opExecution(SameDiff sd, At at, MultiDataSet batch, SameDiffOp op, INDArray[] outputs) { if(op.getOp().isInPlace()){ //Don't check inplace op return; @@ -124,4 +125,8 @@ public class NonInplaceValidationListener extends BaseListener { } } + @Override + public boolean isActive(Operation operation) { + return true; + } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/evaluation/IEvaluation.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/evaluation/IEvaluation.java index abb51d4c4..edb35a32b 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/evaluation/IEvaluation.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/evaluation/IEvaluation.java @@ -99,4 +99,14 @@ public interface IEvaluation extends Serializable { * @return */ String toYaml(); + + /** + * Get the value of a given metric for this evaluation. + */ + double getValue(IMetric metric); + + /** + * Get a new instance of this evaluation, with the same configuration but no data. + */ + T newInstance(); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/evaluation/IMetric.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/evaluation/IMetric.java new file mode 100644 index 000000000..dca00d47d --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/evaluation/IMetric.java @@ -0,0 +1,35 @@ +/* + * Copyright (c) 2015-2019 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.nd4j.evaluation; + +/** + * A metric used to get a double value from an {@link IEvaluation}. + * + * Examples: {@link org.nd4j.evaluation.classification.Evaluation.Metric#ACCURACY}, {@link org.nd4j.evaluation.classification.ROC.Metric#AUPRC}. + */ +public interface IMetric { + + /** + * The {@link IEvaluation} class this metric is for + */ + public Class getEvaluationClass(); + + /** + * Whether this metric should be minimized (aka whether lower values are better). + */ + public boolean minimize(); +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/evaluation/classification/Evaluation.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/evaluation/classification/Evaluation.java index 75be58371..cb0f5ac24 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/evaluation/classification/Evaluation.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/evaluation/classification/Evaluation.java @@ -22,7 +22,11 @@ import org.nd4j.base.Preconditions; import org.nd4j.evaluation.BaseEvaluation; import org.nd4j.evaluation.EvaluationAveraging; import org.nd4j.evaluation.EvaluationUtils; +import org.nd4j.evaluation.IEvaluation; +import org.nd4j.evaluation.IMetric; import org.nd4j.evaluation.meta.Prediction; +import org.nd4j.evaluation.regression.RegressionEvaluation; +import org.nd4j.evaluation.regression.RegressionEvaluation.Metric; import org.nd4j.evaluation.serde.ConfusionMatrixDeserializer; import org.nd4j.evaluation.serde.ConfusionMatrixSerializer; import org.nd4j.linalg.api.buffer.DataType; @@ -83,7 +87,18 @@ import java.util.*; @JsonIgnoreProperties({"confusionMatrixMetaData"}) public class Evaluation extends BaseEvaluation { - public enum Metric {ACCURACY, F1, PRECISION, RECALL, GMEASURE, MCC} + public enum Metric implements IMetric {ACCURACY, F1, PRECISION, RECALL, GMEASURE, MCC; + + @Override + public Class getEvaluationClass() { + return Evaluation.class; + } + + @Override + public boolean minimize() { + return false; + } + } //What to output from the precision/recall function when we encounter an edge case protected static final double DEFAULT_EDGE_VALUE = 0.0; @@ -122,6 +137,17 @@ public class Evaluation extends BaseEvaluation { @Getter @Setter protected int maxWarningClassesToPrint = 16; + protected Evaluation(int axis, Integer binaryPositiveClass, int topN, List labelsList, + Double binaryDecisionThreshold, INDArray costArray, int maxWarningClassesToPrint){ + this.axis = axis; + this.binaryPositiveClass = binaryPositiveClass; + this.topN = topN; + this.labelsList = labelsList; + this.binaryDecisionThreshold = binaryDecisionThreshold; + this.costArray = costArray; + this.maxWarningClassesToPrint = maxWarningClassesToPrint; + } + // Empty constructor public Evaluation() { this.topN = 1; @@ -190,6 +216,7 @@ public class Evaluation extends BaseEvaluation { if (labels != null) { createConfusion(labels.size()); } + this.topN = topN; if(labels != null && labels.size() == 2){ this.binaryPositiveClass = 1; @@ -1869,4 +1896,17 @@ public class Evaluation extends BaseEvaluation { public static Evaluation fromYaml(String yaml) { return fromYaml(yaml, Evaluation.class); } + + @Override + public double getValue(IMetric metric){ + if(metric instanceof Metric){ + return scoreForMetric((Metric) metric); + } else + throw new IllegalStateException("Can't get value for non-evaluation Metric " + metric); + } + + @Override + public Evaluation newInstance() { + return new Evaluation(axis, binaryPositiveClass, topN, labelsList, binaryDecisionThreshold, costArray, maxWarningClassesToPrint); + } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/evaluation/classification/EvaluationBinary.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/evaluation/classification/EvaluationBinary.java index bb4ad2396..7f6474163 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/evaluation/classification/EvaluationBinary.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/evaluation/classification/EvaluationBinary.java @@ -21,6 +21,9 @@ import lombok.EqualsAndHashCode; import lombok.NoArgsConstructor; import org.nd4j.evaluation.BaseEvaluation; import org.nd4j.evaluation.EvaluationUtils; +import org.nd4j.evaluation.IEvaluation; +import org.nd4j.evaluation.IMetric; +import org.nd4j.evaluation.classification.Evaluation.Metric; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.impl.broadcast.bool.BroadcastGreaterThan; @@ -58,7 +61,18 @@ import java.util.List; @Data public class EvaluationBinary extends BaseEvaluation { - public enum Metric {ACCURACY, F1, PRECISION, RECALL, GMEASURE, MCC, FAR} + public enum Metric implements IMetric {ACCURACY, F1, PRECISION, RECALL, GMEASURE, MCC, FAR; + + @Override + public Class getEvaluationClass() { + return EvaluationBinary.class; + } + + @Override + public boolean minimize() { + return false; + } + } public static final int DEFAULT_PRECISION = 4; public static final double DEFAULT_EDGE_VALUE = 0.0; @@ -80,6 +94,13 @@ public class EvaluationBinary extends BaseEvaluation { @JsonDeserialize(using = NDArrayTextDeSerializer.class) private INDArray decisionThreshold; + protected EvaluationBinary(int axis, ROCBinary rocBinary, List labels, INDArray decisionThreshold){ + this.axis = axis; + this.rocBinary = rocBinary; + this.labels = labels; + this.decisionThreshold = decisionThreshold; + } + /** * Create an EvaulationBinary instance with an optional decision threshold array. * @@ -452,10 +473,25 @@ public class EvaluationBinary extends BaseEvaluation { } /** - * Calculate the G-measure for the given output + * Macro average of the Matthews correlation coefficient (MCC) (see {@link #matthewsCorrelation(int)}) for all labels. + * + * @return The macro average of the MCC for all labels. + */ + public double averageMatthewsCorrelation() { + double ret = 0.0; + for (int i = 0; i < numLabels(); i++) { + ret += matthewsCorrelation(i); + } + + ret /= (double) numLabels(); + return ret; + } + + /** + * Calculate the macro average G-measure for the given output * * @param output The specified output - * @return The G-measure for the specified output + * @return The macro average of the G-measure for the specified output */ public double gMeasure(int output) { double precision = precision(output); @@ -463,6 +499,21 @@ public class EvaluationBinary extends BaseEvaluation { return EvaluationUtils.gMeasure(precision, recall); } + /** + * Average G-measure (see {@link #gMeasure(int)}) for all labels. + * + * @return The G-measure for all labels. + */ + public double averageGMeasure() { + double ret = 0.0; + for (int i = 0; i < numLabels(); i++) { + ret += gMeasure(i); + } + + ret /= (double) numLabels(); + return ret; + } + /** * Returns the false positive rate for a given label * @@ -679,5 +730,37 @@ public class EvaluationBinary extends BaseEvaluation { return fromYaml(yaml, EvaluationBinary.class); } + @Override + public double getValue(IMetric metric){ + if(metric instanceof Metric){ + switch ((Metric) metric){ + case ACCURACY: + return averageAccuracy(); + case F1: + return averageF1(); + case PRECISION: + return averagePrecision(); + case RECALL: + return averageRecall(); + case GMEASURE: + return averageGMeasure(); + case MCC: + return averageMatthewsCorrelation(); + case FAR: + return averageFalseAlarmRate(); + default: + throw new IllegalStateException("Can't get value for non-binary evaluation Metric " + metric); + } + } else + throw new IllegalStateException("Can't get value for non-binary evaluation Metric " + metric); + } + @Override + public EvaluationBinary newInstance() { + if(rocBinary != null) { + return new EvaluationBinary(axis, rocBinary.newInstance(), labels, decisionThreshold); + } else { + return new EvaluationBinary(axis, null, labels, decisionThreshold); + } + } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/evaluation/classification/EvaluationCalibration.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/evaluation/classification/EvaluationCalibration.java index 8d5ff279f..08528be10 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/evaluation/classification/EvaluationCalibration.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/evaluation/classification/EvaluationCalibration.java @@ -21,6 +21,8 @@ import lombok.Getter; import lombok.val; import org.nd4j.base.Preconditions; import org.nd4j.evaluation.BaseEvaluation; +import org.nd4j.evaluation.IMetric; +import org.nd4j.evaluation.classification.Evaluation.Metric; import org.nd4j.evaluation.curves.Histogram; import org.nd4j.evaluation.curves.ReliabilityDiagram; import org.nd4j.linalg.api.buffer.DataType; @@ -105,6 +107,13 @@ public class EvaluationCalibration extends BaseEvaluation @JsonDeserialize(using = NDArrayDeSerializer.class) private INDArray probHistogramByLabelClass; //Histogram - for each label class separately + protected EvaluationCalibration(int axis, int reliabilityDiagNumBins, int histogramNumBins, boolean excludeEmptyBins) { + this.axis = axis; + this.reliabilityDiagNumBins = reliabilityDiagNumBins; + this.histogramNumBins = histogramNumBins; + this.excludeEmptyBins = excludeEmptyBins; + } + /** * Create an EvaluationCalibration instance with the default number of bins */ @@ -476,4 +485,14 @@ public class EvaluationCalibration extends BaseEvaluation public static EvaluationCalibration fromJson(String json){ return fromJson(json, EvaluationCalibration.class); } + + @Override + public double getValue(IMetric metric){ + throw new IllegalStateException("Can't get value for non-calibration Metric " + metric); + } + + @Override + public EvaluationCalibration newInstance() { + return new EvaluationCalibration(axis, reliabilityDiagNumBins, histogramNumBins, excludeEmptyBins); + } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/evaluation/classification/ROC.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/evaluation/classification/ROC.java index b124f14d7..7e4de14c0 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/evaluation/classification/ROC.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/evaluation/classification/ROC.java @@ -20,6 +20,9 @@ import lombok.*; import org.apache.commons.lang3.ArrayUtils; import org.nd4j.base.Preconditions; import org.nd4j.evaluation.BaseEvaluation; +import org.nd4j.evaluation.IEvaluation; +import org.nd4j.evaluation.IMetric; +import org.nd4j.evaluation.classification.Evaluation.Metric; import org.nd4j.evaluation.curves.PrecisionRecallCurve; import org.nd4j.evaluation.curves.RocCurve; import org.nd4j.evaluation.serde.ROCSerializer; @@ -77,11 +80,24 @@ import static org.nd4j.linalg.indexing.NDArrayIndex.interval; @JsonSerialize(using = ROCSerializer.class) @JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY) public class ROC extends BaseEvaluation { + /** * AUROC: Area under ROC curve
* AUPRC: Area under Precision-Recall Curve */ - public enum Metric {AUROC, AUPRC} + public enum Metric implements IMetric { + AUROC, AUPRC; + + @Override + public Class getEvaluationClass() { + return ROC.class; + } + + @Override + public boolean minimize() { + return false; + } + } private static final int DEFAULT_EXACT_ALLOC_BLOCK_SIZE = 2048; private final Map counts = new LinkedHashMap<>(); @@ -100,6 +116,13 @@ public class ROC extends BaseEvaluation { private int exactAllocBlockSize; protected int axis = 1; + + + public ROC(int thresholdSteps, boolean rocRemoveRedundantPts, int exactAllocBlockSize, int axis) { + this(thresholdSteps, rocRemoveRedundantPts, exactAllocBlockSize); + this.axis = axis; + } + public ROC() { //Default to exact this(0); @@ -811,4 +834,17 @@ public class ROC extends BaseEvaluation { throw new IllegalStateException("Unknown metric: " + metric); } } + + @Override + public double getValue(IMetric metric){ + if(metric instanceof Metric){ + return scoreForMetric((Metric) metric); + } else + throw new IllegalStateException("Can't get value for non-ROC Metric " + metric); + } + + @Override + public ROC newInstance() { + return new ROC(thresholdSteps, rocRemoveRedundantPts, exactAllocBlockSize, axis); + } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/evaluation/classification/ROCBinary.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/evaluation/classification/ROCBinary.java index e58b019a8..3695a692b 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/evaluation/classification/ROCBinary.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/evaluation/classification/ROCBinary.java @@ -21,6 +21,9 @@ import lombok.EqualsAndHashCode; import lombok.val; import org.nd4j.base.Preconditions; import org.nd4j.evaluation.BaseEvaluation; +import org.nd4j.evaluation.IEvaluation; +import org.nd4j.evaluation.IMetric; +import org.nd4j.evaluation.classification.ROCMultiClass.Metric; import org.nd4j.evaluation.curves.PrecisionRecallCurve; import org.nd4j.evaluation.curves.RocCurve; import org.nd4j.evaluation.serde.ROCArraySerializer; @@ -53,7 +56,18 @@ public class ROCBinary extends BaseEvaluation { * AUROC: Area under ROC curve
* AUPRC: Area under Precision-Recall Curve */ - public enum Metric {AUROC, AUPRC} + public enum Metric implements IMetric {AUROC, AUPRC; + + @Override + public Class getEvaluationClass() { + return ROCBinary.class; + } + + @Override + public boolean minimize() { + return false; + } + } @JsonSerialize(using = ROCArraySerializer.class) private ROC[] underlying; @@ -65,6 +79,13 @@ public class ROCBinary extends BaseEvaluation { @EqualsAndHashCode.Exclude //Exclude axis: otherwise 2 Evaluation instances could contain identical stats and fail equality protected int axis = 1; + protected ROCBinary(int axis, int thresholdSteps, boolean rocRemoveRedundantPts, List labels) { + this.thresholdSteps = thresholdSteps; + this.rocRemoveRedundantPts = rocRemoveRedundantPts; + this.axis = axis; + this.labels = labels; + } + public ROCBinary() { this(0); } @@ -410,4 +431,22 @@ public class ROCBinary extends BaseEvaluation { throw new IllegalStateException("Unknown metric: " + metric); } } + + @Override + public double getValue(IMetric metric){ + if(metric instanceof Metric){ + if(metric == Metric.AUPRC) + return calculateAverageAUCPR(); + else if(metric == Metric.AUROC) + return calculateAverageAuc(); + else + throw new IllegalStateException("Can't get value for non-binary ROC Metric " + metric); + } else + throw new IllegalStateException("Can't get value for non-binary ROC Metric " + metric); + } + + @Override + public ROCBinary newInstance() { + return new ROCBinary(axis, thresholdSteps, rocRemoveRedundantPts, labels); + } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/evaluation/classification/ROCMultiClass.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/evaluation/classification/ROCMultiClass.java index 8cf2c3aca..07266399a 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/evaluation/classification/ROCMultiClass.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/evaluation/classification/ROCMultiClass.java @@ -20,6 +20,9 @@ import lombok.Data; import lombok.EqualsAndHashCode; import org.nd4j.base.Preconditions; import org.nd4j.evaluation.BaseEvaluation; +import org.nd4j.evaluation.IEvaluation; +import org.nd4j.evaluation.IMetric; +import org.nd4j.evaluation.classification.ROC.Metric; import org.nd4j.evaluation.curves.PrecisionRecallCurve; import org.nd4j.evaluation.curves.RocCurve; import org.nd4j.evaluation.serde.ROCArraySerializer; @@ -49,7 +52,18 @@ public class ROCMultiClass extends BaseEvaluation { * AUROC: Area under ROC curve
* AUPRC: Area under Precision-Recall Curve */ - public enum Metric {AUROC, AUPRC} + public enum Metric implements IMetric {AUROC, AUPRC; + + @Override + public Class getEvaluationClass() { + return ROCMultiClass.class; + } + + @Override + public boolean minimize() { + return false; + } + } private int thresholdSteps; private boolean rocRemoveRedundantPts; @@ -60,6 +74,13 @@ public class ROCMultiClass extends BaseEvaluation { @EqualsAndHashCode.Exclude //Exclude axis: otherwise 2 Evaluation instances could contain identical stats and fail equality protected int axis = 1; + protected ROCMultiClass(int axis, int thresholdSteps, boolean rocRemoveRedundantPts, List labels) { + this.thresholdSteps = thresholdSteps; + this.rocRemoveRedundantPts = rocRemoveRedundantPts; + this.axis = axis; + this.labels = labels; + } + public ROCMultiClass() { //Default to exact this(0); @@ -362,4 +383,22 @@ public class ROCMultiClass extends BaseEvaluation { throw new IllegalStateException("Unknown metric: " + metric); } } + + @Override + public double getValue(IMetric metric){ + if(metric instanceof Metric){ + if(metric == Metric.AUPRC) + return calculateAverageAUCPR(); + else if(metric == Metric.AUROC) + return calculateAverageAUC(); + else + throw new IllegalStateException("Can't get value for non-ROC Metric " + metric); + } else + throw new IllegalStateException("Can't get value for non-ROC Metric " + metric); + } + + @Override + public ROCMultiClass newInstance() { + return new ROCMultiClass(axis, thresholdSteps, rocRemoveRedundantPts, labels); + } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/evaluation/custom/CustomEvaluation.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/evaluation/custom/CustomEvaluation.java new file mode 100644 index 000000000..fed8c63b4 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/evaluation/custom/CustomEvaluation.java @@ -0,0 +1,182 @@ +/* + * Copyright (c) 2015-2019 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.nd4j.evaluation.custom; + +import com.google.common.collect.Lists; +import java.io.Serializable; +import java.util.ArrayList; +import java.util.List; +import lombok.AllArgsConstructor; +import lombok.Data; +import lombok.EqualsAndHashCode; +import lombok.Getter; +import lombok.NonNull; +import lombok.RequiredArgsConstructor; +import org.nd4j.evaluation.BaseEvaluation; +import org.nd4j.evaluation.IEvaluation; +import org.nd4j.evaluation.IMetric; +import org.nd4j.evaluation.regression.RegressionEvaluation; +import org.nd4j.evaluation.regression.RegressionEvaluation.Metric; +import org.nd4j.linalg.api.ndarray.INDArray; + +/** + * A evaluation using lambdas to calculate the score. + * + * Uses 3 lambdas:
+ * EvaluationLambda: takes in the labels, predictions, mask, and metadata and returns a value of type T
+ * MergeLambda: takes in two lists of Ts, returns one. Used in merging for distributed training
+* ResultLambda (in Metric): takes a list of Ts, returns a double value
+ *
+ * The EvaluationLambda will be called on each batch, and the results will be stored in a list. + * MergeLambda merges two of those lists for distributed training (think Spark or Map-Reduce). + * ResultLambda gets a score from that list. + * + */ +@Data +@EqualsAndHashCode(callSuper = true) +public class CustomEvaluation extends BaseEvaluation { + + /** + * The metric used to get a score for the CustomEvaluation. Uses a ResultLambda + */ + @AllArgsConstructor + @RequiredArgsConstructor + public static class Metric implements IMetric{ + + @Getter + @NonNull private ResultLambda getResult; + + private boolean minimize = false; + + @Override + public Class getEvaluationClass() { + return CustomEvaluation.class; + } + + @Override + public boolean minimize() { + return minimize; + } + + /** + * A metric that takes the average of a list of doubles + */ + public static Metric doubleAverage(boolean minimize){ + return new Metric<>(new ResultLambda() { + @Override + public double toResult(List data) { + int count = 0; + double sum = 0; + for (Double d : data) { + count++; + sum += d; + } + return sum / count; + } + }, minimize); + } + + + /** + * A metric that takes the max of a list of doubles + */ + public static Metric doubleMax(boolean minimize){ + return new Metric<>(new ResultLambda() { + @Override + public double toResult(List data) { + double max = 0; + for (Double d : data) { + if(d > max) + max = d; + } + return max; + } + }, minimize); + } + + + /** + * A metric that takes the min of a list of doubles + */ + public static Metric doubleMin(boolean minimize){ + return new Metric<>(new ResultLambda() { + @Override + public double toResult(List data) { + double max = 0; + for (Double d : data) { + if(d < max) + max = d; + } + return max; + } + }, minimize); + } + } + + /** + * A MergeLambda that merges by concatenating the two lists + */ + public static MergeLambda mergeConcatenate(){ + return new MergeLambda() { + @Override + public List merge(List a, List b) { + List res = Lists.newArrayList(a); + res.addAll(b); + return res; + } + }; + } + + @NonNull private EvaluationLambda evaluationLambda; + @NonNull private MergeLambda mergeLambda; + + private List evaluations = new ArrayList<>(); + + @Override + public void eval(INDArray labels, INDArray networkPredictions, INDArray maskArray, + List recordMetaData) { + evaluations.add(evaluationLambda.eval(labels, networkPredictions, maskArray, recordMetaData)); + } + + @Override + public void merge(CustomEvaluation other) { + evaluations = mergeLambda.merge(evaluations, other.evaluations); + } + + @Override + public void reset() { + evaluations = new ArrayList<>(); + } + + @Override + public String stats() { + return ""; + } + + @Override + public double getValue(IMetric metric) { + if(metric instanceof Metric){ + return ((Metric) metric).getGetResult().toResult(evaluations); + } else + throw new IllegalStateException("Can't get value for non-regression Metric " + metric); + } + + @Override + public CustomEvaluation newInstance() { + return new CustomEvaluation(evaluationLambda, mergeLambda); + } +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/evaluation/custom/EvaluationLambda.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/evaluation/custom/EvaluationLambda.java new file mode 100644 index 000000000..47eeb3918 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/evaluation/custom/EvaluationLambda.java @@ -0,0 +1,31 @@ +/* + * Copyright (c) 2015-2019 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.nd4j.evaluation.custom; + +import java.io.Serializable; +import java.util.List; + +import org.nd4j.linalg.api.ndarray.INDArray; + +/** + * A lambda used to get an evaluation result for a batch + * See {@link CustomEvaluation} + */ +public interface EvaluationLambda { + public T eval(INDArray labels, INDArray networkPredictions, INDArray maskArray, + List recordMetaData); +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/evaluation/custom/MergeLambda.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/evaluation/custom/MergeLambda.java new file mode 100644 index 000000000..cbad73da3 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/evaluation/custom/MergeLambda.java @@ -0,0 +1,29 @@ +/* + * Copyright (c) 2015-2019 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.nd4j.evaluation.custom; + +import com.google.common.collect.Lists; + +import java.util.List; + +/** + * A lambda used to merge two lists of evaluation results + * See {@link CustomEvaluation} + */ +public interface MergeLambda { + public List merge(List a, List b); +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/evaluation/custom/ResultLambda.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/evaluation/custom/ResultLambda.java new file mode 100644 index 000000000..33e784175 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/evaluation/custom/ResultLambda.java @@ -0,0 +1,27 @@ +/* + * Copyright (c) 2015-2019 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.nd4j.evaluation.custom; + +import java.util.List; + +/** + * A lambda used to get a score from a list of evaluation results + * See {@link CustomEvaluation} + */ +public interface ResultLambda { + public double toResult(List data); +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/evaluation/regression/RegressionEvaluation.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/evaluation/regression/RegressionEvaluation.java index 67c8e94d5..e1a0d1f82 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/evaluation/regression/RegressionEvaluation.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/evaluation/regression/RegressionEvaluation.java @@ -20,6 +20,8 @@ import lombok.Data; import lombok.EqualsAndHashCode; import lombok.val; import org.nd4j.evaluation.BaseEvaluation; +import org.nd4j.evaluation.IEvaluation; +import org.nd4j.evaluation.IMetric; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.impl.reduce.same.ASum; @@ -54,12 +56,18 @@ import java.util.List; @EqualsAndHashCode(callSuper = true) public class RegressionEvaluation extends BaseEvaluation { - public enum Metric { MSE, MAE, RMSE, RSE, PC, R2; + public enum Metric implements IMetric { MSE, MAE, RMSE, RSE, PC, R2; + + @Override + public Class getEvaluationClass() { + return RegressionEvaluation.class; + } /** * @return True if the metric should be minimized, or false if the metric should be maximized. * For example, MSE of 0 is best, but R^2 of 1.0 is best */ + @Override public boolean minimize(){ if(this == R2 || this == PC){ return false; @@ -106,6 +114,12 @@ public class RegressionEvaluation extends BaseEvaluation { @JsonDeserialize(using = NDArrayTextDeSerializer.class) private INDArray sumLabels; + protected RegressionEvaluation(int axis, List columnNames, long precision){ + this.axis = axis; + this.columnNames = columnNames; + this.precision = precision; + } + public RegressionEvaluation() { this(null, DEFAULT_PRECISION); } @@ -568,6 +582,14 @@ public class RegressionEvaluation extends BaseEvaluation { return ret / (double) numColumns(); } + @Override + public double getValue(IMetric metric){ + if(metric instanceof Metric){ + return scoreForMetric((Metric) metric); + } else + throw new IllegalStateException("Can't get value for non-regression Metric " + metric); + } + public double scoreForMetric(Metric metric){ switch (metric){ case MSE: @@ -590,4 +612,9 @@ public class RegressionEvaluation extends BaseEvaluation { public static RegressionEvaluation fromJson(String json){ return fromJson(json, RegressionEvaluation.class); } + + @Override + public RegressionEvaluation newInstance() { + return new RegressionEvaluation(axis, columnNames, precision); + } } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/TestSessions.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/TestSessions.java index a46cfdcab..15ed777d4 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/TestSessions.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/TestSessions.java @@ -72,7 +72,8 @@ public class TestSessions extends BaseNd4jTest { m.put("x", x); m.put("y", y); - Map outMap = is.output(Collections.singletonList("out"), m, null, true, null); + Map outMap = is.output(Collections.singletonList("out"), m, null, + Collections.emptyList(), true, null); assertEquals(1, outMap.size()); assertEquals(outExp, outMap.get("out")); @@ -109,7 +110,8 @@ public class TestSessions extends BaseNd4jTest { m.put("y", y); System.out.println("----------------------------------"); - Map outMap = is.output(Collections.singletonList("d"), m, null, false, null); + Map outMap = is.output(Collections.singletonList("d"), m, null, + Collections.emptyList(), false, null); assertEquals(1, outMap.size()); assertEquals(dExp, outMap.get("d")); @@ -143,7 +145,8 @@ public class TestSessions extends BaseNd4jTest { InferenceSession is = new InferenceSession(sd); // String outName = merge.getVarName(); String outName = outVar.getVarName(); - Map outMap = is.output(Collections.singletonList(outName), m, null, false, null); + Map outMap = is.output(Collections.singletonList(outName), m, null, + Collections.emptyList(), false, null); assertEquals(1, outMap.size()); INDArray out = outMap.get(outName); @@ -178,7 +181,8 @@ public class TestSessions extends BaseNd4jTest { String n = merge.getVarName(); System.out.println("----------------------------------"); - Map outMap = is.output(Collections.singletonList(n), m, null, false, null); + Map outMap = is.output(Collections.singletonList(n), m, null, Collections.emptyList(), + false, null); assertEquals(1, outMap.size()); assertEquals(expTrue, outMap.get(n)); @@ -187,7 +191,7 @@ public class TestSessions extends BaseNd4jTest { //Check false case: bArr.assign(0); is = new InferenceSession(sd); - outMap = is.output(Collections.singletonList(n), m, null, false, null); + outMap = is.output(Collections.singletonList(n), m, null, Collections.emptyList(), false, null); assertEquals(1, outMap.size()); assertEquals(expFalse, outMap.get(n)); } @@ -218,7 +222,8 @@ public class TestSessions extends BaseNd4jTest { String n = "while/Exit"; String n2 = "while/Exit_1"; - Map m = is.output(Arrays.asList(n, n2), Collections.emptyMap(), null, false, null); + Map m = is.output(Arrays.asList(n, n2), Collections.emptyMap(), null, + Collections.emptyList(), false, null); assertEquals(2, m.size()); INDArray exp = Nd4j.scalar((float)numIter); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/listeners/ListenerTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/listeners/ListenerTest.java new file mode 100644 index 000000000..8f1ab016f --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/listeners/ListenerTest.java @@ -0,0 +1,119 @@ +/* + * Copyright (c) 2015-2019 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.nd4j.autodiff.samediff.listeners; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; + +import java.util.Arrays; +import java.util.List; + +import org.junit.Test; +import org.nd4j.autodiff.listeners.Operation; +import org.nd4j.autodiff.listeners.impl.ScoreListener; +import org.nd4j.autodiff.listeners.records.History; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.autodiff.samediff.TrainingConfig; +import org.nd4j.evaluation.classification.Evaluation; +import org.nd4j.evaluation.classification.Evaluation.Metric; +import org.nd4j.linalg.BaseNd4jTest; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.dataset.IrisDataSetIterator; +import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; +import org.nd4j.linalg.dataset.api.preprocessor.NormalizerStandardize; +import org.nd4j.linalg.factory.Nd4j; +import org.nd4j.linalg.factory.Nd4jBackend; +import org.nd4j.linalg.learning.config.Adam; +import org.nd4j.linalg.learning.config.IUpdater; +import org.nd4j.weightinit.impl.XavierInitScheme; + +public class ListenerTest extends BaseNd4jTest { + + public ListenerTest(Nd4jBackend backend) { + super(backend); + } + + @Test + public void irisHistoryTest() { + + DataSetIterator iter = new IrisDataSetIterator(150, 150); + NormalizerStandardize std = new NormalizerStandardize(); + std.fit(iter); + iter.setPreProcessor(std); + + Nd4j.getRandom().setSeed(12345); + SameDiff sd = SameDiff.create(); + + SDVariable in = sd.placeHolder("input", DataType.FLOAT, -1, 4); + SDVariable label = sd.placeHolder("label", DataType.FLOAT, -1, 3); + + SDVariable w0 = sd.var("w0", new XavierInitScheme('c', 4, 10), DataType.FLOAT, 4, 10); + SDVariable b0 = sd.zero("b0", DataType.FLOAT, 1, 10); + + SDVariable w1 = sd.var("w1", new XavierInitScheme('c', 10, 3), DataType.FLOAT, 10, 3); + SDVariable b1 = sd.zero("b1", DataType.FLOAT, 1, 3); + + SDVariable z0 = in.mmul(w0).add(b0); + SDVariable a0 = sd.nn().relu(z0, 0); + SDVariable z1 = a0.mmul(w1).add(b1); + SDVariable predictions = sd.nn().softmax("predictions", z1, 1); + + SDVariable loss = sd.loss.softmaxCrossEntropy("loss", label, predictions); + + sd.setLossVariables("loss"); + + IUpdater updater = new Adam(1e-2); + + Evaluation e = new Evaluation(); + + TrainingConfig conf = new TrainingConfig.Builder() + .l2(1e-4) + .updater(updater) + .dataSetFeatureMapping("input") + .dataSetLabelMapping("label") + .trainEvaluation(predictions, 0, e) + .build(); + + sd.setTrainingConfig(conf); + + sd.setListeners(new ScoreListener(1)); + + History hist = sd.fit(iter, 50); +// Map> evalMap = new HashMap<>(); +// evalMap.put("prediction", Collections.singletonList(e)); +// +// sd.evaluateMultiple(iter, evalMap); + + e = (Evaluation) hist.finalTrainingEvaluations().evaluation(predictions); + + System.out.println(e.stats()); + + float[] losses = hist.lossCurve().meanLoss(loss); + + System.out.println("Losses: " + Arrays.toString(losses)); + + double acc = hist.finalTrainingEvaluations().getValue(Metric.ACCURACY); + assertTrue("Accuracy < 75%, was " + acc, acc >= 0.75); + } + + @Override + public char ordering() { + return 'c'; + } +} diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/evaluation/CustomEvaluationTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/evaluation/CustomEvaluationTest.java new file mode 100644 index 000000000..43583d570 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/evaluation/CustomEvaluationTest.java @@ -0,0 +1,65 @@ +/* + * Copyright (c) 2015-2019 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.nd4j.evaluation; + +import static org.junit.Assert.assertEquals; + +import java.util.Arrays; +import org.junit.Test; +import org.nd4j.evaluation.custom.CustomEvaluation; +import org.nd4j.linalg.BaseNd4jTest; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.factory.Nd4j; +import org.nd4j.linalg.factory.Nd4jBackend; +import org.nd4j.linalg.primitives.Pair; + +public class CustomEvaluationTest extends BaseNd4jTest { + + public CustomEvaluationTest(Nd4jBackend backend) { + super(backend); + } + + @Override + public char ordering() { + return 'c'; + } + + @Test + public void customEvalTest(){ + CustomEvaluation accuracyEval = new CustomEvaluation>( + (labels, pred, mask, meta) -> new Pair<>(labels.eq(pred).castTo(DataType.INT).sumNumber(), labels.size(0)), + CustomEvaluation.mergeConcatenate()); + + accuracyEval.eval(Nd4j.createFromArray(1, 1, 2, 1, 3), Nd4j.createFromArray(1, 1, 4, 1, 2)); + + double acc = accuracyEval.getValue(new CustomEvaluation.Metric>( + (list) -> { + int sum = 0; + int count = 0; + for(Pair p : list){ + sum += p.getFirst().intValue(); + count += p.getSecond(); + } + return ((double) (sum)) / count; + } + )); + + assertEquals("Accuracy", acc, 3.0/5, 0.001); + + } + +} diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/evaluation/EmptyEvaluationTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/evaluation/EmptyEvaluationTests.java index 4d39fcbae..372112784 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/evaluation/EmptyEvaluationTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/evaluation/EmptyEvaluationTests.java @@ -3,6 +3,7 @@ package org.nd4j.evaluation; import org.junit.Test; import org.nd4j.evaluation.classification.*; import org.nd4j.evaluation.regression.RegressionEvaluation; +import org.nd4j.evaluation.regression.RegressionEvaluation.Metric; import org.nd4j.linalg.BaseNd4jTest; import org.nd4j.linalg.factory.Nd4jBackend; @@ -40,7 +41,7 @@ public class EmptyEvaluationTests extends BaseNd4jTest { RegressionEvaluation re = new RegressionEvaluation(); re.stats(); - for (RegressionEvaluation.Metric m : RegressionEvaluation.Metric.values()) { + for (Metric m : Metric.values()) { try { re.scoreForMetric(m); } catch (Throwable t){ diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/evaluation/NewInstanceTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/evaluation/NewInstanceTest.java new file mode 100644 index 000000000..8bf42674a --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/evaluation/NewInstanceTest.java @@ -0,0 +1,117 @@ +/* + * Copyright (c) 2015-2019 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.nd4j.evaluation; + +import static org.junit.Assert.assertEquals; + +import org.junit.Test; +import org.nd4j.evaluation.classification.Evaluation; +import org.nd4j.evaluation.classification.EvaluationBinary; +import org.nd4j.evaluation.classification.EvaluationCalibration; +import org.nd4j.evaluation.classification.ROC; +import org.nd4j.evaluation.classification.ROCBinary; +import org.nd4j.evaluation.classification.ROCMultiClass; +import org.nd4j.evaluation.regression.RegressionEvaluation; +import org.nd4j.linalg.BaseNd4jTest; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.random.impl.BernoulliDistribution; +import org.nd4j.linalg.factory.Nd4j; +import org.nd4j.linalg.factory.Nd4jBackend; + +public class NewInstanceTest extends BaseNd4jTest { + + public NewInstanceTest(Nd4jBackend backend) { + super(backend); + } + + @Override + public char ordering() { + return 'c'; + } + + @Test + public void testNewInstances() { + boolean print = true; + Nd4j.getRandom().setSeed(12345); + + Evaluation evaluation = new Evaluation(); + EvaluationBinary evaluationBinary = new EvaluationBinary(); + ROC roc = new ROC(2); + ROCBinary roc2 = new ROCBinary(2); + ROCMultiClass roc3 = new ROCMultiClass(2); + RegressionEvaluation regressionEvaluation = new RegressionEvaluation(); + EvaluationCalibration ec = new EvaluationCalibration(); + + + IEvaluation[] arr = new IEvaluation[] {evaluation, evaluationBinary, roc, roc2, roc3, regressionEvaluation, ec}; + + INDArray evalLabel1 = Nd4j.create(10, 3); + for (int i = 0; i < 10; i++) { + evalLabel1.putScalar(i, i % 3, 1.0); + } + INDArray evalProb1 = Nd4j.rand(10, 3); + evalProb1.diviColumnVector(evalProb1.sum(1)); + + evaluation.eval(evalLabel1, evalProb1); + roc3.eval(evalLabel1, evalProb1); + ec.eval(evalLabel1, evalProb1); + + INDArray evalLabel2 = Nd4j.getExecutioner().exec(new BernoulliDistribution(Nd4j.createUninitialized(10, 3), 0.5)); + INDArray evalProb2 = Nd4j.rand(10, 3); + evaluationBinary.eval(evalLabel2, evalProb2); + roc2.eval(evalLabel2, evalProb2); + + INDArray evalLabel3 = Nd4j.getExecutioner().exec(new BernoulliDistribution(Nd4j.createUninitialized(10, 1), 0.5)); + INDArray evalProb3 = Nd4j.rand(10, 1); + roc.eval(evalLabel3, evalProb3); + + INDArray reg1 = Nd4j.rand(10, 3); + INDArray reg2 = Nd4j.rand(10, 3); + + regressionEvaluation.eval(reg1, reg2); + + Evaluation evaluation2 = evaluation.newInstance(); + EvaluationBinary evaluationBinary2 = evaluationBinary.newInstance(); + ROC roc_2 = roc.newInstance(); + ROCBinary roc22 = roc2.newInstance(); + ROCMultiClass roc32 = roc3.newInstance(); + RegressionEvaluation regressionEvaluation2 = regressionEvaluation.newInstance(); + EvaluationCalibration ec2 = ec.newInstance(); + + IEvaluation[] arr2 = new IEvaluation[] {evaluation2, evaluationBinary2, roc_2, roc22, roc32, regressionEvaluation2, ec2}; + + evaluation2.eval(evalLabel1, evalProb1); + roc32.eval(evalLabel1, evalProb1); + ec2.eval(evalLabel1, evalProb1); + + evaluationBinary2.eval(evalLabel2, evalProb2); + roc22.eval(evalLabel2, evalProb2); + + roc_2.eval(evalLabel3, evalProb3); + + regressionEvaluation2.eval(reg1, reg2); + + for (int i = 0 ; i < arr.length ; i++) { + + IEvaluation e = arr[i]; + IEvaluation e2 = arr2[i]; + assertEquals("Json not equal ", e.toJson(), e2.toJson()); + assertEquals("Stats not equal ", e.stats(), e2.stats()); + } + } + +} diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/evaluation/RegressionEvalTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/evaluation/RegressionEvalTest.java index 92acb1a20..5c3ca3e21 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/evaluation/RegressionEvalTest.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/evaluation/RegressionEvalTest.java @@ -17,8 +17,8 @@ package org.nd4j.evaluation; import org.junit.Test; -import org.nd4j.evaluation.classification.EvaluationCalibration; import org.nd4j.evaluation.regression.RegressionEvaluation; +import org.nd4j.evaluation.regression.RegressionEvaluation.Metric; import org.nd4j.linalg.BaseNd4jTest; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.iter.NdIndexIterator; @@ -256,7 +256,7 @@ public class RegressionEvalTest extends BaseNd4jTest { e3d.eval(label, prediction); e2d.eval(l2d, p2d); - for (RegressionEvaluation.Metric m : RegressionEvaluation.Metric.values()) { + for (Metric m : Metric.values()) { double d1 = e3d.scoreForMetric(m); double d2 = e2d.scoreForMetric(m); assertEquals(m.toString(), d2, d1, 1e-6); @@ -288,7 +288,7 @@ public class RegressionEvalTest extends BaseNd4jTest { e4d.eval(label, prediction); e2d.eval(l2d, p2d); - for (RegressionEvaluation.Metric m : RegressionEvaluation.Metric.values()) { + for (Metric m : Metric.values()) { double d1 = e4d.scoreForMetric(m); double d2 = e2d.scoreForMetric(m); assertEquals(m.toString(), d2, d1, 1e-6); @@ -347,7 +347,7 @@ public class RegressionEvalTest extends BaseNd4jTest { RegressionEvaluation e2d_m2 = new RegressionEvaluation(); e4d_m2.eval(label, prediction, perOutMask); e2d_m2.eval(l2d, p2d, m2d); - for(RegressionEvaluation.Metric m : RegressionEvaluation.Metric.values()){ + for(Metric m : Metric.values()){ double d1 = e4d_m2.scoreForMetric(m); double d2 = e2d_m2.scoreForMetric(m); assertEquals(m.toString(), d2, d1, 1e-6); @@ -382,7 +382,7 @@ public class RegressionEvalTest extends BaseNd4jTest { RegressionEvaluation e2d_m1 = new RegressionEvaluation(); e4d_m1.eval(label, prediction, mask1dPerEx); e2d_m1.eval(l2d, p2d); - for(RegressionEvaluation.Metric m : RegressionEvaluation.Metric.values()){ + for(Metric m : Metric.values()){ double d1 = e4d_m1.scoreForMetric(m); double d2 = e2d_m1.scoreForMetric(m); assertEquals(m.toString(), d2, d1, 1e-6); @@ -409,7 +409,7 @@ public class RegressionEvalTest extends BaseNd4jTest { RegressionEvaluation e2d_m2 = new RegressionEvaluation(); e4d_m2.eval(label, prediction, perOutMask); e2d_m2.eval(l2d, p2d, m2d); - for(RegressionEvaluation.Metric m : RegressionEvaluation.Metric.values()){ + for(Metric m : Metric.values()){ double d1 = e4d_m2.scoreForMetric(m); double d2 = e2d_m2.scoreForMetric(m); assertEquals(m.toString(), d2, d1, 1e-6); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/listener/OpExecOrderListener.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/listener/OpExecOrderListener.java index 28e815982..0653defa0 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/listener/OpExecOrderListener.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/listener/OpExecOrderListener.java @@ -4,11 +4,13 @@ import lombok.Getter; import lombok.Setter; import org.nd4j.autodiff.listeners.At; import org.nd4j.autodiff.listeners.BaseListener; +import org.nd4j.autodiff.listeners.Operation; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.internal.SameDiffOp; import org.nd4j.linalg.api.ndarray.INDArray; import java.util.*; +import org.nd4j.linalg.dataset.api.MultiDataSet; public class OpExecOrderListener extends BaseListener { @@ -22,7 +24,7 @@ public class OpExecOrderListener extends BaseListener { } @Override - public void opExecution(SameDiff sd, At at, boolean training, SameDiffOp op, INDArray[] outputs) { + public void opExecution(SameDiff sd, At at, MultiDataSet batch, SameDiffOp op, INDArray[] outputs) { String opName = op.getName(); if(!opSet.contains(opName)){ opNamesList.add(opName); @@ -30,4 +32,8 @@ public class OpExecOrderListener extends BaseListener { } } + @Override + public boolean isActive(Operation operation) { + return true; + } } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/listeners/ImportDebugListener.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/listeners/ImportDebugListener.java index 23ad66cee..a35d54cbd 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/listeners/ImportDebugListener.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/listeners/ImportDebugListener.java @@ -20,9 +20,11 @@ import lombok.NonNull; import lombok.extern.slf4j.Slf4j; import org.nd4j.autodiff.listeners.At; import org.nd4j.autodiff.listeners.BaseListener; +import org.nd4j.autodiff.listeners.Operation; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.internal.SameDiffOp; import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.dataset.api.MultiDataSet; import org.nd4j.linalg.factory.Nd4j; import java.io.File; @@ -30,6 +32,11 @@ import java.io.File; @Slf4j public class ImportDebugListener extends BaseListener { + @Override + public boolean isActive(Operation operation) { + return true; + } + public enum OnFailure {EXCEPTION, LOG}; private File baseDir; @@ -49,7 +56,7 @@ public class ImportDebugListener extends BaseListener { } @Override - public void opExecution(SameDiff sd, At at, boolean training, SameDiffOp op, INDArray[] outputs) { + public void opExecution(SameDiff sd, At at, MultiDataSet batch, SameDiffOp op, INDArray[] outputs) { //No op for( int i=0; i