@@ -1622,9 +1622,9 @@ def register_gym(
16221622 nondeterministic : bool = False ,
16231623 max_episode_steps : int | None = None ,
16241624 order_enforce : bool = True ,
1625- autoreset : bool = False ,
1625+ autoreset : bool | None = None ,
16261626 disable_env_checker : bool = False ,
1627- apply_api_compatibility : bool = False ,
1627+ apply_api_compatibility : bool | None = None ,
16281628 ** kwargs ,
16291629 ):
16301630 """Registers an environment in gym(nasium).
@@ -1811,9 +1811,9 @@ def _register_gym(
18111811 nondeterministic : bool = False ,
18121812 max_episode_steps : int | None = None ,
18131813 order_enforce : bool = True ,
1814- autoreset : bool = False ,
1814+ autoreset : bool | None = None ,
18151815 disable_env_checker : bool = False ,
1816- apply_api_compatibility : bool = False ,
1816+ apply_api_compatibility : bool | None = None ,
18171817 ** kwargs ,
18181818 ):
18191819 import gym
@@ -1836,9 +1836,9 @@ def _register_gym(
18361836 nondeterministic = nondeterministic ,
18371837 max_episode_steps = max_episode_steps ,
18381838 order_enforce = order_enforce ,
1839- autoreset = autoreset ,
1839+ autoreset = bool ( autoreset ) ,
18401840 disable_env_checker = disable_env_checker ,
1841- apply_api_compatibility = apply_api_compatibility ,
1841+ apply_api_compatibility = bool ( apply_api_compatibility ) ,
18421842 )
18431843
18441844 @implement_for ("gym" , "0.25" , "0.26" , class_method = True )
@@ -1853,14 +1853,14 @@ def _register_gym( # noqa: F811
18531853 nondeterministic : bool = False ,
18541854 max_episode_steps : int | None = None ,
18551855 order_enforce : bool = True ,
1856- autoreset : bool = False ,
1856+ autoreset : bool | None = None ,
18571857 disable_env_checker : bool = False ,
1858- apply_api_compatibility : bool = False ,
1858+ apply_api_compatibility : bool | None = None ,
18591859 ** kwargs ,
18601860 ):
18611861 import gym
18621862
1863- if apply_api_compatibility is not False :
1863+ if apply_api_compatibility is not None :
18641864 raise TypeError (
18651865 cls ._GYM_UNRECOGNIZED_KWARG .format (
18661866 "apply_api_compatibility" , gym .__version__
@@ -1901,14 +1901,14 @@ def _register_gym( # noqa: F811
19011901 nondeterministic : bool = False ,
19021902 max_episode_steps : int | None = None ,
19031903 order_enforce : bool = True ,
1904- autoreset : bool = False ,
1904+ autoreset : bool | None = None ,
19051905 disable_env_checker : bool = False ,
1906- apply_api_compatibility : bool = False ,
1906+ apply_api_compatibility : bool | None = None ,
19071907 ** kwargs ,
19081908 ):
19091909 import gym
19101910
1911- if apply_api_compatibility is not False :
1911+ if apply_api_compatibility is not None :
19121912 raise TypeError (
19131913 cls ._GYM_UNRECOGNIZED_KWARG .format (
19141914 "apply_api_compatibility" , gym .__version__
@@ -1954,14 +1954,14 @@ def _register_gym( # noqa: F811
19541954 nondeterministic : bool = False ,
19551955 max_episode_steps : int | None = None ,
19561956 order_enforce : bool = True ,
1957- autoreset : bool = False ,
1957+ autoreset : bool | None = None ,
19581958 disable_env_checker : bool = False ,
1959- apply_api_compatibility : bool = False ,
1959+ apply_api_compatibility : bool | None = None ,
19601960 ** kwargs ,
19611961 ):
19621962 import gym
19631963
1964- if apply_api_compatibility is not False :
1964+ if apply_api_compatibility is not None :
19651965 raise TypeError (
19661966 cls ._GYM_UNRECOGNIZED_KWARG .format (
19671967 "apply_api_compatibility" , gym .__version__
@@ -1973,7 +1973,7 @@ def _register_gym( # noqa: F811
19731973 "disable_env_checker" , gym .__version__
19741974 )
19751975 )
1976- if autoreset is not False :
1976+ if autoreset is not None :
19771977 raise TypeError (
19781978 cls ._GYM_UNRECOGNIZED_KWARG .format ("autoreset" , gym .__version__ )
19791979 )
@@ -2010,9 +2010,9 @@ def _register_gym( # noqa: F811
20102010 nondeterministic : bool = False ,
20112011 max_episode_steps : int | None = None ,
20122012 order_enforce : bool = True ,
2013- autoreset : bool = False ,
2013+ autoreset : bool | None = None ,
20142014 disable_env_checker : bool = False ,
2015- apply_api_compatibility : bool = False ,
2015+ apply_api_compatibility : bool | None = None ,
20162016 ** kwargs ,
20172017 ):
20182018 import gym
@@ -2028,11 +2028,11 @@ def _register_gym( # noqa: F811
20282028 "disable_env_checker" , gym .__version__
20292029 )
20302030 )
2031- if autoreset is not False :
2031+ if autoreset is not None :
20322032 raise TypeError (
20332033 cls ._GYM_UNRECOGNIZED_KWARG .format ("autoreset" , gym .__version__ )
20342034 )
2035- if apply_api_compatibility is not False :
2035+ if apply_api_compatibility is not None :
20362036 raise TypeError (
20372037 cls ._GYM_UNRECOGNIZED_KWARG .format (
20382038 "apply_api_compatibility" , gym .__version__
@@ -2056,7 +2056,7 @@ def _register_gym( # noqa: F811
20562056 max_episode_steps = max_episode_steps ,
20572057 )
20582058
2059- @implement_for ("gymnasium" , class_method = True )
2059+ @implement_for ("gymnasium" , None , "1.0" , class_method = True )
20602060 def _register_gym ( # noqa: F811
20612061 cls ,
20622062 id ,
@@ -2068,9 +2068,9 @@ def _register_gym( # noqa: F811
20682068 nondeterministic : bool = False ,
20692069 max_episode_steps : int | None = None ,
20702070 order_enforce : bool = True ,
2071- autoreset : bool = False ,
2071+ autoreset : bool | None = None ,
20722072 disable_env_checker : bool = False ,
2073- apply_api_compatibility : bool = False ,
2073+ apply_api_compatibility : bool | None = None ,
20742074 ** kwargs ,
20752075 ):
20762076 import gymnasium
@@ -2094,9 +2094,56 @@ def _register_gym( # noqa: F811
20942094 nondeterministic = nondeterministic ,
20952095 max_episode_steps = max_episode_steps ,
20962096 order_enforce = order_enforce ,
2097- autoreset = autoreset ,
2097+ autoreset = bool (autoreset ),
2098+ disable_env_checker = disable_env_checker ,
2099+ apply_api_compatibility = bool (apply_api_compatibility ),
2100+ )
2101+
2102+ @implement_for ("gymnasium" , "1.0" , class_method = True )
2103+ def _register_gym ( # noqa: F811
2104+ cls ,
2105+ id ,
2106+ entry_point : Callable | None = None ,
2107+ transform : "Transform" | None = None , # noqa: F821
2108+ info_keys : List [NestedKey ] | None = None ,
2109+ to_numpy : bool = False ,
2110+ reward_threshold : float | None = None ,
2111+ nondeterministic : bool = False ,
2112+ max_episode_steps : int | None = None ,
2113+ order_enforce : bool = True ,
2114+ autoreset : bool | None = False ,
2115+ disable_env_checker : bool = False ,
2116+ apply_api_compatibility : bool | None = None ,
2117+ ** kwargs ,
2118+ ):
2119+ import gymnasium
2120+ from torchrl .envs .libs ._gym_utils import _TorchRLGymnasiumWrapper
2121+
2122+ if entry_point is None :
2123+ entry_point = cls
2124+
2125+ entry_point = functools .partial (
2126+ _TorchRLGymnasiumWrapper ,
2127+ entry_point = entry_point ,
2128+ info_keys = info_keys ,
2129+ to_numpy = to_numpy ,
2130+ transform = transform ,
2131+ ** kwargs ,
2132+ )
2133+ if autoreset is not None :
2134+ raise TypeError ("autoreset is only compatible with gymnasium<1.0." )
2135+ if apply_api_compatibility is not None :
2136+ raise TypeError (
2137+ "apply_api_compatibility is only compatible with gymnasium<1.0."
2138+ )
2139+ return gymnasium .register (
2140+ id = id ,
2141+ entry_point = entry_point ,
2142+ reward_threshold = reward_threshold ,
2143+ nondeterministic = nondeterministic ,
2144+ max_episode_steps = max_episode_steps ,
2145+ order_enforce = order_enforce ,
20982146 disable_env_checker = disable_env_checker ,
2099- apply_api_compatibility = apply_api_compatibility ,
21002147 )
21012148
21022149 def forward (self , tensordict : TensorDictBase ) -> TensorDictBase :
0 commit comments