Skip to content

Commit 0adbb6c

Browse files
author
Vincent Moens
committed
[BugFix] Fix collector with no buffers and devices
ghstack-source-id: 89bc0df Pull Request resolved: #2809
1 parent 3acf491 commit 0adbb6c

File tree

1 file changed

+9
-2
lines changed

1 file changed

+9
-2
lines changed

test/test_collector.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3093,16 +3093,23 @@ def test_dynamic_sync_collector(self):
30933093
assert isinstance(data, LazyStackedTensorDict)
30943094
assert data.names[-1] == "time"
30953095

3096-
def test_dynamic_multisync_collector(self):
3096+
@pytest.mark.parametrize("policy_device", [None, *get_default_devices()])
3097+
def test_dynamic_multisync_collector(self, policy_device):
30973098
env = EnvWithDynamicSpec
3098-
policy = RandomPolicy(env().action_spec)
3099+
spec = env().action_spec
3100+
if policy_device is not None:
3101+
spec = spec.to(policy_device)
3102+
policy = RandomPolicy(spec)
30993103
collector = MultiSyncDataCollector(
31003104
[env],
31013105
policy,
31023106
frames_per_batch=20,
31033107
total_frames=100,
31043108
use_buffers=False,
31053109
cat_results="stack",
3110+
policy_device=policy_device,
3111+
env_device="cpu",
3112+
storing_device="cpu",
31063113
)
31073114
for data in collector:
31083115
assert isinstance(data, LazyStackedTensorDict)

0 commit comments

Comments
 (0)