diff --git a/test/mocking_classes.py b/test/mocking_classes.py index 9e7fa98cdfa..e86dfb13136 100644 --- a/test/mocking_classes.py +++ b/test/mocking_classes.py @@ -1127,10 +1127,15 @@ def __init__( shape=self.batch_size, ) - def _reset(self, td): - if self.nested_done and td is not None and "_reset" in td.keys(): - td["_reset"] = td["_reset"].sum(-2, dtype=torch.bool) - td = super()._reset(td) + def _reset(self, tensordict): + if ( + self.nested_done + and tensordict is not None + and "_reset" in tensordict.keys() + ): + tensordict = tensordict.clone() + tensordict["_reset"] = tensordict["_reset"].sum(-2, dtype=torch.bool) + td = super()._reset(tensordict) if self.nested_done: td[self.done_key] = ( td["done"].unsqueeze(-1).expand(*self.batch_size, self.nested_dim, 1) @@ -1149,6 +1154,7 @@ def _reset(self, td): def _step(self, td): if self.nested_obs_action: + td = td.clone() td["data"].batch_size = self.batch_size td[self.action_key] = td[self.action_key].max(-2)[0] td_root = super()._step(td) diff --git a/test/test_env.py b/test/test_env.py index acd44c37d0b..cc2495f3fa9 100644 --- a/test/test_env.py +++ b/test/test_env.py @@ -943,6 +943,66 @@ def test_parallel_env_reset_flag(self, batch_size, n_workers, max_steps=3): assert (td_reset["done"][~_reset] == 1).all() assert (td_reset["observation"][~_reset] == max_steps + 1).all() + @pytest.mark.parametrize("nested_obs_action", [True, False]) + @pytest.mark.parametrize("nested_done", [True, False]) + @pytest.mark.parametrize("nested_reward", [True, False]) + @pytest.mark.parametrize("env_type", ["serial", "parallel"]) + def test_parallel_env_nested( + self, + nested_obs_action, + nested_done, + nested_reward, + env_type, + n_envs=2, + batch_size=(32,), + nested_dim=5, + rollout_length=3, + seed=1, + ): + env_fn = lambda: NestedCountingEnv( + nest_done=nested_done, + nest_reward=nested_reward, + nest_obs_action=nested_obs_action, + batch_size=batch_size, + nested_dim=nested_dim, + ) + if env_type == "serial": + env = SerialEnv(n_envs, env_fn) + else: + env = ParallelEnv(n_envs, env_fn) + env.set_seed(seed) + + batch_size = (n_envs, *batch_size) + + td = env.reset() + assert td.batch_size == batch_size + if nested_done or nested_obs_action: + assert td["data"].batch_size == (*batch_size, nested_dim) + if not nested_done and not nested_reward and not nested_obs_action: + assert "data" not in td.keys() + + policy = CountingEnvCountPolicy(env.action_spec, env.action_key) + td = env.rollout(rollout_length, policy) + assert td.batch_size == (*batch_size, rollout_length) + if nested_done or nested_obs_action: + assert td["data"].batch_size == (*batch_size, rollout_length, nested_dim) + if nested_reward or nested_done or nested_obs_action: + assert td["next", "data"].batch_size == ( + *batch_size, + rollout_length, + nested_dim, + ) + if not nested_done and not nested_reward and not nested_obs_action: + assert "data" not in td.keys() + assert "data" not in td["next"].keys() + + if nested_obs_action: + assert "observation" not in td.keys() + assert (td[..., -1]["data", "states"] == 2).all() + else: + assert ("data", "states") not in td.keys(True, True) + assert (td[..., -1]["observation"] == 2).all() + @pytest.mark.parametrize("batch_size", [(), (2,), (32, 5)]) def test_env_base_reset_flag(batch_size, max_steps=3): diff --git a/torchrl/envs/utils.py b/torchrl/envs/utils.py index d9a774ea8e6..152b65ba3ed 100644 --- a/torchrl/envs/utils.py +++ b/torchrl/envs/utils.py @@ -223,7 +223,10 @@ def step_mdp( return out -def _set_single_key(source, dest, key): +def _set_single_key(source, dest, key, clone=False): + # key should be already unraveled + if isinstance(key, str): + key = (key,) for k in key: val = source.get(k) if is_tensor_collection(val): @@ -234,6 +237,8 @@ def _set_single_key(source, dest, key): source = val dest = new_val else: + if clone: + val = val.clone() dest._set(k, val) @@ -482,6 +487,7 @@ def __get__(self, owner_self, owner_cls): def _sort_keys(element): if isinstance(element, tuple): + element = unravel_keys(element) return "_-|-_".join(element) return element diff --git a/torchrl/envs/vec_env.py b/torchrl/envs/vec_env.py index fc853270a3f..b25ec2c17e5 100644 --- a/torchrl/envs/vec_env.py +++ b/torchrl/envs/vec_env.py @@ -20,6 +20,7 @@ import torch from tensordict import TensorDict from tensordict.tensordict import LazyStackedTensorDict, TensorDictBase +from tensordict.utils import unravel_keys from torch import multiprocessing as mp from torchrl._utils import _check_for_faulty_process, VERBOSE @@ -33,8 +34,7 @@ from torchrl.envs.common import _EnvWrapper, EnvBase from torchrl.envs.env_creator import get_env_metadata -from torchrl.envs.utils import _sort_keys - +from torchrl.envs.utils import _set_single_key, _sort_keys _has_envpool = importlib.util.find_spec("envpool") @@ -324,28 +324,26 @@ def _create_td(self) -> None: if self._single_task: self.env_input_keys = sorted( - list(self.input_spec["_action_spec"].keys(True)) - + list(self.state_spec.keys(True)), + list(self.input_spec["_action_spec"].keys(True, True)) + + list(self.state_spec.keys(True, True)), key=_sort_keys, ) self.env_output_keys = [] self.env_obs_keys = [] - for key in self.output_spec["_observation_spec"].keys(True): - if isinstance(key, str): - key = (key,) - self.env_output_keys.append(("next", *key)) + for key in self.output_spec["_observation_spec"].keys(True, True): + self.env_output_keys.append(unravel_keys(("next", key))) self.env_obs_keys.append(key) - self.env_output_keys.append(("next", "reward")) - self.env_output_keys.append(("next", "done")) + self.env_output_keys.append(unravel_keys(("next", self.reward_key))) + self.env_output_keys.append(unravel_keys(("next", self.done_key))) else: env_input_keys = set() for meta_data in self.meta_data: if meta_data.specs["input_spec", "_state_spec"] is not None: env_input_keys = env_input_keys.union( - meta_data.specs["input_spec", "_state_spec"].keys(True) + meta_data.specs["input_spec", "_state_spec"].keys(True, True) ) env_input_keys = env_input_keys.union( - meta_data.specs["input_spec", "_action_spec"].keys(True) + meta_data.specs["input_spec", "_action_spec"].keys(True, True) ) env_output_keys = set() env_obs_keys = set() @@ -353,17 +351,20 @@ def _create_td(self) -> None: env_obs_keys = env_obs_keys.union( key for key in meta_data.specs["output_spec"]["_observation_spec"].keys( - True + True, True ) ) env_output_keys = env_output_keys.union( - ("next", key) if isinstance(key, str) else ("next", *key) + unravel_keys(("next", key)) for key in meta_data.specs["output_spec"]["_observation_spec"].keys( - True + True, True ) ) env_output_keys = env_output_keys.union( - {("next", "reward"), ("next", "done")} + { + unravel_keys(("next", self.reward_key)), + unravel_keys(("next", self.done_key)), + } ) self.env_obs_keys = sorted(env_obs_keys, key=_sort_keys) self.env_input_keys = sorted(env_input_keys, key=_sort_keys) @@ -374,10 +375,10 @@ def _create_td(self) -> None: .union(self.env_input_keys) .union(self.env_obs_keys) ) - self._selected_keys.add("done") + self._selected_keys.add(self.done_key) self._selected_keys.add("_reset") - self._selected_reset_keys = self.env_obs_keys + ["done"] + ["_reset"] + self._selected_reset_keys = self.env_obs_keys + [self.done_key] + ["_reset"] self._selected_step_keys = self.env_output_keys if self._single_task: @@ -550,7 +551,7 @@ def _step( if self._single_task: out = TensorDict({}, batch_size=self.shared_tensordict_parent.shape) for key in self._selected_step_keys: - out._set(key, self.shared_tensordict_parent.get(key).clone()) + _set_single_key(self.shared_tensordict_parent, out, key, clone=True) else: # strict=False ensures that non-homogeneous keys are still there out = self.shared_tensordict_parent.select( @@ -619,7 +620,7 @@ def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase: out = TensorDict({}, batch_size=self.shared_tensordict_parent.shape) for key in self._selected_reset_keys: if key != "_reset": - out._set(key, self.shared_tensordict_parent.get(key).clone()) + _set_single_key(self.shared_tensordict_parent, out, key, clone=True) return out else: return self.shared_tensordict_parent.select( @@ -790,7 +791,7 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase: if self._single_task: out = TensorDict({}, batch_size=self.shared_tensordict_parent.shape) for key in self._selected_step_keys: - out._set(key, self.shared_tensordict_parent.get(key).clone()) + _set_single_key(self.shared_tensordict_parent, out, key, clone=True) else: # strict=False ensures that non-homogeneous keys are still there out = self.shared_tensordict_parent.select( @@ -853,7 +854,7 @@ def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase: out = TensorDict({}, batch_size=self.shared_tensordict_parent.shape) for key in self._selected_reset_keys: if key != "_reset": - out._set(key, self.shared_tensordict_parent.get(key).clone()) + _set_single_key(self.shared_tensordict_parent, out, key, clone=True) return out else: return self.shared_tensordict_parent.select( @@ -1187,7 +1188,7 @@ def _reset(self, tensordict: TensorDictBase) -> TensorDictBase: @torch.no_grad() def _step(self, tensordict: TensorDictBase) -> TensorDictBase: - action = tensordict.get("action") + action = tensordict.get(self.action_key) # Action needs to be moved to CPU and converted to numpy before being passed to envpool action = action.to(torch.device("cpu")) step_output = self._env.step(action.numpy()) @@ -1285,7 +1286,7 @@ def _transform_reset_output( ) obs = self.obs.clone(False) - obs.update({"done": self.done_spec.zero()}) + obs.update({self.done_key: self.done_spec.zero()}) return obs def _transform_step_output( @@ -1295,7 +1296,7 @@ def _transform_step_output( obs, reward, done, *_ = envpool_output obs = self._treevalue_or_numpy_to_tensor_or_dict(obs) - obs.update({"reward": torch.tensor(reward), "done": done}) + obs.update({self.reward_key: torch.tensor(reward), self.done_key: done}) self.obs = tensordict_out = TensorDict( obs, batch_size=self.batch_size,