@@ -230,6 +230,7 @@ class GeneratorArgs:
230230 max_autotune : bool = False
231231 # (Misnomer) See Issue: https://github.com/pytorch/torchchat/issues/1273
232232 is_torchtune_model : bool = False
233+ accumulate_tokens : int = 8
233234
234235 def __post_init__ (self ):
235236 if self .compile_prefill and self .sequential_prefill :
@@ -294,6 +295,7 @@ def from_args(cls, args):
294295 sequential_prefill = sequential_prefill ,
295296 max_autotune = args .max_autotune ,
296297 is_torchtune_model = args .model and args .model .endswith ("tune" ),
298+ accumulate_tokens = getattr (args , "accumulate_tokens" , 8 ),
297299 )
298300
299301
@@ -530,12 +532,13 @@ def decode_n_tokens(
530532 need_probs : bool ,
531533 batch = Optional [Dict [str , Any ]], # Inputs for multimodal models
532534 callback = lambda _ : _ ,
535+ accumulate_tokens : int = 8 ,
533536 eos_token_id : int = 2 ,
534537 eot_id : Optional [int ] = None ,
535538 attention_backend : SDPBackend = torch .nn .attention .SDPBackend .MATH ,
536539 ** sampling_kwargs ,
537540 ):
538- new_tokens , new_probs = [], []
541+ new_tokens = []
539542 encountered_eos = False
540543 for _i in range (
541544 num_new_tokens - 1
@@ -554,38 +557,58 @@ def decode_n_tokens(
554557 )
555558 input_pos += 1
556559 new_tokens .append (next_token .clone ())
557- callback (new_tokens [- 1 ], done_generating = _i == num_new_tokens - 2 )
558- if need_probs or next_prob is None :
560+
561+ done_generating = _i == num_new_tokens - 2
562+ if need_probs :
563+ callback (new_tokens [- 1 ], done_generating = done_generating )
564+ if not need_probs or next_prob is None :
559565 yield out_token , None
560566 else :
561- new_probs .append (next_prob .clone ())
562567 yield out_token , next_prob .clone ()
563568 cur_token = next_token
564569
565- # encountered eos
566- if next_token .item () == eos_token_id or (
567- eot_id is not None and next_token .item () == eot_id
568- ):
569- encountered_eos = True
570- final_token , next_prob = self .decode_one_token (
571- model ,
572- cur_token ,
573- input_pos ,
574- need_probs ,
575- batch = batch ,
576- ** sampling_kwargs ,
577- )
578- input_pos += 1
579- yield cur_token .clone (), next_prob .clone ()
580- break
570+ if need_probs :
571+ # encountered eos
572+ if next_token .item () == eos_token_id or (
573+ eot_id is not None and next_token .item () == eot_id
574+ ):
575+ encountered_eos = True
576+ final_token , next_prob = self .decode_one_token (
577+ model ,
578+ cur_token ,
579+ input_pos ,
580+ need_probs ,
581+ batch = batch ,
582+ ** sampling_kwargs ,
583+ )
584+ input_pos += 1
585+ yield cur_token .clone (), next_prob .clone ()
586+ break
587+ else :
588+ callback_pos = _i % accumulate_tokens + 1
589+ if done_generating or callback_pos == accumulate_tokens :
590+ callback_num = min (accumulate_tokens , callback_pos )
591+ for i in range (callback_num , 0 , - 1 ):
592+ callback (new_tokens [- i ], done_generating = done_generating )
593+
594+ token_item = new_tokens [- i ].item ()
595+ # encountered eos
596+ if token_item == eos_token_id or (
597+ eot_id is not None and token_item == eot_id
598+ ):
599+ encountered_eos = True
600+ input_pos += 1
601+ yield new_tokens [- i ].clone (), None
602+ break
603+ if encountered_eos :
604+ break
581605
582606 if not encountered_eos :
583607 eos_token = torch .tensor (
584608 [eos_token_id if eot_id is None else eot_id ],
585609 dtype = cur_token .dtype ,
586610 device = cur_token .device ,
587611 )
588- new_tokens .append (eos_token .clone ())
589612 eos_token , next_prob = self .decode_one_token (
590613 model ,
591614 eos_token .view (1 , - 1 ),
@@ -685,6 +708,7 @@ def generate(
685708 speculate_k : Optional [int ] = 8 ,
686709 sequential_prefill = True ,
687710 callback = lambda x : x ,
711+ accumulate_tokens : int ,
688712 max_seq_length : int ,
689713 attention_backend : SDPBackend = torch .nn .attention .SDPBackend .MATH ,
690714 seed : Optional [int ] = None ,
@@ -788,14 +812,14 @@ def generate(
788812 input_pos = input_pos + num_added
789813 next_token = next_tokens [- 1 ]
790814 else :
791- generated_tokens = []
792815 for generated_token , _ in self .decode_n_tokens (
793816 model ,
794817 next_token ,
795818 input_pos ,
796819 max_new_tokens - 1 ,
797820 batch = batch ,
798821 callback = callback ,
822+ accumulate_tokens = accumulate_tokens ,
799823 need_probs = False ,
800824 eos_token_id = self .tokenizer .eos_id () if self .tokenizer else 2 ,
801825 eot_id = (
@@ -806,7 +830,6 @@ def generate(
806830 attention_backend = attention_backend ,
807831 ** sampling_kwargs ,
808832 ):
809- generated_tokens .append (generated_token .view (- 1 ))
810833 yield generated_token , None
811834
812835 generate_stats = {
@@ -1185,6 +1208,7 @@ def callback(x, *, done_generating=False):
11851208 chat_mode = generator_args .chat_mode ,
11861209 batch = batch ,
11871210 callback = callback ,
1211+ accumulate_tokens = generator_args .accumulate_tokens ,
11881212 temperature = generator_args .temperature ,
11891213 top_k = generator_args .top_k ,
11901214 sequential_prefill = generator_args .sequential_prefill ,
@@ -1213,8 +1237,10 @@ def callback(x, *, done_generating=False):
12131237 print (prof .key_averages ().table (sort_by = "self_cpu_time_total" ))
12141238 elif self .builder_args .device == "cuda" :
12151239 print (prof .key_averages ().table (sort_by = "self_cuda_time_total" ))
1216- else :
1240+ elif self . builder_args . device == "xpu" :
12171241 print (prof .key_averages ().table (sort_by = "self_xpu_time_total" ))
1242+ elif self .builder_args .device == "npu" :
1243+ print (prof .key_averages ().table (sort_by = "self_npu_time_total" ))
12181244 prof .export_chrome_trace (f"{ self .profile } .json" )
12191245
12201246 if start_pos >= max_seq_length :
@@ -1229,11 +1255,7 @@ def callback(x, *, done_generating=False):
12291255 t - aggregate_metrics .get ("time_to_first_token" , 0 )
12301256 )
12311257
1232- if jit_compile :
1233- print (
1234- f"just-in-time compilation time (incl run time): { compilation_time :.2} seconds"
1235- )
1236- else :
1258+ if not jit_compile :
12371259 # aggregate_metrics will not append when is jit_compile, which will affect the average numbers.
12381260 aggregate_metrics ["tokens_per_sec" ].append (tokens_sec )
12391261 aggregate_metrics ["first_token_per_sec" ].append (first_token_sec )
@@ -1257,6 +1279,10 @@ def callback(x, *, done_generating=False):
12571279 logging .info (
12581280 f"*** This first iteration will include cold start effects for dynamic import, hardware caches{ ', JIT compilation' if jit_compile else '' } . ***"
12591281 )
1282+ if jit_compile :
1283+ logging .info (
1284+ f"just-in-time compilation time (incl run time): { compilation_time :.2} seconds"
1285+ )
12601286 print ("\n ========================================\n " )
12611287 if start_pos >= max_seq_length :
12621288 if generator_args .chat_mode :
@@ -1299,8 +1325,10 @@ def callback(x, *, done_generating=False):
12991325 )
13001326 if torch .cuda .is_available ():
13011327 print (f"Memory used: { torch .cuda .max_memory_reserved () / 1e9 :.02f} GB" )
1302- if torch .xpu .is_available ():
1328+ elif torch .xpu .is_available ():
13031329 print (f"Memory used: { torch .xpu .max_memory_reserved () / 1e9 :.02f} GB" )
1330+ elif hasattr (torch , "npu" ) and torch .npu .is_available ():
1331+ print (f"Memory used: { torch .npu .max_memory_reserved () / 1e9 :.02f} GB" )
13041332
13051333
13061334
@@ -1595,7 +1623,6 @@ def sample(
15951623
15961624 return idx_next , probs
15971625
1598-
15991626def run_generator (
16001627 args ,
16011628 rank : Optional [int ] = None
@@ -1628,8 +1655,10 @@ def run_generator(
16281655 )
16291656 if torch .cuda .is_available ():
16301657 torch .cuda .reset_peak_memory_stats ()
1631- if torch .xpu .is_available ():
1658+ elif torch .xpu .is_available ():
16321659 torch .xpu .reset_peak_memory_stats ()
1660+ elif hasattr (torch , "npu" ) and torch .npu .is_available ():
1661+ torch .npu .reset_peak_memory_stats ()
16331662
16341663 for _ in gen .chat (generator_args ):
16351664 pass
0 commit comments