Skip to content

Commit 8b4c97f

Browse files
author
Vincent Moens
committed
Update
[ghstack-poisoned]
1 parent 5d316d8 commit 8b4c97f

File tree

4 files changed

+154
-30
lines changed

4 files changed

+154
-30
lines changed

test/test_libs.py

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1424,7 +1424,6 @@ def step(self, action):
14241424

14251425
return CustomEnv(**kwargs)
14261426

1427-
@pytest.fixture(scope="function")
14281427
def counting_env(self):
14291428
import gymnasium as gym
14301429
from gymnasium import Env
@@ -1482,9 +1481,10 @@ def test_gymnasium_autoreset(self, venv): # noqa
14821481

14831482
@implement_for("gymnasium", "1.1.0")
14841483
@pytest.mark.parametrize("venv", ["sync", "async"])
1485-
def test_gymnasium_autoreset(self, venv, counting_env): # noqa
1484+
def test_gymnasium_autoreset(self, venv): # noqa
14861485
import gymnasium as gym
14871486

1487+
counting_env = self.counting_env()
14881488
if venv == "sync":
14891489
venv = gym.vector.SyncVectorEnv
14901490
else:
@@ -1516,8 +1516,25 @@ def test_gymnasium_autoreset(self, venv, counting_env): # noqa
15161516
torch.testing.assert_close(r0["next", "observation"], r1["next", "observation"])
15171517
torch.testing.assert_close(r0["next", "done"], r1["next", "done"])
15181518

1519+
@implement_for("gym")
1520+
def test_resetting_strategies(self):
1521+
return
1522+
1523+
@implement_for("gymnasium", None, "1.0.0")
15191524
@pytest.mark.parametrize("heterogeneous", [False, True])
1520-
def test_resetting_strategies(self, heterogeneous):
1525+
def test_resetting_strategies(self, heterogeneous): # noqa
1526+
self._test_resetting_strategies(heterogeneous, {})
1527+
1528+
@implement_for("gymnasium", "1.1.0")
1529+
@pytest.mark.parametrize("heterogeneous", [False, True])
1530+
def test_resetting_strategies(self, heterogeneous): # noqa
1531+
import gymnasium as gym
1532+
1533+
self._test_resetting_strategies(
1534+
heterogeneous, {"autoreset_mode": gym.vector.AutoresetMode.SAME_STEP}
1535+
)
1536+
1537+
def _test_resetting_strategies(self, heterogeneous, kwargs):
15211538
if _has_gymnasium:
15221539
backend = "gymnasium"
15231540
else:
@@ -1533,7 +1550,8 @@ def test_resetting_strategies(self, heterogeneous):
15331550
env = GymWrapper(
15341551
gym_backend().vector.AsyncVectorEnv(
15351552
[functools.partial(self._get_dummy_gym_env, backend=backend)]
1536-
* 4
1553+
* 4,
1554+
**kwargs,
15371555
)
15381556
)
15391557
else:
@@ -1546,7 +1564,8 @@ def test_resetting_strategies(self, heterogeneous):
15461564
backend=backend,
15471565
)
15481566
for i in range(4)
1549-
]
1567+
],
1568+
**kwargs,
15501569
)
15511570
)
15521571
try:

torchrl/envs/common.py

Lines changed: 67 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2214,7 +2214,7 @@ def register_gym(
22142214
nondeterministic: bool = False,
22152215
max_episode_steps: int | None = None,
22162216
order_enforce: bool = True,
2217-
autoreset: bool = False,
2217+
autoreset: bool | None = None,
22182218
disable_env_checker: bool = False,
22192219
apply_api_compatibility: bool = False,
22202220
**kwargs,
@@ -2276,12 +2276,12 @@ def register_gym(
22762276
enforcer wrapper should be applied to ensure users run functions
22772277
in the correct order.
22782278
Defaults to ``True``.
2279-
autoreset (bool, optional): [Gym >= 0.14] Whether the autoreset wrapper
2279+
autoreset (bool, optional): [Gym >= 0.14 and <1.0.0] Whether the autoreset wrapper
22802280
should be added such that reset does not need to be called.
22812281
Defaults to ``False``.
22822282
disable_env_checker: [Gym >= 0.14] Whether the environment
22832283
checker should be disabled for the environment. Defaults to ``False``.
2284-
apply_api_compatibility: [Gym >= 0.26] If to apply the `StepAPICompatibility` wrapper.
2284+
apply_api_compatibility: [Gym >= 0.26 and <1.0.0] If to apply the `StepAPICompatibility` wrapper.
22852285
Defaults to ``False``.
22862286
**kwargs: arbitrary keyword arguments which are passed to the environment constructor.
22872287
@@ -2403,7 +2403,7 @@ def _register_gym(
24032403
nondeterministic: bool = False,
24042404
max_episode_steps: int | None = None,
24052405
order_enforce: bool = True,
2406-
autoreset: bool = False,
2406+
autoreset: bool | None = None,
24072407
disable_env_checker: bool = False,
24082408
apply_api_compatibility: bool = False,
24092409
**kwargs,
@@ -2428,7 +2428,7 @@ def _register_gym(
24282428
nondeterministic=nondeterministic,
24292429
max_episode_steps=max_episode_steps,
24302430
order_enforce=order_enforce,
2431-
autoreset=autoreset,
2431+
autoreset=bool(autoreset),
24322432
disable_env_checker=disable_env_checker,
24332433
apply_api_compatibility=apply_api_compatibility,
24342434
)
@@ -2445,7 +2445,7 @@ def _register_gym( # noqa: F811
24452445
nondeterministic: bool = False,
24462446
max_episode_steps: int | None = None,
24472447
order_enforce: bool = True,
2448-
autoreset: bool = False,
2448+
autoreset: bool | None = None,
24492449
disable_env_checker: bool = False,
24502450
apply_api_compatibility: bool = False,
24512451
**kwargs,
@@ -2477,7 +2477,7 @@ def _register_gym( # noqa: F811
24772477
nondeterministic=nondeterministic,
24782478
max_episode_steps=max_episode_steps,
24792479
order_enforce=order_enforce,
2480-
autoreset=autoreset,
2480+
autoreset=bool(autoreset),
24812481
disable_env_checker=disable_env_checker,
24822482
)
24832483

@@ -2493,7 +2493,7 @@ def _register_gym( # noqa: F811
24932493
nondeterministic: bool = False,
24942494
max_episode_steps: int | None = None,
24952495
order_enforce: bool = True,
2496-
autoreset: bool = False,
2496+
autoreset: bool | None = None,
24972497
disable_env_checker: bool = False,
24982498
apply_api_compatibility: bool = False,
24992499
**kwargs,
@@ -2531,7 +2531,7 @@ def _register_gym( # noqa: F811
25312531
nondeterministic=nondeterministic,
25322532
max_episode_steps=max_episode_steps,
25332533
order_enforce=order_enforce,
2534-
autoreset=autoreset,
2534+
autoreset=bool(autoreset),
25352535
)
25362536

25372537
@implement_for("gym", "0.21", "0.24", class_method=True)
@@ -2546,7 +2546,7 @@ def _register_gym( # noqa: F811
25462546
nondeterministic: bool = False,
25472547
max_episode_steps: int | None = None,
25482548
order_enforce: bool = True,
2549-
autoreset: bool = False,
2549+
autoreset: bool | None = None,
25502550
disable_env_checker: bool = False,
25512551
apply_api_compatibility: bool = False,
25522552
**kwargs,
@@ -2565,7 +2565,7 @@ def _register_gym( # noqa: F811
25652565
"disable_env_checker", gym.__version__
25662566
)
25672567
)
2568-
if autoreset is not False:
2568+
if autoreset is not None:
25692569
raise TypeError(
25702570
cls._GYM_UNRECOGNIZED_KWARG.format("autoreset", gym.__version__)
25712571
)
@@ -2602,7 +2602,7 @@ def _register_gym( # noqa: F811
26022602
nondeterministic: bool = False,
26032603
max_episode_steps: int | None = None,
26042604
order_enforce: bool = True,
2605-
autoreset: bool = False,
2605+
autoreset: bool | None = None,
26062606
disable_env_checker: bool = False,
26072607
apply_api_compatibility: bool = False,
26082608
**kwargs,
@@ -2620,7 +2620,7 @@ def _register_gym( # noqa: F811
26202620
"disable_env_checker", gym.__version__
26212621
)
26222622
)
2623-
if autoreset is not False:
2623+
if autoreset is not None:
26242624
raise TypeError(
26252625
cls._GYM_UNRECOGNIZED_KWARG.format("autoreset", gym.__version__)
26262626
)
@@ -2648,7 +2648,7 @@ def _register_gym( # noqa: F811
26482648
max_episode_steps=max_episode_steps,
26492649
)
26502650

2651-
@implement_for("gymnasium", class_method=True)
2651+
@implement_for("gymnasium", None, "1.0.0", class_method=True)
26522652
def _register_gym( # noqa: F811
26532653
cls,
26542654
id,
@@ -2660,7 +2660,7 @@ def _register_gym( # noqa: F811
26602660
nondeterministic: bool = False,
26612661
max_episode_steps: int | None = None,
26622662
order_enforce: bool = True,
2663-
autoreset: bool = False,
2663+
autoreset: bool | None = None,
26642664
disable_env_checker: bool = False,
26652665
apply_api_compatibility: bool = False,
26662666
**kwargs,
@@ -2686,11 +2686,62 @@ def _register_gym( # noqa: F811
26862686
nondeterministic=nondeterministic,
26872687
max_episode_steps=max_episode_steps,
26882688
order_enforce=order_enforce,
2689-
autoreset=autoreset,
2689+
autoreset=bool(autoreset),
26902690
disable_env_checker=disable_env_checker,
26912691
apply_api_compatibility=apply_api_compatibility,
26922692
)
26932693

2694+
@implement_for("gymnasium", "1.1.0", class_method=True)
2695+
def _register_gym( # noqa: F811
2696+
cls,
2697+
id,
2698+
entry_point: Callable | None = None,
2699+
transform: Transform | None = None, # noqa: F821
2700+
info_keys: list[NestedKey] | None = None,
2701+
to_numpy: bool = False,
2702+
reward_threshold: float | None = None,
2703+
nondeterministic: bool = False,
2704+
max_episode_steps: int | None = None,
2705+
order_enforce: bool = True,
2706+
autoreset: bool | None = None,
2707+
disable_env_checker: bool = False,
2708+
apply_api_compatibility: bool = False,
2709+
**kwargs,
2710+
):
2711+
import gymnasium
2712+
from torchrl.envs.libs._gym_utils import _TorchRLGymnasiumWrapper
2713+
2714+
if autoreset is not None:
2715+
raise TypeError(
2716+
f"the autoreset argument is deprecated in gymnasium>=1.0. Got autoreset={autoreset}"
2717+
)
2718+
if entry_point is None:
2719+
entry_point = cls
2720+
2721+
entry_point = partial(
2722+
_TorchRLGymnasiumWrapper,
2723+
entry_point=entry_point,
2724+
info_keys=info_keys,
2725+
to_numpy=to_numpy,
2726+
transform=transform,
2727+
**kwargs,
2728+
)
2729+
if apply_api_compatibility is not False:
2730+
raise TypeError(
2731+
cls._GYM_UNRECOGNIZED_KWARG.format(
2732+
"apply_api_compatibility", gymnasium.__version__
2733+
)
2734+
)
2735+
return gymnasium.register(
2736+
id=id,
2737+
entry_point=entry_point,
2738+
reward_threshold=reward_threshold,
2739+
nondeterministic=nondeterministic,
2740+
max_episode_steps=max_episode_steps,
2741+
order_enforce=order_enforce,
2742+
disable_env_checker=disable_env_checker,
2743+
)
2744+
26942745
def forward(self, *args, **kwargs):
26952746
raise NotImplementedError(
26962747
"EnvBase.forward is not implemented. If you ended here during a call to `ParallelEnv(...)`, please use "

torchrl/envs/libs/_gym_utils.py

Lines changed: 58 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ def _action_keys(self):
125125
import gymnasium
126126

127127
class _TorchRLGymnasiumWrapper(gymnasium.Env, _BaseGymWrapper):
128-
@implement_for("gymnasium", "1.0.0")
128+
@implement_for("gymnasium", "1.0.0", "1.1.0")
129129
def step(self, action): # noqa: F811
130130
raise ImportError(GYMNASIUM_1_ERROR)
131131

@@ -157,9 +157,43 @@ def step(self, action): # noqa: F811
157157
out = tree_map(lambda x: x.detach().cpu().numpy(), out)
158158
return out
159159

160+
@implement_for("gymnasium", "1.1.0")
161+
def step(self, action): # noqa: F811
162+
action_keys = self._action_keys
163+
if len(action_keys) == 1:
164+
self._tensordict.set(action_keys[0], action)
165+
else:
166+
raise RuntimeError(
167+
"Wrapping environments with more than one action key is not supported yet."
168+
)
169+
self.torchrl_env.step(self._tensordict)
170+
_tensordict = step_mdp(self._tensordict)
171+
observation = self._tensordict.get("next")
172+
if self.info_keys:
173+
info = observation.select(*self.info_keys).to_dict()
174+
else:
175+
info = {}
176+
observation = observation.select(*self._observation_keys).to_dict()
177+
reward = self._tensordict.get(("next", "reward"))
178+
terminated = self._tensordict.get(("next", "terminated"))
179+
truncated = self._tensordict.get(
180+
("next", "truncated"), torch.zeros_like(terminated)
181+
)
182+
self._tensordict = _tensordict.select(*self._input_keys)
183+
out = (observation, reward, terminated, truncated, info)
184+
if self.to_numpy:
185+
out = tree_map(lambda x: x.detach().cpu().numpy(), out)
186+
return out
187+
160188
@implement_for("gymnasium", None, "1.0.0")
161-
def reset(self): # noqa: F811
162-
self._tensordict = self.torchrl_env.reset()
189+
def reset(
190+
self, seed: int | None = None, options: dict | None = None
191+
): # noqa: F811
192+
if seed is not None:
193+
self.torchrl_env.set_seed(seed)
194+
if options is None:
195+
options = {}
196+
self._tensordict = self.torchrl_env.reset(**options)
163197
observation = self._tensordict
164198
if self.info_keys:
165199
info = observation.select(*self.info_keys).to_dict()
@@ -171,10 +205,30 @@ def reset(self): # noqa: F811
171205
out = tree_map(lambda x: x.detach().cpu().numpy(), out)
172206
return out
173207

174-
@implement_for("gymnasium", "1.0.0")
208+
@implement_for("gymnasium", "1.0.0", "1.1.0")
175209
def reset(self): # noqa: F811
176210
raise ImportError(GYMNASIUM_1_ERROR)
177211

212+
@implement_for("gymnasium", "1.1.0")
213+
def reset( # noqa: F811
214+
self, seed: int | None = None, options: dict | None = None
215+
):
216+
if seed is not None:
217+
self.torchrl_env.set_seed(seed)
218+
if options is None:
219+
options = {}
220+
self._tensordict = self.torchrl_env.reset(**options)
221+
observation = self._tensordict
222+
if self.info_keys:
223+
info = observation.select(*self.info_keys).to_dict()
224+
else:
225+
info = {}
226+
observation = observation.select(*self._observation_keys).to_dict()
227+
out = observation, info
228+
if self.to_numpy:
229+
out = tree_map(lambda x: x.detach().cpu().numpy(), out)
230+
return out
231+
178232
else:
179233

180234
class _TorchRLGymnasiumWrapper:

torchrl/envs/libs/gym.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1033,14 +1033,14 @@ def __init__(self, env=None, categorical_action_encoding=False, **kwargs):
10331033

10341034
@implement_for("gymnasium", "1.1.0")
10351035
def _validate_env(self, env):
1036-
auto_reset_mode = getattr(env, "auto_reset_mode", None)
1037-
if auto_reset_mode is not None:
1038-
from gymnasium import AutoResetMode
1036+
autoreset_mode = getattr(env, "autoreset_mode", None)
1037+
if autoreset_mode is not None:
1038+
from gymnasium.vector import AutoresetMode
10391039

1040-
if auto_reset_mode not in (AutoResetMode.DISABLED, AutoResetMode.SAME_STEP):
1040+
if autoreset_mode not in (AutoresetMode.DISABLED, AutoresetMode.SAME_STEP):
10411041
raise RuntimeError(
10421042
"The auto-reset mode must be one of SAME_STEP or DISABLED (which is preferred). Got "
1043-
f"auto_reset_mode={auto_reset_mode}."
1043+
f"autoreset_mode={autoreset_mode}."
10441044
)
10451045

10461046
@implement_for("gym", None, "1.1.0")

0 commit comments

Comments
 (0)