@@ -230,6 +230,7 @@ class GeneratorArgs:
230230 max_autotune : bool = False
231231 # (Misnomer) See Issue: https://github.com/pytorch/torchchat/issues/1273
232232 is_torchtune_model : bool = False
233+ accumulate_tokens : int = 8
233234
234235 def __post_init__ (self ):
235236 if self .compile_prefill and self .sequential_prefill :
@@ -294,6 +295,7 @@ def from_args(cls, args):
294295 sequential_prefill = sequential_prefill ,
295296 max_autotune = args .max_autotune ,
296297 is_torchtune_model = args .model and args .model .endswith ("tune" ),
298+ accumulate_tokens = getattr (args , "accumulate_tokens" , 8 ),
297299 )
298300
299301
@@ -530,11 +532,13 @@ def decode_n_tokens(
530532 need_probs : bool ,
531533 batch = Optional [Dict [str , Any ]], # Inputs for multimodal models
532534 callback = lambda _ : _ ,
535+ accumulate_tokens : int = 8 ,
533536 eos_token_id : int = 2 ,
534537 eot_id : Optional [int ] = None ,
535538 attention_backend : SDPBackend = torch .nn .attention .SDPBackend .MATH ,
536539 ** sampling_kwargs ,
537540 ):
541+ new_tokens = []
538542 encountered_eos = False
539543 for _i in range (
540544 num_new_tokens - 1
@@ -552,29 +556,52 @@ def decode_n_tokens(
552556 ** sampling_kwargs ,
553557 )
554558 input_pos += 1
555- callback (next_token .clone (), done_generating = _i == num_new_tokens - 2 )
559+ new_tokens .append (next_token .clone ())
560+
561+ done_generating = _i == num_new_tokens - 2
562+ if need_probs :
563+ callback (new_tokens [- 1 ], done_generating = done_generating )
556564 if not need_probs or next_prob is None :
557565 yield out_token , None
558566 else :
559567 yield out_token , next_prob .clone ()
560568 cur_token = next_token
561569
562- # encountered eos
563- if next_token .item () == eos_token_id or (
564- eot_id is not None and next_token .item () == eot_id
565- ):
566- encountered_eos = True
567- final_token , next_prob = self .decode_one_token (
568- model ,
569- cur_token ,
570- input_pos ,
571- need_probs ,
572- batch = batch ,
573- ** sampling_kwargs ,
574- )
575- input_pos += 1
576- yield cur_token .clone (), next_prob .clone ()
577- break
570+ if need_probs :
571+ # encountered eos
572+ if next_token .item () == eos_token_id or (
573+ eot_id is not None and next_token .item () == eot_id
574+ ):
575+ encountered_eos = True
576+ final_token , next_prob = self .decode_one_token (
577+ model ,
578+ cur_token ,
579+ input_pos ,
580+ need_probs ,
581+ batch = batch ,
582+ ** sampling_kwargs ,
583+ )
584+ input_pos += 1
585+ yield cur_token .clone (), next_prob .clone ()
586+ break
587+ else :
588+ callback_pos = _i % accumulate_tokens + 1
589+ if done_generating or callback_pos == accumulate_tokens :
590+ callback_num = min (accumulate_tokens , callback_pos )
591+ for i in range (callback_num , 0 , - 1 ):
592+ callback (new_tokens [- i ], done_generating = done_generating )
593+
594+ token_item = new_tokens [- i ].item ()
595+ # encountered eos
596+ if token_item == eos_token_id or (
597+ eot_id is not None and token_item == eot_id
598+ ):
599+ encountered_eos = True
600+ input_pos += 1
601+ yield new_tokens [- i ].clone (), None
602+ break
603+ if encountered_eos :
604+ break
578605
579606 if not encountered_eos :
580607 eos_token = torch .tensor (
@@ -681,6 +708,7 @@ def generate(
681708 speculate_k : Optional [int ] = 8 ,
682709 sequential_prefill = True ,
683710 callback = lambda x : x ,
711+ accumulate_tokens : int ,
684712 max_seq_length : int ,
685713 attention_backend : SDPBackend = torch .nn .attention .SDPBackend .MATH ,
686714 seed : Optional [int ] = None ,
@@ -791,6 +819,7 @@ def generate(
791819 max_new_tokens - 1 ,
792820 batch = batch ,
793821 callback = callback ,
822+ accumulate_tokens = accumulate_tokens ,
794823 need_probs = False ,
795824 eos_token_id = self .tokenizer .eos_id () if self .tokenizer else 2 ,
796825 eot_id = (
@@ -1179,6 +1208,7 @@ def callback(x, *, done_generating=False):
11791208 chat_mode = generator_args .chat_mode ,
11801209 batch = batch ,
11811210 callback = callback ,
1211+ accumulate_tokens = generator_args .accumulate_tokens ,
11821212 temperature = generator_args .temperature ,
11831213 top_k = generator_args .top_k ,
11841214 sequential_prefill = generator_args .sequential_prefill ,
0 commit comments