From 5568b9d72ff519a3cdaa4db773a02cf72726582e Mon Sep 17 00:00:00 2001 From: Alexandre Boulanger <44292157+aboulang2002@users.noreply.github.com> Date: Wed, 27 May 2020 07:41:02 -0400 Subject: [PATCH] RL4J: Add AgentLearner (#470) Signed-off-by: Alexandre Boulanger --- .../org/deeplearning4j/rl4j/agent/Agent.java | 113 +++++++--- .../rl4j/agent/AgentLearner.java | 115 ++++++++++ .../org/deeplearning4j/rl4j/agent/IAgent.java | 55 +++++ .../rl4j/agent/IAgentLearner.java | 24 ++ .../agent/learning/ILearningBehavior.java | 49 ++++ .../rl4j/agent/learning/LearningBehavior.java | 59 +++++ .../rl4j/agent/listener/AgentListener.java | 47 +++- .../agent/listener/AgentListenerList.java | 39 ++++ .../agent/update/DQNNeuralNetUpdateRule.java | 62 +++++ .../rl4j/agent/update/Gradients.java | 26 +++ .../rl4j/agent/update/IUpdateRule.java | 37 +++ .../rl4j/environment/ActionSchema.java | 9 - .../rl4j/environment/Environment.java | 43 ++++ .../rl4j/environment/IActionSchema.java | 26 +++ .../rl4j/environment/IntegerActionSchema.java | 47 ++++ .../rl4j/environment/Schema.java | 18 +- .../rl4j/environment/StepResult.java | 15 ++ .../rl4j/experience/ExperienceHandler.java | 5 + .../ReplayMemoryExperienceHandler.java | 7 + .../StateActionExperienceHandler.java | 17 +- .../rl4j/helper/INDArrayHelper.java | 31 ++- .../learning/async/AsyncThreadDiscrete.java | 14 +- .../AsyncNStepQLearningThreadDiscrete.java | 3 +- .../discrete/QLearningUpdateAlgorithm.java | 24 +- .../rl4j/learning/sync/ExpReplay.java | 5 + .../rl4j/learning/sync/IExpReplay.java | 5 + .../learning/sync/qlearning/QLearning.java | 19 +- .../qlearning/discrete/QLearningDiscrete.java | 70 +++--- .../rl4j/mdp/CartpoleEnvironment.java | 17 +- .../deeplearning4j/rl4j/policy/EpsGreedy.java | 87 +++++++- .../rl4j/policy/INeuralNetPolicy.java | 7 + .../deeplearning4j/rl4j/policy/Policy.java | 2 +- .../rl4j/agent/AgentLearnerTest.java | 211 ++++++++++++++++++ .../deeplearning4j/rl4j/agent/AgentTest.java | 44 ++-- .../agent/learning/LearningBehaviorTest.java | 133 +++++++++++ .../ReplayMemoryExperienceHandlerTest.java | 100 ++++++--- .../StateActionExperienceHandlerTest.java | 70 +++++- .../rl4j/helper/INDArrayHelperTest.java | 21 ++ .../QLearningUpdateAlgorithmTest.java | 75 ++++--- .../discrete/QLearningDiscreteTest.java | 34 ++- 40 files changed, 1541 insertions(+), 244 deletions(-) create mode 100644 rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/AgentLearner.java create mode 100644 rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/IAgent.java create mode 100644 rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/IAgentLearner.java create mode 100644 rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/learning/ILearningBehavior.java create mode 100644 rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/learning/LearningBehavior.java create mode 100644 rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/update/DQNNeuralNetUpdateRule.java create mode 100644 rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/update/Gradients.java create mode 100644 rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/update/IUpdateRule.java delete mode 100644 rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/environment/ActionSchema.java create mode 100644 rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/environment/IActionSchema.java create mode 100644 rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/environment/IntegerActionSchema.java create mode 100644 rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/policy/INeuralNetPolicy.java create mode 100644 rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/agent/AgentLearnerTest.java create mode 100644 rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/agent/learning/LearningBehaviorTest.java diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/Agent.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/Agent.java index 999f12e8c..198c2a1ca 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/Agent.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/Agent.java @@ -1,3 +1,18 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ package org.deeplearning4j.rl4j.agent; import lombok.AccessLevel; @@ -14,7 +29,13 @@ import org.nd4j.common.base.Preconditions; import java.util.Map; -public class Agent { +/** + * An agent implementation. The Agent will use a {@link IPolicy} to interact with an {@link Environment} and receive + * a reward. + * + * @param The type of action + */ +public class Agent implements IAgent { @Getter private final String id; @@ -37,19 +58,28 @@ public class Agent { private ACTION lastAction; @Getter - private int episodeStepNumber; + private int episodeStepCount; @Getter private double reward; protected boolean canContinue; - private Agent(Builder builder) { - this.environment = builder.environment; - this.transformProcess = builder.transformProcess; - this.policy = builder.policy; - this.maxEpisodeSteps = builder.maxEpisodeSteps; - this.id = builder.id; + /** + * @param environment The {@link Environment} to be used + * @param transformProcess The {@link TransformProcess} to be used to transform the raw observations into usable ones. + * @param policy The {@link IPolicy} to be used + * @param maxEpisodeSteps The maximum number of steps an episode can have before being interrupted. Use null to have no max. + * @param id A user-supplied id to identify the instance. + */ + public Agent(@NonNull Environment environment, @NonNull TransformProcess transformProcess, @NonNull IPolicy policy, Integer maxEpisodeSteps, String id) { + Preconditions.checkArgument(maxEpisodeSteps == null || maxEpisodeSteps > 0, "maxEpisodeSteps must be null (no maximum) or greater than 0, got", maxEpisodeSteps); + + this.environment = environment; + this.transformProcess = transformProcess; + this.policy = policy; + this.maxEpisodeSteps = maxEpisodeSteps; + this.id = id; listeners = buildListenerList(); } @@ -58,10 +88,17 @@ public class Agent { return new AgentListenerList(); } + /** + * Add a {@link AgentListener} that will be notified when agent events happens + * @param listener + */ public void addListener(AgentListener listener) { listeners.add(listener); } + /** + * This will run a single episode + */ public void run() { runEpisode(); } @@ -80,7 +117,7 @@ public class Agent { canContinue = listeners.notifyBeforeEpisode(this); - while (canContinue && !environment.isEpisodeFinished() && (maxEpisodeSteps == null || episodeStepNumber < maxEpisodeSteps)) { + while (canContinue && !environment.isEpisodeFinished() && (maxEpisodeSteps == null || episodeStepCount < maxEpisodeSteps)) { performStep(); } @@ -100,9 +137,9 @@ public class Agent { } protected void resetEnvironment() { - episodeStepNumber = 0; + episodeStepCount = 0; Map channelsData = environment.reset(); - this.observation = transformProcess.transform(channelsData, episodeStepNumber, false); + this.observation = transformProcess.transform(channelsData, episodeStepCount, false); } protected void resetPolicy() { @@ -125,7 +162,6 @@ public class Agent { } StepResult stepResult = act(action); - handleStepResult(stepResult); onAfterStep(stepResult); @@ -134,11 +170,11 @@ public class Agent { return; } - incrementEpisodeStepNumber(); + incrementEpisodeStepCount(); } - protected void incrementEpisodeStepNumber() { - ++episodeStepNumber; + protected void incrementEpisodeStepCount() { + ++episodeStepCount; } protected ACTION decideAction(Observation observation) { @@ -150,12 +186,15 @@ public class Agent { } protected StepResult act(ACTION action) { - return environment.step(action); - } + Observation observationBeforeAction = observation; - protected void handleStepResult(StepResult stepResult) { - observation = convertChannelDataToObservation(stepResult, episodeStepNumber + 1); - reward +=computeReward(stepResult); + StepResult stepResult = environment.step(action); + observation = convertChannelDataToObservation(stepResult, episodeStepCount + 1); + reward += computeReward(stepResult); + + onAfterAction(observationBeforeAction, action, stepResult); + + return stepResult; } protected Observation convertChannelDataToObservation(StepResult stepResult, int episodeStepNumberOfObs) { @@ -166,6 +205,10 @@ public class Agent { return stepResult.getReward(); } + protected void onAfterAction(Observation observationBeforeAction, ACTION action, StepResult stepResult) { + // Do Nothing + } + protected void onAfterStep(StepResult stepResult) { // Do Nothing } @@ -174,16 +217,24 @@ public class Agent { // Do Nothing } - public static Builder builder(@NonNull Environment environment, @NonNull TransformProcess transformProcess, @NonNull IPolicy policy) { + /** + * + * @param environment + * @param transformProcess + * @param policy + * @param + * @return + */ + public static Builder builder(@NonNull Environment environment, @NonNull TransformProcess transformProcess, @NonNull IPolicy policy) { return new Builder<>(environment, transformProcess, policy); } - public static class Builder { - private final Environment environment; - private final TransformProcess transformProcess; - private final IPolicy policy; - private Integer maxEpisodeSteps = null; // Default, no max - private String id; + public static class Builder { + protected final Environment environment; + protected final TransformProcess transformProcess; + protected final IPolicy policy; + protected Integer maxEpisodeSteps = null; // Default, no max + protected String id; public Builder(@NonNull Environment environment, @NonNull TransformProcess transformProcess, @NonNull IPolicy policy) { this.environment = environment; @@ -191,20 +242,20 @@ public class Agent { this.policy = policy; } - public Builder maxEpisodeSteps(int maxEpisodeSteps) { + public Builder maxEpisodeSteps(int maxEpisodeSteps) { Preconditions.checkArgument(maxEpisodeSteps > 0, "maxEpisodeSteps must be greater than 0, got", maxEpisodeSteps); this.maxEpisodeSteps = maxEpisodeSteps; return this; } - public Builder id(String id) { + public Builder id(String id) { this.id = id; return this; } - public Agent build() { - return new Agent(this); + public AGENT_TYPE build() { + return (AGENT_TYPE)new Agent(environment, transformProcess, policy, maxEpisodeSteps, id); } } } \ No newline at end of file diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/AgentLearner.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/AgentLearner.java new file mode 100644 index 000000000..8fd963cda --- /dev/null +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/AgentLearner.java @@ -0,0 +1,115 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.deeplearning4j.rl4j.agent; + +import lombok.Getter; +import lombok.NonNull; +import org.deeplearning4j.rl4j.agent.learning.ILearningBehavior; +import org.deeplearning4j.rl4j.environment.Environment; +import org.deeplearning4j.rl4j.environment.StepResult; +import org.deeplearning4j.rl4j.observation.Observation; +import org.deeplearning4j.rl4j.observation.transform.TransformProcess; +import org.deeplearning4j.rl4j.policy.IPolicy; + +/** + * The ActionLearner is an {@link Agent} that delegate the learning to a {@link ILearningBehavior}. + * @param The type of the action + */ +public class AgentLearner extends Agent implements IAgentLearner { + + @Getter + private int totalStepCount = 0; + + private final ILearningBehavior learningBehavior; + private double rewardAtLastExperience; + + /** + * + * @param environment The {@link Environment} to be used + * @param transformProcess The {@link TransformProcess} to be used to transform the raw observations into usable ones. + * @param policy The {@link IPolicy} to be used + * @param maxEpisodeSteps The maximum number of steps an episode can have before being interrupted. Use null to have no max. + * @param id A user-supplied id to identify the instance. + * @param learningBehavior The {@link ILearningBehavior} that will be used to supervise the learning. + */ + public AgentLearner(Environment environment, TransformProcess transformProcess, IPolicy policy, Integer maxEpisodeSteps, String id, @NonNull ILearningBehavior learningBehavior) { + super(environment, transformProcess, policy, maxEpisodeSteps, id); + + this.learningBehavior = learningBehavior; + } + + @Override + protected void reset() { + super.reset(); + + rewardAtLastExperience = 0; + } + + @Override + protected void onBeforeEpisode() { + super.onBeforeEpisode(); + + learningBehavior.handleEpisodeStart(); + } + + @Override + protected void onAfterAction(Observation observationBeforeAction, ACTION action, StepResult stepResult) { + if(!observationBeforeAction.isSkipped()) { + double rewardSinceLastExperience = getReward() - rewardAtLastExperience; + learningBehavior.handleNewExperience(observationBeforeAction, action, rewardSinceLastExperience, stepResult.isTerminal()); + + rewardAtLastExperience = getReward(); + } + } + + @Override + protected void onAfterEpisode() { + learningBehavior.handleEpisodeEnd(getObservation()); + } + + @Override + protected void incrementEpisodeStepCount() { + super.incrementEpisodeStepCount(); + ++totalStepCount; + } + + // FIXME: parent is still visible + public static AgentLearner.Builder> builder(Environment environment, + TransformProcess transformProcess, + IPolicy policy, + ILearningBehavior learningBehavior) { + return new AgentLearner.Builder>(environment, transformProcess, policy, learningBehavior); + } + + public static class Builder> extends Agent.Builder { + + private final ILearningBehavior learningBehavior; + + public Builder(@NonNull Environment environment, + @NonNull TransformProcess transformProcess, + @NonNull IPolicy policy, + @NonNull ILearningBehavior learningBehavior) { + super(environment, transformProcess, policy); + + this.learningBehavior = learningBehavior; + } + + @Override + public AGENT_TYPE build() { + return (AGENT_TYPE)new AgentLearner(environment, transformProcess, policy, maxEpisodeSteps, id, learningBehavior); + } + } +} diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/IAgent.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/IAgent.java new file mode 100644 index 000000000..7cbd68a70 --- /dev/null +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/IAgent.java @@ -0,0 +1,55 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.deeplearning4j.rl4j.agent; + +import org.deeplearning4j.rl4j.environment.Environment; +import org.deeplearning4j.rl4j.policy.IPolicy; + +/** + * The interface of {@link Agent} + * @param + */ +public interface IAgent { + /** + * Will play a single episode + */ + void run(); + + /** + * @return A user-supplied id to identify the IAgent instance. + */ + String getId(); + + /** + * @return The {@link Environment} instance being used by the agent. + */ + Environment getEnvironment(); + + /** + * @return The {@link IPolicy} instance being used by the agent. + */ + IPolicy getPolicy(); + + /** + * @return The step count taken in the current episode. + */ + int getEpisodeStepCount(); + + /** + * @return The cumulative reward received in the current episode. + */ + double getReward(); +} diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/IAgentLearner.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/IAgentLearner.java new file mode 100644 index 000000000..b1bdd1646 --- /dev/null +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/IAgentLearner.java @@ -0,0 +1,24 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.deeplearning4j.rl4j.agent; + +public interface IAgentLearner extends IAgent { + + /** + * @return The total count of steps taken by this AgentLearner, for all episodes. + */ + int getTotalStepCount(); +} diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/learning/ILearningBehavior.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/learning/ILearningBehavior.java new file mode 100644 index 000000000..0187d8c3a --- /dev/null +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/learning/ILearningBehavior.java @@ -0,0 +1,49 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.deeplearning4j.rl4j.agent.learning; + +import org.deeplearning4j.rl4j.observation.Observation; + +/** + * The ILearningBehavior implementations are in charge of the training. Through this interface, they are + * notified as new experience is generated. + * + * @param The type of action + */ +public interface ILearningBehavior { + + /** + * This method is called when a new episode has been started. + */ + void handleEpisodeStart(); + + /** + * This method is called when new experience is generated. + * + * @param observation The observation prior to taking the action + * @param action The action that has been taken + * @param reward The reward received by taking the action + * @param isTerminal True if the episode ended after taking the action + */ + void handleNewExperience(Observation observation, ACTION action, double reward, boolean isTerminal); + + /** + * This method is called when the episode ends or the maximum number of episode steps is reached. + * + * @param finalObservation The observation after the last action of the episode has been taken. + */ + void handleEpisodeEnd(Observation finalObservation); +} diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/learning/LearningBehavior.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/learning/LearningBehavior.java new file mode 100644 index 000000000..85c7ec4ce --- /dev/null +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/learning/LearningBehavior.java @@ -0,0 +1,59 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.deeplearning4j.rl4j.agent.learning; + +import lombok.Builder; +import org.deeplearning4j.rl4j.agent.update.IUpdateRule; +import org.deeplearning4j.rl4j.experience.ExperienceHandler; +import org.deeplearning4j.rl4j.observation.Observation; + +/** + * A generic {@link ILearningBehavior} that delegates the handling of experience to a {@link ExperienceHandler} and + * the update logic to a {@link IUpdateRule} + * + * @param The type of the action + * @param The type of experience the ExperienceHandler needs + */ +@Builder +public class LearningBehavior implements ILearningBehavior { + + @Builder.Default + private int experienceUpdateSize = 64; + + private final ExperienceHandler experienceHandler; + private final IUpdateRule updateRule; + + @Override + public void handleEpisodeStart() { + experienceHandler.reset(); + } + + @Override + public void handleNewExperience(Observation observation, ACTION action, double reward, boolean isTerminal) { + experienceHandler.addExperience(observation, action, reward, isTerminal); + if(experienceHandler.isTrainingBatchReady()) { + updateRule.update(experienceHandler.generateTrainingBatch()); + } + } + + @Override + public void handleEpisodeEnd(Observation finalObservation) { + experienceHandler.setFinalObservation(finalObservation); + if(experienceHandler.isTrainingBatchReady()) { + updateRule.update(experienceHandler.generateTrainingBatch()); + } + } +} diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/listener/AgentListener.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/listener/AgentListener.java index 898f89241..f176da144 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/listener/AgentListener.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/listener/AgentListener.java @@ -1,23 +1,66 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ package org.deeplearning4j.rl4j.agent.listener; import org.deeplearning4j.rl4j.agent.Agent; import org.deeplearning4j.rl4j.environment.StepResult; import org.deeplearning4j.rl4j.observation.Observation; +/** + * The base definition of all {@link Agent} event listeners + */ public interface AgentListener { enum ListenerResponse { /** - * Tell the learning process to continue calling the listeners and the training. + * Tell the {@link Agent} to continue calling the listeners and the processing. */ CONTINUE, /** - * Tell the learning process to stop calling the listeners and terminate the training. + * Tell the {@link Agent} to interrupt calling the listeners and stop the processing. */ STOP, } + /** + * Called when a new episode is about to start. + * @param agent The agent that generated the event + * + * @return A {@link ListenerResponse}. + */ AgentListener.ListenerResponse onBeforeEpisode(Agent agent); + + /** + * Called when a step is about to be taken. + * + * @param agent The agent that generated the event + * @param observation The observation before the action is taken + * @param action The action that will be performed + * + * @return A {@link ListenerResponse}. + */ AgentListener.ListenerResponse onBeforeStep(Agent agent, Observation observation, ACTION action); + + /** + * Called after a step has been taken. + * + * @param agent The agent that generated the event + * @param stepResult The {@link StepResult} result of the step. + * + * @return A {@link ListenerResponse}. + */ AgentListener.ListenerResponse onAfterStep(Agent agent, StepResult stepResult); } diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/listener/AgentListenerList.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/listener/AgentListenerList.java index e003934d4..48538aeaf 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/listener/AgentListenerList.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/listener/AgentListenerList.java @@ -1,3 +1,18 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ package org.deeplearning4j.rl4j.agent.listener; import org.deeplearning4j.rl4j.agent.Agent; @@ -7,6 +22,10 @@ import org.deeplearning4j.rl4j.observation.Observation; import java.util.ArrayList; import java.util.List; +/** + * A class that manages a list of {@link AgentListener AgentListeners} listening to an {@link Agent}. + * @param + */ public class AgentListenerList { protected final List> listeners = new ArrayList<>(); @@ -18,6 +37,13 @@ public class AgentListenerList { listeners.add(listener); } + /** + * This method will notify all listeners that an episode is about to start. If a listener returns + * {@link AgentListener.ListenerResponse STOP}, any following listener is skipped. + * + * @param agent The agent that generated the event. + * @return False if the processing should be stopped + */ public boolean notifyBeforeEpisode(Agent agent) { for (AgentListener listener : listeners) { if (listener.onBeforeEpisode(agent) == AgentListener.ListenerResponse.STOP) { @@ -28,6 +54,13 @@ public class AgentListenerList { return true; } + /** + * + * @param agent The agent that generated the event. + * @param observation The observation before the action is taken + * @param action The action that will be performed + * @return False if the processing should be stopped + */ public boolean notifyBeforeStep(Agent agent, Observation observation, ACTION action) { for (AgentListener listener : listeners) { if (listener.onBeforeStep(agent, observation, action) == AgentListener.ListenerResponse.STOP) { @@ -38,6 +71,12 @@ public class AgentListenerList { return true; } + /** + * + * @param agent The agent that generated the event. + * @param stepResult The {@link StepResult} result of the step. + * @return False if the processing should be stopped + */ public boolean notifyAfterStep(Agent agent, StepResult stepResult) { for (AgentListener listener : listeners) { if (listener.onAfterStep(agent, stepResult) == AgentListener.ListenerResponse.STOP) { diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/update/DQNNeuralNetUpdateRule.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/update/DQNNeuralNetUpdateRule.java new file mode 100644 index 000000000..46123d645 --- /dev/null +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/update/DQNNeuralNetUpdateRule.java @@ -0,0 +1,62 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.deeplearning4j.rl4j.agent.update; + +import lombok.Getter; +import org.deeplearning4j.rl4j.learning.sync.Transition; +import org.deeplearning4j.rl4j.learning.sync.qlearning.TargetQNetworkSource; +import org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm.DoubleDQN; +import org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm.ITDTargetAlgorithm; +import org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm.StandardDQN; +import org.deeplearning4j.rl4j.network.dqn.IDQN; +import org.nd4j.linalg.dataset.api.DataSet; + +import java.util.List; + +// Temporary class that will be replaced with a more generic class that delegates gradient computation +// and network update to sub components. +public class DQNNeuralNetUpdateRule implements IUpdateRule>, TargetQNetworkSource { + + @Getter + private final IDQN qNetwork; + + @Getter + private IDQN targetQNetwork; + private final int targetUpdateFrequency; + + private final ITDTargetAlgorithm tdTargetAlgorithm; + + @Getter + private int updateCount = 0; + + public DQNNeuralNetUpdateRule(IDQN qNetwork, int targetUpdateFrequency, boolean isDoubleDQN, double gamma, double errorClamp) { + this.qNetwork = qNetwork; + this.targetQNetwork = qNetwork.clone(); + this.targetUpdateFrequency = targetUpdateFrequency; + tdTargetAlgorithm = isDoubleDQN + ? new DoubleDQN(this, gamma, errorClamp) + : new StandardDQN(this, gamma, errorClamp); + } + + @Override + public void update(List> trainingBatch) { + DataSet targets = tdTargetAlgorithm.computeTDTargets(trainingBatch); + qNetwork.fit(targets.getFeatures(), targets.getLabels()); + if(++updateCount % targetUpdateFrequency == 0) { + targetQNetwork = qNetwork.clone(); + } + } +} diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/update/Gradients.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/update/Gradients.java new file mode 100644 index 000000000..4307efe1e --- /dev/null +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/update/Gradients.java @@ -0,0 +1,26 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.deeplearning4j.rl4j.agent.update; + +import lombok.Value; +import org.deeplearning4j.nn.gradient.Gradient; + +// Work in progress +@Value +public class Gradients { + private Gradient[] gradients; // Temporary: we'll need something better than a Gradient[] + private int batchSize; +} diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/update/IUpdateRule.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/update/IUpdateRule.java new file mode 100644 index 000000000..d679cba24 --- /dev/null +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/update/IUpdateRule.java @@ -0,0 +1,37 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.deeplearning4j.rl4j.agent.update; + +import java.util.List; + +/** + * The role of IUpdateRule implementations is to use an experience batch to improve the accuracy of the policy. + * Used by {@link org.deeplearning4j.rl4j.agent.AgentLearner AgentLearner} + * @param The type of the experience + */ +public interface IUpdateRule { + /** + * Perform the update + * @param trainingBatch A batch of experience + */ + void update(List trainingBatch); + + /** + * @return The total number of times the policy has been updated. In a multi-agent learning context, this total is + * for all the agents. + */ + int getUpdateCount(); +} diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/environment/ActionSchema.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/environment/ActionSchema.java deleted file mode 100644 index f6521e734..000000000 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/environment/ActionSchema.java +++ /dev/null @@ -1,9 +0,0 @@ -package org.deeplearning4j.rl4j.environment; - -import lombok.Value; - -@Value -public class ActionSchema { - private ACTION noOp; - //FIXME ACTION randomAction(); -} diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/environment/Environment.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/environment/Environment.java index 95ff7d2b6..7fa84cc51 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/environment/Environment.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/environment/Environment.java @@ -1,11 +1,54 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ package org.deeplearning4j.rl4j.environment; import java.util.Map; +/** + * An interface for environments used by the {@link org.deeplearning4j.rl4j.agent.Agent Agents}. + * @param The type of actions + */ public interface Environment { + + /** + * @return The {@link Schema} of the environment + */ Schema getSchema(); + + /** + * Reset the environment's state to start a new episode. + * @return + */ Map reset(); + + /** + * Perform a single step. + * + * @param action The action taken + * @return A {@link StepResult} describing the result of the step. + */ StepResult step(ACTION action); + + /** + * @return True if the episode is finished + */ boolean isEpisodeFinished(); + + /** + * Called when the agent is finished using this environment instance. + */ void close(); } diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/environment/IActionSchema.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/environment/IActionSchema.java new file mode 100644 index 000000000..9e6e81a7b --- /dev/null +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/environment/IActionSchema.java @@ -0,0 +1,26 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.deeplearning4j.rl4j.environment; + +import lombok.Value; + +// Work in progress +public interface IActionSchema { + ACTION getNoOp(); + + // Review: A schema should be data-only and not have behavior + ACTION getRandomAction(); +} diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/environment/IntegerActionSchema.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/environment/IntegerActionSchema.java new file mode 100644 index 000000000..cdf172da6 --- /dev/null +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/environment/IntegerActionSchema.java @@ -0,0 +1,47 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.deeplearning4j.rl4j.environment; + +import org.nd4j.linalg.api.rng.Random; +import org.nd4j.linalg.factory.Nd4j; + +// Work in progress +public class IntegerActionSchema implements IActionSchema { + + private final int numActions; + private final int noOpAction; + private final Random rnd; + + public IntegerActionSchema(int numActions, int noOpAction) { + this(numActions, noOpAction, Nd4j.getRandom()); + } + + public IntegerActionSchema(int numActions, int noOpAction, Random rnd) { + this.numActions = numActions; + this.noOpAction = noOpAction; + this.rnd = rnd; + } + + @Override + public Integer getNoOp() { + return noOpAction; + } + + @Override + public Integer getRandomAction() { + return rnd.nextInt(numActions); + } +} diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/environment/Schema.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/environment/Schema.java index 5ddea24cd..7768c0553 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/environment/Schema.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/environment/Schema.java @@ -1,8 +1,24 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ package org.deeplearning4j.rl4j.environment; import lombok.Value; +// Work in progress @Value public class Schema { - private ActionSchema actionSchema; + private IActionSchema actionSchema; } diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/environment/StepResult.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/environment/StepResult.java index b64dd08f5..4936625db 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/environment/StepResult.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/environment/StepResult.java @@ -1,3 +1,18 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ package org.deeplearning4j.rl4j.environment; import lombok.Value; diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/experience/ExperienceHandler.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/experience/ExperienceHandler.java index 0017925df..e15c08415 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/experience/ExperienceHandler.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/experience/ExperienceHandler.java @@ -41,6 +41,11 @@ public interface ExperienceHandler { */ int getTrainingBatchSize(); + /** + * @return True if a batch is ready for training. + */ + boolean isTrainingBatchReady(); + /** * The elements are returned in the historical order (i.e. in the order they happened) * @return The list of experience elements diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/experience/ReplayMemoryExperienceHandler.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/experience/ReplayMemoryExperienceHandler.java index 74b7e3f05..c7f7d51ae 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/experience/ReplayMemoryExperienceHandler.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/experience/ReplayMemoryExperienceHandler.java @@ -36,6 +36,7 @@ import java.util.List; public class ReplayMemoryExperienceHandler implements ExperienceHandler> { private static final int DEFAULT_MAX_REPLAY_MEMORY_SIZE = 150000; private static final int DEFAULT_BATCH_SIZE = 32; + private final int batchSize; private IExpReplay expReplay; @@ -43,6 +44,7 @@ public class ReplayMemoryExperienceHandler implements ExperienceHandler expReplay) { this.expReplay = expReplay; + this.batchSize = expReplay.getDesignatedBatchSize(); } public ReplayMemoryExperienceHandler(int maxReplayMemorySize, int batchSize, Random random) { @@ -64,6 +66,11 @@ public class ReplayMemoryExperienceHandler implements ExperienceHandler= batchSize; + } + /** * @return A batch of experience selected from the replay memory. The replay memory is unchanged after the call. */ diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/experience/StateActionExperienceHandler.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/experience/StateActionExperienceHandler.java index 4c6b95c89..a8fae47bc 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/experience/StateActionExperienceHandler.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/experience/StateActionExperienceHandler.java @@ -30,10 +30,18 @@ import java.util.List; */ public class StateActionExperienceHandler implements ExperienceHandler> { + private final int batchSize; + + private boolean isFinalObservationSet; + + public StateActionExperienceHandler(int batchSize) { + this.batchSize = batchSize; + } + private List> stateActionPairs = new ArrayList<>(); public void setFinalObservation(Observation observation) { - // Do nothing + isFinalObservationSet = true; } public void addExperience(Observation observation, A action, double reward, boolean isTerminal) { @@ -45,6 +53,12 @@ public class StateActionExperienceHandler implements ExperienceHandler= batchSize + || (isFinalObservationSet && stateActionPairs.size() > 0); + } + /** * The elements are returned in the historical order (i.e. in the order they happened) * Note: the experience store is cleared after calling this method. @@ -62,6 +76,7 @@ public class StateActionExperienceHandler implements ExperienceHandler(); + isFinalObservationSet = false; } } diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/helper/INDArrayHelper.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/helper/INDArrayHelper.java index b42a7c503..9c35ed6f4 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/helper/INDArrayHelper.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/helper/INDArrayHelper.java @@ -24,17 +24,38 @@ import org.nd4j.linalg.factory.Nd4j; * @author Alexandre Boulanger */ public class INDArrayHelper { - /** - * MultiLayerNetwork and ComputationGraph expects input data to be in NCHW in the case of pixels and NS in case of other data types. - * - * We must have either shape 2 (NK) or shape 4 (NCHW) + * Force the input source to have the correct shape: + *

    + *
  • DL4J requires it to be at least 2D
  • + *
  • RL4J has a convention to have the batch size on dimension 0 to all INDArrays
  • + *

+ * @param source The {@link INDArray} to be corrected. + * @return The corrected INDArray */ public static INDArray forceCorrectShape(INDArray source) { - return source.shape()[0] == 1 && source.shape().length > 1 + return source.shape()[0] == 1 && source.rank() > 1 ? source : Nd4j.expandDims(source, 0); } + + /** + * This will create a INDArray with batchSize as dimension 0 and shape as other dimensions. + * For example, if batchSize is 10 and shape is { 1, 3, 4 }, the resulting INDArray shape will be { 10, 3, 4} + * @param batchSize The size of the batch to create + * @param shape The shape of individual elements. + * Note: all shapes in RL4J should have a batch size as dimension 0; in this case the batch size should be 1. + * @return A INDArray + */ + public static INDArray createBatchForShape(long batchSize, long... shape) { + long[] batchShape; + + batchShape = new long[shape.length]; + System.arraycopy(shape, 0, batchShape, 0, shape.length); + + batchShape[0] = batchSize; + return Nd4j.create(batchShape); + } } diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/AsyncThreadDiscrete.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/AsyncThreadDiscrete.java index c32be6906..bf8838424 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/AsyncThreadDiscrete.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/AsyncThreadDiscrete.java @@ -25,6 +25,7 @@ import org.deeplearning4j.gym.StepReply; import org.deeplearning4j.rl4j.experience.ExperienceHandler; import org.deeplearning4j.rl4j.experience.StateActionExperienceHandler; import org.deeplearning4j.rl4j.learning.IHistoryProcessor; +import org.deeplearning4j.rl4j.learning.configuration.IAsyncLearningConfiguration; import org.deeplearning4j.rl4j.learning.listener.TrainingListenerList; import org.deeplearning4j.rl4j.mdp.MDP; import org.deeplearning4j.rl4j.network.NeuralNet; @@ -49,7 +50,7 @@ public abstract class AsyncThreadDiscrete asyncGlobal, MDP mdp, @@ -60,6 +61,17 @@ public abstract class AsyncThreadDiscrete ex @Override protected UpdateAlgorithm buildUpdateAlgorithm() { - int[] shape = getHistoryProcessor() == null ? getMdp().getObservationSpace().getShape() : getHistoryProcessor().getConf().getShape(); - return new QLearningUpdateAlgorithm(shape, getMdp().getActionSpace().getSize(), configuration.getGamma()); + return new QLearningUpdateAlgorithm(getMdp().getActionSpace().getSize(), configuration.getGamma()); } } diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/QLearningUpdateAlgorithm.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/QLearningUpdateAlgorithm.java index 79c9666a2..f935240dc 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/QLearningUpdateAlgorithm.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/QLearningUpdateAlgorithm.java @@ -17,7 +17,7 @@ package org.deeplearning4j.rl4j.learning.async.nstep.discrete; import org.deeplearning4j.nn.gradient.Gradient; import org.deeplearning4j.rl4j.experience.StateActionPair; -import org.deeplearning4j.rl4j.learning.Learning; +import org.deeplearning4j.rl4j.helper.INDArrayHelper; import org.deeplearning4j.rl4j.learning.async.UpdateAlgorithm; import org.deeplearning4j.rl4j.network.dqn.IDQN; import org.nd4j.linalg.api.ndarray.INDArray; @@ -27,15 +27,12 @@ import java.util.List; public class QLearningUpdateAlgorithm implements UpdateAlgorithm { - private final int[] shape; private final int actionSpaceSize; private final double gamma; - public QLearningUpdateAlgorithm(int[] shape, - int actionSpaceSize, + public QLearningUpdateAlgorithm(int actionSpaceSize, double gamma) { - this.shape = shape; this.actionSpaceSize = actionSpaceSize; this.gamma = gamma; } @@ -44,33 +41,34 @@ public class QLearningUpdateAlgorithm implements UpdateAlgorithm { public Gradient[] computeGradients(IDQN current, List> experience) { int size = experience.size(); - int[] nshape = Learning.makeShape(size, shape); - INDArray input = Nd4j.create(nshape); - INDArray targets = Nd4j.create(size, actionSpaceSize); - StateActionPair stateActionPair = experience.get(size - 1); + INDArray data = stateActionPair.getObservation().getData(); + INDArray features = INDArrayHelper.createBatchForShape(size, data.shape()); + INDArray targets = Nd4j.create(size, actionSpaceSize); + double r; if (stateActionPair.isTerminal()) { r = 0; } else { INDArray[] output = null; - output = current.outputAll(stateActionPair.getObservation().getData()); + output = current.outputAll(data); r = Nd4j.max(output[0]).getDouble(0); } for (int i = size - 1; i >= 0; i--) { stateActionPair = experience.get(i); + data = stateActionPair.getObservation().getData(); - input.putRow(i, stateActionPair.getObservation().getData()); + features.putRow(i, data); r = stateActionPair.getReward() + gamma * r; - INDArray[] output = current.outputAll(stateActionPair.getObservation().getData()); + INDArray[] output = current.outputAll(data); INDArray row = output[0]; row = row.putScalar(stateActionPair.getAction(), r); targets.putRow(i, row); } - return current.gradient(input, targets); + return current.gradient(features, targets); } } diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/ExpReplay.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/ExpReplay.java index 93b4d1bb5..7bfcad53d 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/ExpReplay.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/ExpReplay.java @@ -80,6 +80,11 @@ public class ExpReplay
implements IExpReplay { //log.info("size: "+storage.size()); } + @Override + public int getDesignatedBatchSize() { + return batchSize; + } + public int getBatchSize() { int storageSize = storage.size(); return Math.min(storageSize, batchSize); diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/IExpReplay.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/IExpReplay.java index eaef5f0f8..8b2133806 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/IExpReplay.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/IExpReplay.java @@ -47,4 +47,9 @@ public interface IExpReplay { * @param transition a new transition to store */ void store(Transition transition); + + /** + * @return The desired size of batches + */ + int getDesignatedBatchSize(); } diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/QLearning.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/QLearning.java index b2e06dc9c..d9c955e17 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/QLearning.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/QLearning.java @@ -51,25 +51,16 @@ import java.util.List; @Slf4j public abstract class QLearning> extends SyncLearning - implements TargetQNetworkSource, IEpochTrainer { + implements IEpochTrainer { protected abstract LegacyMDPWrapper getLegacyMDPWrapper(); - protected abstract EpsGreedy getEgPolicy(); + protected abstract EpsGreedy getEgPolicy(); public abstract MDP getMdp(); public abstract IDQN getQNetwork(); - public abstract IDQN getTargetQNetwork(); - - protected abstract void setTargetQNetwork(IDQN dqn); - - protected void updateTargetNetwork() { - log.info("Update target network"); - setTargetQNetwork(getQNetwork().clone()); - } - public IDQN getNeuralNet() { return getQNetwork(); } @@ -101,11 +92,6 @@ public abstract class QLearning scores = new ArrayList<>(); while (currentEpisodeStepCount < getConfiguration().getMaxEpochStep() && !getMdp().isDone()) { - - if (this.getStepCount() % getConfiguration().getTargetDqnUpdateFreq() == 0) { - updateTargetNetwork(); - } - QLStepReturn stepR = trainStep(obs); if (!stepR.getMaxQ().isNaN()) { @@ -146,7 +132,6 @@ public abstract class QLearning refacInitMdp() { diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/QLearningDiscrete.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/QLearningDiscrete.java index 771650340..4e357584d 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/QLearningDiscrete.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/QLearningDiscrete.java @@ -21,6 +21,10 @@ import lombok.AccessLevel; import lombok.Getter; import lombok.Setter; import org.deeplearning4j.gym.StepReply; +import org.deeplearning4j.rl4j.agent.learning.ILearningBehavior; +import org.deeplearning4j.rl4j.agent.learning.LearningBehavior; +import org.deeplearning4j.rl4j.agent.update.DQNNeuralNetUpdateRule; +import org.deeplearning4j.rl4j.agent.update.IUpdateRule; import org.deeplearning4j.rl4j.experience.ExperienceHandler; import org.deeplearning4j.rl4j.experience.ReplayMemoryExperienceHandler; import org.deeplearning4j.rl4j.learning.IHistoryProcessor; @@ -28,9 +32,6 @@ import org.deeplearning4j.rl4j.learning.Learning; import org.deeplearning4j.rl4j.learning.configuration.QLearningConfiguration; import org.deeplearning4j.rl4j.learning.sync.Transition; import org.deeplearning4j.rl4j.learning.sync.qlearning.QLearning; -import org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm.DoubleDQN; -import org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm.ITDTargetAlgorithm; -import org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm.StandardDQN; import org.deeplearning4j.rl4j.mdp.MDP; import org.deeplearning4j.rl4j.network.dqn.IDQN; import org.deeplearning4j.rl4j.space.Encodable; @@ -41,12 +42,8 @@ import org.deeplearning4j.rl4j.space.DiscreteSpace; import org.deeplearning4j.rl4j.util.LegacyMDPWrapper; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.rng.Random; -import org.nd4j.linalg.dataset.api.DataSet; import org.nd4j.linalg.factory.Nd4j; -import java.util.List; - - /** * @author rubenfiszel (ruben.fiszel@epfl.ch) 7/18/16. @@ -63,22 +60,15 @@ public abstract class QLearningDiscrete extends QLearning policy; @Getter - private EpsGreedy egPolicy; + private EpsGreedy egPolicy; @Getter final private IDQN qNetwork; - @Getter - @Setter(AccessLevel.PROTECTED) - private IDQN targetQNetwork; private int lastAction; private double accuReward = 0; - ITDTargetAlgorithm tdTargetAlgorithm; - - // TODO: User a builder and remove the setter - @Getter(AccessLevel.PROTECTED) @Setter - private ExperienceHandler> experienceHandler; + private final ILearningBehavior learningBehavior; protected LegacyMDPWrapper getLegacyMDPWrapper() { return mdp; @@ -88,21 +78,31 @@ public abstract class QLearningDiscrete extends QLearning mdp, IDQN dqn, QLearningConfiguration conf, int epsilonNbStep, Random random) { + this(mdp, dqn, conf, epsilonNbStep, buildLearningBehavior(dqn, conf, random), random); + } + public QLearningDiscrete(MDP mdp, IDQN dqn, QLearningConfiguration conf, - int epsilonNbStep, Random random) { + int epsilonNbStep, ILearningBehavior learningBehavior, Random random) { this.configuration = conf; this.mdp = new LegacyMDPWrapper<>(mdp, null); qNetwork = dqn; - targetQNetwork = dqn.clone(); policy = new DQNPolicy(getQNetwork()); egPolicy = new EpsGreedy(policy, mdp, conf.getUpdateStart(), epsilonNbStep, random, conf.getMinEpsilon(), this); - tdTargetAlgorithm = conf.isDoubleDQN() - ? new DoubleDQN(this, conf.getGamma(), conf.getErrorClamp()) - : new StandardDQN(this, conf.getGamma(), conf.getErrorClamp()); + this.learningBehavior = learningBehavior; + } + + private static ILearningBehavior buildLearningBehavior(IDQN qNetwork, QLearningConfiguration conf, Random random) { + IUpdateRule> updateRule = new DQNNeuralNetUpdateRule(qNetwork, conf.getTargetDqnUpdateFreq(), conf.isDoubleDQN(), conf.getGamma(), conf.getErrorClamp()); + ExperienceHandler> experienceHandler = new ReplayMemoryExperienceHandler(conf.getExpRepMaxSize(), conf.getBatchSize(), random); + return LearningBehavior.>builder() + .experienceHandler(experienceHandler) + .updateRule(updateRule) + .experienceUpdateSize(conf.getBatchSize()) + .build(); - experienceHandler = new ReplayMemoryExperienceHandler(conf.getExpRepMaxSize(), conf.getBatchSize(), random); } public MDP getMdp() { @@ -119,7 +119,7 @@ public abstract class QLearningDiscrete extends QLearning extends QLearning trainStep(Observation obs) { - boolean isHistoryProcessor = getHistoryProcessor() != null; - int skipFrame = isHistoryProcessor ? getHistoryProcessor().getConf().getSkipFrame() : 1; - int historyLength = isHistoryProcessor ? getHistoryProcessor().getConf().getHistoryLength() : 1; - int updateStart = this.getConfiguration().getUpdateStart() - + ((this.getConfiguration().getBatchSize() + historyLength) * skipFrame); - Double maxQ = Double.NaN; //ignore if Nan for stats //if step of training, just repeat lastAction @@ -160,29 +154,15 @@ public abstract class QLearningDiscrete extends QLearning updateStart) { - DataSet targets = setTarget(experienceHandler.generateTrainingBatch()); - getQNetwork().fit(targets.getFeatures(), targets.getLabels()); - } } return new QLStepReturn<>(maxQ, getQNetwork().getLatestScore(), stepReply); } - protected DataSet setTarget(List> transitions) { - if (transitions.size() == 0) - throw new IllegalArgumentException("too few transitions"); - - return tdTargetAlgorithm.computeTDTargets(transitions); - } - @Override protected void finishEpoch(Observation observation) { - experienceHandler.setFinalObservation(observation); + learningBehavior.handleEpisodeEnd(observation); } } diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/mdp/CartpoleEnvironment.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/mdp/CartpoleEnvironment.java index 1e1348b4a..86907017b 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/mdp/CartpoleEnvironment.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/mdp/CartpoleEnvironment.java @@ -2,21 +2,19 @@ package org.deeplearning4j.rl4j.mdp; import lombok.Getter; import lombok.Setter; -import org.deeplearning4j.rl4j.environment.ActionSchema; -import org.deeplearning4j.rl4j.environment.Environment; -import org.deeplearning4j.rl4j.environment.Schema; -import org.deeplearning4j.rl4j.environment.StepResult; +import org.deeplearning4j.rl4j.environment.*; +import org.nd4j.linalg.api.rng.Random; +import org.nd4j.linalg.factory.Nd4j; import java.util.HashMap; import java.util.Map; -import java.util.Random; public class CartpoleEnvironment implements Environment { private static final int NUM_ACTIONS = 2; private static final int ACTION_LEFT = 0; private static final int ACTION_RIGHT = 1; - private static final Schema schema = new Schema<>(new ActionSchema<>(ACTION_LEFT)); + private final Schema schema; public enum KinematicsIntegrators { Euler, SemiImplicitEuler }; @@ -48,11 +46,12 @@ public class CartpoleEnvironment implements Environment { private Integer stepsBeyondDone; public CartpoleEnvironment() { - rnd = new Random(); + this(Nd4j.getRandom()); } - public CartpoleEnvironment(int seed) { - rnd = new Random(seed); + public CartpoleEnvironment(Random rnd) { + this.rnd = rnd; + this.schema = new Schema(new IntegerActionSchema(NUM_ACTIONS, ACTION_LEFT, rnd)); } @Override diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/policy/EpsGreedy.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/policy/EpsGreedy.java index a7282f139..f7422be92 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/policy/EpsGreedy.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/policy/EpsGreedy.java @@ -17,16 +17,19 @@ package org.deeplearning4j.rl4j.policy; -import lombok.AllArgsConstructor; +import lombok.Builder; +import lombok.NonNull; import lombok.extern.slf4j.Slf4j; +import org.deeplearning4j.rl4j.environment.IActionSchema; import org.deeplearning4j.rl4j.learning.IEpochTrainer; import org.deeplearning4j.rl4j.mdp.MDP; import org.deeplearning4j.rl4j.network.NeuralNet; -import org.deeplearning4j.rl4j.space.Encodable; import org.deeplearning4j.rl4j.observation.Observation; import org.deeplearning4j.rl4j.space.ActionSpace; +import org.deeplearning4j.rl4j.space.Encodable; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.rng.Random; +import org.nd4j.linalg.factory.Nd4j; /** * @author rubenfiszel (ruben.fiszel@epfl.ch) 7/24/16. @@ -38,18 +41,60 @@ import org.nd4j.linalg.api.rng.Random; * epislon is annealed to minEpsilon over epsilonNbStep steps * */ -@AllArgsConstructor @Slf4j -public class EpsGreedy> extends Policy { +public class EpsGreedy extends Policy { - final private Policy policy; - final private MDP mdp; + final private INeuralNetPolicy policy; final private int updateStart; final private int epsilonNbStep; final private Random rnd; final private double minEpsilon; + + private final IActionSchema actionSchema; + + final private MDP> mdp; final private IEpochTrainer learning; + // Using agent's (learning's) step count is incorrect; frame skipping makes epsilon's value decrease too quickly + private int annealingStep = 0; + + @Deprecated + public > EpsGreedy(Policy policy, + MDP> mdp, + int updateStart, + int epsilonNbStep, + Random rnd, + double minEpsilon, + IEpochTrainer learning) { + this.policy = policy; + this.mdp = mdp; + this.updateStart = updateStart; + this.epsilonNbStep = epsilonNbStep; + this.rnd = rnd; + this.minEpsilon = minEpsilon; + this.learning = learning; + + this.actionSchema = null; + } + + public EpsGreedy(@NonNull Policy policy, @NonNull IActionSchema actionSchema, double minEpsilon, int updateStart, int epsilonNbStep) { + this(policy, actionSchema, minEpsilon, updateStart, epsilonNbStep, null); + } + + @Builder + public EpsGreedy(@NonNull INeuralNetPolicy policy, @NonNull IActionSchema actionSchema, double minEpsilon, int updateStart, int epsilonNbStep, Random rnd) { + this.policy = policy; + + this.rnd = rnd == null ? Nd4j.getRandom() : rnd; + this.minEpsilon = minEpsilon; + this.updateStart = updateStart; + this.epsilonNbStep = epsilonNbStep; + this.actionSchema = actionSchema; + + this.mdp = null; + this.learning = null; + } + public NeuralNet getNeuralNet() { return policy.getNeuralNet(); } @@ -57,6 +102,11 @@ public class EpsGreedy ep) @@ -66,10 +116,31 @@ public class EpsGreedy ep) { + result = policy.nextAction(observation); + } + else { + result = actionSchema.getRandomAction(); + } + + ++annealingStep; + + return result; } public double getEpsilon() { - return Math.min(1.0, Math.max(minEpsilon, 1.0 - (learning.getStepCount() - updateStart) * 1.0 / epsilonNbStep)); + int step = actionSchema != null ? annealingStep : learning.getStepCount(); + return Math.min(1.0, Math.max(minEpsilon, 1.0 - (step - updateStart) * 1.0 / epsilonNbStep)); } } diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/policy/INeuralNetPolicy.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/policy/INeuralNetPolicy.java new file mode 100644 index 000000000..c213396c6 --- /dev/null +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/policy/INeuralNetPolicy.java @@ -0,0 +1,7 @@ +package org.deeplearning4j.rl4j.policy; + +import org.deeplearning4j.rl4j.network.NeuralNet; + +public interface INeuralNetPolicy extends IPolicy { + NeuralNet getNeuralNet(); +} diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/policy/Policy.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/policy/Policy.java index 6a4146c94..cf369e359 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/policy/Policy.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/policy/Policy.java @@ -34,7 +34,7 @@ import org.deeplearning4j.rl4j.util.LegacyMDPWrapper; * * A Policy responsability is to choose the next action given a state */ -public abstract class Policy implements IPolicy { +public abstract class Policy implements INeuralNetPolicy { public abstract NeuralNet getNeuralNet(); diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/agent/AgentLearnerTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/agent/AgentLearnerTest.java new file mode 100644 index 000000000..e0c0685bf --- /dev/null +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/agent/AgentLearnerTest.java @@ -0,0 +1,211 @@ +package org.deeplearning4j.rl4j.agent; + +import org.deeplearning4j.rl4j.agent.learning.LearningBehavior; +import org.deeplearning4j.rl4j.environment.Environment; +import org.deeplearning4j.rl4j.environment.IntegerActionSchema; +import org.deeplearning4j.rl4j.environment.Schema; +import org.deeplearning4j.rl4j.environment.StepResult; +import org.deeplearning4j.rl4j.observation.Observation; +import org.deeplearning4j.rl4j.observation.transform.TransformProcess; +import org.deeplearning4j.rl4j.policy.IPolicy; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; +import org.mockito.invocation.InvocationOnMock; +import org.mockito.junit.MockitoJUnitRunner; +import org.mockito.stubbing.Answer; +import org.nd4j.linalg.factory.Nd4j; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.mockito.ArgumentMatchers.*; +import static org.mockito.Mockito.*; +import static org.junit.Assert.*; + +@RunWith(MockitoJUnitRunner.class) +public class AgentLearnerTest { + + @Mock + Environment environmentMock; + + @Mock + TransformProcess transformProcessMock; + + @Mock + IPolicy policyMock; + + @Mock + LearningBehavior learningBehaviorMock; + + @Test + public void when_episodeIsStarted_expect_learningBehaviorHandleEpisodeStartCalled() { + // Arrange + AgentLearner sut = AgentLearner.builder(environmentMock, transformProcessMock, policyMock, learningBehaviorMock) + .maxEpisodeSteps(3) + .build(); + + Schema schema = new Schema(new IntegerActionSchema(0, -1)); + when(environmentMock.reset()).thenReturn(new HashMap<>()); + when(environmentMock.getSchema()).thenReturn(schema); + StepResult stepResult = new StepResult(new HashMap<>(), 234.0, false); + when(environmentMock.step(any(Integer.class))).thenReturn(stepResult); + + when(transformProcessMock.transform(any(Map.class), anyInt(), anyBoolean())).thenReturn(new Observation(Nd4j.create(new double[] { 123.0 }))); + + when(policyMock.nextAction(any(Observation.class))).thenReturn(123); + + // Act + sut.run(); + + // Assert + verify(learningBehaviorMock, times(1)).handleEpisodeStart(); + } + + @Test + public void when_runIsCalled_expect_experienceHandledWithLearningBehavior() { + // Arrange + AgentLearner sut = AgentLearner.builder(environmentMock, transformProcessMock, policyMock, learningBehaviorMock) + .maxEpisodeSteps(4) + .build(); + + Schema schema = new Schema(new IntegerActionSchema(0, -1)); + when(environmentMock.getSchema()).thenReturn(schema); + when(environmentMock.reset()).thenReturn(new HashMap<>()); + + double[] reward = new double[] { 0.0 }; + when(environmentMock.step(any(Integer.class))) + .thenAnswer(a -> new StepResult(new HashMap<>(), ++reward[0], reward[0] == 4.0)); + + when(environmentMock.isEpisodeFinished()).thenAnswer(x -> reward[0] == 4.0); + + when(transformProcessMock.transform(any(Map.class), anyInt(), anyBoolean())) + .thenAnswer(new Answer() { + public Observation answer(InvocationOnMock invocation) throws Throwable { + int step = (int)invocation.getArgument(1); + boolean isTerminal = (boolean)invocation.getArgument(2); + return (step % 2 == 0 || isTerminal) + ? new Observation(Nd4j.create(new double[] { step * 1.1 })) + : Observation.SkippedObservation; + } + }); + + when(policyMock.nextAction(any(Observation.class))).thenAnswer(x -> (int)reward[0]); + + // Act + sut.run(); + + // Assert + ArgumentCaptor observationCaptor = ArgumentCaptor.forClass(Observation.class); + ArgumentCaptor actionCaptor = ArgumentCaptor.forClass(Integer.class); + ArgumentCaptor rewardCaptor = ArgumentCaptor.forClass(Double.class); + ArgumentCaptor isTerminalCaptor = ArgumentCaptor.forClass(Boolean.class); + + verify(learningBehaviorMock, times(2)).handleNewExperience(observationCaptor.capture(), actionCaptor.capture(), rewardCaptor.capture(), isTerminalCaptor.capture()); + List observations = observationCaptor.getAllValues(); + List actions = actionCaptor.getAllValues(); + List rewards = rewardCaptor.getAllValues(); + List isTerminalList = isTerminalCaptor.getAllValues(); + + assertEquals(0.0, observations.get(0).getData().getDouble(0), 0.00001); + assertEquals(0, (int)actions.get(0)); + assertEquals(0.0 + 1.0, rewards.get(0), 0.00001); + assertFalse(isTerminalList.get(0)); + + assertEquals(2.2, observations.get(1).getData().getDouble(0), 0.00001); + assertEquals(2, (int)actions.get(1)); + assertEquals(2.0 + 3.0, rewards.get(1), 0.00001); + assertFalse(isTerminalList.get(1)); + + ArgumentCaptor finalObservationCaptor = ArgumentCaptor.forClass(Observation.class); + verify(learningBehaviorMock, times(1)).handleEpisodeEnd(finalObservationCaptor.capture()); + assertEquals(4.4, finalObservationCaptor.getValue().getData().getDouble(0), 0.00001); + } + + @Test + public void when_runIsCalledMultipleTimes_expect_totalStepCountCorrect() { + // Arrange + AgentLearner sut = AgentLearner.builder(environmentMock, transformProcessMock, policyMock, learningBehaviorMock) + .maxEpisodeSteps(4) + .build(); + + Schema schema = new Schema(new IntegerActionSchema(0, -1)); + when(environmentMock.getSchema()).thenReturn(schema); + when(environmentMock.reset()).thenReturn(new HashMap<>()); + + double[] reward = new double[] { 0.0 }; + when(environmentMock.step(any(Integer.class))) + .thenAnswer(a -> new StepResult(new HashMap<>(), ++reward[0], reward[0] == 4.0)); + + when(environmentMock.isEpisodeFinished()).thenAnswer(x -> reward[0] == 4.0); + + when(transformProcessMock.transform(any(Map.class), anyInt(), anyBoolean())) + .thenAnswer(new Answer() { + public Observation answer(InvocationOnMock invocation) throws Throwable { + int step = (int)invocation.getArgument(1); + boolean isTerminal = (boolean)invocation.getArgument(2); + return (step % 2 == 0 || isTerminal) + ? new Observation(Nd4j.create(new double[] { step * 1.1 })) + : Observation.SkippedObservation; + } + }); + + when(policyMock.nextAction(any(Observation.class))).thenAnswer(x -> (int)reward[0]); + + // Act + sut.run(); + reward[0] = 0.0; + sut.run(); + + // Assert + assertEquals(8, sut.getTotalStepCount()); + } + + @Test + public void when_runIsCalledMultipleTimes_expect_rewardSentToLearningBehaviorToBeCorrect() { + // Arrange + AgentLearner sut = AgentLearner.builder(environmentMock, transformProcessMock, policyMock, learningBehaviorMock) + .maxEpisodeSteps(4) + .build(); + + Schema schema = new Schema(new IntegerActionSchema(0, -1)); + when(environmentMock.getSchema()).thenReturn(schema); + when(environmentMock.reset()).thenReturn(new HashMap<>()); + + double[] reward = new double[] { 0.0 }; + when(environmentMock.step(any(Integer.class))) + .thenAnswer(a -> new StepResult(new HashMap<>(), ++reward[0], reward[0] == 4.0)); + + when(environmentMock.isEpisodeFinished()).thenAnswer(x -> reward[0] == 4.0); + + when(transformProcessMock.transform(any(Map.class), anyInt(), anyBoolean())) + .thenAnswer(new Answer() { + public Observation answer(InvocationOnMock invocation) throws Throwable { + int step = (int)invocation.getArgument(1); + boolean isTerminal = (boolean)invocation.getArgument(2); + return (step % 2 == 0 || isTerminal) + ? new Observation(Nd4j.create(new double[] { step * 1.1 })) + : Observation.SkippedObservation; + } + }); + + when(policyMock.nextAction(any(Observation.class))).thenAnswer(x -> (int)reward[0]); + + // Act + sut.run(); + reward[0] = 0.0; + sut.run(); + + // Assert + ArgumentCaptor rewardCaptor = ArgumentCaptor.forClass(Double.class); + + verify(learningBehaviorMock, times(4)).handleNewExperience(any(Observation.class), any(Integer.class), rewardCaptor.capture(), any(Boolean.class)); + List rewards = rewardCaptor.getAllValues(); + + // rewardAtLastExperience at the end of 1st call to .run() should not leak into 2nd call. + assertEquals(0.0 + 1.0, rewards.get(2), 0.00001); + assertEquals(2.0 + 3.0, rewards.get(3), 0.00001); + } +} \ No newline at end of file diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/agent/AgentTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/agent/AgentTest.java index a8beae640..0022e61f0 100644 --- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/agent/AgentTest.java +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/agent/AgentTest.java @@ -1,10 +1,7 @@ package org.deeplearning4j.rl4j.agent; import org.deeplearning4j.rl4j.agent.listener.AgentListener; -import org.deeplearning4j.rl4j.environment.ActionSchema; -import org.deeplearning4j.rl4j.environment.Environment; -import org.deeplearning4j.rl4j.environment.Schema; -import org.deeplearning4j.rl4j.environment.StepResult; +import org.deeplearning4j.rl4j.environment.*; import org.deeplearning4j.rl4j.observation.Observation; import org.deeplearning4j.rl4j.observation.transform.TransformProcess; import org.deeplearning4j.rl4j.policy.IPolicy; @@ -12,6 +9,7 @@ import org.junit.Rule; import org.junit.Test; import static org.junit.Assert.*; +import org.junit.runner.RunWith; import org.mockito.*; import org.mockito.junit.*; import org.nd4j.linalg.factory.Nd4j; @@ -23,8 +21,8 @@ import java.util.Map; import static org.mockito.ArgumentMatchers.*; import static org.mockito.Mockito.*; +@RunWith(MockitoJUnitRunner.class) public class AgentTest { - @Mock Environment environmentMock; @Mock TransformProcess transformProcessMock; @Mock IPolicy policyMock; @@ -102,7 +100,7 @@ public class AgentTest { public void when_runIsCalled_expect_agentIsReset() { // Arrange Map envResetResult = new HashMap<>(); - Schema schema = new Schema(new ActionSchema<>(-1)); + Schema schema = new Schema(new IntegerActionSchema(0, -1)); when(environmentMock.reset()).thenReturn(envResetResult); when(environmentMock.getSchema()).thenReturn(schema); @@ -119,7 +117,7 @@ public class AgentTest { sut.run(); // Assert - assertEquals(0, sut.getEpisodeStepNumber()); + assertEquals(0, sut.getEpisodeStepCount()); verify(transformProcessMock).transform(envResetResult, 0, false); verify(policyMock, times(1)).reset(); assertEquals(0.0, sut.getReward(), 0.00001); @@ -130,7 +128,7 @@ public class AgentTest { public void when_runIsCalled_expect_onBeforeAndAfterEpisodeCalled() { // Arrange Map envResetResult = new HashMap<>(); - Schema schema = new Schema(new ActionSchema<>(-1)); + Schema schema = new Schema(new IntegerActionSchema(0, -1)); when(environmentMock.reset()).thenReturn(envResetResult); when(environmentMock.getSchema()).thenReturn(schema); @@ -152,7 +150,7 @@ public class AgentTest { public void when_onBeforeEpisodeReturnsStop_expect_performStepAndOnAfterEpisodeNotCalled() { // Arrange Map envResetResult = new HashMap<>(); - Schema schema = new Schema(new ActionSchema<>(-1)); + Schema schema = new Schema(new IntegerActionSchema(0, -1)); when(environmentMock.reset()).thenReturn(envResetResult); when(environmentMock.getSchema()).thenReturn(schema); @@ -179,7 +177,7 @@ public class AgentTest { public void when_runIsCalledWithoutMaxStep_expect_agentRunUntilEpisodeIsFinished() { // Arrange Map envResetResult = new HashMap<>(); - Schema schema = new Schema(new ActionSchema<>(-1)); + Schema schema = new Schema(new IntegerActionSchema(0, -1)); when(environmentMock.reset()).thenReturn(envResetResult); when(environmentMock.getSchema()).thenReturn(schema); @@ -191,10 +189,10 @@ public class AgentTest { final Agent spy = Mockito.spy(sut); doAnswer(invocation -> { - ((Agent)invocation.getMock()).incrementEpisodeStepNumber(); + ((Agent)invocation.getMock()).incrementEpisodeStepCount(); return null; }).when(spy).performStep(); - when(environmentMock.isEpisodeFinished()).thenAnswer(invocation -> spy.getEpisodeStepNumber() >= 5 ); + when(environmentMock.isEpisodeFinished()).thenAnswer(invocation -> spy.getEpisodeStepCount() >= 5 ); // Act spy.run(); @@ -209,7 +207,7 @@ public class AgentTest { public void when_maxStepsIsReachedBeforeEposideEnds_expect_runTerminated() { // Arrange Map envResetResult = new HashMap<>(); - Schema schema = new Schema(new ActionSchema<>(-1)); + Schema schema = new Schema(new IntegerActionSchema(0, -1)); when(environmentMock.reset()).thenReturn(envResetResult); when(environmentMock.getSchema()).thenReturn(schema); @@ -222,7 +220,7 @@ public class AgentTest { final Agent spy = Mockito.spy(sut); doAnswer(invocation -> { - ((Agent)invocation.getMock()).incrementEpisodeStepNumber(); + ((Agent)invocation.getMock()).incrementEpisodeStepCount(); return null; }).when(spy).performStep(); @@ -239,7 +237,7 @@ public class AgentTest { public void when_initialObservationsAreSkipped_expect_performNoOpAction() { // Arrange Map envResetResult = new HashMap<>(); - Schema schema = new Schema(new ActionSchema<>(-1)); + Schema schema = new Schema(new IntegerActionSchema(0, -1)); when(environmentMock.reset()).thenReturn(envResetResult); when(environmentMock.getSchema()).thenReturn(schema); @@ -264,7 +262,7 @@ public class AgentTest { public void when_initialObservationsAreSkipped_expect_performNoOpActionAnd() { // Arrange Map envResetResult = new HashMap<>(); - Schema schema = new Schema(new ActionSchema<>(-1)); + Schema schema = new Schema(new IntegerActionSchema(0, -1)); when(environmentMock.reset()).thenReturn(envResetResult); when(environmentMock.getSchema()).thenReturn(schema); @@ -289,7 +287,7 @@ public class AgentTest { public void when_observationsIsSkipped_expect_performLastAction() { // Arrange Map envResetResult = new HashMap<>(); - Schema schema = new Schema(new ActionSchema<>(-1)); + Schema schema = new Schema(new IntegerActionSchema(0, -1)); when(environmentMock.reset()).thenReturn(envResetResult); when(environmentMock.step(any(Integer.class))).thenReturn(new StepResult(envResetResult, 0.0, false)); when(environmentMock.getSchema()).thenReturn(schema); @@ -331,7 +329,7 @@ public class AgentTest { @Test public void when_onBeforeStepReturnsStop_expect_performStepAndOnAfterEpisodeNotCalled() { // Arrange - Schema schema = new Schema(new ActionSchema<>(-1)); + Schema schema = new Schema(new IntegerActionSchema(0, -1)); when(environmentMock.reset()).thenReturn(new HashMap<>()); when(environmentMock.getSchema()).thenReturn(schema); @@ -358,7 +356,7 @@ public class AgentTest { @Test public void when_observationIsNotSkipped_expect_policyActionIsSentToEnvironment() { // Arrange - Schema schema = new Schema(new ActionSchema<>(-1)); + Schema schema = new Schema(new IntegerActionSchema(0, -1)); when(environmentMock.reset()).thenReturn(new HashMap<>()); when(environmentMock.getSchema()).thenReturn(schema); when(environmentMock.step(any(Integer.class))).thenReturn(new StepResult(new HashMap<>(), 0.0, false)); @@ -381,7 +379,7 @@ public class AgentTest { @Test public void when_stepResultIsReceived_expect_observationAndRewardUpdated() { // Arrange - Schema schema = new Schema(new ActionSchema<>(-1)); + Schema schema = new Schema(new IntegerActionSchema(0, -1)); when(environmentMock.reset()).thenReturn(new HashMap<>()); when(environmentMock.getSchema()).thenReturn(schema); when(environmentMock.step(any(Integer.class))).thenReturn(new StepResult(new HashMap<>(), 234.0, false)); @@ -405,7 +403,7 @@ public class AgentTest { @Test public void when_stepIsDone_expect_onAfterStepAndWithStepResult() { // Arrange - Schema schema = new Schema(new ActionSchema<>(-1)); + Schema schema = new Schema(new IntegerActionSchema(0, -1)); when(environmentMock.reset()).thenReturn(new HashMap<>()); when(environmentMock.getSchema()).thenReturn(schema); StepResult stepResult = new StepResult(new HashMap<>(), 234.0, false); @@ -430,7 +428,7 @@ public class AgentTest { @Test public void when_onAfterStepReturnsStop_expect_onAfterEpisodeNotCalled() { // Arrange - Schema schema = new Schema(new ActionSchema<>(-1)); + Schema schema = new Schema(new IntegerActionSchema(0, -1)); when(environmentMock.reset()).thenReturn(new HashMap<>()); when(environmentMock.getSchema()).thenReturn(schema); StepResult stepResult = new StepResult(new HashMap<>(), 234.0, false); @@ -458,7 +456,7 @@ public class AgentTest { @Test public void when_runIsCalled_expect_onAfterEpisodeIsCalled() { // Arrange - Schema schema = new Schema(new ActionSchema<>(-1)); + Schema schema = new Schema(new IntegerActionSchema(0, -1)); when(environmentMock.reset()).thenReturn(new HashMap<>()); when(environmentMock.getSchema()).thenReturn(schema); StepResult stepResult = new StepResult(new HashMap<>(), 234.0, false); diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/agent/learning/LearningBehaviorTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/agent/learning/LearningBehaviorTest.java new file mode 100644 index 000000000..1e39c63d5 --- /dev/null +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/agent/learning/LearningBehaviorTest.java @@ -0,0 +1,133 @@ +package org.deeplearning4j.rl4j.agent.learning; + +import org.deeplearning4j.rl4j.agent.update.IUpdateRule; +import org.deeplearning4j.rl4j.experience.ExperienceHandler; +import org.deeplearning4j.rl4j.observation.Observation; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; +import org.mockito.junit.MockitoJUnitRunner; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.factory.Nd4j; + +import java.util.ArrayList; +import java.util.List; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.*; + +@RunWith(MockitoJUnitRunner.class) +public class LearningBehaviorTest { + + @Mock + ExperienceHandler experienceHandlerMock; + + @Mock + IUpdateRule updateRuleMock; + + LearningBehavior sut; + + @Before + public void setup() { + sut = LearningBehavior.builder() + .experienceHandler(experienceHandlerMock) + .updateRule(updateRuleMock) + .build(); + } + + @Test + public void when_callingHandleEpisodeStart_expect_experienceHandlerResetCalled() { + // Arrange + LearningBehavior sut = LearningBehavior.builder() + .experienceHandler(experienceHandlerMock) + .updateRule(updateRuleMock) + .build(); + + // Act + sut.handleEpisodeStart(); + + // Assert + verify(experienceHandlerMock, times(1)).reset(); + } + + @Test + public void when_callingHandleNewExperience_expect_experienceHandlerAddExperienceCalled() { + // Arrange + INDArray observationData = Nd4j.rand(1, 1); + when(experienceHandlerMock.isTrainingBatchReady()).thenReturn(false); + + // Act + sut.handleNewExperience(new Observation(observationData), 1, 2.0, false); + + // Assert + ArgumentCaptor observationCaptor = ArgumentCaptor.forClass(Observation.class); + ArgumentCaptor actionCaptor = ArgumentCaptor.forClass(Integer.class); + ArgumentCaptor rewardCaptor = ArgumentCaptor.forClass(Double.class); + ArgumentCaptor isTerminatedCaptor = ArgumentCaptor.forClass(Boolean.class); + verify(experienceHandlerMock, times(1)).addExperience(observationCaptor.capture(), actionCaptor.capture(), rewardCaptor.capture(), isTerminatedCaptor.capture()); + + assertEquals(observationData.getDouble(0, 0), observationCaptor.getValue().getData().getDouble(0, 0), 0.00001); + assertEquals(1, (int)actionCaptor.getValue()); + assertEquals(2.0, (double)rewardCaptor.getValue(), 0.00001); + assertFalse(isTerminatedCaptor.getValue()); + + verify(updateRuleMock, never()).update(any(List.class)); + } + + @Test + public void when_callingHandleNewExperienceAndTrainingBatchIsReady_expect_updateRuleUpdateWithTrainingBatch() { + // Arrange + INDArray observationData = Nd4j.rand(1, 1); + when(experienceHandlerMock.isTrainingBatchReady()).thenReturn(true); + List trainingBatch = new ArrayList(); + when(experienceHandlerMock.generateTrainingBatch()).thenReturn(trainingBatch); + + // Act + sut.handleNewExperience(new Observation(observationData), 1, 2.0, false); + + // Assert + verify(updateRuleMock, times(1)).update(trainingBatch); + } + + @Test + public void when_callingHandleEpisodeEnd_expect_experienceHandlerSetFinalObservationCalled() { + // Arrange + INDArray observationData = Nd4j.rand(1, 1); + when(experienceHandlerMock.isTrainingBatchReady()).thenReturn(false); + + // Act + sut.handleEpisodeEnd(new Observation(observationData)); + + // Assert + ArgumentCaptor observationCaptor = ArgumentCaptor.forClass(Observation.class); + verify(experienceHandlerMock, times(1)).setFinalObservation(observationCaptor.capture()); + + assertEquals(observationData.getDouble(0, 0), observationCaptor.getValue().getData().getDouble(0, 0), 0.00001); + + verify(updateRuleMock, never()).update(any(List.class)); + } + + @Test + public void when_callingHandleEpisodeEndAndTrainingBatchIsNotEmpty_expect_updateRuleUpdateWithTrainingBatch() { + // Arrange + INDArray observationData = Nd4j.rand(1, 1); + when(experienceHandlerMock.isTrainingBatchReady()).thenReturn(true); + List trainingBatch = new ArrayList(); + when(experienceHandlerMock.generateTrainingBatch()).thenReturn(trainingBatch); + + // Act + sut.handleEpisodeEnd(new Observation(observationData)); + + // Assert + ArgumentCaptor observationCaptor = ArgumentCaptor.forClass(Observation.class); + verify(experienceHandlerMock, times(1)).setFinalObservation(observationCaptor.capture()); + + assertEquals(observationData.getDouble(0, 0), observationCaptor.getValue().getData().getDouble(0, 0), 0.00001); + + verify(updateRuleMock, times(1)).update(trainingBatch); + } +} diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/experience/ReplayMemoryExperienceHandlerTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/experience/ReplayMemoryExperienceHandlerTest.java index 765a14c8f..0d90e812d 100644 --- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/experience/ReplayMemoryExperienceHandlerTest.java +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/experience/ReplayMemoryExperienceHandlerTest.java @@ -4,34 +4,44 @@ import org.deeplearning4j.rl4j.learning.sync.IExpReplay; import org.deeplearning4j.rl4j.learning.sync.Transition; import org.deeplearning4j.rl4j.observation.Observation; import org.junit.Test; +import org.junit.runner.RunWith; +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; +import org.mockito.junit.MockitoJUnitRunner; import org.nd4j.linalg.factory.Nd4j; -import java.util.ArrayList; import java.util.List; -import static org.junit.Assert.assertEquals; +import static org.junit.Assert.*; +import static org.mockito.Mockito.*; +@RunWith(MockitoJUnitRunner.class) public class ReplayMemoryExperienceHandlerTest { + + @Mock + IExpReplay expReplayMock; + @Test public void when_addingFirstExperience_expect_notAddedToStoreBeforeNextObservationIsAdded() { // Arrange - TestExpReplay expReplayMock = new TestExpReplay(); + when(expReplayMock.getDesignatedBatchSize()).thenReturn(10); + ReplayMemoryExperienceHandler sut = new ReplayMemoryExperienceHandler(expReplayMock); // Act sut.addExperience(new Observation(Nd4j.create(new double[] { 1.0 })), 1, 1.0, false); - int numStoredTransitions = expReplayMock.addedTransitions.size(); + boolean isStoreCalledAfterFirstAdd = mockingDetails(expReplayMock).getInvocations().stream().anyMatch(x -> x.getMethod().getName() == "store"); sut.addExperience(new Observation(Nd4j.create(new double[] { 2.0 })), 2, 2.0, false); + boolean isStoreCalledAfterSecondAdd = mockingDetails(expReplayMock).getInvocations().stream().anyMatch(x -> x.getMethod().getName() == "store"); // Assert - assertEquals(0, numStoredTransitions); - assertEquals(1, expReplayMock.addedTransitions.size()); + assertFalse(isStoreCalledAfterFirstAdd); + assertTrue(isStoreCalledAfterSecondAdd); } @Test public void when_addingExperience_expect_transitionsAreCorrect() { // Arrange - TestExpReplay expReplayMock = new TestExpReplay(); ReplayMemoryExperienceHandler sut = new ReplayMemoryExperienceHandler(expReplayMock); // Act @@ -40,24 +50,25 @@ public class ReplayMemoryExperienceHandlerTest { sut.setFinalObservation(new Observation(Nd4j.create(new double[] { 3.0 }))); // Assert - assertEquals(2, expReplayMock.addedTransitions.size()); + ArgumentCaptor> argument = ArgumentCaptor.forClass(Transition.class); + verify(expReplayMock, times(2)).store(argument.capture()); + List> transitions = argument.getAllValues(); - assertEquals(1.0, expReplayMock.addedTransitions.get(0).getObservation().getData().getDouble(0), 0.00001); - assertEquals(1, (int)expReplayMock.addedTransitions.get(0).getAction()); - assertEquals(1.0, expReplayMock.addedTransitions.get(0).getReward(), 0.00001); - assertEquals(2.0, expReplayMock.addedTransitions.get(0).getNextObservation().getDouble(0), 0.00001); + assertEquals(1.0, transitions.get(0).getObservation().getData().getDouble(0), 0.00001); + assertEquals(1, (int)transitions.get(0).getAction()); + assertEquals(1.0, transitions.get(0).getReward(), 0.00001); + assertEquals(2.0, transitions.get(0).getNextObservation().getDouble(0), 0.00001); - assertEquals(2.0, expReplayMock.addedTransitions.get(1).getObservation().getData().getDouble(0), 0.00001); - assertEquals(2, (int)expReplayMock.addedTransitions.get(1).getAction()); - assertEquals(2.0, expReplayMock.addedTransitions.get(1).getReward(), 0.00001); - assertEquals(3.0, expReplayMock.addedTransitions.get(1).getNextObservation().getDouble(0), 0.00001); + assertEquals(2.0, transitions.get(1).getObservation().getData().getDouble(0), 0.00001); + assertEquals(2, (int)transitions.get(1).getAction()); + assertEquals(2.0, transitions.get(1).getReward(), 0.00001); + assertEquals(3.0, transitions.get(1).getNextObservation().getDouble(0), 0.00001); } @Test public void when_settingFinalObservation_expect_nextAddedExperienceDoNotUsePreviousObservation() { // Arrange - TestExpReplay expReplayMock = new TestExpReplay(); ReplayMemoryExperienceHandler sut = new ReplayMemoryExperienceHandler(expReplayMock); // Act @@ -66,42 +77,57 @@ public class ReplayMemoryExperienceHandlerTest { sut.addExperience(new Observation(Nd4j.create(new double[] { 3.0 })), 3, 3.0, false); // Assert - assertEquals(1, expReplayMock.addedTransitions.size()); - assertEquals(1, (int)expReplayMock.addedTransitions.get(0).getAction()); + ArgumentCaptor> argument = ArgumentCaptor.forClass(Transition.class); + verify(expReplayMock, times(1)).store(argument.capture()); + Transition transition = argument.getValue(); + + assertEquals(1, (int)transition.getAction()); } @Test public void when_addingExperience_expect_getTrainingBatchSizeReturnSize() { // Arrange - TestExpReplay expReplayMock = new TestExpReplay(); - ReplayMemoryExperienceHandler sut = new ReplayMemoryExperienceHandler(expReplayMock); + ReplayMemoryExperienceHandler sut = new ReplayMemoryExperienceHandler(10, 5, Nd4j.getRandom()); sut.addExperience(new Observation(Nd4j.create(new double[] { 1.0 })), 1, 1.0, false); sut.addExperience(new Observation(Nd4j.create(new double[] { 2.0 })), 2, 2.0, false); sut.setFinalObservation(new Observation(Nd4j.create(new double[] { 3.0 }))); // Act int size = sut.getTrainingBatchSize(); + // Assert assertEquals(2, size); } - private static class TestExpReplay implements IExpReplay { + @Test + public void when_experienceSizeIsSmallerThanBatchSize_expect_TrainingBatchIsNotReady() { + // Arrange + ReplayMemoryExperienceHandler sut = new ReplayMemoryExperienceHandler(10, 5, Nd4j.getRandom()); + sut.addExperience(new Observation(Nd4j.create(new double[] { 1.0 })), 1, 1.0, false); + sut.addExperience(new Observation(Nd4j.create(new double[] { 2.0 })), 2, 2.0, false); + sut.setFinalObservation(new Observation(Nd4j.create(new double[] { 3.0 }))); - public final List> addedTransitions = new ArrayList<>(); + // Act - @Override - public ArrayList> getBatch() { - return null; - } - - @Override - public void store(Transition transition) { - addedTransitions.add(transition); - } - - @Override - public int getBatchSize() { - return addedTransitions.size(); - } + // Assert + assertFalse(sut.isTrainingBatchReady()); } + + @Test + public void when_experienceSizeIsGreaterOrEqualToBatchSize_expect_TrainingBatchIsReady() { + // Arrange + ReplayMemoryExperienceHandler sut = new ReplayMemoryExperienceHandler(10, 5, Nd4j.getRandom()); + sut.addExperience(new Observation(Nd4j.create(new double[] { 1.0 })), 1, 1.0, false); + sut.addExperience(new Observation(Nd4j.create(new double[] { 2.0 })), 2, 2.0, false); + sut.addExperience(new Observation(Nd4j.create(new double[] { 3.0 })), 3, 3.0, false); + sut.addExperience(new Observation(Nd4j.create(new double[] { 4.0 })), 4, 4.0, false); + sut.addExperience(new Observation(Nd4j.create(new double[] { 5.0 })), 5, 5.0, false); + sut.setFinalObservation(new Observation(Nd4j.create(new double[] { 6.0 }))); + + // Act + + // Assert + assertTrue(sut.isTrainingBatchReady()); + } + } diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/experience/StateActionExperienceHandlerTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/experience/StateActionExperienceHandlerTest.java index 7334ff87a..2ce0d6659 100644 --- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/experience/StateActionExperienceHandlerTest.java +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/experience/StateActionExperienceHandlerTest.java @@ -13,7 +13,7 @@ public class StateActionExperienceHandlerTest { @Test public void when_addingExperience_expect_generateTrainingBatchReturnsIt() { // Arrange - StateActionExperienceHandler sut = new StateActionExperienceHandler(); + StateActionExperienceHandler sut = new StateActionExperienceHandler(Integer.MAX_VALUE); sut.reset(); Observation observation = new Observation(Nd4j.zeros(1)); sut.addExperience(observation, 123, 234.0, true); @@ -32,7 +32,7 @@ public class StateActionExperienceHandlerTest { @Test public void when_addingMultipleExperiences_expect_generateTrainingBatchReturnsItInSameOrder() { // Arrange - StateActionExperienceHandler sut = new StateActionExperienceHandler(); + StateActionExperienceHandler sut = new StateActionExperienceHandler(Integer.MAX_VALUE); sut.reset(); sut.addExperience(null, 1, 1.0, false); sut.addExperience(null, 2, 2.0, false); @@ -51,7 +51,7 @@ public class StateActionExperienceHandlerTest { @Test public void when_gettingExperience_expect_experienceStoreIsCleared() { // Arrange - StateActionExperienceHandler sut = new StateActionExperienceHandler(); + StateActionExperienceHandler sut = new StateActionExperienceHandler(Integer.MAX_VALUE); sut.reset(); sut.addExperience(null, 1, 1.0, false); @@ -67,7 +67,7 @@ public class StateActionExperienceHandlerTest { @Test public void when_addingExperience_expect_getTrainingBatchSizeReturnSize() { // Arrange - StateActionExperienceHandler sut = new StateActionExperienceHandler(); + StateActionExperienceHandler sut = new StateActionExperienceHandler(Integer.MAX_VALUE); sut.reset(); sut.addExperience(null, 1, 1.0, false); sut.addExperience(null, 2, 2.0, false); @@ -79,4 +79,66 @@ public class StateActionExperienceHandlerTest { // Assert assertEquals(3, size); } + + @Test + public void when_experienceIsEmpty_expect_TrainingBatchNotReady() { + // Arrange + StateActionExperienceHandler sut = new StateActionExperienceHandler(5); + sut.reset(); + + // Act + boolean isTrainingBatchReady = sut.isTrainingBatchReady(); + + // Assert + assertFalse(isTrainingBatchReady); + } + + @Test + public void when_experienceSizeIsGreaterOrEqualToThanBatchSize_expect_TrainingBatchIsReady() { + // Arrange + StateActionExperienceHandler sut = new StateActionExperienceHandler(5); + sut.reset(); + sut.addExperience(null, 1, 1.0, false); + sut.addExperience(null, 2, 2.0, false); + sut.addExperience(null, 3, 3.0, false); + sut.addExperience(null, 4, 4.0, false); + sut.addExperience(null, 5, 5.0, false); + + // Act + boolean isTrainingBatchReady = sut.isTrainingBatchReady(); + + // Assert + assertTrue(isTrainingBatchReady); + } + + @Test + public void when_experienceSizeIsSmallerThanBatchSizeButFinalObservationIsSet_expect_TrainingBatchIsReady() { + // Arrange + StateActionExperienceHandler sut = new StateActionExperienceHandler(5); + sut.reset(); + sut.addExperience(null, 1, 1.0, false); + sut.addExperience(null, 2, 2.0, false); + sut.setFinalObservation(null); + + // Act + boolean isTrainingBatchReady = sut.isTrainingBatchReady(); + + // Assert + assertTrue(isTrainingBatchReady); + } + + @Test + public void when_experienceSizeIsZeroAndFinalObservationIsSet_expect_TrainingBatchIsNotReady() { + // Arrange + StateActionExperienceHandler sut = new StateActionExperienceHandler(5); + sut.reset(); + sut.setFinalObservation(null); + + // Act + boolean isTrainingBatchReady = sut.isTrainingBatchReady(); + + // Assert + assertFalse(isTrainingBatchReady); + } + } diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/helper/INDArrayHelperTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/helper/INDArrayHelperTest.java index e1c5c64ed..7af15b8c4 100644 --- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/helper/INDArrayHelperTest.java +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/helper/INDArrayHelperTest.java @@ -49,4 +49,25 @@ public class INDArrayHelperTest { assertEquals(1, output.shape()[1]); } + @Test + public void when_callingCreateBatchForShape_expect_INDArrayWithCorrectShapeAndOriginalShapeUnchanged() { + // Arrange + long[] shape = new long[] { 1, 3, 4}; + + // Act + INDArray output = INDArrayHelper.createBatchForShape(2, shape); + + // Assert + // Output shape + assertEquals(3, output.shape().length); + assertEquals(2, output.shape()[0]); + assertEquals(3, output.shape()[1]); + assertEquals(4, output.shape()[2]); + + // Input should remain unchanged + assertEquals(1, shape[0]); + assertEquals(3, shape[1]); + assertEquals(4, shape[2]); + + } } diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/QLearningUpdateAlgorithmTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/QLearningUpdateAlgorithmTest.java index f44437d67..ae83bd1f0 100644 --- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/QLearningUpdateAlgorithmTest.java +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/QLearningUpdateAlgorithmTest.java @@ -19,10 +19,11 @@ package org.deeplearning4j.rl4j.learning.async.nstep.discrete; import org.deeplearning4j.rl4j.experience.StateActionPair; import org.deeplearning4j.rl4j.learning.async.AsyncGlobal; import org.deeplearning4j.rl4j.learning.async.UpdateAlgorithm; +import org.deeplearning4j.rl4j.network.dqn.IDQN; import org.deeplearning4j.rl4j.observation.Observation; -import org.deeplearning4j.rl4j.support.MockDQN; import org.junit.Test; import org.junit.runner.RunWith; +import org.mockito.ArgumentCaptor; import org.mockito.Mock; import org.mockito.junit.MockitoJUnitRunner; import org.nd4j.linalg.api.ndarray.INDArray; @@ -32,6 +33,9 @@ import java.util.ArrayList; import java.util.List; import static org.junit.Assert.assertEquals; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.argThat; +import static org.mockito.Mockito.*; @RunWith(MockitoJUnitRunner.class) public class QLearningUpdateAlgorithmTest { @@ -39,12 +43,24 @@ public class QLearningUpdateAlgorithmTest { @Mock AsyncGlobal mockAsyncGlobal; + @Mock + IDQN dqnMock; + + private UpdateAlgorithm sut; + + private void setup(double gamma) { + // mock a neural net output -- just invert the sign of the input + when(dqnMock.outputAll(any(INDArray.class))).thenAnswer(invocation -> new INDArray[] { invocation.getArgument(0, INDArray.class).mul(-1.0) }); + + sut = new QLearningUpdateAlgorithm(2, gamma); + } + @Test public void when_isTerminal_expect_initRewardIs0() { // Arrange - MockDQN dqnMock = new MockDQN(); - UpdateAlgorithm sut = new QLearningUpdateAlgorithm(new int[] { 1 }, 1, 1.0); - final Observation observation = new Observation(Nd4j.zeros(1)); + setup(1.0); + + final Observation observation = new Observation(Nd4j.zeros(1, 2)); List> experience = new ArrayList>() { { add(new StateActionPair(observation, 0, 0.0, true)); @@ -55,59 +71,68 @@ public class QLearningUpdateAlgorithmTest { sut.computeGradients(dqnMock, experience); // Assert - assertEquals(0.0, dqnMock.gradientParams.get(0).getRight().getDouble(0), 0.00001); + verify(dqnMock, times(1)).gradient(any(INDArray.class), argThat((INDArray x) -> x.getDouble(0) == 0.0)); } @Test public void when_terminalAndNoTargetUpdate_expect_initRewardWithMaxQFromCurrent() { // Arrange - UpdateAlgorithm sut = new QLearningUpdateAlgorithm(new int[] { 2 }, 2, 1.0); - final Observation observation = new Observation(Nd4j.create(new double[] { -123.0, -234.0 })); + setup(1.0); + + final Observation observation = new Observation(Nd4j.create(new double[] { -123.0, -234.0 }).reshape(1, 2)); List> experience = new ArrayList>() { { add(new StateActionPair(observation, 0, 0.0, false)); } }; - MockDQN dqnMock = new MockDQN(); // Act sut.computeGradients(dqnMock, experience); // Assert - assertEquals(2, dqnMock.outputAllParams.size()); - assertEquals(-123.0, dqnMock.outputAllParams.get(0).getDouble(0, 0), 0.00001); - assertEquals(234.0, dqnMock.gradientParams.get(0).getRight().getDouble(0), 0.00001); + ArgumentCaptor argument = ArgumentCaptor.forClass(INDArray.class); + + verify(dqnMock, times(2)).outputAll(argument.capture()); + List values = argument.getAllValues(); + assertEquals(-123.0, values.get(0).getDouble(0, 0), 0.00001); + assertEquals(-123.0, values.get(1).getDouble(0, 0), 0.00001); + + verify(dqnMock, times(1)).gradient(any(INDArray.class), argThat((INDArray x) -> x.getDouble(0) == 234.0)); } @Test public void when_callingWithMultipleExperiences_expect_gradientsAreValid() { // Arrange double gamma = 0.9; - UpdateAlgorithm sut = new QLearningUpdateAlgorithm(new int[] { 2 }, 2, gamma); + setup(gamma); + List> experience = new ArrayList>() { { - add(new StateActionPair(new Observation(Nd4j.create(new double[] { -1.1, -1.2 })), 0, 1.0, false)); - add(new StateActionPair(new Observation(Nd4j.create(new double[] { -2.1, -2.2 })), 1, 2.0, true)); + add(new StateActionPair(new Observation(Nd4j.create(new double[] { -1.1, -1.2 }).reshape(1, 2)), 0, 1.0, false)); + add(new StateActionPair(new Observation(Nd4j.create(new double[] { -2.1, -2.2 }).reshape(1, 2)), 1, 2.0, true)); } }; - MockDQN dqnMock = new MockDQN(); // Act sut.computeGradients(dqnMock, experience); // Assert + ArgumentCaptor features = ArgumentCaptor.forClass(INDArray.class); + ArgumentCaptor targets = ArgumentCaptor.forClass(INDArray.class); + verify(dqnMock, times(1)).gradient(features.capture(), targets.capture()); + // input side -- should be a stack of observations - INDArray input = dqnMock.gradientParams.get(0).getLeft(); - assertEquals(-1.1, input.getDouble(0, 0), 0.00001); - assertEquals(-1.2, input.getDouble(0, 1), 0.00001); - assertEquals(-2.1, input.getDouble(1, 0), 0.00001); - assertEquals(-2.2, input.getDouble(1, 1), 0.00001); + INDArray featuresValues = features.getValue(); + assertEquals(-1.1, featuresValues.getDouble(0, 0), 0.00001); + assertEquals(-1.2, featuresValues.getDouble(0, 1), 0.00001); + assertEquals(-2.1, featuresValues.getDouble(1, 0), 0.00001); + assertEquals(-2.2, featuresValues.getDouble(1, 1), 0.00001); // target side - INDArray target = dqnMock.gradientParams.get(0).getRight(); - assertEquals(1.0 + gamma * 2.0, target.getDouble(0, 0), 0.00001); - assertEquals(1.2, target.getDouble(0, 1), 0.00001); - assertEquals(2.1, target.getDouble(1, 0), 0.00001); - assertEquals(2.0, target.getDouble(1, 1), 0.00001); + INDArray targetsValues = targets.getValue(); + assertEquals(1.0 + gamma * 2.0, targetsValues.getDouble(0, 0), 0.00001); + assertEquals(1.2, targetsValues.getDouble(0, 1), 0.00001); + assertEquals(2.1, targetsValues.getDouble(1, 0), 0.00001); + assertEquals(2.0, targetsValues.getDouble(1, 1), 0.00001); } } diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/QLearningDiscreteTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/QLearningDiscreteTest.java index e19af338b..e1424c286 100644 --- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/QLearningDiscreteTest.java +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/QLearningDiscreteTest.java @@ -18,6 +18,7 @@ package org.deeplearning4j.rl4j.learning.sync.qlearning.discrete; import org.deeplearning4j.gym.StepReply; +import org.deeplearning4j.rl4j.agent.learning.ILearningBehavior; import org.deeplearning4j.rl4j.learning.IHistoryProcessor; import org.deeplearning4j.rl4j.learning.configuration.QLearningConfiguration; import org.deeplearning4j.rl4j.learning.sync.qlearning.QLearning; @@ -74,6 +75,9 @@ public class QLearningDiscreteTest { @Mock QLearningConfiguration mockQlearningConfiguration; + @Mock + ILearningBehavior learningBehavior; + // HWC int[] observationShape = new int[]{3, 10, 10}; int totalObservationSize = 1; @@ -92,18 +96,28 @@ public class QLearningDiscreteTest { } - private void mockTestContext(int maxSteps, int updateStart, int batchSize, double rewardFactor, int maxExperienceReplay) { + private void mockTestContext(int maxSteps, int updateStart, int batchSize, double rewardFactor, int maxExperienceReplay, ILearningBehavior learningBehavior) { when(mockQlearningConfiguration.getBatchSize()).thenReturn(batchSize); when(mockQlearningConfiguration.getRewardFactor()).thenReturn(rewardFactor); when(mockQlearningConfiguration.getExpRepMaxSize()).thenReturn(maxExperienceReplay); when(mockQlearningConfiguration.getSeed()).thenReturn(123L); - qLearningDiscrete = mock( - QLearningDiscrete.class, - Mockito.withSettings() - .useConstructor(mockMDP, mockDQN, mockQlearningConfiguration, 0) - .defaultAnswer(Mockito.CALLS_REAL_METHODS) - ); + if(learningBehavior != null) { + qLearningDiscrete = mock( + QLearningDiscrete.class, + Mockito.withSettings() + .useConstructor(mockMDP, mockDQN, mockQlearningConfiguration, 0, learningBehavior, Nd4j.getRandom()) + .defaultAnswer(Mockito.CALLS_REAL_METHODS) + ); + } + else { + qLearningDiscrete = mock( + QLearningDiscrete.class, + Mockito.withSettings() + .useConstructor(mockMDP, mockDQN, mockQlearningConfiguration, 0) + .defaultAnswer(Mockito.CALLS_REAL_METHODS) + ); + } } private void mockHistoryProcessor(int skipFrames) { @@ -136,7 +150,7 @@ public class QLearningDiscreteTest { public void when_singleTrainStep_expect_correctValues() { // Arrange - mockTestContext(100,0,2,1.0, 10); + mockTestContext(100,0,2,1.0, 10, null); // An example observation and 2 Q values output (2 actions) Observation observation = new Observation(Nd4j.zeros(observationShape)); @@ -162,7 +176,7 @@ public class QLearningDiscreteTest { @Test public void when_singleTrainStepSkippedFrames_expect_correctValues() { // Arrange - mockTestContext(100,0,2,1.0, 10); + mockTestContext(100,0,2,1.0, 10, learningBehavior); Observation skippedObservation = Observation.SkippedObservation; Observation nextObservation = new Observation(Nd4j.zeros(observationShape)); @@ -180,8 +194,8 @@ public class QLearningDiscreteTest { assertEquals(0, stepReply.getReward(), 1e-5); assertFalse(stepReply.isDone()); assertFalse(stepReply.getObservation().isSkipped()); - assertEquals(0, qLearningDiscrete.getExperienceHandler().getTrainingBatchSize()); + verify(learningBehavior, never()).handleNewExperience(any(Observation.class), any(Integer.class), any(Double.class), any(Boolean.class)); verify(mockDQN, never()).output(any(INDArray.class)); }