From 0e38ed46dff5fa7eb5d325d05ed946f22d0a8c81 Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Fri, 6 Jan 2023 19:01:23 +0100 Subject: [PATCH 01/26] Changed flag in EnvBase --- torchrl/envs/common.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/torchrl/envs/common.py b/torchrl/envs/common.py index aec32eff395..f4c25e4aad6 100644 --- a/torchrl/envs/common.py +++ b/torchrl/envs/common.py @@ -14,12 +14,11 @@ import torch import torch.nn as nn from tensordict.tensordict import TensorDict, TensorDictBase - from torchrl.data import CompositeSpec, TensorSpec +from .utils import get_available_libraries, step_mdp from .._utils import prod, seed_generator from ..data.utils import DEVICE_TYPING -from .utils import get_available_libraries, step_mdp LIBRARIES = get_available_libraries() @@ -457,9 +456,15 @@ def reset( *tensordict_reset.batch_size, 1, dtype=torch.bool, device=self.device ), ) - if tensordict_reset.get("done").any(): + if tensordict is not None and "_reset" in tensordict.keys(): + self._assert_tensordict_shape(tensordict) + _reset = tensordict.get("_reset") + else: + _reset = None + + if (_reset is None and tensordict_reset.get("done").any()) or (_reset is not None and tensordict_reset.get("done")[_reset].any()): raise RuntimeError( - f"Env {self} was done after reset. This is (currently) not allowed." + f"Env {self} was done after reset on specified '_reset' dimensions. This is (currently) not allowed." ) if tensordict is not None: tensordict.update(tensordict_reset) From 7ef1b478e4b29e5914f42fa0f0c54bb091bc90eb Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Fri, 6 Jan 2023 19:02:17 +0100 Subject: [PATCH 02/26] Changed flag in Vec Envs --- torchrl/envs/vec_env.py | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/torchrl/envs/vec_env.py b/torchrl/envs/vec_env.py index 733de7d1c45..c490ab1c43e 100644 --- a/torchrl/envs/vec_env.py +++ b/torchrl/envs/vec_env.py @@ -626,15 +626,16 @@ def set_seed( @_check_start def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase: - if tensordict is not None and "reset_workers" in tensordict.keys(): + + if tensordict is not None and "_reset" in tensordict.keys(): self._assert_tensordict_shape(tensordict) - reset_workers = tensordict.get("reset_workers") + _reset = tensordict.get("_reset") else: - reset_workers = torch.ones(self.num_workers, dtype=torch.bool) + _reset = torch.ones((*self.batch_size,1), dtype=torch.bool) keys = set() for i, _env in enumerate(self._envs): - if not reset_workers[i]: + if not _reset[i].any(): continue _tensordict = tensordict[i] if tensordict is not None else None _td = _env._reset(tensordict=_tensordict, **kwargs) @@ -843,29 +844,30 @@ def set_seed( @_check_start def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase: cmd_out = "reset" - if tensordict is not None and "reset_workers" in tensordict.keys(): + if tensordict is not None and "_reset" in tensordict.keys(): self._assert_tensordict_shape(tensordict) - reset_workers = tensordict.get("reset_workers") + _reset = tensordict.get("_reset") else: - reset_workers = torch.ones(self.num_workers, dtype=torch.bool) + _reset = torch.ones((*self.batch_size,1), dtype=torch.bool) for i, channel in enumerate(self.parent_channels): - if not reset_workers[i]: + if not _reset[i].any(): continue + kwargs["tensordict"] = tensordict[i] if tensordict is not None else None channel.send((cmd_out, kwargs)) keys = set() for i, channel in enumerate(self.parent_channels): - if not reset_workers[i]: + if not _reset[i].any(): continue cmd_in, new_keys = channel.recv() keys = keys.union(new_keys) if cmd_in != "reset_obs": raise RuntimeError(f"received cmd {cmd_in} instead of reset_obs") check_count = 0 - while self.shared_tensordict_parent.get("done").any(): + while self.shared_tensordict_parent.get("done")[_reset].any(): if check_count == 4: - raise RuntimeError("Envs have just been reset but some are still done") + raise RuntimeError("Envs have just been reset bur env is done on specified '_reset' dimensions.") else: check_count += 1 # there might be some delay between writing the shared tensordict @@ -1019,8 +1021,6 @@ def _run_worker_pipe_shared_mem( _td.pin_memory() tensordict.update_(_td) child_pipe.send(("reset_obs", reset_keys)) - if done.any(): - raise RuntimeError(f"{env.__class__.__name__} is done after reset") elif cmd == "step": if not initialized: From 4cf67cbccadd05e7bc60257a0cd9a73e38067a06 Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Fri, 6 Jan 2023 23:11:54 +0100 Subject: [PATCH 03/26] Fixed collectors --- torchrl/collectors/collectors.py | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/torchrl/collectors/collectors.py b/torchrl/collectors/collectors.py index 3a288c589ef..09dbabe451f 100644 --- a/torchrl/collectors/collectors.py +++ b/torchrl/collectors/collectors.py @@ -22,13 +22,11 @@ from tensordict.tensordict import TensorDict, TensorDictBase from torch import multiprocessing as mp from torch.utils.data import IterableDataset - from torchrl._utils import _check_for_faulty_process, prod from torchrl.collectors.utils import split_trajectories from torchrl.data import TensorSpec from torchrl.data.utils import CloudpickleWrapper, DEVICE_TYPING from torchrl.envs.common import EnvBase - from torchrl.envs.transforms import TransformedEnv from torchrl.envs.utils import set_exploration_mode, step_mdp from torchrl.envs.vec_env import _BatchedEnv @@ -615,7 +613,7 @@ def _reset_if_necessary(self) -> None: steps = steps.clone() if len(self.env.batch_size): self._tensordict.masked_fill_(done_or_terminated, 0) - self._tensordict.set("reset_workers", done_or_terminated) + self._tensordict.set("_reset", done_or_terminated) else: self._tensordict.zero_() self.env.reset(self._tensordict) @@ -625,7 +623,7 @@ def _reset_if_necessary(self) -> None: f"Got {sum(self._tensordict.get('done'))} done envs after reset." ) if len(self.env.batch_size): - self._tensordict.del_("reset_workers") + self._tensordict.del_("_reset") traj_ids[done_or_terminated] = traj_ids.max() + torch.arange( 1, done_or_terminated.sum() + 1, device=traj_ids.device ) @@ -683,13 +681,13 @@ def reset(self, index=None, **kwargs) -> None: # check that the env supports partial reset if prod(self.env.batch_size) == 0: raise RuntimeError("resetting unique env with index is not permitted.") - reset_workers = torch.zeros( - *self.env.batch_size, + _reset = torch.zeros( + (*self.env.batch_size, 1), dtype=torch.bool, device=self.env.device, ) - reset_workers[index] = 1 - td_in = TensorDict({"reset_workers": reset_workers}, self.env.batch_size) + _reset[index] = 1 + td_in = TensorDict({"_reset": _reset}, self.env.batch_size) self._tensordict[index].zero_() else: td_in = None From 197c354ed1473ac3c2f765c8f3dec1b94d82329c Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Fri, 6 Jan 2023 23:12:18 +0100 Subject: [PATCH 04/26] Fixed transforms --- test/test_transforms.py | 8 ++++---- torchrl/envs/transforms/transforms.py | 9 ++++----- 2 files changed, 8 insertions(+), 9 deletions(-) diff --git a/test/test_transforms.py b/test/test_transforms.py index 8d3e40f7550..b20afb0e843 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -832,7 +832,7 @@ def test_sum_reward(self, keys, device): assert (td.get("episode_reward") == 2 * td.get("reward")).all() # reset environments - td.set("reset_workers", torch.ones((batch, 1), dtype=torch.bool, device=device)) + td.set("_reset", torch.ones((batch, 1), dtype=torch.bool, device=device)) rs.reset(td) # apply a third time, episode_reward should be equal to reward again @@ -1724,7 +1724,7 @@ def test_step_counter(self, max_steps, device, batch, reset_workers): {"done": torch.zeros(*batch, 1, dtype=torch.bool)}, batch, device=device ) if reset_workers: - td.set("reset_workers", torch.randn(*batch, 1) < 0) + td.set("_reset", torch.randn(*batch, 1) < 0) step_counter.reset(td) assert not torch.all(td.get("step_count")) i = 0 @@ -1740,10 +1740,10 @@ def test_step_counter(self, max_steps, device, batch, reset_workers): step_counter.reset(td) if reset_workers: assert torch.all( - torch.masked_select(td.get("step_count"), td.get("reset_workers")) == 0 + torch.masked_select(td.get("step_count"), td.get("_reset")) == 0 ) assert torch.all( - torch.masked_select(td.get("step_count"), ~td.get("reset_workers")) == i + torch.masked_select(td.get("step_count"), ~td.get("_reset")) == i ) else: assert torch.all(td.get("step_count") == 0) diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index 3dc630b7f31..c07a5df5882 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -14,7 +14,6 @@ import torch from tensordict.tensordict import TensorDict, TensorDictBase from torch import nn, Tensor - from torchrl.data.tensor_specs import ( BinaryDiscreteTensorSpec, BoundedTensorSpec, @@ -2516,8 +2515,8 @@ def reset(self, tensordict: TensorDictBase) -> TensorDictBase: # Batched environments else: - reset_workers = tensordict.get( - "reset_workers", + _reset = tensordict.get( + "_reset", torch.ones( *tensordict.batch_size, 1, @@ -2527,7 +2526,7 @@ def reset(self, tensordict: TensorDictBase) -> TensorDictBase: ) for out_key in self.out_keys: if out_key in tensordict.keys(): - tensordict[out_key][reset_workers] = 0.0 + tensordict[out_key][_reset] = 0.0 return tensordict @@ -2618,7 +2617,7 @@ def __init__(self, max_steps: Optional[int] = None): def reset(self, tensordict: TensorDictBase) -> TensorDictBase: workers = tensordict.get( - "reset_workers", + "_reset", default=torch.ones( *tensordict.batch_size, 1, dtype=torch.bool, device=tensordict.device ), From 9d191e1f8822d7e34d02be147890f7624a2185b1 Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Fri, 6 Jan 2023 23:12:56 +0100 Subject: [PATCH 05/26] Fixed env tests and lint --- test/test_env.py | 4 ++-- torchrl/envs/common.py | 4 +++- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/test/test_env.py b/test/test_env.py index 4661211f863..902e4861e9f 100644 --- a/test/test_env.py +++ b/test/test_env.py @@ -511,7 +511,7 @@ def test_parallel_env( _ = env_parallel.step(td) td_reset = TensorDict( - source={"reset_workers": torch.zeros(N, dtype=torch.bool).bernoulli_()}, + source={"_reset": torch.zeros(N, 1, dtype=torch.bool).bernoulli_()}, batch_size=[ N, ], @@ -595,7 +595,7 @@ def test_parallel_env_with_policy( _ = env_parallel.step(td) td_reset = TensorDict( - source={"reset_workers": torch.zeros(N, dtype=torch.bool).bernoulli_()}, + source={"_reset": torch.zeros(N, 1, dtype=torch.bool).bernoulli_()}, batch_size=[ N, ], diff --git a/torchrl/envs/common.py b/torchrl/envs/common.py index f4c25e4aad6..75a06c67125 100644 --- a/torchrl/envs/common.py +++ b/torchrl/envs/common.py @@ -462,7 +462,9 @@ def reset( else: _reset = None - if (_reset is None and tensordict_reset.get("done").any()) or (_reset is not None and tensordict_reset.get("done")[_reset].any()): + if (_reset is None and tensordict_reset.get("done").any()) or ( + _reset is not None and tensordict_reset.get("done")[_reset].any() + ): raise RuntimeError( f"Env {self} was done after reset on specified '_reset' dimensions. This is (currently) not allowed." ) From 53204146bd85b43df45e7e0236e5762644946d12 Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Fri, 6 Jan 2023 23:13:04 +0100 Subject: [PATCH 06/26] Fixed env tests and lint --- torchrl/envs/vec_env.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/torchrl/envs/vec_env.py b/torchrl/envs/vec_env.py index c490ab1c43e..1b505658f6b 100644 --- a/torchrl/envs/vec_env.py +++ b/torchrl/envs/vec_env.py @@ -631,7 +631,7 @@ def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase: self._assert_tensordict_shape(tensordict) _reset = tensordict.get("_reset") else: - _reset = torch.ones((*self.batch_size,1), dtype=torch.bool) + _reset = torch.ones((*self.batch_size, 1), dtype=torch.bool) keys = set() for i, _env in enumerate(self._envs): @@ -848,7 +848,7 @@ def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase: self._assert_tensordict_shape(tensordict) _reset = tensordict.get("_reset") else: - _reset = torch.ones((*self.batch_size,1), dtype=torch.bool) + _reset = torch.ones((*self.batch_size, 1), dtype=torch.bool) for i, channel in enumerate(self.parent_channels): if not _reset[i].any(): @@ -867,7 +867,9 @@ def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase: check_count = 0 while self.shared_tensordict_parent.get("done")[_reset].any(): if check_count == 4: - raise RuntimeError("Envs have just been reset bur env is done on specified '_reset' dimensions.") + raise RuntimeError( + "Envs have just been reset bur env is done on specified '_reset' dimensions." + ) else: check_count += 1 # there might be some delay between writing the shared tensordict From 2cf04f9a3277940fe786172109a8874a6b094205 Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Fri, 6 Jan 2023 23:13:55 +0100 Subject: [PATCH 07/26] Linting --- torchrl/envs/common.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/torchrl/envs/common.py b/torchrl/envs/common.py index 75a06c67125..bf0ecccfa79 100644 --- a/torchrl/envs/common.py +++ b/torchrl/envs/common.py @@ -16,10 +16,11 @@ from tensordict.tensordict import TensorDict, TensorDictBase from torchrl.data import CompositeSpec, TensorSpec -from .utils import get_available_libraries, step_mdp from .._utils import prod, seed_generator from ..data.utils import DEVICE_TYPING +from .utils import get_available_libraries, step_mdp + LIBRARIES = get_available_libraries() From 155efd67840fbe55078512bfec40076f62c9887e Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Sat, 7 Jan 2023 00:01:08 +0100 Subject: [PATCH 08/26] Added "_reset" kay deletion after use --- torchrl/envs/common.py | 5 +++-- torchrl/envs/vec_env.py | 2 ++ 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/torchrl/envs/common.py b/torchrl/envs/common.py index bf0ecccfa79..105e7e2ce11 100644 --- a/torchrl/envs/common.py +++ b/torchrl/envs/common.py @@ -16,11 +16,10 @@ from tensordict.tensordict import TensorDict, TensorDictBase from torchrl.data import CompositeSpec, TensorSpec +from .utils import get_available_libraries, step_mdp from .._utils import prod, seed_generator from ..data.utils import DEVICE_TYPING -from .utils import get_available_libraries, step_mdp - LIBRARIES = get_available_libraries() @@ -470,6 +469,8 @@ def reset( f"Env {self} was done after reset on specified '_reset' dimensions. This is (currently) not allowed." ) if tensordict is not None: + if "_reset" in tensordict.keys(): + tensordict.del_("_reset") tensordict.update(tensordict_reset) else: tensordict = tensordict_reset diff --git a/torchrl/envs/vec_env.py b/torchrl/envs/vec_env.py index 1b505658f6b..744400d18d6 100644 --- a/torchrl/envs/vec_env.py +++ b/torchrl/envs/vec_env.py @@ -1012,6 +1012,8 @@ def _run_worker_pipe_shared_mem( raise RuntimeError("call 'init' before resetting") # _td = tensordict.select("observation").to(env.device).clone() _td = env._reset(**reset_kwargs) + if _td.get("_reset", None) is not None: + _td.del_("_reset") done = _td.get("done", None) if done is None: _td["done"] = done = torch.zeros( From 01a737709bf369e57a1fe842b615007d17d0621d Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Sat, 7 Jan 2023 00:01:29 +0100 Subject: [PATCH 09/26] Linting --- torchrl/envs/common.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/torchrl/envs/common.py b/torchrl/envs/common.py index 105e7e2ce11..1f65a606e94 100644 --- a/torchrl/envs/common.py +++ b/torchrl/envs/common.py @@ -16,10 +16,11 @@ from tensordict.tensordict import TensorDict, TensorDictBase from torchrl.data import CompositeSpec, TensorSpec -from .utils import get_available_libraries, step_mdp from .._utils import prod, seed_generator from ..data.utils import DEVICE_TYPING +from .utils import get_available_libraries, step_mdp + LIBRARIES = get_available_libraries() From 23a82fc585d111ef4fc0c29b298ee98896be9532 Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Sat, 7 Jan 2023 00:04:40 +0100 Subject: [PATCH 10/26] Added "_reset" kay deletion after use --- torchrl/envs/vec_env.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/torchrl/envs/vec_env.py b/torchrl/envs/vec_env.py index 744400d18d6..12505f9df37 100644 --- a/torchrl/envs/vec_env.py +++ b/torchrl/envs/vec_env.py @@ -639,6 +639,8 @@ def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase: continue _tensordict = tensordict[i] if tensordict is not None else None _td = _env._reset(tensordict=_tensordict, **kwargs) + if _td.get("_reset", None) is not None: + _td.del_("_reset") keys = keys.union(_td.keys()) self.shared_tensordicts[i].update_(_td) From f8e3d3386e60ad33f8b5b7701093abb926f07b36 Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Sun, 8 Jan 2023 08:41:36 +0100 Subject: [PATCH 11/26] removed deletion of reset flag from collector as the flag is deleted by the envs --- torchrl/collectors/collectors.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/torchrl/collectors/collectors.py b/torchrl/collectors/collectors.py index 09dbabe451f..9b47b4df46c 100644 --- a/torchrl/collectors/collectors.py +++ b/torchrl/collectors/collectors.py @@ -622,8 +622,6 @@ def _reset_if_necessary(self) -> None: raise RuntimeError( f"Got {sum(self._tensordict.get('done'))} done envs after reset." ) - if len(self.env.batch_size): - self._tensordict.del_("_reset") traj_ids[done_or_terminated] = traj_ids.max() + torch.arange( 1, done_or_terminated.sum() + 1, device=traj_ids.device ) From 2c4eb10dc1b2d309d2254e1c5c00138086d71230 Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Sun, 8 Jan 2023 09:07:33 +0100 Subject: [PATCH 12/26] Moved _reset check before calling wrapped env --- torchrl/envs/common.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/torchrl/envs/common.py b/torchrl/envs/common.py index 1f65a606e94..37566f18f38 100644 --- a/torchrl/envs/common.py +++ b/torchrl/envs/common.py @@ -16,11 +16,10 @@ from tensordict.tensordict import TensorDict, TensorDictBase from torchrl.data import CompositeSpec, TensorSpec +from .utils import get_available_libraries, step_mdp from .._utils import prod, seed_generator from ..data.utils import DEVICE_TYPING -from .utils import get_available_libraries, step_mdp - LIBRARIES = get_available_libraries() @@ -428,6 +427,12 @@ def reset( a tensordict (or the input tensordict, if any), modified in place with the resulting observations. """ + if tensordict is not None and "_reset" in tensordict.keys(): + self._assert_tensordict_shape(tensordict) + _reset = tensordict.get("_reset") + else: + _reset = None + tensordict_reset = self._reset(tensordict, **kwargs) done = tensordict_reset.get("done", None) @@ -457,11 +462,6 @@ def reset( *tensordict_reset.batch_size, 1, dtype=torch.bool, device=self.device ), ) - if tensordict is not None and "_reset" in tensordict.keys(): - self._assert_tensordict_shape(tensordict) - _reset = tensordict.get("_reset") - else: - _reset = None if (_reset is None and tensordict_reset.get("done").any()) or ( _reset is not None and tensordict_reset.get("done")[_reset].any() From 87228e41dd9030a6dd8c305e78e8241a3e833fe6 Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Sun, 8 Jan 2023 09:09:26 +0100 Subject: [PATCH 13/26] refactor how to check if flag is present --- torchrl/envs/vec_env.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchrl/envs/vec_env.py b/torchrl/envs/vec_env.py index 12505f9df37..69072806cd9 100644 --- a/torchrl/envs/vec_env.py +++ b/torchrl/envs/vec_env.py @@ -639,7 +639,7 @@ def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase: continue _tensordict = tensordict[i] if tensordict is not None else None _td = _env._reset(tensordict=_tensordict, **kwargs) - if _td.get("_reset", None) is not None: + if "_reset" in _td.keys(): _td.del_("_reset") keys = keys.union(_td.keys()) self.shared_tensordicts[i].update_(_td) @@ -1014,7 +1014,7 @@ def _run_worker_pipe_shared_mem( raise RuntimeError("call 'init' before resetting") # _td = tensordict.select("observation").to(env.device).clone() _td = env._reset(**reset_kwargs) - if _td.get("_reset", None) is not None: + if "_reset" in _td.keys(): _td.del_("_reset") done = _td.get("done", None) if done is None: From 34317c0efd1d605cb96ec1c7808d81958ce62538 Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Sun, 8 Jan 2023 10:34:49 +0100 Subject: [PATCH 14/26] added tests --- test/mocking_classes.py | 51 +++++++++++++++++++++++++++++- test/test_env.py | 70 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 120 insertions(+), 1 deletion(-) diff --git a/test/mocking_classes.py b/test/mocking_classes.py index c949a5e094f..39525fec36f 100644 --- a/test/mocking_classes.py +++ b/test/mocking_classes.py @@ -7,7 +7,6 @@ import torch import torch.nn as nn from tensordict.tensordict import TensorDict, TensorDictBase - from torchrl.data.tensor_specs import ( BinaryDiscreteTensorSpec, BoundedTensorSpec, @@ -718,3 +717,53 @@ def __init__(self, in_size, out_size): def forward(self, observation, action): return self.linear(torch.cat([observation, action], dim=-1)) + + +class CountingEnv(EnvBase): + def __init__(self, max_steps: int = 5, **kwargs): + super().__init__(**kwargs) + self.max_steps = max_steps + + self.observation_spec = CompositeSpec( + observation=UnboundedContinuousTensorSpec((1,)) + ) + self.reward_spec = UnboundedContinuousTensorSpec((1,)) + self.input_spec = CompositeSpec(action=BinaryDiscreteTensorSpec(1)) + + self.count = torch.zeros( + (*self.batch_size, 1), device=self.device, dtype=torch.int + ) + + def _set_seed(self, seed: Optional[int]): + torch.manual_seed(seed) + + def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase: + if tensordict is not None and "_reset" in tensordict.keys(): + _reset = tensordict.get("_reset") + self.count[_reset] = 0 + else: + self.count[:] = 0 + return TensorDict( + source={ + "observation": self.count.clone(), + "done": self.count > self.max_steps, + }, + batch_size=self.batch_size, + device=self.device, + ) + + def _step( + self, + tensordict: TensorDictBase, + ) -> TensorDictBase: + action = tensordict.get("action") + self.count += action.to(torch.int) + return TensorDict( + source={ + "observation": self.count, + "done": self.count > self.max_steps, + "reward": torch.zeros_like(self.count, dtype=torch.float), + }, + batch_size=self.batch_size, + device=self.device, + ) diff --git a/test/test_env.py b/test/test_env.py index 902e4861e9f..a144e677dca 100644 --- a/test/test_env.py +++ b/test/test_env.py @@ -20,6 +20,7 @@ ) from mocking_classes import ( ActionObsMergeLinear, + CountingEnv, DiscreteActionConvMockEnv, DiscreteActionVecMockEnv, DummyModelBasedEnvBase, @@ -900,6 +901,75 @@ def env_fn2(seed): env1.close() env2.close() + @pytest.mark.parametrize("batch_size", [(), (1,), (4,), (32, 5)]) + @pytest.mark.parametrize("n_workers", [1, 2]) + def test_parallel_env_reset_flag(self, batch_size, n_workers, max_steps=3): + env = ParallelEnv( + n_workers, lambda: CountingEnv(max_steps=max_steps, batch_size=batch_size) + ) + env.set_seed(1) + action = env.action_spec.rand(env.batch_size) + action[:] = 1 + + for i in range(max_steps): + td = env.step( + TensorDict( + {"action": action}, batch_size=env.batch_size, device=env.device + ) + ) + assert (td["done"] == 0).all() + assert (td["next"]["observation"] == i + 1).all() + + td = env.step( + TensorDict({"action": action}, batch_size=env.batch_size, device=env.device) + ) + assert (td["done"] == 1).all() + assert (td["next"]["observation"] == max_steps + 1).all() + + _reset = torch.randint( + low=0, high=2, size=(*env.batch_size, 1), dtype=torch.bool + ) + td_reset = env.reset( + TensorDict({"_reset": _reset}, batch_size=env.batch_size, device=env.device) + ) + if _reset.any(): + assert (td_reset["done"][_reset] == 0).all() + assert (td_reset["observation"][_reset] == 0).all() + assert (td_reset["done"][~_reset] == 1).all() + assert (td_reset["observation"][~_reset] == max_steps + 1).all() + + +@pytest.mark.parametrize("batch_size", [(), (2,), (32, 5)]) +def test_env_base_reset_flag(batch_size, max_steps=3): + env = CountingEnv(max_steps=max_steps, batch_size=batch_size) + env.set_seed(1) + + action = env.action_spec.rand(env.batch_size) + action[:] = 1 + + for i in range(max_steps): + td = env.step( + TensorDict({"action": action}, batch_size=env.batch_size, device=env.device) + ) + assert (td["done"] == 0).all() + assert (td["next"]["observation"] == i + 1).all() + + td = env.step( + TensorDict({"action": action}, batch_size=env.batch_size, device=env.device) + ) + assert (td["done"] == 1).all() + assert (td["next"]["observation"] == max_steps + 1).all() + + _reset = torch.randint(low=0, high=2, size=(*env.batch_size, 1), dtype=torch.bool) + td_reset = env.reset( + TensorDict({"_reset": _reset}, batch_size=env.batch_size, device=env.device) + ) + + assert (td_reset["done"][_reset] == 0).all() + assert (td_reset["observation"][_reset] == 0).all() + assert (td_reset["done"][~_reset] == 1).all() + assert (td_reset["observation"][~_reset] == max_steps + 1).all() + @pytest.mark.skipif(not _has_gym, reason="no gym") def test_seed(): From 9b3d4e79b474ec816d2ed4e194c4f4b47c77cb5d Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Sun, 8 Jan 2023 10:35:30 +0100 Subject: [PATCH 15/26] Linting --- torchrl/envs/common.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/torchrl/envs/common.py b/torchrl/envs/common.py index 37566f18f38..e455e4412ef 100644 --- a/torchrl/envs/common.py +++ b/torchrl/envs/common.py @@ -16,10 +16,11 @@ from tensordict.tensordict import TensorDict, TensorDictBase from torchrl.data import CompositeSpec, TensorSpec -from .utils import get_available_libraries, step_mdp from .._utils import prod, seed_generator from ..data.utils import DEVICE_TYPING +from .utils import get_available_libraries, step_mdp + LIBRARIES = get_available_libraries() From a721446e236d9e73e720c1ac1bc692a1807b7398 Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Sun, 8 Jan 2023 11:15:31 +0100 Subject: [PATCH 16/26] vmas support for _reset flag --- torchrl/envs/libs/vmas.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/torchrl/envs/libs/vmas.py b/torchrl/envs/libs/vmas.py index a059f91fd9d..35de76df4c7 100644 --- a/torchrl/envs/libs/vmas.py +++ b/torchrl/envs/libs/vmas.py @@ -2,7 +2,6 @@ import torch from tensordict.tensordict import TensorDict, TensorDictBase - from torchrl.data import CompositeSpec, UnboundedContinuousTensorSpec from torchrl.envs.common import _EnvWrapper from torchrl.envs.libs.gym import _gym_to_torchrl_spec_transform @@ -203,7 +202,18 @@ def _set_seed(self, seed: Optional[int]): def _reset( self, tensordict: Optional[TensorDictBase] = None, **kwargs ) -> TensorDictBase: - obs, infos = self._env.reset(return_info=True) + if tensordict is not None and "_reset" in tensordict.keys(): + envs_to_reset = tensordict.get("_reset").any(dim=0).squeeze(-1) + for env_index, to_reset in enumerate(envs_to_reset): + if to_reset: + self._env.reset_at(env_index) + obs = [] + infos = [] + for agent in self.agents: + obs.append(self.scenario.observation(agent)) + infos.append(self.scenario.info(agent)) + else: + obs, infos = self._env.reset(return_info=True) agent_tds = [] for i in range(self.n_agents): From ac34b084f6cd5af8e3d00c4ab066269fa6cfd7d4 Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Sun, 8 Jan 2023 16:50:43 +0100 Subject: [PATCH 17/26] refactor --- torchrl/envs/transforms/transforms.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index c07a5df5882..5a55c7d2b57 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -2616,7 +2616,7 @@ def __init__(self, max_steps: Optional[int] = None): super().__init__([]) def reset(self, tensordict: TensorDictBase) -> TensorDictBase: - workers = tensordict.get( + _reset = tensordict.get( "_reset", default=torch.ones( *tensordict.batch_size, 1, dtype=torch.bool, device=tensordict.device @@ -2624,7 +2624,7 @@ def reset(self, tensordict: TensorDictBase) -> TensorDictBase: ) tensordict.set( "step_count", - (~workers) + (~_reset) * tensordict.get( "step_count", torch.zeros( From 469b6d5e3583efa2ddc31f5223411b8263d8f60b Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Sun, 8 Jan 2023 18:28:15 +0100 Subject: [PATCH 18/26] close ParallelEnv --- test/test_env.py | 1 + 1 file changed, 1 insertion(+) diff --git a/test/test_env.py b/test/test_env.py index a144e677dca..99a65198b66 100644 --- a/test/test_env.py +++ b/test/test_env.py @@ -932,6 +932,7 @@ def test_parallel_env_reset_flag(self, batch_size, n_workers, max_steps=3): td_reset = env.reset( TensorDict({"_reset": _reset}, batch_size=env.batch_size, device=env.device) ) + env.close() if _reset.any(): assert (td_reset["done"][_reset] == 0).all() assert (td_reset["observation"][_reset] == 0).all() From 416b06621e912bd4cc4a6e0f42087a425a26d58a Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Mon, 9 Jan 2023 14:57:03 +0100 Subject: [PATCH 19/26] partial reset mirrored in step_count --- torchrl/collectors/collectors.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/torchrl/collectors/collectors.py b/torchrl/collectors/collectors.py index 9b47b4df46c..0108de33642 100644 --- a/torchrl/collectors/collectors.py +++ b/torchrl/collectors/collectors.py @@ -688,6 +688,7 @@ def reset(self, index=None, **kwargs) -> None: td_in = TensorDict({"_reset": _reset}, self.env.batch_size) self._tensordict[index].zero_() else: + _reset = None td_in = None self._tensordict.zero_() @@ -695,7 +696,10 @@ def reset(self, index=None, **kwargs) -> None: self._tensordict.update(td_in, inplace=True) self._tensordict.update(self.env.reset(**kwargs), inplace=True) - self._tensordict.fill_("step_count", 0) + if _reset is not None: + self._tensordict["step_count"][_reset] = 0 + else: + self._tensordict.fill_("step_count", 0) def shutdown(self) -> None: """Shuts down all workers and/or closes the local environment.""" From 9d5754e986ef9063a80521d68c7620431920526c Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Tue, 10 Jan 2023 13:35:01 +0100 Subject: [PATCH 20/26] Added torch seeding --- test/test_env.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/test/test_env.py b/test/test_env.py index 99a65198b66..6dd809ea10a 100644 --- a/test/test_env.py +++ b/test/test_env.py @@ -904,6 +904,7 @@ def env_fn2(seed): @pytest.mark.parametrize("batch_size", [(), (1,), (4,), (32, 5)]) @pytest.mark.parametrize("n_workers", [1, 2]) def test_parallel_env_reset_flag(self, batch_size, n_workers, max_steps=3): + torch.manual_seed(1) env = ParallelEnv( n_workers, lambda: CountingEnv(max_steps=max_steps, batch_size=batch_size) ) @@ -942,6 +943,7 @@ def test_parallel_env_reset_flag(self, batch_size, n_workers, max_steps=3): @pytest.mark.parametrize("batch_size", [(), (2,), (32, 5)]) def test_env_base_reset_flag(batch_size, max_steps=3): + torch.manual_seed(1) env = CountingEnv(max_steps=max_steps, batch_size=batch_size) env.set_seed(1) From e1dbaffab5364c2ce2b2d58b5680d324761dba58 Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Tue, 10 Jan 2023 13:42:09 +0100 Subject: [PATCH 21/26] Modified tests --- test/test_env.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/test/test_env.py b/test/test_env.py index 6dd809ea10a..286a9239b29 100644 --- a/test/test_env.py +++ b/test/test_env.py @@ -930,20 +930,24 @@ def test_parallel_env_reset_flag(self, batch_size, n_workers, max_steps=3): _reset = torch.randint( low=0, high=2, size=(*env.batch_size, 1), dtype=torch.bool ) + while not _reset.any(): + _reset = torch.randint( + low=0, high=2, size=(*env.batch_size, 1), dtype=torch.bool + ) + td_reset = env.reset( TensorDict({"_reset": _reset}, batch_size=env.batch_size, device=env.device) ) env.close() - if _reset.any(): - assert (td_reset["done"][_reset] == 0).all() - assert (td_reset["observation"][_reset] == 0).all() - assert (td_reset["done"][~_reset] == 1).all() - assert (td_reset["observation"][~_reset] == max_steps + 1).all() + + assert (td_reset["done"][_reset] == 0).all() + assert (td_reset["observation"][_reset] == 0).all() + assert (td_reset["done"][~_reset] == 1).all() + assert (td_reset["observation"][~_reset] == max_steps + 1).all() @pytest.mark.parametrize("batch_size", [(), (2,), (32, 5)]) def test_env_base_reset_flag(batch_size, max_steps=3): - torch.manual_seed(1) env = CountingEnv(max_steps=max_steps, batch_size=batch_size) env.set_seed(1) From d9dc888cae9e334f7f98ec546f2c7b7dfe987298 Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Tue, 10 Jan 2023 15:15:47 +0100 Subject: [PATCH 22/26] Removed last dim of 1 --- test/test_env.py | 14 +++++--------- test/test_transforms.py | 4 ++-- torchrl/collectors/collectors.py | 2 +- torchrl/envs/libs/vmas.py | 2 +- torchrl/envs/transforms/transforms.py | 2 +- torchrl/envs/vec_env.py | 4 ++-- 6 files changed, 12 insertions(+), 16 deletions(-) diff --git a/test/test_env.py b/test/test_env.py index 286a9239b29..e4d621a567c 100644 --- a/test/test_env.py +++ b/test/test_env.py @@ -512,7 +512,7 @@ def test_parallel_env( _ = env_parallel.step(td) td_reset = TensorDict( - source={"_reset": torch.zeros(N, 1, dtype=torch.bool).bernoulli_()}, + source={"_reset": torch.zeros(N, dtype=torch.bool).bernoulli_()}, batch_size=[ N, ], @@ -596,7 +596,7 @@ def test_parallel_env_with_policy( _ = env_parallel.step(td) td_reset = TensorDict( - source={"_reset": torch.zeros(N, 1, dtype=torch.bool).bernoulli_()}, + source={"_reset": torch.zeros(N, dtype=torch.bool).bernoulli_()}, batch_size=[ N, ], @@ -927,13 +927,9 @@ def test_parallel_env_reset_flag(self, batch_size, n_workers, max_steps=3): assert (td["done"] == 1).all() assert (td["next"]["observation"] == max_steps + 1).all() - _reset = torch.randint( - low=0, high=2, size=(*env.batch_size, 1), dtype=torch.bool - ) + _reset = torch.randint(low=0, high=2, size=env.batch_size, dtype=torch.bool) while not _reset.any(): - _reset = torch.randint( - low=0, high=2, size=(*env.batch_size, 1), dtype=torch.bool - ) + _reset = torch.randint(low=0, high=2, size=env.batch_size, dtype=torch.bool) td_reset = env.reset( TensorDict({"_reset": _reset}, batch_size=env.batch_size, device=env.device) @@ -967,7 +963,7 @@ def test_env_base_reset_flag(batch_size, max_steps=3): assert (td["done"] == 1).all() assert (td["next"]["observation"] == max_steps + 1).all() - _reset = torch.randint(low=0, high=2, size=(*env.batch_size, 1), dtype=torch.bool) + _reset = torch.randint(low=0, high=2, size=env.batch_size, dtype=torch.bool) td_reset = env.reset( TensorDict({"_reset": _reset}, batch_size=env.batch_size, device=env.device) ) diff --git a/test/test_transforms.py b/test/test_transforms.py index b20afb0e843..6419f6c579e 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -832,7 +832,7 @@ def test_sum_reward(self, keys, device): assert (td.get("episode_reward") == 2 * td.get("reward")).all() # reset environments - td.set("_reset", torch.ones((batch, 1), dtype=torch.bool, device=device)) + td.set("_reset", torch.ones(batch, dtype=torch.bool, device=device)) rs.reset(td) # apply a third time, episode_reward should be equal to reward again @@ -1724,7 +1724,7 @@ def test_step_counter(self, max_steps, device, batch, reset_workers): {"done": torch.zeros(*batch, 1, dtype=torch.bool)}, batch, device=device ) if reset_workers: - td.set("_reset", torch.randn(*batch, 1) < 0) + td.set("_reset", torch.randn(batch) < 0) step_counter.reset(td) assert not torch.all(td.get("step_count")) i = 0 diff --git a/torchrl/collectors/collectors.py b/torchrl/collectors/collectors.py index 0108de33642..e7742b70c98 100644 --- a/torchrl/collectors/collectors.py +++ b/torchrl/collectors/collectors.py @@ -680,7 +680,7 @@ def reset(self, index=None, **kwargs) -> None: if prod(self.env.batch_size) == 0: raise RuntimeError("resetting unique env with index is not permitted.") _reset = torch.zeros( - (*self.env.batch_size, 1), + self.env.batch_size, dtype=torch.bool, device=self.env.device, ) diff --git a/torchrl/envs/libs/vmas.py b/torchrl/envs/libs/vmas.py index 35de76df4c7..42eba338509 100644 --- a/torchrl/envs/libs/vmas.py +++ b/torchrl/envs/libs/vmas.py @@ -203,7 +203,7 @@ def _reset( self, tensordict: Optional[TensorDictBase] = None, **kwargs ) -> TensorDictBase: if tensordict is not None and "_reset" in tensordict.keys(): - envs_to_reset = tensordict.get("_reset").any(dim=0).squeeze(-1) + envs_to_reset = tensordict.get("_reset").any(dim=0) for env_index, to_reset in enumerate(envs_to_reset): if to_reset: self._env.reset_at(env_index) diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index 5a55c7d2b57..6a05a7f5e29 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -2619,7 +2619,7 @@ def reset(self, tensordict: TensorDictBase) -> TensorDictBase: _reset = tensordict.get( "_reset", default=torch.ones( - *tensordict.batch_size, 1, dtype=torch.bool, device=tensordict.device + tensordict.batch_size, dtype=torch.bool, device=tensordict.device ), ) tensordict.set( diff --git a/torchrl/envs/vec_env.py b/torchrl/envs/vec_env.py index 69072806cd9..b5e26035453 100644 --- a/torchrl/envs/vec_env.py +++ b/torchrl/envs/vec_env.py @@ -631,7 +631,7 @@ def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase: self._assert_tensordict_shape(tensordict) _reset = tensordict.get("_reset") else: - _reset = torch.ones((*self.batch_size, 1), dtype=torch.bool) + _reset = torch.ones(self.batch_size, dtype=torch.bool) keys = set() for i, _env in enumerate(self._envs): @@ -850,7 +850,7 @@ def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase: self._assert_tensordict_shape(tensordict) _reset = tensordict.get("_reset") else: - _reset = torch.ones((*self.batch_size, 1), dtype=torch.bool) + _reset = torch.ones(self.batch_size, dtype=torch.bool) for i, channel in enumerate(self.parent_channels): if not _reset[i].any(): From a244c95bf9eda41174847147c821327b445baefb Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Tue, 10 Jan 2023 15:25:15 +0100 Subject: [PATCH 23/26] removed lid dim of 1 to "step_count" --- torchrl/envs/transforms/transforms.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index 6a05a7f5e29..b51f55f0c12 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -2628,8 +2628,7 @@ def reset(self, tensordict: TensorDictBase) -> TensorDictBase: * tensordict.get( "step_count", torch.zeros( - *tensordict.batch_size, - 1, + tensordict.batch_size, dtype=torch.int64, device=tensordict.device, ), @@ -2642,8 +2641,7 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase: tensordict.get( "step_count", torch.zeros( - *tensordict.batch_size, - 1, + tensordict.batch_size, dtype=torch.int64, device=tensordict.device, ), @@ -2654,7 +2652,7 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase: if self.max_steps is not None: tensordict.set( "done", - tensordict.get("done") | next_step_count >= self.max_steps, + tensordict.get("done").squeeze(-1) | next_step_count >= self.max_steps, ) return tensordict From 5daadeead01a0e6c56bd99e975da2e559f6ed223 Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Tue, 10 Jan 2023 15:29:15 +0100 Subject: [PATCH 24/26] removed another 1 --- torchrl/envs/transforms/transforms.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index b51f55f0c12..f3fd2c575a5 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -2518,8 +2518,7 @@ def reset(self, tensordict: TensorDictBase) -> TensorDictBase: _reset = tensordict.get( "_reset", torch.ones( - *tensordict.batch_size, - 1, + tensordict.batch_size, dtype=torch.bool, device=tensordict.device, ), From 00c351b7907f0b66d4d9ffe4f66c1141c9e972a4 Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Tue, 10 Jan 2023 15:32:35 +0100 Subject: [PATCH 25/26] change spec of StepCount --- torchrl/envs/transforms/transforms.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index f3fd2c575a5..7170e52002a 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -2663,7 +2663,7 @@ def transform_observation_spec( f"observation_spec was expected to be of type CompositeSpec. Got {type(observation_spec)} instead." ) observation_spec["step_count"] = UnboundedDiscreteTensorSpec( - shape=torch.Size([1]), dtype=torch.int64, device=observation_spec.device + shape=torch.Size([]), dtype=torch.int64, device=observation_spec.device ) observation_spec["step_count"].space.minimum = 0 return observation_spec From 6f9913339917a5ee134581087af4a8e405fbb50a Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Tue, 10 Jan 2023 15:35:59 +0100 Subject: [PATCH 26/26] set the done properly in StepCount --- torchrl/envs/transforms/transforms.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index 7170e52002a..34286662d94 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -2651,7 +2651,8 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase: if self.max_steps is not None: tensordict.set( "done", - tensordict.get("done").squeeze(-1) | next_step_count >= self.max_steps, + tensordict.get("done") + | (next_step_count >= self.max_steps).unsqueeze(-1), ) return tensordict