@@ -273,13 +273,11 @@ def main(args):
273273 pp_rank = pp_mesh .get_local_rank ()
274274 tp_group = tp_mesh .get_group ()
275275 pp_group = pp_mesh .get_group ()
276- pp_group_size = pp_group .size ()
277- tp_group_size = tp_group .size ()
278- logger .info (f"{ pp_group_size = } , { tp_group_size = } " )
276+ logger .info (f"{ pp_degree = } , { tp_degree = } " )
279277
280278 # Convenience variables
281279 first_pp_rank = 0
282- last_pp_rank = pp_group_size - 1
280+ last_pp_rank = pp_degree - 1
283281
284282 # Assuming same number of GPUs per node
285283 device = torch .device (f"cuda:{ rank % torch .cuda .device_count ()} " )
@@ -292,23 +290,28 @@ def main(args):
292290 # TODO: we should create model instead of Transformer
293291 model = Transformer (config )
294292
295- # Distribute model on TP mesh
296- model .distribute (tp_mesh )
297293 if rank == 0 :
298294 logger .info (f"Model: { model } " )
299295
300- mbs = 1 # number of micro-batches
301- mb_size = 4 # micro-batch size
302- batch_size = mbs * mb_size # total batch size
296+ # Distribute model on TP mesh
297+ model .distribute (tp_mesh )
303298
299+ # Batch size. Since we push batches dynamically through the pipeline rather
300+ # than chunking them, this is effectively micro-batch size in pipeline
301+ # sense. Thus it is interchangeable with micro-batch size below.
302+ batch_size = 4
304303 seqlen_prefill = 1024 # sequence length
305304 dim = 4096 # embedding dimension
306305
307306 # Setup KV caches (after model distribution)
308- # TODO: the setting below only works for 1 micro-batch case. To support
309- # multiple micro-batches, we need the KV cache in the model to be aware of
310- # the number of micro-batches and the current micro-batch index.
311- model .setup_caches (mb_size , seqlen_prefill )
307+ # The number of cache lanes is the same as the maximum number of
308+ # micro-batches that can be "in flight" in parallel -- imagine each
309+ # micro-batch takes 1 "pipeline lane," they need distinct KV cache spaces.
310+ # When decoding is done for certain micro-batches, we can reuse the KV cache
311+ # lanes.
312+ # TODO: bump up the lane count
313+ cache_lanes = 1
314+ model .setup_caches (batch_size , seqlen_prefill , cache_lanes = cache_lanes )
312315
313316 # Load weights
314317 logger .info (f"Loading weights for { pp_rank = } on { device = } " )
@@ -317,7 +320,7 @@ def main(args):
317320 model .to (device )
318321
319322 logger .info (
320- f"{ color .green } Total weight loading time: { timer .get_time ()} { timer .unit } for stage { rank } { color .reset } "
323+ f"{ color .green } Total weight loading time: { timer .get_time ()} { timer .unit } for rank { rank } { color .reset } "
321324 )
322325
323326 # info on stage size and params
@@ -335,12 +338,12 @@ def main(args):
335338
336339 # Helper function to get example inputs and outputs for the stages.
337340 def get_example_ins_outs (seqlen : int ) -> Tuple [torch .Tensor , torch .Tensor ]:
338- mb_ids = torch .randint (0 , config .vocab_size , (mb_size , seqlen ), device = device )
341+ mb_ids = torch .randint (0 , config .vocab_size , (batch_size , seqlen ), device = device )
339342 activation = torch .rand (
340- mb_size , seqlen , dim , device = device , dtype = model_dtype
343+ batch_size , seqlen , dim , device = device , dtype = model_dtype
341344 )
342345 logits = torch .rand (
343- mb_size , seqlen , config .vocab_size , device = device , dtype = model_dtype
346+ batch_size , seqlen , config .vocab_size , device = device , dtype = model_dtype
344347 )
345348 example_inputs = (mb_ids if pp_rank == first_pp_rank else activation ,)
346349 example_outputs = (logits if pp_rank == last_pp_rank else activation ,)
@@ -358,8 +361,13 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:
358361 output_args = example_outputs ,
359362 group = pp_group ,
360363 )
364+ # Number of micro-batches for the schedule is 1, because each step() call we
365+ # only push 1 micro-batch into the pipeline. But we can continuously push
366+ # new micro-batches into the pipeline as they arrive, achieving same
367+ # pipelining effect.
368+ mbs = 1
361369 # create schedule
362- prefill_schedule = ScheduleGPipe (prefill_stage , mbs )
370+ prefiller = ScheduleGPipe (prefill_stage , mbs )
363371
364372 prompt = [
365373 "What is a computer?" ,
@@ -401,14 +409,15 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:
401409 num_tokens = 40
402410
403411 # Prefill phase
404- # Run context input through pipeline, in 1 step
412+ # Run context input through pipeline
413+ # TODO: we need to pass `input_pos` and `cache_lane` to each stage.
405414 with torch .no_grad ():
406415 if pp_rank == first_pp_rank :
407- output = prefill_schedule .step (padded_sequence )
416+ output = prefiller .step (padded_sequence )
408417 elif pp_rank == last_pp_rank :
409- output = prefill_schedule .step ()
418+ output = prefiller .step ()
410419 else : # middle pp ranks
411- prefill_schedule .step ()
420+ prefiller .step ()
412421
413422 # Decode the output -- first generated token
414423 if pp_rank == last_pp_rank :
@@ -445,7 +454,7 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:
445454 group = pp_group ,
446455 )
447456 # create schedule
448- decode_schedule = ScheduleGPipe (decode_stage , mbs )
457+ decorder = ScheduleGPipe (decode_stage , mbs )
449458
450459 # Decoding
451460 with torch .no_grad ():
@@ -467,11 +476,11 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:
467476
468477 # Run data through pipeline
469478 if pp_rank == first_pp_rank :
470- output = decode_schedule .step (new_token )
479+ output = decorder .step (new_token )
471480 elif pp_rank == last_pp_rank :
472- output = decode_schedule .step ()
481+ output = decorder .step ()
473482 else : # middle pp ranks
474- decode_schedule .step ()
483+ decorder .step ()
475484
476485 # Decode the output
477486 if pp_rank == last_pp_rank :
0 commit comments