File tree Expand file tree Collapse file tree 1 file changed +9
-2
lines changed Expand file tree Collapse file tree 1 file changed +9
-2
lines changed Original file line number Diff line number Diff line change @@ -3093,16 +3093,23 @@ def test_dynamic_sync_collector(self):
3093
3093
assert isinstance (data , LazyStackedTensorDict )
3094
3094
assert data .names [- 1 ] == "time"
3095
3095
3096
- def test_dynamic_multisync_collector (self ):
3096
+ @pytest .mark .parametrize ("policy_device" , [None , * get_default_devices ()])
3097
+ def test_dynamic_multisync_collector (self , policy_device ):
3097
3098
env = EnvWithDynamicSpec
3098
- policy = RandomPolicy (env ().action_spec )
3099
+ spec = env ().action_spec
3100
+ if policy_device is not None :
3101
+ spec = spec .to (policy_device )
3102
+ policy = RandomPolicy (spec )
3099
3103
collector = MultiSyncDataCollector (
3100
3104
[env ],
3101
3105
policy ,
3102
3106
frames_per_batch = 20 ,
3103
3107
total_frames = 100 ,
3104
3108
use_buffers = False ,
3105
3109
cat_results = "stack" ,
3110
+ policy_device = policy_device ,
3111
+ env_device = "cpu" ,
3112
+ storing_device = "cpu" ,
3106
3113
)
3107
3114
for data in collector :
3108
3115
assert isinstance (data , LazyStackedTensorDict )
You can’t perform that action at this time.
0 commit comments