File tree Expand file tree Collapse file tree 4 files changed +20
-3
lines changed Expand file tree Collapse file tree 4 files changed +20
-3
lines changed Original file line number Diff line number Diff line change @@ -114,12 +114,16 @@ def __init__(
114114 else :
115115 self .random_erasing = None
116116 self .is_cuda = torch .cuda .is_available () and device .type == 'cuda'
117+ self .is_npu = torch .npu .is_available () and device .type == 'npu'
117118
118119 def __iter__ (self ):
119120 first = True
120121 if self .is_cuda :
121122 stream = torch .cuda .Stream ()
122123 stream_context = partial (torch .cuda .stream , stream = stream )
124+ elif self .is_npu :
125+ stream = torch .npu .Stream ()
126+ stream_context = partial (torch .npu .stream , stream = stream )
123127 else :
124128 stream = None
125129 stream_context = suppress
@@ -139,7 +143,10 @@ def __iter__(self):
139143 first = False
140144
141145 if stream is not None :
142- torch .cuda .current_stream ().wait_stream (stream )
146+ if self .is_cuda :
147+ torch .cuda .current_stream ().wait_stream (stream )
148+ elif self .is_npu :
149+ torch .npu .current_stream ().wait_stream (stream )
143150
144151 input = next_input
145152 target = next_target
Original file line number Diff line number Diff line change @@ -116,6 +116,7 @@ def init_distributed_device_so(
116116 "xpu" : "ccl" ,
117117 "hpu" : "hccl" ,
118118 "cuda" : "nccl" ,
119+ "npu" : "hccl" ,
119120 }
120121 dist_backend = dist_backends .get (device_type , 'gloo' )
121122 dist_url = dist_url or 'env://'
@@ -159,6 +160,8 @@ def init_distributed_device_so(
159160
160161 if device_type == 'cuda' :
161162 assert torch .cuda .is_available (), f'CUDA is not available but { device } was specified.'
163+ if device_type == 'npu' :
164+ assert torch .npu .is_available (), f'Ascend NPU is not available but { device } was specified.'
162165
163166 if distributed and device != 'cpu' :
164167 # Ignore manually specified device index in distributed mode and
Original file line number Diff line number Diff line change @@ -1054,8 +1054,11 @@ def _backward(_loss):
10541054 if model_ema is not None :
10551055 model_ema .update (model , step = num_updates )
10561056
1057- if args .synchronize_step and device .type == 'cuda' :
1058- torch .cuda .synchronize ()
1057+ if args .synchronize_step :
1058+ if device .type == 'cuda' :
1059+ torch .cuda .synchronize ()
1060+ elif device .type == 'npu' :
1061+ torch .npu .synchronize ()
10591062 time_now = time .time ()
10601063 update_time_m .update (time .time () - update_start_time )
10611064 update_start_time = time_now
@@ -1155,6 +1158,8 @@ def validate(
11551158
11561159 if device .type == 'cuda' :
11571160 torch .cuda .synchronize ()
1161+ elif device .type == "npu" :
1162+ torch .npu .synchronize ()
11581163
11591164 losses_m .update (reduced_loss .item (), input .size (0 ))
11601165 top1_m .update (acc1 .item (), output .size (0 ))
Original file line number Diff line number Diff line change @@ -397,6 +397,8 @@ def _try_run(args, initial_batch_size):
397397 try :
398398 if torch .cuda .is_available () and 'cuda' in args .device :
399399 torch .cuda .empty_cache ()
400+ elif torch .npu .is_available () and "npu" in args .device :
401+ torch .npu .empty_cache ()
400402 results = validate (args )
401403 return results
402404 except RuntimeError as e :
You can’t perform that action at this time.
0 commit comments