@@ -733,7 +733,7 @@ def __init__(
733
733
elif torch .backends .mps .is_available () and hasattr (torch , "mps" ):
734
734
# Will break for older PT versions which don't have torch.mps
735
735
self ._sync_storage = torch .mps .synchronize
736
- elif torch . npu . is_available () and hasattr (torch , "npu" ):
736
+ elif hasattr (torch , "npu" ) and torch . npu . is_available ( ):
737
737
self ._sync_storage = torch .npu .synchronize
738
738
elif self .storing_device .type == "cpu" :
739
739
self ._sync_storage = _do_nothing
@@ -749,7 +749,7 @@ def __init__(
749
749
self ._sync_env = torch .cuda .synchronize
750
750
elif torch .backends .mps .is_available () and hasattr (torch , "mps" ):
751
751
self ._sync_env = torch .mps .synchronize
752
- elif torch . npu . is_available () and hasattr (torch , "npu" ):
752
+ elif hasattr (torch , "npu" ) and torch . npu . is_available ( ):
753
753
self ._sync_env = torch .npu .synchronize
754
754
elif self .env_device .type == "cpu" :
755
755
self ._sync_env = _do_nothing
@@ -764,7 +764,7 @@ def __init__(
764
764
self ._sync_policy = torch .cuda .synchronize
765
765
elif torch .backends .mps .is_available () and hasattr (torch , "mps" ):
766
766
self ._sync_policy = torch .mps .synchronize
767
- elif torch . npu . is_available () and hasattr (torch , "npu" ):
767
+ elif hasattr (torch , "npu" ) and torch . npu . is_available ( ):
768
768
self ._sync_policy = torch .npu .synchronize
769
769
elif self .policy_device .type == "cpu" :
770
770
self ._sync_policy = _do_nothing
0 commit comments