diff --git a/test/mocking_classes.py b/test/mocking_classes.py index 76dc7060de9..9e7fa98cdfa 100644 --- a/test/mocking_classes.py +++ b/test/mocking_classes.py @@ -7,6 +7,8 @@ import torch import torch.nn as nn from tensordict.tensordict import TensorDict, TensorDictBase +from tensordict.utils import NestedKey + from torchrl.data.tensor_specs import ( BinaryDiscreteTensorSpec, BoundedTensorSpec, @@ -14,6 +16,7 @@ DiscreteTensorSpec, MultiOneHotDiscreteTensorSpec, OneHotDiscreteTensorSpec, + TensorSpec, UnboundedContinuousTensorSpec, ) from torchrl.envs.common import EnvBase @@ -941,6 +944,15 @@ def forward(self, observation, action): return self.linear(torch.cat([observation, action], dim=-1)) +class CountingEnvCountPolicy: + def __init__(self, action_spec: TensorSpec, action_key: NestedKey = "action"): + self.action_spec = action_spec + self.action_key = action_key + + def __call__(self, td: TensorDictBase) -> TensorDictBase: + return td.set(self.action_key, self.action_spec.zero() + 1) + + class CountingEnv(EnvBase): """An env that is done after a given number of steps. @@ -1011,7 +1023,7 @@ def _step( self, tensordict: TensorDictBase, ) -> TensorDictBase: - action = tensordict.get("action") + action = tensordict.get(self.action_key) self.count += action.to(torch.int).to(self.device) tensordict = TensorDict( source={ @@ -1025,38 +1037,149 @@ def _step( return tensordict.select().set("next", tensordict) -class NestedRewardEnv(CountingEnv): +class NestedCountingEnv(CountingEnv): # an env with nested reward and done states - def __init__(self, max_steps: int = 5, start_val: int = 0, **kwargs): + def __init__( + self, + max_steps: int = 5, + start_val: int = 0, + nest_obs_action: bool = True, + nest_done: bool = True, + nest_reward: bool = True, + nested_dim: int = 3, + **kwargs, + ): super().__init__(max_steps=max_steps, start_val=start_val, **kwargs) - self.observation_spec = CompositeSpec( - {("data", "states"): self.observation_spec["observation"].clone()}, - shape=self.batch_size, - ) - self.reward_spec = CompositeSpec( - {("data", "reward"): self.reward_spec.clone()}, shape=self.batch_size - ) - self.done_spec = CompositeSpec( - {("data", "done"): self.done_spec.clone()}, shape=self.batch_size - ) + + self.nested_dim = nested_dim + + self.nested_obs_action = nest_obs_action + self.nested_done = nest_done + self.nested_reward = nest_reward + + if self.nested_obs_action: + self.observation_spec = CompositeSpec( + { + "data": CompositeSpec( + { + "states": self.observation_spec["observation"] + .unsqueeze(-1) + .expand(*self.batch_size, self.nested_dim, 1) + }, + shape=( + *self.batch_size, + self.nested_dim, + ), + ) + }, + shape=self.batch_size, + ) + self.action_spec = CompositeSpec( + { + "data": CompositeSpec( + { + "action": self.action_spec.unsqueeze(-1).expand( + *self.batch_size, self.nested_dim, 1 + ) + }, + shape=( + *self.batch_size, + self.nested_dim, + ), + ) + }, + shape=self.batch_size, + ) + + if self.nested_reward: + self.reward_spec = CompositeSpec( + { + "data": CompositeSpec( + { + "reward": self.reward_spec.unsqueeze(-1).expand( + *self.batch_size, self.nested_dim, 1 + ) + }, + shape=( + *self.batch_size, + self.nested_dim, + ), + ) + }, + shape=self.batch_size, + ) + + if self.nested_done: + self.done_spec = CompositeSpec( + { + "data": CompositeSpec( + { + "done": self.done_spec.unsqueeze(-1).expand( + *self.batch_size, self.nested_dim, 1 + ) + }, + shape=( + *self.batch_size, + self.nested_dim, + ), + ) + }, + shape=self.batch_size, + ) def _reset(self, td): + if self.nested_done and td is not None and "_reset" in td.keys(): + td["_reset"] = td["_reset"].sum(-2, dtype=torch.bool) td = super()._reset(td) - td[self.done_key] = td["done"] - del td["done"] - td["data", "states"] = td["observation"] - del td["observation"] + if self.nested_done: + td[self.done_key] = ( + td["done"].unsqueeze(-1).expand(*self.batch_size, self.nested_dim, 1) + ) + del td["done"] + if self.nested_obs_action: + td["data", "states"] = ( + td["observation"] + .unsqueeze(-1) + .expand(*self.batch_size, self.nested_dim, 1) + ) + del td["observation"] + if "data" in td.keys(): + td["data"].batch_size = (*self.batch_size, self.nested_dim) return td def _step(self, td): + if self.nested_obs_action: + td["data"].batch_size = self.batch_size + td[self.action_key] = td[self.action_key].max(-2)[0] td_root = super()._step(td) + if self.nested_obs_action: + td[self.action_key] = ( + td[self.action_key] + .unsqueeze(-1) + .expand(*self.batch_size, self.nested_dim, 1) + ) + if "data" in td.keys(): + td["data"].batch_size = (*self.batch_size, self.nested_dim) td = td_root["next"] - td[self.reward_key] = td["reward"] - del td["reward"] - td[self.done_key] = td["done"] - del td["done"] - td["data", "states"] = td["observation"] - del td["observation"] + if self.nested_done: + td[self.done_key] = ( + td["done"].unsqueeze(-1).expand(*self.batch_size, self.nested_dim, 1) + ) + del td["done"] + if self.nested_obs_action: + td["data", "states"] = ( + td["observation"] + .unsqueeze(-1) + .expand(*self.batch_size, self.nested_dim, 1) + ) + del td["observation"] + if self.nested_reward: + td[self.reward_key] = ( + td["reward"].unsqueeze(-1).expand(*self.batch_size, self.nested_dim, 1) + ) + del td["reward"] + if "data" in td.keys(): + td["data"].batch_size = (*self.batch_size, self.nested_dim) return td_root diff --git a/test/test_collector.py b/test/test_collector.py index 5ed8b4891ef..5d41d0b6905 100644 --- a/test/test_collector.py +++ b/test/test_collector.py @@ -14,16 +14,18 @@ ContinuousActionVecMockEnv, CountingBatchedEnv, CountingEnv, + CountingEnvCountPolicy, DiscreteActionConvMockEnv, DiscreteActionConvPolicy, DiscreteActionVecMockEnv, DiscreteActionVecPolicy, MockSerialEnv, + NestedCountingEnv, ) from tensordict.nn import TensorDictModule from tensordict.tensordict import assert_allclose_td, TensorDict from torch import nn -from torchrl._utils import seed_generator +from torchrl._utils import prod, seed_generator from torchrl.collectors import aSyncDataCollector, SyncDataCollector from torchrl.collectors.collectors import ( _Interruptor, @@ -33,7 +35,14 @@ ) from torchrl.collectors.utils import split_trajectories from torchrl.data import CompositeSpec, UnboundedContinuousTensorSpec -from torchrl.envs import EnvBase, EnvCreator, ParallelEnv, SerialEnv, StepCounter +from torchrl.envs import ( + EnvBase, + EnvCreator, + InitTracker, + ParallelEnv, + SerialEnv, + StepCounter, +) from torchrl.envs.libs.gym import _has_gym, GymEnv from torchrl.envs.transforms import TransformedEnv, VecNorm from torchrl.modules import Actor, LSTMNet, OrnsteinUhlenbeckProcessWrapper, SafeModule @@ -1346,6 +1355,119 @@ def test_reset_heterogeneous_envs(): ).all() +class TestNestedEnvsCollector: + def test_multi_collector_nested_env_consistency(self, seed=1): + env = NestedCountingEnv() + torch.manual_seed(seed) + env_fn = lambda: TransformedEnv(env, InitTracker()) + policy = CountingEnvCountPolicy(env.action_spec, env.action_key) + + ccollector = MultiaSyncDataCollector( + create_env_fn=[env_fn], + policy=policy, + frames_per_batch=20, + total_frames=100, + device="cpu", + ) + for i, d in enumerate(ccollector): + if i == 0: + c1 = d + elif i == 1: + c2 = d + else: + break + assert d.names[-1] == "time" + with pytest.raises(AssertionError): + assert_allclose_td(c1, c2) + ccollector.shutdown() + + ccollector = MultiSyncDataCollector( + create_env_fn=[env_fn], + policy=policy, + frames_per_batch=20, + total_frames=100, + device="cpu", + ) + for i, d in enumerate(ccollector): + if i == 0: + d1 = d + elif i == 1: + d2 = d + else: + break + assert d.names[-1] == "time" + with pytest.raises(AssertionError): + assert_allclose_td(d1, d2) + ccollector.shutdown() + + assert_allclose_td(c1, d1) + assert_allclose_td(c2, d2) + + @pytest.mark.parametrize("nested_obs_action", [True, False]) + @pytest.mark.parametrize("nested_done", [True, False]) + @pytest.mark.parametrize("nested_reward", [True, False]) + def test_collector_nested_env_combinations( + self, + nested_obs_action, + nested_done, + nested_reward, + seed=1, + frames_per_batch=20, + ): + env = NestedCountingEnv( + nest_reward=nested_reward, + nest_done=nested_done, + nest_obs_action=nested_obs_action, + ) + torch.manual_seed(seed) + policy = CountingEnvCountPolicy(env.action_spec, env.action_key) + ccollector = SyncDataCollector( + create_env_fn=env, + policy=policy, + frames_per_batch=frames_per_batch, + total_frames=100, + device="cpu", + ) + + for _td in ccollector: + break + ccollector.shutdown() + + @pytest.mark.parametrize("batch_size", [(), (5,), (5, 2)]) + def test_nested_env_dims(self, batch_size, nested_dim=5, frames_per_batch=20): + from mocking_classes import CountingEnvCountPolicy, NestedCountingEnv + + env = NestedCountingEnv(batch_size=batch_size, nested_dim=nested_dim) + env_fn = lambda: NestedCountingEnv(batch_size=batch_size, nested_dim=nested_dim) + torch.manual_seed(0) + policy = CountingEnvCountPolicy(env.action_spec, env.action_key) + policy(env.reset()) + ccollector = SyncDataCollector( + create_env_fn=env_fn, + policy=policy, + frames_per_batch=frames_per_batch, + total_frames=100, + device="cpu", + ) + + for _td in ccollector: + break + ccollector.shutdown() + + # assert ("data","reward") not in td.keys(True) # this can be activates once step_mdp is fixed for nested keys + assert _td.batch_size == (*batch_size, frames_per_batch // prod(batch_size)) + assert _td["data"].batch_size == ( + *batch_size, + frames_per_batch // prod(batch_size), + nested_dim, + ) + assert _td["next", "data"].batch_size == ( + *batch_size, + frames_per_batch // prod(batch_size), + nested_dim, + ) + + @pytest.mark.skipif(not torch.cuda.device_count(), reason="No casting if no cuda") class TestUpdateParams: class DummyEnv(EnvBase): diff --git a/test/test_env.py b/test/test_env.py index 3fdccb014c3..46d3d8ad8b6 100644 --- a/test/test_env.py +++ b/test/test_env.py @@ -28,6 +28,7 @@ ContinuousActionVecMockEnv, CountingBatchedEnv, CountingEnv, + CountingEnvCountPolicy, DiscreteActionConvMockEnv, DiscreteActionConvMockEnvNumpy, DiscreteActionVecMockEnv, @@ -35,7 +36,7 @@ MockBatchedLockedEnv, MockBatchedUnLockedEnv, MockSerialEnv, - NestedRewardEnv, + NestedCountingEnv, ) from packaging import version from tensordict.nn import TensorDictModuleBase @@ -1368,14 +1369,13 @@ def test_mp_collector(self, nproc): class TestNestedSpecs: - @pytest.mark.parametrize("envclass", ["CountingEnv", "NestedRewardEnv"]) - def test_nested_reward(self, envclass): - from mocking_classes import NestedRewardEnv + @pytest.mark.parametrize("envclass", ["CountingEnv", "NestedCountingEnv"]) + def test_nested_env(self, envclass): if envclass == "CountingEnv": env = CountingEnv() - elif envclass == "NestedRewardEnv": - env = NestedRewardEnv() + elif envclass == "NestedCountingEnv": + env = NestedCountingEnv() else: raise NotImplementedError reset = env.reset() @@ -1383,7 +1383,7 @@ def test_nested_reward(self, envclass): assert not isinstance(env.reward_spec, CompositeSpec) assert env.done_spec == env.output_spec[("_done_spec", *env.done_key)] assert env.reward_spec == env.output_spec[("_reward_spec", *env.reward_key)] - if envclass == "NestedRewardEnv": + if envclass == "NestedCountingEnv": assert env.done_key == ("data", "done") assert env.reward_key == ("data", "reward") assert ("data", "done") in reset.keys(True) @@ -1393,14 +1393,47 @@ def test_nested_reward(self, envclass): assert env.reward_key not in reset.keys(True) next_state = env.rand_step() - if envclass == "NestedRewardEnv": + if envclass == "NestedCountingEnv": assert ("next", "data", "done") in next_state.keys(True) assert ("next", "data", "states") in next_state.keys(True) assert ("next", "data", "reward") in next_state.keys(True) assert ("next", *env.done_key) in next_state.keys(True) assert ("next", *env.reward_key) in next_state.keys(True) - check_env_specs(env) + # check_env_specs(env) + + @pytest.mark.parametrize("batch_size", [(), (32,), (32, 1)]) + def test_nested_env_dims(self, batch_size, nested_dim=5, rollout_length=3): + + env = NestedCountingEnv(batch_size=batch_size, nested_dim=nested_dim) + + td = env.reset() + assert td.batch_size == batch_size + assert td["data"].batch_size == (*batch_size, nested_dim) + + td = env.rand_step() + assert td.batch_size == batch_size + assert td["data"].batch_size == (*batch_size, nested_dim) + assert td["next", "data"].batch_size == (*batch_size, nested_dim) + + td = env.rollout(rollout_length) + assert td.batch_size == (*batch_size, rollout_length) + assert td["data"].batch_size == (*batch_size, rollout_length, nested_dim) + assert td["next", "data"].batch_size == ( + *batch_size, + rollout_length, + nested_dim, + ) + + policy = CountingEnvCountPolicy(env.action_spec, env.action_key) + td = env.rollout(rollout_length, policy) + assert td.batch_size == (*batch_size, rollout_length) + assert td["data"].batch_size == (*batch_size, rollout_length, nested_dim) + assert td["next", "data"].batch_size == ( + *batch_size, + rollout_length, + nested_dim, + ) @pytest.mark.parametrize( @@ -1420,7 +1453,7 @@ def test_nested_reward(self, envclass): MockBatchedLockedEnv, MockBatchedUnLockedEnv, MockSerialEnv, - NestedRewardEnv, + # NestedCountingEnv, ], ) def test_mocking_envs(envclass): diff --git a/torchrl/collectors/collectors.py b/torchrl/collectors/collectors.py index 6c8d1cc114b..8509f153766 100644 --- a/torchrl/collectors/collectors.py +++ b/torchrl/collectors/collectors.py @@ -24,6 +24,7 @@ import torch.nn as nn from tensordict.nn import TensorDictModule, TensorDictModuleBase from tensordict.tensordict import TensorDict, TensorDictBase +from tensordict.utils import NestedKey from torch import multiprocessing as mp from torch.utils.data import IterableDataset @@ -72,11 +73,12 @@ class RandomPolicy: >>> td = actor(TensorDict(batch_size=[])) # selects a random action in the cube [-1; 1] """ - def __init__(self, action_spec: TensorSpec): + def __init__(self, action_spec: TensorSpec, action_key: NestedKey = "action"): self.action_spec = action_spec + self.action_key = action_key def __call__(self, td: TensorDictBase) -> TensorDictBase: - return td.set("action", self.action_spec.rand()) + return td.set(self.action_key, self.action_spec.rand()) class _Interruptor: @@ -208,7 +210,7 @@ def _get_policy_and_device( raise ValueError( "env must be provided to _get_policy_and_device if policy is None" ) - policy = RandomPolicy(self.env.action_spec) + policy = RandomPolicy(self.env.action_spec, self.env.action_key) elif isinstance(policy, nn.Module): # TODO: revisit these checks when we have determined whether arbitrary # callables should be supported as policies. @@ -240,7 +242,10 @@ def _get_policy_and_device( # we check if all the mandatory params are there if not required_params.difference(set(next_observation)): in_keys = [str(k) for k in sig.parameters if k in next_observation] - out_keys = ["action"] + if not hasattr(self, "env") or self.env is None: + out_keys = ["action"] + else: + out_keys = [self.env.action_key] output = policy(**next_observation) if isinstance(output, tuple): @@ -766,7 +771,7 @@ def iterator(self) -> Iterator[TensorDictBase]: break def _step_and_maybe_reset(self) -> None: - done = self._tensordict.get(("next", "done")) + done = self._tensordict.get(("next", *self.env.done_key)) truncated = self._tensordict.get(("next", "truncated"), None) traj_ids = self._tensordict.get(("collector", "traj_ids")) @@ -796,7 +801,7 @@ def _step_and_maybe_reset(self) -> None: else: self._tensordict.update(td_reset, inplace=True) - done = self._tensordict.get("done") + done = self._tensordict.get(self.env.done_key) if done.any(): raise RuntimeError( f"Env {self.env} was done after reset on specified '_reset' dimensions. This is (currently) not allowed." diff --git a/torchrl/envs/common.py b/torchrl/envs/common.py index eb6f4f4883f..497a8d767e2 100644 --- a/torchrl/envs/common.py +++ b/torchrl/envs/common.py @@ -77,7 +77,7 @@ def specs(self, value: CompositeSpec): @staticmethod def metadata_from_env(env) -> EnvMetaData: tensordict = env.fake_tensordict().clone() - tensordict.set("_reset", torch.zeros_like(tensordict.get("done"))) + tensordict.set("_reset", torch.zeros_like(tensordict.get(env.done_key))) specs = env.specs.to("cpu")