Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit e33e681

Browse files
committed
[WIP][Distributed] Add lanes to KV cache
1 parent 8d01d9b commit e33e681

File tree

2 files changed

+47
-36
lines changed

2 files changed

+47
-36
lines changed

dist_run.py

Lines changed: 35 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -273,13 +273,11 @@ def main(args):
273273
pp_rank = pp_mesh.get_local_rank()
274274
tp_group = tp_mesh.get_group()
275275
pp_group = pp_mesh.get_group()
276-
pp_group_size = pp_group.size()
277-
tp_group_size = tp_group.size()
278-
logger.info(f"{pp_group_size=}, {tp_group_size=}")
276+
logger.info(f"{pp_degree=}, {tp_degree=}")
279277

280278
# Convenience variables
281279
first_pp_rank = 0
282-
last_pp_rank = pp_group_size - 1
280+
last_pp_rank = pp_degree - 1
283281

284282
# Assuming same number of GPUs per node
285283
device = torch.device(f"cuda:{rank % torch.cuda.device_count()}")
@@ -292,23 +290,28 @@ def main(args):
292290
# TODO: we should create model instead of Transformer
293291
model = Transformer(config)
294292

295-
# Distribute model on TP mesh
296-
model.distribute(tp_mesh)
297293
if rank == 0:
298294
logger.info(f"Model: {model}")
299295

300-
mbs = 1 # number of micro-batches
301-
mb_size = 4 # micro-batch size
302-
batch_size = mbs * mb_size # total batch size
296+
# Distribute model on TP mesh
297+
model.distribute(tp_mesh)
303298

299+
# Batch size. Since we push batches dynamically through the pipeline rather
300+
# than chunking them, this is effectively micro-batch size in pipeline
301+
# sense. Thus it is interchangeable with micro-batch size below.
302+
batch_size = 4
304303
seqlen_prefill = 1024 # sequence length
305304
dim = 4096 # embedding dimension
306305

307306
# Setup KV caches (after model distribution)
308-
# TODO: the setting below only works for 1 micro-batch case. To support
309-
# multiple micro-batches, we need the KV cache in the model to be aware of
310-
# the number of micro-batches and the current micro-batch index.
311-
model.setup_caches(mb_size, seqlen_prefill)
307+
# The number of cache lanes is the same as the maximum number of
308+
# micro-batches that can be "in flight" in parallel -- imagine each
309+
# micro-batch takes 1 "pipeline lane," they need distinct KV cache spaces.
310+
# When decoding is done for certain micro-batches, we can reuse the KV cache
311+
# lanes.
312+
# TODO: bump up the lane count
313+
cache_lanes = 1
314+
model.setup_caches(batch_size, seqlen_prefill, cache_lanes=cache_lanes)
312315

313316
# Load weights
314317
logger.info(f"Loading weights for {pp_rank=} on {device=}")
@@ -317,7 +320,7 @@ def main(args):
317320
model.to(device)
318321

319322
logger.info(
320-
f"{color.green}Total weight loading time: {timer.get_time()} {timer.unit} for stage {rank}{color.reset}"
323+
f"{color.green}Total weight loading time: {timer.get_time()} {timer.unit} for rank {rank}{color.reset}"
321324
)
322325

323326
# info on stage size and params
@@ -335,12 +338,12 @@ def main(args):
335338

336339
# Helper function to get example inputs and outputs for the stages.
337340
def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:
338-
mb_ids = torch.randint(0, config.vocab_size, (mb_size, seqlen), device=device)
341+
mb_ids = torch.randint(0, config.vocab_size, (batch_size, seqlen), device=device)
339342
activation = torch.rand(
340-
mb_size, seqlen, dim, device=device, dtype=model_dtype
343+
batch_size, seqlen, dim, device=device, dtype=model_dtype
341344
)
342345
logits = torch.rand(
343-
mb_size, seqlen, config.vocab_size, device=device, dtype=model_dtype
346+
batch_size, seqlen, config.vocab_size, device=device, dtype=model_dtype
344347
)
345348
example_inputs = (mb_ids if pp_rank == first_pp_rank else activation,)
346349
example_outputs = (logits if pp_rank == last_pp_rank else activation,)
@@ -358,8 +361,13 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:
358361
output_args=example_outputs,
359362
group=pp_group,
360363
)
364+
# Number of micro-batches for the schedule is 1, because each step() call we
365+
# only push 1 micro-batch into the pipeline. But we can continuously push
366+
# new micro-batches into the pipeline as they arrive, achieving same
367+
# pipelining effect.
368+
mbs = 1
361369
# create schedule
362-
prefill_schedule = ScheduleGPipe(prefill_stage, mbs)
370+
prefiller = ScheduleGPipe(prefill_stage, mbs)
363371

364372
prompt = [
365373
"What is a computer?",
@@ -401,14 +409,15 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:
401409
num_tokens = 40
402410

403411
# Prefill phase
404-
# Run context input through pipeline, in 1 step
412+
# Run context input through pipeline
413+
# TODO: we need to pass `input_pos` and `cache_lane` to each stage.
405414
with torch.no_grad():
406415
if pp_rank == first_pp_rank:
407-
output = prefill_schedule.step(padded_sequence)
416+
output = prefiller.step(padded_sequence)
408417
elif pp_rank == last_pp_rank:
409-
output = prefill_schedule.step()
418+
output = prefiller.step()
410419
else: # middle pp ranks
411-
prefill_schedule.step()
420+
prefiller.step()
412421

413422
# Decode the output -- first generated token
414423
if pp_rank == last_pp_rank:
@@ -445,7 +454,7 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:
445454
group=pp_group,
446455
)
447456
# create schedule
448-
decode_schedule = ScheduleGPipe(decode_stage, mbs)
457+
decorder = ScheduleGPipe(decode_stage, mbs)
449458

450459
# Decoding
451460
with torch.no_grad():
@@ -467,11 +476,11 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:
467476

468477
# Run data through pipeline
469478
if pp_rank == first_pp_rank:
470-
output = decode_schedule.step(new_token)
479+
output = decorder.step(new_token)
471480
elif pp_rank == last_pp_rank:
472-
output = decode_schedule.step()
481+
output = decorder.step()
473482
else: # middle pp ranks
474-
decode_schedule.step()
483+
decorder.step()
475484

476485
# Decode the output
477486
if pp_rank == last_pp_rank:

torchchat/model.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -606,7 +606,7 @@ def __init__(self, config: TransformerArgs) -> None:
606606
self.max_batch_size = -1
607607
self.max_seq_length = -1
608608

609-
def setup_caches(self, max_batch_size, max_seq_length):
609+
def setup_caches(self, max_batch_size, max_seq_length, cache_lanes: int = 1):
610610
if (
611611
self.max_seq_length >= max_seq_length
612612
and self.max_batch_size >= max_batch_size
@@ -620,7 +620,7 @@ def setup_caches(self, max_batch_size, max_seq_length):
620620
# parallelism may have been applied there and the `n_local_heads``
621621
# value being adjusted.
622622
b.attention.setup_cache(
623-
max_batch_size, max_seq_length,
623+
max_batch_size, max_seq_length, cache_lanes=cache_lanes
624624
)
625625

626626
freqs_cis = precompute_freqs_cis(
@@ -658,7 +658,7 @@ def distribute(self, device_mesh: DeviceMesh):
658658
def setup_input_pos(self, input_pos: Tensor) -> None:
659659
self._input_pos = input_pos
660660

661-
def forward(self, x: Tensor, input_pos: Optional[Tensor] = None) -> Tensor:
661+
def forward(self, x: Tensor, input_pos: Optional[Tensor] = None, cache_lane: int = 1) -> Tensor:
662662
assert self.freqs_cis is not None, "Caches must be initialized first"
663663
# TODO: find a better way to pass input_pos to non-0 pipeline stages
664664
input_pos = input_pos if input_pos is not None else self._input_pos
@@ -668,7 +668,7 @@ def forward(self, x: Tensor, input_pos: Optional[Tensor] = None) -> Tensor:
668668
x = self.tok_embeddings(x)
669669

670670
for _, layer in self.layers.items():
671-
x = layer(x, input_pos, freqs_cis, mask)
671+
x = layer(x, input_pos, freqs_cis, mask, cache_lane=cache_lane)
672672

673673
if self.norm:
674674
x = self.norm(x)
@@ -691,7 +691,7 @@ def distribute(self, device_mesh: DeviceMesh):
691691
self.feed_forward.distribute(device_mesh)
692692

693693
def forward(
694-
self, x: Tensor, input_pos: Tensor, freqs_cis: Tensor, mask: Tensor
694+
self, x: Tensor, input_pos: Tensor, freqs_cis: Tensor, mask: Tensor, cache_lane: int = 0
695695
) -> Tensor:
696696
h = x + self.attention(self.attention_norm(x), freqs_cis, mask, input_pos)
697697
out = h + self.feed_forward(self.ffn_norm(h))
@@ -723,15 +723,16 @@ def __init__(self, config: TransformerArgs):
723723
self.dim = config.dim
724724
self._register_load_state_dict_pre_hook(self.load_hook)
725725

726-
def setup_cache(self, max_batch_size, max_seq_length):
726+
def setup_cache(self, max_batch_size, max_seq_length, cache_lanes: int = 1):
727727
n_local_heads = self.n_local_heads
728728
# If TP is enabled, the heads would be divided and assigned to different ranks
729729
if hasattr(self, "tp_degree"):
730730
n_local_heads = self.n_local_heads // self.tp_degree
731731

732-
self.kv_cache = KVCache(
733-
max_batch_size, max_seq_length, n_local_heads, self.head_dim
734-
)
732+
self.kv_cache = nn.ModuleList([
733+
KVCache(max_batch_size, max_seq_length, n_local_heads, self.head_dim)
734+
for _ in range(cache_lanes)
735+
])
735736

736737
def load_hook(self, state_dict, prefix, *args):
737738
# if prefix + "wq.weight" in state_dict:
@@ -784,6 +785,7 @@ def forward(
784785
freqs_cis: Tensor,
785786
mask: Tensor,
786787
input_pos: Optional[Tensor] = None,
788+
cache_lane: int = 0,
787789
) -> Tensor:
788790
bsz, seqlen, _ = x.shape
789791

@@ -809,7 +811,7 @@ def forward(
809811
q, k, v = (x.transpose(1, 2) for x in (q, k, v))
810812

811813
if self.kv_cache is not None:
812-
k, v = self.kv_cache.update(input_pos, k, v)
814+
k, v = self.kv_cache[cache_lane].update(input_pos, k, v)
813815

814816
k = k.repeat_interleave(self.n_heads // self.n_local_heads, dim=1)
815817
v = v.repeat_interleave(self.n_heads // self.n_local_heads, dim=1)

0 commit comments

Comments
 (0)