From ebcb1a301af0eb36688e071ac618206466549e8d Mon Sep 17 00:00:00 2001 From: vmoens Date: Mon, 29 May 2023 21:03:59 +0100 Subject: [PATCH 1/3] fix --- test/mocking_classes.py | 64 ++++++++++++++++++++++++++++++++++++ torchrl/data/tensor_specs.py | 17 +++++++--- torchrl/envs/common.py | 17 +++++----- 3 files changed, 86 insertions(+), 12 deletions(-) diff --git a/test/mocking_classes.py b/test/mocking_classes.py index 30aee19fa82..d35a31f64ca 100644 --- a/test/mocking_classes.py +++ b/test/mocking_classes.py @@ -1152,3 +1152,67 @@ def _step( device=self.device, ) return tensordict.select().set("next", tensordict) + + +class HeteroEnv(EnvBase): + """A heterogeneous, counting Env.""" + + def __init__(self, device="cpu"): + self.observation_spec = CompositeSpec( + agent_features=torch.stack( + [CompositeSpec( + observation=UnboundedContinuousTensorSpec(shape=(3,)) + ), + CompositeSpec( + observation=UnboundedContinuousTensorSpec(shape=(2,)) + )], + dim=0 + ), + common_features=UnboundedContinuousTensorSpec(shape=(5,)), + shape=() + ) + self.action_spec = CompositeSpec( + agent_features=torch.stack( + [CompositeSpec( + action=UnboundedContinuousTensorSpec(shape=(3,), ) + ), + CompositeSpec( + action=UnboundedContinuousTensorSpec(shape=(2,), ) + )], dim=0 + ), shape=()) + self.reward_spec = CompositeSpec( + agent_features=CompositeSpec( + reward=UnboundedContinuousTensorSpec(shape=(2,)), + shape=(2,) + ), shape=() + ) + self.done_spec = CompositeSpec( + agent_features=CompositeSpec( + done=DiscreteTensorSpec(n=2, shape=(2,), dtype=torch.bool) + ) + ) + super().__init__(device=device) + + def _set_seed(self, seed): + return seed + + def _reset( + self, + tensordict: TensorDictBase = None, + **kwargs, + ) -> TensorDictBase: + self.counter = 0 + td = self.observation_spec.zero() + td.update(self.output_spec['_done_spec'].zero()) + td.update(self.output_spec['_reward_spec'].zero()) + return td + + def _step( + self, + tensordict: TensorDictBase, + ) -> TensorDictBase: + td = self.observation_spec.zero() + td.apply_(lambda x: x + self.counter) + td.update(self.output_spec['_done_spec'].zero()) + td.update(self.output_spec['_reward_spec'].zero()) + return td.select().set("next", td) diff --git a/torchrl/data/tensor_specs.py b/torchrl/data/tensor_specs.py index a6bfe0cbe1f..c5833e46a76 100644 --- a/torchrl/data/tensor_specs.py +++ b/torchrl/data/tensor_specs.py @@ -3088,12 +3088,21 @@ def to_numpy(self, val: TensorDict, safe: bool = True) -> dict: def __len__(self): pass - def values(self): - for key in self.keys(): + def values(self, + include_nested: bool = False, + leaves_only: bool = False, + ): + for key in self.keys(include_nested=include_nested, leaves_only=leaves_only): yield self[key] - def items(self): - for key in self.keys(): + def items(self, + include_nested: bool = False, + leaves_only: bool = False, + ): + for key in self.keys( + include_nested=include_nested, + leaves_only=leaves_only + ): yield key, self[key] def keys( diff --git a/torchrl/envs/common.py b/torchrl/envs/common.py index 19d41b4e133..2b6e26c7d7a 100644 --- a/torchrl/envs/common.py +++ b/torchrl/envs/common.py @@ -1035,13 +1035,10 @@ def rand_action(self, tensordict: Optional[TensorDictBase] = None): """ shape = torch.Size([]) - if tensordict is None: - tensordict = TensorDict( - {}, device=self.device, batch_size=self.batch_size, _run_checks=False - ) - - if not self.batch_locked and not self.batch_size: + if not self.batch_locked and not self.batch_size and tensordict is not None: shape = tensordict.shape + elif not self.batch_locked and not self.batch_size: + shape = torch.Size([]) elif not self.batch_locked and tensordict.shape != self.batch_size: raise RuntimeError( "The input tensordict and the env have a different batch size: " @@ -1049,8 +1046,12 @@ def rand_action(self, tensordict: Optional[TensorDictBase] = None): f"Non batch-locked environment require the env batch-size to be either empty or to" f" match the tensordict one." ) - action = self.action_spec.rand(shape) - tensordict.set("action", action) + r = self.input_spec['_action_spec'].rand(shape) + if tensordict is None: + tensordict = r.select() + print(r, type(r)) + print(tensordict, type(tensordict)) + tensordict.update(r) return tensordict def rand_step(self, tensordict: Optional[TensorDictBase] = None) -> TensorDictBase: From dca37ff8adefff2a27e3dadcb3d8935a65223071 Mon Sep 17 00:00:00 2001 From: vmoens Date: Tue, 13 Jun 2023 12:43:03 +0100 Subject: [PATCH 2/3] amend --- test/mocking_classes.py | 44 ++++++------ test/test_env.py | 41 ++++++----- torchrl/data/tensor_specs.py | 54 ++++++++------ torchrl/envs/common.py | 25 +++---- torchrl/envs/utils.py | 133 ++++++++++++++++++++++------------- 5 files changed, 174 insertions(+), 123 deletions(-) diff --git a/test/mocking_classes.py b/test/mocking_classes.py index 3a676721957..413f5f89d52 100644 --- a/test/mocking_classes.py +++ b/test/mocking_classes.py @@ -1160,36 +1160,37 @@ class HeteroEnv(EnvBase): def __init__(self, device="cpu"): self.observation_spec = CompositeSpec( agent_features=torch.stack( - [CompositeSpec( - observation=UnboundedContinuousTensorSpec(shape=(3,)) - ), + [ + CompositeSpec( + observation=UnboundedContinuousTensorSpec(shape=(3,)) + ), CompositeSpec( observation=UnboundedContinuousTensorSpec(shape=(2,)) - )], - dim=0 + ), + ], + dim=0, ), common_features=UnboundedContinuousTensorSpec(shape=(5,)), - shape=() + shape=(), ) self.action_spec = CompositeSpec( - agent_features=torch.stack( - [CompositeSpec( - action=UnboundedContinuousTensorSpec(shape=(3,), ) - ), - CompositeSpec( - action=UnboundedContinuousTensorSpec(shape=(2,), ) - )], dim=0 - ), shape=()) + agent_features=CompositeSpec( + action=UnboundedContinuousTensorSpec( + shape=(2, 3), + ), + shape=(2,), + ) + ) self.reward_spec = CompositeSpec( agent_features=CompositeSpec( - reward=UnboundedContinuousTensorSpec(shape=(2,)), - shape=(2,) - ), shape=() + reward=UnboundedContinuousTensorSpec(shape=(2,)), shape=(2,) + ), + shape=(), ) self.done_spec = CompositeSpec( agent_features=CompositeSpec( done=DiscreteTensorSpec(n=2, shape=(2,), dtype=torch.bool) - ) + ) ) super().__init__(device=device) @@ -1203,8 +1204,7 @@ def _reset( ) -> TensorDictBase: self.counter = 0 td = self.observation_spec.zero() - td.update(self.output_spec['_done_spec'].zero()) - td.update(self.output_spec['_reward_spec'].zero()) + td.update(self.output_spec["_done_spec"].zero()) return td def _step( @@ -1213,6 +1213,6 @@ def _step( ) -> TensorDictBase: td = self.observation_spec.zero() td.apply_(lambda x: x + self.counter) - td.update(self.output_spec['_done_spec'].zero()) - td.update(self.output_spec['_reward_spec'].zero()) + td.update(self.output_spec["_done_spec"].zero()) + td.update(self.output_spec["_reward_spec"].zero()) return td.select().set("next", td) diff --git a/test/test_env.py b/test/test_env.py index 3fdccb014c3..db0bf7ca7c0 100644 --- a/test/test_env.py +++ b/test/test_env.py @@ -32,6 +32,7 @@ DiscreteActionConvMockEnvNumpy, DiscreteActionVecMockEnv, DummyModelBasedEnvBase, + HeteroEnv, MockBatchedLockedEnv, MockBatchedUnLockedEnv, MockSerialEnv, @@ -1404,31 +1405,35 @@ def test_nested_reward(self, envclass): @pytest.mark.parametrize( - "envclass", + "envclass,return_contiguous", [ - ContinuousActionConvMockEnv, - ContinuousActionConvMockEnvNumpy, - ContinuousActionVecMockEnv, - CountingBatchedEnv, - CountingEnv, - DiscreteActionConvMockEnv, - DiscreteActionConvMockEnvNumpy, - DiscreteActionVecMockEnv, - partial( - DummyModelBasedEnvBase, world_model=TestModelBasedEnvBase.world_model() - ), - MockBatchedLockedEnv, - MockBatchedUnLockedEnv, - MockSerialEnv, - NestedRewardEnv, + [HeteroEnv, False], + [ContinuousActionConvMockEnv, True], + [ContinuousActionConvMockEnvNumpy, True], + [ContinuousActionVecMockEnv, True], + [CountingBatchedEnv, True], + [CountingEnv, True], + [DiscreteActionConvMockEnv, True], + [DiscreteActionConvMockEnvNumpy, True], + [DiscreteActionVecMockEnv, True], + [ + partial( + DummyModelBasedEnvBase, world_model=TestModelBasedEnvBase.world_model() + ), + True, + ], + [MockBatchedLockedEnv, True], + [MockBatchedUnLockedEnv, True], + [MockSerialEnv, True], + [NestedRewardEnv, True], ], ) -def test_mocking_envs(envclass): +def test_mocking_envs(envclass, return_contiguous): env = envclass() env.set_seed(100) reset = env.reset() _ = env.rand_step(reset) - check_env_specs(env, seed=100) + check_env_specs(env, seed=100, return_contiguous=return_contiguous) if __name__ == "__main__": diff --git a/torchrl/data/tensor_specs.py b/torchrl/data/tensor_specs.py index c827d509a7c..2c292864dff 100644 --- a/torchrl/data/tensor_specs.py +++ b/torchrl/data/tensor_specs.py @@ -978,6 +978,12 @@ def set(self, name, spec): ) self._specs[name] = spec + def is_in(self, val) -> bool: + isin = True + for spec, subval in zip(self._specs, val.unbind(self.dim)): + isin = isin and spec.is_in(subval) + return isin + @dataclass(repr=False) class OneHotDiscreteTensorSpec(TensorSpec): @@ -3126,23 +3132,22 @@ def to_numpy(self, val: TensorDict, safe: bool = None) -> dict: return {key: self[key].to_numpy(val) for key, val in val.items()} def __len__(self): - pass + raise NotImplementedError - def values(self, - include_nested: bool = False, - leaves_only: bool = False, - ): - for key in self.keys(include_nested=include_nested, leaves_only=leaves_only): + def values( + self, + include_nested: bool = False, + leaves_only: bool = False, + ): + for key in self.keys(include_nested=include_nested, leaves_only=leaves_only): yield self[key] - def items(self, - include_nested: bool = False, - leaves_only: bool = False, - ): - for key in self.keys( - include_nested=include_nested, - leaves_only=leaves_only - ): + def items( + self, + include_nested: bool = False, + leaves_only: bool = False, + ): + for key in self.keys(include_nested=include_nested, leaves_only=leaves_only): yield key, self[key] def keys( @@ -3155,17 +3160,14 @@ def keys( ) def project(self, val: TensorDictBase) -> TensorDictBase: - pass - - def is_in(self, val: Union[dict, TensorDictBase]) -> bool: - pass + raise NotImplementedError def type_check( self, value: Union[torch.Tensor, TensorDictBase], selected_keys: Union[str, Optional[Sequence[str]]] = None, ): - pass + raise NotImplementedError def __repr__(self) -> str: sub_str = ",\n".join( @@ -3178,19 +3180,25 @@ def __repr__(self) -> str: f"LazyStackedCompositeSpec(\n{', '.join([sub_str, device_str, shape_str])})" ) + def is_in(self, val) -> bool: + isin = True + for spec, subval in zip(self._specs, val.unbind(self.dim)): + isin = isin and spec.is_in(subval) + return isin + def encode( self, vals: Dict[str, Any], ignore_device: bool = False ) -> Dict[str, torch.Tensor]: - pass + raise NotImplementedError def __delitem__(self, key): - pass + raise NotImplementedError def __iter__(self): - pass + raise NotImplementedError def __setitem__(self, key, value): - pass + raise NotImplementedError @property def device(self) -> DEVICE_TYPING: diff --git a/torchrl/envs/common.py b/torchrl/envs/common.py index ebf81cb4784..7213a32c792 100644 --- a/torchrl/envs/common.py +++ b/torchrl/envs/common.py @@ -12,7 +12,7 @@ import numpy as np import torch import torch.nn as nn -from tensordict.tensordict import TensorDict, TensorDictBase +from tensordict.tensordict import TensorDictBase from torchrl._utils import prod, seed_generator @@ -1044,7 +1044,7 @@ def rand_action(self, tensordict: Optional[TensorDictBase] = None): f"Non batch-locked environment require the env batch-size to be either empty or to" f" match the tensordict one." ) - r = self.input_spec['_action_spec'].rand(shape) + r = self.input_spec["_action_spec"].rand(shape) if tensordict is None: tensordict = r.select() print(r, type(r)) @@ -1255,6 +1255,7 @@ def policy(td): tensordict, keep_other=True, exclude_action=False, + exclude_reward=self.reward_key, ) if not break_when_any_done and done.any(): _reset = done.clone() @@ -1307,14 +1308,20 @@ def fake_tensordict(self) -> TensorDictBase: state_spec = self.state_spec observation_spec = self.observation_spec action_spec = self.input_spec["_action_spec"] + # instantiates reward_spec if needed _ = self.reward_spec reward_spec = self.output_spec["_reward_spec"] + # instantiates done_spec if needed _ = self.done_spec done_spec = self.output_spec["_done_spec"] fake_obs = observation_spec.zero() + fake_input = state_spec.zero() - fake_input = fake_input.update(action_spec.zero()) + + action_zero = action_spec.zero() + + fake_input = fake_input.update(action_zero) # the input and output key may match, but the output prevails # Hence we generate the input, and override using the output @@ -1326,17 +1333,11 @@ def fake_tensordict(self) -> TensorDictBase: next_output = fake_obs.clone() next_output.update(fake_reward) next_output.update(fake_done) - fake_in_out.update(fake_done.clone()) - fake_td = TensorDict( - { - **fake_in_out, - "next": next_output, - }, - batch_size=self.batch_size, - device=self.device, - ) + fake_td = fake_in_out.clone().set("next", next_output) + fake_td.batch_size = self.batch_size + fake_td = fake_td.to(self.device) return fake_td diff --git a/torchrl/envs/utils.py b/torchrl/envs/utils.py index d797a507b00..df3ca89d7df 100644 --- a/torchrl/envs/utils.py +++ b/torchrl/envs/utils.py @@ -63,18 +63,21 @@ def step_mdp( next_tensordict (TensorDictBase, optional): destination tensordict keep_other (bool, optional): if ``True``, all keys that do not start with :obj:`'next_'` will be kept. Default is ``True``. - exclude_reward (bool, optional): if ``True``, the :obj:`"reward"` key will be discarded + exclude_reward (bool or key, optional): if ``True``, the :obj:`"reward"` key will be discarded from the resulting tensordict. If ``False``, it will be copied (and replaced) - from the ``"next"`` entry (if present). + from the ``"next"`` entry (if present). If a key is provided, + this key will be excluded. Default is ``True``. - exclude_done (bool, optional): if ``True``, the :obj:`"done"` key will be discarded + exclude_done (bool or key, optional): if ``True``, the :obj:`"done"` key will be discarded from the resulting tensordict. If ``False``, it will be copied (and replaced) - from the ``"next"`` entry (if present). + from the ``"next"`` entry (if present). If a key is provided, + this key will be excluded. Default is ``False``. - exclude_action (bool, optional): if ``True``, the :obj:`"action"` key will + exclude_action (bool or key, optional): if ``True``, the :obj:`"action"` key will be discarded from the resulting tensordict. If ``False``, it will be kept in the root tensordict (since it should not be present in - the ``"next"`` entry). + the ``"next"`` entry). If a key is provided, + this key will be excluded. Default is ``True``. Returns: @@ -170,12 +173,15 @@ def step_mdp( return out out = tensordict.get("next").clone(False) excluded = None + done_key = "done" if exclude_done is True else exclude_done + reward_key = "reward" if exclude_reward is True else exclude_reward + action_key = "action" if exclude_action is True else exclude_action if exclude_done and exclude_reward: - excluded = {"done", "reward"} + excluded = {done_key, reward_key} elif exclude_reward: - excluded = {"reward"} + excluded = {reward_key} elif exclude_done: - excluded = {"done"} + excluded = {done_key} if excluded: out = out.exclude(*excluded, inplace=True) # TODO: make it work with LazyStackedTensorDict @@ -192,9 +198,9 @@ def step_mdp( out_keys = set(out.keys()) td_keys = set(tensordict.keys()) - out_keys - {"next"} if exclude_action: - td_keys = td_keys - {"action"} + td_keys = td_keys - {action_key} elif not exclude_action: - td_keys = {"action"} + td_keys = {action_key} if td_keys: # update does some checks that we can spare @@ -268,6 +274,39 @@ def _check_dmlab(): } +def _per_level_env_check(data0, data1, check_dtype): + """Checks shape and dtype of two tensordicts, accounting for lazy stacks.""" + if isinstance(data0, LazyStackedTensorDict) and isinstance( + data1, LazyStackedTensorDict + ): + if data0.stack_dim != data1.stack_dim: + raise AssertionError(f"Stack dimension mismatch: {data0} vs {data1}.") + for _data0, _data1 in zip(data0.tensordicts, data1.tensordicts): + _per_level_env_check(_data0, _data1, check_dtype=check_dtype) + return + else: + keys0 = set(data0.keys()) + keys1 = set(data1.keys()) + if keys0 != keys1: + raise AssertionError(f"Keys mismatch: {keys0} vs {keys1}") + for key in keys0: + _data0 = data0[key] + _data1 = data1[key] + if _data0.shape != _data1.shape: + raise AssertionError( + f"The shapes of the real and fake tensordict don't match for key {key}. " + f"Got fake={_data0.shape} and real={_data0.shape}." + ) + if isinstance(_data0, TensorDictBase): + _per_level_env_check(_data0, _data1, check_dtype=check_dtype) + else: + if check_dtype and (_data0.dtype != _data1.dtype): + raise AssertionError( + f"The dtypes of the real and fake tensordict don't match for key {key}. " + f"Got fake={_data0.dtype} and real={_data1.dtype}." + ) + + def check_env_specs(env, return_contiguous=True, check_dtype=True, seed=0): """Tests an environment specs against the results of short rollout. @@ -295,56 +334,54 @@ def check_env_specs(env, return_contiguous=True, check_dtype=True, seed=0): torch.manual_seed(seed) env.set_seed(seed) - fake_tensordict = env.fake_tensordict().flatten_keys(".") + fake_tensordict = env.fake_tensordict() # .flatten_keys(".") real_tensordict = env.rollout(3, return_contiguous=return_contiguous) - # # remove private keys - # real_tensordict = real_tensordict.exclude( - # *[ - # key - # for key in real_tensordict.keys(True) - # if (isinstance(key, str) and key.startswith("_")) - # or ( - # isinstance(key, tuple) and any(subkey.startswith("_") for subkey in key) - # ) - # ] - # ) - real_tensordict = real_tensordict.flatten_keys(".") - - keys1 = set(fake_tensordict.keys(True)) - keys2 = set(real_tensordict.keys(True)) - if keys1 != keys2: - raise AssertionError( - "The keys of the fake tensordict and the one collected during rollout do not match:" - f"Got fake-real: {keys1-keys2} and real-fake: {keys2-keys1}" - ) - fake_tensordict = fake_tensordict.unsqueeze(real_tensordict.batch_dims - 1) - fake_tensordict = fake_tensordict.expand(*real_tensordict.shape) - fake_tensordict = fake_tensordict.to_tensordict() + + if return_contiguous: + fake_tensordict = fake_tensordict.unsqueeze(real_tensordict.batch_dims - 1) + fake_tensordict = fake_tensordict.expand(*real_tensordict.shape) + else: + fake_tensordict = torch.stack([fake_tensordict.clone() for _ in range(3)], -1) + if ( fake_tensordict.apply(lambda x: torch.zeros_like(x)) != real_tensordict.apply(lambda x: torch.zeros_like(x)) - ).all(): + ).any(): raise AssertionError( "zeroing the two tensordicts did not make them identical. " f"Check for discrepancies:\nFake=\n{fake_tensordict}\nReal=\n{real_tensordict}" ) - for key in keys2: - if fake_tensordict[key].shape != real_tensordict[key].shape: + _per_level_env_check(fake_tensordict, real_tensordict, check_dtype=check_dtype) + + # test dtypes + # real_tensordict = env.rollout(3, return_contiguous=return_contiguous) # keep empty structures, for example dict() + last_td = real_tensordict[..., -1] + _action_spec = env.input_spec["_action_spec"] + _state_spec = env.input_spec["_state_spec"] + _obs_spec = env.output_spec["_observation_spec"] + _reward_spec = env.output_spec["_reward_spec"] + _done_spec = env.output_spec["_done_spec"] + for name, spec in ( + ("action", _action_spec), + ("state", _state_spec), + ("obs", _obs_spec), + ): + td = last_td.select(*spec.keys(True, True), strict=True) + if not spec.is_in(td): raise AssertionError( - f"The shapes of the real and fake tensordict don't match for key {key}. " - f"Got fake={fake_tensordict[key].shape} and real={real_tensordict[key].shape}." + f"spec check failed at root for spec {name}={spec} and data {td}." ) - if check_dtype and (fake_tensordict[key].dtype != real_tensordict[key].dtype): + for name, spec in ( + ("reward", _reward_spec), + ("done", _done_spec), + ("obs", _obs_spec), + ): + td = last_td.get("next").select(*spec.keys(True, True), strict=True) + if not spec.is_in(td): raise AssertionError( - f"The dtypes of the real and fake tensordict don't match for key {key}. " - f"Got fake={fake_tensordict[key].dtype} and real={real_tensordict[key].dtype}." + f"spec check failed at root for spec {name}={spec} and data {td}." ) - # test dtypes - real_tensordict = env.rollout(3) # keep empty structures, for example dict() - for key, value in real_tensordict[..., -1].items(): - _check_isin(key, value, env.observation_spec, env.input_spec) - print("check_env_specs succeeded!") From bfe7d8f75d8d629dfaffbce3a2b682932ada462b Mon Sep 17 00:00:00 2001 From: vmoens Date: Tue, 13 Jun 2023 14:36:45 +0100 Subject: [PATCH 3/3] fix --- test/test_env.py | 48 ++++++++++++++++++++++++-------------- torchrl/envs/common.py | 6 ++--- torchrl/envs/utils.py | 52 +++++++++++++++++++++++------------------- 3 files changed, 63 insertions(+), 43 deletions(-) diff --git a/test/test_env.py b/test/test_env.py index db0bf7ca7c0..a6f278fbeb1 100644 --- a/test/test_env.py +++ b/test/test_env.py @@ -1014,10 +1014,21 @@ def test_seed(): @pytest.mark.parametrize("exclude_reward", [True, False]) @pytest.mark.parametrize("exclude_done", [True, False]) @pytest.mark.parametrize("exclude_action", [True, False]) +@pytest.mark.parametrize("reward_key", ["reward", ("some", "other", "reward")]) +@pytest.mark.parametrize("done_key", ["done", ("some", "other", "done")]) +@pytest.mark.parametrize("action_key", ["action", ("some", "other", "action")]) @pytest.mark.parametrize("has_out", [True, False]) @pytest.mark.parametrize("lazy_stack", [False, True]) def test_steptensordict( - keep_other, exclude_reward, exclude_done, exclude_action, has_out, lazy_stack + keep_other, + exclude_reward, + exclude_done, + exclude_action, + has_out, + lazy_stack, + reward_key, + action_key, + done_key, ): torch.manual_seed(0) tensordict = TensorDict( @@ -1025,11 +1036,11 @@ def test_steptensordict( "ledzep": torch.randn(4, 2), "next": { "ledzep": torch.randn(4, 2), - "reward": torch.randn(4, 1), - "done": torch.zeros(4, 1, dtype=torch.bool), + reward_key: torch.randn(4, 1), + done_key: torch.zeros(4, 1, dtype=torch.bool), }, "beatles": torch.randn(4, 1), - "action": torch.randn(4, 2), + action_key: torch.randn(4, 2), }, [4], ) @@ -1050,6 +1061,9 @@ def test_steptensordict( exclude_done=exclude_done, exclude_action=exclude_action, next_tensordict=next_tensordict, + reward_key=reward_key, + done_key=done_key, + action_key=action_key, ) assert "ledzep" in out.keys() if lazy_stack: @@ -1066,31 +1080,31 @@ def test_steptensordict( else: assert out["beatles"] is tensordict["beatles"] else: - assert "beatles" not in out.keys() + assert "beatles" not in out.keys(True) if not exclude_reward: - assert "reward" in out.keys() + assert reward_key in out.keys(True) if lazy_stack: - assert (out["reward"] == tensordict["next", "reward"]).all() + assert (out[reward_key] == tensordict["next", reward_key]).all() else: - assert out["reward"] is tensordict["next", "reward"] + assert out[reward_key] is tensordict["next", reward_key] else: - assert "reward" not in out.keys() + assert reward_key not in out.keys(True) if not exclude_action: - assert "action" in out.keys() + assert action_key in out.keys(True) if lazy_stack: - assert (out["action"] == tensordict["action"]).all() + assert (out[action_key] == tensordict[action_key]).all() else: - assert out["action"] is tensordict["action"] + assert out[action_key] is tensordict[action_key] else: - assert "action" not in out.keys() + assert action_key not in out.keys(True) if not exclude_done: - assert "done" in out.keys() + assert done_key in out.keys(True) if lazy_stack: - assert (out["done"] == tensordict["next", "done"]).all() + assert (out[done_key] == tensordict["next", done_key]).all() else: - assert out["done"] is tensordict["next", "done"] + assert out[done_key] is tensordict["next", done_key] else: - assert "done" not in out.keys() + assert done_key not in out.keys(True) if has_out: assert out is next_tensordict diff --git a/torchrl/envs/common.py b/torchrl/envs/common.py index 7213a32c792..f2bd95889b3 100644 --- a/torchrl/envs/common.py +++ b/torchrl/envs/common.py @@ -1047,8 +1047,6 @@ def rand_action(self, tensordict: Optional[TensorDictBase] = None): r = self.input_spec["_action_spec"].rand(shape) if tensordict is None: tensordict = r.select() - print(r, type(r)) - print(tensordict, type(tensordict)) tensordict.update(r) return tensordict @@ -1255,7 +1253,9 @@ def policy(td): tensordict, keep_other=True, exclude_action=False, - exclude_reward=self.reward_key, + exclude_reward=True, + reward_key=self.reward_key, + action_key=self.action_key, ) if not break_when_any_done and done.any(): _reset = done.clone() diff --git a/torchrl/envs/utils.py b/torchrl/envs/utils.py index df3ca89d7df..791bdced4a0 100644 --- a/torchrl/envs/utils.py +++ b/torchrl/envs/utils.py @@ -16,7 +16,8 @@ set_interaction_mode as set_exploration_mode, set_interaction_type as set_exploration_type, ) -from tensordict.tensordict import LazyStackedTensorDict, TensorDictBase +from tensordict.tensordict import LazyStackedTensorDict, NestedKey, TensorDictBase +from torchrl.data.tensor_specs import CompositeSpec __all__ = [ "exploration_mode", @@ -49,6 +50,9 @@ def step_mdp( exclude_reward: bool = True, exclude_done: bool = False, exclude_action: bool = True, + reward_key: NestedKey = "reward", + done_key: NestedKey = "done", + action_key: NestedKey = "action", ) -> TensorDictBase: """Creates a new tensordict that reflects a step in time of the input tensordict. @@ -65,20 +69,23 @@ def step_mdp( Default is ``True``. exclude_reward (bool or key, optional): if ``True``, the :obj:`"reward"` key will be discarded from the resulting tensordict. If ``False``, it will be copied (and replaced) - from the ``"next"`` entry (if present). If a key is provided, - this key will be excluded. + from the ``"next"`` entry (if present). Default is ``True``. exclude_done (bool or key, optional): if ``True``, the :obj:`"done"` key will be discarded from the resulting tensordict. If ``False``, it will be copied (and replaced) - from the ``"next"`` entry (if present). If a key is provided, - this key will be excluded. + from the ``"next"`` entry (if present). Default is ``False``. exclude_action (bool or key, optional): if ``True``, the :obj:`"action"` key will be discarded from the resulting tensordict. If ``False``, it will be kept in the root tensordict (since it should not be present in - the ``"next"`` entry). If a key is provided, - this key will be excluded. + the ``"next"`` entry). Default is ``True``. + reward_key (key, optional): the key where the reward is written. Defaults + to "reward". + done_key (key, optional): the key where the done is written. Defaults + to "done". + action_key (key, optional): the key where the action is written. Defaults + to "action". Returns: A new tensordict (or next_tensordict) containing the tensors of the t+1 step. @@ -162,6 +169,9 @@ def step_mdp( exclude_reward=exclude_reward, exclude_done=exclude_done, exclude_action=exclude_action, + reward_key=reward_key, + done_key=done_key, + action_key=action_key, ) for td, ntd in zip(tensordict.tensordicts, next_tensordicts) ], @@ -173,9 +183,6 @@ def step_mdp( return out out = tensordict.get("next").clone(False) excluded = None - done_key = "done" if exclude_done is True else exclude_done - reward_key = "reward" if exclude_reward is True else exclude_reward - action_key = "action" if exclude_action is True else exclude_action if exclude_done and exclude_reward: excluded = {done_key, reward_key} elif exclude_reward: @@ -184,21 +191,16 @@ def step_mdp( excluded = {done_key} if excluded: out = out.exclude(*excluded, inplace=True) - # TODO: make it work with LazyStackedTensorDict - # def _valid_key(key): - # if key == "next" or key in out.keys(): - # return False - # if exclude_action and key == "action": - # return False - # if keep_other or key == "action": - # return True - # return False td_keys = None if keep_other: - out_keys = set(out.keys()) - td_keys = set(tensordict.keys()) - out_keys - {"next"} - if exclude_action: - td_keys = td_keys - {action_key} + out_keys = set(out.keys(True, True)) + td_keys = { + key + for key in tensordict.keys(True, True) + if not (isinstance(key, tuple) and key[0] == "next") + and not (key in out_keys) + and (not exclude_action or key != action_key) + } elif not exclude_action: td_keys = {action_key} @@ -366,6 +368,8 @@ def check_env_specs(env, return_contiguous=True, check_dtype=True, seed=0): ("state", _state_spec), ("obs", _obs_spec), ): + if spec is None: + spec = CompositeSpec(shape=env.batch_size, device=env.device) td = last_td.select(*spec.keys(True, True), strict=True) if not spec.is_in(td): raise AssertionError( @@ -376,6 +380,8 @@ def check_env_specs(env, return_contiguous=True, check_dtype=True, seed=0): ("done", _done_spec), ("obs", _obs_spec), ): + if spec is None: + spec = CompositeSpec(shape=env.batch_size, device=env.device) td = last_td.get("next").select(*spec.keys(True, True), strict=True) if not spec.is_in(td): raise AssertionError(