@@ -2214,7 +2214,7 @@ def register_gym(
22142214 nondeterministic : bool = False ,
22152215 max_episode_steps : int | None = None ,
22162216 order_enforce : bool = True ,
2217- autoreset : bool = False ,
2217+ autoreset : bool | None = None ,
22182218 disable_env_checker : bool = False ,
22192219 apply_api_compatibility : bool = False ,
22202220 ** kwargs ,
@@ -2276,12 +2276,12 @@ def register_gym(
22762276 enforcer wrapper should be applied to ensure users run functions
22772277 in the correct order.
22782278 Defaults to ``True``.
2279- autoreset (bool, optional): [Gym >= 0.14] Whether the autoreset wrapper
2279+ autoreset (bool, optional): [Gym >= 0.14 and <1.0.0 ] Whether the autoreset wrapper
22802280 should be added such that reset does not need to be called.
22812281 Defaults to ``False``.
22822282 disable_env_checker: [Gym >= 0.14] Whether the environment
22832283 checker should be disabled for the environment. Defaults to ``False``.
2284- apply_api_compatibility: [Gym >= 0.26] If to apply the `StepAPICompatibility` wrapper.
2284+ apply_api_compatibility: [Gym >= 0.26 and <1.0.0 ] If to apply the `StepAPICompatibility` wrapper.
22852285 Defaults to ``False``.
22862286 **kwargs: arbitrary keyword arguments which are passed to the environment constructor.
22872287
@@ -2403,7 +2403,7 @@ def _register_gym(
24032403 nondeterministic : bool = False ,
24042404 max_episode_steps : int | None = None ,
24052405 order_enforce : bool = True ,
2406- autoreset : bool = False ,
2406+ autoreset : bool | None = None ,
24072407 disable_env_checker : bool = False ,
24082408 apply_api_compatibility : bool = False ,
24092409 ** kwargs ,
@@ -2428,7 +2428,7 @@ def _register_gym(
24282428 nondeterministic = nondeterministic ,
24292429 max_episode_steps = max_episode_steps ,
24302430 order_enforce = order_enforce ,
2431- autoreset = autoreset ,
2431+ autoreset = bool ( autoreset ) ,
24322432 disable_env_checker = disable_env_checker ,
24332433 apply_api_compatibility = apply_api_compatibility ,
24342434 )
@@ -2445,7 +2445,7 @@ def _register_gym( # noqa: F811
24452445 nondeterministic : bool = False ,
24462446 max_episode_steps : int | None = None ,
24472447 order_enforce : bool = True ,
2448- autoreset : bool = False ,
2448+ autoreset : bool | None = None ,
24492449 disable_env_checker : bool = False ,
24502450 apply_api_compatibility : bool = False ,
24512451 ** kwargs ,
@@ -2477,7 +2477,7 @@ def _register_gym( # noqa: F811
24772477 nondeterministic = nondeterministic ,
24782478 max_episode_steps = max_episode_steps ,
24792479 order_enforce = order_enforce ,
2480- autoreset = autoreset ,
2480+ autoreset = bool ( autoreset ) ,
24812481 disable_env_checker = disable_env_checker ,
24822482 )
24832483
@@ -2493,7 +2493,7 @@ def _register_gym( # noqa: F811
24932493 nondeterministic : bool = False ,
24942494 max_episode_steps : int | None = None ,
24952495 order_enforce : bool = True ,
2496- autoreset : bool = False ,
2496+ autoreset : bool | None = None ,
24972497 disable_env_checker : bool = False ,
24982498 apply_api_compatibility : bool = False ,
24992499 ** kwargs ,
@@ -2531,7 +2531,7 @@ def _register_gym( # noqa: F811
25312531 nondeterministic = nondeterministic ,
25322532 max_episode_steps = max_episode_steps ,
25332533 order_enforce = order_enforce ,
2534- autoreset = autoreset ,
2534+ autoreset = bool ( autoreset ) ,
25352535 )
25362536
25372537 @implement_for ("gym" , "0.21" , "0.24" , class_method = True )
@@ -2546,7 +2546,7 @@ def _register_gym( # noqa: F811
25462546 nondeterministic : bool = False ,
25472547 max_episode_steps : int | None = None ,
25482548 order_enforce : bool = True ,
2549- autoreset : bool = False ,
2549+ autoreset : bool | None = None ,
25502550 disable_env_checker : bool = False ,
25512551 apply_api_compatibility : bool = False ,
25522552 ** kwargs ,
@@ -2565,7 +2565,7 @@ def _register_gym( # noqa: F811
25652565 "disable_env_checker" , gym .__version__
25662566 )
25672567 )
2568- if autoreset is not False :
2568+ if autoreset is not None :
25692569 raise TypeError (
25702570 cls ._GYM_UNRECOGNIZED_KWARG .format ("autoreset" , gym .__version__ )
25712571 )
@@ -2602,7 +2602,7 @@ def _register_gym( # noqa: F811
26022602 nondeterministic : bool = False ,
26032603 max_episode_steps : int | None = None ,
26042604 order_enforce : bool = True ,
2605- autoreset : bool = False ,
2605+ autoreset : bool | None = None ,
26062606 disable_env_checker : bool = False ,
26072607 apply_api_compatibility : bool = False ,
26082608 ** kwargs ,
@@ -2620,7 +2620,7 @@ def _register_gym( # noqa: F811
26202620 "disable_env_checker" , gym .__version__
26212621 )
26222622 )
2623- if autoreset is not False :
2623+ if autoreset is not None :
26242624 raise TypeError (
26252625 cls ._GYM_UNRECOGNIZED_KWARG .format ("autoreset" , gym .__version__ )
26262626 )
@@ -2648,7 +2648,7 @@ def _register_gym( # noqa: F811
26482648 max_episode_steps = max_episode_steps ,
26492649 )
26502650
2651- @implement_for ("gymnasium" , class_method = True )
2651+ @implement_for ("gymnasium" , None , "1.0.0" , class_method = True )
26522652 def _register_gym ( # noqa: F811
26532653 cls ,
26542654 id ,
@@ -2660,7 +2660,7 @@ def _register_gym( # noqa: F811
26602660 nondeterministic : bool = False ,
26612661 max_episode_steps : int | None = None ,
26622662 order_enforce : bool = True ,
2663- autoreset : bool = False ,
2663+ autoreset : bool | None = None ,
26642664 disable_env_checker : bool = False ,
26652665 apply_api_compatibility : bool = False ,
26662666 ** kwargs ,
@@ -2686,11 +2686,62 @@ def _register_gym( # noqa: F811
26862686 nondeterministic = nondeterministic ,
26872687 max_episode_steps = max_episode_steps ,
26882688 order_enforce = order_enforce ,
2689- autoreset = autoreset ,
2689+ autoreset = bool ( autoreset ) ,
26902690 disable_env_checker = disable_env_checker ,
26912691 apply_api_compatibility = apply_api_compatibility ,
26922692 )
26932693
2694+ @implement_for ("gymnasium" , "1.1.0" , class_method = True )
2695+ def _register_gym ( # noqa: F811
2696+ cls ,
2697+ id ,
2698+ entry_point : Callable | None = None ,
2699+ transform : Transform | None = None , # noqa: F821
2700+ info_keys : list [NestedKey ] | None = None ,
2701+ to_numpy : bool = False ,
2702+ reward_threshold : float | None = None ,
2703+ nondeterministic : bool = False ,
2704+ max_episode_steps : int | None = None ,
2705+ order_enforce : bool = True ,
2706+ autoreset : bool | None = None ,
2707+ disable_env_checker : bool = False ,
2708+ apply_api_compatibility : bool = False ,
2709+ ** kwargs ,
2710+ ):
2711+ import gymnasium
2712+ from torchrl .envs .libs ._gym_utils import _TorchRLGymnasiumWrapper
2713+
2714+ if autoreset is not None :
2715+ raise TypeError (
2716+ f"the autoreset argument is deprecated in gymnasium>=1.0. Got autoreset={ autoreset } "
2717+ )
2718+ if entry_point is None :
2719+ entry_point = cls
2720+
2721+ entry_point = partial (
2722+ _TorchRLGymnasiumWrapper ,
2723+ entry_point = entry_point ,
2724+ info_keys = info_keys ,
2725+ to_numpy = to_numpy ,
2726+ transform = transform ,
2727+ ** kwargs ,
2728+ )
2729+ if apply_api_compatibility is not False :
2730+ raise TypeError (
2731+ cls ._GYM_UNRECOGNIZED_KWARG .format (
2732+ "apply_api_compatibility" , gymnasium .__version__
2733+ )
2734+ )
2735+ return gymnasium .register (
2736+ id = id ,
2737+ entry_point = entry_point ,
2738+ reward_threshold = reward_threshold ,
2739+ nondeterministic = nondeterministic ,
2740+ max_episode_steps = max_episode_steps ,
2741+ order_enforce = order_enforce ,
2742+ disable_env_checker = disable_env_checker ,
2743+ )
2744+
26942745 def forward (self , * args , ** kwargs ):
26952746 raise NotImplementedError (
26962747 "EnvBase.forward is not implemented. If you ended here during a call to `ParallelEnv(...)`, please use "
0 commit comments