@@ -318,6 +318,21 @@ def _make_spec( # noqa: F811
318318 shape = batch_size ,
319319 )
320320
321+ @implement_for ("gymnasium" , "1.1.0" )
322+ def _make_spec ( # noqa: F811
323+ self , batch_size , cat , cat_shape , multicat , multicat_shape
324+ ):
325+ return Composite (
326+ a = Unbounded (shape = (* batch_size , 1 )),
327+ b = Composite (c = cat (5 , shape = cat_shape , dtype = torch .int64 ), shape = batch_size ),
328+ d = cat (5 , shape = cat_shape , dtype = torch .int64 ),
329+ e = multicat ([2 , 3 ], shape = (* batch_size , multicat_shape ), dtype = torch .int64 ),
330+ f = Bounded (- 3 , 4 , shape = (* batch_size , 1 )),
331+ g = UnboundedDiscreteTensorSpec (shape = (* batch_size , 1 ), dtype = torch .long ),
332+ h = Binary (n = 5 , shape = (* batch_size , 5 )),
333+ shape = batch_size ,
334+ )
335+
321336 @pytest .mark .parametrize ("categorical" , [True , False ])
322337 def test_gym_spec_cast (self , categorical ):
323338 batch_size = [3 , 4 ]
@@ -379,10 +394,17 @@ def test_gym_spec_cast_tuple_sequential(self, order):
379394 torchrl_logger .info ("Sequence not available in gym" )
380395 return
381396
382- # @pytest.mark.parametrize("order", ["seq_tuple", "tuple_seq"])
397+ @pytest .mark .parametrize ("order" , ["tuple_seq" ])
398+ @implement_for ("gymnasium" , "1.1.0" )
399+ def test_gym_spec_cast_tuple_sequential (self , order ): # noqa: F811
400+ self ._test_gym_spec_cast_tuple_sequential (order )
401+
383402 @pytest .mark .parametrize ("order" , ["tuple_seq" ])
384403 @implement_for ("gymnasium" , None , "1.0.0" )
385404 def test_gym_spec_cast_tuple_sequential (self , order ): # noqa: F811
405+ self ._test_gym_spec_cast_tuple_sequential (order )
406+
407+ def _test_gym_spec_cast_tuple_sequential (self , order ): # noqa: F811
386408 with set_gym_backend ("gymnasium" ):
387409 if order == "seq_tuple" :
388410 # Requires nested tensors to be created along dim=1, disabling
@@ -974,8 +996,15 @@ def info_reader(info, tensordict):
974996 finally :
975997 set_gym_backend (gb ).set ()
976998
977- @implement_for ("gymnasium" , None , "1.0 .0" )
999+ @implement_for ("gymnasium" , "1.1 .0" )
9781000 def test_one_hot_and_categorical (self ):
1001+ self ._test_one_hot_and_categorical ()
1002+
1003+ @implement_for ("gymnasium" , None , "1.0.0" )
1004+ def test_one_hot_and_categorical (self ): # noqa
1005+ self ._test_one_hot_and_categorical ()
1006+
1007+ def _test_one_hot_and_categorical (self ):
9791008 # tests that one-hot and categorical work ok when an integer is expected as action
9801009 cliff_walking = GymEnv ("CliffWalking-v0" , categorical_action_encoding = True )
9811010 cliff_walking .rollout (10 )
@@ -993,14 +1022,27 @@ def test_one_hot_and_categorical(self): # noqa: F811
9931022 # versions.
9941023 return
9951024
996- @implement_for ("gymnasium" , None , "1.0 .0" )
1025+ @implement_for ("gymnasium" , "1.1 .0" )
9971026 @pytest .mark .parametrize (
9981027 "envname" ,
9991028 ["HalfCheetah-v4" , "CartPole-v1" , "ALE/Pong-v5" ]
10001029 + (["FetchReach-v2" ] if _has_gym_robotics else []),
10011030 )
10021031 @pytest .mark .flaky (reruns = 5 , reruns_delay = 1 )
10031032 def test_vecenvs_wrapper (self , envname ):
1033+ self ._test_vecenvs_wrapper (envname )
1034+
1035+ @implement_for ("gymnasium" , None , "1.0.0" )
1036+ @pytest .mark .parametrize (
1037+ "envname" ,
1038+ ["HalfCheetah-v4" , "CartPole-v1" , "ALE/Pong-v5" ]
1039+ + (["FetchReach-v2" ] if _has_gym_robotics else []),
1040+ )
1041+ @pytest .mark .flaky (reruns = 5 , reruns_delay = 1 )
1042+ def test_vecenvs_wrapper (self , envname ): # noqa
1043+ self ._test_vecenvs_wrapper (envname )
1044+
1045+ def _test_vecenvs_wrapper (self , envname ):
10041046 import gymnasium
10051047
10061048 # we can't use parametrize with implement_for
@@ -1019,7 +1061,7 @@ def test_vecenvs_wrapper(self, envname):
10191061 assert env .batch_size == torch .Size ([2 ])
10201062 check_env_specs (env )
10211063
1022- @implement_for ("gymnasium" , None , "1.0 .0" )
1064+ @implement_for ("gymnasium" , "1.1 .0" )
10231065 # this env has Dict-based observation which is a nice thing to test
10241066 @pytest .mark .parametrize (
10251067 "envname" ,
@@ -1028,6 +1070,21 @@ def test_vecenvs_wrapper(self, envname):
10281070 )
10291071 @pytest .mark .flaky (reruns = 5 , reruns_delay = 1 )
10301072 def test_vecenvs_env (self , envname ):
1073+ self ._test_vecenvs_env (envname )
1074+
1075+ @implement_for ("gymnasium" , None , "1.0.0" )
1076+ # this env has Dict-based observation which is a nice thing to test
1077+ @pytest .mark .parametrize (
1078+ "envname" ,
1079+ ["HalfCheetah-v4" , "CartPole-v1" , "ALE/Pong-v5" ]
1080+ + (["FetchReach-v2" ] if _has_gym_robotics else []),
1081+ )
1082+ @pytest .mark .flaky (reruns = 5 , reruns_delay = 1 )
1083+ def test_vecenvs_env (self , envname ): # noqa
1084+ self ._test_vecenvs_env (envname )
1085+
1086+ def _test_vecenvs_env (self , envname ):
1087+
10311088 gb = gym_backend ()
10321089 try :
10331090 with set_gym_backend ("gymnasium" ):
@@ -1181,9 +1238,17 @@ def test_gym_output_num(self, wrapper): # noqa: F811
11811238 finally :
11821239 set_gym_backend (gym ).set ()
11831240
1241+ @implement_for ("gymnasium" , "1.1.0" )
1242+ @pytest .mark .parametrize ("wrapper" , [True , False ])
1243+ def test_gym_output_num (self , wrapper ): # noqa: F811
1244+ self ._test_gym_output_num (wrapper )
1245+
11841246 @implement_for ("gymnasium" , None , "1.0.0" )
11851247 @pytest .mark .parametrize ("wrapper" , [True , False ])
11861248 def test_gym_output_num (self , wrapper ): # noqa: F811
1249+ self ._test_gym_output_num (wrapper )
1250+
1251+ def _test_gym_output_num (self , wrapper ): # noqa: F811
11871252 # gym has 5 outputs, with truncation
11881253 gym = gym_backend ()
11891254 try :
@@ -1284,8 +1349,15 @@ def test_vecenvs_nan(self): # noqa: F811
12841349 del c
12851350 return
12861351
1352+ @implement_for ("gymnasium" , "1.1.0" )
1353+ def test_vecenvs_nan (self ): # noqa: F811
1354+ self ._test_vecenvs_nan ()
1355+
12871356 @implement_for ("gymnasium" , None , "1.0.0" )
12881357 def test_vecenvs_nan (self ): # noqa: F811
1358+ self ._test_vecenvs_nan ()
1359+
1360+ def _test_vecenvs_nan (self ): # noqa: F811
12891361 # new versions of gym must never return nan for next values when there is a done state
12901362 torch .manual_seed (0 )
12911363 env = GymEnv ("CartPole-v1" , num_envs = 2 )
@@ -1352,6 +1424,98 @@ def step(self, action):
13521424
13531425 return CustomEnv (** kwargs )
13541426
1427+ @pytest .fixture (scope = "function" )
1428+ def counting_env (self ):
1429+ import gymnasium as gym
1430+ from gymnasium import Env
1431+
1432+ class CountingEnvRandomReset (Env ):
1433+ def __init__ (self , i = 0 ):
1434+ self .counter = 1
1435+ self .i = i
1436+ self .observation_space = gym .spaces .Box (- np .inf , np .inf , shape = (1 ,))
1437+ self .action_space = gym .spaces .Box (- np .inf , np .inf , shape = (1 ,))
1438+ self .rng = np .random .RandomState (0 )
1439+
1440+ def step (self , action ):
1441+ self .counter += 1
1442+ done = bool (self .rng .random () < 0.05 )
1443+ return (
1444+ np .asarray (
1445+ [
1446+ self .counter ,
1447+ ]
1448+ ),
1449+ 0 ,
1450+ done ,
1451+ done ,
1452+ {},
1453+ )
1454+
1455+ def reset (
1456+ self ,
1457+ * ,
1458+ seed : int | None = None ,
1459+ options = None ,
1460+ ):
1461+ self .counter = 1
1462+ if seed is not None :
1463+ self .rng = np .random .RandomState (seed )
1464+ return (
1465+ np .asarray (
1466+ [
1467+ self .counter ,
1468+ ]
1469+ ),
1470+ {},
1471+ )
1472+
1473+ yield CountingEnvRandomReset
1474+
1475+ @implement_for ("gym" )
1476+ def test_gymnasium_autoreset (self , venv ):
1477+ return
1478+
1479+ @implement_for ("gymnasium" , None , "1.1.0" )
1480+ def test_gymnasium_autoreset (self , venv ): # noqa
1481+ return
1482+
1483+ @implement_for ("gymnasium" , "1.1.0" )
1484+ @pytest .mark .parametrize ("venv" , ["sync" , "async" ])
1485+ def test_gymnasium_autoreset (self , venv , counting_env ): # noqa
1486+ import gymnasium as gym
1487+
1488+ if venv == "sync" :
1489+ venv = gym .vector .SyncVectorEnv
1490+ else :
1491+ venv = gym .vector .AsyncVectorEnv
1492+ envs0 = venv (
1493+ [lambda i = i : counting_env (i ) for i in range (2 )],
1494+ autoreset_mode = gym .vector .AutoresetMode .DISABLED ,
1495+ )
1496+ env = GymWrapper (envs0 )
1497+ envs0 .reset (seed = 0 )
1498+ torch .manual_seed (0 )
1499+ r0 = env .rollout (20 , break_when_any_done = False )
1500+ envs1 = venv (
1501+ [lambda i = i : counting_env (i ) for i in range (2 )],
1502+ autoreset_mode = gym .vector .AutoresetMode .SAME_STEP ,
1503+ )
1504+ env = GymWrapper (envs1 )
1505+ envs1 .reset (seed = 0 )
1506+ # env.set_seed(0)
1507+ torch .manual_seed (0 )
1508+ r1 = []
1509+ t_ = env .reset ()
1510+ for s in r0 .unbind (- 1 ):
1511+ t_ .set ("action" , s ["action" ])
1512+ t , t_ = env .step_and_maybe_reset (t_ )
1513+ r1 .append (t )
1514+ r1 = torch .stack (r1 , - 1 )
1515+ torch .testing .assert_close (r0 ["observation" ], r1 ["observation" ])
1516+ torch .testing .assert_close (r0 ["next" , "observation" ], r1 ["next" , "observation" ])
1517+ torch .testing .assert_close (r0 ["next" , "done" ], r1 ["next" , "done" ])
1518+
13551519 @pytest .mark .parametrize ("heterogeneous" , [False , True ])
13561520 def test_resetting_strategies (self , heterogeneous ):
13571521 if _has_gymnasium :
@@ -1461,6 +1625,12 @@ def _make_gym_environment(env_name): # noqa: F811
14611625 return gym .make (env_name , render_mode = "rgb_array" )
14621626
14631627
1628+ @implement_for ("gymnasium" , "1.1.0" )
1629+ def _make_gym_environment (env_name ): # noqa: F811
1630+ gym = gym_backend ()
1631+ return gym .make (env_name , render_mode = "rgb_array" )
1632+
1633+
14641634@pytest .mark .skipif (not _has_dmc , reason = "no dm_control library found" )
14651635class TestDMControl :
14661636 @pytest .mark .parametrize ("env_name,task" , [["cheetah" , "run" ]])
0 commit comments