diff --git a/test/test_env.py b/test/test_env.py index 8e6e4bf193e..44ef32201e5 100644 --- a/test/test_env.py +++ b/test/test_env.py @@ -1402,8 +1402,6 @@ def test_nested_env(self, envclass): assert ("next", *env.done_key) in next_state.keys(True) assert ("next", *env.reward_key) in next_state.keys(True) - # check_env_specs(env) - @pytest.mark.parametrize("batch_size", [(), (32,), (32, 1)]) def test_nested_env_dims(self, batch_size, nested_dim=5, rollout_length=3): @@ -1455,7 +1453,7 @@ def test_nested_env_dims(self, batch_size, nested_dim=5, rollout_length=3): MockBatchedLockedEnv, MockBatchedUnLockedEnv, MockSerialEnv, - # NestedCountingEnv, + NestedCountingEnv, ], ) def test_mocking_envs(envclass): diff --git a/torchrl/envs/utils.py b/torchrl/envs/utils.py index 77cc263a347..2f4df4e350a 100644 --- a/torchrl/envs/utils.py +++ b/torchrl/envs/utils.py @@ -28,6 +28,9 @@ "step_mdp", "make_composite_from_td", ] + +from torchrl.data import CompositeSpec + AVAILABLE_LIBRARIES = {pkg.key for pkg in pkg_resources.working_set} @@ -267,11 +270,44 @@ def _check_dmlab(): } +def _per_level_env_check(data0, data1, check_dtype): + """Checks shape and dtype of two tensordicts, accounting for lazy stacks.""" + if isinstance(data0, LazyStackedTensorDict) and isinstance( + data1, LazyStackedTensorDict + ): + if data0.stack_dim != data1.stack_dim: + raise AssertionError(f"Stack dimension mismatch: {data0} vs {data1}.") + for _data0, _data1 in zip(data0.tensordicts, data1.tensordicts): + _per_level_env_check(_data0, _data1, check_dtype=check_dtype) + return + else: + keys0 = set(data0.keys()) + keys1 = set(data1.keys()) + if keys0 != keys1: + raise AssertionError(f"Keys mismatch: {keys0} vs {keys1}") + for key in keys0: + _data0 = data0[key] + _data1 = data1[key] + if _data0.shape != _data1.shape: + raise AssertionError( + f"The shapes of the real and fake tensordict don't match for key {key}. " + f"Got fake={_data0.shape} and real={_data0.shape}." + ) + if isinstance(_data0, TensorDictBase): + _per_level_env_check(_data0, _data1, check_dtype=check_dtype) + else: + if check_dtype and (_data0.dtype != _data1.dtype): + raise AssertionError( + f"The dtypes of the real and fake tensordict don't match for key {key}. " + f"Got fake={_data0.dtype} and real={_data1.dtype}." + ) + + def check_env_specs(env, return_contiguous=True, check_dtype=True, seed=0): """Tests an environment specs against the results of short rollout. This test function should be used as a sanity check for an env wrapped with - torchrl's EnvBase subclasses: any discrepency between the expected data and + torchrl's EnvBase subclasses: any discrepancy between the expected data and the data collected should raise an assertion error. A broken environment spec will likely make it impossible to use parallel @@ -294,94 +330,63 @@ def check_env_specs(env, return_contiguous=True, check_dtype=True, seed=0): torch.manual_seed(seed) env.set_seed(seed) - fake_tensordict = env.fake_tensordict().flatten_keys(".") + fake_tensordict = env.fake_tensordict() real_tensordict = env.rollout(3, return_contiguous=return_contiguous) - # # remove private keys - # real_tensordict = real_tensordict.exclude( - # *[ - # key - # for key in real_tensordict.keys(True) - # if (isinstance(key, str) and key.startswith("_")) - # or ( - # isinstance(key, tuple) and any(subkey.startswith("_") for subkey in key) - # ) - # ] - # ) - real_tensordict = real_tensordict.flatten_keys(".") - - keys1 = set(fake_tensordict.keys(True)) - keys2 = set(real_tensordict.keys(True)) - if keys1 != keys2: - raise AssertionError( - "The keys of the fake tensordict and the one collected during rollout do not match:" - f"Got fake-real: {keys1-keys2} and real-fake: {keys2-keys1}" - ) - fake_tensordict = fake_tensordict.unsqueeze(real_tensordict.batch_dims - 1) - fake_tensordict = fake_tensordict.expand(*real_tensordict.shape) - fake_tensordict = fake_tensordict.to_tensordict() + + if return_contiguous: + fake_tensordict = fake_tensordict.unsqueeze(real_tensordict.batch_dims - 1) + fake_tensordict = fake_tensordict.expand(*real_tensordict.shape) + else: + fake_tensordict = torch.stack([fake_tensordict.clone() for _ in range(3)], -1) + if ( fake_tensordict.apply(lambda x: torch.zeros_like(x)) != real_tensordict.apply(lambda x: torch.zeros_like(x)) - ).all(): + ).any(): raise AssertionError( "zeroing the two tensordicts did not make them identical. " f"Check for discrepancies:\nFake=\n{fake_tensordict}\nReal=\n{real_tensordict}" ) - for key in keys2: - if fake_tensordict[key].shape != real_tensordict[key].shape: + + # Checks shapes and eventually dtypes of keys at all nesting levels + _per_level_env_check(fake_tensordict, real_tensordict, check_dtype=check_dtype) + + # Check specs + last_td = real_tensordict[..., -1] + _action_spec = env.input_spec["_action_spec"] + _state_spec = env.input_spec["_state_spec"] + _obs_spec = env.output_spec["_observation_spec"] + _reward_spec = env.output_spec["_reward_spec"] + _done_spec = env.output_spec["_done_spec"] + for name, spec in ( + ("action", _action_spec), + ("state", _state_spec), + ("done", _done_spec), + ("obs", _obs_spec), + ): + if spec is None: + spec = CompositeSpec(shape=env.batch_size, device=env.device) + td = last_td.select(*spec.keys(True, True), strict=True) + if not spec.is_in(td): raise AssertionError( - f"The shapes of the real and fake tensordict don't match for key {key}. " - f"Got fake={fake_tensordict[key].shape} and real={real_tensordict[key].shape}." + f"spec check failed at root for spec {name}={spec} and data {td}." ) - if check_dtype and (fake_tensordict[key].dtype != real_tensordict[key].dtype): + for name, spec in ( + ("reward", _reward_spec), + ("done", _done_spec), + ("obs", _obs_spec), + ): + if spec is None: + spec = CompositeSpec(shape=env.batch_size, device=env.device) + td = last_td.get("next").select(*spec.keys(True, True), strict=True) + if not spec.is_in(td): raise AssertionError( - f"The dtypes of the real and fake tensordict don't match for key {key}. " - f"Got fake={fake_tensordict[key].dtype} and real={real_tensordict[key].dtype}." + f"spec check failed at root for spec {name}={spec} and data {td}." ) - # test dtypes - real_tensordict = env.rollout(3) # keep empty structures, for example dict() - for key, value in real_tensordict[..., -1].items(): - _check_isin(key, value, env.observation_spec, env.input_spec) - print("check_env_specs succeeded!") -def _check_isin(key, value, obs_spec, input_spec): - if key in {"reward", "done"}: - return - elif key == "next": - for _key, _value in value.items(): - _check_isin(_key, _value, obs_spec, input_spec) - return - elif key in input_spec["_action_spec"].keys(True): - if not input_spec["_action_spec"][key].is_in(value): - raise AssertionError( - f"action_spec.is_in failed for key {key}. " - f"Got action_spec={input_spec['_action_spec'][key]} and real={value}." - ) - return - - elif key in input_spec.keys(True): - if not input_spec[key].is_in(value): - raise AssertionError( - f"input_spec.is_in failed for key {key}. " - f"Got input_spec={input_spec[key]} and real={value}." - ) - return - elif key in obs_spec.keys(True): - if not obs_spec[key].is_in(value): - raise AssertionError( - f"obs_spec.is_in failed for key {key}. " - f"Got obs_spec={obs_spec[key]} and real={value}." - ) - return - else: - raise KeyError( - f"key {key} was not found in input spec with keys {input_spec.keys(True)} or obs spec with keys {obs_spec.keys(True)}" - ) - - def _selective_unsqueeze(tensor: torch.Tensor, batch_size: torch.Size, dim: int = -1): shape_len = len(tensor.shape)