Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit 05129a2

Browse files
authored
Revert "[Distributed] Add lanes to KV cache (#1174)"
This reverts commit 2cf4016.
1 parent 2cf4016 commit 05129a2

File tree

3 files changed

+51
-61
lines changed

3 files changed

+51
-61
lines changed

dist_run.py

Lines changed: 32 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -273,11 +273,13 @@ def main(args):
273273
pp_rank = pp_mesh.get_local_rank()
274274
tp_group = tp_mesh.get_group()
275275
pp_group = pp_mesh.get_group()
276-
logger.info(f"{pp_degree=}, {tp_degree=}")
276+
pp_group_size = pp_group.size()
277+
tp_group_size = tp_group.size()
278+
logger.info(f"{pp_group_size=}, {tp_group_size=}")
277279

278280
# Convenience variables
279281
first_pp_rank = 0
280-
last_pp_rank = pp_degree - 1
282+
last_pp_rank = pp_group_size - 1
281283

282284
# Assuming same number of GPUs per node
283285
device = torch.device(f"cuda:{rank % torch.cuda.device_count()}")
@@ -295,22 +297,18 @@ def main(args):
295297
if rank == 0:
296298
logger.info(f"Model: {model}")
297299

298-
# Batch size. Since we push batches dynamically through the pipeline rather
299-
# than chunking them, this is effectively micro-batch size in pipeline
300-
# sense. Thus it is interchangeable with micro-batch size below.
301-
batch_size = 4
300+
mbs = 1 # number of micro-batches
301+
mb_size = 4 # micro-batch size
302+
batch_size = mbs * mb_size # total batch size
303+
302304
seqlen_prefill = 1024 # sequence length
303305
dim = 4096 # embedding dimension
304306

305307
# Setup KV caches (after model distribution)
306-
# The number of cache lanes is the same as the maximum number of
307-
# micro-batches that can be "in flight" in parallel -- imagine each
308-
# micro-batch takes 1 "pipeline lane," they need distinct KV cache spaces.
309-
# When decoding is done for certain micro-batches, we can reuse the KV cache
310-
# lanes.
311-
# TODO: bump up the lane count
312-
pipeline_lanes = 1
313-
model.setup_caches(batch_size, seqlen_prefill, cache_lanes=pipeline_lanes)
308+
# TODO: the setting below only works for 1 micro-batch case. To support
309+
# multiple micro-batches, we need the KV cache in the model to be aware of
310+
# the number of micro-batches and the current micro-batch index.
311+
model.setup_caches(mb_size, seqlen_prefill)
314312

315313
# Load weights
316314
logger.info(f"Loading weights for {pp_rank=} on {device=}")
@@ -319,7 +317,7 @@ def main(args):
319317
model.to(device)
320318

321319
logger.info(
322-
f"{color.green}Total weight loading time: {timer.get_time()} {timer.unit} for rank {rank}{color.reset}"
320+
f"{color.green}Total weight loading time: {timer.get_time()} {timer.unit} for stage {rank}{color.reset}"
323321
)
324322

325323
# info on stage size and params
@@ -332,16 +330,17 @@ def main(args):
332330

333331
# Setup input position (input_pos) for prefill: a list of increasing integers from 0 to seqlen
334332
input_pos = torch.arange(seqlen_prefill, device=device)
333+
model.setup_input_pos(input_pos)
335334
model.eval()
336335

337336
# Helper function to get example inputs and outputs for the stages.
338337
def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:
339-
mb_ids = torch.randint(0, config.vocab_size, (batch_size, seqlen), device=device)
338+
mb_ids = torch.randint(0, config.vocab_size, (mb_size, seqlen), device=device)
340339
activation = torch.rand(
341-
batch_size, seqlen, dim, device=device, dtype=model_dtype
340+
mb_size, seqlen, dim, device=device, dtype=model_dtype
342341
)
343342
logits = torch.rand(
344-
batch_size, seqlen, config.vocab_size, device=device, dtype=model_dtype
343+
mb_size, seqlen, config.vocab_size, device=device, dtype=model_dtype
345344
)
346345
example_inputs = (mb_ids if pp_rank == first_pp_rank else activation,)
347346
example_outputs = (logits if pp_rank == last_pp_rank else activation,)
@@ -359,13 +358,8 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:
359358
output_args=example_outputs,
360359
group=pp_group,
361360
)
362-
363-
# Create schedule
364-
# Number of micro-batches for the schedule is 1, because each step() call we
365-
# only push 1 micro-batch into the pipeline. But we can continuously push
366-
# new micro-batches into the pipeline as they arrive, achieving same
367-
# pipelining effect.
368-
prefiller = ScheduleGPipe(prefill_stage, 1)
361+
# create schedule
362+
prefill_schedule = ScheduleGPipe(prefill_stage, mbs)
369363

370364
prompt = [
371365
"What is a computer?",
@@ -394,6 +388,7 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:
394388
s = set(prompt_lengths)
395389
assert len(s) == 1, f"prompt_lengths should be the same, got {s}"
396390

391+
# with CUDATrackTime() as timer:
397392
# Need these global ids due to the API definition of dist.send and recv
398393
first_pp_rank_global_id = dist.get_global_rank(pp_group, first_pp_rank)
399394
last_pp_rank_global_id = dist.get_global_rank(pp_group, last_pp_rank)
@@ -406,21 +401,14 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:
406401
num_tokens = 40
407402

408403
# Prefill phase
409-
# Run context input through pipeline
410-
# TODO: we need to pass `input_pos` and `cache_lane` to each stage.
411-
lane = 0
412-
kwargs = {"input_pos": input_pos, "cache_lane": lane}
413-
with torch.no_grad(), CUDATrackTime() as timer:
404+
# Run context input through pipeline, in 1 step
405+
with torch.no_grad():
414406
if pp_rank == first_pp_rank:
415-
output = prefiller.step(padded_sequence, **kwargs)
407+
output = prefill_schedule.step(padded_sequence)
416408
elif pp_rank == last_pp_rank:
417-
output = prefiller.step(**kwargs)
409+
output = prefill_schedule.step()
418410
else: # middle pp ranks
419-
prefiller.step(**kwargs)
420-
421-
logger.info(
422-
f"{color.green}Prefilling time: {timer.get_time()} {timer.unit} for rank {rank}{color.reset}"
423-
)
411+
prefill_schedule.step()
424412

425413
# Decode the output -- first generated token
426414
if pp_rank == last_pp_rank:
@@ -442,6 +430,7 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:
442430
# seqlen = 1 now
443431
seqlen_decode = 1
444432
input_pos = torch.tensor([prompt_lengths[0]], device=device)
433+
model.setup_input_pos(input_pos)
445434

446435
# Create decode stage
447436
logger.info(f"Creating pipeline stage for decode {pp_rank=}, {pp_degree=}")
@@ -456,12 +445,11 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:
456445
group=pp_group,
457446
)
458447
# create schedule
459-
decorder = ScheduleGPipe(decode_stage, 1)
448+
decode_schedule = ScheduleGPipe(decode_stage, mbs)
460449

461450
# Decoding
462-
with torch.no_grad(), CUDATrackTime() as timer:
451+
with torch.no_grad():
463452
for step in range(num_tokens - 1):
464-
kwargs = {"input_pos": input_pos, "cache_lane": lane}
465453
# sendrecv between last and first ranks, only if:
466454
# first_pp_rank != last_pp_rank.
467455
if pp_rank == last_pp_rank and pp_rank != first_pp_rank:
@@ -479,11 +467,11 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:
479467

480468
# Run data through pipeline
481469
if pp_rank == first_pp_rank:
482-
output = decorder.step(new_token, **kwargs)
470+
output = decode_schedule.step(new_token)
483471
elif pp_rank == last_pp_rank:
484-
output = decorder.step(**kwargs)
472+
output = decode_schedule.step()
485473
else: # middle pp ranks
486-
decorder.step(**kwargs)
474+
decode_schedule.step()
487475

488476
# Decode the output
489477
if pp_rank == last_pp_rank:
@@ -503,10 +491,7 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:
503491
) # decode_results[i][0]
504492

505493
input_pos += 1
506-
507-
logger.info(
508-
f"{color.green}Decoding time: {timer.get_time()} {timer.unit} for rank {rank}{color.reset}"
509-
)
494+
model.setup_input_pos(input_pos)
510495

511496
# Display the decoding results
512497

torchchat/export.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -152,9 +152,9 @@ def __init__(self, attention: Attention):
152152
self.wo = attention.wo
153153

154154
max_batch_size, n_heads, max_seq_length, head_dim = (
155-
attention.kv_cache[0].k_cache.shape
155+
attention.kv_cache.k_cache.shape
156156
)
157-
cache_dtype = attention.kv_cache[0].k_cache.dtype
157+
cache_dtype = attention.kv_cache.k_cache.dtype
158158
self.kv_cache = CustomKVCache(
159159
max_batch_size, max_seq_length, n_heads, head_dim, cache_dtype
160160
)

torchchat/model.py

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -606,7 +606,7 @@ def __init__(self, config: TransformerArgs) -> None:
606606
self.max_batch_size = -1
607607
self.max_seq_length = -1
608608

609-
def setup_caches(self, max_batch_size, max_seq_length, cache_lanes: int = 1):
609+
def setup_caches(self, max_batch_size, max_seq_length):
610610
if (
611611
self.max_seq_length >= max_seq_length
612612
and self.max_batch_size >= max_batch_size
@@ -620,7 +620,7 @@ def setup_caches(self, max_batch_size, max_seq_length, cache_lanes: int = 1):
620620
# parallelism may have been applied there and the `n_local_heads``
621621
# value being adjusted.
622622
b.attention.setup_cache(
623-
max_batch_size, max_seq_length, cache_lanes=cache_lanes
623+
max_batch_size, max_seq_length,
624624
)
625625

626626
freqs_cis = precompute_freqs_cis(
@@ -653,15 +653,22 @@ def distribute(self, device_mesh: DeviceMesh):
653653
ColwiseParallel(output_layouts=Replicate()),
654654
)
655655

656-
def forward(self, x: Tensor, input_pos: Optional[Tensor] = None, cache_lane: int = 1) -> Tensor:
656+
# This is a temporary solution to pass input_pos to non-0 pipeline stages
657+
# TODO: make `step()` function of dist.pipelining accept args for non-0 stages
658+
def setup_input_pos(self, input_pos: Tensor) -> None:
659+
self._input_pos = input_pos
660+
661+
def forward(self, x: Tensor, input_pos: Optional[Tensor] = None) -> Tensor:
657662
assert self.freqs_cis is not None, "Caches must be initialized first"
663+
# TODO: find a better way to pass input_pos to non-0 pipeline stages
664+
input_pos = input_pos if input_pos is not None else self._input_pos
658665
mask = self.causal_mask[None, None, input_pos]
659666
freqs_cis = self.freqs_cis[input_pos]
660667
if self.tok_embeddings:
661668
x = self.tok_embeddings(x)
662669

663670
for _, layer in self.layers.items():
664-
x = layer(x, input_pos, freqs_cis, mask, cache_lane=cache_lane)
671+
x = layer(x, input_pos, freqs_cis, mask)
665672

666673
if self.norm:
667674
x = self.norm(x)
@@ -684,7 +691,7 @@ def distribute(self, device_mesh: DeviceMesh):
684691
self.feed_forward.distribute(device_mesh)
685692

686693
def forward(
687-
self, x: Tensor, input_pos: Tensor, freqs_cis: Tensor, mask: Tensor, cache_lane: int = 0
694+
self, x: Tensor, input_pos: Tensor, freqs_cis: Tensor, mask: Tensor
688695
) -> Tensor:
689696
h = x + self.attention(self.attention_norm(x), freqs_cis, mask, input_pos)
690697
out = h + self.feed_forward(self.ffn_norm(h))
@@ -716,16 +723,15 @@ def __init__(self, config: TransformerArgs):
716723
self.dim = config.dim
717724
self._register_load_state_dict_pre_hook(self.load_hook)
718725

719-
def setup_cache(self, max_batch_size, max_seq_length, cache_lanes: int = 1):
726+
def setup_cache(self, max_batch_size, max_seq_length):
720727
n_local_heads = self.n_local_heads
721728
# If TP is enabled, the heads would be divided and assigned to different ranks
722729
if hasattr(self, "tp_degree"):
723730
n_local_heads = self.n_local_heads // self.tp_degree
724731

725-
self.kv_cache = nn.ModuleList([
726-
KVCache(max_batch_size, max_seq_length, n_local_heads, self.head_dim)
727-
for _ in range(cache_lanes)
728-
])
732+
self.kv_cache = KVCache(
733+
max_batch_size, max_seq_length, n_local_heads, self.head_dim
734+
)
729735

730736
def load_hook(self, state_dict, prefix, *args):
731737
# if prefix + "wq.weight" in state_dict:
@@ -778,7 +784,6 @@ def forward(
778784
freqs_cis: Tensor,
779785
mask: Tensor,
780786
input_pos: Optional[Tensor] = None,
781-
cache_lane: int = 0,
782787
) -> Tensor:
783788
bsz, seqlen, _ = x.shape
784789

@@ -804,7 +809,7 @@ def forward(
804809
q, k, v = (x.transpose(1, 2) for x in (q, k, v))
805810

806811
if self.kv_cache is not None:
807-
k, v = self.kv_cache[cache_lane].update(input_pos, k, v)
812+
k, v = self.kv_cache.update(input_pos, k, v)
808813

809814
k = k.repeat_interleave(self.n_heads // self.n_local_heads, dim=1)
810815
v = v.repeat_interleave(self.n_heads // self.n_local_heads, dim=1)

0 commit comments

Comments
 (0)