Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit facb0b7

Browse files
command line argument
1 parent 0f66ad0 commit facb0b7

File tree

2 files changed

+15
-4
lines changed

2 files changed

+15
-4
lines changed

torchchat/cli/cli.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -359,6 +359,12 @@ def _add_generation_args(parser, verb: str) -> None:
359359
default=1,
360360
help="Number of samples",
361361
)
362+
generator_parser.add_argument(
363+
"--accumulate-tokens",
364+
type=int,
365+
default=8,
366+
help="Number of generated tokens to accumulate before calling the callback on each one of them.",
367+
)
362368

363369
generator_parser.add_argument(
364370
"--image-prompts",

torchchat/generate.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,7 @@ class GeneratorArgs:
230230
max_autotune: bool = False
231231
# (Misnomer) See Issue: https://github.com/pytorch/torchchat/issues/1273
232232
is_torchtune_model: bool = False
233+
accumulate_tokens: int = 8
233234

234235
def __post_init__(self):
235236
if self.compile_prefill and self.sequential_prefill:
@@ -294,6 +295,7 @@ def from_args(cls, args):
294295
sequential_prefill=sequential_prefill,
295296
max_autotune=args.max_autotune,
296297
is_torchtune_model=args.model and args.model.endswith("tune"),
298+
accumulate_tokens=getattr(args, "accumulate_tokens", 8),
297299
)
298300

299301

@@ -530,6 +532,7 @@ def decode_n_tokens(
530532
need_probs: bool,
531533
batch=Optional[Dict[str, Any]], # Inputs for multimodal models
532534
callback=lambda _: _,
535+
accumulate_tokens: int = 8,
533536
eos_token_id: int = 2,
534537
eot_id: Optional[int] = None,
535538
attention_backend: SDPBackend = torch.nn.attention.SDPBackend.MATH,
@@ -582,10 +585,9 @@ def decode_n_tokens(
582585
yield cur_token.clone(), next_prob.clone()
583586
break
584587
else:
585-
CALLBACK_BATCH = 8
586-
callback_pos = _i % CALLBACK_BATCH + 1
587-
if done_generating or callback_pos == CALLBACK_BATCH:
588-
callback_num = min(CALLBACK_BATCH, callback_pos)
588+
callback_pos = _i % accumulate_tokens + 1
589+
if done_generating or callback_pos == accumulate_tokens:
590+
callback_num = min(accumulate_tokens, callback_pos)
589591
for i in range(callback_num, 0, -1):
590592
callback(new_tokens[-i], done_generating=done_generating)
591593

@@ -706,6 +708,7 @@ def generate(
706708
speculate_k: Optional[int] = 8,
707709
sequential_prefill=True,
708710
callback=lambda x: x,
711+
accumulate_tokens: int,
709712
max_seq_length: int,
710713
attention_backend: SDPBackend = torch.nn.attention.SDPBackend.MATH,
711714
seed: Optional[int] = None,
@@ -816,6 +819,7 @@ def generate(
816819
max_new_tokens - 1,
817820
batch=batch,
818821
callback=callback,
822+
accumulate_tokens=accumulate_tokens,
819823
need_probs=False,
820824
eos_token_id=self.tokenizer.eos_id() if self.tokenizer else 2,
821825
eot_id=(
@@ -1204,6 +1208,7 @@ def callback(x, *, done_generating=False):
12041208
chat_mode=generator_args.chat_mode,
12051209
batch=batch,
12061210
callback=callback,
1211+
accumulate_tokens=generator_args.accumulate_tokens,
12071212
temperature=generator_args.temperature,
12081213
top_k=generator_args.top_k,
12091214
sequential_prefill=generator_args.sequential_prefill,

0 commit comments

Comments
 (0)