Skip to content

Commit 160e6af

Browse files
author
Vincent Moens
committed
[Feature] Gymnasium 1.0 compatibility
ghstack-source-id: eed26cf Pull Request resolved: #2473
1 parent fac1f7a commit 160e6af

File tree

3 files changed

+79
-28
lines changed

3 files changed

+79
-28
lines changed

.github/unittest/linux_libs/scripts_gym/batch_scripts.sh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ do
118118
done
119119

120120
# For this version "gym[accept-rom-license]" is required.
121-
for GYM_VERSION in '0.27' '0.28'
121+
for GYM_VERSION in '0.27' '0.28' '0.29'
122122
do
123123
# Create a copy of the conda env and work with this
124124
conda deactivate
@@ -140,7 +140,7 @@ conda deactivate
140140
conda create --prefix ./cloned_env --clone ./env -y
141141
conda activate ./cloned_env
142142

143-
pip3 install 'gymnasium[accept-rom-license,ale-py,atari]<1.0' mo-gymnasium gymnasium-robotics -U
143+
pip3 install 'gymnasium[accept-rom-license,ale-py,atari]' mo-gymnasium gymnasium-robotics -U
144144

145145
$DIR/run_test.sh
146146

torchrl/envs/common.py

Lines changed: 72 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1622,9 +1622,9 @@ def register_gym(
16221622
nondeterministic: bool = False,
16231623
max_episode_steps: int | None = None,
16241624
order_enforce: bool = True,
1625-
autoreset: bool = False,
1625+
autoreset: bool | None = None,
16261626
disable_env_checker: bool = False,
1627-
apply_api_compatibility: bool = False,
1627+
apply_api_compatibility: bool | None = None,
16281628
**kwargs,
16291629
):
16301630
"""Registers an environment in gym(nasium).
@@ -1811,9 +1811,9 @@ def _register_gym(
18111811
nondeterministic: bool = False,
18121812
max_episode_steps: int | None = None,
18131813
order_enforce: bool = True,
1814-
autoreset: bool = False,
1814+
autoreset: bool | None = None,
18151815
disable_env_checker: bool = False,
1816-
apply_api_compatibility: bool = False,
1816+
apply_api_compatibility: bool | None = None,
18171817
**kwargs,
18181818
):
18191819
import gym
@@ -1836,9 +1836,9 @@ def _register_gym(
18361836
nondeterministic=nondeterministic,
18371837
max_episode_steps=max_episode_steps,
18381838
order_enforce=order_enforce,
1839-
autoreset=autoreset,
1839+
autoreset=bool(autoreset),
18401840
disable_env_checker=disable_env_checker,
1841-
apply_api_compatibility=apply_api_compatibility,
1841+
apply_api_compatibility=bool(apply_api_compatibility),
18421842
)
18431843

18441844
@implement_for("gym", "0.25", "0.26", class_method=True)
@@ -1853,14 +1853,14 @@ def _register_gym( # noqa: F811
18531853
nondeterministic: bool = False,
18541854
max_episode_steps: int | None = None,
18551855
order_enforce: bool = True,
1856-
autoreset: bool = False,
1856+
autoreset: bool | None = None,
18571857
disable_env_checker: bool = False,
1858-
apply_api_compatibility: bool = False,
1858+
apply_api_compatibility: bool | None = None,
18591859
**kwargs,
18601860
):
18611861
import gym
18621862

1863-
if apply_api_compatibility is not False:
1863+
if apply_api_compatibility is not None:
18641864
raise TypeError(
18651865
cls._GYM_UNRECOGNIZED_KWARG.format(
18661866
"apply_api_compatibility", gym.__version__
@@ -1901,14 +1901,14 @@ def _register_gym( # noqa: F811
19011901
nondeterministic: bool = False,
19021902
max_episode_steps: int | None = None,
19031903
order_enforce: bool = True,
1904-
autoreset: bool = False,
1904+
autoreset: bool | None = None,
19051905
disable_env_checker: bool = False,
1906-
apply_api_compatibility: bool = False,
1906+
apply_api_compatibility: bool | None = None,
19071907
**kwargs,
19081908
):
19091909
import gym
19101910

1911-
if apply_api_compatibility is not False:
1911+
if apply_api_compatibility is not None:
19121912
raise TypeError(
19131913
cls._GYM_UNRECOGNIZED_KWARG.format(
19141914
"apply_api_compatibility", gym.__version__
@@ -1954,14 +1954,14 @@ def _register_gym( # noqa: F811
19541954
nondeterministic: bool = False,
19551955
max_episode_steps: int | None = None,
19561956
order_enforce: bool = True,
1957-
autoreset: bool = False,
1957+
autoreset: bool | None = None,
19581958
disable_env_checker: bool = False,
1959-
apply_api_compatibility: bool = False,
1959+
apply_api_compatibility: bool | None = None,
19601960
**kwargs,
19611961
):
19621962
import gym
19631963

1964-
if apply_api_compatibility is not False:
1964+
if apply_api_compatibility is not None:
19651965
raise TypeError(
19661966
cls._GYM_UNRECOGNIZED_KWARG.format(
19671967
"apply_api_compatibility", gym.__version__
@@ -1973,7 +1973,7 @@ def _register_gym( # noqa: F811
19731973
"disable_env_checker", gym.__version__
19741974
)
19751975
)
1976-
if autoreset is not False:
1976+
if autoreset is not None:
19771977
raise TypeError(
19781978
cls._GYM_UNRECOGNIZED_KWARG.format("autoreset", gym.__version__)
19791979
)
@@ -2010,9 +2010,9 @@ def _register_gym( # noqa: F811
20102010
nondeterministic: bool = False,
20112011
max_episode_steps: int | None = None,
20122012
order_enforce: bool = True,
2013-
autoreset: bool = False,
2013+
autoreset: bool | None = None,
20142014
disable_env_checker: bool = False,
2015-
apply_api_compatibility: bool = False,
2015+
apply_api_compatibility: bool | None = None,
20162016
**kwargs,
20172017
):
20182018
import gym
@@ -2028,11 +2028,11 @@ def _register_gym( # noqa: F811
20282028
"disable_env_checker", gym.__version__
20292029
)
20302030
)
2031-
if autoreset is not False:
2031+
if autoreset is not None:
20322032
raise TypeError(
20332033
cls._GYM_UNRECOGNIZED_KWARG.format("autoreset", gym.__version__)
20342034
)
2035-
if apply_api_compatibility is not False:
2035+
if apply_api_compatibility is not None:
20362036
raise TypeError(
20372037
cls._GYM_UNRECOGNIZED_KWARG.format(
20382038
"apply_api_compatibility", gym.__version__
@@ -2056,7 +2056,7 @@ def _register_gym( # noqa: F811
20562056
max_episode_steps=max_episode_steps,
20572057
)
20582058

2059-
@implement_for("gymnasium", class_method=True)
2059+
@implement_for("gymnasium", None, "1.0", class_method=True)
20602060
def _register_gym( # noqa: F811
20612061
cls,
20622062
id,
@@ -2068,9 +2068,9 @@ def _register_gym( # noqa: F811
20682068
nondeterministic: bool = False,
20692069
max_episode_steps: int | None = None,
20702070
order_enforce: bool = True,
2071-
autoreset: bool = False,
2071+
autoreset: bool | None = None,
20722072
disable_env_checker: bool = False,
2073-
apply_api_compatibility: bool = False,
2073+
apply_api_compatibility: bool | None = None,
20742074
**kwargs,
20752075
):
20762076
import gymnasium
@@ -2094,9 +2094,56 @@ def _register_gym( # noqa: F811
20942094
nondeterministic=nondeterministic,
20952095
max_episode_steps=max_episode_steps,
20962096
order_enforce=order_enforce,
2097-
autoreset=autoreset,
2097+
autoreset=bool(autoreset),
2098+
disable_env_checker=disable_env_checker,
2099+
apply_api_compatibility=bool(apply_api_compatibility),
2100+
)
2101+
2102+
@implement_for("gymnasium", "1.0", class_method=True)
2103+
def _register_gym( # noqa: F811
2104+
cls,
2105+
id,
2106+
entry_point: Callable | None = None,
2107+
transform: "Transform" | None = None, # noqa: F821
2108+
info_keys: List[NestedKey] | None = None,
2109+
to_numpy: bool = False,
2110+
reward_threshold: float | None = None,
2111+
nondeterministic: bool = False,
2112+
max_episode_steps: int | None = None,
2113+
order_enforce: bool = True,
2114+
autoreset: bool | None = False,
2115+
disable_env_checker: bool = False,
2116+
apply_api_compatibility: bool | None = None,
2117+
**kwargs,
2118+
):
2119+
import gymnasium
2120+
from torchrl.envs.libs._gym_utils import _TorchRLGymnasiumWrapper
2121+
2122+
if entry_point is None:
2123+
entry_point = cls
2124+
2125+
entry_point = functools.partial(
2126+
_TorchRLGymnasiumWrapper,
2127+
entry_point=entry_point,
2128+
info_keys=info_keys,
2129+
to_numpy=to_numpy,
2130+
transform=transform,
2131+
**kwargs,
2132+
)
2133+
if autoreset is not None:
2134+
raise TypeError("autoreset is only compatible with gymnasium<1.0.")
2135+
if apply_api_compatibility is not None:
2136+
raise TypeError(
2137+
"apply_api_compatibility is only compatible with gymnasium<1.0."
2138+
)
2139+
return gymnasium.register(
2140+
id=id,
2141+
entry_point=entry_point,
2142+
reward_threshold=reward_threshold,
2143+
nondeterministic=nondeterministic,
2144+
max_episode_steps=max_episode_steps,
2145+
order_enforce=order_enforce,
20982146
disable_env_checker=disable_env_checker,
2099-
apply_api_compatibility=apply_api_compatibility,
21002147
)
21012148

21022149
def forward(self, tensordict: TensorDictBase) -> TensorDictBase:

torchrl/envs/libs/_gym_utils.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,11 @@ def step(self, action): # noqa: F811
154154
return out
155155

156156
@implement_for("gymnasium")
157-
def reset(self): # noqa: F811
157+
def reset(self, seed=None, options=None): # noqa: F811
158+
if seed is not None:
159+
self.torchrl_env.set_seed(seed)
160+
if options is not None:
161+
raise TypeError("options is not supported in torchrl envs.")
158162
self._tensordict = self.torchrl_env.reset()
159163
observation = self._tensordict
160164
if self.info_keys:

0 commit comments

Comments
 (0)