diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/Transition.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/Transition.java index 7b353f00b..7242dd64b 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/Transition.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/Transition.java @@ -16,6 +16,7 @@ package org.deeplearning4j.rl4j.learning.sync; +import lombok.AllArgsConstructor; import lombok.Value; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; @@ -27,6 +28,7 @@ import org.nd4j.linalg.factory.Nd4j; * State, Action, Reward, (isTerminal), State */ @Value +@AllArgsConstructor public class Transition { INDArray[] observation; diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/QLearning.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/QLearning.java index 564f654fc..363cd5e87 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/QLearning.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/QLearning.java @@ -43,7 +43,7 @@ import java.util.List; */ @Slf4j public abstract class QLearning> - extends SyncLearning { + extends SyncLearning implements TargetQNetworkSource { // FIXME Changed for refac // @Getter @@ -61,28 +61,19 @@ public abstract class QLearning getMdp(); - protected abstract IDQN getCurrentDQN(); + public abstract IDQN getQNetwork(); - protected abstract IDQN getTargetDQN(); + public abstract IDQN getTargetQNetwork(); - protected abstract void setTargetDQN(IDQN dqn); - - protected INDArray dqnOutput(INDArray input) { - return getCurrentDQN().output(input); - } - - protected INDArray targetDqnOutput(INDArray input) { - return getTargetDQN().output(input); - } + protected abstract void setTargetQNetwork(IDQN dqn); protected void updateTargetNetwork() { log.info("Update target network"); - setTargetDQN(getCurrentDQN().clone()); + setTargetQNetwork(getQNetwork().clone()); } - public IDQN getNeuralNet() { - return getCurrentDQN(); + return getQNetwork(); } public abstract QLConfiguration getConfiguration(); diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/QNetworkSource.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/QNetworkSource.java new file mode 100644 index 000000000..e22d368e4 --- /dev/null +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/QNetworkSource.java @@ -0,0 +1,28 @@ +/******************************************************************************* + * Copyright (c) 2015-2019 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.deeplearning4j.rl4j.learning.sync.qlearning; + +import org.deeplearning4j.rl4j.network.dqn.IDQN; + +/** + * An interface for all implementations capable of supplying a Q-Network + * + * @author Alexandre Boulanger + */ +public interface QNetworkSource { + IDQN getQNetwork(); +} diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/TargetQNetworkSource.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/TargetQNetworkSource.java new file mode 100644 index 000000000..34fd9c06e --- /dev/null +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/TargetQNetworkSource.java @@ -0,0 +1,28 @@ +/******************************************************************************* + * Copyright (c) 2015-2019 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.deeplearning4j.rl4j.learning.sync.qlearning; + +import org.deeplearning4j.rl4j.network.dqn.IDQN; + +/** + * An interface that is an extension of {@link QNetworkSource} for all implementations capable of supplying a target Q-Network + * + * @author Alexandre Boulanger + */ +public interface TargetQNetworkSource extends QNetworkSource { + IDQN getTargetQNetwork(); +} diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/QLearningDiscrete.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/QLearningDiscrete.java index bc8fb37d2..77cc299ef 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/QLearningDiscrete.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/QLearningDiscrete.java @@ -16,12 +16,14 @@ package org.deeplearning4j.rl4j.learning.sync.qlearning.discrete; +import lombok.AccessLevel; import lombok.Getter; import lombok.Setter; import org.deeplearning4j.gym.StepReply; import org.deeplearning4j.rl4j.learning.Learning; import org.deeplearning4j.rl4j.learning.sync.Transition; import org.deeplearning4j.rl4j.learning.sync.qlearning.QLearning; +import org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm.*; import org.deeplearning4j.rl4j.mdp.MDP; import org.deeplearning4j.rl4j.network.dqn.IDQN; import org.deeplearning4j.rl4j.policy.DQNPolicy; @@ -29,10 +31,7 @@ import org.deeplearning4j.rl4j.policy.EpsGreedy; import org.deeplearning4j.rl4j.space.DiscreteSpace; import org.deeplearning4j.rl4j.space.Encodable; import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.factory.Nd4j; -import org.nd4j.linalg.indexing.INDArrayIndex; -import org.nd4j.linalg.indexing.NDArrayIndex; -import org.nd4j.linalg.primitives.Pair; +import org.nd4j.linalg.dataset.api.DataSet; import org.nd4j.linalg.util.ArrayUtil; import java.util.ArrayList; @@ -53,29 +52,38 @@ public abstract class QLearningDiscrete extends QLearning mdp; @Getter - final private IDQN currentDQN; - @Getter private DQNPolicy policy; @Getter private EpsGreedy egPolicy; + @Getter - @Setter - private IDQN targetDQN; + final private IDQN qNetwork; + @Getter + @Setter(AccessLevel.PROTECTED) + private IDQN targetQNetwork; + private int lastAction; private INDArray[] history = null; private double accuReward = 0; + ITDTargetAlgorithm tdTargetAlgorithm; + public QLearningDiscrete(MDP mdp, IDQN dqn, QLConfiguration conf, int epsilonNbStep) { super(conf); this.configuration = conf; this.mdp = mdp; - currentDQN = dqn; - targetDQN = dqn.clone(); - policy = new DQNPolicy(getCurrentDQN()); + qNetwork = dqn; + targetQNetwork = dqn.clone(); + policy = new DQNPolicy(getQNetwork()); egPolicy = new EpsGreedy(policy, mdp, conf.getUpdateStart(), epsilonNbStep, getRandom(), conf.getMinEpsilon(), this); mdp.getActionSpace().setSeed(conf.getSeed()); + + tdTargetAlgorithm = conf.isDoubleDQN() + ? new DoubleDQN(this, conf.getGamma(), conf.getErrorClamp()) + : new StandardDQN(this, conf.getGamma(), conf.getErrorClamp()); + } public void postEpoch() { @@ -134,7 +142,7 @@ public abstract class QLearningDiscrete extends QLearning 2) hstack = hstack.reshape(Learning.makeShape(1, ArrayUtil.toInts(hstack.shape()))); - INDArray qs = getCurrentDQN().output(hstack); + INDArray qs = getQNetwork().output(hstack); int maxAction = Learning.getMaxAction(qs); maxQ = qs.getDouble(maxAction); @@ -160,96 +168,31 @@ public abstract class QLearningDiscrete extends QLearning updateStart) { - Pair targets = setTarget(getExpReplay().getBatch()); - getCurrentDQN().fit(targets.getFirst(), targets.getSecond()); + DataSet targets = setTarget(getExpReplay().getBatch()); + getQNetwork().fit(targets.getFeatures(), targets.getLabels()); } history = nhistory; accuReward = 0; } - - return new QLStepReturn(maxQ, getCurrentDQN().getLatestScore(), stepReply); - + return new QLStepReturn(maxQ, getQNetwork().getLatestScore(), stepReply); } - protected Pair setTarget(ArrayList> transitions) { + protected DataSet setTarget(ArrayList> transitions) { if (transitions.size() == 0) throw new IllegalArgumentException("too few transitions"); - int size = transitions.size(); - + // TODO: Remove once we use DataSets in observations int[] shape = getHistoryProcessor() == null ? getMdp().getObservationSpace().getShape() : getHistoryProcessor().getConf().getShape(); - int[] nshape = makeShape(size, shape); - INDArray obs = Nd4j.create(nshape); - INDArray nextObs = Nd4j.create(nshape); - int[] actions = new int[size]; - boolean[] areTerminal = new boolean[size]; + ((BaseTDTargetAlgorithm) tdTargetAlgorithm).setNShape(makeShape(transitions.size(), shape)); - for (int i = 0; i < size; i++) { - Transition trans = transitions.get(i); - areTerminal[i] = trans.isTerminal(); - actions[i] = trans.getAction(); - - INDArray[] obsArray = trans.getObservation(); - if (obs.rank() == 2) { - obs.putRow(i, obsArray[0]); - } else { - for (int j = 0; j < obsArray.length; j++) { - obs.put(new INDArrayIndex[] {NDArrayIndex.point(i), NDArrayIndex.point(j)}, obsArray[j]); - } - } - - INDArray[] nextObsArray = Transition.append(trans.getObservation(), trans.getNextObservation()); - if (nextObs.rank() == 2) { - nextObs.putRow(i, nextObsArray[0]); - } else { - for (int j = 0; j < nextObsArray.length; j++) { - nextObs.put(new INDArrayIndex[] {NDArrayIndex.point(i), NDArrayIndex.point(j)}, nextObsArray[j]); - } - } - } - if (getHistoryProcessor() != null) { - obs.muli(1.0 / getHistoryProcessor().getScale()); - nextObs.muli(1.0 / getHistoryProcessor().getScale()); + // TODO: Remove once we use DataSets in observations + if(getHistoryProcessor() != null) { + ((BaseTDTargetAlgorithm) tdTargetAlgorithm).setScale(getHistoryProcessor().getScale()); } - INDArray dqnOutputAr = dqnOutput(obs); - - INDArray dqnOutputNext = dqnOutput(nextObs); - INDArray targetDqnOutputNext = targetDqnOutput(nextObs); - - INDArray tempQ = null; - INDArray getMaxAction = null; - if (getConfiguration().isDoubleDQN()) { - getMaxAction = Nd4j.argMax(dqnOutputNext, 1); - } else { - tempQ = Nd4j.max(targetDqnOutputNext, 1); - } - - - for (int i = 0; i < size; i++) { - double yTar = transitions.get(i).getReward(); - if (!areTerminal[i]) { - double q = 0; - if (getConfiguration().isDoubleDQN()) { - q += targetDqnOutputNext.getDouble(i, getMaxAction.getInt(i)); - } else - q += tempQ.getDouble(i); - - yTar += getConfiguration().getGamma() * q; - - } - - double previousV = dqnOutputAr.getDouble(i, actions[i]); - double lowB = previousV - getConfiguration().getErrorClamp(); - double highB = previousV + getConfiguration().getErrorClamp(); - double clamped = Math.min(highB, Math.max(yTar, lowB)); - - dqnOutputAr.putScalar(i, actions[i], clamped); - } - - return new Pair(obs, dqnOutputAr); + return tdTargetAlgorithm.computeTDTargets(transitions); } } diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/TDTargetAlgorithm/BaseDQNAlgorithm.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/TDTargetAlgorithm/BaseDQNAlgorithm.java new file mode 100644 index 000000000..3f27f954c --- /dev/null +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/TDTargetAlgorithm/BaseDQNAlgorithm.java @@ -0,0 +1,62 @@ +/******************************************************************************* + * Copyright (c) 2015-2019 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm; + +import org.deeplearning4j.rl4j.learning.sync.qlearning.TargetQNetworkSource; +import org.deeplearning4j.rl4j.network.dqn.IDQN; +import org.nd4j.linalg.api.ndarray.INDArray; + +/** + * The base of all DQN based algorithms + * + * @author Alexandre Boulanger + * + */ +public abstract class BaseDQNAlgorithm extends BaseTDTargetAlgorithm { + + private final TargetQNetworkSource qTargetNetworkSource; + + /** + * In litterature, this corresponds to Q{net}(s(t+1), a) + */ + protected INDArray qNetworkNextObservation; + + /** + * In litterature, this corresponds to Q{tnet}(s(t+1), a) + */ + protected INDArray targetQNetworkNextObservation; + + protected BaseDQNAlgorithm(TargetQNetworkSource qTargetNetworkSource, double gamma) { + super(qTargetNetworkSource, gamma); + this.qTargetNetworkSource = qTargetNetworkSource; + } + + protected BaseDQNAlgorithm(TargetQNetworkSource qTargetNetworkSource, double gamma, double errorClamp) { + super(qTargetNetworkSource, gamma, errorClamp); + this.qTargetNetworkSource = qTargetNetworkSource; + } + + @Override + protected void initComputation(INDArray observations, INDArray nextObservations) { + super.initComputation(observations, nextObservations); + + qNetworkNextObservation = qNetworkSource.getQNetwork().output(nextObservations); + + IDQN targetQNetwork = qTargetNetworkSource.getTargetQNetwork(); + targetQNetworkNextObservation = targetQNetwork.output(nextObservations); + } +} diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/TDTargetAlgorithm/BaseTDTargetAlgorithm.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/TDTargetAlgorithm/BaseTDTargetAlgorithm.java new file mode 100644 index 000000000..f4f143ee9 --- /dev/null +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/TDTargetAlgorithm/BaseTDTargetAlgorithm.java @@ -0,0 +1,147 @@ +/******************************************************************************* + * Copyright (c) 2015-2019 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm; + +import lombok.Setter; +import org.deeplearning4j.rl4j.learning.sync.Transition; +import org.deeplearning4j.rl4j.learning.sync.qlearning.QNetworkSource; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.dataset.api.DataSet; +import org.nd4j.linalg.factory.Nd4j; +import org.nd4j.linalg.indexing.INDArrayIndex; +import org.nd4j.linalg.indexing.NDArrayIndex; + +import java.util.List; + +/** + * The base of all TD target calculation algorithms that use deep learning. + * + * @author Alexandre Boulanger + */ +public abstract class BaseTDTargetAlgorithm implements ITDTargetAlgorithm { + + protected final QNetworkSource qNetworkSource; + protected final double gamma; + + private final double errorClamp; + private final boolean isClamped; + + @Setter + private int[] nShape; // TODO: Remove once we use DataSets in observations + @Setter + private double scale = 1.0; // TODO: Remove once we use DataSets in observations + + /** + * + * @param qNetworkSource The source of the Q-Network + * @param gamma The discount factor + * @param errorClamp Will prevent the new Q-Value from being farther than errorClamp away from the previous value. Double.NaN will disable the clamping. + */ + protected BaseTDTargetAlgorithm(QNetworkSource qNetworkSource, double gamma, double errorClamp) { + this.qNetworkSource = qNetworkSource; + this.gamma = gamma; + + this.errorClamp = errorClamp; + isClamped = !Double.isNaN(errorClamp); + } + + /** + * + * @param qNetworkSource The source of the Q-Network + * @param gamma The discount factor + * Note: Error clamping is disabled with this ctor + */ + protected BaseTDTargetAlgorithm(QNetworkSource qNetworkSource, double gamma) { + this(qNetworkSource, gamma, Double.NaN); + } + + /** + * Called just before the calculation starts + * @param observations A INDArray of all observations stacked on dimension 0 + * @param nextObservations A INDArray of all next observations stacked on dimension 0 + */ + protected void initComputation(INDArray observations, INDArray nextObservations) { + // Do nothing + } + + /** + * Compute the new estimated Q-Value for every transition in the batch + * + * @param batchIdx The index in the batch of the current transition + * @param reward The reward of the current transition + * @param isTerminal True if it's the last transition of the "game" + * @return The estimated Q-Value + */ + protected abstract double computeTarget(int batchIdx, double reward, boolean isTerminal); + + @Override + public DataSet computeTDTargets(List> transitions) { + + int size = transitions.size(); + + INDArray observations = Nd4j.create(nShape); + INDArray nextObservations = Nd4j.create(nShape); + + // TODO: Remove once we use DataSets in observations + for (int i = 0; i < size; i++) { + Transition trans = transitions.get(i); + + INDArray[] obsArray = trans.getObservation(); + if (observations.rank() == 2) { + observations.putRow(i, obsArray[0]); + } else { + for (int j = 0; j < obsArray.length; j++) { + observations.put(new INDArrayIndex[] {NDArrayIndex.point(i), NDArrayIndex.point(j)}, obsArray[j]); + } + } + + INDArray[] nextObsArray = Transition.append(trans.getObservation(), trans.getNextObservation()); + if (nextObservations.rank() == 2) { + nextObservations.putRow(i, nextObsArray[0]); + } else { + for (int j = 0; j < nextObsArray.length; j++) { + nextObservations.put(new INDArrayIndex[] {NDArrayIndex.point(i), NDArrayIndex.point(j)}, nextObsArray[j]); + } + } + } + + // TODO: Remove once we use DataSets in observations + if(scale != 1.0) { + observations.muli(1.0 / scale); + nextObservations.muli(1.0 / scale); + } + + initComputation(observations, nextObservations); + + INDArray updatedQValues = qNetworkSource.getQNetwork().output(observations); + + for (int i = 0; i < size; ++i) { + Transition transition = transitions.get(i); + double yTarget = computeTarget(i, transition.getReward(), transition.isTerminal()); + + if(isClamped) { + double previousQValue = updatedQValues.getDouble(i, transition.getAction()); + double lowBound = previousQValue - errorClamp; + double highBound = previousQValue + errorClamp; + yTarget = Math.min(highBound, Math.max(yTarget, lowBound)); + } + updatedQValues.putScalar(i, transition.getAction(), yTarget); + } + + return new org.nd4j.linalg.dataset.DataSet(observations, updatedQValues); + } +} diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/TDTargetAlgorithm/DoubleDQN.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/TDTargetAlgorithm/DoubleDQN.java new file mode 100644 index 000000000..3203af1b8 --- /dev/null +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/TDTargetAlgorithm/DoubleDQN.java @@ -0,0 +1,67 @@ +/******************************************************************************* + * Copyright (c) 2015-2019 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm; + +import org.deeplearning4j.rl4j.learning.sync.qlearning.TargetQNetworkSource; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.factory.Nd4j; + +/** + * The Double-DQN algorithm based on "Deep Reinforcement Learning with Double Q-learning" (https://arxiv.org/abs/1509.06461) + * + * @author Alexandre Boulanger + */ +public class DoubleDQN extends BaseDQNAlgorithm { + + private static final int ACTION_DIMENSION_IDX = 1; + + // In litterature, this corresponds to: max_{a}Q(s_{t+1}, a) + private INDArray maxActionsFromQNetworkNextObservation; + + public DoubleDQN(TargetQNetworkSource qTargetNetworkSource, double gamma) { + super(qTargetNetworkSource, gamma); + } + + public DoubleDQN(TargetQNetworkSource qTargetNetworkSource, double gamma, double errorClamp) { + super(qTargetNetworkSource, gamma, errorClamp); + } + + @Override + protected void initComputation(INDArray observations, INDArray nextObservations) { + super.initComputation(observations, nextObservations); + + maxActionsFromQNetworkNextObservation = Nd4j.argMax(qNetworkNextObservation, ACTION_DIMENSION_IDX); + } + + /** + * In litterature, this corresponds to:
+ * Q(s_t, a_t) = R_{t+1} + \gamma * Q_{tar}(s_{t+1}, max_{a}Q(s_{t+1}, a)) + * @param batchIdx The index in the batch of the current transition + * @param reward The reward of the current transition + * @param isTerminal True if it's the last transition of the "game" + * @return The estimated Q-Value + */ + @Override + protected double computeTarget(int batchIdx, double reward, boolean isTerminal) { + double yTarget = reward; + if (!isTerminal) { + yTarget += gamma * targetQNetworkNextObservation.getDouble(batchIdx, maxActionsFromQNetworkNextObservation.getInt(batchIdx)); + } + + return yTarget; + } +} diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/TDTargetAlgorithm/ITDTargetAlgorithm.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/TDTargetAlgorithm/ITDTargetAlgorithm.java new file mode 100644 index 000000000..199c0e7e3 --- /dev/null +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/TDTargetAlgorithm/ITDTargetAlgorithm.java @@ -0,0 +1,38 @@ +/******************************************************************************* + * Copyright (c) 2015-2019 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm; + +import org.deeplearning4j.rl4j.learning.sync.Transition; +import org.nd4j.linalg.dataset.api.DataSet; + +import java.util.List; + +/** + * The interface of all TD target calculation algorithms. + * + * @param
The type of actions + * + * @author Alexandre Boulanger + */ +public interface ITDTargetAlgorithm { + /** + * Compute the updated estimated Q-Values for every transition + * @param transitions The transitions from the experience replay + * @return A DataSet where every element is the observation and the estimated Q-Values for all actions + */ + DataSet computeTDTargets(List> transitions); +} diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/TDTargetAlgorithm/StandardDQN.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/TDTargetAlgorithm/StandardDQN.java new file mode 100644 index 000000000..8c03c8de9 --- /dev/null +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/TDTargetAlgorithm/StandardDQN.java @@ -0,0 +1,67 @@ +/******************************************************************************* + * Copyright (c) 2015-2019 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm; + +import org.deeplearning4j.rl4j.learning.sync.qlearning.TargetQNetworkSource; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.factory.Nd4j; + +/** + * The Standard DQN algorithm based on "Playing Atari with Deep Reinforcement Learning" (https://arxiv.org/abs/1312.5602) + * + * @author Alexandre Boulanger + */ +public class StandardDQN extends BaseDQNAlgorithm { + + private static final int ACTION_DIMENSION_IDX = 1; + + // In litterature, this corresponds to: max_{a}Q_{tar}(s_{t+1}, a) + private INDArray maxActionsFromQTargetNextObservation; + + public StandardDQN(TargetQNetworkSource qTargetNetworkSource, double gamma) { + super(qTargetNetworkSource, gamma); + } + + public StandardDQN(TargetQNetworkSource qTargetNetworkSource, double gamma, double errorClamp) { + super(qTargetNetworkSource, gamma, errorClamp); + } + + @Override + protected void initComputation(INDArray observations, INDArray nextObservations) { + super.initComputation(observations, nextObservations); + + maxActionsFromQTargetNextObservation = Nd4j.max(targetQNetworkNextObservation, ACTION_DIMENSION_IDX); + } + + /** + * In litterature, this corresponds to:
+ * Q(s_t, a_t) = R_{t+1} + \gamma * max_{a}Q_{tar}(s_{t+1}, a) + * @param batchIdx The index in the batch of the current transition + * @param reward The reward of the current transition + * @param isTerminal True if it's the last transition of the "game" + * @return The estimated Q-Value + */ + @Override + protected double computeTarget(int batchIdx, double reward, boolean isTerminal) { + double yTarget = reward; + if (!isTerminal) { + yTarget += gamma * maxActionsFromQTargetNextObservation.getDouble(batchIdx); + } + + return yTarget; + } +} diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/dqn/IDQN.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/dqn/IDQN.java index ee843dd3f..c6ae2f5ac 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/dqn/IDQN.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/dqn/IDQN.java @@ -35,7 +35,7 @@ public interface IDQN extends NeuralNet { void fit(INDArray input, INDArray labels); void fit(INDArray input, INDArray[] labels); - + INDArray output(INDArray batch); INDArray[] outputAll(INDArray batch); diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/QLearningDiscreteTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/QLearningDiscreteTest.java index 1a02d6e50..2982a1d21 100644 --- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/QLearningDiscreteTest.java +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/QLearningDiscreteTest.java @@ -12,8 +12,8 @@ import org.deeplearning4j.rl4j.util.DataManagerTrainingListener; import org.deeplearning4j.rl4j.util.IDataManager; import org.junit.Test; import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.dataset.api.DataSet; import org.nd4j.linalg.factory.Nd4j; -import org.nd4j.linalg.primitives.Pair; import java.util.ArrayList; import java.util.List; @@ -139,8 +139,8 @@ public class QLearningDiscreteTest { } @Override - protected Pair setTarget(ArrayList> transitions) { - return new Pair<>(Nd4j.create(new double[] { 123.0 }), Nd4j.create(new double[] { 234.0 })); + protected DataSet setTarget(ArrayList> transitions) { + return new org.nd4j.linalg.dataset.DataSet(Nd4j.create(new double[] { 123.0 }), Nd4j.create(new double[] { 234.0 })); } public void setExpReplay(IExpReplay exp){ diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/TDTargetAlgorithm/DoubleDQNTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/TDTargetAlgorithm/DoubleDQNTest.java new file mode 100644 index 000000000..e598b66ca --- /dev/null +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/TDTargetAlgorithm/DoubleDQNTest.java @@ -0,0 +1,105 @@ +package org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm; + +import org.deeplearning4j.rl4j.learning.sync.Transition; +import org.deeplearning4j.rl4j.learning.sync.support.MockDQN; +import org.deeplearning4j.rl4j.learning.sync.support.MockTargetQNetworkSource; +import org.junit.Test; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.dataset.api.DataSet; +import org.nd4j.linalg.factory.Nd4j; + +import java.util.ArrayList; +import java.util.List; + +import static org.junit.Assert.assertEquals; + +public class DoubleDQNTest { + + @Test + public void when_isTerminal_expect_rewardValueAtIdx0() { + + // Assemble + MockDQN qNetwork = new MockDQN(); + MockDQN targetQNetwork = new MockDQN(); + MockTargetQNetworkSource targetQNetworkSource = new MockTargetQNetworkSource(qNetwork, targetQNetwork); + + List> transitions = new ArrayList>() { + { + add(new Transition(new INDArray[]{Nd4j.create(new double[]{1.1, 2.2})}, 0, 1.0, true, Nd4j.create(new double[]{11.0, 22.0}))); + } + }; + + DoubleDQN sut = new DoubleDQN(targetQNetworkSource, 0.5); + sut.setNShape(new int[] { 1, 2 }); + + // Act + DataSet result = sut.computeTDTargets(transitions); + + // Assert + INDArray evaluatedQValues = result.getLabels(); + assertEquals(1.0, evaluatedQValues.getDouble(0, 0), 0.0001); + assertEquals(2.2, evaluatedQValues.getDouble(0, 1), 0.0001); + } + + @Test + public void when_isNotTerminal_expect_rewardPlusEstimatedQValue() { + + // Assemble + MockDQN qNetwork = new MockDQN(); + MockDQN targetQNetwork = new MockDQN(-1.0); + MockTargetQNetworkSource targetQNetworkSource = new MockTargetQNetworkSource(qNetwork, targetQNetwork); + + List> transitions = new ArrayList>() { + { + add(new Transition(new INDArray[]{Nd4j.create(new double[]{1.1, 2.2})}, 0, 1.0, false, Nd4j.create(new double[]{11.0, 22.0}))); + } + }; + + DoubleDQN sut = new DoubleDQN(targetQNetworkSource, 0.5); + sut.setNShape(new int[] { 1, 2 }); + + // Act + DataSet result = sut.computeTDTargets(transitions); + + // Assert + INDArray evaluatedQValues = result.getLabels(); + assertEquals(1.0 + 0.5 * -22.0, evaluatedQValues.getDouble(0, 0), 0.0001); + assertEquals(2.2, evaluatedQValues.getDouble(0, 1), 0.0001); + } + + @Test + public void when_batchHasMoreThanOne_expect_everySampleEvaluated() { + + // Assemble + MockDQN qNetwork = new MockDQN(); + MockDQN targetQNetwork = new MockDQN(-1.0); + MockTargetQNetworkSource targetQNetworkSource = new MockTargetQNetworkSource(qNetwork, targetQNetwork); + + List> transitions = new ArrayList>() { + { + add(new Transition(new INDArray[]{Nd4j.create(new double[]{1.1, 2.2})}, 0, 1.0, false, Nd4j.create(new double[]{11.0, 22.0}))); + add(new Transition(new INDArray[]{Nd4j.create(new double[]{3.3, 4.4})}, 1, 2.0, false, Nd4j.create(new double[]{33.0, 44.0}))); + add(new Transition(new INDArray[]{Nd4j.create(new double[]{5.5, 6.6})}, 0, 3.0, true, Nd4j.create(new double[]{55.0, 66.0}))); + } + }; + + DoubleDQN sut = new DoubleDQN(targetQNetworkSource, 0.5); + sut.setNShape(new int[] { 3, 2 }); + + // Act + DataSet result = sut.computeTDTargets(transitions); + + // Assert + INDArray evaluatedQValues = result.getLabels(); + assertEquals(1.0 + 0.5 * -22.0, evaluatedQValues.getDouble(0, 0), 0.0001); + assertEquals(2.2, evaluatedQValues.getDouble(0, 1), 0.0001); + + assertEquals(3.3, evaluatedQValues.getDouble(1, 0), 0.0001); + assertEquals(2.0 + 0.5 * -44.0, evaluatedQValues.getDouble(1, 1), 0.0001); + + assertEquals(3.0, evaluatedQValues.getDouble(2, 0), 0.0001); // terminal: reward only + assertEquals(6.6, evaluatedQValues.getDouble(2, 1), 0.0001); + + } + +} diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/TDTargetAlgorithm/StandardDQNTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/TDTargetAlgorithm/StandardDQNTest.java new file mode 100644 index 000000000..02dcdf6fd --- /dev/null +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/TDTargetAlgorithm/StandardDQNTest.java @@ -0,0 +1,104 @@ +package org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm; + +import org.deeplearning4j.rl4j.learning.sync.Transition; +import org.deeplearning4j.rl4j.learning.sync.support.MockDQN; +import org.deeplearning4j.rl4j.learning.sync.support.MockTargetQNetworkSource; +import org.junit.Test; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.dataset.api.DataSet; +import org.nd4j.linalg.factory.Nd4j; + +import java.util.ArrayList; +import java.util.List; + +import static org.junit.Assert.*; + +public class StandardDQNTest { + @Test + public void when_isTerminal_expect_rewardValueAtIdx0() { + + // Assemble + MockDQN qNetwork = new MockDQN(); + MockDQN targetQNetwork = new MockDQN(); + MockTargetQNetworkSource targetQNetworkSource = new MockTargetQNetworkSource(qNetwork, targetQNetwork); + + List> transitions = new ArrayList>() { + { + add(new Transition(new INDArray[]{Nd4j.create(new double[]{1.1, 2.2})}, 0, 1.0, true, Nd4j.create(new double[]{11.0, 22.0}))); + } + }; + + StandardDQN sut = new StandardDQN(targetQNetworkSource, 0.5); + sut.setNShape(new int[] { 1, 2 }); + + // Act + DataSet result = sut.computeTDTargets(transitions); + + // Assert + INDArray evaluatedQValues = result.getLabels(); + assertEquals(1.0, evaluatedQValues.getDouble(0, 0), 0.0001); + assertEquals(2.2, evaluatedQValues.getDouble(0, 1), 0.0001); + } + + @Test + public void when_isNotTerminal_expect_rewardPlusEstimatedQValue() { + + // Assemble + MockDQN qNetwork = new MockDQN(); + MockDQN targetQNetwork = new MockDQN(); + MockTargetQNetworkSource targetQNetworkSource = new MockTargetQNetworkSource(qNetwork, targetQNetwork); + + List> transitions = new ArrayList>() { + { + add(new Transition(new INDArray[]{Nd4j.create(new double[]{1.1, 2.2})}, 0, 1.0, false, Nd4j.create(new double[]{11.0, 22.0}))); + } + }; + + StandardDQN sut = new StandardDQN(targetQNetworkSource, 0.5); + sut.setNShape(new int[] { 1, 2 }); + + // Act + DataSet result = sut.computeTDTargets(transitions); + + // Assert + INDArray evaluatedQValues = result.getLabels(); + assertEquals(1.0 + 0.5 * 22.0, evaluatedQValues.getDouble(0, 0), 0.0001); + assertEquals(2.2, evaluatedQValues.getDouble(0, 1), 0.0001); + } + + @Test + public void when_batchHasMoreThanOne_expect_everySampleEvaluated() { + + // Assemble + MockDQN qNetwork = new MockDQN(); + MockDQN targetQNetwork = new MockDQN(); + MockTargetQNetworkSource targetQNetworkSource = new MockTargetQNetworkSource(qNetwork, targetQNetwork); + + List> transitions = new ArrayList>() { + { + add(new Transition(new INDArray[]{Nd4j.create(new double[]{1.1, 2.2})}, 0, 1.0, false, Nd4j.create(new double[]{11.0, 22.0}))); + add(new Transition(new INDArray[]{Nd4j.create(new double[]{3.3, 4.4})}, 1, 2.0, false, Nd4j.create(new double[]{33.0, 44.0}))); + add(new Transition(new INDArray[]{Nd4j.create(new double[]{5.5, 6.6})}, 0, 3.0, true, Nd4j.create(new double[]{55.0, 66.0}))); + } + }; + + StandardDQN sut = new StandardDQN(targetQNetworkSource, 0.5); + sut.setNShape(new int[] { 3, 2 }); + + // Act + DataSet result = sut.computeTDTargets(transitions); + + // Assert + INDArray evaluatedQValues = result.getLabels(); + assertEquals((1.0 + 0.5 * 22.0), evaluatedQValues.getDouble(0, 0), 0.0001); + assertEquals(2.2, evaluatedQValues.getDouble(0, 1), 0.0001); + + assertEquals(3.3, evaluatedQValues.getDouble(1, 0), 0.0001); + assertEquals((2.0 + 0.5 * 44.0), evaluatedQValues.getDouble(1, 1), 0.0001); + + assertEquals(3.0, evaluatedQValues.getDouble(2, 0), 0.0001); // terminal: reward only + assertEquals(6.6, evaluatedQValues.getDouble(2, 1), 0.0001); + + } + +} diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/support/MockDQN.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/support/MockDQN.java index 7d088f060..08957fee5 100644 --- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/support/MockDQN.java +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/support/MockDQN.java @@ -1,15 +1,28 @@ package org.deeplearning4j.rl4j.learning.sync.support; +import lombok.Setter; import org.deeplearning4j.nn.api.NeuralNetwork; import org.deeplearning4j.nn.gradient.Gradient; import org.deeplearning4j.rl4j.network.NeuralNet; import org.deeplearning4j.rl4j.network.dqn.IDQN; import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.factory.Nd4j; import java.io.IOException; import java.io.OutputStream; public class MockDQN implements IDQN { + + private final double mult; + + public MockDQN() { + this(1.0); + } + + public MockDQN(double mult) { + this.mult = mult; + } + @Override public NeuralNetwork[] getNeuralNetworks() { return new NeuralNetwork[0]; @@ -37,7 +50,11 @@ public class MockDQN implements IDQN { @Override public INDArray output(INDArray batch) { - return null; + if(mult != 1.0) { + return batch.dup().muli(mult); + } + + return batch; } @Override diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/support/MockTargetQNetworkSource.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/support/MockTargetQNetworkSource.java new file mode 100644 index 000000000..ce756aa88 --- /dev/null +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/support/MockTargetQNetworkSource.java @@ -0,0 +1,26 @@ +package org.deeplearning4j.rl4j.learning.sync.support; + +import org.deeplearning4j.rl4j.learning.sync.qlearning.TargetQNetworkSource; +import org.deeplearning4j.rl4j.network.dqn.IDQN; + +public class MockTargetQNetworkSource implements TargetQNetworkSource { + + + private final IDQN qNetwork; + private final IDQN targetQNetwork; + + public MockTargetQNetworkSource(IDQN qNetwork, IDQN targetQNetwork) { + this.qNetwork = qNetwork; + this.targetQNetwork = targetQNetwork; + } + + @Override + public IDQN getTargetQNetwork() { + return targetQNetwork; + } + + @Override + public IDQN getQNetwork() { + return qNetwork; + } +}