@@ -230,6 +230,7 @@ class GeneratorArgs:
230230 max_autotune : bool = False
231231 # (Misnomer) See Issue: https://github.com/pytorch/torchchat/issues/1273
232232 is_torchtune_model : bool = False
233+ accumulate_tokens : int = 8
233234
234235 def __post_init__ (self ):
235236 if self .compile_prefill and self .sequential_prefill :
@@ -294,6 +295,7 @@ def from_args(cls, args):
294295 sequential_prefill = sequential_prefill ,
295296 max_autotune = args .max_autotune ,
296297 is_torchtune_model = args .model and args .model .endswith ("tune" ),
298+ accumulate_tokens = getattr (args , "accumulate_tokens" , 8 ),
297299 )
298300
299301
@@ -530,6 +532,7 @@ def decode_n_tokens(
530532 need_probs : bool ,
531533 batch = Optional [Dict [str , Any ]], # Inputs for multimodal models
532534 callback = lambda _ : _ ,
535+ accumulate_tokens : int = 8 ,
533536 eos_token_id : int = 2 ,
534537 eot_id : Optional [int ] = None ,
535538 attention_backend : SDPBackend = torch .nn .attention .SDPBackend .MATH ,
@@ -582,10 +585,9 @@ def decode_n_tokens(
582585 yield cur_token .clone (), next_prob .clone ()
583586 break
584587 else :
585- CALLBACK_BATCH = 8
586- callback_pos = _i % CALLBACK_BATCH + 1
587- if done_generating or callback_pos == CALLBACK_BATCH :
588- callback_num = min (CALLBACK_BATCH , callback_pos )
588+ callback_pos = _i % accumulate_tokens + 1
589+ if done_generating or callback_pos == accumulate_tokens :
590+ callback_num = min (accumulate_tokens , callback_pos )
589591 for i in range (callback_num , 0 , - 1 ):
590592 callback (new_tokens [- i ], done_generating = done_generating )
591593
@@ -706,6 +708,7 @@ def generate(
706708 speculate_k : Optional [int ] = 8 ,
707709 sequential_prefill = True ,
708710 callback = lambda x : x ,
711+ accumulate_tokens : int ,
709712 max_seq_length : int ,
710713 attention_backend : SDPBackend = torch .nn .attention .SDPBackend .MATH ,
711714 seed : Optional [int ] = None ,
@@ -816,6 +819,7 @@ def generate(
816819 max_new_tokens - 1 ,
817820 batch = batch ,
818821 callback = callback ,
822+ accumulate_tokens = accumulate_tokens ,
819823 need_probs = False ,
820824 eos_token_id = self .tokenizer .eos_id () if self .tokenizer else 2 ,
821825 eot_id = (
@@ -1204,6 +1208,7 @@ def callback(x, *, done_generating=False):
12041208 chat_mode = generator_args .chat_mode ,
12051209 batch = batch ,
12061210 callback = callback ,
1211+ accumulate_tokens = generator_args .accumulate_tokens ,
12071212 temperature = generator_args .temperature ,
12081213 top_k = generator_args .top_k ,
12091214 sequential_prefill = generator_args .sequential_prefill ,
0 commit comments