From cd70b16cbd04c91100c70a4f2c8338dbe29d14ee Mon Sep 17 00:00:00 2001 From: vmoens Date: Fri, 28 Jul 2023 15:58:02 +0100 Subject: [PATCH 01/29] init --- torchrl/envs/common.py | 4 +++- torchrl/envs/transforms/transforms.py | 21 +++++++++++++++++++++ torchrl/envs/vec_env.py | 5 ----- 3 files changed, 24 insertions(+), 6 deletions(-) diff --git a/torchrl/envs/common.py b/torchrl/envs/common.py index 15168f20411..d487729bc6e 100644 --- a/torchrl/envs/common.py +++ b/torchrl/envs/common.py @@ -818,7 +818,7 @@ def step(self, tensordict: TensorDictBase) -> TensorDictBase: """ # sanity check self._assert_tensordict_shape(tensordict) - + next_preset = tensordict.get("next", None) tensordict_out = self._step(tensordict) # this tensordict should contain a "next" key try: @@ -829,6 +829,8 @@ def step(self, tensordict: TensorDictBase) -> TensorDictBase: "values at t+1 have been written under a 'next' entry. This " f"tensordict couldn't be found in the output, got: {tensordict_out}." ) + if next_preset is not None: + next_tensordict_out.update(next_preset) if tensordict_out is tensordict: raise RuntimeError( "EnvBase._step should return outplace changes to the input " diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index 7156c4cf571..bfaf354a53c 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -2962,6 +2962,27 @@ def transform_observation_spec( observation_spec[key] = spec.to(device) return observation_spec + # def transform_input_spec( + # self, input_spec: CompositeSpec + # ) -> CompositeSpec: + # if not isinstance(input_spec, CompositeSpec): + # raise ValueError( + # f"input_spec was expected to be of type CompositeSpec. Got {type(input_spec)} instead." + # ) + # state_spec = input_spec["_state_spec"] + # for key, spec in self.primers.items(): + # if spec.shape[: len(state_spec.shape)] != state_spec.shape: + # raise RuntimeError( + # f"The leading shape of the primer specs ({self.__class__}) should match the one of the parent env. " + # f"Got state_spec.shape={state_spec.shape} but the '{key}' entry's shape is {spec.shape}." + # ) + # try: + # device = state_spec.device + # except RuntimeError: + # device = self.device + # state_spec[key] = spec.to(device) + # return input_spec + @property def _batch_size(self): return self.parent.batch_size diff --git a/torchrl/envs/vec_env.py b/torchrl/envs/vec_env.py index 108066522e1..ccdfe4b67d7 100644 --- a/torchrl/envs/vec_env.py +++ b/torchrl/envs/vec_env.py @@ -762,11 +762,6 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase: if self._single_task: # this is faster than update_ but won't work for lazy stacks for key in self.env_input_keys: - # self.shared_tensordict_parent.set( - # key, - # tensordict.get(key), - # inplace=True, - # ) key = _unravel_key_to_tuple(key) self.shared_tensordict_parent._set_tuple( key, From 4126035e24776e123cdc9c61275bf7aa0cd4b996 Mon Sep 17 00:00:00 2001 From: vmoens Date: Fri, 28 Jul 2023 16:00:34 +0100 Subject: [PATCH 02/29] init --- torchrl/envs/common.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/torchrl/envs/common.py b/torchrl/envs/common.py index d487729bc6e..b7a213f6867 100644 --- a/torchrl/envs/common.py +++ b/torchrl/envs/common.py @@ -830,7 +830,9 @@ def step(self, tensordict: TensorDictBase) -> TensorDictBase: f"tensordict couldn't be found in the output, got: {tensordict_out}." ) if next_preset is not None: - next_tensordict_out.update(next_preset) + next_tensordict_out.update( + next_preset.exclude(*next_tensordict_out.keys(True, True)) + ) if tensordict_out is tensordict: raise RuntimeError( "EnvBase._step should return outplace changes to the input " From 888fee5cabb9276ddf7734718a1f3f48f1840e4e Mon Sep 17 00:00:00 2001 From: vmoens Date: Thu, 10 Aug 2023 10:29:05 -0400 Subject: [PATCH 03/29] amend --- test/mocking_classes.py | 2 +- test/test_tensordictmodules.py | 54 +++++++++++++++++++++- torchrl/envs/common.py | 13 ++++-- torchrl/envs/transforms/transforms.py | 58 ++++++++++++------------ torchrl/modules/tensordict_module/rnn.py | 14 ++++-- 5 files changed, 101 insertions(+), 40 deletions(-) diff --git a/test/mocking_classes.py b/test/mocking_classes.py index 6d5107fcc64..3fdfc24b1cb 100644 --- a/test/mocking_classes.py +++ b/test/mocking_classes.py @@ -435,7 +435,7 @@ def __new__( if categorical_action_encoding else OneHotDiscreteTensorSpec ) - action_spec = action_spec_cls(n=7, shape=(*batch_size, 7)) + action_spec = action_spec_cls(n=7, shape=batch_size) if reward_spec is None: reward_spec = UnboundedContinuousTensorSpec(shape=(1,)) if done_spec is None: diff --git a/test/test_tensordictmodules.py b/test/test_tensordictmodules.py index c3f30f05fb3..391240b20be 100644 --- a/test/test_tensordictmodules.py +++ b/test/test_tensordictmodules.py @@ -4,7 +4,9 @@ # LICENSE file in the root directory of this source tree. import argparse - +from mocking_classes import DiscreteActionVecMockEnv +from tensordict.nn import TensorDictSequential +from torchrl.modules import MLP, ProbabilisticActor import pytest import torch from tensordict import pad, TensorDict, unravel_key_list @@ -1759,6 +1761,56 @@ def test_multi_consecutive(self, shape): td_ss["intermediate"], td["intermediate"][..., -1, :] ) + def test_lstm_parallel_env(self): + from torchrl.envs import ParallelEnv, TransformedEnv, InitTracker + # tests that hidden states are carried over with parallel envs + lstm_module = LSTMModule( + input_size=7, + hidden_size=12, + num_layers=2, + in_key="observation", + out_key="features", + ) + + def create_transformed_env(): + primer = lstm_module.make_tensordict_primer() + env = DiscreteActionVecMockEnv( + categorical_action_encoding=True + ) + env = TransformedEnv(env) + env.append_transform(InitTracker()) + env.append_transform(primer) + return env + + env = ParallelEnv( + create_env_fn=create_transformed_env, + num_workers=2, + ) + + mlp = TensorDictModule( + MLP( + in_features=12, + out_features=7, + num_cells=[], + ), + in_keys=["features"], + out_keys=["logits"], + ) + + actor_model = TensorDictSequential(lstm_module, mlp) + + actor = ProbabilisticActor( + module=actor_model, + in_keys=["logits"], + out_keys=["action"], + distribution_class=torch.distributions.Categorical, + return_log_prob=True, + ) + for break_when_any_done in [False, True]: + data = env.rollout(10, actor, break_when_any_done=break_when_any_done) + assert (data.get("recurrent_state_c") != 0.0).any() + assert (data.get("next", "recurrent_state_c") != 0.0).all() + def test_safe_specs(): diff --git a/torchrl/envs/common.py b/torchrl/envs/common.py index b7a213f6867..2fba61b2661 100644 --- a/torchrl/envs/common.py +++ b/torchrl/envs/common.py @@ -1239,7 +1239,9 @@ def policy(td): tensordict = policy(tensordict) if auto_cast_to_device: tensordict = tensordict.to(env_device, non_blocking=True) + print("before", tensordict["next", "recurrent_state_c"]) tensordict = self.step(tensordict) + print("after", tensordict["next", "recurrent_state_c"]) tensordicts.append(tensordict.clone(False)) done = tensordict.get(("next", self.done_key)) @@ -1334,11 +1336,14 @@ def fake_tensordict(self) -> TensorDictBase: next_output.update(fake_reward) next_output.update(fake_done) fake_in_out.update(fake_done.clone()) + if "next" not in fake_in_out.keys(): + fake_in_out.set("next", next_output) + else: + fake_in_out.get("next").update(next_output) - fake_td = fake_in_out.set("next", next_output) - fake_td.batch_size = self.batch_size - fake_td = fake_td.to(self.device) - return fake_td + fake_in_out.batch_size = self.batch_size + fake_in_out = fake_in_out.to(self.device) + return fake_in_out class _EnvWrapper(EnvBase, metaclass=abc.ABCMeta): diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index bfaf354a53c..ee457f056cc 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -2942,46 +2942,44 @@ def to(self, dtype_or_device): self.device = dtype_or_device return super().to(dtype_or_device) - def transform_observation_spec( - self, observation_spec: CompositeSpec - ) -> CompositeSpec: - if not isinstance(observation_spec, CompositeSpec): - raise ValueError( - f"observation_spec was expected to be of type CompositeSpec. Got {type(observation_spec)} instead." - ) - for key, spec in self.primers.items(): - if spec.shape[: len(observation_spec.shape)] != observation_spec.shape: - raise RuntimeError( - f"The leading shape of the primer specs ({self.__class__}) should match the one of the parent env. " - f"Got observation_spec.shape={observation_spec.shape} but the '{key}' entry's shape is {spec.shape}." - ) - try: - device = observation_spec.device - except RuntimeError: - device = self.device - observation_spec[key] = spec.to(device) - return observation_spec - - # def transform_input_spec( - # self, input_spec: CompositeSpec + # def transform_observation_spec( + # self, observation_spec: CompositeSpec # ) -> CompositeSpec: - # if not isinstance(input_spec, CompositeSpec): + # if not isinstance(observation_spec, CompositeSpec): # raise ValueError( - # f"input_spec was expected to be of type CompositeSpec. Got {type(input_spec)} instead." + # f"observation_spec was expected to be of type CompositeSpec. Got {type(observation_spec)} instead." # ) - # state_spec = input_spec["_state_spec"] # for key, spec in self.primers.items(): - # if spec.shape[: len(state_spec.shape)] != state_spec.shape: + # if spec.shape[: len(observation_spec.shape)] != observation_spec.shape: # raise RuntimeError( # f"The leading shape of the primer specs ({self.__class__}) should match the one of the parent env. " - # f"Got state_spec.shape={state_spec.shape} but the '{key}' entry's shape is {spec.shape}." + # f"Got observation_spec.shape={observation_spec.shape} but the '{key}' entry's shape is {spec.shape}." # ) # try: - # device = state_spec.device + # device = observation_spec.device # except RuntimeError: # device = self.device - # state_spec[key] = spec.to(device) - # return input_spec + # observation_spec[key] = spec.to(device) + # return observation_spec + + def transform_input_spec(self, input_spec: CompositeSpec) -> CompositeSpec: + state_spec = input_spec['_state_spec'] + if state_spec is None: + state_spec = CompositeSpec(shape=input_spec.shape, device=input_spec.device) + for key, spec in self.primers.items(): + if spec.shape[: len(state_spec.shape)] != state_spec.shape: + raise RuntimeError( + f"The leading shape of the primer specs ({self.__class__}) should match the one of the parent env. " + f"Got state_spec.shape={state_spec.shape} but the '{key}' entry's shape is {spec.shape}." + ) + try: + device = state_spec.device + except RuntimeError: + device = self.device + print('state spec key', key) + state_spec[key] = spec.to(device) + input_spec["_state_spec"] = state_spec + return input_spec @property def _batch_size(self): diff --git a/torchrl/modules/tensordict_module/rnn.py b/torchrl/modules/tensordict_module/rnn.py index d511b069612..0ca52e024b2 100644 --- a/torchrl/modules/tensordict_module/rnn.py +++ b/torchrl/modules/tensordict_module/rnn.py @@ -5,11 +5,11 @@ from typing import Optional, Tuple import torch -from tensordict import unravel_key_list +from tensordict import unravel_key_list, unravel_key, TensorDictBase from tensordict.nn import TensorDictModuleBase as ModuleBase -from tensordict.tensordict import NO_DEFAULT, TensorDictBase +from tensordict.tensordict import NO_DEFAULT from tensordict.utils import prod from torch import nn @@ -227,10 +227,16 @@ def make_tuple(key): ) return TensorDictPrimer( { - in_key1: UnboundedContinuousTensorSpec( + # in_key1: UnboundedContinuousTensorSpec( + # shape=(self.lstm.num_layers, self.lstm.hidden_size) + # ), + # in_key2: UnboundedContinuousTensorSpec( + # shape=(self.lstm.num_layers, self.lstm.hidden_size) + # ), + unravel_key(("next", in_key1)): UnboundedContinuousTensorSpec( shape=(self.lstm.num_layers, self.lstm.hidden_size) ), - in_key2: UnboundedContinuousTensorSpec( + unravel_key(("next", in_key2)): UnboundedContinuousTensorSpec( shape=(self.lstm.num_layers, self.lstm.hidden_size) ), } From 5fcb42cbdcefc20c67bdda4107bc926bd7212357 Mon Sep 17 00:00:00 2001 From: vmoens Date: Fri, 11 Aug 2023 05:52:40 -0400 Subject: [PATCH 04/29] init --- torchrl/collectors/collectors.py | 17 +- torchrl/data/tensor_specs.py | 6 +- torchrl/envs/common.py | 69 ++--- torchrl/envs/gym_like.py | 37 ++- torchrl/envs/libs/brax.py | 1 - torchrl/envs/libs/isaacgym.py | 4 +- torchrl/envs/libs/jumanji.py | 2 +- torchrl/envs/libs/openml.py | 2 +- torchrl/envs/model_based/common.py | 5 +- torchrl/envs/transforms/rlhf.py | 7 +- torchrl/envs/transforms/transforms.py | 288 +++++++++++++++----- torchrl/envs/utils.py | 111 +++++++- torchrl/envs/vec_env.py | 373 +++++++++++++------------- 13 files changed, 583 insertions(+), 339 deletions(-) diff --git a/torchrl/collectors/collectors.py b/torchrl/collectors/collectors.py index de5b4c9a075..98a8416d7b3 100644 --- a/torchrl/collectors/collectors.py +++ b/torchrl/collectors/collectors.py @@ -419,9 +419,6 @@ class SyncDataCollector(DataCollectorBase): The _Interruptor class has methods ´start_collection´ and ´stop_collection´, which allow to implement strategies such as preeptively stopping rollout collection. Default is ``False``. - reset_when_done (bool, optional): if ``True`` (default), an environment - that return a ``True`` value in its ``"done"`` or ``"truncated"`` - entry will be reset at the corresponding indices. Examples: >>> from torchrl.envs.libs.gym import GymEnv @@ -542,6 +539,8 @@ def __init__( self.storing_device = torch.device(storing_device) self.env: EnvBase = env self.closed = False + if not reset_when_done: + raise ValueError("reset_when_done is deprectated.") self.reset_when_done = reset_when_done self.n_env = self.env.batch_size.numel() @@ -687,10 +686,6 @@ def __init__( if split_trajs is None: split_trajs = False - elif not self.reset_when_done and split_trajs: - raise RuntimeError( - "Cannot split trajectories when reset_when_done is False." - ) self.split_trajs = split_trajs self._exclude_private_keys = True self.interruptor = interruptor @@ -801,9 +796,6 @@ def _step_and_maybe_reset(self) -> None: action_key=self.env.action_key, ) - if not self.reset_when_done: - return - done_or_terminated = ( (done | truncated) if truncated is not None else done.clone() ) @@ -902,7 +894,7 @@ def rollout(self) -> TensorDictBase: def reset(self, index=None, **kwargs) -> None: """Resets the environments to a new initial state.""" # metadata - md = self._tensordict["collector"].clone() + md = self._tensordict.get("collector").clone() if index is not None: # check that the env supports partial reset if prod(self.env.batch_size) == 0: @@ -914,7 +906,7 @@ def reset(self, index=None, **kwargs) -> None: ) _reset[index] = 1 self._tensordict[index].zero_() - self._tensordict["_reset"] = _reset + self._tensordict.set("_reset", _reset) else: _reset = None self._tensordict.zero_() @@ -1336,6 +1328,7 @@ def _run_processes(self) -> None: pipe_child.close() self.procs.append(proc) self.pipes.append(pipe_parent) + for pipe_parent in self.pipes: msg = pipe_parent.recv() if msg != "instantiated": raise RuntimeError(msg) diff --git a/torchrl/data/tensor_specs.py b/torchrl/data/tensor_specs.py index b11a2e2ebef..be5ab037975 100644 --- a/torchrl/data/tensor_specs.py +++ b/torchrl/data/tensor_specs.py @@ -2093,7 +2093,11 @@ def __eq__(self, other): and self.domain == other.domain ) - def to_numpy(self, val: TensorDict, safe: bool = None) -> dict: + def to_numpy(self, val: torch.Tensor, safe: bool = None) -> dict: + if safe is None: + safe = _CHECK_SPEC_ENCODE + if not self.shape and not safe: + return val.item() return super().to_numpy(val, safe) def to_one_hot(self, val: torch.Tensor, safe: bool = None) -> torch.Tensor: diff --git a/torchrl/envs/common.py b/torchrl/envs/common.py index 8453664dc43..622ce121c0d 100644 --- a/torchrl/envs/common.py +++ b/torchrl/envs/common.py @@ -7,11 +7,12 @@ import abc from copy import deepcopy -from typing import Any, Callable, Dict, Iterator, Optional, Union +from typing import Any, Callable, Dict, Iterator, Optional, Tuple, Union import numpy as np import torch import torch.nn as nn +from tensordict._tensordict import _unravel_key_to_tuple from tensordict.tensordict import TensorDictBase from torchrl._utils import prod, seed_generator @@ -23,7 +24,7 @@ UnboundedContinuousTensorSpec, ) from torchrl.data.utils import DEVICE_TYPING -from torchrl.envs.utils import get_available_libraries, step_mdp +from torchrl.envs.utils import _fuse_tensordicts, get_available_libraries, step_mdp LIBRARIES = get_available_libraries() @@ -819,23 +820,13 @@ def step(self, tensordict: TensorDictBase) -> TensorDictBase: # sanity check self._assert_tensordict_shape(tensordict) - tensordict_out = self._step(tensordict) - # this tensordict should contain a "next" key - try: - next_tensordict_out = tensordict_out.get("next") - except KeyError: - raise RuntimeError( - "The value returned by env._step must be a tensordict where the " - "values at t+1 have been written under a 'next' entry. This " - f"tensordict couldn't be found in the output, got: {tensordict_out}." - ) - if tensordict_out is tensordict: - raise RuntimeError( - "EnvBase._step should return outplace changes to the input " - "tensordict. Consider emptying the TensorDict first (e.g. tensordict.empty() or " - "tensordict.select()) inside _step before writing new tensors onto this new instance." - ) + next_tensordict = self._step(tensordict) + next_tensordict = self._step_proc_data(next_tensordict) + # tensordict could already have a "next" key + tensordict.set("next", next_tensordict) + return tensordict + def _step_proc_data(self, next_tensordict_out): # TODO: Refactor this using reward spec reward = next_tensordict_out.get(self.reward_key) # unsqueeze rewards if needed @@ -864,11 +855,11 @@ def step(self, tensordict: TensorDictBase) -> TensorDictBase: if actual_done_shape != expected_done_shape: done = done.view(expected_done_shape) next_tensordict_out.set(self.done_key, done) - tensordict_out.set("next", next_tensordict_out) if self.run_type_checks: - for key in self._select_observation_keys(tensordict_out): - obs = tensordict_out.get(key) + # TODO: check these errors + for key in self._select_observation_keys(next_tensordict_out): + obs = next_tensordict_out.get(key) self.observation_spec.type_check(obs, key) if ( @@ -877,17 +868,14 @@ def step(self, tensordict: TensorDictBase) -> TensorDictBase: ): raise TypeError( f"expected reward.dtype to be {self.reward_spec.dtype} " - f"but got {tensordict_out.get(self.reward_key).dtype}" + f"but got {next_tensordict_out.get(self.reward_key).dtype}" ) if next_tensordict_out.get(self.done_key).dtype is not self.done_spec.dtype: raise TypeError( - f"expected done.dtype to be torch.bool but got {tensordict_out.get(self.done_key).dtype}" + f"expected done.dtype to be torch.bool but got {next_tensordict_out.get(self.done_key).dtype}" ) - # tensordict could already have a "next" key - tensordict.update(tensordict_out) - - return tensordict + return next_tensordict_out def _get_in_keys_to_exclude(self, tensordict): if self._cache_in_keys is None: @@ -942,8 +930,9 @@ def reset( _reset = None tensordict_reset = self._reset(tensordict, **kwargs) - if tensordict_reset.device != self.device: - tensordict_reset = tensordict_reset.to(self.device) + # We assume that this is done properly + # if tensordict_reset.device != self.device: + # tensordict_reset = tensordict_reset.to(self.device, non_blocking=True) if tensordict_reset is tensordict: raise RuntimeError( "EnvBase._reset should return outplace changes to the input " @@ -959,16 +948,14 @@ def reset( leading_dim = tensordict_reset.shape[: -len(self.batch_size)] else: leading_dim = tensordict_reset.shape - if self.done_spec is not None and self.done_key not in tensordict_reset.keys( - True, True - ): - tensordict_reset.set( - self.done_key, - self.done_spec.zero(leading_dim), - ) - - if (_reset is None and tensordict_reset.get(self.done_key).any()) or ( - _reset is not None and tensordict_reset.get(self.done_key)[_reset].any() + done_spec = self.done_spec + done = tensordict_reset.get(self.done_key, None) + if done is None: + done = done_spec.zero(leading_dim) + key = self.done_key + tensordict_reset.set(key, done) + elif (_reset is None and done.any()) or ( + _reset is not None and done[_reset].any() ): raise RuntimeError( f"Env {self} was done after reset on specified '_reset' dimensions. This is (currently) not allowed." @@ -1014,7 +1001,7 @@ def set_state(self): def _assert_tensordict_shape(self, tensordict: TensorDictBase) -> None: if ( - self.batch_locked or self.batch_size != torch.Size([]) + self.batch_locked or self.batch_size != () ) and tensordict.batch_size != self.batch_size: raise RuntimeError( f"Expected a tensordict with shape==env.shape, " @@ -1088,7 +1075,7 @@ def rollout( break_when_any_done: bool = True, return_contiguous: bool = True, tensordict: Optional[TensorDictBase] = None, - ) -> TensorDictBase: + ): """Executes a rollout in the environment. The function will stop as soon as one of the contained environments diff --git a/torchrl/envs/gym_like.py b/torchrl/envs/gym_like.py index 7f641989b7a..2042511bd99 100644 --- a/torchrl/envs/gym_like.py +++ b/torchrl/envs/gym_like.py @@ -154,7 +154,9 @@ def read_reward(self, total_reward, step_reward): step_reward (reward in the format provided by the inner env): reward of this particular step """ - return total_reward + self.reward_spec.encode(step_reward, ignore_device=True) + return ( + total_reward + step_reward + ) # self.reward_spec.encode(step_reward, ignore_device=True) def read_obs( self, observations: Union[Dict[str, Any], torch.Tensor, np.ndarray] @@ -183,7 +185,7 @@ def read_obs( return observations def _step(self, tensordict: TensorDictBase) -> TensorDictBase: - action = tensordict.get("action") + action = tensordict.get(self.action_key) action_np = self.read_action(action) reward = 0 @@ -218,7 +220,6 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase: isinstance(done, np.ndarray) and not len(done) ): done = torch.tensor([done]) - done, do_break = self.read_done(done) if do_break: break @@ -227,19 +228,14 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase: if reward is None: reward = torch.tensor(np.nan).expand(self.reward_spec.shape) - # reward = self._to_tensor(reward, dtype=self.reward_spec.dtype) - # done = self._to_tensor(done, dtype=torch.bool) - obs_dict["reward"] = reward - obs_dict["done"] = done - obs_dict = {("next", key): val for key, val in obs_dict.items()} + obs_dict[self.reward_key] = reward + obs_dict[self.done_key] = done - tensordict_out = TensorDict( - obs_dict, batch_size=tensordict.batch_size, device=self.device - ) + tensordict_out = TensorDict(obs_dict, batch_size=tensordict.batch_size) if self.info_dict_reader is not None and info is not None: - self.info_dict_reader(info, tensordict_out.get("next")) - + self.info_dict_reader(info, tensordict_out) + tensordict_out = tensordict_out.to(self.device, non_blocking=True) return tensordict_out def _reset( @@ -253,10 +249,13 @@ def _reset( if len(other) == 1: info = other[0] + source = self.read_obs(obs) + + # if self.done_key not in source: + # source[self.done_key] = self.done_spec.zero() tensordict_out = TensorDict( - source=self.read_obs(obs), + source=source, batch_size=self.batch_size, - device=self.device, ) if self.info_dict_reader is not None and info is not None: self.info_dict_reader(info, tensordict_out) @@ -264,12 +263,8 @@ def _reset( # populate the reset with the items we have not seen from info for key, item in self.observation_spec.items(): if key not in tensordict_out.keys(): - tensordict_out[key] = item.zero() - - tensordict_out.setdefault( - "done", - self.done_spec.zero(), - ) + source[key] = item.zero() + tensordict_out = tensordict_out.to(self.device, non_blocking=True) return tensordict_out def _output_transform(self, step_outputs_tuple: Tuple) -> Tuple: diff --git a/torchrl/envs/libs/brax.py b/torchrl/envs/libs/brax.py index 06c5e4db28a..0119a17daa4 100644 --- a/torchrl/envs/libs/brax.py +++ b/torchrl/envs/libs/brax.py @@ -282,7 +282,6 @@ def _step( out = self._step_with_grad(tensordict) else: out = self._step_without_grad(tensordict) - out = out.select().set("next", out) return out diff --git a/torchrl/envs/libs/isaacgym.py b/torchrl/envs/libs/isaacgym.py index 3d181982d6c..6e7f47df189 100644 --- a/torchrl/envs/libs/isaacgym.py +++ b/torchrl/envs/libs/isaacgym.py @@ -36,8 +36,8 @@ class IsaacGymWrapper(GymWrapper): """ def __init__( - self, env: "isaacgymenvs.tasks.base.vec_task.Env", **kwargs - ): # noqa: F821 + self, env: "isaacgymenvs.tasks.base.vec_task.Env", **kwargs # noqa: F821 + ): warnings.warn( "IsaacGym environment support is an experimental feature that may change in the future." ) diff --git a/torchrl/envs/libs/jumanji.py b/torchrl/envs/libs/jumanji.py index 690b81f2c47..70181971f05 100644 --- a/torchrl/envs/libs/jumanji.py +++ b/torchrl/envs/libs/jumanji.py @@ -282,7 +282,7 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase: tensordict_out.set("done", done) tensordict_out["state"] = state_dict - return tensordict_out.select().set("next", tensordict_out) + return tensordict_out def _reset( self, tensordict: Optional[TensorDictBase] = None, **kwargs diff --git a/torchrl/envs/libs/openml.py b/torchrl/envs/libs/openml.py index 8cbc9dfb5b4..c7cc7f5cf16 100644 --- a/torchrl/envs/libs/openml.py +++ b/torchrl/envs/libs/openml.py @@ -127,7 +127,7 @@ def _step( self.batch_size, device=self.device, ) - return td.select().set("next", td) + return td def _set_seed(self, seed): self.rng = torch.random.manual_seed(seed) diff --git a/torchrl/envs/model_based/common.py b/torchrl/envs/model_based/common.py index 1a63b0f5c45..8bb4673baec 100644 --- a/torchrl/envs/model_based/common.py +++ b/torchrl/envs/model_based/common.py @@ -167,10 +167,7 @@ def _step( dtype=torch.bool, device=tensordict_out.device, ) - return tensordict_out.select().set( - "next", - tensordict_out.select(*self.observation_spec.keys(), "reward", "done"), - ) + return tensordict_out.select(*self.observation_spec.keys(), "reward", "done") @abc.abstractmethod def _reset(self, tensordict: TensorDict, **kwargs) -> TensorDict: diff --git a/torchrl/envs/transforms/rlhf.py b/torchrl/envs/transforms/rlhf.py index 51b8104b4ea..d057f20205b 100644 --- a/torchrl/envs/transforms/rlhf.py +++ b/torchrl/envs/transforms/rlhf.py @@ -177,7 +177,12 @@ def _call(self, tensordict: TensorDictBase) -> TensorDictBase: tensordict.set(("next", *self.out_keys[0]), reward + self.coef * kl) return tensordict - _step = _call + def _step( + self, tensordict: TensorDictBase, next_tensordict: TensorDictBase + ) -> TensorDictBase: + with tensordict.unlock_(): + return self._call(tensordict.set("next", next_tensordict)).pop("next") + forward = _call def transform_output_spec(self, output_spec: CompositeSpec) -> CompositeSpec: diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index 7156c4cf571..4007869a5d8 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -9,6 +9,7 @@ import multiprocessing as mp import warnings from copy import copy, deepcopy +from functools import wraps from textwrap import indent from typing import Any, List, Optional, OrderedDict, Sequence, Tuple, Union @@ -65,6 +66,7 @@ def interpolation_fn(interpolation): # noqa: D103 def _apply_to_composite(function): + @wraps(function) def new_fun(self, observation_spec): if isinstance(observation_spec, CompositeSpec): d = observation_spec._specs @@ -215,17 +217,17 @@ def _call(self, tensordict: TensorDictBase) -> TensorDictBase: def forward(self, tensordict: TensorDictBase) -> TensorDictBase: """Reads the input tensordict, and for the selected keys, applies the transform.""" for in_key, out_key in zip(self.in_keys, self.out_keys): - if in_key in tensordict.keys(include_nested=True): - observation = self._apply_transform(tensordict.get(in_key)) - tensordict.set( - out_key, - observation, - ) + data = tensordict.get(in_key, None) + if data is not None: + data = self._apply_transform(data) + tensordict.set(out_key, data) elif not self.missing_tolerance: raise KeyError(f"'{in_key}' not found in tensordict {tensordict}") return tensordict - def _step(self, tensordict: TensorDictBase) -> TensorDictBase: + def _step( + self, tensordict: TensorDictBase, next_tensordict: TensorDictBase + ) -> TensorDictBase: """The parent method of a transform during the ``env.step`` execution. This method should be overwritten whenever the :meth:`~._step` needs to be @@ -237,11 +239,14 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase: :meth:`~._step` will only be called by :meth:`TransformedEnv.step` and not by :meth:`TransformedEnv.reset`. + Args: + tensordict (TensorDictBase): data at time t + next_tensordict (TensorDictBase): data at time t+1 + + Returns: the data at t+1 """ - next_tensordict = tensordict.get("next") next_tensordict = self._call(next_tensordict) - tensordict.set("next", next_tensordict) - return tensordict + return next_tensordict def _inv_apply_transform(self, obs: torch.Tensor) -> torch.Tensor: if self.invertible: @@ -254,12 +259,10 @@ def _inv_call(self, tensordict: TensorDictBase) -> TensorDictBase: # # exposed to the user: we'd like that the input keys remain unchanged # # in the originating script if they're being transformed. for in_key, out_key in zip(self.in_keys_inv, self.out_keys_inv): - if in_key in tensordict.keys(include_nested=True): - item = self._inv_apply_transform(tensordict.get(in_key)) - tensordict.set( - out_key, - item, - ) + data = tensordict.get(in_key, None) + if data is not None: + item = self._inv_apply_transform(data) + tensordict.set(out_key, item) elif not self.missing_tolerance: raise KeyError(f"'{in_key}' not found in tensordict {tensordict}") @@ -391,10 +394,7 @@ def parent(self) -> Optional[EnvBase]: compose_parent = TransformedEnv( compose.__dict__["_container"].base_env ) - if compose_parent.transform is not compose: - comp_parent_trans = compose_parent.transform.clone() - else: - comp_parent_trans = None + comp_parent_trans = compose_parent.transform.clone() out = TransformedEnv( compose_parent.base_env, transform=comp_parent_trans, @@ -628,11 +628,10 @@ def done_spec(self) -> TensorSpec: def _step(self, tensordict: TensorDictBase) -> TensorDictBase: tensordict = tensordict.clone(False) tensordict_in = self.transform.inv(tensordict) - tensordict_out = self.base_env._step(tensordict_in) + next_tensordict = self.base_env._step(tensordict_in) # we want the input entries to remain unchanged - tensordict_out = tensordict.update(tensordict_out) - tensordict_out = self.transform._step(tensordict_out) - return tensordict_out + next_tensordict = self.transform._step(tensordict, next_tensordict) + return next_tensordict def set_seed( self, seed: Optional[int] = None, static_seed: bool = False @@ -841,10 +840,12 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: tensordict = t(tensordict) return tensordict - def _step(self, tensordict: TensorDictBase) -> TensorDictBase: + def _step( + self, tensordict: TensorDictBase, next_tensordict: TensorDictBase + ) -> TensorDictBase: for t in self.transforms: - tensordict = t._step(tensordict) - return tensordict + next_tensordict = t._step(tensordict, next_tensordict) + return next_tensordict def _inv_call(self, tensordict: TensorDictBase) -> TensorDictBase: for t in reversed(self.transforms): @@ -1162,10 +1163,12 @@ def _call(self, tensordict: TensorDict) -> TensorDict: raise KeyError(f"'{in_key}' not found in tensordict {tensordict}") return tensordict - def _step(self, tensordict: TensorDictBase) -> TensorDictBase: + def _step( + self, tensordict: TensorDictBase, next_tensordict: TensorDictBase + ) -> TensorDictBase: for out_key in self.out_keys: - tensordict.set(("next", out_key), tensordict.get(out_key)) - return super()._step(tensordict) + next_tensordict.set(out_key, tensordict.get(out_key)) + return super()._step(tensordict, next_tensordict) def _apply_transform( self, reward: torch.Tensor, target_return: torch.Tensor @@ -2350,6 +2353,21 @@ def _call(self, tensordict: TensorDictBase) -> TensorDictBase: class DoubleToFloat(Transform): """Maps actions float to double before they are called on the environment. + Depending on whether the ``in_keys`` or ``in_keys_inv`` are provided + during construction, the class behaviour will change: + + * If the keys are provided, those entries and those entries only will be + transformed from ``float64`` to ``float32`` entries; + * If the keys are not provided and the object is within an environment + register of transforms, the input and output specs that have a dtype + set to ``float64`` will be used as in_keys_inv / in_keys respectively. + * If the keys are not provided and the object is used without an + environment, the ``forward`` / ``inverse`` pass will scan through the + input tensordict for all float64 values and map them to a float32 + tensor. For large data structures, this can impact performance as this + scanning doesn't come for free. The keys to be + transformed will not be cached. + Args: in_keys (sequence of NestedKey, optional): list of double keys to be converted to float before being exposed to external objects and functions. @@ -2358,11 +2376,76 @@ class DoubleToFloat(Transform): Examples: >>> td = TensorDict( - ... {'obs': torch.ones(1, dtype=torch.double)}, []) + ... {'obs': torch.ones(1, dtype=torch.double), + ... 'not_transformed': torch.ones(1, dtype=torch.double), + ... }, []) >>> transform = DoubleToFloat(in_keys=["obs"]) >>> _ = transform(td) >>> print(td.get("obs").dtype) torch.float32 + >>> print(td.get("not_transformed").dtype) + torch.float64 + + In "automatic" mode, all float64 entries are transformed: + + Examples: + >>> td = TensorDict( + ... {'obs': torch.ones(1, dtype=torch.double), + ... 'not_transformed': torch.ones(1, dtype=torch.double), + ... }, []) + >>> transform = DoubleToFloat() + >>> _ = transform(td) + >>> print(td.get("obs").dtype) + torch.float32 + >>> print(td.get("not_transformed").dtype) + torch.float32 + + The same behaviour is the rule when environments are constructedw without + specifying the transform keys: + + Examples: + >>> class MyEnv(EnvBase): + ... def __init__(self): + ... super().__init__() + ... self.observation_spec = CompositeSpec(obs=UnboundedContinuousTensorSpec((), dtype=torch.float64)) + ... self.action_spec = UnboundedContinuousTensorSpec((), dtype=torch.float64) + ... self.reward_spec = UnboundedContinuousTensorSpec((1,), dtype=torch.float64) + ... self.done_spec = UnboundedContinuousTensorSpec((1,), dtype=torch.bool) + ... def _reset(self, data=None): + ... return TensorDict({"done": torch.zeros((1,), dtype=torch.bool), **self.observation_spec.rand()}, []) + ... def _step(self, data): + ... assert data["action"].dtype == torch.float64 + ... reward = self.reward_spec.rand() + ... done = torch.zeros((1,), dtype=torch.bool) + ... obs = self.observation_spec.rand() + ... assert reward.dtype == torch.float64 + ... assert obs["obs"].dtype == torch.float64 + ... return obs.select().set("next", obs.update({"reward": reward, "done": done})) + ... def _set_seed(self, seed): + ... pass + >>> env = TransformedEnv(MyEnv(), DoubleToFloat()) + >>> assert env.action_spec.dtype == torch.float32 + >>> assert env.observation_spec["obs"].dtype == torch.float32 + >>> assert env.reward_spec.dtype == torch.float32, env.reward_spec.dtype + >>> print(env.rollout(2)) + TensorDict( + fields={ + action: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, is_shared=False), + done: Tensor(shape=torch.Size([2, 1]), device=cpu, dtype=torch.bool, is_shared=False), + next: TensorDict( + fields={ + done: Tensor(shape=torch.Size([2, 1]), device=cpu, dtype=torch.bool, is_shared=False), + obs: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, is_shared=False), + reward: Tensor(shape=torch.Size([2, 1]), device=cpu, dtype=torch.float32, is_shared=False)}, + batch_size=torch.Size([2]), + device=cpu, + is_shared=False), + obs: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, is_shared=False)}, + batch_size=torch.Size([2]), + device=cpu, + is_shared=False) + >>> assert env.transform.in_keys == ["obs", "reward"] + >>> assert env.transform.in_keys_inv == ["action"] """ @@ -2373,8 +2456,75 @@ def __init__( in_keys: Optional[Sequence[NestedKey]] = None, in_keys_inv: Optional[Sequence[NestedKey]] = None, ): + if in_keys is None: + self._keys_unset = True + in_keys = [] + else: + self._keys_unset = False + if in_keys_inv is None: + self._keys_inv_unset = True + in_keys_inv = [] + else: + self._keys_inv_unset = False + super().__init__(in_keys=in_keys, in_keys_inv=in_keys_inv) + def _set_in_keys(self): + env_base = self.parent + if env_base is not None: + # retrieve the specs that are float32 + if self._keys_unset: + in_keys = [] + observation_spec = env_base.observation_spec + for key, spec in observation_spec.items(True, True): + if spec.dtype == torch.float64: + in_keys.append(unravel_key(key)) + reward_spec = env_base.reward_spec + if reward_spec.dtype == torch.float64: + in_keys.append(unravel_key(env_base.reward_key)) + + self.in_keys = self.out_keys = in_keys + self._keys_unset = False + if self._keys_inv_unset: + in_keys_inv = [] + state_spec = env_base.state_spec + if state_spec is not None: + for key, spec in state_spec.items(True, True): + if spec.dtype == torch.float64: + in_keys_inv.append(unravel_key(key)) + action_spec = env_base.action_spec + if action_spec.dtype == torch.float64: + in_keys_inv.append(unravel_key(env_base.action_key)) + self.in_keys_inv = self.out_keys_inv = in_keys_inv + self._keys_inv_unset = False + self._container.empty_cache() + + @dispatch(source="in_keys", dest="out_keys") + def forward(self, tensordict: TensorDictBase) -> TensorDictBase: + """Reads the input tensordict, and for the selected keys, applies the transform.""" + if self._keys_unset: + self._set_in_keys() + for in_key, data in tensordict.items(True, True): + if data.dtype == torch.float64: + out_key = in_key + data = self._apply_transform(data) + tensordict.set(out_key, data) + return tensordict + return super().forward(tensordict) + + def _inv_call(self, tensordict: TensorDictBase) -> TensorDictBase: + if self._keys_inv_unset: + self._set_in_keys() + # we can't differentiate between content of forward and inverse + tensordict = tensordict.clone(False) + for in_key, data in tensordict.items(True, True): + if data.dtype == torch.float32: + out_key = in_key + data = self._inv_apply_transform(data) + tensordict.set(out_key, data) + return tensordict + return super()._inv_call(tensordict) + def _apply_transform(self, obs: torch.Tensor) -> torch.Tensor: return obs.to(torch.float) @@ -2393,6 +2543,8 @@ def _transform_spec(self, spec: TensorSpec) -> None: space.maximum = space.maximum.to(torch.float) def transform_input_spec(self, input_spec: TensorSpec) -> TensorSpec: + if self._keys_inv_unset: + self._set_in_keys() action_spec = input_spec["_action_spec"] state_spec = input_spec["_state_spec"] for key in self.in_keys_inv: @@ -2411,6 +2563,8 @@ def transform_input_spec(self, input_spec: TensorSpec) -> TensorSpec: @_apply_to_composite def transform_reward_spec(self, reward_spec: TensorSpec) -> TensorSpec: + if self._keys_unset: + self._set_in_keys() reward_key = self.parent.reward_key if self.parent is not None else "reward" if unravel_key(reward_key) in self.in_keys: if reward_spec.dtype is not torch.double: @@ -2419,8 +2573,13 @@ def transform_reward_spec(self, reward_spec: TensorSpec) -> TensorSpec: self._transform_spec(reward_spec) return reward_spec - @_apply_to_composite def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec: + if self._keys_unset: + self._set_in_keys() + return self._transform_observation_spec(observation_spec) + + @_apply_to_composite + def _transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec: self._transform_spec(observation_spec) return observation_spec @@ -2728,16 +2887,18 @@ def __init__(self, frame_skip: int = 1): raise ValueError("frame_skip should have a value greater or equal to one.") self.frame_skip = frame_skip - def _step(self, tensordict: TensorDictBase) -> TensorDictBase: + def _step( + self, tensordict: TensorDictBase, next_tensordict: TensorDictBase + ) -> TensorDictBase: parent = self.parent if parent is None: raise RuntimeError("parent not found for FrameSkipTransform") reward_key = parent.reward_key - reward = tensordict.get(("next", reward_key)) + reward = next_tensordict.get(reward_key) for _ in range(self.frame_skip - 1): - tensordict = parent._step(tensordict) - reward = reward + tensordict.get(("next", reward_key)) - return tensordict.set(("next", reward_key), reward) + next_tensordict = parent._step(tensordict) + reward = reward + next_tensordict.get(reward_key) + return next_tensordict.set(reward_key, reward) def forward(self, tensordict): raise RuntimeError( @@ -2985,10 +3146,12 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: tensordict.set(key, value) return tensordict - def _step(self, tensordict: TensorDictBase) -> TensorDictBase: + def _step( + self, tensordict: TensorDictBase, next_tensordict: TensorDictBase + ) -> TensorDictBase: for key in self.primers.keys(): - tensordict.setdefault(("next", key), tensordict.get(key, default=None)) - return tensordict + next_tensordict.setdefault(key, tensordict.get(key, default=None)) + return next_tensordict def reset(self, tensordict: TensorDictBase) -> TensorDictBase: """Sets the default values in the input tensordict. @@ -3444,10 +3607,11 @@ def reset(self, tensordict: TensorDictBase) -> TensorDictBase: ) from err return tensordict - def _step(self, tensordict: TensorDictBase) -> TensorDictBase: + def _step( + self, tensordict: TensorDictBase, next_tensordict: TensorDictBase + ) -> TensorDictBase: """Updates the episode rewards with the step rewards.""" # Update episode rewards - next_tensordict = tensordict.get("next") for in_key, out_key in zip(self.in_keys, self.out_keys): if in_key in next_tensordict.keys(include_nested=True): reward = next_tensordict.get(in_key) @@ -3456,8 +3620,7 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase: next_tensordict.set(out_key, tensordict.get(out_key) + reward) elif not self.missing_tolerance: raise KeyError(f"'{in_key}' not found in tensordict {tensordict}") - tensordict.set("next", next_tensordict) - return tensordict + return next_tensordict def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec: """Transforms the observation spec, adding the new keys generated by RewardSum.""" @@ -3549,7 +3712,7 @@ def __init__( self.step_count_key = step_count_key super().__init__([]) - def reset(self, tensordict: TensorDictBase) -> TensorDictBase: + def _get_done(self, tensordict): done_key = self.parent.done_key if self.parent else "done" done = tensordict.get(done_key, None) if done is None: @@ -3558,41 +3721,42 @@ def reset(self, tensordict: TensorDictBase) -> TensorDictBase: dtype=self.parent.done_spec.dtype, device=self.parent.done_spec.device, ) + return done + + def reset(self, tensordict: TensorDictBase) -> TensorDictBase: + done = None _reset = tensordict.get( "_reset", # TODO: decide if using done here, or using a default `True` tensor default=None, ) if _reset is None: + done = self._get_done(tensordict) _reset = torch.ones_like(done) - step_count = tensordict.get( - self.step_count_key, - default=None, - ) + step_count = tensordict.get(self.step_count_key, default=None) if step_count is None: + if done is None: + # avoid getting done if not needed + done = self._get_done(tensordict) step_count = torch.zeros_like(done, dtype=torch.int64) - - step_count[_reset] = 0 - tensordict.set( - self.step_count_key, - step_count, - ) + step_count = torch.where(~_reset, step_count, 0) + tensordict.set(self.step_count_key, step_count) if self.max_steps is not None: truncated = step_count >= self.max_steps tensordict.set(self.truncated_key, truncated) return tensordict - def _step(self, tensordict: TensorDictBase) -> TensorDictBase: + def _step( + self, tensordict: TensorDictBase, next_tensordict: TensorDictBase + ) -> TensorDictBase: tensordict = tensordict.clone(False) - step_count = tensordict.get( - self.step_count_key, - ) + step_count = tensordict.get(self.step_count_key) next_step_count = step_count + 1 - tensordict.set(("next", self.step_count_key), next_step_count) + next_tensordict.set(self.step_count_key, next_step_count) if self.max_steps is not None: truncated = next_step_count >= self.max_steps - tensordict.set(("next", self.truncated_key), truncated) - return tensordict + next_tensordict.set(self.truncated_key, truncated) + return next_tensordict def transform_observation_spec( self, observation_spec: CompositeSpec diff --git a/torchrl/envs/utils.py b/torchrl/envs/utils.py index 1f671dda1b2..dec48e5d5a3 100644 --- a/torchrl/envs/utils.py +++ b/torchrl/envs/utils.py @@ -4,7 +4,10 @@ # LICENSE file in the root directory of this source tree. from __future__ import annotations +import contextlib + import importlib.util +import os import re import torch @@ -20,7 +23,12 @@ set_interaction_mode as set_exploration_mode, set_interaction_type as set_exploration_type, ) -from tensordict.tensordict import LazyStackedTensorDict, NestedKey, TensorDictBase +from tensordict.tensordict import ( + LazyStackedTensorDict, + NestedKey, + TensorDict, + TensorDictBase, +) __all__ = [ "exploration_mode", @@ -221,19 +229,17 @@ def _set_single_key(source, dest, key, clone=False): key = (key,) for k in key: try: - val = source.get(k) + val = source._get_str(k, None) if is_tensor_collection(val): - new_val = dest.get(k, None) + new_val = dest._get_str(k, None) if new_val is None: new_val = val.empty() - # dest.set(k, new_val) dest._set_str(k, new_val, inplace=False, validated=True) source = val dest = new_val else: if clone: val = val.clone() - # dest.set(k, val) dest._set_str(k, val, inplace=False, validated=True) # This is a temporary solution to understand if a key is heterogeneous # while not having performance impact when the exception is not raised @@ -551,3 +557,98 @@ def make_composite_from_td(data): shape=data.shape, ) return composite + + +def _fuse_tensordicts(*tds, excluded, selected=None, total=None): + """Fuses tensordicts with rank-wise priority. + + The first tensordicts of the list will have a higher priority than those + coming after, in such a way that if a key is present in both the first and + second tensordict, the first value is guaranteed to result in the output. + + Args: + tds (sequence of TensorDictBase): tensordicts to fuse. + excluded (sequence of tuples): keys to ignore. Must be tuples, no string + allowed. + selected (sequence of tuples): keys to accept. Must be tuples, no string + allowed. + total (tuple): the root key of the tds. Used for recursive calls. + + Examples: + >>> td1 = TensorDict({ + ... "a": 0, + ... "b": {"c": 0}, + ... }, []) + >>> td2 = TensorDict({ + ... "a": 1, + ... "b": {"c": 1, "d": 1}, + ... }, []) + >>> td3 = TensorDict({ + ... "a": 2, + ... "b": {"c": 2, "d": 2, "e": {"f": 2}}, + ... "g": 2, + ... "h": {"i": 2}, + ... }, []) + >>> out = fuse_tensordicts(td1, td2, td3, excluded=("h", "i")) + >>> assert out["a"] == 0 + >>> assert out["b", "c"] == 0 + >>> assert out["b", "d"] == 1 + >>> assert out["b", "e", "f"] == 2 + >>> assert out["g"] == 2 + >>> assert ("h", "i") not in out.keys(True, True) + + """ + out = TensorDict({}, batch_size=tds[0].batch_size, device=tds[0].device) + if total is None: + total = () + + keys = set() + for i, td in enumerate(tds): + if td is None: + continue + for key in td.keys(): + cur_total = total + (key,) + if cur_total in excluded: + continue + if selected is not None and cur_total not in selected: + continue + if key in keys: + continue + keys.add(key) + val = td._get_str(key, None) + if is_tensor_collection(val): + val = _fuse_tensordicts( + val, + *[_td._get_str(key, None) for _td in tds[i + 1 :]], + total=cur_total, + excluded=excluded, + selected=selected, + ) + out._set_str(key, val, validated=True, inplace=False) + return out + + +@contextlib.contextmanager +def clear_mpi_env_vars(): + """Clears the MPI of environment variables. + + `from mpi4py import MPI` will call `MPI_Init` by default. + If the child process has MPI environment variables, MPI will think that the child process + is an MPI process just like the parent and do bad things such as hang. + + This context manager is a hacky way to clear those environment variables + temporarily such as when we are starting multiprocessing Processes. + + Yields: + Yields for the context manager + """ + removed_environment = {} + for k, v in list(os.environ.items()): + for prefix in ["OMPI_", "PMI_"]: + if k.startswith(prefix): + removed_environment[k] = v + del os.environ[k] + try: + yield + finally: + os.environ.update(removed_environment) diff --git a/torchrl/envs/vec_env.py b/torchrl/envs/vec_env.py index 108066522e1..60e4e0bb411 100644 --- a/torchrl/envs/vec_env.py +++ b/torchrl/envs/vec_env.py @@ -19,7 +19,7 @@ import numpy as np import torch -from tensordict import TensorDict, unravel_key +from tensordict import TensorDict from tensordict._tensordict import _unravel_key_to_tuple from tensordict.tensordict import LazyStackedTensorDict, TensorDictBase from torch import multiprocessing as mp @@ -34,7 +34,12 @@ from torchrl.envs.common import _EnvWrapper, EnvBase from torchrl.envs.env_creator import get_env_metadata -from torchrl.envs.utils import _set_single_key, _sort_keys +from torchrl.envs.utils import ( + _fuse_tensordicts, + _set_single_key, + _sort_keys, + clear_mpi_env_vars, +) _has_envpool = importlib.util.find_spec("envpool") @@ -117,7 +122,6 @@ class _BatchedEnv(EnvBase): if a list of callable is provided, the environment will be executed as if multiple, diverse tasks were needed, which comes with a slight compute overhead; create_env_kwargs (dict or list of dicts, optional): kwargs to be used with the environments being created; - pin_memory (bool): if True and device is "cpu", calls :obj:`pin_memory` on the tensordicts when created. share_individual_td (bool, optional): if ``True``, a different tensordict is created for every process/worker and a lazy stack is returned. default = None (False if single task); @@ -130,9 +134,6 @@ class _BatchedEnv(EnvBase): It is assumed that all environments will run on the same device as a common shared tensordict will be used to pass data from process to process. The device can be changed after instantiation using :obj:`env.to(device)`. - allow_step_when_done (bool, optional): if ``True``, batched environments can - execute steps after a done state is encountered. - Defaults to ``False``. """ @@ -195,10 +196,15 @@ def __init__( self.create_env_fn = create_env_fn self.create_env_kwargs = create_env_kwargs self.pin_memory = pin_memory + if pin_memory: + raise ValueError("pin_memory for batched envs is deprecated") + self.share_individual_td = bool(share_individual_td) self._share_memory = shared_memory self._memmap = memmap self.allow_step_when_done = allow_step_when_done + if allow_step_when_done: + raise ValueError("allow_step_when_done is deprecated") if self._share_memory and self._memmap: raise RuntimeError( "memmap and shared memory are mutually exclusive features." @@ -323,18 +329,18 @@ def _create_td(self) -> None: shared_tensordict_parent = self._env_tensordict.clone() if self._single_task: - self.env_input_keys = sorted( + self._env_input_keys = sorted( list(self.input_spec["_action_spec"].keys(True, True)) + list(self.state_spec.keys(True, True)), key=_sort_keys, ) - self.env_output_keys = [] - self.env_obs_keys = [] + self._env_output_keys = [] + self._env_obs_keys = [] for key in self.output_spec["_observation_spec"].keys(True, True): - self.env_output_keys.append(unravel_key(("next", key))) - self.env_obs_keys.append(key) - self.env_output_keys.append(unravel_key(("next", self.reward_key))) - self.env_output_keys.append(unravel_key(("next", self.done_key))) + self._env_output_keys.append(key) + self._env_obs_keys.append(key) + self._env_output_keys.append(self.reward_key) + self._env_output_keys.append(self.done_key) else: env_input_keys = set() for meta_data in self.meta_data: @@ -355,42 +361,49 @@ def _create_td(self) -> None: ) ) env_output_keys = env_output_keys.union( - unravel_key(("next", key)) + key for key in meta_data.specs["output_spec"]["_observation_spec"].keys( True, True ) ) env_output_keys = env_output_keys.union( { - unravel_key(("next", self.reward_key)), - unravel_key(("next", self.done_key)), + self.reward_key, + self.done_key, } ) - self.env_obs_keys = sorted(env_obs_keys, key=_sort_keys) - self.env_input_keys = sorted(env_input_keys, key=_sort_keys) - self.env_output_keys = sorted(env_output_keys, key=_sort_keys) + self._env_obs_keys = sorted(env_obs_keys, key=_sort_keys) + self._env_input_keys = sorted(env_input_keys, key=_sort_keys) + self._env_output_keys = sorted(env_output_keys, key=_sort_keys) self._selected_keys = ( - set(self.env_output_keys) - .union(self.env_input_keys) - .union(self.env_obs_keys) + set(self._env_output_keys) + .union(self._env_input_keys) + .union(self._env_obs_keys) ) self._selected_keys.add(self.done_key) self._selected_keys.add("_reset") - self._selected_reset_keys = self.env_obs_keys + [self.done_key] + ["_reset"] - self._selected_step_keys = self.env_output_keys + # input keys + self._selected_input_keys = {_unravel_key_to_tuple(key) for key in self._env_input_keys} + # output keys after reset + self._selected_reset_keys = {_unravel_key_to_tuple(key) for key in self._env_obs_keys + [self.done_key] + ["_reset"]} + # output keys after step + self._selected_step_keys = {_unravel_key_to_tuple(key) for key in self._env_output_keys} if self._single_task: shared_tensordict_parent = shared_tensordict_parent.select( *self._selected_keys, + "next", strict=False, ) self.shared_tensordict_parent = shared_tensordict_parent.to(self.device) else: # Multi-task: we share tensordict that *may* have different keys shared_tensordict_parent = [ - tensordict.select(*self._selected_keys, strict=False).to(self.device) + tensordict.select(*self._selected_keys, "next", strict=False).to( + self.device + ) for tensordict in shared_tensordict_parent ] shared_tensordict_parent = torch.stack( @@ -409,12 +422,13 @@ def _create_td(self) -> None: # Multi-task: we share tensordict that *may* have different keys # LazyStacked already stores this so we don't need to do anything self.shared_tensordicts = self.shared_tensordict_parent - if self._share_memory: - for td in self.shared_tensordicts: - td.share_memory_() - elif self._memmap: - for td in self.shared_tensordicts: - td.memmap_() + if self.device.type == "cpu": + if self._share_memory: + for td in self.shared_tensordicts: + td.share_memory_() + elif self._memmap: + for td in self.shared_tensordicts: + td.memmap_() else: if self._share_memory: self.shared_tensordict_parent.share_memory_() @@ -424,10 +438,7 @@ def _create_td(self) -> None: self.shared_tensordict_parent.memmap_() if not self.shared_tensordict_parent.is_memmap(): raise RuntimeError("memmap_() failed") - self.shared_tensordicts = self.shared_tensordict_parent.unbind(0) - if self.pin_memory: - self.shared_tensordict_parent.pin_memory() def _start_workers(self) -> None: """Starts the various envs.""" @@ -536,27 +547,24 @@ def _step( self, tensordict: TensorDict, ) -> TensorDict: - self._assert_tensordict_shape(tensordict) tensordict_in = tensordict.clone(False) + next_td = self.shared_tensordict_parent.get("next") for i in range(self.num_workers): # shared_tensordicts are locked, and we need to select the keys since we update in-place. # There may be unexpected keys, such as "_reset", that we should comfortably ignore here. out_td = self._envs[i]._step(tensordict_in[i]) - out_td.update(tensordict_in[i].select(*self.env_input_keys)) - self.shared_tensordicts[i].update_( - out_td.select(*self.env_input_keys, *self.env_output_keys) - ) + next_td[i].update_(out_td.select(*self._env_output_keys)) # We must pass a clone of the tensordict, as the values of this tensordict # will be modified in-place at further steps if self._single_task: - out = TensorDict({}, batch_size=self.shared_tensordict_parent.shape) + out = TensorDict( + {}, batch_size=self.shared_tensordict_parent.shape, device=self.device + ) for key in self._selected_step_keys: - _set_single_key(self.shared_tensordict_parent, out, key, clone=True) + _set_single_key(next_td, out, key, clone=True) else: # strict=False ensures that non-homogeneous keys are still there - out = self.shared_tensordict_parent.select( - *self._selected_step_keys, strict=False - ).clone() + out = next_td.select(*self._selected_step_keys, strict=False).clone() return out def _shutdown_workers(self) -> None: @@ -576,15 +584,18 @@ def set_seed( @_check_start def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase: - if tensordict is not None and "_reset" in tensordict.keys(): + _reset = None + if tensordict is not None: self._assert_tensordict_shape(tensordict) - _reset = tensordict.get("_reset") + _reset = tensordict.get("_reset", None) if _reset.shape[-len(self.done_spec.shape) :] != self.done_spec.shape: raise RuntimeError( "_reset flag in tensordict should follow env.done_spec" ) - else: - _reset = torch.ones(self.done_spec.shape, dtype=torch.bool) + if _reset is None: + _reset = torch.ones((), dtype=torch.bool, device=self.device).expand( + self.done_spec.shape + ) for i, _env in enumerate(self._envs): if tensordict is not None: @@ -601,9 +612,9 @@ def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase: # step at the root (since the shared_tensordict did not go through # step_mdp). self.shared_tensordicts[i].update_( - self.shared_tensordicts[i]["next"].select( - *self._selected_reset_keys, strict=False - ) + self.shared_tensordicts[i] + .get("next") + .select(*self._selected_reset_keys, strict=False) ) if tensordict_ is not None: self.shared_tensordicts[i].update_( @@ -612,12 +623,14 @@ def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase: continue _td = _env._reset(tensordict=tensordict_, **kwargs) self.shared_tensordicts[i].update_( - _td.select(*self._selected_keys, strict=False) + _td.select(*self._selected_reset_keys, strict=False) ) if self._single_task: # select + clone creates 2 tds, but we can create one only - out = TensorDict({}, batch_size=self.shared_tensordict_parent.shape) + out = TensorDict( + {}, batch_size=self.shared_tensordict_parent.shape, device=self.device + ) for key in self._selected_reset_keys: if key != "_reset": _set_single_key(self.shared_tensordict_parent, out, key, clone=True) @@ -682,52 +695,59 @@ class ParallelEnv(_BatchedEnv): __doc__ += _BatchedEnv.__doc__ def _start_workers(self) -> None: - _num_workers = self.num_workers + from torchrl.envs.env_creator import EnvCreator + ctx = mp.get_context("spawn") + _num_workers = self.num_workers + self.parent_channels = [] self._workers = [] + self._events = [] if self.device.type == "cuda": self.event = torch.cuda.Event() else: self.event = None - for idx in range(_num_workers): - if self._verbose: - print(f"initiating worker {idx}") - # No certainty which module multiprocessing_context is - channel1, channel2 = ctx.Pipe() - env_fun = self.create_env_fn[idx] - if env_fun.__class__.__name__ != "EnvCreator": - env_fun = CloudpickleWrapper(env_fun) - - w = mp.Process( - target=_run_worker_pipe_shared_mem, - args=( - idx, - channel1, - channel2, - env_fun, - self.create_env_kwargs[idx], - False, - self.env_input_keys, - self.device, - self.allow_step_when_done, - ), - ) - w.daemon = True - w.start() - channel2.close() - self.parent_channels.append(channel1) - self._workers.append(w) - for channel1 in self.parent_channels: - msg = channel1.recv() + with clear_mpi_env_vars(): + for idx in range(_num_workers): + if self._verbose: + print(f"initiating worker {idx}") + # No certainty which module multiprocessing_context is + parent_pipe, child_pipe = ctx.Pipe() + event = ctx.Event() + self._events.append(event) + env_fun = self.create_env_fn[idx] + if not isinstance(env_fun, EnvCreator): + env_fun = CloudpickleWrapper(env_fun) + + process = ctx.Process( + target=_run_worker_pipe_shared_mem, + args=( + parent_pipe, + child_pipe, + env_fun, + self.create_env_kwargs[idx], + self.device, + event, + self.shared_tensordicts[idx], + self._selected_input_keys, + self._selected_reset_keys, + self._selected_step_keys, + ), + ) + process.daemon = True + process.start() + child_pipe.close() + self.parent_channels.append(parent_pipe) + self._workers.append(process) + + for parent_pipe in self.parent_channels: + msg = parent_pipe.recv() assert msg == "started" # send shared tensordict to workers - for channel, shared_tensordict in zip( - self.parent_channels, self.shared_tensordicts - ): - channel.send(("init", shared_tensordict)) + for channel in self.parent_channels: + channel.send(("init", None)) self.is_closed = False @_check_start @@ -751,22 +771,15 @@ def load_state_dict(self, state_dict: OrderedDict) -> None: ) for i, channel in enumerate(self.parent_channels): channel.send(("load_state_dict", state_dict[f"worker{i}"])) - for channel in self.parent_channels: - msg, _ = channel.recv() - if msg != "loaded": - raise RuntimeError(f"Expected 'loaded' but received {msg}") + for event in self._events: + event.wait() + event.clear() @_check_start def _step(self, tensordict: TensorDictBase) -> TensorDictBase: - self._assert_tensordict_shape(tensordict) if self._single_task: # this is faster than update_ but won't work for lazy stacks - for key in self.env_input_keys: - # self.shared_tensordict_parent.set( - # key, - # tensordict.get(key), - # inplace=True, - # ) + for key in self._env_input_keys: key = _unravel_key_to_tuple(key) self.shared_tensordict_parent._set_tuple( key, @@ -776,7 +789,7 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase: ) else: self.shared_tensordict_parent.update_( - tensordict.select(*self.env_input_keys, strict=False) + tensordict.select(*self._env_input_keys, strict=False) ) if self.event is not None: self.event.record() @@ -784,43 +797,44 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase: for i in range(self.num_workers): self.parent_channels[i].send(("step", None)) - # keys = set() - for i in range(self.num_workers): - msg, data = self.parent_channels[i].recv() - if msg != "step_result": - raise RuntimeError( - f"Expected 'step_result' but received {msg} from worker {i}" - ) - if data is not None: - self.shared_tensordicts[i].update_(data) + completed = set() + while len(completed) < self.num_workers: + for i, event in enumerate(self._events): + if i in completed: + continue + if event.is_set(): + completed.add(i) + event.clear() + # We must pass a clone of the tensordict, as the values of this tensordict # will be modified in-place at further steps + next_td = self.shared_tensordict_parent.get("next") if self._single_task: - out = TensorDict({}, batch_size=self.shared_tensordict_parent.shape) + out = TensorDict( + {}, batch_size=self.shared_tensordict_parent.shape, device=self.device + ) for key in self._selected_step_keys: - _set_single_key(self.shared_tensordict_parent, out, key, clone=True) + _set_single_key(next_td, out, key, clone=True) else: # strict=False ensures that non-homogeneous keys are still there - out = self.shared_tensordict_parent.select( - *self._selected_step_keys, strict=False - ).clone() + out = next_td.select(*self._selected_step_keys, strict=False).clone() return out @_check_start def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase: - cmd_out = "reset" - if tensordict is not None and "_reset" in tensordict.keys(): + _reset = None + if tensordict is not None: self._assert_tensordict_shape(tensordict) - _reset = tensordict.get("_reset") + _reset = tensordict.get("_reset", None) if _reset.shape[-len(self.done_spec.shape) :] != self.done_spec.shape: raise RuntimeError( "_reset flag in tensordict should follow env.done_spec" ) - else: - _reset = torch.ones( - self.done_spec.shape, dtype=torch.bool, device=self.device + if _reset is None: + _reset = torch.ones((), dtype=torch.bool, device=self.device).expand( + self.done_spec.shape ) - + workers = [] for i, channel in enumerate(self.parent_channels): if tensordict is not None: tensordict_ = tensordict[i] @@ -836,29 +850,34 @@ def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase: # step at the root (since the shared_tensordict did not go through # step_mdp). self.shared_tensordicts[i].update_( - self.shared_tensordicts[i]["next"].select( - *self._selected_reset_keys, strict=False - ) + self.shared_tensordicts[i] + .get("next") + .select(*self._selected_reset_keys, strict=False) ) if tensordict_ is not None: self.shared_tensordicts[i].update_( tensordict_.select(*self._selected_reset_keys, strict=False) ) continue - out = (cmd_out, tensordict_) + out = ("reset", tensordict_) channel.send(out) + workers.append(i) + + completed = set() + while len(completed) < len(workers): + for i in workers: + event = self._events[i] + if i in completed: + continue + if event.is_set(): + completed.add(i) + event.clear() - for i, channel in enumerate(self.parent_channels): - if not _reset[i].any(): - continue - cmd_in, data = channel.recv() - if cmd_in != "reset_obs": - raise RuntimeError(f"received cmd {cmd_in} instead of reset_obs") - if data is not None: - self.shared_tensordicts[i].update_(data) if self._single_task: # select + clone creates 2 tds, but we can create one only - out = TensorDict({}, batch_size=self.shared_tensordict_parent.shape) + out = TensorDict( + {}, batch_size=self.shared_tensordict_parent.shape, device=self.device + ) for key in self._selected_reset_keys: if key != "_reset": _set_single_key(self.shared_tensordict_parent, out, key, clone=True) @@ -878,15 +897,9 @@ def _shutdown_workers(self) -> None: for i, channel in enumerate(self.parent_channels): if self._verbose: print(f"closing {i}") - # try: channel.send(("close", None)) - # except: - # raise RuntimeError(f"closing {channel} number {i} failed") - msg, _ = channel.recv() - if msg != "closing": - raise RuntimeError( - f"Expected 'closing' but received {msg} from worker {i}" - ) + self._events[i].wait() + self._events[i].clear() del self.shared_tensordicts, self.shared_tensordict_parent @@ -974,15 +987,16 @@ def _recursively_strip_locks_from_state_dict(state_dict: OrderedDict) -> Ordered def _run_worker_pipe_shared_mem( - idx: int, parent_pipe: connection.Connection, child_pipe: connection.Connection, env_fun: Union[EnvBase, Callable], env_fun_kwargs: Dict[str, Any], - pin_memory: bool, - env_input_keys: Dict[str, Any], device: DEVICE_TYPING = None, - allow_step_when_done: bool = False, + mp_event: mp.Event = None, + shared_tensordict: TensorDictBase = None, + _selected_input_keys=None, + _selected_reset_keys=None, + _selected_step_keys=None, verbose: bool = False, ) -> None: if device is None: @@ -1003,16 +1017,19 @@ def _run_worker_pipe_shared_mem( ) env = env_fun env = env.to(device) + del env_fun i = -1 initialized = False - # make sure that process can be closed - shared_tensordict = None - local_tensordict = None - child_pipe.send("started") + _excluded_reset_keys = { + _unravel_key_to_tuple(env.reward_key), + # _unravel_key_to_tuple(env.done_key), + _unravel_key_to_tuple(env.action_key), + } + while True: try: cmd, data = child_pipe.recv() @@ -1032,8 +1049,10 @@ def _run_worker_pipe_shared_mem( if initialized: raise RuntimeError("worker already initialized") i = 0 - shared_tensordict = data next_shared_tensordict = shared_tensordict.get("next") + shared_tensordict = shared_tensordict.clone(False) + del shared_tensordict["next"] + if not (shared_tensordict.is_shared() or shared_tensordict.is_memmap()): raise RuntimeError( "tensordict must be placed in shared memory (share_memory_() or memmap_())" @@ -1045,55 +1064,34 @@ def _run_worker_pipe_shared_mem( print(f"resetting worker {pid}") if not initialized: raise RuntimeError("call 'init' before resetting") - local_tensordict = data - local_tensordict = env._reset(tensordict=local_tensordict) - - if "_reset" in local_tensordict.keys(): - local_tensordict.del_("_reset") - if pin_memory: - local_tensordict.pin_memory() - shared_tensordict.update_(local_tensordict) + cur_td = env._reset(tensordict=data) + + if "_reset" in cur_td.keys(): + cur_td.del_("_reset") + shared_tensordict.update_(cur_td) if event is not None: event.record() event.synchronize() - out = ("reset_obs", None) - child_pipe.send(out) + mp_event.set() elif cmd == "step": if not initialized: raise RuntimeError("called 'init' before step") i += 1 - if local_tensordict is not None: - for key in env_input_keys: - # local_tensordict.set(key, shared_tensordict.get(key)) - key = _unravel_key_to_tuple(key) - local_tensordict._set_tuple( - key, - shared_tensordict._get_tuple(key, None), - inplace=False, - validated=True, - ) - else: - local_tensordict = shared_tensordict.clone(recurse=False) - local_tensordict = env._step(local_tensordict) - if pin_memory: - local_tensordict.pin_memory() - msg = "step_result" - next_shared_tensordict.update_(local_tensordict.get("next")) + next_td = env._step(shared_tensordict) + next_shared_tensordict.update_(next_td) if event is not None: event.record() event.synchronize() - out = (msg, None) - child_pipe.send(out) + mp_event.set() elif cmd == "close": - del shared_tensordict, local_tensordict, data + del shared_tensordict, data if not initialized: raise RuntimeError("call 'init' before closing") env.close() del env - - child_pipe.send(("closing", None)) + mp_event.set() child_pipe.close() if verbose: print(f"{pid} closed") @@ -1101,8 +1099,7 @@ def _run_worker_pipe_shared_mem( elif cmd == "load_state_dict": env.load_state_dict(data) - msg = "loaded" - child_pipe.send((msg, None)) + mp_event.set() elif cmd == "state_dict": state_dict = _recursively_strip_locks_from_state_dict(env.state_dict()) @@ -1207,7 +1204,7 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase: action = action.to(torch.device("cpu")) step_output = self._env.step(action.numpy()) tensordict_out = self._transform_step_output(step_output) - return tensordict_out.select().set("next", tensordict_out) + return tensordict_out def _get_action_spec(self) -> TensorSpec: # local import to avoid importing gym in the script @@ -1296,7 +1293,9 @@ def _transform_reset_output( else: # All workers were reset - rewrite the whole observation buffer self.obs = TensorDict( - self._treevalue_or_numpy_to_tensor_or_dict(observation), self.batch_size + self._treevalue_or_numpy_to_tensor_or_dict(observation), + self.batch_size, + device=self.device, ) obs = self.obs.clone(False) From 871589ee86d62f878eaef640ab969ada8527af63 Mon Sep 17 00:00:00 2001 From: vmoens Date: Fri, 11 Aug 2023 05:58:08 -0400 Subject: [PATCH 05/29] amend --- test/mocking_classes.py | 30 ++++++++++-------------------- test/test_env.py | 16 ++++++++++++---- test/test_transforms.py | 36 ++++++++++++++++++++---------------- 3 files changed, 42 insertions(+), 40 deletions(-) diff --git a/test/mocking_classes.py b/test/mocking_classes.py index 1af0ba59c6c..7f12eb9ed01 100644 --- a/test/mocking_classes.py +++ b/test/mocking_classes.py @@ -203,11 +203,7 @@ def _step(self, tensordict): done = self.counter >= self.max_val done = torch.tensor([done], dtype=torch.bool, device=self.device) return TensorDict( - { - "next": TensorDict( - {"reward": n, "done": done, "observation": n.clone()}, batch_size=[] - ) - }, + {"reward": n, "done": done, "observation": n.clone()}, batch_size=[], ) @@ -338,13 +334,7 @@ def _step(self, tensordict): device=self.device, ) return TensorDict( - { - "next": TensorDict( - {"reward": n, "done": done, "observation": n}, - tensordict.batch_size, - device=self.device, - ) - }, + {"reward": n, "done": done, "observation": n}, batch_size=tensordict.batch_size, device=self.device, ) @@ -501,7 +491,7 @@ def _step( done = torch.zeros_like(done).all(-1).unsqueeze(-1) tensordict.set("reward", reward.to(torch.get_default_dtype())) tensordict.set("done", done) - return tensordict.select().set("next", tensordict) + return tensordict class ContinuousActionVecMockEnv(_MockEnv): @@ -603,7 +593,7 @@ def _step( done = reward = done.unsqueeze(-1) tensordict.set("reward", reward.to(torch.get_default_dtype())) tensordict.set("done", done) - return tensordict.select().set("next", tensordict) + return tensordict def _obs_step(self, obs, a): return obs + a / self.maxstep @@ -1044,7 +1034,7 @@ def _step( batch_size=self.batch_size, device=self.device, ) - return tensordict.select().set("next", tensordict) + return tensordict class NestedCountingEnv(CountingEnv): @@ -1167,7 +1157,7 @@ def _step(self, td): td = td.clone() td["data"].batch_size = self.batch_size td[self.action_key] = td[self.action_key].max(-2)[0] - td_root = super()._step(td) + next_td = super()._step(td) if self.nested_obs_action: td[self.action_key] = ( td[self.action_key] @@ -1176,7 +1166,7 @@ def _step(self, td): ) if "data" in td.keys(): td["data"].batch_size = (*self.batch_size, self.nested_dim) - td = td_root["next"] + td = next_td if self.nested_done: td[self.done_key] = ( td["done"].unsqueeze(-1).expand(*self.batch_size, self.nested_dim, 1) @@ -1196,7 +1186,7 @@ def _step(self, td): del td["reward"] if "data" in td.keys(): td["data"].batch_size = (*self.batch_size, self.nested_dim) - return td_root + return td class CountingBatchedEnv(EnvBase): @@ -1290,7 +1280,7 @@ def _step( batch_size=self.batch_size, device=self.device, ) - return tensordict.select().set("next", tensordict) + return tensordict class HeteroCountingEnvPolicy: @@ -1479,7 +1469,7 @@ def _step( self.count > self.max_steps, self.done_spec.shape ) - return td.select().set("next", td) + return td def _set_seed(self, seed: Optional[int]): torch.manual_seed(seed) diff --git a/test/test_env.py b/test/test_env.py index 49604665563..8ae5091864c 100644 --- a/test/test_env.py +++ b/test/test_env.py @@ -5,6 +5,7 @@ import argparse import os.path +import re from collections import defaultdict from functools import partial @@ -244,9 +245,16 @@ def test_rollout_reset(env_name, frame_skip, parallel, truncated_key, seed=0): else: env = SerialEnv(3, envs) env.set_seed(100) + # out = env._single_rollout(100, break_when_any_done=False) out = env.rollout(100, break_when_any_done=False) assert out.names[-1] == "time" assert out.shape == torch.Size([3, 100]) + assert ( + out[..., -1]["step_count"].squeeze().cpu() == torch.tensor([19, 9, 19]) + ).all() + assert ( + out[..., -1]["next", "step_count"].squeeze().cpu() == torch.tensor([20, 10, 20]) + ).all() assert ( out["next", truncated_key].squeeze().sum(-1) == torch.tensor([5, 3, 2]) ).all() @@ -319,7 +327,9 @@ def test_mb_env_batch_lock(self, device, seed=0): td_expanded = td.unsqueeze(-1).expand(10, 2).reshape(-1).to_tensordict() mb_env.step(td) - with pytest.raises(RuntimeError, match="Expected a tensordict with shape"): + with pytest.raises( + RuntimeError, match=re.escape("Expected a tensordict with shape==env.shape") + ): mb_env.step(td_expanded) mb_env = DummyModelBasedEnvBase( @@ -1573,9 +1583,7 @@ def test_batch_unlocked_with_batch_size(device): td_expanded = td.expand(2, 2).reshape(-1).to_tensordict() td = env.step(td) - with pytest.raises( - RuntimeError, match="Expected a tensordict with shape==env.shape, " - ): + with pytest.raises(RuntimeError, match="Expected a tensordict with shape"): env.step(td_expanded) diff --git a/test/test_transforms.py b/test/test_transforms.py index 8241fc8d6a2..dbb622eb6e0 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -1238,7 +1238,8 @@ def test_transform_compose(self, max_steps, device, batch, reset_workers): assert not torch.all(td.get("step_count")) i = 0 while max_steps is None or i < max_steps: - td = step_counter._step(td) + next_td = step_counter._step(td, td.get("next")) + td.set("next", next_td) i += 1 assert torch.all(td.get(("next", "step_count")) == i), ( td.get(("next", "step_count")), @@ -1291,7 +1292,7 @@ def test_transform_no_env(self, max_steps, device, batch, reset_workers): assert not torch.all(td.get("step_count")) i = 0 while max_steps is None or i < max_steps: - td = step_counter._step(td) + td.set("next", step_counter._step(td, td.get("next"))) i += 1 assert torch.all(td.get(("next", "step_count")) == i), ( td.get(("next", "step_count")), @@ -2699,7 +2700,7 @@ def test_transform_no_env(self): with pytest.raises( RuntimeError, match="parent not found for FrameSkipTransform" ): - t._step(tensordict) + t._step(tensordict, tensordict.get("next")) def test_transform_compose(self): t = Compose(FrameSkipTransform(2)) @@ -2707,7 +2708,7 @@ def test_transform_compose(self): with pytest.raises( RuntimeError, match="parent not found for FrameSkipTransform" ): - t._step(tensordict) + t._step(tensordict, tensordict.get("next")) @pytest.mark.skipif(not _has_gym, reason="gym not installed") @pytest.mark.parametrize("skip", [-1, 1, 2, 3]) @@ -3023,7 +3024,8 @@ def test_transform_no_env(self): match="NoopResetEnv.parent not found. Make sure that the parent is set.", ): t.reset(TensorDict({"next": {}}, [])) - t._step(TensorDict({"next": {}}, [])) + td = TensorDict({"next": {}}, []) + t._step(td, td.get("next")) def test_transform_compose(self): t = Compose(NoopResetEnv()) @@ -3032,7 +3034,8 @@ def test_transform_compose(self): match="NoopResetEnv.parent not found. Make sure that the parent is set.", ): t.reset(TensorDict({"next": {}}, [])) - t._step(TensorDict({"next": {}}, [])) + td = TensorDict({"next": {}}, []) + t._step(td, td.get("next")) def test_transform_model(self): t = nn.Sequential(NoopResetEnv(), nn.Identity()) @@ -4168,15 +4171,13 @@ def test_sum_reward(self, keys, device): ) # apply one time, episode_reward should be equal to reward again - td = rs._step(td) - td_next = td["next"] + td_next = rs._step(td, td.get("next")) assert "episode_reward" in td.keys() assert (td_next.get("episode_reward") == td_next.get("reward")).all() # apply a second time, episode_reward should twice the reward td["episode_reward"] = td["next", "episode_reward"] - td = rs._step(td) - td_next = td["next"] + td_next = rs._step(td, td.get("next")) assert (td_next.get("episode_reward") == 2 * td_next.get("reward")).all() # reset environments @@ -4184,8 +4185,7 @@ def test_sum_reward(self, keys, device): rs.reset(td) # apply a third time, episode_reward should be equal to reward again - td = rs._step(td) - td_next = td["next"] + td_next = rs._step(td, td.get("next")) assert (td_next.get("episode_reward") == td_next.get("reward")).all() # test transform_observation_spec @@ -5055,7 +5055,9 @@ def test_transform_compose(self, batch, mode, device): batch_size=batch, ) td = t.reset(td) - td = t._step(td) + next_td = td.get("next") + next_td = t._step(td, next_td) + td.set("next", next_td) if mode == "reduce": assert (td["next", "target_return"] + td["next", "reward"] == 10.0).all() @@ -5132,7 +5134,8 @@ def test_transform_no_env(self, mode, in_key, out_key): reward = torch.randn(10) td = TensorDict({("next", in_key): reward}, []) td = t.reset(td) - td = t._step(td) + td_next = t._step(td, td.get("next")) + td.set("next", td_next) if mode == "reduce": assert (td["next", out_key] + td["next", in_key] == 10.0).all() else: @@ -7884,12 +7887,13 @@ def test_transform_no_env(self, in_key, out_key): { "action": torch.randn(*batch, 7), "observation": torch.randn(*batch, 7), - "next": {t.in_keys[0]: torch.zeros(*batch, 1)}, "sample_log_prob": torch.randn(*batch), }, batch, ) - t._step(tensordict) + next_td = TensorDict({t.in_keys[0]: torch.zeros(*batch, 1)}, batch) + next_td = t._step(tensordict, next_td) + tensordict.set("next", next_td) assert (tensordict.get("next").get(t.out_keys[0]) != 0).all() def test_transform_compose(self): From 858da29bfab4bbbf0ca219d122a248da8e209565 Mon Sep 17 00:00:00 2001 From: vmoens Date: Fri, 11 Aug 2023 06:07:35 -0400 Subject: [PATCH 06/29] amend --- torchrl/envs/vec_env.py | 4 ++-- tutorials/sphinx-tutorials/pendulum.py | 13 +++++-------- 2 files changed, 7 insertions(+), 10 deletions(-) diff --git a/torchrl/envs/vec_env.py b/torchrl/envs/vec_env.py index 60e4e0bb411..e656a509f1c 100644 --- a/torchrl/envs/vec_env.py +++ b/torchrl/envs/vec_env.py @@ -632,7 +632,7 @@ def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase: {}, batch_size=self.shared_tensordict_parent.shape, device=self.device ) for key in self._selected_reset_keys: - if key != "_reset": + if key != ("_reset",): _set_single_key(self.shared_tensordict_parent, out, key, clone=True) return out else: @@ -879,7 +879,7 @@ def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase: {}, batch_size=self.shared_tensordict_parent.shape, device=self.device ) for key in self._selected_reset_keys: - if key != "_reset": + if key != ("_reset",): _set_single_key(self.shared_tensordict_parent, out, key, clone=True) return out else: diff --git a/tutorials/sphinx-tutorials/pendulum.py b/tutorials/sphinx-tutorials/pendulum.py index 17f41430217..85c4226dfb2 100644 --- a/tutorials/sphinx-tutorials/pendulum.py +++ b/tutorials/sphinx-tutorials/pendulum.py @@ -241,16 +241,13 @@ def _step(tensordict): new_th = th + new_thdot * dt reward = -costs.view(*tensordict.shape, 1) done = torch.zeros_like(reward, dtype=torch.bool) - # The output must be written in a ``"next"`` entry out = TensorDict( { - "next": { - "th": new_th, - "thdot": new_thdot, - "params": tensordict["params"], - "reward": reward, - "done": done, - } + "th": new_th, + "thdot": new_thdot, + "params": tensordict["params"], + "reward": reward, + "done": done, }, tensordict.shape, ) From 871bdc732a27e113a10f498c4bce665fad48ce5f Mon Sep 17 00:00:00 2001 From: vmoens Date: Thu, 31 Aug 2023 08:34:20 -0400 Subject: [PATCH 07/29] amend --- torchrl/collectors/collectors.py | 4 +++- torchrl/envs/common.py | 5 +---- torchrl/envs/vec_env.py | 37 ++++++++++++++++++-------------- 3 files changed, 25 insertions(+), 21 deletions(-) diff --git a/torchrl/collectors/collectors.py b/torchrl/collectors/collectors.py index 7e963a718b0..26516ad6dd1 100644 --- a/torchrl/collectors/collectors.py +++ b/torchrl/collectors/collectors.py @@ -843,7 +843,9 @@ def _step_and_maybe_reset(self) -> None: if td_reset.batch_dims: # better cloning here than when passing the td for stacking # cloning is necessary to avoid modifying entries in-place - self._tensordict = torch.where(traj_done_or_terminated, td_reset, self._tensordict) + self._tensordict = torch.where( + traj_done_or_terminated, td_reset, self._tensordict + ) else: self._tensordict.update(td_reset) diff --git a/torchrl/envs/common.py b/torchrl/envs/common.py index 98fa58e5774..e4d522f437d 100644 --- a/torchrl/envs/common.py +++ b/torchrl/envs/common.py @@ -7,12 +7,11 @@ import abc from copy import deepcopy -from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, Iterator, List, Optional, Union import numpy as np import torch import torch.nn as nn -from tensordict._tensordict import _unravel_key_to_tuple from tensordict import unravel_key from tensordict.tensordict import TensorDictBase from tensordict.utils import NestedKey @@ -31,8 +30,6 @@ get_available_libraries, step_mdp, ) -from torchrl.data.utils import DEVICE_TYPING -from torchrl.envs.utils import _fuse_tensordicts, get_available_libraries, step_mdp LIBRARIES = get_available_libraries() diff --git a/torchrl/envs/vec_env.py b/torchrl/envs/vec_env.py index 802518d72b7..9910f9d3a5c 100644 --- a/torchrl/envs/vec_env.py +++ b/torchrl/envs/vec_env.py @@ -20,7 +20,7 @@ import torch from tensordict import TensorDict -from tensordict._tensordict import _unravel_key_to_tuple +from tensordict._tensordict import _unravel_key_to_tuple, unravel_keys from tensordict.tensordict import LazyStackedTensorDict, TensorDictBase from torch import multiprocessing as mp from torchrl._utils import _check_for_faulty_process, VERBOSE @@ -35,12 +35,11 @@ from torchrl.envs.env_creator import get_env_metadata from torchrl.envs.utils import ( - _fuse_tensordicts, + _replace_last, _set_single_key, _sort_keys, clear_mpi_env_vars, ) -from torchrl.envs.utils import _replace_last, _set_single_key, _sort_keys _has_envpool = importlib.util.find_spec("envpool") @@ -342,7 +341,7 @@ def _create_td(self) -> None: self._env_output_keys.append(key) self._env_obs_keys.append(key) self._env_output_keys += [ - unravel_key(("next", key)) for key in self.reward_keys + self.done_keys + unravel_keys(("next", key)) for key in self.reward_keys + self.done_keys ] else: env_input_keys = set() @@ -371,9 +370,7 @@ def _create_td(self) -> None: "full_observation_spec" ].keys(True, True) ) - env_output_keys = env_output_keys.union( - self.reward_keys + self.done_keys - ) + env_output_keys = env_output_keys.union(self.reward_keys + self.done_keys) self._env_obs_keys = sorted(env_obs_keys, key=_sort_keys) self._env_input_keys = sorted(env_input_keys, key=_sort_keys) self._env_output_keys = sorted(env_output_keys, key=_sort_keys) @@ -387,11 +384,18 @@ def _create_td(self) -> None: self._selected_keys.add("_reset") # input keys - self._selected_input_keys = {_unravel_key_to_tuple(key) for key in self._env_input_keys} + self._selected_input_keys = { + _unravel_key_to_tuple(key) for key in self._env_input_keys + } # output keys after reset - self._selected_reset_keys = {_unravel_key_to_tuple(key) for key in self._env_obs_keys + self.done_keys + ["_reset"]} + self._selected_reset_keys = { + _unravel_key_to_tuple(key) + for key in self._env_obs_keys + self.done_keys + ["_reset"] + } # output keys after step - self._selected_step_keys = {_unravel_key_to_tuple(key) for key in self._env_output_keys} + self._selected_step_keys = { + _unravel_key_to_tuple(key) for key in self._env_output_keys + } if self._single_task: shared_tensordict_parent = shared_tensordict_parent.select( @@ -740,7 +744,8 @@ def _start_workers(self) -> None: self._selected_input_keys, self._selected_reset_keys, self._selected_step_keys, - self.has_lazy_inputs,), + self.has_lazy_inputs, + ), ) process.daemon = True process.start() @@ -1037,11 +1042,11 @@ def _run_worker_pipe_shared_mem( child_pipe.send("started") - _excluded_reset_keys = { - _unravel_key_to_tuple(env.reward_key), - # _unravel_key_to_tuple(env.done_key), - _unravel_key_to_tuple(env.action_key), - } + # _excluded_reset_keys = { + # _unravel_key_to_tuple(env.reward_key), + # # _unravel_key_to_tuple(env.done_key), + # _unravel_key_to_tuple(env.action_key), + # } while True: try: From bb352e5fee588540bd8e7f31330ab4d71252696b Mon Sep 17 00:00:00 2001 From: vmoens Date: Thu, 31 Aug 2023 08:42:02 -0400 Subject: [PATCH 08/29] amend --- torchrl/envs/libs/vmas.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchrl/envs/libs/vmas.py b/torchrl/envs/libs/vmas.py index e70d5d29bf5..6ce6ebd222e 100644 --- a/torchrl/envs/libs/vmas.py +++ b/torchrl/envs/libs/vmas.py @@ -354,7 +354,7 @@ def _step( if not self.het_specs: agent_tds = agent_tds.to_tensordict() tensordict_out = TensorDict( - source={"next": {"agents": agent_tds, "done": dones}}, + source={"agents": agent_tds, "done": dones}, batch_size=self.batch_size, device=self.device, ) From 28cb428a3075579e4047bcd9dbaf3cccec7a52b8 Mon Sep 17 00:00:00 2001 From: vmoens Date: Thu, 31 Aug 2023 08:54:37 -0400 Subject: [PATCH 09/29] amend --- torchrl/envs/common.py | 4 ++++ torchrl/envs/vec_env.py | 18 ++++++++---------- 2 files changed, 12 insertions(+), 10 deletions(-) diff --git a/torchrl/envs/common.py b/torchrl/envs/common.py index e4d522f437d..469d6bd65bc 100644 --- a/torchrl/envs/common.py +++ b/torchrl/envs/common.py @@ -822,6 +822,10 @@ def _get_done_keys(self): self.__dict__["_done_keys"] = keys return keys + @property + def reset_keys(self) -> List[NestedKey]: + return [_replace_last(done_key, "_reset") for done_key in self.done_keys] + @property def done_keys(self) -> List[NestedKey]: """The done keys of an environment. diff --git a/torchrl/envs/vec_env.py b/torchrl/envs/vec_env.py index 9910f9d3a5c..5e09706cf0b 100644 --- a/torchrl/envs/vec_env.py +++ b/torchrl/envs/vec_env.py @@ -365,23 +365,23 @@ def _create_td(self) -> None: ].keys(True, True) ) env_output_keys = env_output_keys.union( - key - for key in meta_data.specs["output_spec"][ - "full_observation_spec" - ].keys(True, True) + meta_data.specs["output_spec"]["full_observation_spec"].keys( + True, True + ) ) env_output_keys = env_output_keys.union(self.reward_keys + self.done_keys) self._env_obs_keys = sorted(env_obs_keys, key=_sort_keys) self._env_input_keys = sorted(env_input_keys, key=_sort_keys) self._env_output_keys = sorted(env_output_keys, key=_sort_keys) + reset_keys = self.reset_keys self._selected_keys = ( set(self._env_output_keys) .union(self._env_input_keys) .union(self._env_obs_keys) .union(set(self.done_keys)) ) - self._selected_keys.add("_reset") + self._selected_keys = self._selected_keys.union(reset_keys) # input keys self._selected_input_keys = { @@ -390,7 +390,7 @@ def _create_td(self) -> None: # output keys after reset self._selected_reset_keys = { _unravel_key_to_tuple(key) - for key in self._env_obs_keys + self.done_keys + ["_reset"] + for key in self._env_obs_keys + self.done_keys + reset_keys } # output keys after step self._selected_step_keys = { @@ -594,8 +594,7 @@ def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase: missing_reset = False if tensordict is not None: needs_resetting = [False] * self.num_workers - for done_key in self.done_keys: - _reset_key = _replace_last(done_key, "_reset") + for _reset_key in self.reset_keys: _reset = tensordict.get(_reset_key, default=None) if _reset is not None: for i in range(self.num_workers): @@ -838,8 +837,7 @@ def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase: missing_reset = False if tensordict is not None: needs_resetting = [False for _ in range(self.num_workers)] - for done_key in self.done_keys: - _reset_key = _replace_last(done_key, "_reset") + for _reset_key in self.reset_keys: _reset = tensordict.get(_reset_key, default=None) if _reset is not None: for i in range(self.num_workers): From 26b53bced2b9c3de3f04acb3b76b462738d6ce48 Mon Sep 17 00:00:00 2001 From: vmoens Date: Thu, 31 Aug 2023 09:06:44 -0400 Subject: [PATCH 10/29] amend --- torchrl/envs/utils.py | 69 ------------------------------------------- 1 file changed, 69 deletions(-) diff --git a/torchrl/envs/utils.py b/torchrl/envs/utils.py index 84af7d51238..7adef6aeafd 100644 --- a/torchrl/envs/utils.py +++ b/torchrl/envs/utils.py @@ -568,75 +568,6 @@ def make_composite_from_td(data): return composite -def _fuse_tensordicts(*tds, excluded, selected=None, total=None): - """Fuses tensordicts with rank-wise priority. - - The first tensordicts of the list will have a higher priority than those - coming after, in such a way that if a key is present in both the first and - second tensordict, the first value is guaranteed to result in the output. - - Args: - tds (sequence of TensorDictBase): tensordicts to fuse. - excluded (sequence of tuples): keys to ignore. Must be tuples, no string - allowed. - selected (sequence of tuples): keys to accept. Must be tuples, no string - allowed. - total (tuple): the root key of the tds. Used for recursive calls. - - Examples: - >>> td1 = TensorDict({ - ... "a": 0, - ... "b": {"c": 0}, - ... }, []) - >>> td2 = TensorDict({ - ... "a": 1, - ... "b": {"c": 1, "d": 1}, - ... }, []) - >>> td3 = TensorDict({ - ... "a": 2, - ... "b": {"c": 2, "d": 2, "e": {"f": 2}}, - ... "g": 2, - ... "h": {"i": 2}, - ... }, []) - >>> out = fuse_tensordicts(td1, td2, td3, excluded=("h", "i")) - >>> assert out["a"] == 0 - >>> assert out["b", "c"] == 0 - >>> assert out["b", "d"] == 1 - >>> assert out["b", "e", "f"] == 2 - >>> assert out["g"] == 2 - >>> assert ("h", "i") not in out.keys(True, True) - - """ - out = TensorDict({}, batch_size=tds[0].batch_size, device=tds[0].device) - if total is None: - total = () - - keys = set() - for i, td in enumerate(tds): - if td is None: - continue - for key in td.keys(): - cur_total = total + (key,) - if cur_total in excluded: - continue - if selected is not None and cur_total not in selected: - continue - if key in keys: - continue - keys.add(key) - val = td._get_str(key, None) - if is_tensor_collection(val): - val = _fuse_tensordicts( - val, - *[_td._get_str(key, None) for _td in tds[i + 1 :]], - total=cur_total, - excluded=excluded, - selected=selected, - ) - out._set_str(key, val, validated=True, inplace=False) - return out - - @contextlib.contextmanager def clear_mpi_env_vars(): """Clears the MPI of environment variables. From 77872a6fa0bb59b8f52028ebc96b86ec17fab05e Mon Sep 17 00:00:00 2001 From: vmoens Date: Thu, 31 Aug 2023 09:07:58 -0400 Subject: [PATCH 11/29] amend --- test/mocking_classes.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/mocking_classes.py b/test/mocking_classes.py index 9a96f52c8fb..91676699997 100644 --- a/test/mocking_classes.py +++ b/test/mocking_classes.py @@ -1703,7 +1703,7 @@ def _step( td.update(reward) assert td.batch_size == self.batch_size - return td.select().set("next", td) + return td def _set_seed(self, seed: Optional[int]): torch.manual_seed(seed) From c7ed82e86a11303cc0d802b21324526f898f6e8a Mon Sep 17 00:00:00 2001 From: vmoens Date: Thu, 31 Aug 2023 11:31:07 -0400 Subject: [PATCH 12/29] amend --- examples/decision_transformer/utils.py | 38 ++++++++++++++++------ torchrl/envs/gym_like.py | 14 ++++---- torchrl/envs/libs/jumanji.py | 3 +- torchrl/envs/transforms/transforms.py | 44 +++++++++++++++----------- torchrl/envs/utils.py | 1 - torchrl/envs/vec_env.py | 7 ++-- 6 files changed, 63 insertions(+), 44 deletions(-) diff --git a/examples/decision_transformer/utils.py b/examples/decision_transformer/utils.py index c181b32ca5d..768237178c9 100644 --- a/examples/decision_transformer/utils.py +++ b/examples/decision_transformer/utils.py @@ -74,14 +74,17 @@ def make_transformed_env(base_env, env_cfg, obs_loc, obs_std, train=False): transformed_env = TransformedEnv(base_env) transformed_env.append_transform( RewardScaling( - loc=0, scale=env_cfg.reward_scaling, in_keys="reward", standard_normal=False + loc=0, + scale=env_cfg.reward_scaling, + in_keys=["reward"], + standard_normal=False, ) ) if train: transformed_env.append_transform( TargetReturn( env_cfg.collect_target_return * env_cfg.reward_scaling, - out_keys=["return_to_go"], + out_keys=["return_to_go_single"], mode=env_cfg.target_return_mode, ) ) @@ -89,7 +92,7 @@ def make_transformed_env(base_env, env_cfg, obs_loc, obs_std, train=False): transformed_env.append_transform( TargetReturn( env_cfg.eval_target_return * env_cfg.reward_scaling, - out_keys=["return_to_go"], + out_keys=["return_to_go_single"], mode=env_cfg.target_return_mode, ) ) @@ -107,7 +110,11 @@ def make_transformed_env(base_env, env_cfg, obs_loc, obs_std, train=False): ) transformed_env.append_transform(obsnorm) transformed_env.append_transform( - UnsqueezeTransform(-2, in_keys=["observation", "action", "return_to_go"]) + UnsqueezeTransform( + -2, + in_keys=["observation", "action", "return_to_go_single"], + out_keys=["observation", "action", "return_to_go"], + ) ) transformed_env.append_transform( CatFrames( @@ -158,6 +165,8 @@ def make_collector(cfg, policy): exclude_target_return = ExcludeTransform( "return_to_go", ("next", "return_to_go"), + "return_to_go_single", + ("next", "return_to_go_single"), ("next", "action"), ("next", "observation"), "scale", @@ -183,9 +192,15 @@ def make_collector(cfg, policy): def make_offline_replay_buffer(rb_cfg, reward_scaling): - r2g = Reward2GoTransform(gamma=1.0, in_keys=["reward"], out_keys=["return_to_go"]) + r2g = Reward2GoTransform( + gamma=1.0, in_keys=["reward"], out_keys=["return_to_go_single"] + ) reward_scale = RewardScaling( - loc=0, scale=reward_scaling, in_keys="return_to_go", standard_normal=False + loc=0, + scale=reward_scaling, + in_keys="return_to_go_single", + out_keys=["return_to_go"], + standard_normal=False, ) crop_seq = RandomCropTensorDict(sub_seq_len=rb_cfg.stacked_frames, sample_dim=-1) @@ -230,12 +245,17 @@ def make_offline_replay_buffer(rb_cfg, reward_scaling): def make_online_replay_buffer(offline_buffer, rb_cfg, reward_scaling=0.001): - r2g = Reward2GoTransform(gamma=1.0, out_keys=["return_to_go"]) + r2g = Reward2GoTransform(gamma=1.0, out_keys=["return_to_go_single"]) reward_scale = RewardScaling( - loc=0, scale=reward_scaling, in_keys="return_to_go", standard_normal=False + loc=0, + scale=reward_scaling, + in_keys=["return_to_go_single"], + out_keys=["return_to_go"], + standard_normal=False, ) catframes = CatFrames( - in_keys=["return_to_go"], + in_keys=["return_to_go_single"], + out_keys=["return_to_go"], N=rb_cfg.stacked_frames, dim=-2, padding="zeros", diff --git a/torchrl/envs/gym_like.py b/torchrl/envs/gym_like.py index 2042511bd99..aa6a0485261 100644 --- a/torchrl/envs/gym_like.py +++ b/torchrl/envs/gym_like.py @@ -146,17 +146,14 @@ def read_done(self, done): """ return done, done - def read_reward(self, total_reward, step_reward): - """Reads a reward and the total reward so far (in the frame skip loop) and returns a sum of the two. + def read_reward(self, reward): + """Reads the reward and maps it to the reward space. Args: - total_reward (torch.Tensor or TensorDict): total reward so far in the step - step_reward (reward in the format provided by the inner env): reward of this particular step + reward (torch.Tensor or TensorDict): reward to be mapped. """ - return ( - total_reward + step_reward - ) # self.reward_spec.encode(step_reward, ignore_device=True) + return self.reward_spec.encode(reward) def read_obs( self, observations: Union[Dict[str, Any], torch.Tensor, np.ndarray] @@ -214,7 +211,7 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase: if _reward is None: _reward = self.reward_spec.zero() - reward = self.read_reward(reward, _reward) + reward = reward + _reward if isinstance(done, bool) or ( isinstance(done, np.ndarray) and not len(done) @@ -224,6 +221,7 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase: if do_break: break + reward = self.read_reward(reward) obs_dict = self.read_obs(obs) if reward is None: diff --git a/torchrl/envs/libs/jumanji.py b/torchrl/envs/libs/jumanji.py index 70181971f05..d27663628b3 100644 --- a/torchrl/envs/libs/jumanji.py +++ b/torchrl/envs/libs/jumanji.py @@ -252,7 +252,6 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase: # prepare inputs state = _tensordict_to_object(tensordict.get("state"), self._state_example) action = self.read_action(tensordict.get("action")) - reward = self.reward_spec.zero() # flatten batch size into vector state = _tree_flatten(state, self.batch_size) @@ -268,7 +267,7 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase: # collect outputs state_dict = self.read_state(state) obs_dict = self.read_obs(timestep.observation) - reward = self.read_reward(reward, np.asarray(timestep.reward)) + reward = self.read_reward(np.asarray(timestep.reward)) done = timestep.step_type == self.lib.types.StepType.LAST done = _ndarray_to_tensor(done).view(torch.bool).to(self.device) diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index cb721d962b9..c458ea8d885 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -151,8 +151,10 @@ def __init__( out_keys_inv: Optional[Sequence[NestedKey]] = None, ): super().__init__() - if isinstance(in_keys, str): + if isinstance(in_keys, (str, tuple)): in_keys = [in_keys] + if isinstance(out_keys, (str, tuple)): + out_keys = [out_keys] self.in_keys = in_keys if out_keys is None: @@ -1132,17 +1134,17 @@ def __init__( self.mode = mode def reset(self, tensordict: TensorDict): - init_target_return = torch.full( - size=(*tensordict.batch_size, 1), - fill_value=self.target_return, - dtype=torch.float32, - device=tensordict.device, - ) for out_key in self.out_keys: target_return = tensordict.get(out_key, default=None) if target_return is None: + init_target_return = torch.full( + size=(*tensordict.batch_size, 1), + fill_value=self.target_return, + dtype=torch.float32, + device=tensordict.device, + ) target_return = init_target_return tensordict.set( @@ -1173,18 +1175,18 @@ def _apply_transform( self, reward: torch.Tensor, target_return: torch.Tensor ) -> torch.Tensor: if self.mode == "reduce": - if reward.ndim == 1 and target_return.ndim == 2: - # if target is stacked - target_return = target_return[-1] - reward - else: - target_return = target_return - reward + # if reward.ndim == 1 and target_return.ndim == 2: + # # if target is stacked + # target_return = target_return[-1] - reward + # else: + target_return = target_return - reward return target_return elif self.mode == "constant": - if reward.ndim == 1 and target_return.ndim == 2: - # if target is stacked - target_return = target_return[-1] - else: - target_return = target_return + # if reward.ndim == 1 and target_return.ndim == 2: + # # if target is stacked + # target_return = target_return[-1] + # else: + target_return = target_return return target_return else: raise ValueError("Unknown mode: {}".format(self.mode)) @@ -2127,7 +2129,7 @@ def _call(self, tensordict: TensorDictBase) -> TensorDictBase: for in_key, out_key in zip(self.in_keys, self.out_keys): # Lazy init of buffers buffer_name = f"_cat_buffers_{in_key}" - data = tensordict[in_key] + data = tensordict.get(in_key) d = data.size(self.dim) buffer = getattr(self, buffer_name) if isinstance(buffer, torch.nn.parameter.UninitializedBuffer): @@ -2297,11 +2299,15 @@ def __init__( loc: Union[float, torch.Tensor], scale: Union[float, torch.Tensor], in_keys: Optional[Sequence[NestedKey]] = None, + out_keys: Optional[Sequence[NestedKey]] = None, standard_normal: bool = False, ): if in_keys is None: in_keys = ["reward"] - super().__init__(in_keys=in_keys) + if out_keys is None: + out_keys = in_keys + + super().__init__(in_keys=in_keys, out_keys=out_keys) if not isinstance(standard_normal, torch.Tensor): standard_normal = torch.tensor(standard_normal) self.register_buffer("standard_normal", standard_normal) diff --git a/torchrl/envs/utils.py b/torchrl/envs/utils.py index 7adef6aeafd..a4b217bb922 100644 --- a/torchrl/envs/utils.py +++ b/torchrl/envs/utils.py @@ -27,7 +27,6 @@ from tensordict.tensordict import ( LazyStackedTensorDict, NestedKey, - TensorDict, TensorDictBase, ) diff --git a/torchrl/envs/vec_env.py b/torchrl/envs/vec_env.py index 5e09706cf0b..c4cbf1fe74a 100644 --- a/torchrl/envs/vec_env.py +++ b/torchrl/envs/vec_env.py @@ -20,7 +20,7 @@ import torch from tensordict import TensorDict -from tensordict._tensordict import _unravel_key_to_tuple, unravel_keys +from tensordict._tensordict import _unravel_key_to_tuple from tensordict.tensordict import LazyStackedTensorDict, TensorDictBase from torch import multiprocessing as mp from torchrl._utils import _check_for_faulty_process, VERBOSE @@ -35,7 +35,6 @@ from torchrl.envs.env_creator import get_env_metadata from torchrl.envs.utils import ( - _replace_last, _set_single_key, _sort_keys, clear_mpi_env_vars, @@ -340,9 +339,7 @@ def _create_td(self) -> None: for key in self.output_spec["full_observation_spec"].keys(True, True): self._env_output_keys.append(key) self._env_obs_keys.append(key) - self._env_output_keys += [ - unravel_keys(("next", key)) for key in self.reward_keys + self.done_keys - ] + self._env_output_keys += self.reward_keys + self.done_keys else: env_input_keys = set() for meta_data in self.meta_data: From fd089d1f3e502d0eed32cd032328d2c57b66cc79 Mon Sep 17 00:00:00 2001 From: vmoens Date: Thu, 31 Aug 2023 11:37:19 -0400 Subject: [PATCH 13/29] lint --- torchrl/envs/utils.py | 6 +----- torchrl/envs/vec_env.py | 6 +----- 2 files changed, 2 insertions(+), 10 deletions(-) diff --git a/torchrl/envs/utils.py b/torchrl/envs/utils.py index a4b217bb922..7b7a0f9a615 100644 --- a/torchrl/envs/utils.py +++ b/torchrl/envs/utils.py @@ -24,11 +24,7 @@ set_interaction_mode as set_exploration_mode, set_interaction_type as set_exploration_type, ) -from tensordict.tensordict import ( - LazyStackedTensorDict, - NestedKey, - TensorDictBase, -) +from tensordict.tensordict import LazyStackedTensorDict, NestedKey, TensorDictBase __all__ = [ "exploration_mode", diff --git a/torchrl/envs/vec_env.py b/torchrl/envs/vec_env.py index c4cbf1fe74a..23cfb8d42d4 100644 --- a/torchrl/envs/vec_env.py +++ b/torchrl/envs/vec_env.py @@ -34,11 +34,7 @@ from torchrl.envs.common import _EnvWrapper, EnvBase from torchrl.envs.env_creator import get_env_metadata -from torchrl.envs.utils import ( - _set_single_key, - _sort_keys, - clear_mpi_env_vars, -) +from torchrl.envs.utils import _set_single_key, _sort_keys, clear_mpi_env_vars _has_envpool = importlib.util.find_spec("envpool") From d6f304a7e5f97848369f376f94d4713352dbaa51 Mon Sep 17 00:00:00 2001 From: vmoens Date: Thu, 31 Aug 2023 11:47:00 -0400 Subject: [PATCH 14/29] fixes --- test/test_collector.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/test/test_collector.py b/test/test_collector.py index 17eed15409c..504eb807fbd 100644 --- a/test/test_collector.py +++ b/test/test_collector.py @@ -375,7 +375,6 @@ def make_env(seed): num_workers=num_env, create_env_fn=make_env, create_env_kwargs=[{"seed": i} for i in range(seed, seed + num_env)], - allow_step_when_done=True, ) env.set_seed(seed) return env @@ -424,7 +423,6 @@ def make_env(seed): num_workers=num_env, create_env_fn=make_env, create_env_kwargs=[{"seed": i} for i in range(seed, seed + num_env)], - allow_step_when_done=True, ) env.set_seed(seed) return env From df0210e76cbc96f0f0c12434d7d4e027951bef63 Mon Sep 17 00:00:00 2001 From: vmoens Date: Thu, 31 Aug 2023 11:47:46 -0400 Subject: [PATCH 15/29] amend --- torchrl/envs/transforms/transforms.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index c458ea8d885..518926efd85 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -1175,17 +1175,9 @@ def _apply_transform( self, reward: torch.Tensor, target_return: torch.Tensor ) -> torch.Tensor: if self.mode == "reduce": - # if reward.ndim == 1 and target_return.ndim == 2: - # # if target is stacked - # target_return = target_return[-1] - reward - # else: target_return = target_return - reward return target_return elif self.mode == "constant": - # if reward.ndim == 1 and target_return.ndim == 2: - # # if target is stacked - # target_return = target_return[-1] - # else: target_return = target_return return target_return else: From 78a06d580332ad5f231ae782e6aa072c387549d5 Mon Sep 17 00:00:00 2001 From: vmoens Date: Thu, 31 Aug 2023 11:53:47 -0400 Subject: [PATCH 16/29] amend --- test/test_collector.py | 92 +++++++++++++-------------- torchrl/envs/transforms/transforms.py | 4 ++ 2 files changed, 50 insertions(+), 46 deletions(-) diff --git a/test/test_collector.py b/test/test_collector.py index 504eb807fbd..c9aa7269a79 100644 --- a/test/test_collector.py +++ b/test/test_collector.py @@ -352,52 +352,52 @@ def make_env(): _data = split_trajectories(_data, prefix="collector") assert _data["next", "reward"].sum(-2).min() == -21 - -@pytest.mark.parametrize("num_env", [1, 2]) -@pytest.mark.parametrize("env_name", ["vec"]) -def test_collector_done_persist(num_env, env_name, seed=5): - if num_env == 1: - - def env_fn(seed): - env = MockSerialEnv(device="cpu") - env.set_seed(seed) - return env - - else: - - def env_fn(seed): - def make_env(seed): - env = MockSerialEnv(device="cpu") - env.set_seed(seed) - return env - - env = ParallelEnv( - num_workers=num_env, - create_env_fn=make_env, - create_env_kwargs=[{"seed": i} for i in range(seed, seed + num_env)], - ) - env.set_seed(seed) - return env - - policy = make_policy(env_name) - - collector = SyncDataCollector( - create_env_fn=env_fn, - create_env_kwargs={"seed": seed}, - policy=policy, - frames_per_batch=200 * num_env, - max_frames_per_traj=2000, - total_frames=20000, - device="cpu", - reset_when_done=False, - ) - for _, d in enumerate(collector): # noqa - break - - assert (d["done"].sum(-2) >= 1).all() - assert torch.unique(d["collector", "traj_ids"], dim=-1).shape[-1] == 1 - - del collector +# Deprecated reset_when_done +# @pytest.mark.parametrize("num_env", [1, 2]) +# @pytest.mark.parametrize("env_name", ["vec"]) +# def test_collector_done_persist(num_env, env_name, seed=5): +# if num_env == 1: +# +# def env_fn(seed): +# env = MockSerialEnv(device="cpu") +# env.set_seed(seed) +# return env +# +# else: +# +# def env_fn(seed): +# def make_env(seed): +# env = MockSerialEnv(device="cpu") +# env.set_seed(seed) +# return env +# +# env = ParallelEnv( +# num_workers=num_env, +# create_env_fn=make_env, +# create_env_kwargs=[{"seed": i} for i in range(seed, seed + num_env)], +# ) +# env.set_seed(seed) +# return env +# +# policy = make_policy(env_name) +# +# collector = SyncDataCollector( +# create_env_fn=env_fn, +# create_env_kwargs={"seed": seed}, +# policy=policy, +# frames_per_batch=200 * num_env, +# max_frames_per_traj=2000, +# total_frames=20000, +# device="cpu", +# reset_when_done=False, +# ) +# for _, d in enumerate(collector): # noqa +# break +# +# assert (d["done"].sum(-2) >= 1).all() +# assert torch.unique(d["collector", "traj_ids"], dim=-1).shape[-1] == 1 +# +# del collector @pytest.mark.parametrize("frames_per_batch", [200, 10]) diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index 518926efd85..81781b04dec 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -1174,6 +1174,10 @@ def _step( def _apply_transform( self, reward: torch.Tensor, target_return: torch.Tensor ) -> torch.Tensor: + if target_return.shape != reward.shape: + raise ValueError( + f"The shape of the reward ({reward.shape}) and target return ({target_return.shape}) must match." + ) if self.mode == "reduce": target_return = target_return - reward return target_return From 45615c3688d63097e25cfe70c2aa2b26bc846fe5 Mon Sep 17 00:00:00 2001 From: vmoens Date: Thu, 31 Aug 2023 12:19:20 -0400 Subject: [PATCH 17/29] amend --- test/test_collector.py | 8 +++----- test/test_transforms.py | 12 ++++++------ 2 files changed, 9 insertions(+), 11 deletions(-) diff --git a/test/test_collector.py b/test/test_collector.py index c9aa7269a79..e3aee2f2ea6 100644 --- a/test/test_collector.py +++ b/test/test_collector.py @@ -1654,11 +1654,9 @@ def _step( self.state += action return TensorDict( { - "next": { - "state": self.state.clone(), - "reward": self.reward_spec.zero(), - "done": self.done_spec.zero(), - } + "state": self.state.clone(), + "reward": self.reward_spec.zero(), + "done": self.done_spec.zero(), }, self.batch_size, ) diff --git a/test/test_transforms.py b/test/test_transforms.py index c5fa69d4aa2..f20e505145a 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -5220,8 +5220,8 @@ def test_transform_no_env(self, mode, in_key, out_key): t = TargetReturn( target_return=10.0, mode=mode, in_keys=[in_key], out_keys=[out_key] ) - reward = torch.randn(10) - td = TensorDict({("next", in_key): reward}, []) + reward = torch.randn(10, 1) + td = TensorDict({("next", in_key): reward}, [10]) td = t.reset(td) td_next = t._step(td, td.get("next")) td.set("next", td_next) @@ -5235,8 +5235,8 @@ def test_transform_model( ): t = TargetReturn(target_return=10.0) model = nn.Sequential(t, nn.Identity()) - reward = torch.randn(10) - td = TensorDict({("next", "reward"): reward}, []) + reward = torch.randn(10, 1) + td = TensorDict({("next", "reward"): reward}, [10]) with pytest.raises( NotImplementedError, match="cannot be executed without a parent" ): @@ -5249,8 +5249,8 @@ def test_transform_rb( ): t = TargetReturn(target_return=10.0) rb = rbclass(storage=LazyTensorStorage(10)) - reward = torch.randn(10) - td = TensorDict({("next", "reward"): reward}, []).expand(10) + reward = torch.randn(10, 1) + td = TensorDict({("next", "reward"): reward}, [10]) rb.append_transform(t) rb.extend(td) with pytest.raises( From bc3abd2a2d22d938cfed6dac1aabb891210836f4 Mon Sep 17 00:00:00 2001 From: vmoens Date: Thu, 31 Aug 2023 12:26:09 -0400 Subject: [PATCH 18/29] amend --- torchrl/envs/transforms/transforms.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index 0b54ecd9f7f..959105ade19 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -3333,7 +3333,7 @@ def to(self, dtype_or_device): # return observation_spec def transform_input_spec(self, input_spec: CompositeSpec) -> CompositeSpec: - state_spec = input_spec['_state_spec'] + state_spec = input_spec['full_state_spec'] if state_spec is None: state_spec = CompositeSpec(shape=input_spec.shape, device=input_spec.device) for key, spec in self.primers.items(): @@ -3348,7 +3348,7 @@ def transform_input_spec(self, input_spec: CompositeSpec) -> CompositeSpec: device = self.device print('state spec key', key) state_spec[key] = spec.to(device) - input_spec["_state_spec"] = state_spec + input_spec["full_state_spec"] = state_spec return input_spec @property From e0d81ef8728d7e266d65bc7278536f5d31898278 Mon Sep 17 00:00:00 2001 From: vmoens Date: Thu, 31 Aug 2023 15:54:07 -0400 Subject: [PATCH 19/29] tmp --- test/mocking_classes.py | 2 +- test/test_collector.py | 1 + torchrl/envs/common.py | 2 -- torchrl/envs/transforms/transforms.py | 1 - 4 files changed, 2 insertions(+), 4 deletions(-) diff --git a/test/mocking_classes.py b/test/mocking_classes.py index 61ec05bc9f2..91676699997 100644 --- a/test/mocking_classes.py +++ b/test/mocking_classes.py @@ -426,7 +426,7 @@ def __new__( if categorical_action_encoding else OneHotDiscreteTensorSpec ) - action_spec = action_spec_cls(n=7, shape=batch_size) + action_spec = action_spec_cls(n=7, shape=(*batch_size, 7)) if reward_spec is None: reward_spec = UnboundedContinuousTensorSpec(shape=(1,)) if done_spec is None: diff --git a/test/test_collector.py b/test/test_collector.py index e3aee2f2ea6..7d8979b43c0 100644 --- a/test/test_collector.py +++ b/test/test_collector.py @@ -352,6 +352,7 @@ def make_env(): _data = split_trajectories(_data, prefix="collector") assert _data["next", "reward"].sum(-2).min() == -21 + # Deprecated reset_when_done # @pytest.mark.parametrize("num_env", [1, 2]) # @pytest.mark.parametrize("env_name", ["vec"]) diff --git a/torchrl/envs/common.py b/torchrl/envs/common.py index 3ca2c90b572..a2caa5449ef 100644 --- a/torchrl/envs/common.py +++ b/torchrl/envs/common.py @@ -1565,9 +1565,7 @@ def policy(td): tensordict = policy(tensordict) if auto_cast_to_device: tensordict = tensordict.to(env_device, non_blocking=True) - print("before", tensordict["next", "recurrent_state_c"]) tensordict = self.step(tensordict) - print("after", tensordict["next", "recurrent_state_c"]) tensordicts.append(tensordict.clone(False)) diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index 959105ade19..fd97e31fb4c 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -3346,7 +3346,6 @@ def transform_input_spec(self, input_spec: CompositeSpec) -> CompositeSpec: device = state_spec.device except RuntimeError: device = self.device - print('state spec key', key) state_spec[key] = spec.to(device) input_spec["full_state_spec"] = state_spec return input_spec From 9fc95fcb36f035f0d09d10652a1ab0ac5cd992a9 Mon Sep 17 00:00:00 2001 From: vmoens Date: Fri, 1 Sep 2023 07:17:00 -0400 Subject: [PATCH 20/29] amend --- test/mocking_classes.py | 2 +- test/test_tensordictmodules.py | 4 ++-- torchrl/envs/transforms/transforms.py | 5 +++-- torchrl/envs/vec_env.py | 13 ++++++++----- torchrl/modules/tensordict_module/rnn.py | 16 ++++++++-------- 5 files changed, 22 insertions(+), 18 deletions(-) diff --git a/test/mocking_classes.py b/test/mocking_classes.py index 91676699997..61ec05bc9f2 100644 --- a/test/mocking_classes.py +++ b/test/mocking_classes.py @@ -426,7 +426,7 @@ def __new__( if categorical_action_encoding else OneHotDiscreteTensorSpec ) - action_spec = action_spec_cls(n=7, shape=(*batch_size, 7)) + action_spec = action_spec_cls(n=7, shape=batch_size) if reward_spec is None: reward_spec = UnboundedContinuousTensorSpec(shape=(1,)) if done_spec is None: diff --git a/test/test_tensordictmodules.py b/test/test_tensordictmodules.py index c942bf0de02..c6ea7b4ec67 100644 --- a/test/test_tensordictmodules.py +++ b/test/test_tensordictmodules.py @@ -1814,8 +1814,8 @@ def create_transformed_env(): ) for break_when_any_done in [False, True]: data = env.rollout(10, actor, break_when_any_done=break_when_any_done) - assert (data.get("recurrent_state_c") != 0.0).any() - assert (data.get("next", "recurrent_state_c") != 0.0).all() + # assert (data.get("recurrent_state_c") != 0.0).any() + assert (data.get(("next", "recurrent_state_c")) != 0.0).all() def test_safe_specs(): diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index fd97e31fb4c..4237ee84322 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -3376,9 +3376,10 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: def _step( self, tensordict: TensorDictBase, next_tensordict: TensorDictBase ) -> TensorDictBase: - for key in self.primers.keys(): - next_tensordict.setdefault(key, tensordict.get(key, default=None)) return next_tensordict + # for key in self.primers.keys(): + # next_tensordict.setdefault(key, tensordict.get(key, default=None)) + # return next_tensordict def reset(self, tensordict: TensorDictBase) -> TensorDictBase: """Sets the default values in the input tensordict. diff --git a/torchrl/envs/vec_env.py b/torchrl/envs/vec_env.py index 23cfb8d42d4..d9edc071455 100644 --- a/torchrl/envs/vec_env.py +++ b/torchrl/envs/vec_env.py @@ -20,7 +20,7 @@ import torch from tensordict import TensorDict -from tensordict._tensordict import _unravel_key_to_tuple +from tensordict._tensordict import _unravel_key_to_tuple, unravel_keys from tensordict.tensordict import LazyStackedTensorDict, TensorDictBase from torch import multiprocessing as mp from torchrl._utils import _check_for_faulty_process, VERBOSE @@ -394,13 +394,14 @@ def _create_td(self) -> None: shared_tensordict_parent = shared_tensordict_parent.select( *self._selected_keys, "next", + *[unravel_keys(("next", key)) for key in self._env_output_keys], strict=False, ) self.shared_tensordict_parent = shared_tensordict_parent.to(self.device) else: # Multi-task: we share tensordict that *may* have different keys shared_tensordict_parent = [ - tensordict.select(*self._selected_keys, "next", strict=False).to( + tensordict.select(*self._selected_keys, "next", *[unravel_keys(("next", key)) for key in self._env_output_keys], strict=False).to( self.device ) for tensordict in shared_tensordict_parent @@ -785,9 +786,12 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase: # this is faster than update_ but won't work for lazy stacks for key in self._env_input_keys: key = _unravel_key_to_tuple(key) + val = tensordict._get_tuple(key, None) + if val is None: + continue self.shared_tensordict_parent._set_tuple( key, - tensordict._get_tuple(key, None), + val, inplace=True, validated=True, ) @@ -1058,9 +1062,8 @@ def _run_worker_pipe_shared_mem( if initialized: raise RuntimeError("worker already initialized") i = 0 - next_shared_tensordict = shared_tensordict.get("next") + next_shared_tensordict = shared_tensordict.get("next").clone(False) shared_tensordict = shared_tensordict.clone(False) - del shared_tensordict["next"] if not (shared_tensordict.is_shared() or shared_tensordict.is_memmap()): raise RuntimeError( diff --git a/torchrl/modules/tensordict_module/rnn.py b/torchrl/modules/tensordict_module/rnn.py index 0ca52e024b2..b724064d15e 100644 --- a/torchrl/modules/tensordict_module/rnn.py +++ b/torchrl/modules/tensordict_module/rnn.py @@ -227,16 +227,16 @@ def make_tuple(key): ) return TensorDictPrimer( { - # in_key1: UnboundedContinuousTensorSpec( - # shape=(self.lstm.num_layers, self.lstm.hidden_size) - # ), - # in_key2: UnboundedContinuousTensorSpec( - # shape=(self.lstm.num_layers, self.lstm.hidden_size) - # ), - unravel_key(("next", in_key1)): UnboundedContinuousTensorSpec( + in_key1: UnboundedContinuousTensorSpec( shape=(self.lstm.num_layers, self.lstm.hidden_size) ), - unravel_key(("next", in_key2)): UnboundedContinuousTensorSpec( + in_key2: UnboundedContinuousTensorSpec( + shape=(self.lstm.num_layers, self.lstm.hidden_size) + ), + out_key1: UnboundedContinuousTensorSpec( + shape=(self.lstm.num_layers, self.lstm.hidden_size) + ), + out_key2: UnboundedContinuousTensorSpec( shape=(self.lstm.num_layers, self.lstm.hidden_size) ), } From 8f3ed5eecaaa602cf6e306484c3ca0fb5644f7eb Mon Sep 17 00:00:00 2001 From: vmoens Date: Fri, 1 Sep 2023 08:41:16 -0400 Subject: [PATCH 21/29] amend --- test/test_tensordictmodules.py | 2 +- torchrl/envs/transforms/transforms.py | 43 +++++++++++++-------------- 2 files changed, 22 insertions(+), 23 deletions(-) diff --git a/test/test_tensordictmodules.py b/test/test_tensordictmodules.py index c6ea7b4ec67..58e92d3fc28 100644 --- a/test/test_tensordictmodules.py +++ b/test/test_tensordictmodules.py @@ -1814,7 +1814,7 @@ def create_transformed_env(): ) for break_when_any_done in [False, True]: data = env.rollout(10, actor, break_when_any_done=break_when_any_done) - # assert (data.get("recurrent_state_c") != 0.0).any() + assert (data.get("recurrent_state_c") != 0.0).any() assert (data.get(("next", "recurrent_state_c")) != 0.0).all() diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index 4237ee84322..86d3516fe12 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -3312,25 +3312,25 @@ def to(self, dtype_or_device): self.device = dtype_or_device return super().to(dtype_or_device) - # def transform_observation_spec( - # self, observation_spec: CompositeSpec - # ) -> CompositeSpec: - # if not isinstance(observation_spec, CompositeSpec): - # raise ValueError( - # f"observation_spec was expected to be of type CompositeSpec. Got {type(observation_spec)} instead." - # ) - # for key, spec in self.primers.items(): - # if spec.shape[: len(observation_spec.shape)] != observation_spec.shape: - # raise RuntimeError( - # f"The leading shape of the primer specs ({self.__class__}) should match the one of the parent env. " - # f"Got observation_spec.shape={observation_spec.shape} but the '{key}' entry's shape is {spec.shape}." - # ) - # try: - # device = observation_spec.device - # except RuntimeError: - # device = self.device - # observation_spec[key] = spec.to(device) - # return observation_spec + def transform_observation_spec( + self, observation_spec: CompositeSpec + ) -> CompositeSpec: + if not isinstance(observation_spec, CompositeSpec): + raise ValueError( + f"observation_spec was expected to be of type CompositeSpec. Got {type(observation_spec)} instead." + ) + for key, spec in self.primers.items(): + if spec.shape[: len(observation_spec.shape)] != observation_spec.shape: + raise RuntimeError( + f"The leading shape of the primer specs ({self.__class__}) should match the one of the parent env. " + f"Got observation_spec.shape={observation_spec.shape} but the '{key}' entry's shape is {spec.shape}." + ) + try: + device = observation_spec.device + except RuntimeError: + device = self.device + observation_spec[key] = spec.to(device) + return observation_spec def transform_input_spec(self, input_spec: CompositeSpec) -> CompositeSpec: state_spec = input_spec['full_state_spec'] @@ -3376,10 +3376,9 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: def _step( self, tensordict: TensorDictBase, next_tensordict: TensorDictBase ) -> TensorDictBase: + for key in self.primers.keys(): + next_tensordict.setdefault(key, tensordict.get(key, default=None)) return next_tensordict - # for key in self.primers.keys(): - # next_tensordict.setdefault(key, tensordict.get(key, default=None)) - # return next_tensordict def reset(self, tensordict: TensorDictBase) -> TensorDictBase: """Sets the default values in the input tensordict. From e6fa755f1323c7ec8dc5071fd4aac6faa80daa72 Mon Sep 17 00:00:00 2001 From: vmoens Date: Fri, 1 Sep 2023 09:41:54 -0400 Subject: [PATCH 22/29] amend --- test/mocking_classes.py | 2 +- torchrl/envs/common.py | 5 ++--- torchrl/envs/transforms/transforms.py | 18 ------------------ torchrl/modules/tensordict_module/rnn.py | 6 ------ 4 files changed, 3 insertions(+), 28 deletions(-) diff --git a/test/mocking_classes.py b/test/mocking_classes.py index 6fcc3426816..0a70483f9ba 100644 --- a/test/mocking_classes.py +++ b/test/mocking_classes.py @@ -1053,7 +1053,7 @@ def _step( batch_size=self.batch_size, device=self.device, ) - return tensordict.select().set("next", tensordict) + return tensordict class NestedCountingEnv(CountingEnv): diff --git a/torchrl/envs/common.py b/torchrl/envs/common.py index a2caa5449ef..cc2137c6948 100644 --- a/torchrl/envs/common.py +++ b/torchrl/envs/common.py @@ -1131,10 +1131,9 @@ def step(self, tensordict: TensorDictBase) -> TensorDictBase: next_tensordict = self._step(tensordict) next_tensordict = self._step_proc_data(next_tensordict) if next_preset is not None: - next_preset.update(next_tensordict) - else: # tensordict could already have a "next" key - tensordict.set("next", next_tensordict) + next_tensordict.update(next_preset) + tensordict.set("next", next_tensordict) return tensordict def _step_proc_data(self, next_tensordict_out): diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index d5cdfc18d46..6a726883971 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -3333,24 +3333,6 @@ def transform_observation_spec( observation_spec[key] = spec.to(device) return observation_spec - def transform_input_spec(self, input_spec: CompositeSpec) -> CompositeSpec: - state_spec = input_spec['full_state_spec'] - if state_spec is None: - state_spec = CompositeSpec(shape=input_spec.shape, device=input_spec.device) - for key, spec in self.primers.items(): - if spec.shape[: len(state_spec.shape)] != state_spec.shape: - raise RuntimeError( - f"The leading shape of the primer specs ({self.__class__}) should match the one of the parent env. " - f"Got state_spec.shape={state_spec.shape} but the '{key}' entry's shape is {spec.shape}." - ) - try: - device = state_spec.device - except RuntimeError: - device = self.device - state_spec[key] = spec.to(device) - input_spec["full_state_spec"] = state_spec - return input_spec - @property def _batch_size(self): return self.parent.batch_size diff --git a/torchrl/modules/tensordict_module/rnn.py b/torchrl/modules/tensordict_module/rnn.py index b724064d15e..da8ca65ae1e 100644 --- a/torchrl/modules/tensordict_module/rnn.py +++ b/torchrl/modules/tensordict_module/rnn.py @@ -233,12 +233,6 @@ def make_tuple(key): in_key2: UnboundedContinuousTensorSpec( shape=(self.lstm.num_layers, self.lstm.hidden_size) ), - out_key1: UnboundedContinuousTensorSpec( - shape=(self.lstm.num_layers, self.lstm.hidden_size) - ), - out_key2: UnboundedContinuousTensorSpec( - shape=(self.lstm.num_layers, self.lstm.hidden_size) - ), } ) From 38f74a210c9227c6c090703060fca640efbcbe7a Mon Sep 17 00:00:00 2001 From: vmoens Date: Fri, 1 Sep 2023 11:55:49 -0400 Subject: [PATCH 23/29] amend --- torchrl/envs/vec_env.py | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/torchrl/envs/vec_env.py b/torchrl/envs/vec_env.py index d9edc071455..23cfb8d42d4 100644 --- a/torchrl/envs/vec_env.py +++ b/torchrl/envs/vec_env.py @@ -20,7 +20,7 @@ import torch from tensordict import TensorDict -from tensordict._tensordict import _unravel_key_to_tuple, unravel_keys +from tensordict._tensordict import _unravel_key_to_tuple from tensordict.tensordict import LazyStackedTensorDict, TensorDictBase from torch import multiprocessing as mp from torchrl._utils import _check_for_faulty_process, VERBOSE @@ -394,14 +394,13 @@ def _create_td(self) -> None: shared_tensordict_parent = shared_tensordict_parent.select( *self._selected_keys, "next", - *[unravel_keys(("next", key)) for key in self._env_output_keys], strict=False, ) self.shared_tensordict_parent = shared_tensordict_parent.to(self.device) else: # Multi-task: we share tensordict that *may* have different keys shared_tensordict_parent = [ - tensordict.select(*self._selected_keys, "next", *[unravel_keys(("next", key)) for key in self._env_output_keys], strict=False).to( + tensordict.select(*self._selected_keys, "next", strict=False).to( self.device ) for tensordict in shared_tensordict_parent @@ -786,12 +785,9 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase: # this is faster than update_ but won't work for lazy stacks for key in self._env_input_keys: key = _unravel_key_to_tuple(key) - val = tensordict._get_tuple(key, None) - if val is None: - continue self.shared_tensordict_parent._set_tuple( key, - val, + tensordict._get_tuple(key, None), inplace=True, validated=True, ) @@ -1062,8 +1058,9 @@ def _run_worker_pipe_shared_mem( if initialized: raise RuntimeError("worker already initialized") i = 0 - next_shared_tensordict = shared_tensordict.get("next").clone(False) + next_shared_tensordict = shared_tensordict.get("next") shared_tensordict = shared_tensordict.clone(False) + del shared_tensordict["next"] if not (shared_tensordict.is_shared() or shared_tensordict.is_memmap()): raise RuntimeError( From b568c3b6451003b9c399ddf828d1aa4cc5d1d550 Mon Sep 17 00:00:00 2001 From: vmoens Date: Fri, 1 Sep 2023 12:06:06 -0400 Subject: [PATCH 24/29] amend --- test/test_transforms.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/test/test_transforms.py b/test/test_transforms.py index 40037085e8d..37348328352 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -6693,6 +6693,7 @@ def _test_vecnorm_subproc_auto( tensordict = env.reset() for _ in range(10): tensordict = env.rand_step(tensordict) + tensordict = step_mdp(tensordict) queue_out.put(True) msg = queue_in.get(timeout=TIMEOUT) assert msg == "all_done" @@ -6800,11 +6801,13 @@ def _run_parallelenv(parallel_env, queue_in, queue_out): assert msg == "start" for _ in range(10): tensordict = parallel_env.rand_step(tensordict) + tensordict = step_mdp(tensordict) queue_out.put("first round") msg = queue_in.get(timeout=TIMEOUT) assert msg == "start" for _ in range(10): tensordict = parallel_env.rand_step(tensordict) + tensordict = step_mdp(tensordict) queue_out.put("second round") parallel_env.close() queue_out.close() @@ -6884,6 +6887,7 @@ def test_vecnorm_rollout(self, parallel, thr=0.2, N=200): for _ in range(N): td = env_t.rand_step(td) tds.append(td.clone()) + td = step_mdp(td) if td.get("done").any(): td = env_t.reset() tds = torch.stack(tds, 0) From 71d1076791349543bf67334b46963b3059f70abe Mon Sep 17 00:00:00 2001 From: vmoens Date: Fri, 1 Sep 2023 12:07:38 -0400 Subject: [PATCH 25/29] amend --- test/test_tensordictmodules.py | 32 +++++++++++++----------- torchrl/modules/tensordict_module/rnn.py | 2 +- 2 files changed, 19 insertions(+), 15 deletions(-) diff --git a/test/test_tensordictmodules.py b/test/test_tensordictmodules.py index 58e92d3fc28..ca1e0e46e57 100644 --- a/test/test_tensordictmodules.py +++ b/test/test_tensordictmodules.py @@ -4,13 +4,17 @@ # LICENSE file in the root directory of this source tree. import argparse -from mocking_classes import DiscreteActionVecMockEnv -from tensordict.nn import TensorDictSequential -from torchrl.modules import MLP, ProbabilisticActor + import pytest import torch +from mocking_classes import DiscreteActionVecMockEnv from tensordict import pad, TensorDict, unravel_key_list -from tensordict.nn import InteractionType, make_functional, TensorDictModule +from tensordict.nn import ( + InteractionType, + make_functional, + TensorDictModule, + TensorDictSequential, +) from torch import nn from torchrl.data.tensor_specs import ( BoundedTensorSpec, @@ -23,6 +27,7 @@ DecisionTransformerInferenceWrapper, DTActor, LSTMModule, + MLP, NormalParamWrapper, OnlineDTActor, ProbabilisticActor, @@ -1768,21 +1773,20 @@ def test_multi_consecutive(self, shape): ) def test_lstm_parallel_env(self): - from torchrl.envs import ParallelEnv, TransformedEnv, InitTracker + from torchrl.envs import InitTracker, ParallelEnv, TransformedEnv + # tests that hidden states are carried over with parallel envs lstm_module = LSTMModule( - input_size=7, - hidden_size=12, - num_layers=2, - in_key="observation", - out_key="features", - ) + input_size=7, + hidden_size=12, + num_layers=2, + in_key="observation", + out_key="features", + ) def create_transformed_env(): primer = lstm_module.make_tensordict_primer() - env = DiscreteActionVecMockEnv( - categorical_action_encoding=True - ) + env = DiscreteActionVecMockEnv(categorical_action_encoding=True) env = TransformedEnv(env) env.append_transform(InitTracker()) env.append_transform(primer) diff --git a/torchrl/modules/tensordict_module/rnn.py b/torchrl/modules/tensordict_module/rnn.py index da8ca65ae1e..6baa4ad267d 100644 --- a/torchrl/modules/tensordict_module/rnn.py +++ b/torchrl/modules/tensordict_module/rnn.py @@ -5,7 +5,7 @@ from typing import Optional, Tuple import torch -from tensordict import unravel_key_list, unravel_key, TensorDictBase +from tensordict import TensorDictBase, unravel_key_list from tensordict.nn import TensorDictModuleBase as ModuleBase From 5f7885e27d1cb3db6fa0b100adcec56cd651c998 Mon Sep 17 00:00:00 2001 From: vmoens Date: Fri, 1 Sep 2023 12:17:31 -0400 Subject: [PATCH 26/29] init --- torchrl/modules/tensordict_module/rnn.py | 365 +++++++++++++++++++++++ 1 file changed, 365 insertions(+) diff --git a/torchrl/modules/tensordict_module/rnn.py b/torchrl/modules/tensordict_module/rnn.py index 6baa4ad267d..76d68802786 100644 --- a/torchrl/modules/tensordict_module/rnn.py +++ b/torchrl/modules/tensordict_module/rnn.py @@ -399,3 +399,368 @@ def _lstm( 1, ) return tuple(out) + + +class GRUModule(ModuleBase): + """An embedder for an GRU module. + + This class adds the following functionality to :class:`torch.nn.GRU`: + + - Compatibility with TensorDict: the hidden states are reshaped to match + the tensordict batch size. + - Optional multi-step execution: with torch.nn, one has to choose between + :class:`torch.nn.GRUCell` and :class:`torch.nn.GRU`, the former being + compatible with single step inputs and the latter being compatible with + multi-step. This class enables both usages. + + + After construction, the module is *not* set in temporal mode, ie. it will + expect single steps inputs. + + If in temporal mode, it is expected that the last dimension of the tensordict + marks the number of steps. There is no constrain on the dimensionality of the + tensordict (except that it must be greater than one for temporal inputs). + + Args: + input_size: The number of expected features in the input `x` + hidden_size: The number of features in the hidden state `h` + num_layers: Number of recurrent layers. E.g., setting ``num_layers=2`` + would mean stacking two LSTMs together to form a `stacked LSTM`, + with the second LSTM taking in outputs of the first LSTM and + computing the final results. Default: 1 + bias: If ``False``, then the layer does not use bias weights `b_ih` and `b_hh`. + Default: ``True`` + dropout: If non-zero, introduces a `Dropout` layer on the outputs of each + LSTM layer except the last layer, with dropout probability equal to + :attr:`dropout`. Default: 0 + proj_size: If ``> 0``, will use LSTM with projections of corresponding size. Default: 0 + + Keyword Args: + in_key (str or tuple of str): the input key of the module. Exclusive use + with ``in_keys``. If provided, the recurrent keys are assumed to be + ["recurrent_state"] and the ``in_key`` will be + appended before this. + in_keys (list of str): a pair of strings corresponding to the input value and recurrent entry. + Exclusive with ``in_key``. + out_key (str or tuple of str): the output key of the module. Exclusive use + with ``out_keys``. If provided, the recurrent keys are assumed to be + [("recurrent_state")] and the ``out_key`` will be + appended before these. + out_keys (list of str): a pair of strings corresponding to the output value, + first and second hidden key. + .. note:: + For a better integration with TorchRL's environments, the best naming + for the output hidden key is ``("next", )``, such + that the hidden values are passed from step to step during a rollout. + device (torch.device or compatible): the device of the module. + gru (torch.nn.GRU, optional): a GRU instance to be wrapped. + Exclusive with other nn.GRU arguments. + + Attributes: + temporal_mode: Returns the temporal mode of the module. + + Methods: + set_temporal_mode: controls whether the module should be executed in + temporal mode. + + Examples: + >>> from torchrl.envs import TransformedEnv, InitTracker + >>> from torchrl.envs.libs.gym import GymEnv + >>> from torchrl.modules import MLP + >>> from torch import nn + >>> from tensordict.nn import TensorDictSequential as Seq, TensorDictModule as Mod + >>> env = TransformedEnv(GymEnv("Pendulum-v1"), InitTracker()) + >>> gru_module = GRUModule( + ... input_size=env.observation_spec["observation"].shape[-1], + ... hidden_size=64, + ... in_keys=["observation", "rs"], + ... out_keys=["intermediate", ("next", "rs")]) + >>> mlp = MLP(num_cells=[64], out_features=1) + >>> policy = Seq(gru_module, Mod(mlp, in_keys=["intermediate"], out_keys=["action"])) + >>> policy(env.reset()) + TensorDict( + fields={ + action: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False), + done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False), + intermediate: Tensor(shape=torch.Size([64]), device=cpu, dtype=torch.float32, is_shared=False), + is_init: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False), + next: TensorDict( + fields={ + rs: Tensor(shape=torch.Size([1, 64]), device=cpu, dtype=torch.float32, is_shared=False), + batch_size=torch.Size([]), + device=cpu, + is_shared=False), + observation: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False)}, + batch_size=torch.Size([]), + device=cpu, + is_shared=False) + + """ + + DEFAULT_IN_KEYS = ["recurrent_state"] + DEFAULT_OUT_KEYS = [("next", "recurrent_state")] + + def __init__( + self, + input_size: int = None, + hidden_size: int = None, + num_layers: int = 1, + bias: bool = True, + batch_first=True, + dropout=0, + proj_size=0, + bidirectional=False, + *, + in_key=None, + in_keys=None, + out_key=None, + out_keys=None, + device=None, + gru=None, + ): + super().__init__() + if gru is not None: + if not gru.batch_first: + raise ValueError("The input lstm must have batch_first=True.") + if gru.bidirectional: + raise ValueError("The input lstm cannot be bidirectional.") + if input_size is not None or hidden_size is not None: + raise ValueError( + "An LSTM instance cannot be passed along with class argument." + ) + else: + if not batch_first: + raise ValueError("The input lstm must have batch_first=True.") + if bidirectional: + raise ValueError("The input lstm cannot be bidirectional.") + gru = nn.GRU( + input_size=input_size, + hidden_size=hidden_size, + num_layers=num_layers, + bias=bias, + dropout=dropout, + proj_size=proj_size, + device=device, + batch_first=True, + bidirectional=False, + ) + if not ((in_key is None) ^ (in_keys is None)): + raise ValueError( + f"Either in_keys or in_key must be specified but not both or none. Got {in_keys} and {in_key} respectively." + ) + elif in_key: + in_keys = [in_key, *self.DEFAULT_IN_KEYS] + + if not ((out_key is None) ^ (out_keys is None)): + raise ValueError( + f"Either out_keys or out_key must be specified but not both or none. Got {out_keys} and {out_key} respectively." + ) + elif out_key: + out_keys = [out_key, *self.DEFAULT_OUT_KEYS] + + in_keys = unravel_key_list(in_keys) + out_keys = unravel_key_list(out_keys) + if not isinstance(in_keys, (tuple, list)) or ( + len(in_keys) != 2 and not (len(in_keys) == 3 and in_keys[-1] == "is_init") + ): + raise ValueError( + f"LSTMModule expects 3 inputs: a value, and two hidden states (and potentially an 'is_init' marker). Got in_keys {in_keys} instead." + ) + if not isinstance(out_keys, (tuple, list)) or len(out_keys) != 2: + raise ValueError( + f"LSTMModule expects 3 outputs: a value, and two hidden states. Got out_keys {out_keys} instead." + ) + self.gru = gru + if "is_init" not in in_keys: + in_keys = in_keys + ["is_init"] + self.in_keys = in_keys + self.out_keys = out_keys + self._temporal_mode = False + + def make_tensordict_primer(self): + from torchrl.envs import TensorDictPrimer + + def make_tuple(key): + if isinstance(key, tuple): + return key + return (key,) + + out_key1 = make_tuple(self.out_keys[1]) + in_key1 = make_tuple(self.in_keys[1]) + if out_key1 != ("next", *in_key1): + raise RuntimeError( + "make_tensordict_primer is supposed to work with in_keys/out_keys that " + "have compatible names, ie. the out_keys should be named after ('next', ). Got " + f"in_keys={self.in_keys} and out_keys={self.out_keys} instead." + ) + return TensorDictPrimer( + { + in_key1: UnboundedContinuousTensorSpec( + shape=(self.gru.num_layers, self.gru.hidden_size) + ), + } + ) + + @property + def temporal_mode(self): + return self._temporal_mode + + @temporal_mode.setter + def temporal_mode(self, value): + raise RuntimeError("temporal_mode cannot be changed in-place. Call `module.set") + + def set_recurrent_mode(self, mode: bool = True): + """Returns a new copy of the module that shares the same lstm model but with a different ``temporal_mode`` attribute (if it differs). + + A copy is created such that the module can be used with divergent behaviour + in various parts of the code (inference vs training): + + Examples: + >>> from torchrl.envs import TransformedEnv, InitTracker, step_mdp + >>> from torchrl.envs.libs.gym import GymEnv + >>> from torchrl.modules import MLP + >>> from tensordict import TensorDict + >>> from torch import nn + >>> from tensordict.nn import TensorDictSequential as Seq, TensorDictModule as Mod + >>> env = TransformedEnv(GymEnv("Pendulum-v1"), InitTracker()) + >>> lstm = nn.LSTM(input_size=env.observation_spec["observation"].shape[-1], hidden_size=64, batch_first=True) + >>> lstm_module = LSTMModule(lstm, in_keys=["observation", "hidden0", "hidden1"], out_keys=["intermediate", ("next", "hidden0"), ("next", "hidden1")]) + >>> mlp = MLP(num_cells=[64], out_features=1) + >>> # building two policies with different behaviours: + >>> policy_inference = Seq(lstm_module, Mod(mlp, in_keys=["intermediate"], out_keys=["action"])) + >>> policy_training = Seq(lstm_module.set_recurrent_mode(True), Mod(mlp, in_keys=["intermediate"], out_keys=["action"])) + >>> traj_td = env.rollout(3) # some random temporal data + >>> traj_td = policy_training(traj_td) + >>> # let's check that both return the same results + >>> td_inf = TensorDict({}, traj_td.shape[:-1]) + >>> for td in traj_td.unbind(-1): + ... td_inf = td_inf.update(td.select("is_init", "observation", ("next", "observation"))) + ... td_inf = policy_inference(td_inf) + ... td_inf = step_mdp(td_inf) + ... + >>> torch.testing.assert_close(td_inf["hidden0"], traj_td[..., -1]["next", "hidden0"]) + """ + if mode is self._temporal_mode: + return self + out = LSTMModule(lstm=self.gru, in_keys=self.in_keys, out_keys=self.out_keys) + out._temporal_mode = mode + return out + + def forward(self, tensordict: TensorDictBase): + # we want to get an error if the value input is missing, but not the hidden states + defaults = [NO_DEFAULT, None, None] + shape = tensordict.shape + tensordict_shaped = tensordict + if self.temporal_mode: + # if less than 2 dims, unsqueeze + ndim = tensordict_shaped.get(self.in_keys[0]).ndim + while ndim < 3: + tensordict_shaped = tensordict_shaped.unsqueeze(0) + ndim += 1 + if ndim > 3: + dims_to_flatten = ndim - 3 + # we assume that the tensordict can be flattened like this + nelts = prod(tensordict_shaped.shape[: dims_to_flatten + 1]) + tensordict_shaped = tensordict_shaped.apply( + lambda value: value.flatten(0, dims_to_flatten), + batch_size=[nelts, tensordict_shaped.shape[-1]], + ) + else: + tensordict_shaped = tensordict.reshape(-1).unsqueeze(-1) + + is_init = tensordict_shaped.get("is_init").squeeze(-1) + splits = None + if self.temporal_mode and is_init[..., 1:].any(): + # if we have consecutive trajectories, things get a little more complicated + # we have a tensordict of shape [B, T] + # we will split / pad things such that we get a tensordict of shape + # [N, T'] where T' <= T and N >= B is the new batch size, such that + # each index of N is an independent trajectory. We'll need to keep + # track of the indices though, as we want to put things back together in the end. + splits = _get_num_per_traj_init(is_init) + tensordict_shaped_shape = tensordict_shaped.shape + tensordict_shaped = _split_and_pad_sequence( + tensordict_shaped.select(*self.in_keys, strict=False), splits + ) + is_init = tensordict_shaped.get("is_init").squeeze(-1) + + value, hidden0, hidden1 = ( + tensordict_shaped.get(key, default) + for key, default in zip(self.in_keys, defaults) + ) + batch, steps = value.shape[:2] + device = value.device + dtype = value.dtype + # packed sequences do not help to get the accurate last hidden values + # if splits is not None: + # value = torch.nn.utils.rnn.pack_padded_sequence(value, splits, batch_first=True) + if is_init.any() and hidden0 is not None: + hidden0[is_init] = 0 + hidden1[is_init] = 0 + val, hidden0, hidden1 = self._lstm( + value, batch, steps, device, dtype, hidden0, hidden1 + ) + tensordict_shaped.set(self.out_keys[0], val) + tensordict_shaped.set(self.out_keys[1], hidden0) + tensordict_shaped.set(self.out_keys[2], hidden1) + if splits is not None: + # let's recover our original shape + tensordict_shaped = _inv_pad_sequence(tensordict_shaped, splits).reshape( + tensordict_shaped_shape + ) + + if shape != tensordict_shaped.shape or tensordict_shaped is not tensordict: + tensordict.update(tensordict_shaped.reshape(shape)) + return tensordict + + def _lstm( + self, + input: torch.Tensor, + batch, + steps, + device, + dtype, + hidden0_in: Optional[torch.Tensor] = None, + hidden1_in: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + + if not self.temporal_mode and steps != 1: + raise ValueError("Expected a single step") + + if hidden1_in is None and hidden0_in is None: + shape = (batch, steps) + hidden0_in, hidden1_in = [ + torch.zeros( + *shape, + self.gru.num_layers, + self.gru.hidden_size, + device=device, + dtype=dtype, + ) + for _ in range(2) + ] + elif hidden1_in is None or hidden0_in is None: + raise RuntimeError( + f"got type(hidden0)={type(hidden0_in)} and type(hidden1)={type(hidden1_in)}" + ) + + # we only need the first hidden state + _hidden0_in = hidden0_in[:, 0] + _hidden1_in = hidden1_in[:, 0] + hidden = ( + _hidden0_in.transpose(-3, -2).contiguous(), + _hidden1_in.transpose(-3, -2).contiguous(), + ) + + y, hidden = self.gru(input, hidden) + # dim 0 in hidden is num_layers, but that will conflict with tensordict + hidden = tuple(_h.transpose(0, 1) for _h in hidden) + + out = [y, *hidden] + # we pad the hidden states with zero to make tensordict happy + for i in range(1, 3): + out[i] = torch.stack( + [torch.zeros_like(out[i]) for _ in range(steps - 1)] + [out[i]], + 1, + ) + return tuple(out) From b053beebaf7af0127210ed6ed5b8e7b86491e5a2 Mon Sep 17 00:00:00 2001 From: vmoens Date: Wed, 4 Oct 2023 15:31:33 +0100 Subject: [PATCH 27/29] amend --- docs/source/reference/modules.rst | 1 + test/test_tensordictmodules.py | 258 ++++++++++++++++++ torchrl/modules/__init__.py | 1 + torchrl/modules/tensordict_module/__init__.py | 2 +- torchrl/modules/tensordict_module/rnn.py | 138 +++++----- 5 files changed, 336 insertions(+), 64 deletions(-) diff --git a/docs/source/reference/modules.rst b/docs/source/reference/modules.rst index 32f50244771..bf9ccab7d5f 100644 --- a/docs/source/reference/modules.rst +++ b/docs/source/reference/modules.rst @@ -332,6 +332,7 @@ algorithms, such as DQN, DDPG or Dreamer. DistributionalDQNnet DreamerActor DuelingCnnDQNet + GRUModule LSTMModule ObsDecoder ObsEncoder diff --git a/test/test_tensordictmodules.py b/test/test_tensordictmodules.py index ca1e0e46e57..bcf6fa76a3a 100644 --- a/test/test_tensordictmodules.py +++ b/test/test_tensordictmodules.py @@ -26,6 +26,7 @@ AdditiveGaussianWrapper, DecisionTransformerInferenceWrapper, DTActor, + GRUModule, LSTMModule, MLP, NormalParamWrapper, @@ -1822,6 +1823,263 @@ def create_transformed_env(): assert (data.get(("next", "recurrent_state_c")) != 0.0).all() +class TestGRUModule: + def test_errs(self): + with pytest.raises(ValueError, match="batch_first"): + gru_module = GRUModule( + input_size=3, + hidden_size=12, + batch_first=False, + in_keys=["observation", "hidden"], + out_keys=["intermediate", ("next", "hidden")], + ) + with pytest.raises(ValueError, match="in_keys"): + gru_module = GRUModule( + input_size=3, + hidden_size=12, + batch_first=True, + in_keys=[ + "observation", + "hidden0", + "hidden1", + ], + out_keys=["intermediate", ("next", "hidden")], + ) + with pytest.raises(TypeError, match="incompatible function arguments"): + gru_module = GRUModule( + input_size=3, + hidden_size=12, + batch_first=True, + in_keys="abc", + out_keys=["intermediate", ("next", "hidden")], + ) + with pytest.raises(ValueError, match="in_keys"): + gru_module = GRUModule( + input_size=3, + hidden_size=12, + batch_first=True, + in_key="smth", + in_keys=["observation", "hidden0", "hidden1"], + out_keys=["intermediate", ("next", "hidden")], + ) + with pytest.raises(ValueError, match="out_keys"): + gru_module = GRUModule( + input_size=3, + hidden_size=12, + batch_first=True, + in_keys=["observation", "hidden"], + out_keys=["intermediate", ("next", "hidden"), "other"], + ) + with pytest.raises(TypeError, match="incompatible function arguments"): + gru_module = GRUModule( + input_size=3, + hidden_size=12, + batch_first=True, + in_keys=["observation", "hidden"], + out_keys="abc", + ) + with pytest.raises(ValueError, match="out_keys"): + gru_module = GRUModule( + input_size=3, + hidden_size=12, + batch_first=True, + in_keys=["observation", "hidden"], + out_key="smth", + out_keys=["intermediate", ("next", "hidden"), "other"], + ) + gru_module = GRUModule( + input_size=3, + hidden_size=12, + batch_first=True, + in_keys=["observation", "hidden"], + out_keys=["intermediate", ("next", "hidden")], + ) + td = TensorDict({"observation": torch.randn(3)}, []) + with pytest.raises(KeyError, match="is_init"): + gru_module(td) + + def test_set_temporal_mode(self): + gru_module = GRUModule( + input_size=3, + hidden_size=12, + batch_first=True, + in_keys=["observation", "hidden"], + out_keys=["intermediate", ("next", "hidden")], + ) + assert gru_module.set_recurrent_mode(False) is gru_module + assert not gru_module.set_recurrent_mode(False).temporal_mode + assert gru_module.set_recurrent_mode(True) is not gru_module + assert gru_module.set_recurrent_mode(True).temporal_mode + assert set(gru_module.set_recurrent_mode(True).parameters()) == set( + gru_module.parameters() + ) + + def test_noncontiguous(self): + gru_module = GRUModule( + input_size=3, + hidden_size=12, + batch_first=True, + in_keys=["bork", "h"], + out_keys=["dork", ("next", "h")], + ) + td = TensorDict( + { + "bork": torch.randn(3, 3), + "is_init": torch.zeros(3, 1, dtype=torch.bool), + }, + [3], + ) + padded = pad(td, [0, 5]) + gru_module(padded) + + @pytest.mark.parametrize("shape", [[], [2], [2, 3], [2, 3, 4]]) + def test_singel_step(self, shape): + td = TensorDict( + { + "observation": torch.zeros(*shape, 3), + "is_init": torch.zeros(*shape, 1, dtype=torch.bool), + }, + shape, + ) + gru_module = GRUModule( + input_size=3, + hidden_size=12, + batch_first=True, + in_keys=["observation", "hidden"], + out_keys=["intermediate", ("next", "hidden")], + ) + td = gru_module(td) + td_next = step_mdp(td, keep_other=True) + td_next = gru_module(td_next) + + assert not torch.isclose(td_next["next", "hidden"], td["next", "hidden"]).any() + + @pytest.mark.parametrize("shape", [[], [2], [2, 3], [2, 3, 4]]) + @pytest.mark.parametrize("t", [1, 10]) + def test_single_step_vs_multi(self, shape, t): + td = TensorDict( + { + "observation": torch.arange(t, dtype=torch.float32) + .unsqueeze(-1) + .expand(*shape, t, 3), + "is_init": torch.zeros(*shape, t, 1, dtype=torch.bool), + }, + [*shape, t], + ) + gru_module_ss = GRUModule( + input_size=3, + hidden_size=12, + batch_first=True, + in_keys=["observation", "hidden"], + out_keys=["intermediate", ("next", "hidden")], + ) + gru_module_ms = gru_module_ss.set_recurrent_mode() + gru_module_ms(td) + td_ss = TensorDict( + { + "observation": torch.zeros(*shape, 3), + "is_init": torch.zeros(*shape, 1, dtype=torch.bool), + }, + shape, + ) + for _t in range(t): + gru_module_ss(td_ss) + td_ss = step_mdp(td_ss, keep_other=True) + td_ss["observation"][:] = _t + 1 + torch.testing.assert_close(td_ss["hidden"], td["next", "hidden"][..., -1, :, :]) + + @pytest.mark.parametrize("shape", [[], [2], [2, 3], [2, 3, 4]]) + def test_multi_consecutive(self, shape): + t = 20 + td = TensorDict( + { + "observation": torch.arange(t, dtype=torch.float32) + .unsqueeze(-1) + .expand(*shape, t, 3), + "is_init": torch.zeros(*shape, t, 1, dtype=torch.bool), + }, + [*shape, t], + ) + if shape: + td["is_init"][0, ..., 13, :] = True + else: + td["is_init"][13, :] = True + + gru_module_ss = GRUModule( + input_size=3, + hidden_size=12, + batch_first=True, + in_keys=["observation", "hidden"], + out_keys=["intermediate", ("next", "hidden")], + ) + gru_module_ms = gru_module_ss.set_recurrent_mode() + gru_module_ms(td) + td_ss = TensorDict( + { + "observation": torch.zeros(*shape, 3), + "is_init": torch.zeros(*shape, 1, dtype=torch.bool), + }, + shape, + ) + for _t in range(t): + td_ss["is_init"][:] = td["is_init"][..., _t, :] + gru_module_ss(td_ss) + td_ss = step_mdp(td_ss, keep_other=True) + td_ss["observation"][:] = _t + 1 + torch.testing.assert_close( + td_ss["intermediate"], td["intermediate"][..., -1, :] + ) + + def test_gru_parallel_env(self): + from torchrl.envs import InitTracker, ParallelEnv, TransformedEnv + + # tests that hidden states are carried over with parallel envs + gru_module = GRUModule( + input_size=7, + hidden_size=12, + num_layers=2, + in_key="observation", + out_key="features", + ) + + def create_transformed_env(): + primer = gru_module.make_tensordict_primer() + env = DiscreteActionVecMockEnv(categorical_action_encoding=True) + env = TransformedEnv(env) + env.append_transform(InitTracker()) + env.append_transform(primer) + return env + + env = ParallelEnv( + create_env_fn=create_transformed_env, + num_workers=2, + ) + + mlp = TensorDictModule( + MLP( + in_features=12, + out_features=7, + num_cells=[], + ), + in_keys=["features"], + out_keys=["logits"], + ) + + actor_model = TensorDictSequential(gru_module, mlp) + + actor = ProbabilisticActor( + module=actor_model, + in_keys=["logits"], + out_keys=["action"], + distribution_class=torch.distributions.Categorical, + return_log_prob=True, + ) + for break_when_any_done in [False, True]: + data = env.rollout(10, actor, break_when_any_done=break_when_any_done) + assert (data.get("recurrent_state") != 0.0).any() + assert (data.get(("next", "recurrent_state")) != 0.0).all() + + def test_safe_specs(): out_key = ("a", "b") diff --git a/torchrl/modules/__init__.py b/torchrl/modules/__init__.py index 604bb3bdca7..a4d69cd9cea 100644 --- a/torchrl/modules/__init__.py +++ b/torchrl/modules/__init__.py @@ -55,6 +55,7 @@ DistributionalQValueModule, EGreedyModule, EGreedyWrapper, + GRUModule, LMHeadActorValueOperator, LSTMModule, OrnsteinUhlenbeckProcessWrapper, diff --git a/torchrl/modules/tensordict_module/__init__.py b/torchrl/modules/tensordict_module/__init__.py index d1930855ab2..7605238f99a 100644 --- a/torchrl/modules/tensordict_module/__init__.py +++ b/torchrl/modules/tensordict_module/__init__.py @@ -31,6 +31,6 @@ SafeProbabilisticModule, SafeProbabilisticTensorDictSequential, ) -from .rnn import LSTMModule +from .rnn import GRUModule, LSTMModule from .sequence import SafeSequential from .world_models import WorldModelWrapper diff --git a/torchrl/modules/tensordict_module/rnn.py b/torchrl/modules/tensordict_module/rnn.py index 7c7e8b4dc4c..aeff7f83e24 100644 --- a/torchrl/modules/tensordict_module/rnn.py +++ b/torchrl/modules/tensordict_module/rnn.py @@ -61,7 +61,6 @@ class LSTMModule(ModuleBase): dropout: If non-zero, introduces a `Dropout` layer on the outputs of each LSTM layer except the last layer, with dropout probability equal to :attr:`dropout`. Default: 0 - proj_size: If ``> 0``, will use LSTM with projections of corresponding size. Default: 0 Keyword Args: in_key (str or tuple of str): the input key of the module. Exclusive use @@ -89,12 +88,12 @@ class LSTMModule(ModuleBase): temporal_mode: Returns the temporal mode of the module. Methods: - set_temporal_mode: controls whether the module should be executed in + set_recurrent_mode: controls whether the module should be executed in temporal mode. Examples: >>> from torchrl.envs import TransformedEnv, InitTracker - >>> from torchrl.envs.libs.gym import GymEnv + >>> from torchrl.envs import GymEnv >>> from torchrl.modules import MLP >>> from torch import nn >>> from tensordict.nn import TensorDictSequential as Seq, TensorDictModule as Mod @@ -121,6 +120,8 @@ class LSTMModule(ModuleBase): device=cpu, is_shared=False), observation: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False)}, + terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False), + truncated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False)}, batch_size=torch.Size([]), device=cpu, is_shared=False) @@ -252,14 +253,14 @@ def set_recurrent_mode(self, mode: bool = True): Examples: >>> from torchrl.envs import TransformedEnv, InitTracker, step_mdp - >>> from torchrl.envs.libs.gym import GymEnv + >>> from torchrl.envs import GymEnv >>> from torchrl.modules import MLP >>> from tensordict import TensorDict >>> from torch import nn >>> from tensordict.nn import TensorDictSequential as Seq, TensorDictModule as Mod >>> env = TransformedEnv(GymEnv("Pendulum-v1"), InitTracker()) >>> lstm = nn.LSTM(input_size=env.observation_spec["observation"].shape[-1], hidden_size=64, batch_first=True) - >>> lstm_module = LSTMModule(lstm, in_keys=["observation", "hidden0", "hidden1"], out_keys=["intermediate", ("next", "hidden0"), ("next", "hidden1")]) + >>> lstm_module = LSTMModule(lstm=lstm, in_keys=["observation", "hidden0", "hidden1"], out_keys=["intermediate", ("next", "hidden0"), ("next", "hidden1")]) >>> mlp = MLP(num_cells=[64], out_features=1) >>> # building two policies with different behaviours: >>> policy_inference = Seq(lstm_module, Mod(mlp, in_keys=["intermediate"], out_keys=["action"])) @@ -425,15 +426,15 @@ class GRUModule(ModuleBase): input_size: The number of expected features in the input `x` hidden_size: The number of features in the hidden state `h` num_layers: Number of recurrent layers. E.g., setting ``num_layers=2`` - would mean stacking two LSTMs together to form a `stacked LSTM`, - with the second LSTM taking in outputs of the first LSTM and + would mean stacking two GRUs together to form a `stacked GRU`, + with the second GRU taking in outputs of the first GRU and computing the final results. Default: 1 - bias: If ``False``, then the layer does not use bias weights `b_ih` and `b_hh`. + bias: If ``False``, then the layer does not use bias weights. Default: ``True`` dropout: If non-zero, introduces a `Dropout` layer on the outputs of each - LSTM layer except the last layer, with dropout probability equal to + GRU layer except the last layer, with dropout probability equal to :attr:`dropout`. Default: 0 - proj_size: If ``> 0``, will use LSTM with projections of corresponding size. Default: 0 + proj_size: If ``> 0``, will use GRU with projections of corresponding size. Default: 0 Keyword Args: in_key (str or tuple of str): the input key of the module. Exclusive use @@ -460,12 +461,12 @@ class GRUModule(ModuleBase): temporal_mode: Returns the temporal mode of the module. Methods: - set_temporal_mode: controls whether the module should be executed in + set_recurrent_mode: controls whether the module should be executed in temporal mode. Examples: >>> from torchrl.envs import TransformedEnv, InitTracker - >>> from torchrl.envs.libs.gym import GymEnv + >>> from torchrl.envs import GymEnv >>> from torchrl.modules import MLP >>> from torch import nn >>> from tensordict.nn import TensorDictSequential as Seq, TensorDictModule as Mod @@ -486,14 +487,45 @@ class GRUModule(ModuleBase): is_init: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False), next: TensorDict( fields={ - rs: Tensor(shape=torch.Size([1, 64]), device=cpu, dtype=torch.float32, is_shared=False), + rs: Tensor(shape=torch.Size([1, 64]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([]), device=cpu, is_shared=False), - observation: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False)}, + observation: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False), + terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False), + truncated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False)}, batch_size=torch.Size([]), device=cpu, is_shared=False) + >>> gru_module_training = gru_module.set_recurrent_mode() + >>> policy_training = Seq(gru_module, Mod(mlp, in_keys=["intermediate"], out_keys=["action"])) + >>> traj_td = env.rollout(3) # some random temporal data + >>> traj_td = policy_training(traj_td) + >>> print(traj_td) + TensorDict( + fields={ + action: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.float32, is_shared=False), + done: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False), + intermediate: Tensor(shape=torch.Size([3, 64]), device=cpu, dtype=torch.float32, is_shared=False), + is_init: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False), + next: TensorDict( + fields={ + done: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False), + is_init: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False), + observation: Tensor(shape=torch.Size([3, 3]), device=cpu, dtype=torch.float32, is_shared=False), + reward: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.float32, is_shared=False), + rs: Tensor(shape=torch.Size([3, 1, 64]), device=cpu, dtype=torch.float32, is_shared=False), + terminated: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False), + truncated: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, + batch_size=torch.Size([3]), + device=cpu, + is_shared=False), + observation: Tensor(shape=torch.Size([3, 3]), device=cpu, dtype=torch.float32, is_shared=False), + terminated: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False), + truncated: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, + batch_size=torch.Size([3]), + device=cpu, + is_shared=False) """ @@ -508,7 +540,6 @@ def __init__( bias: bool = True, batch_first=True, dropout=0, - proj_size=0, bidirectional=False, *, in_key=None, @@ -539,7 +570,6 @@ def __init__( num_layers=num_layers, bias=bias, dropout=dropout, - proj_size=proj_size, device=device, batch_first=True, bidirectional=False, @@ -616,19 +646,18 @@ def set_recurrent_mode(self, mode: bool = True): in various parts of the code (inference vs training): Examples: - >>> from torchrl.envs import TransformedEnv, InitTracker, step_mdp - >>> from torchrl.envs.libs.gym import GymEnv + >>> from torchrl.envs import GymEnv, TransformedEnv, InitTracker, step_mdp >>> from torchrl.modules import MLP >>> from tensordict import TensorDict >>> from torch import nn >>> from tensordict.nn import TensorDictSequential as Seq, TensorDictModule as Mod >>> env = TransformedEnv(GymEnv("Pendulum-v1"), InitTracker()) - >>> lstm = nn.LSTM(input_size=env.observation_spec["observation"].shape[-1], hidden_size=64, batch_first=True) - >>> lstm_module = LSTMModule(lstm, in_keys=["observation", "hidden0", "hidden1"], out_keys=["intermediate", ("next", "hidden0"), ("next", "hidden1")]) + >>> gru = nn.GRU(input_size=env.observation_spec["observation"].shape[-1], hidden_size=64, batch_first=True) + >>> gru_module = GRUModule(gru=gru, in_keys=["observation", "hidden"], out_keys=["intermediate", ("next", "hidden")]) >>> mlp = MLP(num_cells=[64], out_features=1) >>> # building two policies with different behaviours: - >>> policy_inference = Seq(lstm_module, Mod(mlp, in_keys=["intermediate"], out_keys=["action"])) - >>> policy_training = Seq(lstm_module.set_recurrent_mode(True), Mod(mlp, in_keys=["intermediate"], out_keys=["action"])) + >>> policy_inference = Seq(gru_module, Mod(mlp, in_keys=["intermediate"], out_keys=["action"])) + >>> policy_training = Seq(gru_module.set_recurrent_mode(True), Mod(mlp, in_keys=["intermediate"], out_keys=["action"])) >>> traj_td = env.rollout(3) # some random temporal data >>> traj_td = policy_training(traj_td) >>> # let's check that both return the same results @@ -638,17 +667,17 @@ def set_recurrent_mode(self, mode: bool = True): ... td_inf = policy_inference(td_inf) ... td_inf = step_mdp(td_inf) ... - >>> torch.testing.assert_close(td_inf["hidden0"], traj_td[..., -1]["next", "hidden0"]) + >>> torch.testing.assert_close(td_inf["hidden"], traj_td[..., -1]["next", "hidden"]) """ if mode is self._temporal_mode: return self - out = LSTMModule(lstm=self.gru, in_keys=self.in_keys, out_keys=self.out_keys) + out = GRUModule(gru=self.gru, in_keys=self.in_keys, out_keys=self.out_keys) out._temporal_mode = mode return out def forward(self, tensordict: TensorDictBase): # we want to get an error if the value input is missing, but not the hidden states - defaults = [NO_DEFAULT, None, None] + defaults = [NO_DEFAULT, None] shape = tensordict.shape tensordict_shaped = tensordict if self.temporal_mode: @@ -684,7 +713,7 @@ def forward(self, tensordict: TensorDictBase): ) is_init = tensordict_shaped.get("is_init").squeeze(-1) - value, hidden0, hidden1 = ( + value, hidden = ( tensordict_shaped.get(key, default) for key, default in zip(self.in_keys, defaults) ) @@ -694,15 +723,11 @@ def forward(self, tensordict: TensorDictBase): # packed sequences do not help to get the accurate last hidden values # if splits is not None: # value = torch.nn.utils.rnn.pack_padded_sequence(value, splits, batch_first=True) - if is_init.any() and hidden0 is not None: - hidden0[is_init] = 0 - hidden1[is_init] = 0 - val, hidden0, hidden1 = self._lstm( - value, batch, steps, device, dtype, hidden0, hidden1 - ) + if is_init.any() and hidden is not None: + hidden[is_init] = 0 + val, hidden = self._lstm(value, batch, steps, device, dtype, hidden) tensordict_shaped.set(self.out_keys[0], val) - tensordict_shaped.set(self.out_keys[1], hidden0) - tensordict_shaped.set(self.out_keys[2], hidden1) + tensordict_shaped.set(self.out_keys[1], hidden) if splits is not None: # let's recover our original shape tensordict_shaped = _inv_pad_sequence(tensordict_shaped, splits).reshape( @@ -720,47 +745,34 @@ def _lstm( steps, device, dtype, - hidden0_in: Optional[torch.Tensor] = None, - hidden1_in: Optional[torch.Tensor] = None, + hidden_in: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: if not self.temporal_mode and steps != 1: raise ValueError("Expected a single step") - if hidden1_in is None and hidden0_in is None: + if hidden_in is None: shape = (batch, steps) - hidden0_in, hidden1_in = [ - torch.zeros( - *shape, - self.gru.num_layers, - self.gru.hidden_size, - device=device, - dtype=dtype, - ) - for _ in range(2) - ] - elif hidden1_in is None or hidden0_in is None: - raise RuntimeError( - f"got type(hidden0)={type(hidden0_in)} and type(hidden1)={type(hidden1_in)}" + hidden_in = torch.zeros( + *shape, + self.gru.num_layers, + self.gru.hidden_size, + device=device, + dtype=dtype, ) # we only need the first hidden state - _hidden0_in = hidden0_in[:, 0] - _hidden1_in = hidden1_in[:, 0] - hidden = ( - _hidden0_in.transpose(-3, -2).contiguous(), - _hidden1_in.transpose(-3, -2).contiguous(), - ) + _hidden_in = hidden_in[:, 0] + hidden = _hidden_in.transpose(-3, -2).contiguous() y, hidden = self.gru(input, hidden) # dim 0 in hidden is num_layers, but that will conflict with tensordict - hidden = tuple(_h.transpose(0, 1) for _h in hidden) + hidden = hidden.transpose(0, 1) - out = [y, *hidden] # we pad the hidden states with zero to make tensordict happy - for i in range(1, 3): - out[i] = torch.stack( - [torch.zeros_like(out[i]) for _ in range(steps - 1)] + [out[i]], - 1, - ) + hidden = torch.stack( + [torch.zeros_like(hidden) for _ in range(steps - 1)] + [hidden], + 1, + ) + out = [y, hidden] return tuple(out) From c40e3db78912c296989d8cff7ce24340ae29dd6f Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Thu, 5 Oct 2023 08:52:19 +0100 Subject: [PATCH 28/29] Update test/mocking_classes.py --- test/mocking_classes.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/mocking_classes.py b/test/mocking_classes.py index 650f8770e9e..d71a0b5cbb3 100644 --- a/test/mocking_classes.py +++ b/test/mocking_classes.py @@ -444,7 +444,7 @@ def __new__( action_spec = action_spec_cls(n=7, shape=batch_size) else: action_spec_cls = OneHotDiscreteTensorSpec - action_spec = action_spec_cls(n=7, shape=batch_size) + action_spec = action_spec_cls(n=7, shape=(*batch_size, 7)) if reward_spec is None: reward_spec = CompositeSpec( reward=UnboundedContinuousTensorSpec(shape=(1,)) From da7e173f3d2dff4537ffdc63dc76feb7888542cf Mon Sep 17 00:00:00 2001 From: vmoens Date: Thu, 5 Oct 2023 09:45:47 +0100 Subject: [PATCH 29/29] amend --- test/test_tensordictmodules.py | 8 +- torchrl/modules/tensordict_module/rnn.py | 103 ++++++++++++++--------- 2 files changed, 66 insertions(+), 45 deletions(-) diff --git a/test/test_tensordictmodules.py b/test/test_tensordictmodules.py index bcf6fa76a3a..4e1fbfcd1c1 100644 --- a/test/test_tensordictmodules.py +++ b/test/test_tensordictmodules.py @@ -1646,9 +1646,9 @@ def test_set_temporal_mode(self): out_keys=["intermediate", ("next", "hidden0"), ("next", "hidden1")], ) assert lstm_module.set_recurrent_mode(False) is lstm_module - assert not lstm_module.set_recurrent_mode(False).temporal_mode + assert not lstm_module.set_recurrent_mode(False).recurrent_mode assert lstm_module.set_recurrent_mode(True) is not lstm_module - assert lstm_module.set_recurrent_mode(True).temporal_mode + assert lstm_module.set_recurrent_mode(True).recurrent_mode assert set(lstm_module.set_recurrent_mode(True).parameters()) == set( lstm_module.parameters() ) @@ -1907,9 +1907,9 @@ def test_set_temporal_mode(self): out_keys=["intermediate", ("next", "hidden")], ) assert gru_module.set_recurrent_mode(False) is gru_module - assert not gru_module.set_recurrent_mode(False).temporal_mode + assert not gru_module.set_recurrent_mode(False).recurrent_mode assert gru_module.set_recurrent_mode(True) is not gru_module - assert gru_module.set_recurrent_mode(True).temporal_mode + assert gru_module.set_recurrent_mode(True).recurrent_mode assert set(gru_module.set_recurrent_mode(True).parameters()) == set( gru_module.parameters() ) diff --git a/torchrl/modules/tensordict_module/rnn.py b/torchrl/modules/tensordict_module/rnn.py index aeff7f83e24..22be1432edf 100644 --- a/torchrl/modules/tensordict_module/rnn.py +++ b/torchrl/modules/tensordict_module/rnn.py @@ -2,6 +2,7 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +import warnings from typing import Optional, Tuple import torch @@ -35,10 +36,10 @@ class LSTMModule(ModuleBase): multi-step. This class enables both usages. - After construction, the module is *not* set in temporal mode, ie. it will + After construction, the module is *not* set in recurrent mode, ie. it will expect single steps inputs. - If in temporal mode, it is expected that the last dimension of the tensordict + If in recurrent mode, it is expected that the last dimension of the tensordict marks the number of steps. There is no constrain on the dimensionality of the tensordict (except that it must be greater than one for temporal inputs). @@ -85,11 +86,11 @@ class LSTMModule(ModuleBase): Exclusive with other nn.LSTM arguments. Attributes: - temporal_mode: Returns the temporal mode of the module. + recurrent_mode: Returns the recurrent mode of the module. Methods: set_recurrent_mode: controls whether the module should be executed in - temporal mode. + recurrent mode. Examples: >>> from torchrl.envs import TransformedEnv, InitTracker @@ -206,7 +207,7 @@ def __init__( in_keys = in_keys + ["is_init"] self.in_keys = in_keys self.out_keys = out_keys - self._temporal_mode = False + self._recurrent_mode = False def make_tensordict_primer(self): from torchrl.envs.transforms.transforms import TensorDictPrimer @@ -238,15 +239,25 @@ def make_tuple(key): ) @property - def temporal_mode(self): - return self._temporal_mode + def recurrent_mode(self): + return self._recurrent_mode + + @recurrent_mode.setter + def recurrent_mode(self, value): + raise RuntimeError( + "recurrent_mode cannot be changed in-place. Call `module.set" + ) - @temporal_mode.setter - def temporal_mode(self, value): - raise RuntimeError("temporal_mode cannot be changed in-place. Call `module.set") + @property + def temporal_mode(self): + warnings.warn( + "temporal_mode is deprecated, use recurrent_mode instead.", + category=DeprecationWarning, + ) + return self.recurrent_mode def set_recurrent_mode(self, mode: bool = True): - """Returns a new copy of the module that shares the same lstm model but with a different ``temporal_mode`` attribute (if it differs). + """Returns a new copy of the module that shares the same lstm model but with a different ``recurrent_mode`` attribute (if it differs). A copy is created such that the module can be used with divergent behaviour in various parts of the code (inference vs training): @@ -276,10 +287,10 @@ def set_recurrent_mode(self, mode: bool = True): ... >>> torch.testing.assert_close(td_inf["hidden0"], traj_td[..., -1]["next", "hidden0"]) """ - if mode is self._temporal_mode: + if mode is self._recurrent_mode: return self out = LSTMModule(lstm=self.lstm, in_keys=self.in_keys, out_keys=self.out_keys) - out._temporal_mode = mode + out._recurrent_mode = mode return out def forward(self, tensordict: TensorDictBase): @@ -287,7 +298,7 @@ def forward(self, tensordict: TensorDictBase): defaults = [NO_DEFAULT, None, None] shape = tensordict.shape tensordict_shaped = tensordict - if self.temporal_mode: + if self.recurrent_mode: # if less than 2 dims, unsqueeze ndim = tensordict_shaped.get(self.in_keys[0]).ndim while ndim < 3: @@ -306,7 +317,7 @@ def forward(self, tensordict: TensorDictBase): is_init = tensordict_shaped.get("is_init").squeeze(-1) splits = None - if self.temporal_mode and is_init[..., 1:].any(): + if self.recurrent_mode and is_init[..., 1:].any(): # if we have consecutive trajectories, things get a little more complicated # we have a tensordict of shape [B, T] # we will split / pad things such that we get a tensordict of shape @@ -360,7 +371,7 @@ def _lstm( hidden1_in: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - if not self.temporal_mode and steps != 1: + if not self.recurrent_mode and steps != 1: raise ValueError("Expected a single step") if hidden1_in is None and hidden0_in is None: @@ -415,10 +426,10 @@ class GRUModule(ModuleBase): multi-step. This class enables both usages. - After construction, the module is *not* set in temporal mode, ie. it will + After construction, the module is *not* set in recurrent mode, ie. it will expect single steps inputs. - If in temporal mode, it is expected that the last dimension of the tensordict + If in recurrent mode, it is expected that the last dimension of the tensordict marks the number of steps. There is no constrain on the dimensionality of the tensordict (except that it must be greater than one for temporal inputs). @@ -458,11 +469,11 @@ class GRUModule(ModuleBase): Exclusive with other nn.GRU arguments. Attributes: - temporal_mode: Returns the temporal mode of the module. + recurrent_mode: Returns the recurrent mode of the module. Methods: set_recurrent_mode: controls whether the module should be executed in - temporal mode. + recurrent mode. Examples: >>> from torchrl.envs import TransformedEnv, InitTracker @@ -552,18 +563,18 @@ def __init__( super().__init__() if gru is not None: if not gru.batch_first: - raise ValueError("The input lstm must have batch_first=True.") + raise ValueError("The input gru must have batch_first=True.") if gru.bidirectional: - raise ValueError("The input lstm cannot be bidirectional.") + raise ValueError("The input gru cannot be bidirectional.") if input_size is not None or hidden_size is not None: raise ValueError( - "An LSTM instance cannot be passed along with class argument." + "An GRU instance cannot be passed along with class argument." ) else: if not batch_first: - raise ValueError("The input lstm must have batch_first=True.") + raise ValueError("The input gru must have batch_first=True.") if bidirectional: - raise ValueError("The input lstm cannot be bidirectional.") + raise ValueError("The input gru cannot be bidirectional.") gru = nn.GRU( input_size=input_size, hidden_size=hidden_size, @@ -594,18 +605,18 @@ def __init__( len(in_keys) != 2 and not (len(in_keys) == 3 and in_keys[-1] == "is_init") ): raise ValueError( - f"LSTMModule expects 3 inputs: a value, and two hidden states (and potentially an 'is_init' marker). Got in_keys {in_keys} instead." + f"GRUModule expects 3 inputs: a value, and two hidden states (and potentially an 'is_init' marker). Got in_keys {in_keys} instead." ) if not isinstance(out_keys, (tuple, list)) or len(out_keys) != 2: raise ValueError( - f"LSTMModule expects 3 outputs: a value, and two hidden states. Got out_keys {out_keys} instead." + f"GRUModule expects 3 outputs: a value, and two hidden states. Got out_keys {out_keys} instead." ) self.gru = gru if "is_init" not in in_keys: in_keys = in_keys + ["is_init"] self.in_keys = in_keys self.out_keys = out_keys - self._temporal_mode = False + self._recurrent_mode = False def make_tensordict_primer(self): from torchrl.envs import TensorDictPrimer @@ -632,15 +643,25 @@ def make_tuple(key): ) @property - def temporal_mode(self): - return self._temporal_mode + def recurrent_mode(self): + return self._recurrent_mode + + @recurrent_mode.setter + def recurrent_mode(self, value): + raise RuntimeError( + "recurrent_mode cannot be changed in-place. Call `module.set" + ) - @temporal_mode.setter - def temporal_mode(self, value): - raise RuntimeError("temporal_mode cannot be changed in-place. Call `module.set") + @property + def temporal_mode(self): + warnings.warn( + "temporal_mode is deprecated, use recurrent_mode instead.", + category=DeprecationWarning, + ) + return self.recurrent_mode def set_recurrent_mode(self, mode: bool = True): - """Returns a new copy of the module that shares the same lstm model but with a different ``temporal_mode`` attribute (if it differs). + """Returns a new copy of the module that shares the same gru model but with a different ``recurrent_mode`` attribute (if it differs). A copy is created such that the module can be used with divergent behaviour in various parts of the code (inference vs training): @@ -669,10 +690,10 @@ def set_recurrent_mode(self, mode: bool = True): ... >>> torch.testing.assert_close(td_inf["hidden"], traj_td[..., -1]["next", "hidden"]) """ - if mode is self._temporal_mode: + if mode is self._recurrent_mode: return self out = GRUModule(gru=self.gru, in_keys=self.in_keys, out_keys=self.out_keys) - out._temporal_mode = mode + out._recurrent_mode = mode return out def forward(self, tensordict: TensorDictBase): @@ -680,7 +701,7 @@ def forward(self, tensordict: TensorDictBase): defaults = [NO_DEFAULT, None] shape = tensordict.shape tensordict_shaped = tensordict - if self.temporal_mode: + if self.recurrent_mode: # if less than 2 dims, unsqueeze ndim = tensordict_shaped.get(self.in_keys[0]).ndim while ndim < 3: @@ -699,7 +720,7 @@ def forward(self, tensordict: TensorDictBase): is_init = tensordict_shaped.get("is_init").squeeze(-1) splits = None - if self.temporal_mode and is_init[..., 1:].any(): + if self.recurrent_mode and is_init[..., 1:].any(): # if we have consecutive trajectories, things get a little more complicated # we have a tensordict of shape [B, T] # we will split / pad things such that we get a tensordict of shape @@ -725,7 +746,7 @@ def forward(self, tensordict: TensorDictBase): # value = torch.nn.utils.rnn.pack_padded_sequence(value, splits, batch_first=True) if is_init.any() and hidden is not None: hidden[is_init] = 0 - val, hidden = self._lstm(value, batch, steps, device, dtype, hidden) + val, hidden = self._gru(value, batch, steps, device, dtype, hidden) tensordict_shaped.set(self.out_keys[0], val) tensordict_shaped.set(self.out_keys[1], hidden) if splits is not None: @@ -738,7 +759,7 @@ def forward(self, tensordict: TensorDictBase): tensordict.update(tensordict_shaped.reshape(shape)) return tensordict - def _lstm( + def _gru( self, input: torch.Tensor, batch, @@ -748,7 +769,7 @@ def _lstm( hidden_in: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - if not self.temporal_mode and steps != 1: + if not self.recurrent_mode and steps != 1: raise ValueError("Expected a single step") if hidden_in is None: