diff --git a/test/mocking_classes.py b/test/mocking_classes.py index e86dfb13136..d12e3d40069 100644 --- a/test/mocking_classes.py +++ b/test/mocking_classes.py @@ -944,13 +944,18 @@ def forward(self, observation, action): return self.linear(torch.cat([observation, action], dim=-1)) -class CountingEnvCountPolicy: +class CountingEnvCountPolicy(nn.Module): def __init__(self, action_spec: TensorSpec, action_key: NestedKey = "action"): + super().__init__() self.action_spec = action_spec self.action_key = action_key - def __call__(self, td: TensorDictBase) -> TensorDictBase: - return td.set(self.action_key, self.action_spec.zero() + 1) + def __call__(self, t): + action = self.action_spec.zero() + 1 + if isinstance(t, torch.Tensor): + return action + elif isinstance(t, TensorDictBase): + return t.set(self.action_key, action) class CountingEnv(EnvBase): diff --git a/test/test_exploration.py b/test/test_exploration.py index 3bdb80f6a1b..f8181406349 100644 --- a/test/test_exploration.py +++ b/test/test_exploration.py @@ -8,9 +8,13 @@ import pytest import torch from _utils_internal import get_default_devices -from mocking_classes import ContinuousActionVecMockEnv +from mocking_classes import ( + ContinuousActionVecMockEnv, + CountingEnvCountPolicy, + NestedCountingEnv, +) from scipy.stats import ttest_1samp -from tensordict.nn import InteractionType +from tensordict.nn import InteractionType, TensorDictModule from tensordict.tensordict import TensorDict from torch import nn @@ -180,6 +184,59 @@ def test_collector(self, device, parallel_spec, probabilistic, seed=0): pass return + @pytest.mark.parametrize("nested_obs_action", [True, False]) + @pytest.mark.parametrize("nested_done", [True, False]) + @pytest.mark.parametrize("is_init_key", ["some", ("one", "nested")]) + def test_nested( + self, + device, + nested_obs_action, + nested_done, + is_init_key, + seed=0, + n_envs=2, + nested_dim=5, + frames_per_batch=100, + ): + torch.manual_seed(seed) + + env = SerialEnv( + n_envs, + lambda: TransformedEnv( + NestedCountingEnv( + nest_obs_action=nested_obs_action, + nest_done=nested_done, + nested_dim=nested_dim, + ).to(device), + InitTracker(init_key=is_init_key), + ), + ) + + action_spec = env.action_spec + d_act = action_spec.shape[-1] + + net = nn.LazyLinear(d_act).to(device) + policy = TensorDictModule( + CountingEnvCountPolicy(action_spec=action_spec, action_key=env.action_key), + in_keys=[("data", "states") if nested_obs_action else "observation"], + out_keys=[env.action_key], + ) + exploratory_policy = OrnsteinUhlenbeckProcessWrapper( + policy, spec=action_spec, action_key=env.action_key, is_init_key=is_init_key + ) + collector = SyncDataCollector( + create_env_fn=env, + policy=exploratory_policy, + frames_per_batch=frames_per_batch, + total_frames=1000, + device=device, + ) + for _td in collector: + assert _td[is_init_key].shape == _td[env.done_key].shape + break + + return + @pytest.mark.parametrize("device", get_default_devices()) class TestAdditiveGaussian: diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index a8f5953e7ee..c3c960332c6 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -3953,7 +3953,7 @@ class InitTracker(Transform): that is set to ``True`` whenever :meth:`~.reset` is called. Args: - init_key (str, optional): the key to be used for the tracker entry. + init_key (NestedKey, optional): the key to be used for the tracker entry. Examples: >>> from torchrl.envs.libs.gym import GymEnv @@ -3971,7 +3971,7 @@ def __init__(self, init_key: bool = "is_init"): super().__init__(in_keys=[], out_keys=[init_key]) def _call(self, tensordict: TensorDictBase) -> TensorDictBase: - if self.out_keys[0] not in tensordict.keys(): + if self.out_keys[0] not in tensordict.keys(True, True): device = tensordict.device if device is None: device = torch.device("cpu") diff --git a/torchrl/modules/tensordict_module/exploration.py b/torchrl/modules/tensordict_module/exploration.py index 90f451c3af9..20d26b7aabd 100644 --- a/torchrl/modules/tensordict_module/exploration.py +++ b/torchrl/modules/tensordict_module/exploration.py @@ -9,7 +9,7 @@ import torch from tensordict.nn import TensorDictModule, TensorDictModuleWrapper from tensordict.tensordict import TensorDictBase -from tensordict.utils import expand_as_right, NestedKey +from tensordict.utils import expand_as_right, expand_right, NestedKey from torchrl.data.tensor_specs import CompositeSpec, TensorSpec from torchrl.envs.utils import exploration_type, ExplorationType @@ -34,7 +34,7 @@ class EGreedyWrapper(TensorDictModuleWrapper): eps_end (scalar, optional): final epsilon value. default: 0.1 annealing_num_steps (int, optional): number of steps it will take for epsilon to reach the eps_end value - action_key (str, Tuple[str], optional): if the policy module has more than one output key, + action_key (NestedKey, optional): if the policy module has more than one output key, its output spec will be of type CompositeSpec. One needs to know where to find the action spec. Default is "action". @@ -81,7 +81,7 @@ def __init__( eps_init: float = 1.0, eps_end: float = 0.1, annealing_num_steps: int = 1000, - action_key: NestedKey = "action", + action_key: Optional[NestedKey] = "action", spec: Optional[TensorSpec] = None, ): super().__init__(policy) @@ -173,7 +173,7 @@ class AdditiveGaussianWrapper(TensorDictModuleWrapper): sigma to reach the :obj:`sigma_end` value. mean (float, optional): mean of each output element’s normal distribution. std (float, optional): standard deviation of each output element’s normal distribution. - action_key (str, optional): if the policy module has more than one output key, + action_key (NestedKey, optional): if the policy module has more than one output key, its output spec will be of type CompositeSpec. One needs to know where to find the action spec. Default is "action". @@ -204,7 +204,7 @@ def __init__( annealing_num_steps: int = 1000, mean: float = 0.0, std: float = 1.0, - action_key: str = "action", + action_key: Optional[NestedKey] = "action", spec: Optional[TensorSpec] = None, safe: Optional[bool] = True, ): @@ -346,8 +346,10 @@ class OrnsteinUhlenbeckProcessWrapper(TensorDictModuleWrapper): default: None n_steps_annealing (int): number of steps for the sigma annealing. default: 1000 - action_key (str): key of the action to be modified. + action_key (NestedKey, optional): key of the action to be modified. default: "action" + is_init_key (NestedKey, optional): key where to find the is_init flag used to reset the noise steps. + default: "is_init" spec (TensorSpec, optional): if provided, the sampled action will be projected onto the valid action space once explored. If not provided, the exploration wrapper will attempt to recover it from the policy. @@ -392,10 +394,11 @@ def __init__( x0: Optional[Union[torch.Tensor, np.ndarray]] = None, sigma_min: Optional[float] = None, n_steps_annealing: int = 1000, - action_key: str = "action", + action_key: Optional[NestedKey] = "action", + is_init_key: Optional[NestedKey] = "is_init", spec: TensorSpec = None, safe: bool = True, - key: str = None, + key: Optional[NestedKey] = None, ): if key is not None: action_key = key @@ -423,6 +426,7 @@ def __init__( self.annealing_num_steps = annealing_num_steps self.register_buffer("eps", torch.tensor([eps_init])) self.out_keys = list(self.td_module.out_keys) + self.ou.out_keys + self.is_init_key = is_init_key noise_key = self.ou.noise_key steps_key = self.ou.steps_key @@ -432,11 +436,11 @@ def __init__( self._spec = spec elif hasattr(self.td_module, "_spec"): self._spec = self.td_module._spec.clone() - if action_key not in self._spec.keys(): + if action_key not in self._spec.keys(True, True): self._spec[action_key] = None elif hasattr(self.td_module, "spec"): self._spec = self.td_module.spec.clone() - if action_key not in self._spec.keys(): + if action_key not in self._spec.keys(True, True): self._spec[action_key] = None else: self._spec = CompositeSpec({key: None for key in policy.out_keys}) @@ -481,20 +485,20 @@ def step(self, frames: int = 1) -> None: def forward(self, tensordict: TensorDictBase) -> TensorDictBase: tensordict = super().forward(tensordict) if exploration_type() == ExplorationType.RANDOM or exploration_type() is None: - if "is_init" not in tensordict.keys(): + is_init = tensordict.get(self.is_init_key, None) + if is_init is None: warnings.warn( f"The tensordict passed to {self.__class__.__name__} appears to be " - f"missing the 'is_init' entry. This entry is used to " + f"missing the '{self.is_init_key}' entry. This entry is used to " f"reset the noise at the beginning of a trajectory, without it " f"the behaviour of this exploration method is undefined. " f"This is allowed for BC compatibility purposes but it will be deprecated soon! " - f"To create a 'is_init' entry, simply append an torchrl.envs.InitTracker " + f"To create a '{self.is_init_key}' entry, simply append an torchrl.envs.InitTracker " f"transform to your environment with `env = TransformedEnv(env, InitTracker())`." ) - tensordict.set( - "is_init", torch.zeros(*tensordict.shape, 1, dtype=torch.bool) - ) - tensordict = self.ou.add_sample(tensordict, self.eps.item()) + tensordict = self.ou.add_sample( + tensordict, self.eps.item(), is_init=is_init + ) return tensordict @@ -509,7 +513,8 @@ def __init__( x0: Optional[Union[torch.Tensor, np.ndarray]] = None, sigma_min: Optional[float] = None, n_steps_annealing: int = 1000, - key: str = "action", + key: Optional[NestedKey] = "action", + is_init_key: Optional[NestedKey] = "is_init", ): self.mu = mu self.sigma = sigma @@ -528,6 +533,7 @@ def __init__( self.dt = dt self.x0 = x0 if x0 is not None else 0.0 self.key = key + self.is_init_key = is_init_key self._noise_key = "_ou_prev_noise" self._steps_key = "_ou_steps" self.out_keys = [self.noise_key, self.steps_key] @@ -540,43 +546,73 @@ def noise_key(self): def steps_key(self): return self._steps_key # + str(id(self)) - def _make_noise_pair(self, tensordict: TensorDictBase, is_init=None) -> None: + def _make_noise_pair( + self, + action_tensordict: TensorDictBase, + tensordict: TensorDictBase, + is_init: torch.Tensor, + ): + if self.steps_key not in tensordict.keys(): + noise = torch.zeros( + tensordict.get(self.key).shape, device=tensordict.device + ) + steps = torch.zeros( + action_tensordict.batch_size, dtype=torch.long, device=tensordict.device + ) + tensordict.set(self.noise_key, noise) + tensordict.set(self.steps_key, steps) + else: + noise = tensordict.get(self.noise_key) + steps = tensordict.get(self.steps_key) if is_init is not None: - tensordict = tensordict.get_sub_tensordict(is_init.view(tensordict.shape)) - tensordict.set( - self.noise_key, - torch.zeros(tensordict.get(self.key).shape, device=tensordict.device), - inplace=is_init is not None, - ) - tensordict.set( - self.steps_key, - torch.zeros( - torch.Size([*tensordict.batch_size, 1]), - dtype=torch.long, - device=tensordict.device, - ), - inplace=is_init is not None, - ) + noise[is_init] = 0 + steps[is_init] = 0 + return noise, steps def add_sample( - self, tensordict: TensorDictBase, eps: float = 1.0 + self, + tensordict: TensorDictBase, + eps: float = 1.0, + is_init: Optional[torch.Tensor] = None, ) -> TensorDictBase: - if self.noise_key not in tensordict.keys(): - self._make_noise_pair(tensordict) - is_init = tensordict.get("is_init", None) - if is_init is not None and is_init.any(): - self._make_noise_pair(tensordict, is_init.view(tensordict.shape)) - - prev_noise = tensordict.get(self.noise_key) - prev_noise = prev_noise + self.x0 + # Get the nested tensordict where the action lives + if isinstance(self.key, tuple) and len(self.key) > 1: + action_tensordict = tensordict.get(self.key[:-1]) + else: + action_tensordict = tensordict + + if is_init is None: + is_init = tensordict.get(self.is_init_key, None) + if ( + is_init is not None + ): # is_init has the shape of done_spec, let's bring it to the action_tensordict shape + if is_init.ndim > 1 and is_init.shape[-1] == 1: + is_init = is_init.squeeze(-1) # Squeeze dangling dim + if ( + action_tensordict.ndim >= is_init.ndim + ): # if is_init has less dimensions than action_tensordict we expand it + is_init = expand_right(is_init, action_tensordict.shape) + else: + is_init = is_init.sum( + tuple(range(action_tensordict.batch_dims, is_init.ndim)), + dtype=torch.bool, + ) # otherwise we reduce it to that batch_size + if is_init.shape != action_tensordict.shape: + raise ValueError( + f"'{self.is_init_key}' shape not compatible with action tensordict shape, " + f"got {tensordict.get(self.is_init_key).shape} and {action_tensordict.shape}" + ) - n_steps = tensordict.get(self.steps_key) + prev_noise, n_steps = self._make_noise_pair( + action_tensordict, tensordict, is_init + ) + prev_noise = prev_noise + self.x0 noise = ( prev_noise + self.theta * (self.mu - prev_noise) * self.dt - + self.current_sigma(n_steps) + + self.current_sigma(expand_as_right(n_steps, prev_noise)) * np.sqrt(self.dt) * torch.randn_like(prev_noise) )