Skip to content

Commit d58723e

Browse files
authored
[BugFix] Fix 'npu' attribute (#3159)
1 parent a2bd989 commit d58723e

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

torchrl/collectors/collectors.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -733,7 +733,7 @@ def __init__(
733733
elif torch.backends.mps.is_available() and hasattr(torch, "mps"):
734734
# Will break for older PT versions which don't have torch.mps
735735
self._sync_storage = torch.mps.synchronize
736-
elif torch.npu.is_available() and hasattr(torch, "npu"):
736+
elif hasattr(torch, "npu") and torch.npu.is_available():
737737
self._sync_storage = torch.npu.synchronize
738738
elif self.storing_device.type == "cpu":
739739
self._sync_storage = _do_nothing
@@ -749,7 +749,7 @@ def __init__(
749749
self._sync_env = torch.cuda.synchronize
750750
elif torch.backends.mps.is_available() and hasattr(torch, "mps"):
751751
self._sync_env = torch.mps.synchronize
752-
elif torch.npu.is_available() and hasattr(torch, "npu"):
752+
elif hasattr(torch, "npu") and torch.npu.is_available():
753753
self._sync_env = torch.npu.synchronize
754754
elif self.env_device.type == "cpu":
755755
self._sync_env = _do_nothing
@@ -764,7 +764,7 @@ def __init__(
764764
self._sync_policy = torch.cuda.synchronize
765765
elif torch.backends.mps.is_available() and hasattr(torch, "mps"):
766766
self._sync_policy = torch.mps.synchronize
767-
elif torch.npu.is_available() and hasattr(torch, "npu"):
767+
elif hasattr(torch, "npu") and torch.npu.is_available():
768768
self._sync_policy = torch.npu.synchronize
769769
elif self.policy_device.type == "cpu":
770770
self._sync_policy = _do_nothing

0 commit comments

Comments
 (0)