Skip to content

Commit a2bd989

Browse files
authored
[Feature] Add NPU support for SyncDataCollector (#3155)
1 parent 3e6de39 commit a2bd989

File tree

1 file changed

+6
-0
lines changed

1 file changed

+6
-0
lines changed

torchrl/collectors/collectors.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -733,6 +733,8 @@ def __init__(
733733
elif torch.backends.mps.is_available() and hasattr(torch, "mps"):
734734
# Will break for older PT versions which don't have torch.mps
735735
self._sync_storage = torch.mps.synchronize
736+
elif torch.npu.is_available() and hasattr(torch, "npu"):
737+
self._sync_storage = torch.npu.synchronize
736738
elif self.storing_device.type == "cpu":
737739
self._sync_storage = _do_nothing
738740
else:
@@ -747,6 +749,8 @@ def __init__(
747749
self._sync_env = torch.cuda.synchronize
748750
elif torch.backends.mps.is_available() and hasattr(torch, "mps"):
749751
self._sync_env = torch.mps.synchronize
752+
elif torch.npu.is_available() and hasattr(torch, "npu"):
753+
self._sync_env = torch.npu.synchronize
750754
elif self.env_device.type == "cpu":
751755
self._sync_env = _do_nothing
752756
else:
@@ -760,6 +764,8 @@ def __init__(
760764
self._sync_policy = torch.cuda.synchronize
761765
elif torch.backends.mps.is_available() and hasattr(torch, "mps"):
762766
self._sync_policy = torch.mps.synchronize
767+
elif torch.npu.is_available() and hasattr(torch, "npu"):
768+
self._sync_policy = torch.npu.synchronize
763769
elif self.policy_device.type == "cpu":
764770
self._sync_policy = _do_nothing
765771
else:

0 commit comments

Comments
 (0)