From dafe892e17dafed75c774ef5f262ba3438caaa71 Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Mon, 3 Jul 2023 09:01:27 +0100 Subject: [PATCH 01/10] refactor fake_tensordict Signed-off-by: Matteo Bettini --- torchrl/envs/common.py | 22 ++++++++++------------ 1 file changed, 10 insertions(+), 12 deletions(-) diff --git a/torchrl/envs/common.py b/torchrl/envs/common.py index 497a8d767e2..9f105a790e2 100644 --- a/torchrl/envs/common.py +++ b/torchrl/envs/common.py @@ -1306,18 +1306,22 @@ def fake_tensordict(self) -> TensorDictBase: state_spec = self.state_spec observation_spec = self.observation_spec action_spec = self.input_spec["_action_spec"] + # instantiates reward_spec if needed _ = self.reward_spec reward_spec = self.output_spec["_reward_spec"] + # instantiates done_spec if needed _ = self.done_spec done_spec = self.output_spec["_done_spec"] fake_obs = observation_spec.zero() - fake_input = state_spec.zero() - fake_input = fake_input.update(action_spec.zero()) + + fake_state = state_spec.zero() + fake_action = action_spec.zero() + fake_input = fake_state.update(fake_action) # the input and output key may match, but the output prevails # Hence we generate the input, and override using the output - fake_in_out = fake_input.clone().update(fake_obs) + fake_in_out = fake_input.update(fake_obs) fake_reward = reward_spec.zero() fake_done = done_spec.zero() @@ -1325,17 +1329,11 @@ def fake_tensordict(self) -> TensorDictBase: next_output = fake_obs.clone() next_output.update(fake_reward) next_output.update(fake_done) - fake_in_out.update(fake_done.clone()) - fake_td = TensorDict( - { - **fake_in_out, - "next": next_output, - }, - batch_size=self.batch_size, - device=self.device, - ) + fake_td = fake_in_out.set("next", next_output) + fake_td.batch_size = self.batch_size + fake_td = fake_td.to(self.device) return fake_td From 2cf6616837c759e5ce0090de84c1e170411507cd Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Mon, 3 Jul 2023 09:22:29 +0100 Subject: [PATCH 02/10] refactor rollout Signed-off-by: Matteo Bettini --- torchrl/envs/common.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/torchrl/envs/common.py b/torchrl/envs/common.py index 9f105a790e2..88ec42e7672 100644 --- a/torchrl/envs/common.py +++ b/torchrl/envs/common.py @@ -1230,9 +1230,6 @@ def policy(td): return td tensordicts = [] - done_key = self.done_key - if not isinstance(done_key, tuple): - done_key = (done_key,) for i in range(max_steps): if auto_cast_to_device: tensordict = tensordict.to(policy_device, non_blocking=True) @@ -1242,7 +1239,7 @@ def policy(td): tensordict = self.step(tensordict) tensordicts.append(tensordict.clone(False)) - done = tensordict.get(("next", *done_key)) + done = tensordict.get(("next", self.done_key)) truncated = tensordict.get( ("next", "truncated"), default=torch.zeros((), device=done.device, dtype=torch.bool), @@ -1252,8 +1249,12 @@ def policy(td): break tensordict = step_mdp( tensordict, - keep_other=True, - exclude_action=False, + keep_other=False, + exclude_action=True, + exclude_reward=True, + reward_key=self.reward_key, + action_key=self.action_key, + done_key=self.done_key, ) if not break_when_any_done and done.any(): _reset = done.clone() From 0ecf9f332adbc62159074cf11586e6899b8b4326 Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Mon, 3 Jul 2023 09:41:41 +0100 Subject: [PATCH 03/10] refactor rand_action Signed-off-by: Matteo Bettini --- test/test_env.py | 15 ++++++++++++++- torchrl/envs/common.py | 17 ++++++++--------- 2 files changed, 22 insertions(+), 10 deletions(-) diff --git a/test/test_env.py b/test/test_env.py index 6839be5e885..3b38dea5b45 100644 --- a/test/test_env.py +++ b/test/test_env.py @@ -1634,7 +1634,15 @@ def test_nested_env_dims(self, batch_size, nested_dim=5, rollout_length=3): env = NestedCountingEnv(batch_size=batch_size, nested_dim=nested_dim) - td = env.reset() + td_reset = env.reset() + assert td_reset.batch_size == batch_size + assert td_reset["data"].batch_size == (*batch_size, nested_dim) + + td = env.rand_action() + assert td.batch_size == batch_size + assert td["data"].batch_size == (*batch_size, nested_dim) + + td = env.rand_action(td_reset) assert td.batch_size == batch_size assert td["data"].batch_size == (*batch_size, nested_dim) @@ -1643,6 +1651,11 @@ def test_nested_env_dims(self, batch_size, nested_dim=5, rollout_length=3): assert td["data"].batch_size == (*batch_size, nested_dim) assert td["next", "data"].batch_size == (*batch_size, nested_dim) + td = env.rand_step(td_reset) + assert td.batch_size == batch_size + assert td["data"].batch_size == (*batch_size, nested_dim) + assert td["next", "data"].batch_size == (*batch_size, nested_dim) + td = env.rollout(rollout_length) assert td.batch_size == (*batch_size, rollout_length) assert td["data"].batch_size == (*batch_size, rollout_length, nested_dim) diff --git a/torchrl/envs/common.py b/torchrl/envs/common.py index 88ec42e7672..921dbc8dbe8 100644 --- a/torchrl/envs/common.py +++ b/torchrl/envs/common.py @@ -12,7 +12,7 @@ import numpy as np import torch import torch.nn as nn -from tensordict.tensordict import TensorDict, TensorDictBase +from tensordict.tensordict import TensorDictBase from torchrl._utils import prod, seed_generator @@ -1033,13 +1033,10 @@ def rand_action(self, tensordict: Optional[TensorDictBase] = None): """ shape = torch.Size([]) - if tensordict is None: - tensordict = TensorDict( - {}, device=self.device, batch_size=self.batch_size, _run_checks=False - ) - - if not self.batch_locked and not self.batch_size: + if not self.batch_locked and not self.batch_size and tensordict is not None: shape = tensordict.shape + elif not self.batch_locked and not self.batch_size: + shape = torch.Size([]) elif not self.batch_locked and tensordict.shape != self.batch_size: raise RuntimeError( "The input tensordict and the env have a different batch size: " @@ -1047,8 +1044,10 @@ def rand_action(self, tensordict: Optional[TensorDictBase] = None): f"Non batch-locked environment require the env batch-size to be either empty or to" f" match the tensordict one." ) - action = self.action_spec.rand(shape) - tensordict.set(self.action_key, action) + r = self.input_spec["_action_spec"].rand(shape) + if tensordict is None: + return r + tensordict.update(r) return tensordict def rand_step(self, tensordict: Optional[TensorDictBase] = None) -> TensorDictBase: From 9cc8c3e31204fd102d8ce37032c97a93b3b41ac8 Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Mon, 3 Jul 2023 09:42:57 +0100 Subject: [PATCH 04/10] added test Signed-off-by: Matteo Bettini --- test/test_env.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/test/test_env.py b/test/test_env.py index 3b38dea5b45..acd44c37d0b 100644 --- a/test/test_env.py +++ b/test/test_env.py @@ -1646,6 +1646,11 @@ def test_nested_env_dims(self, batch_size, nested_dim=5, rollout_length=3): assert td.batch_size == batch_size assert td["data"].batch_size == (*batch_size, nested_dim) + td = env.rand_step(td) + assert td.batch_size == batch_size + assert td["data"].batch_size == (*batch_size, nested_dim) + assert td["next", "data"].batch_size == (*batch_size, nested_dim) + td = env.rand_step() assert td.batch_size == batch_size assert td["data"].batch_size == (*batch_size, nested_dim) From 931437561aba7d6c1ce7bce8e85d2c204931dde4 Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Mon, 3 Jul 2023 10:06:41 +0100 Subject: [PATCH 05/10] make rollout keep_other=True Signed-off-by: Matteo Bettini --- torchrl/envs/common.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchrl/envs/common.py b/torchrl/envs/common.py index 921dbc8dbe8..15168f20411 100644 --- a/torchrl/envs/common.py +++ b/torchrl/envs/common.py @@ -1248,7 +1248,7 @@ def policy(td): break tensordict = step_mdp( tensordict, - keep_other=False, + keep_other=True, exclude_action=True, exclude_reward=True, reward_key=self.reward_key, From baae9c243a3482546aed214ecc84ee11672a83a2 Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Mon, 3 Jul 2023 10:31:52 +0100 Subject: [PATCH 06/10] parametric keys Signed-off-by: Matteo Bettini --- torchrl/envs/vec_env.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/torchrl/envs/vec_env.py b/torchrl/envs/vec_env.py index fc853270a3f..2e412d897ae 100644 --- a/torchrl/envs/vec_env.py +++ b/torchrl/envs/vec_env.py @@ -335,8 +335,8 @@ def _create_td(self) -> None: key = (key,) self.env_output_keys.append(("next", *key)) self.env_obs_keys.append(key) - self.env_output_keys.append(("next", "reward")) - self.env_output_keys.append(("next", "done")) + self.env_output_keys.append(("next", self.reward_key)) + self.env_output_keys.append(("next", self.done_key)) else: env_input_keys = set() for meta_data in self.meta_data: @@ -363,7 +363,7 @@ def _create_td(self) -> None: ) ) env_output_keys = env_output_keys.union( - {("next", "reward"), ("next", "done")} + {("next", self.reward_key), ("next", self.done_key)} ) self.env_obs_keys = sorted(env_obs_keys, key=_sort_keys) self.env_input_keys = sorted(env_input_keys, key=_sort_keys) @@ -374,10 +374,10 @@ def _create_td(self) -> None: .union(self.env_input_keys) .union(self.env_obs_keys) ) - self._selected_keys.add("done") + self._selected_keys.add(self.done_key) self._selected_keys.add("_reset") - self._selected_reset_keys = self.env_obs_keys + ["done"] + ["_reset"] + self._selected_reset_keys = self.env_obs_keys + [self.done_key] + ["_reset"] self._selected_step_keys = self.env_output_keys if self._single_task: @@ -1187,7 +1187,7 @@ def _reset(self, tensordict: TensorDictBase) -> TensorDictBase: @torch.no_grad() def _step(self, tensordict: TensorDictBase) -> TensorDictBase: - action = tensordict.get("action") + action = tensordict.get(self.action_key) # Action needs to be moved to CPU and converted to numpy before being passed to envpool action = action.to(torch.device("cpu")) step_output = self._env.step(action.numpy()) @@ -1285,7 +1285,7 @@ def _transform_reset_output( ) obs = self.obs.clone(False) - obs.update({"done": self.done_spec.zero()}) + obs.update({self.done_key: self.done_spec.zero()}) return obs def _transform_step_output( @@ -1295,7 +1295,7 @@ def _transform_step_output( obs, reward, done, *_ = envpool_output obs = self._treevalue_or_numpy_to_tensor_or_dict(obs) - obs.update({"reward": torch.tensor(reward), "done": done}) + obs.update({self.reward_key: torch.tensor(reward), self.done_key: done}) self.obs = tensordict_out = TensorDict( obs, batch_size=self.batch_size, From 482b29c3d3e71c98f1a3bbb3bceb03bb71ceda32 Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Mon, 3 Jul 2023 13:39:42 +0100 Subject: [PATCH 07/10] amend Signed-off-by: Matteo Bettini --- test/mocking_classes.py | 14 ++++++++++---- torchrl/envs/utils.py | 8 +++++++- torchrl/envs/vec_env.py | 35 ++++++++++++++++++----------------- 3 files changed, 35 insertions(+), 22 deletions(-) diff --git a/test/mocking_classes.py b/test/mocking_classes.py index 9e7fa98cdfa..e86dfb13136 100644 --- a/test/mocking_classes.py +++ b/test/mocking_classes.py @@ -1127,10 +1127,15 @@ def __init__( shape=self.batch_size, ) - def _reset(self, td): - if self.nested_done and td is not None and "_reset" in td.keys(): - td["_reset"] = td["_reset"].sum(-2, dtype=torch.bool) - td = super()._reset(td) + def _reset(self, tensordict): + if ( + self.nested_done + and tensordict is not None + and "_reset" in tensordict.keys() + ): + tensordict = tensordict.clone() + tensordict["_reset"] = tensordict["_reset"].sum(-2, dtype=torch.bool) + td = super()._reset(tensordict) if self.nested_done: td[self.done_key] = ( td["done"].unsqueeze(-1).expand(*self.batch_size, self.nested_dim, 1) @@ -1149,6 +1154,7 @@ def _reset(self, td): def _step(self, td): if self.nested_obs_action: + td = td.clone() td["data"].batch_size = self.batch_size td[self.action_key] = td[self.action_key].max(-2)[0] td_root = super()._step(td) diff --git a/torchrl/envs/utils.py b/torchrl/envs/utils.py index d9a774ea8e6..45b3088984d 100644 --- a/torchrl/envs/utils.py +++ b/torchrl/envs/utils.py @@ -223,7 +223,10 @@ def step_mdp( return out -def _set_single_key(source, dest, key): +def _set_single_key(source, dest, key, clone=False): + # key should be unraveled + if isinstance(key, str): + key = (key,) for k in key: val = source.get(k) if is_tensor_collection(val): @@ -234,6 +237,8 @@ def _set_single_key(source, dest, key): source = val dest = new_val else: + if clone: + val = val.clone() dest._set(k, val) @@ -482,6 +487,7 @@ def __get__(self, owner_self, owner_cls): def _sort_keys(element): if isinstance(element, tuple): + element = unravel_keys(element) return "_-|-_".join(element) return element diff --git a/torchrl/envs/vec_env.py b/torchrl/envs/vec_env.py index 2e412d897ae..cf6108763a4 100644 --- a/torchrl/envs/vec_env.py +++ b/torchrl/envs/vec_env.py @@ -20,6 +20,7 @@ import torch from tensordict import TensorDict from tensordict.tensordict import LazyStackedTensorDict, TensorDictBase +from tensordict.utils import unravel_keys from torch import multiprocessing as mp from torchrl._utils import _check_for_faulty_process, VERBOSE @@ -33,8 +34,7 @@ from torchrl.envs.common import _EnvWrapper, EnvBase from torchrl.envs.env_creator import get_env_metadata -from torchrl.envs.utils import _sort_keys - +from torchrl.envs.utils import _set_single_key, _sort_keys _has_envpool = importlib.util.find_spec("envpool") @@ -324,28 +324,26 @@ def _create_td(self) -> None: if self._single_task: self.env_input_keys = sorted( - list(self.input_spec["_action_spec"].keys(True)) - + list(self.state_spec.keys(True)), + list(self.input_spec["_action_spec"].keys(True, True)) + + list(self.state_spec.keys(True, True)), key=_sort_keys, ) self.env_output_keys = [] self.env_obs_keys = [] - for key in self.output_spec["_observation_spec"].keys(True): - if isinstance(key, str): - key = (key,) - self.env_output_keys.append(("next", *key)) + for key in self.output_spec["_observation_spec"].keys(True, True): + self.env_output_keys.append(unravel_keys(("next", key))) self.env_obs_keys.append(key) - self.env_output_keys.append(("next", self.reward_key)) - self.env_output_keys.append(("next", self.done_key)) + self.env_output_keys.append(unravel_keys(("next", self.reward_key))) + self.env_output_keys.append(unravel_keys(("next", self.done_key))) else: env_input_keys = set() for meta_data in self.meta_data: if meta_data.specs["input_spec", "_state_spec"] is not None: env_input_keys = env_input_keys.union( - meta_data.specs["input_spec", "_state_spec"].keys(True) + meta_data.specs["input_spec", "_state_spec"].keys(True, True) ) env_input_keys = env_input_keys.union( - meta_data.specs["input_spec", "_action_spec"].keys(True) + meta_data.specs["input_spec", "_action_spec"].keys(True, True) ) env_output_keys = set() env_obs_keys = set() @@ -353,17 +351,20 @@ def _create_td(self) -> None: env_obs_keys = env_obs_keys.union( key for key in meta_data.specs["output_spec"]["_observation_spec"].keys( - True + True, True ) ) env_output_keys = env_output_keys.union( - ("next", key) if isinstance(key, str) else ("next", *key) + unravel_keys(("next", key)) for key in meta_data.specs["output_spec"]["_observation_spec"].keys( - True + True, True ) ) env_output_keys = env_output_keys.union( - {("next", self.reward_key), ("next", self.done_key)} + { + unravel_keys(("next", self.reward_key)), + unravel_keys(("next", self.done_key)), + } ) self.env_obs_keys = sorted(env_obs_keys, key=_sort_keys) self.env_input_keys = sorted(env_input_keys, key=_sort_keys) @@ -619,7 +620,7 @@ def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase: out = TensorDict({}, batch_size=self.shared_tensordict_parent.shape) for key in self._selected_reset_keys: if key != "_reset": - out._set(key, self.shared_tensordict_parent.get(key).clone()) + _set_single_key(self.shared_tensordict_parent, out, key, clone=True) return out else: return self.shared_tensordict_parent.select( From 4ff6eafb4818a45ac14b0547f2e3d3e15d7f82e9 Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Mon, 3 Jul 2023 13:58:03 +0100 Subject: [PATCH 08/10] added tests and logic Signed-off-by: Matteo Bettini --- test/test_env.py | 53 +++++++++++++++++++++++++++++++++++++++++ torchrl/envs/vec_env.py | 6 ++--- 2 files changed, 56 insertions(+), 3 deletions(-) diff --git a/test/test_env.py b/test/test_env.py index acd44c37d0b..e5d01b411f1 100644 --- a/test/test_env.py +++ b/test/test_env.py @@ -943,6 +943,59 @@ def test_parallel_env_reset_flag(self, batch_size, n_workers, max_steps=3): assert (td_reset["done"][~_reset] == 1).all() assert (td_reset["observation"][~_reset] == max_steps + 1).all() + @pytest.mark.parametrize("nested_obs_action", [True, False]) + @pytest.mark.parametrize("nested_done", [True, False]) + @pytest.mark.parametrize("nested_reward", [True, False]) + @pytest.mark.parametrize("env_type", ["serial", "parallel"]) + def test_parallel_env_nested( + self, + nested_obs_action, + nested_done, + nested_reward, + env_type, + n_envs=2, + batch_size=(32,), + nested_dim=5, + rollout_length=3, + seed=1, + ): + env_fn = lambda: NestedCountingEnv( + nest_done=nested_done, + nest_reward=nested_reward, + nest_obs_action=nested_obs_action, + batch_size=batch_size, + nested_dim=nested_dim, + ) + if env_type == "serial": + env = SerialEnv(n_envs, env_fn) + else: + env = ParallelEnv(n_envs, env_fn) + env.set_seed(seed) + + batch_size = (n_envs, *batch_size) + + td = env.reset() + assert td.batch_size == batch_size + if nested_done or nested_obs_action: + assert td["data"].batch_size == (*batch_size, nested_dim) + if not nested_done and not nested_reward and not nested_obs_action: + assert "data" not in td.keys() + + policy = CountingEnvCountPolicy(env.action_spec, env.action_key) + td = env.rollout(rollout_length, policy) + assert td.batch_size == (*batch_size, rollout_length) + if nested_done or nested_obs_action: + assert td["data"].batch_size == (*batch_size, rollout_length, nested_dim) + if nested_reward or nested_done or nested_obs_action: + assert td["next", "data"].batch_size == ( + *batch_size, + rollout_length, + nested_dim, + ) + if not nested_done and not nested_reward and not nested_obs_action: + assert "data" not in td.keys() + assert "data" not in td["next"].keys() + @pytest.mark.parametrize("batch_size", [(), (2,), (32, 5)]) def test_env_base_reset_flag(batch_size, max_steps=3): diff --git a/torchrl/envs/vec_env.py b/torchrl/envs/vec_env.py index cf6108763a4..b25ec2c17e5 100644 --- a/torchrl/envs/vec_env.py +++ b/torchrl/envs/vec_env.py @@ -551,7 +551,7 @@ def _step( if self._single_task: out = TensorDict({}, batch_size=self.shared_tensordict_parent.shape) for key in self._selected_step_keys: - out._set(key, self.shared_tensordict_parent.get(key).clone()) + _set_single_key(self.shared_tensordict_parent, out, key, clone=True) else: # strict=False ensures that non-homogeneous keys are still there out = self.shared_tensordict_parent.select( @@ -791,7 +791,7 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase: if self._single_task: out = TensorDict({}, batch_size=self.shared_tensordict_parent.shape) for key in self._selected_step_keys: - out._set(key, self.shared_tensordict_parent.get(key).clone()) + _set_single_key(self.shared_tensordict_parent, out, key, clone=True) else: # strict=False ensures that non-homogeneous keys are still there out = self.shared_tensordict_parent.select( @@ -854,7 +854,7 @@ def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase: out = TensorDict({}, batch_size=self.shared_tensordict_parent.shape) for key in self._selected_reset_keys: if key != "_reset": - out._set(key, self.shared_tensordict_parent.get(key).clone()) + _set_single_key(self.shared_tensordict_parent, out, key, clone=True) return out else: return self.shared_tensordict_parent.select( From ab8196158fc2abf0bf454ceeed5d56089aac49a6 Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Mon, 3 Jul 2023 13:58:56 +0100 Subject: [PATCH 09/10] typo Signed-off-by: Matteo Bettini --- torchrl/envs/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchrl/envs/utils.py b/torchrl/envs/utils.py index 45b3088984d..152b65ba3ed 100644 --- a/torchrl/envs/utils.py +++ b/torchrl/envs/utils.py @@ -224,7 +224,7 @@ def step_mdp( def _set_single_key(source, dest, key, clone=False): - # key should be unraveled + # key should be already unraveled if isinstance(key, str): key = (key,) for k in key: From 3412c1bd1bcce2f7e29b9af74b4284dc01d6446e Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Mon, 3 Jul 2023 14:09:02 +0100 Subject: [PATCH 10/10] amend Signed-off-by: Matteo Bettini --- test/test_env.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/test/test_env.py b/test/test_env.py index e5d01b411f1..cc2495f3fa9 100644 --- a/test/test_env.py +++ b/test/test_env.py @@ -996,6 +996,13 @@ def test_parallel_env_nested( assert "data" not in td.keys() assert "data" not in td["next"].keys() + if nested_obs_action: + assert "observation" not in td.keys() + assert (td[..., -1]["data", "states"] == 2).all() + else: + assert ("data", "states") not in td.keys(True, True) + assert (td[..., -1]["observation"] == 2).all() + @pytest.mark.parametrize("batch_size", [(), (2,), (32, 5)]) def test_env_base_reset_flag(batch_size, max_steps=3):