From 20d09bc564695057d70327c32fb0064407a4b7f0 Mon Sep 17 00:00:00 2001 From: vmoens Date: Tue, 12 Sep 2023 09:56:42 -0400 Subject: [PATCH 01/16] init --- test/test_libs.py | 51 +++++++++++++++++++++++ torchrl/envs/common.py | 17 ++++++-- torchrl/envs/gym_like.py | 2 +- torchrl/envs/libs/gym.py | 89 ++++++++++++++++++++++++++++++---------- 4 files changed, 133 insertions(+), 26 deletions(-) diff --git a/test/test_libs.py b/test/test_libs.py index 2a44e5a70bc..d37c0887653 100644 --- a/test/test_libs.py +++ b/test/test_libs.py @@ -58,6 +58,7 @@ GymWrapper, MOGymEnv, MOGymWrapper, + set_gym_backend, ) from torchrl.envs.libs.habitat import _has_habitat, HabitatEnv from torchrl.envs.libs.jumanji import _has_jumanji, JumanjiEnv @@ -312,6 +313,56 @@ def test_one_hot_and_categorical(self): # noqa: F811 # versions. return + @implement_for("gymnasium", "0.27.0", None) + def test_vecenvs(self): + import gymnasium + + # we can't use parametrize with implement_for + for envname in ["HalfCheetah-v4", "CartPole-v1", "ALE/Pong-v5"]: + env = GymWrapper( + gymnasium.vector.SyncVectorEnv(2 * [lambda envname=envname: gymnasium.make(envname)]) + ) + assert env.batch_size == torch.Size([2]) + check_env_specs(env) + env = GymWrapper( + gymnasium.vector.AsyncVectorEnv(2 * [lambda envname=envname: gymnasium.make(envname)]) + ) + assert env.batch_size == torch.Size([2]) + check_env_specs(env) + with set_gym_backend("gymnasium"): + env = GymEnv(envname, num_envs=2, from_pixels=False) + check_env_specs(env) + # with set_gym_backend("gymnasium"): + # env = GymEnv(envname, num_envs=2, from_pixels=True) + # check_env_specs(env) + + @implement_for("gym", "0.24", "0.27.0") + def test_vecenvs(self): # noqa: F811 + import gymnasium + + # we can't use parametrize with implement_for + for envname in ["CartPole-v1", "HalfCheetah-v4"]: + env = GymWrapper( + gymnasium.vector.SyncVectorEnv( + 2 * [lambda envname=envname: gymnasium.make(envname)] + ) + ) + assert env.batch_size == torch.Size([2]) + check_env_specs(env) + env = GymWrapper( + gymnasium.vector.aSyncVectorEnv( + 2 * [lambda envname=envname: gymnasium.make(envname)] + ) + ) + assert env.batch_size == torch.Size([2]) + check_env_specs(env) + with set_gym_backend("gym"): + env = GymEnv(envname, num_envs=2, from_pixels=False) + check_env_specs(env) + with set_gym_backend("gym"): + env = GymEnv(envname, num_envs=2, from_pixels=True) + check_env_specs(env) + @implement_for("gym", None, "0.26") def _make_gym_environment(env_name): # noqa: F811 diff --git a/torchrl/envs/common.py b/torchrl/envs/common.py index 5ecdf148238..02fe161799a 100644 --- a/torchrl/envs/common.py +++ b/torchrl/envs/common.py @@ -615,7 +615,7 @@ def action_spec(self, value: TensorSpec) -> None: ) if value.shape[: len(self.batch_size)] != self.batch_size: raise ValueError( - "The value of spec.shape must match the env batch size." + f"The value of spec.shape ({value.shape}) must match the env batch size ({self.batch_size})." ) if isinstance(value, CompositeSpec): @@ -789,7 +789,7 @@ def reward_spec(self, value: TensorSpec) -> None: ) if value.shape[: len(self.batch_size)] != self.batch_size: raise ValueError( - "The value of spec.shape must match the env batch size." + f"The value of spec.shape ({value.shape}) must match the env batch size ({self.batch_size})." ) if isinstance(value, CompositeSpec): for _ in value.values(True, True): # noqa: B007 @@ -965,7 +965,7 @@ def done_spec(self, value: TensorSpec) -> None: ) if value.shape[: len(self.batch_size)] != self.batch_size: raise ValueError( - "The value of spec.shape must match the env batch size." + f"The value of spec.shape ({value.shape}) must match the env batch size ({self.batch_size})." ) if isinstance(value, CompositeSpec): for _ in value.values(True, True): # noqa: B007 @@ -1736,10 +1736,21 @@ def __init__( self._constructor_kwargs = kwargs self._check_kwargs(kwargs) self._env = self._build_env(**kwargs) # writes the self._env attribute + if self.batch_size in (None, torch.Size([])): + self.__dict__["_batch_size"] = self._get_batch_size(self._env) self._make_specs(self._env) # writes the self._env attribute self.is_closed = False self._init_env() # runs all the steps to have a ready-to-use env + def _get_batch_size(self, env): + """Batch-size adjustment. + + This is executed after super().__init__(), ie. when the batch-size has been set. + By default, it is a no-op. For some envs (batched envs) we adapt the batch-size + according to the number of sub-envs. See GymWrapper._get_batch_size for an example. + """ + return self.batch_size + @abc.abstractmethod def _check_kwargs(self, kwargs: Dict): raise NotImplementedError diff --git a/torchrl/envs/gym_like.py b/torchrl/envs/gym_like.py index 289bb731278..df774ab9fec 100644 --- a/torchrl/envs/gym_like.py +++ b/torchrl/envs/gym_like.py @@ -144,7 +144,7 @@ def read_done(self, done): done (np.ndarray, boolean or other format): done state obtained from the environment """ - return done, done + return done, done.any() if not isinstance(done, bool) else done def read_reward(self, reward): """Reads the reward and maps it to the reward space. diff --git a/torchrl/envs/libs/gym.py b/torchrl/envs/libs/gym.py index 8e58c0915a6..27a549f08e4 100644 --- a/torchrl/envs/libs/gym.py +++ b/torchrl/envs/libs/gym.py @@ -10,6 +10,7 @@ from warnings import warn import torch +from torchrl.envs.vec_env import CloudpickleWrapper try: from torch.utils._contextlib import _DecoratorContextManager @@ -22,8 +23,6 @@ BoundedTensorSpec, CompositeSpec, DiscreteTensorSpec, - MultiDiscreteTensorSpec, - MultiOneHotDiscreteTensorSpec, OneHotDiscreteTensorSpec, TensorSpec, UnboundedContinuousTensorSpec, @@ -198,7 +197,18 @@ def _gym_to_torchrl_spec_transform( """ gym = gym_backend() if isinstance(spec, gym.spaces.tuple.Tuple): - raise NotImplementedError("gym.spaces.tuple.Tuple mapping not yet implemented") + return torch.stack( + [ + _gym_to_torchrl_spec_transform( + s, + device=device, + categorical_action_encoding=categorical_action_encoding, + remap_state_to_observation=remap_state_to_observation, + ) + for s in spec + ], + 0, + ) if isinstance(spec, gym.spaces.discrete.Discrete): action_space_cls = ( DiscreteTensorSpec @@ -216,16 +226,28 @@ def _gym_to_torchrl_spec_transform( spec.n, device=device, dtype=numpy_to_torch_dtype_dict[spec.dtype] ) elif isinstance(spec, gym.spaces.multi_discrete.MultiDiscrete): - dtype = ( - numpy_to_torch_dtype_dict[spec.dtype] - if categorical_action_encoding - else torch.long - ) - return ( - MultiDiscreteTensorSpec(spec.nvec, device=device, dtype=dtype) - if categorical_action_encoding - else MultiOneHotDiscreteTensorSpec(spec.nvec, device=device, dtype=dtype) + # dtype = ( + # numpy_to_torch_dtype_dict[spec.dtype] + # if categorical_action_encoding + # else torch.long + # ) + return torch.stack( + [ + _gym_to_torchrl_spec_transform( + spec[i], + device=device, + categorical_action_encoding=categorical_action_encoding, + remap_state_to_observation=remap_state_to_observation, + ) + for i in range(len(spec.nvec)) + ], + 0, ) + # return ( + # MultiDiscreteTensorSpec(spec.nvec, device=device, dtype=dtype) + # if categorical_action_encoding + # else MultiOneHotDiscreteTensorSpec(spec.nvec, device=device, dtype=dtype) + # ) elif isinstance(spec, gym.spaces.Box): shape = spec.shape if not len(shape): @@ -387,6 +409,13 @@ def __init__(self, env=None, categorical_action_encoding=False, **kwargs): else: super().__init__(**kwargs) + def _get_batch_size(self, env): + if hasattr(env, "num_envs"): + batch_size = torch.Size([env.num_envs, *self.batch_size]) + else: + batch_size = self.batch_size + return batch_size + def _check_kwargs(self, kwargs: Dict): if "env" not in kwargs: raise TypeError("Could not find environment key 'env' in kwargs.") @@ -551,9 +580,13 @@ def _make_specs(self, env: "gym.Env", batch_size=None) -> None: # noqa: F821 ) if not isinstance(observation_spec, CompositeSpec): if self.from_pixels: - observation_spec = CompositeSpec(pixels=observation_spec) + observation_spec = CompositeSpec( + pixels=observation_spec, shape=self.batch_size + ) else: - observation_spec = CompositeSpec(observation=observation_spec) + observation_spec = CompositeSpec( + observation=observation_spec, shape=self.batch_size + ) if hasattr(env, "reward_space") and env.reward_space is not None: reward_spec = _gym_to_torchrl_spec_transform( env.reward_space, @@ -572,7 +605,10 @@ def _make_specs(self, env: "gym.Env", batch_size=None) -> None: # noqa: F821 *batch_size, *observation_spec.shape ) self.action_spec = action_spec - self.reward_spec = reward_spec + if reward_spec.shape[: len(self.batch_size)] != self.batch_size: + self.reward_spec = reward_spec.expand(*self.batch_size, *reward_spec.shape) + else: + self.reward_spec = reward_spec self.observation_spec = observation_spec def _init_env(self): @@ -643,6 +679,14 @@ def _set_gym_args( # noqa: F811 ) -> None: kwargs.setdefault("disable_env_checker", True) + @implement_for("gym", None, "0.27") + def _async_env(self, *args, **kwargs): + return gym_backend("vector").aSyncVectorEnv(*args, **kwargs) + + @implement_for("gymnasium", "0.27", None) + def _async_env(self, *args, **kwargs): # noqa: F811 + return gym_backend("vector").AsyncVectorEnv(*args, **kwargs) + def _build_env( self, env_name: str, @@ -654,13 +698,10 @@ def _build_env( f"Consider downloading and installing gym from" f" {self.git_url}" ) - from_pixels = kwargs.get("from_pixels", False) + from_pixels = kwargs.pop("from_pixels", False) self._set_gym_default(kwargs, from_pixels) - if "from_pixels" in kwargs: - del kwargs["from_pixels"] - pixels_only = kwargs.get("pixels_only", True) - if "pixels_only" in kwargs: - del kwargs["pixels_only"] + pixels_only = kwargs.pop("pixels_only", True) + num_envs = kwargs.pop("num_envs", 0) made_env = False kwargs["frameskip"] = self.frame_skip self.wrapper_frame_skip = 1 @@ -687,7 +728,11 @@ def _build_env( kwargs.pop("render_mode") else: raise err - return super()._build_env(env, pixels_only=pixels_only, from_pixels=from_pixels) + env = super()._build_env(env, pixels_only=pixels_only, from_pixels=from_pixels) + if num_envs > 0: + return self._async_env([CloudpickleWrapper(lambda: env)] * num_envs) + else: + return env @implement_for("gym", None, "0.25.1") def _set_gym_default(self, kwargs, from_pixels: bool) -> None: # noqa: F811 From 0b0d8df0a3c9634c6b82a26193de129e20d787b8 Mon Sep 17 00:00:00 2001 From: vmoens Date: Tue, 12 Sep 2023 10:51:27 -0400 Subject: [PATCH 02/16] amend --- torchrl/envs/libs/gym.py | 26 +++++++++++++++----------- 1 file changed, 15 insertions(+), 11 deletions(-) diff --git a/torchrl/envs/libs/gym.py b/torchrl/envs/libs/gym.py index 27a549f08e4..65afab82b7d 100644 --- a/torchrl/envs/libs/gym.py +++ b/torchrl/envs/libs/gym.py @@ -9,6 +9,7 @@ from typing import Dict, List from warnings import warn +import numpy as np import torch from torchrl.envs.vec_env import CloudpickleWrapper @@ -25,7 +26,8 @@ DiscreteTensorSpec, OneHotDiscreteTensorSpec, TensorSpec, - UnboundedContinuousTensorSpec, + UnboundedContinuousTensorSpec, MultiDiscreteTensorSpec, + MultiOneHotDiscreteTensorSpec, ) from torchrl.data.utils import numpy_to_torch_dtype_dict @@ -226,11 +228,18 @@ def _gym_to_torchrl_spec_transform( spec.n, device=device, dtype=numpy_to_torch_dtype_dict[spec.dtype] ) elif isinstance(spec, gym.spaces.multi_discrete.MultiDiscrete): - # dtype = ( - # numpy_to_torch_dtype_dict[spec.dtype] - # if categorical_action_encoding - # else torch.long - # ) + if len(spec.nvec.shape) == 1 and len(np.unique(spec.nvec)) > 1: + dtype = ( + numpy_to_torch_dtype_dict[spec.dtype] + if categorical_action_encoding + else torch.long + ) + return ( + MultiDiscreteTensorSpec(spec.nvec, device=device, dtype=dtype) + if categorical_action_encoding + else MultiOneHotDiscreteTensorSpec(spec.nvec, device=device, dtype=dtype) + ) + return torch.stack( [ _gym_to_torchrl_spec_transform( @@ -243,11 +252,6 @@ def _gym_to_torchrl_spec_transform( ], 0, ) - # return ( - # MultiDiscreteTensorSpec(spec.nvec, device=device, dtype=dtype) - # if categorical_action_encoding - # else MultiOneHotDiscreteTensorSpec(spec.nvec, device=device, dtype=dtype) - # ) elif isinstance(spec, gym.spaces.Box): shape = spec.shape if not len(shape): From c79830792445757cf7b7d2f7c2813d44d4b2dafc Mon Sep 17 00:00:00 2001 From: vmoens Date: Wed, 13 Sep 2023 13:05:10 -0400 Subject: [PATCH 03/16] amend --- torchrl/envs/gym_like.py | 26 +++++--- torchrl/envs/libs/gym.py | 85 ++++++++++++++++++++++++--- torchrl/envs/transforms/transforms.py | 61 ++++++++++++++++++- 3 files changed, 153 insertions(+), 19 deletions(-) diff --git a/torchrl/envs/gym_like.py b/torchrl/envs/gym_like.py index df774ab9fec..08acbfe2514 100644 --- a/torchrl/envs/gym_like.py +++ b/torchrl/envs/gym_like.py @@ -115,11 +115,11 @@ class GymLikeEnv(_EnvWrapper): It is also expected that env.reset() returns an observation similar to the one observed after a step is completed. """ - _info_dict_reader: BaseInfoDictReader + _info_dict_reader: List[BaseInfoDictReader] @classmethod def __new__(cls, *args, **kwargs): - cls._info_dict_reader = None + cls._info_dict_reader = [] return super().__new__(cls, *args, _batch_locked=True, **kwargs) def read_action(self, action): @@ -231,8 +231,11 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase: tensordict_out = TensorDict(obs_dict, batch_size=tensordict.batch_size) - if self.info_dict_reader is not None and info is not None: - self.info_dict_reader(info, tensordict_out) + if self.info_dict_reader and info is not None: + for info_dict_reader in self.info_dict_reader: + out = info_dict_reader(info, tensordict_out) + if out is not None: + tensordict_out = out tensordict_out = tensordict_out.to(self.device, non_blocking=True) return tensordict_out @@ -255,9 +258,12 @@ def _reset( source=source, batch_size=self.batch_size, ) - if self.info_dict_reader is not None and info is not None: - self.info_dict_reader(info, tensordict_out) - elif info is None and self.info_dict_reader is not None: + if self.info_dict_reader and info is not None: + for info_dict_reader in self.info_dict_reader: + out = info_dict_reader(info, tensordict_out) + if out is not None: + tensordict_out = out + elif info is None and self.info_dict_reader: # populate the reset with the items we have not seen from info for key, item in self.observation_spec.items(True, True): if key not in tensordict_out.keys(True, True): @@ -298,7 +304,7 @@ def set_info_dict_reader(self, info_dict_reader: BaseInfoDictReader) -> GymLikeE >>> assert "my_info_key" in tensordict.keys() """ - self.info_dict_reader = info_dict_reader + self.info_dict_reader.append(info_dict_reader) for info_key, spec in info_dict_reader.info_spec.items(): self.observation_spec[info_key] = spec.to(self.device) return self @@ -314,4 +320,6 @@ def info_dict_reader(self): @info_dict_reader.setter def info_dict_reader(self, value: callable): - self._info_dict_reader = value + warnings.warn(f"Please use {type(self)}.set_info_dict_reader method to set a new info reader. Setting info_dict_reader directly will be soon deprecated.") + self._info_dict_reader = [value] + diff --git a/torchrl/envs/libs/gym.py b/torchrl/envs/libs/gym.py index 65afab82b7d..0c85a93bf0c 100644 --- a/torchrl/envs/libs/gym.py +++ b/torchrl/envs/libs/gym.py @@ -2,15 +2,18 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +import abc import importlib import warnings from copy import copy from types import ModuleType -from typing import Dict, List +from typing import Dict, List, Optional from warnings import warn import numpy as np import torch + +from tensordict import TensorDictBase from torchrl.envs.vec_env import CloudpickleWrapper try: @@ -31,7 +34,8 @@ ) from torchrl.data.utils import numpy_to_torch_dtype_dict -from torchrl.envs.gym_like import default_info_dict_reader, GymLikeEnv +from torchrl.envs.gym_like import default_info_dict_reader, GymLikeEnv, \ + BaseInfoDictReader from torchrl.envs.utils import _classproperty DEFAULT_GYM = None @@ -42,7 +46,7 @@ _has_gym = importlib.util.find_spec("gymnasium") is not None _has_mo = importlib.util.find_spec("mo_gymnasium") is not None - +_has_sb3 = importlib.util.find_spec("stable_baselines3") is not None class set_gym_backend(_DecoratorContextManager): """Sets the gym-backend to a certain value. @@ -366,7 +370,24 @@ class PixelObservationWrapper: return False -class GymWrapper(GymLikeEnv): +class _AsyncMeta(abc.ABCMeta): + def __call__(cls, *args, **kwargs): + instance: GymWrapper = super().__call__(*args, **kwargs) + if instance._is_batched: + from torchrl.envs.transforms.transforms import TransformedEnv, VecGymEnvTransform + if _has_sb3: + from stable_baselines3.common.vec_env.base_vec_env import VecEnv + if isinstance(instance._env, VecEnv): + backend = "sb3" + else: + backend = "gym" + else: + backend = "gym" + instance.set_info_dict_reader(terminal_obs_reader(instance.observation_spec, backend=backend)) + return TransformedEnv(instance, VecGymEnvTransform()) + return instance + +class GymWrapper(GymLikeEnv, metaclass=_AsyncMeta): """OpenAI Gym environment wrapper. Examples: @@ -413,6 +434,15 @@ def __init__(self, env=None, categorical_action_encoding=False, **kwargs): else: super().__init__(**kwargs) + @property + def _is_batched(self): + if _has_sb3: + from stable_baselines3.common.vec_env.base_vec_env import VecEnv + tuple_of_classes = (VecEnv,) + else: + tuple_of_classes = () + return isinstance(self._env, tuple_of_classes + (gym_backend("vector").VectorEnv,)) + def _get_batch_size(self, env): if hasattr(env, "num_envs"): batch_size = torch.Size([env.num_envs, *self.batch_size]) @@ -638,6 +668,18 @@ def info_dict_reader(self): def info_dict_reader(self, value: callable): self._info_dict_reader = value + def _reset( + self, tensordict: Optional[TensorDictBase] = None, **kwargs + ) -> TensorDictBase: + if self._is_batched: + if tensordict is None: + return super()._reset(tensordict) + reset = tensordict.get("_reset", None) + if reset is None or reset.all(): + return super()._reset(tensordict) + elif reset is not None: + return tensordict.clone(False) + return super()._reset(tensordict) ACCEPTED_TYPE_ERRORS = { "render_mode": "__init__() got an unexpected keyword argument 'render_mode'", @@ -683,12 +725,7 @@ def _set_gym_args( # noqa: F811 ) -> None: kwargs.setdefault("disable_env_checker", True) - @implement_for("gym", None, "0.27") def _async_env(self, *args, **kwargs): - return gym_backend("vector").aSyncVectorEnv(*args, **kwargs) - - @implement_for("gymnasium", "0.27", None) - def _async_env(self, *args, **kwargs): # noqa: F811 return gym_backend("vector").AsyncVectorEnv(*args, **kwargs) def _build_env( @@ -810,3 +847,33 @@ def lib(self) -> ModuleType: raise ImportError("MO-gymnasium not found, check installation") from err _make_specs = set_gym_backend("gymnasium")(GymEnv._make_specs) + + +class terminal_obs_reader(BaseInfoDictReader): + backend_key = { + "sb3": "terminal_observation", + "gym": "final_observation", + } + def __init__(self, observation_spec: CompositeSpec, backend): + self._info_spec = CompositeSpec( + {("final", key): item.clone() for key, item in observation_spec.items()}, shape=observation_spec.shape + ) + self.backend = backend + + @property + def info_spec(self): + return self._info_spec + + def __call__(self, info_dict, tensordict): + terminal_obs = info_dict.get(self.backend_key[self.backend], None) + for key, item in self.info_spec.items(True, True): + final_obs = item.zero() + break + else: + raise RuntimeError("The info spec cannot be empty.") + if terminal_obs is not None: + for i, obs in enumerate(terminal_obs): + if obs is not None: + final_obs[i] = torch.as_tensor(obs, device=final_obs.device) + tensordict.set(key, final_obs) + return tensordict diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index e37e079e63b..0a6d8ff7f3a 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -655,7 +655,7 @@ def _set_seed(self, seed: Optional[int]): def _reset(self, tensordict: Optional[TensorDictBase] = None, **kwargs): if tensordict is not None: tensordict = tensordict.clone(recurse=False) - out_tensordict = self.base_env.reset(tensordict=tensordict, **kwargs) + out_tensordict = self.base_env._reset(tensordict=tensordict, **kwargs) out_tensordict = self.transform.reset(out_tensordict) mt_mode = self.transform.missing_tolerance @@ -5030,3 +5030,62 @@ def reset(self, tensordict: TensorDictBase) -> TensorDictBase: ) action_spec.update_mask(tensordict.get(self.in_keys[1], None)) return tensordict + + +class VecGymEnvTransform(Transform): + """A transform for GymWrapper subclasses that handles the auto-reset in a consistent way. + + Gym, gymnasium and SB3 provide vectorized (read, parallel or batched) environments + that are automatically reset. When this occur, the actual observation resulting + from the action is saved within a key in the info. + The class :class:`torchrl.envs.libs.gym.terminal_obs_reader` reads that observation + and stores it in a ``"final"`` key within the output tensordict. + In turn, this transform reads that final data, swaps it with the observation + written in its place that results from the actual reset, and saves the + reset output in a private container. The resulting data truly reflects + the output of the step. + + Then, when calling `env.reset`, the saved data is written back where it belongs + (and the `reset` is a no-op). + + This transform is automatically appended to the gym env whenever the wrapper + is created with an async env. + """ + def __init__(self): + super().__init__(in_keys=[]) + self._memo = {} + + def _step( + self, tensordict: TensorDictBase, next_tensordict: TensorDictBase + ) -> TensorDictBase: + # save the final info + done = self._memo['done'] = next_tensordict.get("done") + final = next_tensordict.pop("final") + # if anything's done, we need to swap the final obs + if done.any(): + done = done.squeeze(-1) + saved_next = next_tensordict.select(*final.keys(True, True))[done].clone() + next_tensordict[done] = final[done] + self._memo['saved_done'] = saved_next + else: + self._memo['saved_done'] = None + return next_tensordict + + def reset(self, tensordict: TensorDictBase) -> TensorDictBase: + done = self._memo.get("done", None) + reset = tensordict.get("_reset", done) + if done is not None: + done = done.view_as(reset) + if reset is not done and (reset != done).any() and (not reset.all() or not reset.any()): + raise RuntimeError("Cannot partially reset a gym(nasium) async env with a reset mask that does not match the done mask. " + f"Got reset={reset}\nand done={done}") + # if not reset.any(), we don't need to do anything. + # if reset.all(), we don't either (bc GymWrapper will call a plain reset). + if reset is not None and reset.any() and not reset.all(): + saved_done = self._memo['saved_done'] + reset = reset.view(tensordict.shape) + updated_td = torch.where(~reset, tensordict.select(*saved_done.keys(True, True)), saved_done) + tensordict.update(updated_td) + tensordict.set("done", tensordict.get("done").clone().fill_(0)) + tensordict.pop("final", None) + return tensordict From 2c178a52b9307943025125477d50618a5e13ae05 Mon Sep 17 00:00:00 2001 From: vmoens Date: Thu, 14 Sep 2023 05:58:39 -0400 Subject: [PATCH 04/16] amend --- .../linux_libs/scripts_gym/batch_scripts.sh | 1 + docs/source/reference/envs.rst | 2 + test/_utils_internal.py | 26 ++++ test/test_libs.py | 62 ++++++---- torchrl/envs/gym_like.py | 6 +- torchrl/envs/libs/gym.py | 115 +++++++++++++++--- torchrl/envs/libs/robohive.py | 3 +- torchrl/envs/transforms/transforms.py | 46 +++++-- 8 files changed, 207 insertions(+), 54 deletions(-) diff --git a/.github/unittest/linux_libs/scripts_gym/batch_scripts.sh b/.github/unittest/linux_libs/scripts_gym/batch_scripts.sh index ee1145d3f93..3fadedb0ffd 100755 --- a/.github/unittest/linux_libs/scripts_gym/batch_scripts.sh +++ b/.github/unittest/linux_libs/scripts_gym/batch_scripts.sh @@ -155,6 +155,7 @@ do pip install gymnasium[atari] fi pip install mo-gymnasium + pip install gymnasium-robotics $DIR/run_test.sh diff --git a/docs/source/reference/envs.rst b/docs/source/reference/envs.rst index f6e6c536ce5..c48e8638d11 100644 --- a/docs/source/reference/envs.rst +++ b/docs/source/reference/envs.rst @@ -484,6 +484,7 @@ to be able to create this other composition: TimeMaxPool ToTensorImage UnsqueezeTransform + VecGymEnvTransform VecNorm VC1Transform VIPRewardTransform @@ -618,6 +619,7 @@ the following function will return ``1`` when queried: dm_control.DMControlWrapper gym.GymEnv gym.GymWrapper + gym.terminal_obs_reader gym.MOGymEnv gym.MOGymWrapper gym.set_gym_backend diff --git a/test/_utils_internal.py b/test/_utils_internal.py index 85e76790c26..3a4d3eaf333 100644 --- a/test/_utils_internal.py +++ b/test/_utils_internal.py @@ -320,3 +320,29 @@ class MyClass: for key in td.keys(): MyClass.__annotations__[key] = torch.Tensor return tensorclass(MyClass) + + +def rollout_consistency_assertion( + rollout, *, done_key="done", observation_key="observation" +): + """Tests that observations in "next" match observations in the next root tensordict when done is False, and don't match otherwise.""" + + done = rollout[:, :-1]["next", done_key].squeeze(-1) + # data resulting from step, when it's not done + r_not_done = rollout[:, :-1]["next"][~done] + # data resulting from step, when it's not done, after step_mdp + r_not_done_tp1 = rollout[:, 1:][~done] + torch.testing.assert_close( + r_not_done[observation_key], r_not_done_tp1[observation_key] + ) + + if not done.any(): + return + + # data resulting from step, when it's done + r_done = rollout[:, :-1]["next"][done] + # data resulting from step, when it's done, after step_mdp and reset + r_done_tp1 = rollout[:, 1:][done] + assert ( + (r_done[observation_key] - r_done_tp1[observation_key]).norm(dim=-1) > 1e-1 + ).all(), (r_done[observation_key] - r_done_tp1[observation_key]).norm(dim=-1) diff --git a/test/test_libs.py b/test/test_libs.py index d37c0887653..273dfc0b825 100644 --- a/test/test_libs.py +++ b/test/test_libs.py @@ -75,6 +75,7 @@ _has_sklearn = importlib.util.find_spec("sklearn") is not None +_has_gym_robotics = importlib.util.find_spec("gymnasium_robotics") is not None if _has_gym: try: @@ -314,44 +315,57 @@ def test_one_hot_and_categorical(self): # noqa: F811 return @implement_for("gymnasium", "0.27.0", None) - def test_vecenvs(self): + # this env has Dict-based observation which is a nice thing to test + @pytest.mark.parametrize( + "envname", + ["HalfCheetah-v4", "CartPole-v1", "ALE/Pong-v5"] + + (["FetchReach-v2"] if _has_gym_robotics else []), + ) + def test_vecenvs(self, envname): import gymnasium + from _utils_internal import rollout_consistency_assertion # we can't use parametrize with implement_for - for envname in ["HalfCheetah-v4", "CartPole-v1", "ALE/Pong-v5"]: - env = GymWrapper( - gymnasium.vector.SyncVectorEnv(2 * [lambda envname=envname: gymnasium.make(envname)]) + env = GymWrapper( + gymnasium.vector.SyncVectorEnv( + 2 * [lambda envname=envname: gymnasium.make(envname)] ) - assert env.batch_size == torch.Size([2]) - check_env_specs(env) - env = GymWrapper( - gymnasium.vector.AsyncVectorEnv(2 * [lambda envname=envname: gymnasium.make(envname)]) + ) + assert env.batch_size == torch.Size([2]) + check_env_specs(env) + env = GymWrapper( + gymnasium.vector.AsyncVectorEnv( + 2 * [lambda envname=envname: gymnasium.make(envname)] ) - assert env.batch_size == torch.Size([2]) + ) + assert env.batch_size == torch.Size([2]) + check_env_specs(env) + with set_gym_backend("gymnasium"): + env = GymEnv(envname, num_envs=2, from_pixels=False) check_env_specs(env) - with set_gym_backend("gymnasium"): - env = GymEnv(envname, num_envs=2, from_pixels=False) - check_env_specs(env) - # with set_gym_backend("gymnasium"): - # env = GymEnv(envname, num_envs=2, from_pixels=True) - # check_env_specs(env) + rollout = env.rollout(100, break_when_any_done=False) + for obs_key in env.observation_spec.keys(True, True): + rollout_consistency_assertion( + rollout, done_key="done", observation_key=obs_key + ) - @implement_for("gym", "0.24", "0.27.0") + @implement_for("gym", "0.18", "0.27.0") def test_vecenvs(self): # noqa: F811 - import gymnasium + import gym + from _utils_internal import rollout_consistency_assertion # we can't use parametrize with implement_for for envname in ["CartPole-v1", "HalfCheetah-v4"]: env = GymWrapper( - gymnasium.vector.SyncVectorEnv( - 2 * [lambda envname=envname: gymnasium.make(envname)] + gym.vector.SyncVectorEnv( + 2 * [lambda envname=envname: gym.make(envname)] ) ) assert env.batch_size == torch.Size([2]) check_env_specs(env) env = GymWrapper( - gymnasium.vector.aSyncVectorEnv( - 2 * [lambda envname=envname: gymnasium.make(envname)] + gym.vector.AsyncVectorEnv( + 2 * [lambda envname=envname: gym.make(envname)] ) ) assert env.batch_size == torch.Size([2]) @@ -359,6 +373,12 @@ def test_vecenvs(self): # noqa: F811 with set_gym_backend("gym"): env = GymEnv(envname, num_envs=2, from_pixels=False) check_env_specs(env) + rollout = env.rollout(100, break_when_any_done=False) + for obs_key in env.observation_spec.keys(True, True): + rollout_consistency_assertion( + rollout, done_key="done", observation_key=obs_key + ) + with set_gym_backend("gym"): env = GymEnv(envname, num_envs=2, from_pixels=True) check_env_specs(env) diff --git a/torchrl/envs/gym_like.py b/torchrl/envs/gym_like.py index 08acbfe2514..9f6e179f709 100644 --- a/torchrl/envs/gym_like.py +++ b/torchrl/envs/gym_like.py @@ -320,6 +320,8 @@ def info_dict_reader(self): @info_dict_reader.setter def info_dict_reader(self, value: callable): - warnings.warn(f"Please use {type(self)}.set_info_dict_reader method to set a new info reader. Setting info_dict_reader directly will be soon deprecated.") + warnings.warn( + f"Please use {type(self)}.set_info_dict_reader method to set a new info reader. Setting info_dict_reader directly will be soon deprecated.", + category=DeprecationWarning, + ) self._info_dict_reader = [value] - diff --git a/torchrl/envs/libs/gym.py b/torchrl/envs/libs/gym.py index 0c85a93bf0c..e41e2a1a42f 100644 --- a/torchrl/envs/libs/gym.py +++ b/torchrl/envs/libs/gym.py @@ -27,15 +27,19 @@ BoundedTensorSpec, CompositeSpec, DiscreteTensorSpec, + MultiDiscreteTensorSpec, + MultiOneHotDiscreteTensorSpec, OneHotDiscreteTensorSpec, TensorSpec, - UnboundedContinuousTensorSpec, MultiDiscreteTensorSpec, - MultiOneHotDiscreteTensorSpec, + UnboundedContinuousTensorSpec, ) from torchrl.data.utils import numpy_to_torch_dtype_dict -from torchrl.envs.gym_like import default_info_dict_reader, GymLikeEnv, \ - BaseInfoDictReader +from torchrl.envs.gym_like import ( + BaseInfoDictReader, + default_info_dict_reader, + GymLikeEnv, +) from torchrl.envs.utils import _classproperty DEFAULT_GYM = None @@ -48,6 +52,7 @@ _has_mo = importlib.util.find_spec("mo_gymnasium") is not None _has_sb3 = importlib.util.find_spec("stable_baselines3") is not None + class set_gym_backend(_DecoratorContextManager): """Sets the gym-backend to a certain value. @@ -241,7 +246,9 @@ def _gym_to_torchrl_spec_transform( return ( MultiDiscreteTensorSpec(spec.nvec, device=device, dtype=dtype) if categorical_action_encoding - else MultiOneHotDiscreteTensorSpec(spec.nvec, device=device, dtype=dtype) + else MultiOneHotDiscreteTensorSpec( + spec.nvec, device=device, dtype=dtype + ) ) return torch.stack( @@ -374,19 +381,27 @@ class _AsyncMeta(abc.ABCMeta): def __call__(cls, *args, **kwargs): instance: GymWrapper = super().__call__(*args, **kwargs) if instance._is_batched: - from torchrl.envs.transforms.transforms import TransformedEnv, VecGymEnvTransform + from torchrl.envs.transforms.transforms import ( + TransformedEnv, + VecGymEnvTransform, + ) + if _has_sb3: from stable_baselines3.common.vec_env.base_vec_env import VecEnv + if isinstance(instance._env, VecEnv): backend = "sb3" else: backend = "gym" else: backend = "gym" - instance.set_info_dict_reader(terminal_obs_reader(instance.observation_spec, backend=backend)) + instance.set_info_dict_reader( + terminal_obs_reader(instance.observation_spec, backend=backend) + ) return TransformedEnv(instance, VecGymEnvTransform()) return instance + class GymWrapper(GymLikeEnv, metaclass=_AsyncMeta): """OpenAI Gym environment wrapper. @@ -438,10 +453,13 @@ def __init__(self, env=None, categorical_action_encoding=False, **kwargs): def _is_batched(self): if _has_sb3: from stable_baselines3.common.vec_env.base_vec_env import VecEnv + tuple_of_classes = (VecEnv,) else: tuple_of_classes = () - return isinstance(self._env, tuple_of_classes + (gym_backend("vector").VectorEnv,)) + return isinstance( + self._env, tuple_of_classes + (gym_backend("vector").VectorEnv,) + ) def _get_batch_size(self, env): if hasattr(env, "num_envs"): @@ -621,6 +639,9 @@ def _make_specs(self, env: "gym.Env", batch_size=None) -> None: # noqa: F821 observation_spec = CompositeSpec( observation=observation_spec, shape=self.batch_size ) + elif observation_spec.shape[: len(self.batch_size)] != self.batch_size: + observation_spec.shape = self.batch_size + if hasattr(env, "reward_space") and env.reward_space is not None: reward_spec = _gym_to_torchrl_spec_transform( env.reward_space, @@ -679,7 +700,8 @@ def _reset( return super()._reset(tensordict) elif reset is not None: return tensordict.clone(False) - return super()._reset(tensordict) + return super()._reset(tensordict, **kwargs) + ACCEPTED_TYPE_ERRORS = { "render_mode": "__init__() got an unexpected keyword argument 'render_mode'", @@ -850,13 +872,42 @@ def lib(self) -> ModuleType: class terminal_obs_reader(BaseInfoDictReader): + """Terminal observation reader for 'vectorized' gym environments. + + When running envs in parallel, Gym(nasium) writes the result of the true call + to `step` in `"final_observation"` entry within the `info` dictionary. + + This breaks the natural flow and makes single-processed and multiprocessed envs + incompatible. + + This class reads the info obs, removes the `"final_observation"` from + the env and writes its content in the data. + + Next, a :class:`torchrl.envs.VecGymEnvTransform` transform will reorganise the + data by caching the result of the (implicit) reset and swap the true next + observation with the reset one. At reset time, the true reset data will be + replaced. + + Args: + observation_spec (CompositeSpec): The observation spec of the gym env. + backend (str, optional): the backend of the env. One of `"sb3"` for + stable-baselines3 or `"gym"` for gym/gymnasium. + + .. note:: In general, this class should not be handled directly. It is + created whenever a vectorized environment is placed within a :class:`GymWrapper`. + + """ + backend_key = { "sb3": "terminal_observation", "gym": "final_observation", } - def __init__(self, observation_spec: CompositeSpec, backend): + + def __init__(self, observation_spec: CompositeSpec, backend, name="final"): + self.name = name self._info_spec = CompositeSpec( - {("final", key): item.clone() for key, item in observation_spec.items()}, shape=observation_spec.shape + {(self.name, key): item.clone() for key, item in observation_spec.items()}, + shape=observation_spec.shape, ) self.backend = backend @@ -864,16 +915,42 @@ def __init__(self, observation_spec: CompositeSpec, backend): def info_spec(self): return self._info_spec + def _read_obs(self, obs, key, tensor, index): + if obs is None: + return + if isinstance(obs, np.ndarray): + # Simplest case: there is one observation, + # presented as a np.ndarray. The key should be pixels or observation. + # We just write that value at its location in the tensor + tensor[index] = torch.as_tensor(obs, device=tensor.device) + elif isinstance(obs, dict): + if key not in obs: + raise KeyError( + f"The observation {key} could not be found in the final observation dict." + ) + subobs = obs[key] + if subobs is not None: + # if the obs is a dict, we expect that the key points also to + # a value in the obs. We retrieve this value and write it in the + # tensor + tensor[index] = torch.as_tensor(subobs, device=tensor.device) + + elif isinstance(obs, (list, tuple)): + # tuples are stacked along the first dimension when passing gym spaces + # to torchrl specs. As such, we can simply stack the tuple and set it + # at the relevant index (assuming stacking can be achieved) + tensor[index] = torch.as_tensor(obs, device=tensor.device) + else: + raise NotImplementedError( + f"Observations of type {type(obs)} are not supported yet." + ) + def __call__(self, info_dict, tensordict): terminal_obs = info_dict.get(self.backend_key[self.backend], None) for key, item in self.info_spec.items(True, True): final_obs = item.zero() - break - else: - raise RuntimeError("The info spec cannot be empty.") - if terminal_obs is not None: - for i, obs in enumerate(terminal_obs): - if obs is not None: - final_obs[i] = torch.as_tensor(obs, device=final_obs.device) - tensordict.set(key, final_obs) + if terminal_obs is not None: + for i, obs in enumerate(terminal_obs): + self._read_obs(obs, key[-1], final_obs, index=i) + tensordict.set(key, final_obs) return tensordict diff --git a/torchrl/envs/libs/robohive.py b/torchrl/envs/libs/robohive.py index 78eefa6d443..0f637f3140a 100644 --- a/torchrl/envs/libs/robohive.py +++ b/torchrl/envs/libs/robohive.py @@ -22,7 +22,8 @@ ) from torchrl.envs.utils import make_composite_from_td -_has_robohive = importlib.util.find_spec("robohive") is not None +_has_gym = importlib.util.find_spec("gym") is not None +_has_robohive = importlib.util.find_spec("robohive") is not None and _has_gym if _has_robohive: os.environ.setdefault("sim_backend", "MUJOCO") diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index 0a6d8ff7f3a..30182870d7e 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -656,6 +656,8 @@ def _reset(self, tensordict: Optional[TensorDictBase] = None, **kwargs): if tensordict is not None: tensordict = tensordict.clone(recurse=False) out_tensordict = self.base_env._reset(tensordict=tensordict, **kwargs) + if tensordict is not None: + out_tensordict = tensordict.update(out_tensordict) out_tensordict = self.transform.reset(out_tensordict) mt_mode = self.transform.missing_tolerance @@ -5036,7 +5038,7 @@ class VecGymEnvTransform(Transform): """A transform for GymWrapper subclasses that handles the auto-reset in a consistent way. Gym, gymnasium and SB3 provide vectorized (read, parallel or batched) environments - that are automatically reset. When this occur, the actual observation resulting + that are automatically reset. When this occurs, the actual observation resulting from the action is saved within a key in the info. The class :class:`torchrl.envs.libs.gym.terminal_obs_reader` reads that observation and stores it in a ``"final"`` key within the output tensordict. @@ -5050,8 +5052,18 @@ class VecGymEnvTransform(Transform): This transform is automatically appended to the gym env whenever the wrapper is created with an async env. + + Args: + final_name (str, optional): the name of the final observation in the dict. + Defaults to `"final"`. + + .. note:: In general, this class should not be handled directly. It is + created whenever a vectorized environment is placed within a :class:`GymWrapper`. + """ - def __init__(self): + + def __init__(self, final_name="final"): + self.final_name = final_name super().__init__(in_keys=[]) self._memo = {} @@ -5059,16 +5071,16 @@ def _step( self, tensordict: TensorDictBase, next_tensordict: TensorDictBase ) -> TensorDictBase: # save the final info - done = self._memo['done'] = next_tensordict.get("done") + done = self._memo["done"] = next_tensordict.get("done") final = next_tensordict.pop("final") # if anything's done, we need to swap the final obs if done.any(): done = done.squeeze(-1) - saved_next = next_tensordict.select(*final.keys(True, True))[done].clone() + saved_next = next_tensordict.select(*final.keys(True, True)).clone() next_tensordict[done] = final[done] - self._memo['saved_done'] = saved_next + self._memo["saved_done"] = saved_next else: - self._memo['saved_done'] = None + self._memo["saved_done"] = None return next_tensordict def reset(self, tensordict: TensorDictBase) -> TensorDictBase: @@ -5076,16 +5088,28 @@ def reset(self, tensordict: TensorDictBase) -> TensorDictBase: reset = tensordict.get("_reset", done) if done is not None: done = done.view_as(reset) - if reset is not done and (reset != done).any() and (not reset.all() or not reset.any()): - raise RuntimeError("Cannot partially reset a gym(nasium) async env with a reset mask that does not match the done mask. " - f"Got reset={reset}\nand done={done}") + if ( + reset is not done + and (reset != done).any() + and (not reset.all() or not reset.any()) + ): + raise RuntimeError( + "Cannot partially reset a gym(nasium) async env with a reset mask that does not match the done mask. " + f"Got reset={reset}\nand done={done}" + ) # if not reset.any(), we don't need to do anything. # if reset.all(), we don't either (bc GymWrapper will call a plain reset). if reset is not None and reset.any() and not reset.all(): - saved_done = self._memo['saved_done'] + saved_done = self._memo["saved_done"] reset = reset.view(tensordict.shape) - updated_td = torch.where(~reset, tensordict.select(*saved_done.keys(True, True)), saved_done) + updated_td = torch.where( + ~reset, tensordict.select(*saved_done.keys(True, True)), saved_done + ) tensordict.update(updated_td) tensordict.set("done", tensordict.get("done").clone().fill_(0)) tensordict.pop("final", None) return tensordict + + def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec: + del observation_spec[self.final_name] + return observation_spec From 9335a990b55ed9f348d45c8d34e5509618991e61 Mon Sep 17 00:00:00 2001 From: vmoens Date: Thu, 14 Sep 2023 06:19:55 -0400 Subject: [PATCH 05/16] amend --- torchrl/envs/common.py | 9 +++++++++ torchrl/envs/libs/gym.py | 1 + 2 files changed, 10 insertions(+) diff --git a/torchrl/envs/common.py b/torchrl/envs/common.py index 02fe161799a..0d991b9a377 100644 --- a/torchrl/envs/common.py +++ b/torchrl/envs/common.py @@ -818,6 +818,15 @@ def reward_spec(self, value: TensorSpec) -> None: # done spec def _get_done_keys(self): + if "full_done_spec" not in self.output_spec.keys(): + # populate the "done" entry + # this will be raised if there is not full_done_spec (unlikely) or no done_key + # Since output_spec is lazily populated with an empty composite spec for + # done_spec, the second case is much more likely to occur. + self.done_spec = DiscreteTensorSpec( + n=2, shape=(*self.batch_size, 1), dtype=torch.bool, device=self.device + ) + keys = self.output_spec["full_done_spec"].keys(True, True) if not len(keys): raise AttributeError("Could not find done spec") diff --git a/torchrl/envs/libs/gym.py b/torchrl/envs/libs/gym.py index e41e2a1a42f..f31d3b62fec 100644 --- a/torchrl/envs/libs/gym.py +++ b/torchrl/envs/libs/gym.py @@ -243,6 +243,7 @@ def _gym_to_torchrl_spec_transform( if categorical_action_encoding else torch.long ) + return ( MultiDiscreteTensorSpec(spec.nvec, device=device, dtype=dtype) if categorical_action_encoding From cca67fa07e33bd1cfe4b9d9249681eac7a6b49c9 Mon Sep 17 00:00:00 2001 From: vmoens Date: Thu, 14 Sep 2023 12:06:10 -0400 Subject: [PATCH 06/16] fix robohive --- torchrl/_utils.py | 9 ++++++--- torchrl/envs/gym_like.py | 13 +++++++++---- torchrl/envs/libs/gym.py | 4 ++-- torchrl/envs/libs/robohive.py | 2 +- 4 files changed, 18 insertions(+), 10 deletions(-) diff --git a/torchrl/_utils.py b/torchrl/_utils.py index 8d590b05210..e013023320a 100644 --- a/torchrl/_utils.py +++ b/torchrl/_utils.py @@ -248,6 +248,7 @@ class implement_for: # Stores pointers to fitting implementations: dict[func_name] = func_pointer _implementations = {} _setters = [] + _memo = {} def __init__( self, @@ -285,11 +286,13 @@ def module_set(self): cls = inspect.getmodule(self.fn) setattr(cls, self.fn.__name__, self.fn) - @staticmethod - def import_module(module_name: Union[Callable, str]) -> str: + @classmethod + def import_module(cls, module_name: Union[Callable, str]) -> str: """Imports module and returns its version.""" if not callable(module_name): - module = import_module(module_name) + module = cls._memo.get(module_name, None) + if module is None: + module = cls._memo[module_name] = import_module(module_name) else: module = module_name() return module.__version__ diff --git a/torchrl/envs/gym_like.py b/torchrl/envs/gym_like.py index 9f6e179f709..c5b5ef1ad55 100644 --- a/torchrl/envs/gym_like.py +++ b/torchrl/envs/gym_like.py @@ -305,8 +305,11 @@ def set_info_dict_reader(self, info_dict_reader: BaseInfoDictReader) -> GymLikeE """ self.info_dict_reader.append(info_dict_reader) - for info_key, spec in info_dict_reader.info_spec.items(): - self.observation_spec[info_key] = spec.to(self.device) + if isinstance(info_dict_reader, BaseInfoDictReader): + # if we have a BaseInfoDictReader, we know what the specs will be + # In other cases (eg, RoboHive) we will need to figure it out empirically. + for info_key, spec in info_dict_reader.info_spec.items(): + self.observation_spec[info_key] = spec.to(self.device) return self def __repr__(self) -> str: @@ -321,7 +324,9 @@ def info_dict_reader(self): @info_dict_reader.setter def info_dict_reader(self, value: callable): warnings.warn( - f"Please use {type(self)}.set_info_dict_reader method to set a new info reader. Setting info_dict_reader directly will be soon deprecated.", + f"Please use {type(self)}.set_info_dict_reader method to set a new info reader. " + f"This method will append a reader to the list of existing readers (if any). " + f"Setting info_dict_reader directly will be soon deprecated.", category=DeprecationWarning, ) - self._info_dict_reader = [value] + self._info_dict_reader.append(value) diff --git a/torchrl/envs/libs/gym.py b/torchrl/envs/libs/gym.py index f31d3b62fec..d626ff9f5c5 100644 --- a/torchrl/envs/libs/gym.py +++ b/torchrl/envs/libs/gym.py @@ -682,8 +682,8 @@ def rebuild_with_kwargs(self, **new_kwargs): @property def info_dict_reader(self): - if self._info_dict_reader is None: - self._info_dict_reader = default_info_dict_reader() + if not self._info_dict_reader: + self._info_dict_reader.append(default_info_dict_reader()) return self._info_dict_reader @info_dict_reader.setter diff --git a/torchrl/envs/libs/robohive.py b/torchrl/envs/libs/robohive.py index 0f637f3140a..c54015bd039 100644 --- a/torchrl/envs/libs/robohive.py +++ b/torchrl/envs/libs/robohive.py @@ -171,7 +171,7 @@ def _build_env( # noqa: F811 self.from_pixels = from_pixels self.render_device = render_device if kwargs.get("read_info", True): - self.info_dict_reader = self.read_info + self.set_info_dict_reader(self.read_info) return env @classmethod From 65b46053704cb1a30dc90c077fa5e9a94e70bd8c Mon Sep 17 00:00:00 2001 From: vmoens Date: Thu, 14 Sep 2023 16:24:16 -0400 Subject: [PATCH 07/16] amend --- torchrl/_utils.py | 5 +---- torchrl/envs/libs/gym.py | 2 +- 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/torchrl/_utils.py b/torchrl/_utils.py index e013023320a..7f5b80175be 100644 --- a/torchrl/_utils.py +++ b/torchrl/_utils.py @@ -248,7 +248,6 @@ class implement_for: # Stores pointers to fitting implementations: dict[func_name] = func_pointer _implementations = {} _setters = [] - _memo = {} def __init__( self, @@ -290,9 +289,7 @@ def module_set(self): def import_module(cls, module_name: Union[Callable, str]) -> str: """Imports module and returns its version.""" if not callable(module_name): - module = cls._memo.get(module_name, None) - if module is None: - module = cls._memo[module_name] = import_module(module_name) + module = import_module(module_name) else: module = module_name() return module.__version__ diff --git a/torchrl/envs/libs/gym.py b/torchrl/envs/libs/gym.py index d626ff9f5c5..935c876479b 100644 --- a/torchrl/envs/libs/gym.py +++ b/torchrl/envs/libs/gym.py @@ -14,7 +14,7 @@ import torch from tensordict import TensorDictBase -from torchrl.envs.vec_env import CloudpickleWrapper +from torchrl.envs.batched_envs import CloudpickleWrapper try: from torch.utils._contextlib import _DecoratorContextManager From a9a0ebeef3b37e224a97861682a97cb571673779 Mon Sep 17 00:00:00 2001 From: vmoens Date: Fri, 15 Sep 2023 09:49:25 -0400 Subject: [PATCH 08/16] amend --- test/test_libs.py | 7 ++++++- torchrl/envs/libs/gym.py | 6 ++---- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/test/test_libs.py b/test/test_libs.py index 372361c9ade..fcdfc118b8b 100644 --- a/test/test_libs.py +++ b/test/test_libs.py @@ -75,10 +75,10 @@ from torchrl.envs.libs.openml import OpenMLEnv from torchrl.envs.libs.pettingzoo import _has_pettingzoo, PettingZooEnv from torchrl.envs.libs.robohive import RoboHiveEnv +from torchrl.envs.libs.smacv2 import _has_smacv2, SMACv2Env from torchrl.envs.libs.vmas import _has_vmas, VmasEnv, VmasWrapper from torchrl.envs.utils import check_env_specs, ExplorationType, MarlGroupMapType from torchrl.modules import ActorCriticOperator, MLP, SafeModule, ValueOperator -from torchrl.envs.libs.smacv2 import _has_smacv2, SMACv2Env _has_d4rl = importlib.util.find_spec("d4rl") is not None @@ -395,6 +395,11 @@ def test_vecenvs(self): # noqa: F811 env = GymEnv(envname, num_envs=2, from_pixels=True) check_env_specs(env) + @implement_for("gym", None, "0.18") + def test_vecenvs(self): # noqa: F811 + # skipping tests for older versions of gym + return + @implement_for("gym", None, "0.26") def _make_gym_environment(env_name): # noqa: F811 diff --git a/torchrl/envs/libs/gym.py b/torchrl/envs/libs/gym.py index 31ed11f6039..7cdc9c270dd 100644 --- a/torchrl/envs/libs/gym.py +++ b/torchrl/envs/libs/gym.py @@ -14,9 +14,6 @@ import torch from tensordict import TensorDictBase -from torchrl.envs.batched_envs import CloudpickleWrapper - -from torchrl.envs.utils import _classproperty from torchrl._utils import implement_for from torchrl.data.tensor_specs import ( @@ -31,14 +28,15 @@ UnboundedContinuousTensorSpec, ) from torchrl.data.utils import numpy_to_torch_dtype_dict +from torchrl.envs.batched_envs import CloudpickleWrapper from torchrl.envs.gym_like import ( BaseInfoDictReader, default_info_dict_reader, GymLikeEnv, ) + from torchrl.envs.utils import _classproperty -from torchrl.envs.gym_like import default_info_dict_reader, GymLikeEnv try: from torch.utils._contextlib import _DecoratorContextManager From 32c0cb2e7222cbd3ee23bf5b91126e2959a81369 Mon Sep 17 00:00:00 2001 From: vmoens Date: Fri, 15 Sep 2023 11:27:23 -0400 Subject: [PATCH 09/16] amend --- torchrl/envs/libs/gym.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/torchrl/envs/libs/gym.py b/torchrl/envs/libs/gym.py index 4be7c5efc50..dcb10c9c362 100644 --- a/torchrl/envs/libs/gym.py +++ b/torchrl/envs/libs/gym.py @@ -699,10 +699,13 @@ def _reset( self, tensordict: Optional[TensorDictBase] = None, **kwargs ) -> TensorDictBase: if self._is_batched: + # batched (aka 'vectorized') env reset is a bit special: envs are + # automatically reset. What we do here is just to check if _reset + # is present. If it is not, we just reset. Otherwise we just skip. if tensordict is None: return super()._reset(tensordict) reset = tensordict.get("_reset", None) - if reset is None or reset.all(): + if reset is None: return super()._reset(tensordict) elif reset is not None: return tensordict.clone(False) From 8aa2660aa645b9a785b490c5e87d987cea06a843 Mon Sep 17 00:00:00 2001 From: vmoens Date: Fri, 15 Sep 2023 13:19:30 -0400 Subject: [PATCH 10/16] init --- torchrl/envs/common.py | 11 ----------- torchrl/envs/libs/gym.py | 1 + 2 files changed, 1 insertion(+), 11 deletions(-) diff --git a/torchrl/envs/common.py b/torchrl/envs/common.py index 89c5ab68e80..771c529933c 100644 --- a/torchrl/envs/common.py +++ b/torchrl/envs/common.py @@ -1750,21 +1750,10 @@ def __init__( self._constructor_kwargs = kwargs self._check_kwargs(kwargs) self._env = self._build_env(**kwargs) # writes the self._env attribute - if self.batch_size in (None, torch.Size([])): - self.__dict__["_batch_size"] = self._get_batch_size(self._env) self._make_specs(self._env) # writes the self._env attribute self.is_closed = False self._init_env() # runs all the steps to have a ready-to-use env - def _get_batch_size(self, env): - """Batch-size adjustment. - - This is executed after super().__init__(), ie. when the batch-size has been set. - By default, it is a no-op. For some envs (batched envs) we adapt the batch-size - according to the number of sub-envs. See GymWrapper._get_batch_size for an example. - """ - return self.batch_size - @abc.abstractmethod def _check_kwargs(self, kwargs: Dict): raise NotImplementedError diff --git a/torchrl/envs/libs/gym.py b/torchrl/envs/libs/gym.py index dcb10c9c362..c05d6c02134 100644 --- a/torchrl/envs/libs/gym.py +++ b/torchrl/envs/libs/gym.py @@ -485,6 +485,7 @@ def _build_env( from_pixels: bool = False, pixels_only: bool = False, ) -> "gym.core.Env": # noqa: F821 + self.batch_size = self._get_batch_size(env) env_from_pixels = _is_from_pixels(env) from_pixels = from_pixels or env_from_pixels self.from_pixels = from_pixels From 49caa04950eb584811f3a6290129c67271c0a846 Mon Sep 17 00:00:00 2001 From: vmoens Date: Fri, 15 Sep 2023 15:32:31 -0400 Subject: [PATCH 11/16] fix --- test/test_libs.py | 14 ++++++++++++-- torchrl/envs/libs/gym.py | 7 ++++--- 2 files changed, 16 insertions(+), 5 deletions(-) diff --git a/test/test_libs.py b/test/test_libs.py index ee12b1f5c59..a659cf5ae98 100644 --- a/test/test_libs.py +++ b/test/test_libs.py @@ -332,9 +332,8 @@ def test_one_hot_and_categorical(self): # noqa: F811 ["HalfCheetah-v4", "CartPole-v1", "ALE/Pong-v5"] + (["FetchReach-v2"] if _has_gym_robotics else []), ) - def test_vecenvs(self, envname): + def test_vecenvs_wrapper(self, envname): import gymnasium - from _utils_internal import rollout_consistency_assertion # we can't use parametrize with implement_for env = GymWrapper( @@ -351,6 +350,17 @@ def test_vecenvs(self, envname): ) assert env.batch_size == torch.Size([2]) check_env_specs(env) + + @implement_for("gymnasium", "0.27.0", None) + # this env has Dict-based observation which is a nice thing to test + @pytest.mark.parametrize( + "envname", + ["HalfCheetah-v4", "CartPole-v1", "ALE/Pong-v5"] + + (["FetchReach-v2"] if _has_gym_robotics else []), + ) + def test_vecenvs_env(self, envname): + from _utils_internal import rollout_consistency_assertion + with set_gym_backend("gymnasium"): env = GymEnv(envname, num_envs=2, from_pixels=False) check_env_specs(env) diff --git a/torchrl/envs/libs/gym.py b/torchrl/envs/libs/gym.py index c05d6c02134..597579a09cd 100644 --- a/torchrl/envs/libs/gym.py +++ b/torchrl/envs/libs/gym.py @@ -486,6 +486,7 @@ def _build_env( pixels_only: bool = False, ) -> "gym.core.Env": # noqa: F821 self.batch_size = self._get_batch_size(env) + env_from_pixels = _is_from_pixels(env) from_pixels = from_pixels or env_from_pixels self.from_pixels = from_pixels @@ -803,9 +804,9 @@ def _build_env( raise err env = super()._build_env(env, pixels_only=pixels_only, from_pixels=from_pixels) if num_envs > 0: - return self._async_env([CloudpickleWrapper(lambda: env)] * num_envs) - else: - return env + env = self._async_env([CloudpickleWrapper(lambda: env)] * num_envs) + self.batch_size = torch.Size([num_envs, *self.batch_size]) + return env @implement_for("gym", None, "0.25.1") def _set_gym_default(self, kwargs, from_pixels: bool) -> None: # noqa: F811 From 0a4ea8e6d937617359edf68db85ebee18ced0665 Mon Sep 17 00:00:00 2001 From: vmoens Date: Fri, 15 Sep 2023 15:34:28 -0400 Subject: [PATCH 12/16] fix --- test/test_libs.py | 20 ++++++++++++++++++-- 1 file changed, 18 insertions(+), 2 deletions(-) diff --git a/test/test_libs.py b/test/test_libs.py index a659cf5ae98..e13a3d79c0e 100644 --- a/test/test_libs.py +++ b/test/test_libs.py @@ -371,7 +371,11 @@ def test_vecenvs_env(self, envname): ) @implement_for("gym", "0.18", "0.27.0") - def test_vecenvs(self): # noqa: F811 + @pytest.mark.parametrize( + "envname", + ["CartPole-v1", "HalfCheetah-v4"], + ) + def test_vecenvs_wrapper(self): # noqa: F811 import gym from _utils_internal import rollout_consistency_assertion @@ -391,6 +395,13 @@ def test_vecenvs(self): # noqa: F811 ) assert env.batch_size == torch.Size([2]) check_env_specs(env) + + @implement_for("gym", "0.18", "0.27.0") + @pytest.mark.parametrize( + "envname", + ["CartPole-v1", "HalfCheetah-v4"], + ) + def test_vecenvs_env(self): # noqa: F811 with set_gym_backend("gym"): env = GymEnv(envname, num_envs=2, from_pixels=False) check_env_specs(env) @@ -405,7 +416,12 @@ def test_vecenvs(self): # noqa: F811 check_env_specs(env) @implement_for("gym", None, "0.18") - def test_vecenvs(self): # noqa: F811 + def test_vecenvs_wrapper(self): # noqa: F811 + # skipping tests for older versions of gym + return + + @implement_for("gym", None, "0.18") + def test_vecenvs_env(self): # noqa: F811 # skipping tests for older versions of gym return From c188b88787a2e9777a3ae07b5a5daba17ad8b0b8 Mon Sep 17 00:00:00 2001 From: vmoens Date: Sun, 17 Sep 2023 12:20:16 -0400 Subject: [PATCH 13/16] amend --- test/test_libs.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/test_libs.py b/test/test_libs.py index e13a3d79c0e..fe75b2e8b48 100644 --- a/test/test_libs.py +++ b/test/test_libs.py @@ -375,7 +375,7 @@ def test_vecenvs_env(self, envname): "envname", ["CartPole-v1", "HalfCheetah-v4"], ) - def test_vecenvs_wrapper(self): # noqa: F811 + def test_vecenvs_wrapper(self, envname): # noqa: F811 import gym from _utils_internal import rollout_consistency_assertion @@ -401,7 +401,7 @@ def test_vecenvs_wrapper(self): # noqa: F811 "envname", ["CartPole-v1", "HalfCheetah-v4"], ) - def test_vecenvs_env(self): # noqa: F811 + def test_vecenvs_env(self, envname): # noqa: F811 with set_gym_backend("gym"): env = GymEnv(envname, num_envs=2, from_pixels=False) check_env_specs(env) From a329b80c8220b4770446feaa6442e92d9e9e0984 Mon Sep 17 00:00:00 2001 From: vmoens Date: Sun, 17 Sep 2023 12:22:04 -0400 Subject: [PATCH 14/16] amend --- test/test_libs.py | 54 +++++++++++++++++++++++++++-------------------- 1 file changed, 31 insertions(+), 23 deletions(-) diff --git a/test/test_libs.py b/test/test_libs.py index fe75b2e8b48..459a360d07e 100644 --- a/test/test_libs.py +++ b/test/test_libs.py @@ -33,7 +33,7 @@ get_default_devices, HALFCHEETAH_VERSIONED, PENDULUM_VERSIONED, - PONG_VERSIONED, + PONG_VERSIONED, rollout_consistency_assertion, ) from packaging import version from tensordict import LazyStackedTensorDict @@ -396,34 +396,42 @@ def test_vecenvs_wrapper(self, envname): # noqa: F811 assert env.batch_size == torch.Size([2]) check_env_specs(env) - @implement_for("gym", "0.18", "0.27.0") - @pytest.mark.parametrize( - "envname", - ["CartPole-v1", "HalfCheetah-v4"], - ) - def test_vecenvs_env(self, envname): # noqa: F811 - with set_gym_backend("gym"): - env = GymEnv(envname, num_envs=2, from_pixels=False) - check_env_specs(env) - rollout = env.rollout(100, break_when_any_done=False) - for obs_key in env.observation_spec.keys(True, True): - rollout_consistency_assertion( - rollout, done_key="done", observation_key=obs_key - ) - - with set_gym_backend("gym"): - env = GymEnv(envname, num_envs=2, from_pixels=True) - check_env_specs(env) + @implement_for("gym", "0.18", "0.27.0") + @pytest.mark.parametrize( + "envname", + ["CartPole-v1", "HalfCheetah-v4"], + ) + def test_vecenvs_env(self, envname): # noqa: F811 + with set_gym_backend("gym"): + env = GymEnv(envname, num_envs=2, from_pixels=False) + check_env_specs(env) + rollout = env.rollout(100, break_when_any_done=False) + for obs_key in env.observation_spec.keys(True, True): + rollout_consistency_assertion( + rollout, done_key="done", observation_key=obs_key + ) + + with set_gym_backend("gym"): + env = GymEnv(envname, num_envs=2, from_pixels=True) + check_env_specs(env) @implement_for("gym", None, "0.18") - def test_vecenvs_wrapper(self): # noqa: F811 + @pytest.mark.parametrize( + "envname", + ["CartPole-v1", "HalfCheetah-v4"], + ) + def test_vecenvs_wrapper(self, envname): # noqa: F811 # skipping tests for older versions of gym - return + ... @implement_for("gym", None, "0.18") - def test_vecenvs_env(self): # noqa: F811 + @pytest.mark.parametrize( + "envname", + ["CartPole-v1", "HalfCheetah-v4"], + ) + def test_vecenvs_env(self, envname): # noqa: F811 # skipping tests for older versions of gym - return + ... @implement_for("gym", None, "0.26") From 8d6e9c73b74ab76692b19d5458559ab62ed57fda Mon Sep 17 00:00:00 2001 From: vmoens Date: Sun, 17 Sep 2023 12:24:46 -0400 Subject: [PATCH 15/16] amend --- test/test_libs.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/test/test_libs.py b/test/test_libs.py index 459a360d07e..49c670841ec 100644 --- a/test/test_libs.py +++ b/test/test_libs.py @@ -410,10 +410,10 @@ def test_vecenvs_env(self, envname): # noqa: F811 rollout_consistency_assertion( rollout, done_key="done", observation_key=obs_key ) - - with set_gym_backend("gym"): - env = GymEnv(envname, num_envs=2, from_pixels=True) - check_env_specs(env) + if envname != "CartPole-v1": + with set_gym_backend("gym"): + env = GymEnv(envname, num_envs=2, from_pixels=True) + check_env_specs(env) @implement_for("gym", None, "0.18") @pytest.mark.parametrize( From f2f574484cdb50f0ae94e4e089c6acbdd9a2b2f3 Mon Sep 17 00:00:00 2001 From: vmoens Date: Sun, 17 Sep 2023 15:33:28 -0400 Subject: [PATCH 16/16] amend --- test/test_libs.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/test_libs.py b/test/test_libs.py index 49c670841ec..ef2be615c2e 100644 --- a/test/test_libs.py +++ b/test/test_libs.py @@ -33,7 +33,8 @@ get_default_devices, HALFCHEETAH_VERSIONED, PENDULUM_VERSIONED, - PONG_VERSIONED, rollout_consistency_assertion, + PONG_VERSIONED, + rollout_consistency_assertion, ) from packaging import version from tensordict import LazyStackedTensorDict @@ -377,7 +378,6 @@ def test_vecenvs_env(self, envname): ) def test_vecenvs_wrapper(self, envname): # noqa: F811 import gym - from _utils_internal import rollout_consistency_assertion # we can't use parametrize with implement_for for envname in ["CartPole-v1", "HalfCheetah-v4"]: