1919
2020import torch
2121import torch .nn .functional as F
22+ from pippy .PipelineSchedule import ScheduleGPipe
23+ from pippy .PipelineStage import PipelineStage
2224from torch .distributed import destroy_process_group
2325from torch .distributed .checkpoint .stateful import Stateful
2426from torch .distributed .elastic .multiprocessing .errors import record
@@ -126,7 +128,8 @@ def main(job_config: JobConfig):
126128 world_size = world_size ,
127129 enable_loss_parallel = job_config .training .enable_loss_parallel ,
128130 )
129- torch .cuda .set_device (int (os .environ ["LOCAL_RANK" ]))
131+ device = torch .device (f"cuda:{ int (os .environ ['LOCAL_RANK' ])} " )
132+ torch .cuda .set_device (device )
130133 init_distributed (job_config )
131134
132135 world_mesh = parallel_dims .build_mesh (device_type = "cuda" )
@@ -144,6 +147,15 @@ def main(job_config: JobConfig):
144147 dp_rank = dp_mesh .get_local_rank ()
145148 else :
146149 dp_degree , dp_rank = 1 , 0
150+
151+ if parallel_dims .pp_enabled :
152+ pp_mesh = world_mesh ["pp" ]
153+ pp_degree = pp_mesh .size ()
154+ pp_rank = pp_mesh .get_local_rank ()
155+
156+ else :
157+ pp_degree , pp_rank = 1 , 0
158+
147159 data_loader = build_hf_data_loader (
148160 job_config .training .dataset ,
149161 job_config .training .dataset_path ,
@@ -203,9 +215,34 @@ def loss_fn(pred, labels):
203215 model = models_parallelize_fns [model_name ](
204216 model , world_mesh , parallel_dims , job_config
205217 )
206- # allocate sharded model on GPU and initialize weights via DTensor
218+ if parallel_dims .pp_enabled :
219+ pipe_meta = model
220+ model = pipe_meta .get_stage_module (pp_rank )
221+
207222 model .to_empty (device = "cuda" )
208- model .init_weights ()
223+
224+ # TODO(whc) everything below needs to become a function that can be applied to each 'virtual stage' of PP, if
225+ # there are virtual stages
226+ if parallel_dims .pp_enabled :
227+ stage = PipelineStage (
228+ pipe = pipe_meta ,
229+ stage_index = pp_rank ,
230+ device = device ,
231+ group = pp_mesh .get_group (),
232+ )
233+ pp_schedule = ScheduleGPipe (
234+ stage ,
235+ n_microbatches = parallel_dims .pp ,
236+ loss_fn = loss_fn ,
237+ )
238+ else :
239+ # if PP is enabled, we can't use init_weights. instead, we have to rely on offline creating an initial checkpoint
240+ # and loading it to get initialization values. This is becuase the init_weights functions are written assuming
241+ # the whole model (all its weights, or FQNs) exist on one rank. In PP, the init_weights on stage1 might crash
242+ # becuase it can't find "embedding" layer, for example.
243+
244+ # allocate sharded model on GPU and initialize weights via DTensor
245+ model .init_weights ()
209246
210247 gpu_mem_stats = gpu_memory_monitor .get_peak_stats ()
211248 logger .info (
@@ -217,7 +254,6 @@ def loss_fn(pred, labels):
217254 # build optimizer after applying parallelisms to the model
218255 optimizer = build_optimizer (model , job_config )
219256 scheduler = get_lr_scheduler (optimizer , job_config )
220-
221257 metric_logger = build_metric_logger (job_config )
222258
223259 # torch.compile model for improved performance
@@ -253,7 +289,13 @@ def loss_fn(pred, labels):
253289 logger .info ("Created seed checkpoint" )
254290 return
255291
256- checkpoint .load ()
292+ checkpoint_loaded = checkpoint .load ()
293+
294+ if parallel_dims .pp_enabled and not checkpoint_loaded :
295+ raise RuntimeError (
296+ "Pipeline Parallelism requires meta-initialization and loading seed checkpoint. "
297+ "Please run `./create_seed_checkpoint.sh` and rerun training with `--checkpoint.enable_checkpoint`"
298+ )
257299
258300 # plot losses loaded from checkpoint (if any) to TensorBoard
259301 # NOTE: Loss info after the last log step before checkpoint saving will not be ploted.
@@ -293,14 +335,33 @@ def loss_fn(pred, labels):
293335
294336 input_ids = input_ids .cuda ()
295337 labels = labels .cuda ()
296-
297338 optimizer .zero_grad ()
298339
299- # forward / backward
300- with loss_parallel_ctx ():
301- pred = model (input_ids )
302- loss = loss_fn (pred , labels )
303- loss .backward ()
340+ if parallel_dims .pp_enabled :
341+ # pipeline parallel forward / backward inside step() call
342+ is_last_stage = pp_mesh .get_local_rank () == pp_mesh .size () - 1
343+
344+ with loss_parallel_ctx ():
345+ if pp_mesh .get_local_rank () == 0 :
346+ pp_schedule .step (input_ids )
347+ elif is_last_stage :
348+ losses = []
349+ pp_schedule .step (target = labels , losses = losses )
350+ else :
351+ schedule .step ()
352+
353+ # accumulate losses across pipeline microbatches
354+ loss = (
355+ torch .mean (torch .stack (losses ))
356+ if is_last_stage
357+ else torch .Tensor ([- 1.0 ])
358+ )
359+ else :
360+ # Non-PP forward / backward
361+ with loss_parallel_ctx ():
362+ pred = model (input_ids )
363+ loss = loss_fn (pred , labels )
364+ loss .backward ()
304365
305366 # clip gradients
306367 torch .nn .utils .clip_grad_norm_ (
0 commit comments