Skip to content
4 changes: 1 addition & 3 deletions test/test_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -1402,8 +1402,6 @@ def test_nested_env(self, envclass):
assert ("next", *env.done_key) in next_state.keys(True)
assert ("next", *env.reward_key) in next_state.keys(True)

# check_env_specs(env)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can remove this since we test the same thing at line 1456


@pytest.mark.parametrize("batch_size", [(), (32,), (32, 1)])
def test_nested_env_dims(self, batch_size, nested_dim=5, rollout_length=3):

Expand Down Expand Up @@ -1455,7 +1453,7 @@ def test_nested_env_dims(self, batch_size, nested_dim=5, rollout_length=3):
MockBatchedLockedEnv,
MockBatchedUnLockedEnv,
MockSerialEnv,
# NestedCountingEnv,
NestedCountingEnv,
],
)
def test_mocking_envs(envclass):
Expand Down
151 changes: 78 additions & 73 deletions torchrl/envs/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@
"step_mdp",
"make_composite_from_td",
]

from torchrl.data import CompositeSpec

AVAILABLE_LIBRARIES = {pkg.key for pkg in pkg_resources.working_set}


Expand Down Expand Up @@ -267,11 +270,44 @@ def _check_dmlab():
}


def _per_level_env_check(data0, data1, check_dtype):
"""Checks shape and dtype of two tensordicts, accounting for lazy stacks."""
if isinstance(data0, LazyStackedTensorDict) and isinstance(
data1, LazyStackedTensorDict
):
if data0.stack_dim != data1.stack_dim:
raise AssertionError(f"Stack dimension mismatch: {data0} vs {data1}.")
for _data0, _data1 in zip(data0.tensordicts, data1.tensordicts):
_per_level_env_check(_data0, _data1, check_dtype=check_dtype)
return
else:
keys0 = set(data0.keys())
keys1 = set(data1.keys())
if keys0 != keys1:
raise AssertionError(f"Keys mismatch: {keys0} vs {keys1}")
for key in keys0:
_data0 = data0[key]
_data1 = data1[key]
if _data0.shape != _data1.shape:
raise AssertionError(
f"The shapes of the real and fake tensordict don't match for key {key}. "
f"Got fake={_data0.shape} and real={_data0.shape}."
)
if isinstance(_data0, TensorDictBase):
_per_level_env_check(_data0, _data1, check_dtype=check_dtype)
else:
if check_dtype and (_data0.dtype != _data1.dtype):
raise AssertionError(
f"The dtypes of the real and fake tensordict don't match for key {key}. "
f"Got fake={_data0.dtype} and real={_data1.dtype}."
)


def check_env_specs(env, return_contiguous=True, check_dtype=True, seed=0):
"""Tests an environment specs against the results of short rollout.

This test function should be used as a sanity check for an env wrapped with
torchrl's EnvBase subclasses: any discrepency between the expected data and
torchrl's EnvBase subclasses: any discrepancy between the expected data and
the data collected should raise an assertion error.

A broken environment spec will likely make it impossible to use parallel
Expand All @@ -294,94 +330,63 @@ def check_env_specs(env, return_contiguous=True, check_dtype=True, seed=0):
torch.manual_seed(seed)
env.set_seed(seed)

fake_tensordict = env.fake_tensordict().flatten_keys(".")
fake_tensordict = env.fake_tensordict()
real_tensordict = env.rollout(3, return_contiguous=return_contiguous)
# # remove private keys
# real_tensordict = real_tensordict.exclude(
# *[
# key
# for key in real_tensordict.keys(True)
# if (isinstance(key, str) and key.startswith("_"))
# or (
# isinstance(key, tuple) and any(subkey.startswith("_") for subkey in key)
# )
# ]
# )
real_tensordict = real_tensordict.flatten_keys(".")

keys1 = set(fake_tensordict.keys(True))
keys2 = set(real_tensordict.keys(True))
if keys1 != keys2:
raise AssertionError(
"The keys of the fake tensordict and the one collected during rollout do not match:"
f"Got fake-real: {keys1-keys2} and real-fake: {keys2-keys1}"
)
fake_tensordict = fake_tensordict.unsqueeze(real_tensordict.batch_dims - 1)
fake_tensordict = fake_tensordict.expand(*real_tensordict.shape)
fake_tensordict = fake_tensordict.to_tensordict()

if return_contiguous:
fake_tensordict = fake_tensordict.unsqueeze(real_tensordict.batch_dims - 1)
fake_tensordict = fake_tensordict.expand(*real_tensordict.shape)
else:
fake_tensordict = torch.stack([fake_tensordict.clone() for _ in range(3)], -1)

if (
fake_tensordict.apply(lambda x: torch.zeros_like(x))
!= real_tensordict.apply(lambda x: torch.zeros_like(x))
).all():
).any():
raise AssertionError(
"zeroing the two tensordicts did not make them identical. "
f"Check for discrepancies:\nFake=\n{fake_tensordict}\nReal=\n{real_tensordict}"
)
for key in keys2:
if fake_tensordict[key].shape != real_tensordict[key].shape:

# Checks shapes and eventually dtypes of keys at all nesting levels
_per_level_env_check(fake_tensordict, real_tensordict, check_dtype=check_dtype)

# Check specs
last_td = real_tensordict[..., -1]
_action_spec = env.input_spec["_action_spec"]
_state_spec = env.input_spec["_state_spec"]
_obs_spec = env.output_spec["_observation_spec"]
_reward_spec = env.output_spec["_reward_spec"]
_done_spec = env.output_spec["_done_spec"]
for name, spec in (
("action", _action_spec),
("state", _state_spec),
("done", _done_spec),
("obs", _obs_spec),
):
if spec is None:
spec = CompositeSpec(shape=env.batch_size, device=env.device)
td = last_td.select(*spec.keys(True, True), strict=True)
if not spec.is_in(td):
raise AssertionError(
f"The shapes of the real and fake tensordict don't match for key {key}. "
f"Got fake={fake_tensordict[key].shape} and real={real_tensordict[key].shape}."
f"spec check failed at root for spec {name}={spec} and data {td}."
)
if check_dtype and (fake_tensordict[key].dtype != real_tensordict[key].dtype):
for name, spec in (
("reward", _reward_spec),
("done", _done_spec),
("obs", _obs_spec),
):
if spec is None:
spec = CompositeSpec(shape=env.batch_size, device=env.device)
td = last_td.get("next").select(*spec.keys(True, True), strict=True)
if not spec.is_in(td):
raise AssertionError(
f"The dtypes of the real and fake tensordict don't match for key {key}. "
f"Got fake={fake_tensordict[key].dtype} and real={real_tensordict[key].dtype}."
f"spec check failed at root for spec {name}={spec} and data {td}."
)

# test dtypes
real_tensordict = env.rollout(3) # keep empty structures, for example dict()
for key, value in real_tensordict[..., -1].items():
_check_isin(key, value, env.observation_spec, env.input_spec)

print("check_env_specs succeeded!")


def _check_isin(key, value, obs_spec, input_spec):
if key in {"reward", "done"}:
return
elif key == "next":
for _key, _value in value.items():
_check_isin(_key, _value, obs_spec, input_spec)
return
elif key in input_spec["_action_spec"].keys(True):
if not input_spec["_action_spec"][key].is_in(value):
raise AssertionError(
f"action_spec.is_in failed for key {key}. "
f"Got action_spec={input_spec['_action_spec'][key]} and real={value}."
)
return

elif key in input_spec.keys(True):
if not input_spec[key].is_in(value):
raise AssertionError(
f"input_spec.is_in failed for key {key}. "
f"Got input_spec={input_spec[key]} and real={value}."
)
return
elif key in obs_spec.keys(True):
if not obs_spec[key].is_in(value):
raise AssertionError(
f"obs_spec.is_in failed for key {key}. "
f"Got obs_spec={obs_spec[key]} and real={value}."
)
return
else:
raise KeyError(
f"key {key} was not found in input spec with keys {input_spec.keys(True)} or obs spec with keys {obs_spec.keys(True)}"
)


def _selective_unsqueeze(tensor: torch.Tensor, batch_size: torch.Size, dim: int = -1):
shape_len = len(tensor.shape)

Expand Down