diff --git a/.github/unittest/linux_libs/scripts_gym/batch_scripts.sh b/.github/unittest/linux_libs/scripts_gym/batch_scripts.sh index ee1145d3f93..3fadedb0ffd 100755 --- a/.github/unittest/linux_libs/scripts_gym/batch_scripts.sh +++ b/.github/unittest/linux_libs/scripts_gym/batch_scripts.sh @@ -155,6 +155,7 @@ do pip install gymnasium[atari] fi pip install mo-gymnasium + pip install gymnasium-robotics $DIR/run_test.sh diff --git a/docs/source/reference/envs.rst b/docs/source/reference/envs.rst index 212a566aac3..b987934039f 100644 --- a/docs/source/reference/envs.rst +++ b/docs/source/reference/envs.rst @@ -494,6 +494,7 @@ to be able to create this other composition: TimeMaxPool ToTensorImage UnsqueezeTransform + VecGymEnvTransform VecNorm VC1Transform VIPRewardTransform diff --git a/test/_utils_internal.py b/test/_utils_internal.py index 84dc6f33626..1af5a588e74 100644 --- a/test/_utils_internal.py +++ b/test/_utils_internal.py @@ -321,3 +321,29 @@ class MyClass: for key in td.keys(): MyClass.__annotations__[key] = torch.Tensor return tensorclass(MyClass) + + +def rollout_consistency_assertion( + rollout, *, done_key="done", observation_key="observation" +): + """Tests that observations in "next" match observations in the next root tensordict when done is False, and don't match otherwise.""" + + done = rollout[:, :-1]["next", done_key].squeeze(-1) + # data resulting from step, when it's not done + r_not_done = rollout[:, :-1]["next"][~done] + # data resulting from step, when it's not done, after step_mdp + r_not_done_tp1 = rollout[:, 1:][~done] + torch.testing.assert_close( + r_not_done[observation_key], r_not_done_tp1[observation_key] + ) + + if not done.any(): + return + + # data resulting from step, when it's done + r_done = rollout[:, :-1]["next"][done] + # data resulting from step, when it's done, after step_mdp and reset + r_done_tp1 = rollout[:, 1:][done] + assert ( + (r_done[observation_key] - r_done_tp1[observation_key]).norm(dim=-1) > 1e-1 + ).all(), (r_done[observation_key] - r_done_tp1[observation_key]).norm(dim=-1) diff --git a/test/test_libs.py b/test/test_libs.py index bb8444e22b4..ef2be615c2e 100644 --- a/test/test_libs.py +++ b/test/test_libs.py @@ -34,6 +34,7 @@ HALFCHEETAH_VERSIONED, PENDULUM_VERSIONED, PONG_VERSIONED, + rollout_consistency_assertion, ) from packaging import version from tensordict import LazyStackedTensorDict @@ -67,12 +68,14 @@ GymWrapper, MOGymEnv, MOGymWrapper, + set_gym_backend, ) from torchrl.envs.libs.habitat import _has_habitat, HabitatEnv from torchrl.envs.libs.jumanji import _has_jumanji, JumanjiEnv from torchrl.envs.libs.openml import OpenMLEnv from torchrl.envs.libs.pettingzoo import _has_pettingzoo, PettingZooEnv from torchrl.envs.libs.robohive import RoboHiveEnv +from torchrl.envs.libs.smacv2 import _has_smacv2, SMACv2Env from torchrl.envs.libs.vmas import _has_vmas, VmasEnv, VmasWrapper from torchrl.envs.utils import check_env_specs, ExplorationType, MarlGroupMapType from torchrl.modules import ActorCriticOperator, MLP, SafeModule, ValueOperator @@ -83,7 +86,7 @@ _has_sklearn = importlib.util.find_spec("sklearn") is not None -from torchrl.envs.libs.smacv2 import _has_smacv2, SMACv2Env +_has_gym_robotics = importlib.util.find_spec("gymnasium_robotics") is not None if _has_gym: try: @@ -323,6 +326,113 @@ def test_one_hot_and_categorical(self): # noqa: F811 # versions. return + @implement_for("gymnasium", "0.27.0", None) + # this env has Dict-based observation which is a nice thing to test + @pytest.mark.parametrize( + "envname", + ["HalfCheetah-v4", "CartPole-v1", "ALE/Pong-v5"] + + (["FetchReach-v2"] if _has_gym_robotics else []), + ) + def test_vecenvs_wrapper(self, envname): + import gymnasium + + # we can't use parametrize with implement_for + env = GymWrapper( + gymnasium.vector.SyncVectorEnv( + 2 * [lambda envname=envname: gymnasium.make(envname)] + ) + ) + assert env.batch_size == torch.Size([2]) + check_env_specs(env) + env = GymWrapper( + gymnasium.vector.AsyncVectorEnv( + 2 * [lambda envname=envname: gymnasium.make(envname)] + ) + ) + assert env.batch_size == torch.Size([2]) + check_env_specs(env) + + @implement_for("gymnasium", "0.27.0", None) + # this env has Dict-based observation which is a nice thing to test + @pytest.mark.parametrize( + "envname", + ["HalfCheetah-v4", "CartPole-v1", "ALE/Pong-v5"] + + (["FetchReach-v2"] if _has_gym_robotics else []), + ) + def test_vecenvs_env(self, envname): + from _utils_internal import rollout_consistency_assertion + + with set_gym_backend("gymnasium"): + env = GymEnv(envname, num_envs=2, from_pixels=False) + check_env_specs(env) + rollout = env.rollout(100, break_when_any_done=False) + for obs_key in env.observation_spec.keys(True, True): + rollout_consistency_assertion( + rollout, done_key="done", observation_key=obs_key + ) + + @implement_for("gym", "0.18", "0.27.0") + @pytest.mark.parametrize( + "envname", + ["CartPole-v1", "HalfCheetah-v4"], + ) + def test_vecenvs_wrapper(self, envname): # noqa: F811 + import gym + + # we can't use parametrize with implement_for + for envname in ["CartPole-v1", "HalfCheetah-v4"]: + env = GymWrapper( + gym.vector.SyncVectorEnv( + 2 * [lambda envname=envname: gym.make(envname)] + ) + ) + assert env.batch_size == torch.Size([2]) + check_env_specs(env) + env = GymWrapper( + gym.vector.AsyncVectorEnv( + 2 * [lambda envname=envname: gym.make(envname)] + ) + ) + assert env.batch_size == torch.Size([2]) + check_env_specs(env) + + @implement_for("gym", "0.18", "0.27.0") + @pytest.mark.parametrize( + "envname", + ["CartPole-v1", "HalfCheetah-v4"], + ) + def test_vecenvs_env(self, envname): # noqa: F811 + with set_gym_backend("gym"): + env = GymEnv(envname, num_envs=2, from_pixels=False) + check_env_specs(env) + rollout = env.rollout(100, break_when_any_done=False) + for obs_key in env.observation_spec.keys(True, True): + rollout_consistency_assertion( + rollout, done_key="done", observation_key=obs_key + ) + if envname != "CartPole-v1": + with set_gym_backend("gym"): + env = GymEnv(envname, num_envs=2, from_pixels=True) + check_env_specs(env) + + @implement_for("gym", None, "0.18") + @pytest.mark.parametrize( + "envname", + ["CartPole-v1", "HalfCheetah-v4"], + ) + def test_vecenvs_wrapper(self, envname): # noqa: F811 + # skipping tests for older versions of gym + ... + + @implement_for("gym", None, "0.18") + @pytest.mark.parametrize( + "envname", + ["CartPole-v1", "HalfCheetah-v4"], + ) + def test_vecenvs_env(self, envname): # noqa: F811 + # skipping tests for older versions of gym + ... + @implement_for("gym", None, "0.26") def _make_gym_environment(env_name): # noqa: F811 diff --git a/torchrl/_utils.py b/torchrl/_utils.py index 8d590b05210..7f5b80175be 100644 --- a/torchrl/_utils.py +++ b/torchrl/_utils.py @@ -285,8 +285,8 @@ def module_set(self): cls = inspect.getmodule(self.fn) setattr(cls, self.fn.__name__, self.fn) - @staticmethod - def import_module(module_name: Union[Callable, str]) -> str: + @classmethod + def import_module(cls, module_name: Union[Callable, str]) -> str: """Imports module and returns its version.""" if not callable(module_name): module = import_module(module_name) diff --git a/torchrl/envs/common.py b/torchrl/envs/common.py index e50621144cf..771c529933c 100644 --- a/torchrl/envs/common.py +++ b/torchrl/envs/common.py @@ -617,7 +617,7 @@ def action_spec(self, value: TensorSpec) -> None: ) if value.shape[: len(self.batch_size)] != self.batch_size: raise ValueError( - "The value of spec.shape must match the env batch size." + f"The value of spec.shape ({value.shape}) must match the env batch size ({self.batch_size})." ) if isinstance(value, CompositeSpec): @@ -791,7 +791,7 @@ def reward_spec(self, value: TensorSpec) -> None: ) if value.shape[: len(self.batch_size)] != self.batch_size: raise ValueError( - "The value of spec.shape must match the env batch size." + f"The value of spec.shape ({value.shape}) must match the env batch size ({self.batch_size})." ) if isinstance(value, CompositeSpec): for _ in value.values(True, True): # noqa: B007 @@ -820,6 +820,15 @@ def reward_spec(self, value: TensorSpec) -> None: # done spec def _get_done_keys(self): + if "full_done_spec" not in self.output_spec.keys(): + # populate the "done" entry + # this will be raised if there is not full_done_spec (unlikely) or no done_key + # Since output_spec is lazily populated with an empty composite spec for + # done_spec, the second case is much more likely to occur. + self.done_spec = DiscreteTensorSpec( + n=2, shape=(*self.batch_size, 1), dtype=torch.bool, device=self.device + ) + keys = self.output_spec["full_done_spec"].keys(True, True) if not len(keys): raise AttributeError("Could not find done spec") @@ -967,7 +976,7 @@ def done_spec(self, value: TensorSpec) -> None: ) if value.shape[: len(self.batch_size)] != self.batch_size: raise ValueError( - "The value of spec.shape must match the env batch size." + f"The value of spec.shape ({value.shape}) must match the env batch size ({self.batch_size})." ) if isinstance(value, CompositeSpec): for _ in value.values(True, True): # noqa: B007 diff --git a/torchrl/envs/gym_like.py b/torchrl/envs/gym_like.py index 289bb731278..c5b5ef1ad55 100644 --- a/torchrl/envs/gym_like.py +++ b/torchrl/envs/gym_like.py @@ -115,11 +115,11 @@ class GymLikeEnv(_EnvWrapper): It is also expected that env.reset() returns an observation similar to the one observed after a step is completed. """ - _info_dict_reader: BaseInfoDictReader + _info_dict_reader: List[BaseInfoDictReader] @classmethod def __new__(cls, *args, **kwargs): - cls._info_dict_reader = None + cls._info_dict_reader = [] return super().__new__(cls, *args, _batch_locked=True, **kwargs) def read_action(self, action): @@ -144,7 +144,7 @@ def read_done(self, done): done (np.ndarray, boolean or other format): done state obtained from the environment """ - return done, done + return done, done.any() if not isinstance(done, bool) else done def read_reward(self, reward): """Reads the reward and maps it to the reward space. @@ -231,8 +231,11 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase: tensordict_out = TensorDict(obs_dict, batch_size=tensordict.batch_size) - if self.info_dict_reader is not None and info is not None: - self.info_dict_reader(info, tensordict_out) + if self.info_dict_reader and info is not None: + for info_dict_reader in self.info_dict_reader: + out = info_dict_reader(info, tensordict_out) + if out is not None: + tensordict_out = out tensordict_out = tensordict_out.to(self.device, non_blocking=True) return tensordict_out @@ -255,9 +258,12 @@ def _reset( source=source, batch_size=self.batch_size, ) - if self.info_dict_reader is not None and info is not None: - self.info_dict_reader(info, tensordict_out) - elif info is None and self.info_dict_reader is not None: + if self.info_dict_reader and info is not None: + for info_dict_reader in self.info_dict_reader: + out = info_dict_reader(info, tensordict_out) + if out is not None: + tensordict_out = out + elif info is None and self.info_dict_reader: # populate the reset with the items we have not seen from info for key, item in self.observation_spec.items(True, True): if key not in tensordict_out.keys(True, True): @@ -298,9 +304,12 @@ def set_info_dict_reader(self, info_dict_reader: BaseInfoDictReader) -> GymLikeE >>> assert "my_info_key" in tensordict.keys() """ - self.info_dict_reader = info_dict_reader - for info_key, spec in info_dict_reader.info_spec.items(): - self.observation_spec[info_key] = spec.to(self.device) + self.info_dict_reader.append(info_dict_reader) + if isinstance(info_dict_reader, BaseInfoDictReader): + # if we have a BaseInfoDictReader, we know what the specs will be + # In other cases (eg, RoboHive) we will need to figure it out empirically. + for info_key, spec in info_dict_reader.info_spec.items(): + self.observation_spec[info_key] = spec.to(self.device) return self def __repr__(self) -> str: @@ -314,4 +323,10 @@ def info_dict_reader(self): @info_dict_reader.setter def info_dict_reader(self, value: callable): - self._info_dict_reader = value + warnings.warn( + f"Please use {type(self)}.set_info_dict_reader method to set a new info reader. " + f"This method will append a reader to the list of existing readers (if any). " + f"Setting info_dict_reader directly will be soon deprecated.", + category=DeprecationWarning, + ) + self._info_dict_reader.append(value) diff --git a/torchrl/envs/libs/gym.py b/torchrl/envs/libs/gym.py index 51fc6588217..597579a09cd 100644 --- a/torchrl/envs/libs/gym.py +++ b/torchrl/envs/libs/gym.py @@ -2,21 +2,18 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +import abc import importlib import warnings from copy import copy from types import ModuleType -from typing import Dict, List +from typing import Dict, List, Optional from warnings import warn +import numpy as np import torch -from torchrl.envs.utils import _classproperty - -try: - from torch.utils._contextlib import _DecoratorContextManager -except ModuleNotFoundError: - from torchrl._utils import _DecoratorContextManager +from tensordict import TensorDictBase from torchrl._utils import implement_for from torchrl.data.tensor_specs import ( @@ -31,8 +28,20 @@ UnboundedContinuousTensorSpec, ) from torchrl.data.utils import numpy_to_torch_dtype_dict +from torchrl.envs.batched_envs import CloudpickleWrapper -from torchrl.envs.gym_like import default_info_dict_reader, GymLikeEnv +from torchrl.envs.gym_like import ( + BaseInfoDictReader, + default_info_dict_reader, + GymLikeEnv, +) + +from torchrl.envs.utils import _classproperty + +try: + from torch.utils._contextlib import _DecoratorContextManager +except ModuleNotFoundError: + from torchrl._utils import _DecoratorContextManager DEFAULT_GYM = None IMPORT_ERROR = None @@ -42,6 +51,7 @@ _has_gym = importlib.util.find_spec("gymnasium") is not None _has_mo = importlib.util.find_spec("mo_gymnasium") is not None +_has_sb3 = importlib.util.find_spec("stable_baselines3") is not None class set_gym_backend(_DecoratorContextManager): @@ -199,7 +209,18 @@ def _gym_to_torchrl_spec_transform( """ gym = gym_backend() if isinstance(spec, gym.spaces.tuple.Tuple): - raise NotImplementedError("gym.spaces.tuple.Tuple mapping not yet implemented") + return torch.stack( + [ + _gym_to_torchrl_spec_transform( + s, + device=device, + categorical_action_encoding=categorical_action_encoding, + remap_state_to_observation=remap_state_to_observation, + ) + for s in spec + ], + 0, + ) if isinstance(spec, gym.spaces.discrete.Discrete): action_space_cls = ( DiscreteTensorSpec @@ -217,15 +238,32 @@ def _gym_to_torchrl_spec_transform( spec.n, device=device, dtype=numpy_to_torch_dtype_dict[spec.dtype] ) elif isinstance(spec, gym.spaces.multi_discrete.MultiDiscrete): - dtype = ( - numpy_to_torch_dtype_dict[spec.dtype] - if categorical_action_encoding - else torch.long - ) - return ( - MultiDiscreteTensorSpec(spec.nvec, device=device, dtype=dtype) - if categorical_action_encoding - else MultiOneHotDiscreteTensorSpec(spec.nvec, device=device, dtype=dtype) + if len(spec.nvec.shape) == 1 and len(np.unique(spec.nvec)) > 1: + dtype = ( + numpy_to_torch_dtype_dict[spec.dtype] + if categorical_action_encoding + else torch.long + ) + + return ( + MultiDiscreteTensorSpec(spec.nvec, device=device, dtype=dtype) + if categorical_action_encoding + else MultiOneHotDiscreteTensorSpec( + spec.nvec, device=device, dtype=dtype + ) + ) + + return torch.stack( + [ + _gym_to_torchrl_spec_transform( + spec[i], + device=device, + categorical_action_encoding=categorical_action_encoding, + remap_state_to_observation=remap_state_to_observation, + ) + for i in range(len(spec.nvec)) + ], + 0, ) elif isinstance(spec, gym.spaces.Box): shape = spec.shape @@ -343,7 +381,32 @@ class PixelObservationWrapper: return False -class GymWrapper(GymLikeEnv): +class _AsyncMeta(abc.ABCMeta): + def __call__(cls, *args, **kwargs): + instance: GymWrapper = super().__call__(*args, **kwargs) + if instance._is_batched: + from torchrl.envs.transforms.transforms import ( + TransformedEnv, + VecGymEnvTransform, + ) + + if _has_sb3: + from stable_baselines3.common.vec_env.base_vec_env import VecEnv + + if isinstance(instance._env, VecEnv): + backend = "sb3" + else: + backend = "gym" + else: + backend = "gym" + instance.set_info_dict_reader( + terminal_obs_reader(instance.observation_spec, backend=backend) + ) + return TransformedEnv(instance, VecGymEnvTransform()) + return instance + + +class GymWrapper(GymLikeEnv, metaclass=_AsyncMeta): """OpenAI Gym environment wrapper. Examples: @@ -390,6 +453,25 @@ def __init__(self, env=None, categorical_action_encoding=False, **kwargs): else: super().__init__(**kwargs) + @property + def _is_batched(self): + if _has_sb3: + from stable_baselines3.common.vec_env.base_vec_env import VecEnv + + tuple_of_classes = (VecEnv,) + else: + tuple_of_classes = () + return isinstance( + self._env, tuple_of_classes + (gym_backend("vector").VectorEnv,) + ) + + def _get_batch_size(self, env): + if hasattr(env, "num_envs"): + batch_size = torch.Size([env.num_envs, *self.batch_size]) + else: + batch_size = self.batch_size + return batch_size + def _check_kwargs(self, kwargs: Dict): if "env" not in kwargs: raise TypeError("Could not find environment key 'env' in kwargs.") @@ -403,6 +485,8 @@ def _build_env( from_pixels: bool = False, pixels_only: bool = False, ) -> "gym.core.Env": # noqa: F821 + self.batch_size = self._get_batch_size(env) + env_from_pixels = _is_from_pixels(env) from_pixels = from_pixels or env_from_pixels self.from_pixels = from_pixels @@ -556,9 +640,16 @@ def _make_specs(self, env: "gym.Env", batch_size=None) -> None: # noqa: F821 ) if not isinstance(observation_spec, CompositeSpec): if self.from_pixels: - observation_spec = CompositeSpec(pixels=observation_spec) + observation_spec = CompositeSpec( + pixels=observation_spec, shape=self.batch_size + ) else: - observation_spec = CompositeSpec(observation=observation_spec) + observation_spec = CompositeSpec( + observation=observation_spec, shape=self.batch_size + ) + elif observation_spec.shape[: len(self.batch_size)] != self.batch_size: + observation_spec.shape = self.batch_size + if hasattr(env, "reward_space") and env.reward_space is not None: reward_spec = _gym_to_torchrl_spec_transform( env.reward_space, @@ -577,7 +668,10 @@ def _make_specs(self, env: "gym.Env", batch_size=None) -> None: # noqa: F821 *batch_size, *observation_spec.shape ) self.action_spec = action_spec - self.reward_spec = reward_spec + if reward_spec.shape[: len(self.batch_size)] != self.batch_size: + self.reward_spec = reward_spec.expand(*self.batch_size, *reward_spec.shape) + else: + self.reward_spec = reward_spec self.observation_spec = observation_spec def _init_env(self): @@ -595,14 +689,30 @@ def rebuild_with_kwargs(self, **new_kwargs): @property def info_dict_reader(self): - if self._info_dict_reader is None: - self._info_dict_reader = default_info_dict_reader() + if not self._info_dict_reader: + self._info_dict_reader.append(default_info_dict_reader()) return self._info_dict_reader @info_dict_reader.setter def info_dict_reader(self, value: callable): self._info_dict_reader = value + def _reset( + self, tensordict: Optional[TensorDictBase] = None, **kwargs + ) -> TensorDictBase: + if self._is_batched: + # batched (aka 'vectorized') env reset is a bit special: envs are + # automatically reset. What we do here is just to check if _reset + # is present. If it is not, we just reset. Otherwise we just skip. + if tensordict is None: + return super()._reset(tensordict) + reset = tensordict.get("_reset", None) + if reset is None: + return super()._reset(tensordict) + elif reset is not None: + return tensordict.clone(False) + return super()._reset(tensordict, **kwargs) + ACCEPTED_TYPE_ERRORS = { "render_mode": "__init__() got an unexpected keyword argument 'render_mode'", @@ -648,6 +758,9 @@ def _set_gym_args( # noqa: F811 ) -> None: kwargs.setdefault("disable_env_checker", True) + def _async_env(self, *args, **kwargs): + return gym_backend("vector").AsyncVectorEnv(*args, **kwargs) + def _build_env( self, env_name: str, @@ -659,13 +772,10 @@ def _build_env( f"Consider downloading and installing gym from" f" {self.git_url}" ) - from_pixels = kwargs.get("from_pixels", False) + from_pixels = kwargs.pop("from_pixels", False) self._set_gym_default(kwargs, from_pixels) - if "from_pixels" in kwargs: - del kwargs["from_pixels"] - pixels_only = kwargs.get("pixels_only", True) - if "pixels_only" in kwargs: - del kwargs["pixels_only"] + pixels_only = kwargs.pop("pixels_only", True) + num_envs = kwargs.pop("num_envs", 0) made_env = False kwargs["frameskip"] = self.frame_skip self.wrapper_frame_skip = 1 @@ -692,7 +802,11 @@ def _build_env( kwargs.pop("render_mode") else: raise err - return super()._build_env(env, pixels_only=pixels_only, from_pixels=from_pixels) + env = super()._build_env(env, pixels_only=pixels_only, from_pixels=from_pixels) + if num_envs > 0: + env = self._async_env([CloudpickleWrapper(lambda: env)] * num_envs) + self.batch_size = torch.Size([num_envs, *self.batch_size]) + return env @implement_for("gym", None, "0.25.1") def _set_gym_default(self, kwargs, from_pixels: bool) -> None: # noqa: F811 @@ -766,3 +880,88 @@ def lib(self) -> ModuleType: raise ImportError("MO-gymnasium not found, check installation") from err _make_specs = set_gym_backend("gymnasium")(GymEnv._make_specs) + + +class terminal_obs_reader(BaseInfoDictReader): + """Terminal observation reader for 'vectorized' gym environments. + + When running envs in parallel, Gym(nasium) writes the result of the true call + to `step` in `"final_observation"` entry within the `info` dictionary. + + This breaks the natural flow and makes single-processed and multiprocessed envs + incompatible. + + This class reads the info obs, removes the `"final_observation"` from + the env and writes its content in the data. + + Next, a :class:`torchrl.envs.VecGymEnvTransform` transform will reorganise the + data by caching the result of the (implicit) reset and swap the true next + observation with the reset one. At reset time, the true reset data will be + replaced. + + Args: + observation_spec (CompositeSpec): The observation spec of the gym env. + backend (str, optional): the backend of the env. One of `"sb3"` for + stable-baselines3 or `"gym"` for gym/gymnasium. + + .. note:: In general, this class should not be handled directly. It is + created whenever a vectorized environment is placed within a :class:`GymWrapper`. + + """ + + backend_key = { + "sb3": "terminal_observation", + "gym": "final_observation", + } + + def __init__(self, observation_spec: CompositeSpec, backend, name="final"): + self.name = name + self._info_spec = CompositeSpec( + {(self.name, key): item.clone() for key, item in observation_spec.items()}, + shape=observation_spec.shape, + ) + self.backend = backend + + @property + def info_spec(self): + return self._info_spec + + def _read_obs(self, obs, key, tensor, index): + if obs is None: + return + if isinstance(obs, np.ndarray): + # Simplest case: there is one observation, + # presented as a np.ndarray. The key should be pixels or observation. + # We just write that value at its location in the tensor + tensor[index] = torch.as_tensor(obs, device=tensor.device) + elif isinstance(obs, dict): + if key not in obs: + raise KeyError( + f"The observation {key} could not be found in the final observation dict." + ) + subobs = obs[key] + if subobs is not None: + # if the obs is a dict, we expect that the key points also to + # a value in the obs. We retrieve this value and write it in the + # tensor + tensor[index] = torch.as_tensor(subobs, device=tensor.device) + + elif isinstance(obs, (list, tuple)): + # tuples are stacked along the first dimension when passing gym spaces + # to torchrl specs. As such, we can simply stack the tuple and set it + # at the relevant index (assuming stacking can be achieved) + tensor[index] = torch.as_tensor(obs, device=tensor.device) + else: + raise NotImplementedError( + f"Observations of type {type(obs)} are not supported yet." + ) + + def __call__(self, info_dict, tensordict): + terminal_obs = info_dict.get(self.backend_key[self.backend], None) + for key, item in self.info_spec.items(True, True): + final_obs = item.zero() + if terminal_obs is not None: + for i, obs in enumerate(terminal_obs): + self._read_obs(obs, key[-1], final_obs, index=i) + tensordict.set(key, final_obs) + return tensordict diff --git a/torchrl/envs/libs/robohive.py b/torchrl/envs/libs/robohive.py index 655a4eee16e..a1fc1d6d1ba 100644 --- a/torchrl/envs/libs/robohive.py +++ b/torchrl/envs/libs/robohive.py @@ -18,7 +18,8 @@ from torchrl.envs.libs.gym import _gym_to_torchrl_spec_transform, GymEnv from torchrl.envs.utils import _classproperty, make_composite_from_td -_has_robohive = importlib.util.find_spec("robohive") is not None +_has_gym = importlib.util.find_spec("gym") is not None +_has_robohive = importlib.util.find_spec("robohive") is not None and _has_gym if _has_robohive: os.environ.setdefault("sim_backend", "MUJOCO") @@ -164,7 +165,7 @@ def _build_env( # noqa: F811 self.from_pixels = from_pixels self.render_device = render_device if kwargs.get("read_info", True): - self.info_dict_reader = self.read_info + self.set_info_dict_reader(self.read_info) return env @classmethod diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index e37e079e63b..30182870d7e 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -655,7 +655,9 @@ def _set_seed(self, seed: Optional[int]): def _reset(self, tensordict: Optional[TensorDictBase] = None, **kwargs): if tensordict is not None: tensordict = tensordict.clone(recurse=False) - out_tensordict = self.base_env.reset(tensordict=tensordict, **kwargs) + out_tensordict = self.base_env._reset(tensordict=tensordict, **kwargs) + if tensordict is not None: + out_tensordict = tensordict.update(out_tensordict) out_tensordict = self.transform.reset(out_tensordict) mt_mode = self.transform.missing_tolerance @@ -5030,3 +5032,84 @@ def reset(self, tensordict: TensorDictBase) -> TensorDictBase: ) action_spec.update_mask(tensordict.get(self.in_keys[1], None)) return tensordict + + +class VecGymEnvTransform(Transform): + """A transform for GymWrapper subclasses that handles the auto-reset in a consistent way. + + Gym, gymnasium and SB3 provide vectorized (read, parallel or batched) environments + that are automatically reset. When this occurs, the actual observation resulting + from the action is saved within a key in the info. + The class :class:`torchrl.envs.libs.gym.terminal_obs_reader` reads that observation + and stores it in a ``"final"`` key within the output tensordict. + In turn, this transform reads that final data, swaps it with the observation + written in its place that results from the actual reset, and saves the + reset output in a private container. The resulting data truly reflects + the output of the step. + + Then, when calling `env.reset`, the saved data is written back where it belongs + (and the `reset` is a no-op). + + This transform is automatically appended to the gym env whenever the wrapper + is created with an async env. + + Args: + final_name (str, optional): the name of the final observation in the dict. + Defaults to `"final"`. + + .. note:: In general, this class should not be handled directly. It is + created whenever a vectorized environment is placed within a :class:`GymWrapper`. + + """ + + def __init__(self, final_name="final"): + self.final_name = final_name + super().__init__(in_keys=[]) + self._memo = {} + + def _step( + self, tensordict: TensorDictBase, next_tensordict: TensorDictBase + ) -> TensorDictBase: + # save the final info + done = self._memo["done"] = next_tensordict.get("done") + final = next_tensordict.pop("final") + # if anything's done, we need to swap the final obs + if done.any(): + done = done.squeeze(-1) + saved_next = next_tensordict.select(*final.keys(True, True)).clone() + next_tensordict[done] = final[done] + self._memo["saved_done"] = saved_next + else: + self._memo["saved_done"] = None + return next_tensordict + + def reset(self, tensordict: TensorDictBase) -> TensorDictBase: + done = self._memo.get("done", None) + reset = tensordict.get("_reset", done) + if done is not None: + done = done.view_as(reset) + if ( + reset is not done + and (reset != done).any() + and (not reset.all() or not reset.any()) + ): + raise RuntimeError( + "Cannot partially reset a gym(nasium) async env with a reset mask that does not match the done mask. " + f"Got reset={reset}\nand done={done}" + ) + # if not reset.any(), we don't need to do anything. + # if reset.all(), we don't either (bc GymWrapper will call a plain reset). + if reset is not None and reset.any() and not reset.all(): + saved_done = self._memo["saved_done"] + reset = reset.view(tensordict.shape) + updated_td = torch.where( + ~reset, tensordict.select(*saved_done.keys(True, True)), saved_done + ) + tensordict.update(updated_td) + tensordict.set("done", tensordict.get("done").clone().fill_(0)) + tensordict.pop("final", None) + return tensordict + + def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec: + del observation_spec[self.final_name] + return observation_spec