diff --git a/test/mocking_classes.py b/test/mocking_classes.py index c949a5e094f..39525fec36f 100644 --- a/test/mocking_classes.py +++ b/test/mocking_classes.py @@ -7,7 +7,6 @@ import torch import torch.nn as nn from tensordict.tensordict import TensorDict, TensorDictBase - from torchrl.data.tensor_specs import ( BinaryDiscreteTensorSpec, BoundedTensorSpec, @@ -718,3 +717,53 @@ def __init__(self, in_size, out_size): def forward(self, observation, action): return self.linear(torch.cat([observation, action], dim=-1)) + + +class CountingEnv(EnvBase): + def __init__(self, max_steps: int = 5, **kwargs): + super().__init__(**kwargs) + self.max_steps = max_steps + + self.observation_spec = CompositeSpec( + observation=UnboundedContinuousTensorSpec((1,)) + ) + self.reward_spec = UnboundedContinuousTensorSpec((1,)) + self.input_spec = CompositeSpec(action=BinaryDiscreteTensorSpec(1)) + + self.count = torch.zeros( + (*self.batch_size, 1), device=self.device, dtype=torch.int + ) + + def _set_seed(self, seed: Optional[int]): + torch.manual_seed(seed) + + def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase: + if tensordict is not None and "_reset" in tensordict.keys(): + _reset = tensordict.get("_reset") + self.count[_reset] = 0 + else: + self.count[:] = 0 + return TensorDict( + source={ + "observation": self.count.clone(), + "done": self.count > self.max_steps, + }, + batch_size=self.batch_size, + device=self.device, + ) + + def _step( + self, + tensordict: TensorDictBase, + ) -> TensorDictBase: + action = tensordict.get("action") + self.count += action.to(torch.int) + return TensorDict( + source={ + "observation": self.count, + "done": self.count > self.max_steps, + "reward": torch.zeros_like(self.count, dtype=torch.float), + }, + batch_size=self.batch_size, + device=self.device, + ) diff --git a/test/test_env.py b/test/test_env.py index 4661211f863..e4d621a567c 100644 --- a/test/test_env.py +++ b/test/test_env.py @@ -20,6 +20,7 @@ ) from mocking_classes import ( ActionObsMergeLinear, + CountingEnv, DiscreteActionConvMockEnv, DiscreteActionVecMockEnv, DummyModelBasedEnvBase, @@ -511,7 +512,7 @@ def test_parallel_env( _ = env_parallel.step(td) td_reset = TensorDict( - source={"reset_workers": torch.zeros(N, dtype=torch.bool).bernoulli_()}, + source={"_reset": torch.zeros(N, dtype=torch.bool).bernoulli_()}, batch_size=[ N, ], @@ -595,7 +596,7 @@ def test_parallel_env_with_policy( _ = env_parallel.step(td) td_reset = TensorDict( - source={"reset_workers": torch.zeros(N, dtype=torch.bool).bernoulli_()}, + source={"_reset": torch.zeros(N, dtype=torch.bool).bernoulli_()}, batch_size=[ N, ], @@ -900,6 +901,78 @@ def env_fn2(seed): env1.close() env2.close() + @pytest.mark.parametrize("batch_size", [(), (1,), (4,), (32, 5)]) + @pytest.mark.parametrize("n_workers", [1, 2]) + def test_parallel_env_reset_flag(self, batch_size, n_workers, max_steps=3): + torch.manual_seed(1) + env = ParallelEnv( + n_workers, lambda: CountingEnv(max_steps=max_steps, batch_size=batch_size) + ) + env.set_seed(1) + action = env.action_spec.rand(env.batch_size) + action[:] = 1 + + for i in range(max_steps): + td = env.step( + TensorDict( + {"action": action}, batch_size=env.batch_size, device=env.device + ) + ) + assert (td["done"] == 0).all() + assert (td["next"]["observation"] == i + 1).all() + + td = env.step( + TensorDict({"action": action}, batch_size=env.batch_size, device=env.device) + ) + assert (td["done"] == 1).all() + assert (td["next"]["observation"] == max_steps + 1).all() + + _reset = torch.randint(low=0, high=2, size=env.batch_size, dtype=torch.bool) + while not _reset.any(): + _reset = torch.randint(low=0, high=2, size=env.batch_size, dtype=torch.bool) + + td_reset = env.reset( + TensorDict({"_reset": _reset}, batch_size=env.batch_size, device=env.device) + ) + env.close() + + assert (td_reset["done"][_reset] == 0).all() + assert (td_reset["observation"][_reset] == 0).all() + assert (td_reset["done"][~_reset] == 1).all() + assert (td_reset["observation"][~_reset] == max_steps + 1).all() + + +@pytest.mark.parametrize("batch_size", [(), (2,), (32, 5)]) +def test_env_base_reset_flag(batch_size, max_steps=3): + env = CountingEnv(max_steps=max_steps, batch_size=batch_size) + env.set_seed(1) + + action = env.action_spec.rand(env.batch_size) + action[:] = 1 + + for i in range(max_steps): + td = env.step( + TensorDict({"action": action}, batch_size=env.batch_size, device=env.device) + ) + assert (td["done"] == 0).all() + assert (td["next"]["observation"] == i + 1).all() + + td = env.step( + TensorDict({"action": action}, batch_size=env.batch_size, device=env.device) + ) + assert (td["done"] == 1).all() + assert (td["next"]["observation"] == max_steps + 1).all() + + _reset = torch.randint(low=0, high=2, size=env.batch_size, dtype=torch.bool) + td_reset = env.reset( + TensorDict({"_reset": _reset}, batch_size=env.batch_size, device=env.device) + ) + + assert (td_reset["done"][_reset] == 0).all() + assert (td_reset["observation"][_reset] == 0).all() + assert (td_reset["done"][~_reset] == 1).all() + assert (td_reset["observation"][~_reset] == max_steps + 1).all() + @pytest.mark.skipif(not _has_gym, reason="no gym") def test_seed(): diff --git a/test/test_transforms.py b/test/test_transforms.py index 8d3e40f7550..6419f6c579e 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -832,7 +832,7 @@ def test_sum_reward(self, keys, device): assert (td.get("episode_reward") == 2 * td.get("reward")).all() # reset environments - td.set("reset_workers", torch.ones((batch, 1), dtype=torch.bool, device=device)) + td.set("_reset", torch.ones(batch, dtype=torch.bool, device=device)) rs.reset(td) # apply a third time, episode_reward should be equal to reward again @@ -1724,7 +1724,7 @@ def test_step_counter(self, max_steps, device, batch, reset_workers): {"done": torch.zeros(*batch, 1, dtype=torch.bool)}, batch, device=device ) if reset_workers: - td.set("reset_workers", torch.randn(*batch, 1) < 0) + td.set("_reset", torch.randn(batch) < 0) step_counter.reset(td) assert not torch.all(td.get("step_count")) i = 0 @@ -1740,10 +1740,10 @@ def test_step_counter(self, max_steps, device, batch, reset_workers): step_counter.reset(td) if reset_workers: assert torch.all( - torch.masked_select(td.get("step_count"), td.get("reset_workers")) == 0 + torch.masked_select(td.get("step_count"), td.get("_reset")) == 0 ) assert torch.all( - torch.masked_select(td.get("step_count"), ~td.get("reset_workers")) == i + torch.masked_select(td.get("step_count"), ~td.get("_reset")) == i ) else: assert torch.all(td.get("step_count") == 0) diff --git a/torchrl/collectors/collectors.py b/torchrl/collectors/collectors.py index 3a288c589ef..e7742b70c98 100644 --- a/torchrl/collectors/collectors.py +++ b/torchrl/collectors/collectors.py @@ -22,13 +22,11 @@ from tensordict.tensordict import TensorDict, TensorDictBase from torch import multiprocessing as mp from torch.utils.data import IterableDataset - from torchrl._utils import _check_for_faulty_process, prod from torchrl.collectors.utils import split_trajectories from torchrl.data import TensorSpec from torchrl.data.utils import CloudpickleWrapper, DEVICE_TYPING from torchrl.envs.common import EnvBase - from torchrl.envs.transforms import TransformedEnv from torchrl.envs.utils import set_exploration_mode, step_mdp from torchrl.envs.vec_env import _BatchedEnv @@ -615,7 +613,7 @@ def _reset_if_necessary(self) -> None: steps = steps.clone() if len(self.env.batch_size): self._tensordict.masked_fill_(done_or_terminated, 0) - self._tensordict.set("reset_workers", done_or_terminated) + self._tensordict.set("_reset", done_or_terminated) else: self._tensordict.zero_() self.env.reset(self._tensordict) @@ -624,8 +622,6 @@ def _reset_if_necessary(self) -> None: raise RuntimeError( f"Got {sum(self._tensordict.get('done'))} done envs after reset." ) - if len(self.env.batch_size): - self._tensordict.del_("reset_workers") traj_ids[done_or_terminated] = traj_ids.max() + torch.arange( 1, done_or_terminated.sum() + 1, device=traj_ids.device ) @@ -683,15 +679,16 @@ def reset(self, index=None, **kwargs) -> None: # check that the env supports partial reset if prod(self.env.batch_size) == 0: raise RuntimeError("resetting unique env with index is not permitted.") - reset_workers = torch.zeros( - *self.env.batch_size, + _reset = torch.zeros( + self.env.batch_size, dtype=torch.bool, device=self.env.device, ) - reset_workers[index] = 1 - td_in = TensorDict({"reset_workers": reset_workers}, self.env.batch_size) + _reset[index] = 1 + td_in = TensorDict({"_reset": _reset}, self.env.batch_size) self._tensordict[index].zero_() else: + _reset = None td_in = None self._tensordict.zero_() @@ -699,7 +696,10 @@ def reset(self, index=None, **kwargs) -> None: self._tensordict.update(td_in, inplace=True) self._tensordict.update(self.env.reset(**kwargs), inplace=True) - self._tensordict.fill_("step_count", 0) + if _reset is not None: + self._tensordict["step_count"][_reset] = 0 + else: + self._tensordict.fill_("step_count", 0) def shutdown(self) -> None: """Shuts down all workers and/or closes the local environment.""" diff --git a/torchrl/envs/common.py b/torchrl/envs/common.py index aec32eff395..e455e4412ef 100644 --- a/torchrl/envs/common.py +++ b/torchrl/envs/common.py @@ -14,11 +14,11 @@ import torch import torch.nn as nn from tensordict.tensordict import TensorDict, TensorDictBase - from torchrl.data import CompositeSpec, TensorSpec from .._utils import prod, seed_generator from ..data.utils import DEVICE_TYPING + from .utils import get_available_libraries, step_mdp LIBRARIES = get_available_libraries() @@ -428,6 +428,12 @@ def reset( a tensordict (or the input tensordict, if any), modified in place with the resulting observations. """ + if tensordict is not None and "_reset" in tensordict.keys(): + self._assert_tensordict_shape(tensordict) + _reset = tensordict.get("_reset") + else: + _reset = None + tensordict_reset = self._reset(tensordict, **kwargs) done = tensordict_reset.get("done", None) @@ -457,11 +463,16 @@ def reset( *tensordict_reset.batch_size, 1, dtype=torch.bool, device=self.device ), ) - if tensordict_reset.get("done").any(): + + if (_reset is None and tensordict_reset.get("done").any()) or ( + _reset is not None and tensordict_reset.get("done")[_reset].any() + ): raise RuntimeError( - f"Env {self} was done after reset. This is (currently) not allowed." + f"Env {self} was done after reset on specified '_reset' dimensions. This is (currently) not allowed." ) if tensordict is not None: + if "_reset" in tensordict.keys(): + tensordict.del_("_reset") tensordict.update(tensordict_reset) else: tensordict = tensordict_reset diff --git a/torchrl/envs/libs/vmas.py b/torchrl/envs/libs/vmas.py index a059f91fd9d..42eba338509 100644 --- a/torchrl/envs/libs/vmas.py +++ b/torchrl/envs/libs/vmas.py @@ -2,7 +2,6 @@ import torch from tensordict.tensordict import TensorDict, TensorDictBase - from torchrl.data import CompositeSpec, UnboundedContinuousTensorSpec from torchrl.envs.common import _EnvWrapper from torchrl.envs.libs.gym import _gym_to_torchrl_spec_transform @@ -203,7 +202,18 @@ def _set_seed(self, seed: Optional[int]): def _reset( self, tensordict: Optional[TensorDictBase] = None, **kwargs ) -> TensorDictBase: - obs, infos = self._env.reset(return_info=True) + if tensordict is not None and "_reset" in tensordict.keys(): + envs_to_reset = tensordict.get("_reset").any(dim=0) + for env_index, to_reset in enumerate(envs_to_reset): + if to_reset: + self._env.reset_at(env_index) + obs = [] + infos = [] + for agent in self.agents: + obs.append(self.scenario.observation(agent)) + infos.append(self.scenario.info(agent)) + else: + obs, infos = self._env.reset(return_info=True) agent_tds = [] for i in range(self.n_agents): diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index 3dc630b7f31..34286662d94 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -14,7 +14,6 @@ import torch from tensordict.tensordict import TensorDict, TensorDictBase from torch import nn, Tensor - from torchrl.data.tensor_specs import ( BinaryDiscreteTensorSpec, BoundedTensorSpec, @@ -2516,18 +2515,17 @@ def reset(self, tensordict: TensorDictBase) -> TensorDictBase: # Batched environments else: - reset_workers = tensordict.get( - "reset_workers", + _reset = tensordict.get( + "_reset", torch.ones( - *tensordict.batch_size, - 1, + tensordict.batch_size, dtype=torch.bool, device=tensordict.device, ), ) for out_key in self.out_keys: if out_key in tensordict.keys(): - tensordict[out_key][reset_workers] = 0.0 + tensordict[out_key][_reset] = 0.0 return tensordict @@ -2617,20 +2615,19 @@ def __init__(self, max_steps: Optional[int] = None): super().__init__([]) def reset(self, tensordict: TensorDictBase) -> TensorDictBase: - workers = tensordict.get( - "reset_workers", + _reset = tensordict.get( + "_reset", default=torch.ones( - *tensordict.batch_size, 1, dtype=torch.bool, device=tensordict.device + tensordict.batch_size, dtype=torch.bool, device=tensordict.device ), ) tensordict.set( "step_count", - (~workers) + (~_reset) * tensordict.get( "step_count", torch.zeros( - *tensordict.batch_size, - 1, + tensordict.batch_size, dtype=torch.int64, device=tensordict.device, ), @@ -2643,8 +2640,7 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase: tensordict.get( "step_count", torch.zeros( - *tensordict.batch_size, - 1, + tensordict.batch_size, dtype=torch.int64, device=tensordict.device, ), @@ -2655,7 +2651,8 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase: if self.max_steps is not None: tensordict.set( "done", - tensordict.get("done") | next_step_count >= self.max_steps, + tensordict.get("done") + | (next_step_count >= self.max_steps).unsqueeze(-1), ) return tensordict @@ -2667,7 +2664,7 @@ def transform_observation_spec( f"observation_spec was expected to be of type CompositeSpec. Got {type(observation_spec)} instead." ) observation_spec["step_count"] = UnboundedDiscreteTensorSpec( - shape=torch.Size([1]), dtype=torch.int64, device=observation_spec.device + shape=torch.Size([]), dtype=torch.int64, device=observation_spec.device ) observation_spec["step_count"].space.minimum = 0 return observation_spec diff --git a/torchrl/envs/vec_env.py b/torchrl/envs/vec_env.py index 733de7d1c45..b5e26035453 100644 --- a/torchrl/envs/vec_env.py +++ b/torchrl/envs/vec_env.py @@ -626,18 +626,21 @@ def set_seed( @_check_start def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase: - if tensordict is not None and "reset_workers" in tensordict.keys(): + + if tensordict is not None and "_reset" in tensordict.keys(): self._assert_tensordict_shape(tensordict) - reset_workers = tensordict.get("reset_workers") + _reset = tensordict.get("_reset") else: - reset_workers = torch.ones(self.num_workers, dtype=torch.bool) + _reset = torch.ones(self.batch_size, dtype=torch.bool) keys = set() for i, _env in enumerate(self._envs): - if not reset_workers[i]: + if not _reset[i].any(): continue _tensordict = tensordict[i] if tensordict is not None else None _td = _env._reset(tensordict=_tensordict, **kwargs) + if "_reset" in _td.keys(): + _td.del_("_reset") keys = keys.union(_td.keys()) self.shared_tensordicts[i].update_(_td) @@ -843,29 +846,32 @@ def set_seed( @_check_start def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase: cmd_out = "reset" - if tensordict is not None and "reset_workers" in tensordict.keys(): + if tensordict is not None and "_reset" in tensordict.keys(): self._assert_tensordict_shape(tensordict) - reset_workers = tensordict.get("reset_workers") + _reset = tensordict.get("_reset") else: - reset_workers = torch.ones(self.num_workers, dtype=torch.bool) + _reset = torch.ones(self.batch_size, dtype=torch.bool) for i, channel in enumerate(self.parent_channels): - if not reset_workers[i]: + if not _reset[i].any(): continue + kwargs["tensordict"] = tensordict[i] if tensordict is not None else None channel.send((cmd_out, kwargs)) keys = set() for i, channel in enumerate(self.parent_channels): - if not reset_workers[i]: + if not _reset[i].any(): continue cmd_in, new_keys = channel.recv() keys = keys.union(new_keys) if cmd_in != "reset_obs": raise RuntimeError(f"received cmd {cmd_in} instead of reset_obs") check_count = 0 - while self.shared_tensordict_parent.get("done").any(): + while self.shared_tensordict_parent.get("done")[_reset].any(): if check_count == 4: - raise RuntimeError("Envs have just been reset but some are still done") + raise RuntimeError( + "Envs have just been reset bur env is done on specified '_reset' dimensions." + ) else: check_count += 1 # there might be some delay between writing the shared tensordict @@ -1008,6 +1014,8 @@ def _run_worker_pipe_shared_mem( raise RuntimeError("call 'init' before resetting") # _td = tensordict.select("observation").to(env.device).clone() _td = env._reset(**reset_kwargs) + if "_reset" in _td.keys(): + _td.del_("_reset") done = _td.get("done", None) if done is None: _td["done"] = done = torch.zeros( @@ -1019,8 +1027,6 @@ def _run_worker_pipe_shared_mem( _td.pin_memory() tensordict.update_(_td) child_pipe.send(("reset_obs", reset_keys)) - if done.any(): - raise RuntimeError(f"{env.__class__.__name__} is done after reset") elif cmd == "step": if not initialized: