From 7c420d7ce56a4e05edc4d2d9c6db44400001ea99 Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Wed, 6 Sep 2023 17:38:21 +0100 Subject: [PATCH 01/15] init Signed-off-by: Matteo Bettini --- torchrl/modules/tensordict_module/actors.py | 47 +++++++++++++++++++++ 1 file changed, 47 insertions(+) diff --git a/torchrl/modules/tensordict_module/actors.py b/torchrl/modules/tensordict_module/actors.py index da719102179..8ac86f4ae08 100644 --- a/torchrl/modules/tensordict_module/actors.py +++ b/torchrl/modules/tensordict_module/actors.py @@ -327,6 +327,8 @@ class QValueModule(TensorDictModuleBase): conditions the action_space. action_value_key (str or tuple of str, optional): The input key representing the action value. Defaults to ``"action_value"``. + action_mask_key (str or tuple of str, optional): The input key + representing the action mask. Defaults to ``"None"`` (equivalent to no masking). out_keys (list of str or tuple of str, optional): The output keys representing the actions, action values and chosen action value. Defaults to ``["action", "action_value", "chosen_action_value"]``. @@ -378,6 +380,7 @@ def __init__( self, action_space: Optional[str], action_value_key: Optional[NestedKey] = None, + action_mask_key: Optional[NestedKey] = None, out_keys: Optional[Sequence[NestedKey]] = None, var_nums: Optional[int] = None, spec: Optional[TensorSpec] = None, @@ -408,6 +411,9 @@ def __init__( if action_value_key is None: action_value_key = "action_value" self.in_keys = [action_value_key] + self.action_mask_key = action_mask_key + if self.action_mask_key is not None: + self.in_keys.append(self.action_mask_key) if out_keys is None: out_keys = ["action", action_value_key, "chosen_action_value"] elif action_value_key not in out_keys: @@ -446,6 +452,13 @@ def forward(self, tensordict: torch.Tensor) -> TensorDictBase: raise KeyError( f"Action value key {self.action_value_key} not found in {tensordict}." ) + if self.action_mask_key is not None: + action_mask = tensordict.get(self.action_mask_key, None) + if action_mask is None: + raise KeyError( + f"Action mask key {self.action_mask_key} not found in {tensordict}." + ) + action_values[action_mask] = torch.finfo(action_values.dtype).min action = self.action_func_mapping[self.action_space](action_values) @@ -528,6 +541,8 @@ class DistributionalQValueModule(QValueModule): support (torch.Tensor): support of the action values. action_value_key (str or tuple of str, optional): The input key representing the action value. Defaults to ``"action_value"``. + action_mask_key (str or tuple of str, optional): The input key + representing the action mask. Defaults to ``"None"`` (equivalent to no masking). out_keys (list of str or tuple of str, optional): The output keys representing the actions and action values. Defaults to ``["action", "action_value"]``. @@ -583,6 +598,7 @@ def __init__( action_space: Optional[str], support: torch.Tensor, action_value_key: Optional[NestedKey] = None, + action_mask_key: Optional[NestedKey] = None, out_keys: Optional[Sequence[NestedKey]] = None, var_nums: Optional[int] = None, spec: TensorSpec = None, @@ -595,6 +611,7 @@ def __init__( super().__init__( action_space=action_space, action_value_key=action_value_key, + action_mask_key=action_mask_key, out_keys=out_keys, var_nums=var_nums, spec=spec, @@ -609,6 +626,13 @@ def forward(self, tensordict: torch.Tensor) -> TensorDictBase: raise KeyError( f"Action value key {self.action_value_key} not found in {tensordict}." ) + if self.action_mask_key is not None: + action_mask = tensordict.get(self.action_mask_key, None) + if action_mask is None: + raise KeyError( + f"Action mask key {self.action_mask_key} not found in {tensordict}." + ) + action_values[action_mask] = torch.finfo(action_values.dtype).min action = self.action_func_mapping[self.action_space](action_values) @@ -698,6 +722,8 @@ class QValueHook: action_value_key (str or tuple of str, optional): to be used when hooked on a TensorDictModule. The input key representing the action value. Defaults to ``"action_value"``. + action_mask_key (str or tuple of str, optional): The input key + representing the action mask. Defaults to ``"None"`` (equivalent to no masking). out_keys (list of str or tuple of str, optional): to be used when hooked on a TensorDictModule. The output keys representing the actions, action values and chosen action value. Defaults to ``["action", "action_value", "chosen_action_value"]``. @@ -733,6 +759,7 @@ def __init__( action_space: str, var_nums: Optional[int] = None, action_value_key: Optional[NestedKey] = None, + action_mask_key: Optional[NestedKey] = None, out_keys: Optional[Sequence[NestedKey]] = None, ): if isinstance(action_space, TensorSpec): @@ -747,6 +774,7 @@ def __init__( action_space=action_space, var_nums=var_nums, action_value_key=action_value_key, + action_mask_key=action_mask_key, out_keys=out_keys, ) action_value_key = self.qvalue_model.in_keys[0] @@ -776,6 +804,11 @@ class DistributionalQValueHook(QValueHook): Args: action_space (str): Action space. Must be one of ``"one-hot"``, ``"mult-one-hot"``, ``"binary"`` or ``"categorical"``. + action_value_key (str or tuple of str, optional): to be used when hooked on + a TensorDictModule. The input key representing the action value. Defaults + to ``"action_value"``. + action_mask_key (str or tuple of str, optional): The input key + representing the action mask. Defaults to ``"None"`` (equivalent to no masking). support (torch.Tensor): support of the action values. var_nums (int, optional): if ``action_space = "mult-one-hot"``, this value represents the cardinality of each @@ -823,6 +856,7 @@ def __init__( support: torch.Tensor, var_nums: Optional[int] = None, action_value_key: Optional[NestedKey] = None, + action_mask_key: Optional[NestedKey] = None, out_keys: Optional[Sequence[NestedKey]] = None, ): if isinstance(action_space, TensorSpec): @@ -837,6 +871,7 @@ def __init__( var_nums=var_nums, support=support, action_value_key=action_value_key, + action_mask_key=action_mask_key, out_keys=out_keys, ) action_value_key = self.qvalue_model.in_keys[0] @@ -884,6 +919,8 @@ class QValueActor(SafeSequential): is a :class:`tensordict.nn.TensorDictModuleBase` instance, it must match one of its output keys. Otherwise, this string represents the name of the action-value entry in the output tensordict. + action_mask_key (str or tuple of str, optional): The input key + representing the action mask. Defaults to ``"None"`` (equivalent to no masking). .. note:: ``out_keys`` cannot be passed. If the module is a :class:`tensordict.nn.TensorDictModule` @@ -942,6 +979,7 @@ def __init__( safe=False, action_space: Optional[str] = None, action_value_key=None, + action_mask_key: Optional[NestedKey] = None, ): if isinstance(action_space, TensorSpec): warnings.warn( @@ -987,6 +1025,7 @@ def __init__( spec=spec, safe=safe, action_space=action_space, + action_mask_key=action_mask_key, ) super().__init__(module, qvalue) @@ -1035,6 +1074,12 @@ class DistributionalQValueActor(QValueActor): make_log_softmax (bool, optional): if ``True`` and if the module is not of type :class:`torchrl.modules.DistributionalDQNnet`, a log-softmax operation will be applied along dimension -2 of the action value tensor. + action_value_key (str or tuple of str, optional): if the input module + is a :class:`tensordict.nn.TensorDictModuleBase` instance, it must + match one of its output keys. Otherwise, this string represents + the name of the action-value entry in the output tensordict. + action_mask_key (str or tuple of str, optional): The input key + representing the action mask. Defaults to ``"None"`` (equivalent to no masking). Examples: >>> import torch @@ -1079,6 +1124,7 @@ def __init__( var_nums: Optional[int] = None, action_space: Optional[str] = None, action_value_key: str = "action_value", + action_mask_key: Optional[NestedKey] = None, make_log_softmax: bool = True, ): if isinstance(action_space, TensorSpec): @@ -1121,6 +1167,7 @@ def __init__( spec=spec, safe=safe, action_space=action_space, + action_mask_key=action_mask_key, support=support, var_nums=var_nums, ) From 3b79ee4ba028931fdd4b9136b3dcf629f9bdbacd Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Wed, 6 Sep 2023 18:43:23 +0100 Subject: [PATCH 02/15] amend Signed-off-by: Matteo Bettini --- test/test_actors.py | 31 +++++++++++++++++++++ torchrl/modules/tensordict_module/actors.py | 2 +- 2 files changed, 32 insertions(+), 1 deletion(-) diff --git a/test/test_actors.py b/test/test_actors.py index d16c95731d5..0fcbdd5d19a 100644 --- a/test/test_actors.py +++ b/test/test_actors.py @@ -613,6 +613,37 @@ def test_qvalue_hook_categorical_1_dim_batch(self, action_space, expected_action assert values.shape == in_values.shape assert (values == in_values).all() + @pytest.mark.parametrize("action_space", ["categorical", "one-hot"]) + def test_qvalue_mask(self, action_space): + torch.manual_seed(0) + shape = (3, 4, 6) + action_values = torch.randn(size=shape) + td = TensorDict({"action_value": action_values.clone()}, [3]) + module = QValueModule( + action_space=action_space, + action_value_key="action_value", + action_mask_key="action_mask", + ) + with pytest.raises(KeyError, match="Action mask key "): + module(td) + + action_mask = torch.randint(high=2, size=shape).to(torch.bool) + while not action_mask.any(dim=-1).all() or action_mask.all(): + action_mask = torch.randint(high=2, size=shape).to(torch.bool) + + td.set("action_mask", action_mask) + module(td) + new_action_values = td.get("action_value") + + assert (new_action_values[~action_mask] != action_values[~action_mask]).all() + assert (new_action_values[action_mask] == action_values[action_mask]).all() + + if action_space == "one-hot": + assert (td.get("action")[action_mask]).any() + assert not (td.get("action")[~action_mask]).any() + else: + assert action_mask.gather(-1, td.get("action").unsqueeze(-1)).all() + @pytest.mark.parametrize("device", get_default_devices()) def test_value_based_policy(device): diff --git a/torchrl/modules/tensordict_module/actors.py b/torchrl/modules/tensordict_module/actors.py index 8ac86f4ae08..578bd667cbf 100644 --- a/torchrl/modules/tensordict_module/actors.py +++ b/torchrl/modules/tensordict_module/actors.py @@ -458,7 +458,7 @@ def forward(self, tensordict: torch.Tensor) -> TensorDictBase: raise KeyError( f"Action mask key {self.action_mask_key} not found in {tensordict}." ) - action_values[action_mask] = torch.finfo(action_values.dtype).min + action_values[~action_mask] = torch.finfo(action_values.dtype).min action = self.action_func_mapping[self.action_space](action_values) From eb541cd0035c3f5aab62945987b13b784d5a2581 Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Wed, 6 Sep 2023 18:49:50 +0100 Subject: [PATCH 03/15] test Signed-off-by: Matteo Bettini --- test/test_actors.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/test/test_actors.py b/test/test_actors.py index 0fcbdd5d19a..940d1f5dd96 100644 --- a/test/test_actors.py +++ b/test/test_actors.py @@ -614,9 +614,10 @@ def test_qvalue_hook_categorical_1_dim_batch(self, action_space, expected_action assert (values == in_values).all() @pytest.mark.parametrize("action_space", ["categorical", "one-hot"]) - def test_qvalue_mask(self, action_space): + @pytest.mark.parametrize("action_n", [2, 3, 4, 5]) + def test_qvalue_mask(self, action_space, action_n): torch.manual_seed(0) - shape = (3, 4, 6) + shape = (3, 4, 3, action_n) action_values = torch.randn(size=shape) td = TensorDict({"action_value": action_values.clone()}, [3]) module = QValueModule( @@ -637,6 +638,7 @@ def test_qvalue_mask(self, action_space): assert (new_action_values[~action_mask] != action_values[~action_mask]).all() assert (new_action_values[action_mask] == action_values[action_mask]).all() + assert (td.get("chosen_action_value") > torch.finfo(torch.float).min).all() if action_space == "one-hot": assert (td.get("action")[action_mask]).any() From c328249b0c6bdf3ddb9d4beb2481be252412f685 Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Wed, 6 Sep 2023 20:58:02 +0100 Subject: [PATCH 04/15] amend Signed-off-by: Matteo Bettini --- docs/source/reference/modules.rst | 1 + examples/multiagent/iql.py | 2 +- examples/multiagent/qmix_vdn.py | 2 +- torchrl/modules/__init__.py | 1 + torchrl/modules/tensordict_module/__init__.py | 1 + torchrl/modules/tensordict_module/actors.py | 5 +- .../modules/tensordict_module/exploration.py | 168 +++++++++++++++++- 7 files changed, 174 insertions(+), 6 deletions(-) diff --git a/docs/source/reference/modules.rst b/docs/source/reference/modules.rst index 704c8e6276a..21edd5b2804 100644 --- a/docs/source/reference/modules.rst +++ b/docs/source/reference/modules.rst @@ -69,6 +69,7 @@ other cases, the action written in the tensordict is simply the network output. AdditiveGaussianWrapper EGreedyWrapper + EGreedyModule OrnsteinUhlenbeckProcessWrapper Probabilistic actors diff --git a/examples/multiagent/iql.py b/examples/multiagent/iql.py index 00c7bf5fc87..0970e9607f2 100644 --- a/examples/multiagent/iql.py +++ b/examples/multiagent/iql.py @@ -101,7 +101,7 @@ def train(cfg: "DictConfig"): # noqa: F821 eps_end=0, annealing_num_steps=int(cfg.collector.total_frames * (1 / 2)), action_key=env.action_key, - spec=env.unbatched_action_spec[env.action_key], + spec=env.input_spec["full_action_spec"], ) collector = SyncDataCollector( diff --git a/examples/multiagent/qmix_vdn.py b/examples/multiagent/qmix_vdn.py index 55c5ef012ba..648b2ba17eb 100644 --- a/examples/multiagent/qmix_vdn.py +++ b/examples/multiagent/qmix_vdn.py @@ -102,7 +102,7 @@ def train(cfg: "DictConfig"): # noqa: F821 eps_end=0, annealing_num_steps=int(cfg.collector.total_frames * (1 / 2)), action_key=env.action_key, - spec=env.unbatched_action_spec[env.action_key], + spec=env.input_spec["full_action_spec"], ) if cfg.loss.mixer_type == "qmix": diff --git a/torchrl/modules/__init__.py b/torchrl/modules/__init__.py index ad654dbc7c9..604bb3bdca7 100644 --- a/torchrl/modules/__init__.py +++ b/torchrl/modules/__init__.py @@ -53,6 +53,7 @@ DistributionalQValueActor, DistributionalQValueHook, DistributionalQValueModule, + EGreedyModule, EGreedyWrapper, LMHeadActorValueOperator, LSTMModule, diff --git a/torchrl/modules/tensordict_module/__init__.py b/torchrl/modules/tensordict_module/__init__.py index 645c7b6f122..d1930855ab2 100644 --- a/torchrl/modules/tensordict_module/__init__.py +++ b/torchrl/modules/tensordict_module/__init__.py @@ -23,6 +23,7 @@ from .common import SafeModule, VmapModule from .exploration import ( AdditiveGaussianWrapper, + EGreedyModule, EGreedyWrapper, OrnsteinUhlenbeckProcessWrapper, ) diff --git a/torchrl/modules/tensordict_module/actors.py b/torchrl/modules/tensordict_module/actors.py index 578bd667cbf..91cd4fcd896 100644 --- a/torchrl/modules/tensordict_module/actors.py +++ b/torchrl/modules/tensordict_module/actors.py @@ -410,10 +410,11 @@ def __init__( ) if action_value_key is None: action_value_key = "action_value" - self.in_keys = [action_value_key] self.action_mask_key = action_mask_key + in_keys = [action_value_key] if self.action_mask_key is not None: - self.in_keys.append(self.action_mask_key) + in_keys.append(self.action_mask_key) + self.in_keys = in_keys if out_keys is None: out_keys = ["action", action_value_key, "chosen_action_value"] elif action_value_key not in out_keys: diff --git a/torchrl/modules/tensordict_module/exploration.py b/torchrl/modules/tensordict_module/exploration.py index 20d26b7aabd..064212838ed 100644 --- a/torchrl/modules/tensordict_module/exploration.py +++ b/torchrl/modules/tensordict_module/exploration.py @@ -7,7 +7,12 @@ import numpy as np import torch -from tensordict.nn import TensorDictModule, TensorDictModuleWrapper + +from tensordict.nn import ( + TensorDictModule, + TensorDictModuleBase, + TensorDictModuleWrapper, +) from tensordict.tensordict import TensorDictBase from tensordict.utils import expand_as_right, expand_right, NestedKey @@ -17,11 +22,155 @@ __all__ = [ "EGreedyWrapper", + "EGreedyModule", "AdditiveGaussianWrapper", "OrnsteinUhlenbeckProcessWrapper", ] +class EGreedyModule(TensorDictModuleBase): + """Epsilon-Greedy module. + + This module updates the action in a tensordict to an epsilon greedy one. + + Keyword Args: + eps_init (scalar, optional): initial epsilon value. + default: 1.0 + eps_end (scalar, optional): final epsilon value. + default: 0.1 + annealing_num_steps (int, optional): number of steps it will take for epsilon to reach the eps_end value + action_key (NestedKey, optional): if the policy module has more than one output key, + its output spec will be of type CompositeSpec. One needs to know where to + find the action spec. + Default is ``"action"``. + action_mask_key (NestedKey, optional): the key where the action maskcan be found in the tensordict. + Default is ``"None"`` (corresponding to no mask). + spec (TensorSpec, optional): if provided, the sampled action will be + projected onto the valid action space once explored. If not provided, + the exploration wrapper will attempt to recover it from the policy. + + .. note:: + It is crucial to incorporate a call to :meth:`~.step` in the training loop + to update the exploration factor. + Since it is not easy to capture this omission no warning or exception + will be raised if this is ommitted! + + Examples: + >>> import torch + >>> from tensordict import TensorDict + >>> from tensordict.nn import TensorDictSequential + >>> from torchrl.modules import EGreedyModule, Actor + >>> from torchrl.data import BoundedTensorSpec + >>> torch.manual_seed(0) + >>> spec = BoundedTensorSpec(-1, 1, torch.Size([4])) + >>> module = torch.nn.Linear(4, 4, bias=False) + >>> policy = Actor(spec=spec, module=module) + >>> explorative_policy = TensorDictSequential(policy, EGreedyModule(eps_init=0.2)) + >>> td = TensorDict({"observation": torch.zeros(10, 4)}, batch_size=[10]) + >>> print(explorative_policy(td).get("action")) + tensor([[ 0.0000, 0.0000, 0.0000, 0.0000], + [ 0.0000, 0.0000, 0.0000, 0.0000], + [ 0.9055, -0.9277, -0.6295, -0.2532], + [ 0.0000, 0.0000, 0.0000, 0.0000], + [ 0.0000, 0.0000, 0.0000, 0.0000], + [ 0.0000, 0.0000, 0.0000, 0.0000], + [ 0.0000, 0.0000, 0.0000, 0.0000], + [ 0.0000, 0.0000, 0.0000, 0.0000], + [ 0.0000, 0.0000, 0.0000, 0.0000], + [ 0.0000, 0.0000, 0.0000, 0.0000]], grad_fn=) + + """ + + def __init__( + self, + eps_init: float = 1.0, + eps_end: float = 0.1, + annealing_num_steps: int = 1000, + action_key: Optional[NestedKey] = "action", + action_mask_key: Optional[NestedKey] = None, + spec: Optional[TensorSpec] = None, + ): + self.register_buffer("eps_init", torch.tensor([eps_init])) + self.register_buffer("eps_end", torch.tensor([eps_end])) + if self.eps_end > self.eps_init: + raise RuntimeError("eps should decrease over time or be constant") + self.annealing_num_steps = annealing_num_steps + self.register_buffer("eps", torch.tensor([eps_init])) + self.action_key = action_key + self.action_mask_key = action_mask_key + in_keys = [action_key] + if self.action_mask_key is not None: + in_keys.append(self.action_mask_key) + self.in_keys = in_keys + self.out_keys = [self.action_key] + if spec is not None: + if not isinstance(spec, CompositeSpec) and len(self.out_keys) >= 1: + spec = CompositeSpec({action_key: spec}, shape=spec.shape[:-1]) + self._spec = spec + else: + self._spec = CompositeSpec({action_key: None}) + + @property + def spec(self): + return self._spec + + def step(self, frames: int = 1) -> None: + """A step of epsilon decay. + + After self.annealing_num_steps, this function is a no-op. + + Args: + frames (int): number of frames since last step. + + """ + for _ in range(frames): + self.eps.data[0] = max( + self.eps_end.item(), + ( + self.eps - (self.eps_init - self.eps_end) / self.annealing_num_steps + ).item(), + ) + + def forward(self, tensordict: TensorDictBase) -> TensorDictBase: + if exploration_type() == ExplorationType.RANDOM or exploration_type() is None: + if isinstance(self.action_key, tuple) and len(self.action_key) > 1: + action_tensordict = tensordict.get(self.action_key[:-1]) + action_key = self.action_key[-1] + else: + action_tensordict = tensordict + action_key = self.action_key + + out = action_tensordict.get(action_key) + eps = self.eps.item() + cond = ( + torch.rand(action_tensordict.shape, device=action_tensordict.device) + < eps + ).to(out.dtype) + cond = expand_as_right(cond, out) + spec = self.spec + if spec is not None: + if isinstance(spec, CompositeSpec): + spec = spec[self.action_key] + if spec.shape != out.shape: + raise ValueError( + "Action spec shape does not match the action shape" + ) + if self.action_mask_key is not None: + action_mask = tensordict.get(self.action_mask_key, None) + if action_mask is None: + raise KeyError( + f"Action mask key {self.action_mask_key} not found in {tensordict}." + ) + spec.update_mask(action_mask) + out = cond * spec.rand().to(out.device) + (1 - cond) * out + else: + raise RuntimeError( + "spec must be provided by the policy or directly to the exploration wrapper." + ) + action_tensordict.set(action_key, out) + return tensordict + + class EGreedyWrapper(TensorDictModuleWrapper): """Epsilon-Greedy PO wrapper. @@ -38,12 +187,14 @@ class EGreedyWrapper(TensorDictModuleWrapper): its output spec will be of type CompositeSpec. One needs to know where to find the action spec. Default is "action". + action_mask_key (NestedKey, optional): the key where the action maskcan be found in the tensordict. + Default is ``"None"`` (corresponding to no mask). spec (TensorSpec, optional): if provided, the sampled action will be projected onto the valid action space once explored. If not provided, the exploration wrapper will attempt to recover it from the policy. .. note:: - Once an environment has been wrapped in :class:`EGreedyWrapper`, it is + Once a module has been wrapped in :class:`EGreedyWrapper`, it is crucial to incorporate a call to :meth:`~.step` in the training loop to update the exploration factor. Since it is not easy to capture this omission no warning or exception @@ -82,6 +233,7 @@ def __init__( eps_end: float = 0.1, annealing_num_steps: int = 1000, action_key: Optional[NestedKey] = "action", + action_mask_key: Optional[NestedKey] = None, spec: Optional[TensorSpec] = None, ): super().__init__(policy) @@ -92,6 +244,7 @@ def __init__( self.annealing_num_steps = annealing_num_steps self.register_buffer("eps", torch.tensor([eps_init])) self.action_key = action_key + self.action_mask_key = action_mask_key if spec is not None: if not isinstance(spec, CompositeSpec) and len(self.out_keys) >= 1: spec = CompositeSpec({action_key: spec}, shape=spec.shape[:-1]) @@ -149,6 +302,17 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: if spec is not None: if isinstance(spec, CompositeSpec): spec = spec[self.action_key] + if spec.shape != out.shape: + raise ValueError( + "Action spec shape does not match the action shape" + ) + if self.action_mask_key is not None: + action_mask = tensordict.get(self.action_mask_key, None) + if action_mask is None: + raise KeyError( + f"Action mask key {self.action_mask_key} not found in {tensordict}." + ) + spec.update_mask(action_mask) out = cond * spec.rand().to(out.device) + (1 - cond) * out else: raise RuntimeError( From 2c021e3c2377ebc30e6864e92d815b95b2a7f1d7 Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Wed, 6 Sep 2023 21:04:03 +0100 Subject: [PATCH 05/15] amend Signed-off-by: Matteo Bettini --- .../modules/tensordict_module/exploration.py | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/torchrl/modules/tensordict_module/exploration.py b/torchrl/modules/tensordict_module/exploration.py index 064212838ed..7448f8407a7 100644 --- a/torchrl/modules/tensordict_module/exploration.py +++ b/torchrl/modules/tensordict_module/exploration.py @@ -90,19 +90,23 @@ def __init__( action_mask_key: Optional[NestedKey] = None, spec: Optional[TensorSpec] = None, ): - self.register_buffer("eps_init", torch.tensor([eps_init])) - self.register_buffer("eps_end", torch.tensor([eps_end])) - if self.eps_end > self.eps_init: - raise RuntimeError("eps should decrease over time or be constant") - self.annealing_num_steps = annealing_num_steps - self.register_buffer("eps", torch.tensor([eps_init])) self.action_key = action_key self.action_mask_key = action_mask_key - in_keys = [action_key] + in_keys = [self.action_key] if self.action_mask_key is not None: in_keys.append(self.action_mask_key) self.in_keys = in_keys self.out_keys = [self.action_key] + + super().__init__() + + self.register_buffer("eps_init", torch.tensor([eps_init])) + self.register_buffer("eps_end", torch.tensor([eps_end])) + if self.eps_end > self.eps_init: + raise RuntimeError("eps should decrease over time or be constant") + self.annealing_num_steps = annealing_num_steps + self.register_buffer("eps", torch.tensor([eps_init])) + if spec is not None: if not isinstance(spec, CompositeSpec) and len(self.out_keys) >= 1: spec = CompositeSpec({action_key: spec}, shape=spec.shape[:-1]) From ff8e05830757bae03c29102f07246a4553f5a4fd Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Wed, 6 Sep 2023 21:18:37 +0100 Subject: [PATCH 06/15] fix Signed-off-by: Matteo Bettini --- examples/multiagent/iql.py | 2 +- examples/multiagent/qmix_vdn.py | 2 +- .../modules/tensordict_module/exploration.py | 22 ++++++++++++++----- 3 files changed, 18 insertions(+), 8 deletions(-) diff --git a/examples/multiagent/iql.py b/examples/multiagent/iql.py index 0970e9607f2..4d36614f199 100644 --- a/examples/multiagent/iql.py +++ b/examples/multiagent/iql.py @@ -101,7 +101,7 @@ def train(cfg: "DictConfig"): # noqa: F821 eps_end=0, annealing_num_steps=int(cfg.collector.total_frames * (1 / 2)), action_key=env.action_key, - spec=env.input_spec["full_action_spec"], + spec=env.unbatched_action_spec, ) collector = SyncDataCollector( diff --git a/examples/multiagent/qmix_vdn.py b/examples/multiagent/qmix_vdn.py index 648b2ba17eb..222e0434db2 100644 --- a/examples/multiagent/qmix_vdn.py +++ b/examples/multiagent/qmix_vdn.py @@ -102,7 +102,7 @@ def train(cfg: "DictConfig"): # noqa: F821 eps_end=0, annealing_num_steps=int(cfg.collector.total_frames * (1 / 2)), action_key=env.action_key, - spec=env.input_spec["full_action_spec"], + spec=env.unbatched_action_spec, ) if cfg.loss.mixer_type == "qmix": diff --git a/torchrl/modules/tensordict_module/exploration.py b/torchrl/modules/tensordict_module/exploration.py index 7448f8407a7..1c2c55bdc18 100644 --- a/torchrl/modules/tensordict_module/exploration.py +++ b/torchrl/modules/tensordict_module/exploration.py @@ -156,9 +156,14 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: if isinstance(spec, CompositeSpec): spec = spec[self.action_key] if spec.shape != out.shape: - raise ValueError( - "Action spec shape does not match the action shape" - ) + # In batched envs if the spec is passed unbatched, the rand() will not + # cover all batched dims + if out.shape[-len(spec.shape) :] == spec.shape: + spec = spec.expand(out.shape) + else: + raise ValueError( + "Action spec shape does not match the action shape" + ) if self.action_mask_key is not None: action_mask = tensordict.get(self.action_mask_key, None) if action_mask is None: @@ -307,9 +312,14 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: if isinstance(spec, CompositeSpec): spec = spec[self.action_key] if spec.shape != out.shape: - raise ValueError( - "Action spec shape does not match the action shape" - ) + # In batched envs if the spec is passed unbatched, the rand() will not + # cover all batched dims + if out.shape[-len(spec.shape) :] == spec.shape: + spec = spec.expand(out.shape) + else: + raise ValueError( + "Action spec shape does not match the action shape" + ) if self.action_mask_key is not None: action_mask = tensordict.get(self.action_mask_key, None) if action_mask is None: From 00365ca2d649c9f703ceab97be4342d4d979e1a9 Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Thu, 7 Sep 2023 09:59:32 +0100 Subject: [PATCH 07/15] amend Signed-off-by: Matteo Bettini --- test/test_exploration.py | 97 ++++++++++++++++++- .../modules/tensordict_module/exploration.py | 4 - 2 files changed, 92 insertions(+), 9 deletions(-) diff --git a/test/test_exploration.py b/test/test_exploration.py index c823dbaf4f4..57c16fbd5d3 100644 --- a/test/test_exploration.py +++ b/test/test_exploration.py @@ -14,12 +14,18 @@ NestedCountingEnv, ) from scipy.stats import ttest_1samp -from tensordict.nn import InteractionType, TensorDictModule + +from tensordict.nn import InteractionType, TensorDictModule, TensorDictSequential from tensordict.tensordict import TensorDict from torch import nn from torchrl.collectors import SyncDataCollector -from torchrl.data import BoundedTensorSpec, CompositeSpec +from torchrl.data import ( + BoundedTensorSpec, + CompositeSpec, + DiscreteTensorSpec, + OneHotDiscreteTensorSpec, +) from torchrl.envs import SerialEnv from torchrl.envs.transforms.transforms import gSDENoise, InitTracker, TransformedEnv from torchrl.envs.utils import set_exploration_type @@ -30,10 +36,15 @@ NormalParamWrapper, ) from torchrl.modules.models.exploration import LazygSDEModule -from torchrl.modules.tensordict_module.actors import Actor, ProbabilisticActor +from torchrl.modules.tensordict_module.actors import ( + Actor, + ProbabilisticActor, + QValueActor, +) from torchrl.modules.tensordict_module.exploration import ( _OrnsteinUhlenbeckProcess, AdditiveGaussianWrapper, + EGreedyModule, EGreedyWrapper, OrnsteinUhlenbeckProcessWrapper, ) @@ -41,12 +52,21 @@ @pytest.mark.parametrize("eps_init", [0.0, 0.5, 1.0]) class TestEGreedy: - def test_egreedy(self, eps_init): + @pytest.mark.parametrize("module", [True, False]) + def test_egreedy(self, eps_init, module): torch.manual_seed(0) spec = BoundedTensorSpec(1, 1, torch.Size([4])) module = torch.nn.Linear(4, 4, bias=False) + policy = Actor(spec=spec, module=module) - explorative_policy = EGreedyWrapper(policy, eps_init=eps_init, eps_end=eps_init) + if module: + explorative_policy = TensorDictSequential( + policy, EGreedyModule(eps_init=eps_init, eps_end=eps_init, spec=spec) + ) + else: + explorative_policy = EGreedyWrapper( + policy, eps_init=eps_init, eps_end=eps_init + ) td = TensorDict({"observation": torch.zeros(10, 4)}, batch_size=[10]) action = explorative_policy(td).get("action") if eps_init == 0: @@ -58,6 +78,73 @@ def test_egreedy(self, eps_init): assert (action == 0).any() assert ((action == 1) | (action == 0)).all() + @pytest.mark.parametrize("module", [True]) + @pytest.mark.parametrize("spec_class", ["discrete", "one_hot"]) + def test_egreedy_masked(self, module, eps_init, spec_class): + torch.manual_seed(0) + action_size = 4 + batch_size = (3, 4, 2) + module = torch.nn.Linear(action_size, action_size, bias=False) + if spec_class == "discrete": + spec = DiscreteTensorSpec(action_size, shape=batch_size) + else: + spec = OneHotDiscreteTensorSpec( + action_size, shape=batch_size + (action_size,) + ) + policy = QValueActor(spec=spec, module=module, action_mask_key="action_mask") + if module: + explorative_policy = TensorDictSequential( + policy, + EGreedyModule( + eps_init=eps_init, + eps_end=eps_init, + spec=spec, + action_mask_key="action_mask", + ), + ) + else: + explorative_policy = EGreedyWrapper( + policy, + eps_init=eps_init, + eps_end=eps_init, + action_mask_key="action_mask", + ) + torch.manual_seed(0) + action_mask = torch.ones(*batch_size, action_size).to(torch.bool) + td = TensorDict( + { + "observation": torch.zeros(*batch_size, action_size), + "action_mask": action_mask, + }, + batch_size=batch_size, + ) + action = explorative_policy(td).get("action") + + torch.manual_seed(0) + action_mask = torch.randint(high=2, size=(*batch_size, action_size)).to( + torch.bool + ) + while not action_mask.any(dim=-1).all() or action_mask.all(): + action_mask = torch.randint(high=2, size=(*batch_size, action_size)).to( + torch.bool + ) + + td = TensorDict( + { + "observation": torch.zeros(*batch_size, action_size), + "action_mask": action_mask, + }, + batch_size=batch_size, + ) + masked_action = explorative_policy(td).get("action") + + if spec_class == "discrete": + action = spec.to_one_hot(action) + masked_action = spec.to_one_hot(masked_action) + + assert not (action[~action_mask] == 0).all() + assert (masked_action[~action_mask] == 0).all() + @pytest.mark.parametrize("device", get_default_devices()) class TestOrnsteinUhlenbeckProcessWrapper: diff --git a/torchrl/modules/tensordict_module/exploration.py b/torchrl/modules/tensordict_module/exploration.py index 1c2c55bdc18..21ac27d9341 100644 --- a/torchrl/modules/tensordict_module/exploration.py +++ b/torchrl/modules/tensordict_module/exploration.py @@ -111,8 +111,6 @@ def __init__( if not isinstance(spec, CompositeSpec) and len(self.out_keys) >= 1: spec = CompositeSpec({action_key: spec}, shape=spec.shape[:-1]) self._spec = spec - else: - self._spec = CompositeSpec({action_key: None}) @property def spec(self): @@ -266,8 +264,6 @@ def __init__( self._spec = self.td_module.spec.clone() if action_key not in self._spec.keys(): self._spec[action_key] = None - else: - self._spec = CompositeSpec({key: None for key in policy.out_keys}) @property def spec(self): From b9b1bb626a7f0643e1d0045ed89d81dd75e00b2d Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Thu, 7 Sep 2023 10:03:38 +0100 Subject: [PATCH 08/15] test typo Signed-off-by: Matteo Bettini --- test/test_exploration.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_exploration.py b/test/test_exploration.py index 57c16fbd5d3..c08bba1e10e 100644 --- a/test/test_exploration.py +++ b/test/test_exploration.py @@ -78,7 +78,7 @@ def test_egreedy(self, eps_init, module): assert (action == 0).any() assert ((action == 1) | (action == 0)).all() - @pytest.mark.parametrize("module", [True]) + @pytest.mark.parametrize("module", [True, False]) @pytest.mark.parametrize("spec_class", ["discrete", "one_hot"]) def test_egreedy_masked(self, module, eps_init, spec_class): torch.manual_seed(0) From fc4d2eb305201e5ba6260f32c30250eeeace3f26 Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Thu, 7 Sep 2023 10:51:25 +0100 Subject: [PATCH 09/15] review Signed-off-by: Matteo Bettini --- docs/source/reference/modules.rst | 1 - test/test_exploration.py | 68 ++++++++++++++++++- torchrl/modules/tensordict_module/actors.py | 8 ++- .../modules/tensordict_module/exploration.py | 25 +++++-- 4 files changed, 89 insertions(+), 13 deletions(-) diff --git a/docs/source/reference/modules.rst b/docs/source/reference/modules.rst index 21edd5b2804..0281061b007 100644 --- a/docs/source/reference/modules.rst +++ b/docs/source/reference/modules.rst @@ -68,7 +68,6 @@ other cases, the action written in the tensordict is simply the network output. :template: rl_template_noinherit.rst AdditiveGaussianWrapper - EGreedyWrapper EGreedyModule OrnsteinUhlenbeckProcessWrapper diff --git a/test/test_exploration.py b/test/test_exploration.py index c08bba1e10e..23e02558dfe 100644 --- a/test/test_exploration.py +++ b/test/test_exploration.py @@ -50,8 +50,8 @@ ) -@pytest.mark.parametrize("eps_init", [0.0, 0.5, 1.0]) class TestEGreedy: + @pytest.mark.parametrize("eps_init", [0.0, 0.5, 1.0]) @pytest.mark.parametrize("module", [True, False]) def test_egreedy(self, eps_init, module): torch.manual_seed(0) @@ -78,6 +78,7 @@ def test_egreedy(self, eps_init, module): assert (action == 0).any() assert ((action == 1) | (action == 0)).all() + @pytest.mark.parametrize("eps_init", [0.0, 0.5, 1.0]) @pytest.mark.parametrize("module", [True, False]) @pytest.mark.parametrize("spec_class", ["discrete", "one_hot"]) def test_egreedy_masked(self, module, eps_init, spec_class): @@ -86,10 +87,11 @@ def test_egreedy_masked(self, module, eps_init, spec_class): batch_size = (3, 4, 2) module = torch.nn.Linear(action_size, action_size, bias=False) if spec_class == "discrete": - spec = DiscreteTensorSpec(action_size, shape=batch_size) + spec = DiscreteTensorSpec(action_size) else: spec = OneHotDiscreteTensorSpec( - action_size, shape=batch_size + (action_size,) + action_size, + shape=(action_size,), ) policy = QValueActor(spec=spec, module=module, action_mask_key="action_mask") if module: @@ -109,6 +111,14 @@ def test_egreedy_masked(self, module, eps_init, spec_class): eps_end=eps_init, action_mask_key="action_mask", ) + + td = TensorDict( + {"observation": torch.zeros(*batch_size, action_size)}, + batch_size=batch_size, + ) + with pytest.raises(KeyError, match="Action mask key action_mask not found in"): + explorative_policy(td) + torch.manual_seed(0) action_mask = torch.ones(*batch_size, action_size).to(torch.bool) td = TensorDict( @@ -145,6 +155,58 @@ def test_egreedy_masked(self, module, eps_init, spec_class): assert not (action[~action_mask] == 0).all() assert (masked_action[~action_mask] == 0).all() + def test_egreedy_wrapper_deprecation(self): + torch.manual_seed(0) + spec = BoundedTensorSpec(1, 1, torch.Size([4])) + module = torch.nn.Linear(4, 4, bias=False) + policy = Actor(spec=spec, module=module) + with pytest.deprecated_call(): + EGreedyWrapper(policy) + + def test_no_spec_error( + self, + ): + torch.manual_seed(0) + action_size = 4 + batch_size = (3, 4, 2) + module = torch.nn.Linear(action_size, action_size, bias=False) + spec = OneHotDiscreteTensorSpec(action_size, shape=(action_size,)) + policy = QValueActor(spec=spec, module=module) + explorative_policy = TensorDictSequential( + policy, + EGreedyModule(), + ) + td = TensorDict( + { + "observation": torch.zeros(*batch_size, action_size), + }, + batch_size=batch_size, + ) + + with pytest.raises( + RuntimeError, match="spec must be provided to the exploration wrapper." + ): + explorative_policy(td) + + @pytest.mark.parametrize("module", [True, False]) + def test_wrong_action_shape(self, module): + torch.manual_seed(0) + spec = BoundedTensorSpec(1, 1, torch.Size([4])) + module = torch.nn.Linear(4, 5, bias=False) + + policy = Actor(spec=spec, module=module) + if module: + explorative_policy = TensorDictSequential(policy, EGreedyModule(spec=spec)) + else: + explorative_policy = EGreedyWrapper( + policy, + ) + td = TensorDict({"observation": torch.zeros(10, 4)}, batch_size=[10]) + with pytest.raises( + ValueError, match="Action spec shape does not match the action shape" + ): + explorative_policy(td) + @pytest.mark.parametrize("device", get_default_devices()) class TestOrnsteinUhlenbeckProcessWrapper: diff --git a/torchrl/modules/tensordict_module/actors.py b/torchrl/modules/tensordict_module/actors.py index 91cd4fcd896..7606836caa0 100644 --- a/torchrl/modules/tensordict_module/actors.py +++ b/torchrl/modules/tensordict_module/actors.py @@ -459,7 +459,9 @@ def forward(self, tensordict: torch.Tensor) -> TensorDictBase: raise KeyError( f"Action mask key {self.action_mask_key} not found in {tensordict}." ) - action_values[~action_mask] = torch.finfo(action_values.dtype).min + action_values = torch.where( + action_mask, action_values, torch.finfo(action_values.dtype).min + ) action = self.action_func_mapping[self.action_space](action_values) @@ -633,7 +635,9 @@ def forward(self, tensordict: torch.Tensor) -> TensorDictBase: raise KeyError( f"Action mask key {self.action_mask_key} not found in {tensordict}." ) - action_values[action_mask] = torch.finfo(action_values.dtype).min + action_values = torch.where( + action_mask, action_values, torch.finfo(action_values.dtype).min + ) action = self.action_func_mapping[self.action_space](action_values) diff --git a/torchrl/modules/tensordict_module/exploration.py b/torchrl/modules/tensordict_module/exploration.py index 21ac27d9341..3128bce037c 100644 --- a/torchrl/modules/tensordict_module/exploration.py +++ b/torchrl/modules/tensordict_module/exploration.py @@ -110,7 +110,7 @@ def __init__( if spec is not None: if not isinstance(spec, CompositeSpec) and len(self.out_keys) >= 1: spec = CompositeSpec({action_key: spec}, shape=spec.shape[:-1]) - self._spec = spec + self._spec = spec @property def spec(self): @@ -156,7 +156,10 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: if spec.shape != out.shape: # In batched envs if the spec is passed unbatched, the rand() will not # cover all batched dims - if out.shape[-len(spec.shape) :] == spec.shape: + if ( + not len(spec.shape) + or out.shape[-len(spec.shape) :] == spec.shape + ): spec = spec.expand(out.shape) else: raise ValueError( @@ -171,15 +174,13 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: spec.update_mask(action_mask) out = cond * spec.rand().to(out.device) + (1 - cond) * out else: - raise RuntimeError( - "spec must be provided by the policy or directly to the exploration wrapper." - ) + raise RuntimeError("spec must be provided to the exploration wrapper.") action_tensordict.set(action_key, out) return tensordict class EGreedyWrapper(TensorDictModuleWrapper): - """Epsilon-Greedy PO wrapper. + """[Deprecated] ]Epsilon-Greedy PO wrapper. Args: policy (TensorDictModule): a deterministic policy. @@ -243,6 +244,11 @@ def __init__( action_mask_key: Optional[NestedKey] = None, spec: Optional[TensorSpec] = None, ): + warnings.warn( + "EGreedyWrapper is deprecated and it will be removed in v0.3. Please use torchrl/modules.EGreedyModule instead.", + category=DeprecationWarning, + ) + super().__init__(policy) self.register_buffer("eps_init", torch.tensor([eps_init])) self.register_buffer("eps_end", torch.tensor([eps_end])) @@ -264,6 +270,8 @@ def __init__( self._spec = self.td_module.spec.clone() if action_key not in self._spec.keys(): self._spec[action_key] = None + else: + self._spec = spec @property def spec(self): @@ -310,7 +318,10 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: if spec.shape != out.shape: # In batched envs if the spec is passed unbatched, the rand() will not # cover all batched dims - if out.shape[-len(spec.shape) :] == spec.shape: + if ( + not len(spec.shape) + or out.shape[-len(spec.shape) :] == spec.shape + ): spec = spec.expand(out.shape) else: raise ValueError( From ba18d3a67c7de6b806d63e0b3db04d5bc0db0d2d Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Thu, 7 Sep 2023 10:52:49 +0100 Subject: [PATCH 10/15] review Signed-off-by: Matteo Bettini --- test/test_actors.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_actors.py b/test/test_actors.py index 940d1f5dd96..06d59de0a48 100644 --- a/test/test_actors.py +++ b/test/test_actors.py @@ -619,7 +619,7 @@ def test_qvalue_mask(self, action_space, action_n): torch.manual_seed(0) shape = (3, 4, 3, action_n) action_values = torch.randn(size=shape) - td = TensorDict({"action_value": action_values.clone()}, [3]) + td = TensorDict({"action_value": action_values}, [3]) module = QValueModule( action_space=action_space, action_value_key="action_value", From 32ef6442bcd1546ebc1f3fcf2fdedc1d80bd6fd9 Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Thu, 7 Sep 2023 10:54:46 +0100 Subject: [PATCH 11/15] typo Signed-off-by: Matteo Bettini --- torchrl/modules/tensordict_module/exploration.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/torchrl/modules/tensordict_module/exploration.py b/torchrl/modules/tensordict_module/exploration.py index 3128bce037c..c8df7c70ae2 100644 --- a/torchrl/modules/tensordict_module/exploration.py +++ b/torchrl/modules/tensordict_module/exploration.py @@ -245,7 +245,8 @@ def __init__( spec: Optional[TensorSpec] = None, ): warnings.warn( - "EGreedyWrapper is deprecated and it will be removed in v0.3. Please use torchrl/modules.EGreedyModule instead.", + "EGreedyWrapper is deprecated and it will be removed in v0.3. " + "Please use torchrl.modules.EGreedyModule instead.", category=DeprecationWarning, ) From 4d06ed9c543b539225dc9518f8e522d631ce0bef Mon Sep 17 00:00:00 2001 From: Matteo Bettini <55539777+matteobettini@users.noreply.github.com> Date: Thu, 7 Sep 2023 13:39:25 +0100 Subject: [PATCH 12/15] Update torchrl/modules/tensordict_module/exploration.py Co-authored-by: Vincent Moens --- torchrl/modules/tensordict_module/exploration.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/torchrl/modules/tensordict_module/exploration.py b/torchrl/modules/tensordict_module/exploration.py index c8df7c70ae2..856f7413792 100644 --- a/torchrl/modules/tensordict_module/exploration.py +++ b/torchrl/modules/tensordict_module/exploration.py @@ -29,9 +29,10 @@ class EGreedyModule(TensorDictModuleBase): - """Epsilon-Greedy module. + """Epsilon-Greedy exploration module. - This module updates the action in a tensordict to an epsilon greedy one. + This module randomly updates the action(s) in a tensordict given an epsilon greedy exploration strategy. + At each call, random draws (one per action) are executed given a certain probability threshold. If successful, the corresponding actions are being replaced by random samples drawn from the action spec provided. Others are left unchanged. Keyword Args: eps_init (scalar, optional): initial epsilon value. From 17da46e3369595579e1f35f91f45a0937ffcb6b5 Mon Sep 17 00:00:00 2001 From: Matteo Bettini <55539777+matteobettini@users.noreply.github.com> Date: Thu, 7 Sep 2023 13:42:54 +0100 Subject: [PATCH 13/15] Apply suggestions from code review Co-authored-by: Vincent Moens --- .../modules/tensordict_module/exploration.py | 24 +++++++++++-------- 1 file changed, 14 insertions(+), 10 deletions(-) diff --git a/torchrl/modules/tensordict_module/exploration.py b/torchrl/modules/tensordict_module/exploration.py index 856f7413792..f90aa7f3eb4 100644 --- a/torchrl/modules/tensordict_module/exploration.py +++ b/torchrl/modules/tensordict_module/exploration.py @@ -34,18 +34,21 @@ class EGreedyModule(TensorDictModuleBase): This module randomly updates the action(s) in a tensordict given an epsilon greedy exploration strategy. At each call, random draws (one per action) are executed given a certain probability threshold. If successful, the corresponding actions are being replaced by random samples drawn from the action spec provided. Others are left unchanged. - Keyword Args: + Args: eps_init (scalar, optional): initial epsilon value. default: 1.0 eps_end (scalar, optional): final epsilon value. default: 0.1 - annealing_num_steps (int, optional): number of steps it will take for epsilon to reach the eps_end value + annealing_num_steps (int, optional): number of steps it will take for epsilon to reach + the ``eps_end`` value. Defaults to `1000`. + + Keyword Args: action_key (NestedKey, optional): if the policy module has more than one output key, - its output spec will be of type CompositeSpec. One needs to know where to + its output spec will be of type :class:`torchrl.data.CompositeSpec`. One needs to know where to find the action spec. Default is ``"action"``. - action_mask_key (NestedKey, optional): the key where the action maskcan be found in the tensordict. - Default is ``"None"`` (corresponding to no mask). + action_mask_key (NestedKey, optional): the key where the action mask can be found in the input tensordict. + Default is ``None`` (corresponding to no mask). spec (TensorSpec, optional): if provided, the sampled action will be projected onto the valid action space once explored. If not provided, the exploration wrapper will attempt to recover it from the policy. @@ -87,6 +90,7 @@ def __init__( eps_init: float = 1.0, eps_end: float = 0.1, annealing_num_steps: int = 1000, + *, action_key: Optional[NestedKey] = "action", action_mask_key: Optional[NestedKey] = None, spec: Optional[TensorSpec] = None, @@ -120,10 +124,10 @@ def spec(self): def step(self, frames: int = 1) -> None: """A step of epsilon decay. - After self.annealing_num_steps, this function is a no-op. + After `self.annealing_num_steps` calls to this method, calls result in no-op. Args: - frames (int): number of frames since last step. + frames (int, optional): number of frames since last step. Defaults to ``1``. """ for _ in range(frames): @@ -181,7 +185,7 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: class EGreedyWrapper(TensorDictModuleWrapper): - """[Deprecated] ]Epsilon-Greedy PO wrapper. + """[Deprecated] Epsilon-Greedy PO wrapper. Args: policy (TensorDictModule): a deterministic policy. @@ -196,8 +200,8 @@ class EGreedyWrapper(TensorDictModuleWrapper): its output spec will be of type CompositeSpec. One needs to know where to find the action spec. Default is "action". - action_mask_key (NestedKey, optional): the key where the action maskcan be found in the tensordict. - Default is ``"None"`` (corresponding to no mask). + action_mask_key (NestedKey, optional): the key where the action mask can be found in the input tensordict. + Default is ``None`` (corresponding to no mask). spec (TensorSpec, optional): if provided, the sampled action will be projected onto the valid action space once explored. If not provided, the exploration wrapper will attempt to recover it from the policy. From 005a83ff8c13838745aee8355c86d5818cb27f72 Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Thu, 7 Sep 2023 13:50:53 +0100 Subject: [PATCH 14/15] review Signed-off-by: Matteo Bettini --- test/test_exploration.py | 2 +- .../modules/tensordict_module/exploration.py | 18 ++++++------------ 2 files changed, 7 insertions(+), 13 deletions(-) diff --git a/test/test_exploration.py b/test/test_exploration.py index 23e02558dfe..c4cd44f0692 100644 --- a/test/test_exploration.py +++ b/test/test_exploration.py @@ -174,7 +174,7 @@ def test_no_spec_error( policy = QValueActor(spec=spec, module=module) explorative_policy = TensorDictSequential( policy, - EGreedyModule(), + EGreedyModule(spec=None), ) td = TensorDict( { diff --git a/torchrl/modules/tensordict_module/exploration.py b/torchrl/modules/tensordict_module/exploration.py index f90aa7f3eb4..73d33f27c27 100644 --- a/torchrl/modules/tensordict_module/exploration.py +++ b/torchrl/modules/tensordict_module/exploration.py @@ -35,6 +35,7 @@ class EGreedyModule(TensorDictModuleBase): At each call, random draws (one per action) are executed given a certain probability threshold. If successful, the corresponding actions are being replaced by random samples drawn from the action spec provided. Others are left unchanged. Args: + spec (TensorSpec): the spec used for samppling actions. eps_init (scalar, optional): initial epsilon value. default: 1.0 eps_end (scalar, optional): final epsilon value. @@ -43,15 +44,10 @@ class EGreedyModule(TensorDictModuleBase): the ``eps_end`` value. Defaults to `1000`. Keyword Args: - action_key (NestedKey, optional): if the policy module has more than one output key, - its output spec will be of type :class:`torchrl.data.CompositeSpec`. One needs to know where to - find the action spec. + action_key (NestedKey, optional): the key where the action can be found in the input tensordict. Default is ``"action"``. action_mask_key (NestedKey, optional): the key where the action mask can be found in the input tensordict. Default is ``None`` (corresponding to no mask). - spec (TensorSpec, optional): if provided, the sampled action will be - projected onto the valid action space once explored. If not provided, - the exploration wrapper will attempt to recover it from the policy. .. note:: It is crucial to incorporate a call to :meth:`~.step` in the training loop @@ -87,13 +83,13 @@ class EGreedyModule(TensorDictModuleBase): def __init__( self, + spec: TensorSpec, eps_init: float = 1.0, eps_end: float = 0.1, annealing_num_steps: int = 1000, *, action_key: Optional[NestedKey] = "action", action_mask_key: Optional[NestedKey] = None, - spec: Optional[TensorSpec] = None, ): self.action_key = action_key self.action_mask_key = action_mask_key @@ -196,14 +192,12 @@ class EGreedyWrapper(TensorDictModuleWrapper): eps_end (scalar, optional): final epsilon value. default: 0.1 annealing_num_steps (int, optional): number of steps it will take for epsilon to reach the eps_end value - action_key (NestedKey, optional): if the policy module has more than one output key, - its output spec will be of type CompositeSpec. One needs to know where to - find the action spec. - Default is "action". + action_key (NestedKey, optional): the key where the action can be found in the input tensordict. + Default is ``"action"``. action_mask_key (NestedKey, optional): the key where the action mask can be found in the input tensordict. Default is ``None`` (corresponding to no mask). spec (TensorSpec, optional): if provided, the sampled action will be - projected onto the valid action space once explored. If not provided, + taken from this action space. If not provided, the exploration wrapper will attempt to recover it from the policy. .. note:: From 364ee712dfc65411d17977c0b74ffe358e72ac0b Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Thu, 7 Sep 2023 13:52:23 +0100 Subject: [PATCH 15/15] typo Signed-off-by: Matteo Bettini --- torchrl/modules/tensordict_module/exploration.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/torchrl/modules/tensordict_module/exploration.py b/torchrl/modules/tensordict_module/exploration.py index 73d33f27c27..d2e8ed8e3a1 100644 --- a/torchrl/modules/tensordict_module/exploration.py +++ b/torchrl/modules/tensordict_module/exploration.py @@ -32,10 +32,12 @@ class EGreedyModule(TensorDictModuleBase): """Epsilon-Greedy exploration module. This module randomly updates the action(s) in a tensordict given an epsilon greedy exploration strategy. - At each call, random draws (one per action) are executed given a certain probability threshold. If successful, the corresponding actions are being replaced by random samples drawn from the action spec provided. Others are left unchanged. + At each call, random draws (one per action) are executed given a certain probability threshold. If successful, + the corresponding actions are being replaced by random samples drawn from the action spec provided. + Others are left unchanged. Args: - spec (TensorSpec): the spec used for samppling actions. + spec (TensorSpec): the spec used for sampling actions. eps_init (scalar, optional): initial epsilon value. default: 1.0 eps_end (scalar, optional): final epsilon value.