@@ -733,6 +733,8 @@ def __init__(
733
733
elif torch .backends .mps .is_available () and hasattr (torch , "mps" ):
734
734
# Will break for older PT versions which don't have torch.mps
735
735
self ._sync_storage = torch .mps .synchronize
736
+ elif torch .npu .is_available () and hasattr (torch , "npu" ):
737
+ self ._sync_storage = torch .npu .synchronize
736
738
elif self .storing_device .type == "cpu" :
737
739
self ._sync_storage = _do_nothing
738
740
else :
@@ -747,6 +749,8 @@ def __init__(
747
749
self ._sync_env = torch .cuda .synchronize
748
750
elif torch .backends .mps .is_available () and hasattr (torch , "mps" ):
749
751
self ._sync_env = torch .mps .synchronize
752
+ elif torch .npu .is_available () and hasattr (torch , "npu" ):
753
+ self ._sync_env = torch .npu .synchronize
750
754
elif self .env_device .type == "cpu" :
751
755
self ._sync_env = _do_nothing
752
756
else :
@@ -760,6 +764,8 @@ def __init__(
760
764
self ._sync_policy = torch .cuda .synchronize
761
765
elif torch .backends .mps .is_available () and hasattr (torch , "mps" ):
762
766
self ._sync_policy = torch .mps .synchronize
767
+ elif torch .npu .is_available () and hasattr (torch , "npu" ):
768
+ self ._sync_policy = torch .npu .synchronize
763
769
elif self .policy_device .type == "cpu" :
764
770
self ._sync_policy = _do_nothing
765
771
else :
0 commit comments