@@ -273,11 +273,13 @@ def main(args):
273273 pp_rank = pp_mesh .get_local_rank ()
274274 tp_group = tp_mesh .get_group ()
275275 pp_group = pp_mesh .get_group ()
276- logger .info (f"{ pp_degree = } , { tp_degree = } " )
276+ pp_group_size = pp_group .size ()
277+ tp_group_size = tp_group .size ()
278+ logger .info (f"{ pp_group_size = } , { tp_group_size = } " )
277279
278280 # Convenience variables
279281 first_pp_rank = 0
280- last_pp_rank = pp_degree - 1
282+ last_pp_rank = pp_group_size - 1
281283
282284 # Assuming same number of GPUs per node
283285 device = torch .device (f"cuda:{ rank % torch .cuda .device_count ()} " )
@@ -295,22 +297,18 @@ def main(args):
295297 if rank == 0 :
296298 logger .info (f"Model: { model } " )
297299
298- # Batch size. Since we push batches dynamically through the pipeline rather
299- # than chunking them, this is effectively micro-batch size in pipeline
300- # sense. Thus it is interchangeable with micro- batch size below.
301- batch_size = 4
300+ mbs = 1 # number of micro-batches
301+ mb_size = 4 # micro-batch size
302+ batch_size = mbs * mb_size # total batch size
303+
302304 seqlen_prefill = 1024 # sequence length
303305 dim = 4096 # embedding dimension
304306
305307 # Setup KV caches (after model distribution)
306- # The number of cache lanes is the same as the maximum number of
307- # micro-batches that can be "in flight" in parallel -- imagine each
308- # micro-batch takes 1 "pipeline lane," they need distinct KV cache spaces.
309- # When decoding is done for certain micro-batches, we can reuse the KV cache
310- # lanes.
311- # TODO: bump up the lane count
312- pipeline_lanes = 1
313- model .setup_caches (batch_size , seqlen_prefill , cache_lanes = pipeline_lanes )
308+ # TODO: the setting below only works for 1 micro-batch case. To support
309+ # multiple micro-batches, we need the KV cache in the model to be aware of
310+ # the number of micro-batches and the current micro-batch index.
311+ model .setup_caches (mb_size , seqlen_prefill )
314312
315313 # Load weights
316314 logger .info (f"Loading weights for { pp_rank = } on { device = } " )
@@ -319,7 +317,7 @@ def main(args):
319317 model .to (device )
320318
321319 logger .info (
322- f"{ color .green } Total weight loading time: { timer .get_time ()} { timer .unit } for rank { rank } { color .reset } "
320+ f"{ color .green } Total weight loading time: { timer .get_time ()} { timer .unit } for stage { rank } { color .reset } "
323321 )
324322
325323 # info on stage size and params
@@ -332,16 +330,17 @@ def main(args):
332330
333331 # Setup input position (input_pos) for prefill: a list of increasing integers from 0 to seqlen
334332 input_pos = torch .arange (seqlen_prefill , device = device )
333+ model .setup_input_pos (input_pos )
335334 model .eval ()
336335
337336 # Helper function to get example inputs and outputs for the stages.
338337 def get_example_ins_outs (seqlen : int ) -> Tuple [torch .Tensor , torch .Tensor ]:
339- mb_ids = torch .randint (0 , config .vocab_size , (batch_size , seqlen ), device = device )
338+ mb_ids = torch .randint (0 , config .vocab_size , (mb_size , seqlen ), device = device )
340339 activation = torch .rand (
341- batch_size , seqlen , dim , device = device , dtype = model_dtype
340+ mb_size , seqlen , dim , device = device , dtype = model_dtype
342341 )
343342 logits = torch .rand (
344- batch_size , seqlen , config .vocab_size , device = device , dtype = model_dtype
343+ mb_size , seqlen , config .vocab_size , device = device , dtype = model_dtype
345344 )
346345 example_inputs = (mb_ids if pp_rank == first_pp_rank else activation ,)
347346 example_outputs = (logits if pp_rank == last_pp_rank else activation ,)
@@ -359,13 +358,8 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:
359358 output_args = example_outputs ,
360359 group = pp_group ,
361360 )
362-
363- # Create schedule
364- # Number of micro-batches for the schedule is 1, because each step() call we
365- # only push 1 micro-batch into the pipeline. But we can continuously push
366- # new micro-batches into the pipeline as they arrive, achieving same
367- # pipelining effect.
368- prefiller = ScheduleGPipe (prefill_stage , 1 )
361+ # create schedule
362+ prefill_schedule = ScheduleGPipe (prefill_stage , mbs )
369363
370364 prompt = [
371365 "What is a computer?" ,
@@ -394,6 +388,7 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:
394388 s = set (prompt_lengths )
395389 assert len (s ) == 1 , f"prompt_lengths should be the same, got { s } "
396390
391+ # with CUDATrackTime() as timer:
397392 # Need these global ids due to the API definition of dist.send and recv
398393 first_pp_rank_global_id = dist .get_global_rank (pp_group , first_pp_rank )
399394 last_pp_rank_global_id = dist .get_global_rank (pp_group , last_pp_rank )
@@ -406,21 +401,14 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:
406401 num_tokens = 40
407402
408403 # Prefill phase
409- # Run context input through pipeline
410- # TODO: we need to pass `input_pos` and `cache_lane` to each stage.
411- lane = 0
412- kwargs = {"input_pos" : input_pos , "cache_lane" : lane }
413- with torch .no_grad (), CUDATrackTime () as timer :
404+ # Run context input through pipeline, in 1 step
405+ with torch .no_grad ():
414406 if pp_rank == first_pp_rank :
415- output = prefiller .step (padded_sequence , ** kwargs )
407+ output = prefill_schedule .step (padded_sequence )
416408 elif pp_rank == last_pp_rank :
417- output = prefiller .step (** kwargs )
409+ output = prefill_schedule .step ()
418410 else : # middle pp ranks
419- prefiller .step (** kwargs )
420-
421- logger .info (
422- f"{ color .green } Prefilling time: { timer .get_time ()} { timer .unit } for rank { rank } { color .reset } "
423- )
411+ prefill_schedule .step ()
424412
425413 # Decode the output -- first generated token
426414 if pp_rank == last_pp_rank :
@@ -442,6 +430,7 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:
442430 # seqlen = 1 now
443431 seqlen_decode = 1
444432 input_pos = torch .tensor ([prompt_lengths [0 ]], device = device )
433+ model .setup_input_pos (input_pos )
445434
446435 # Create decode stage
447436 logger .info (f"Creating pipeline stage for decode { pp_rank = } , { pp_degree = } " )
@@ -456,12 +445,11 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:
456445 group = pp_group ,
457446 )
458447 # create schedule
459- decorder = ScheduleGPipe (decode_stage , 1 )
448+ decode_schedule = ScheduleGPipe (decode_stage , mbs )
460449
461450 # Decoding
462- with torch .no_grad (), CUDATrackTime () as timer :
451+ with torch .no_grad ():
463452 for step in range (num_tokens - 1 ):
464- kwargs = {"input_pos" : input_pos , "cache_lane" : lane }
465453 # sendrecv between last and first ranks, only if:
466454 # first_pp_rank != last_pp_rank.
467455 if pp_rank == last_pp_rank and pp_rank != first_pp_rank :
@@ -479,11 +467,11 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:
479467
480468 # Run data through pipeline
481469 if pp_rank == first_pp_rank :
482- output = decorder .step (new_token , ** kwargs )
470+ output = decode_schedule .step (new_token )
483471 elif pp_rank == last_pp_rank :
484- output = decorder .step (** kwargs )
472+ output = decode_schedule .step ()
485473 else : # middle pp ranks
486- decorder .step (** kwargs )
474+ decode_schedule .step ()
487475
488476 # Decode the output
489477 if pp_rank == last_pp_rank :
@@ -503,10 +491,7 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:
503491 ) # decode_results[i][0]
504492
505493 input_pos += 1
506-
507- logger .info (
508- f"{ color .green } Decoding time: { timer .get_time ()} { timer .unit } for rank { rank } { color .reset } "
509- )
494+ model .setup_input_pos (input_pos )
510495
511496 # Display the decoding results
512497
0 commit comments