1515
1616import torch
1717import torch .nn .functional as F
18+ from pippy .PipelineSchedule import PipelineScheduleGPipe
19+ from pippy .PipelineStage import PipelineStage
1820from torch .distributed .elastic .multiprocessing .errors import record
1921from torch .distributed .tensor .parallel import loss_parallel
2022
@@ -120,7 +122,9 @@ def main(job_config: JobConfig):
120122 world_size = world_size ,
121123 enable_loss_parallel = job_config .training .enable_loss_parallel ,
122124 )
123- torch .cuda .set_device (int (os .environ ["LOCAL_RANK" ]))
125+ device = torch .device (f"cuda:{ int (os .environ ['LOCAL_RANK' ])} " )
126+ # torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))
127+ torch .cuda .set_device (device )
124128 init_distributed (job_config )
125129
126130 world_mesh = parallel_dims .build_mesh (device_type = "cuda" )
@@ -139,6 +143,14 @@ def main(job_config: JobConfig):
139143 dp_rank = dp_mesh .get_local_rank ()
140144 else :
141145 dp_degree , dp_rank = 1 , 0
146+
147+ if parallel_dims .pp_enabled :
148+ pp_mesh = world_mesh ["pp" ]
149+ pp_degree = pp_mesh .size ()
150+ pp_rank = pp_mesh .get_local_rank ()
151+ else :
152+ pp_degree , pp_rank = 1 , 0
153+
142154 data_loader = build_dataloader_fn (
143155 job_config .training .dataset ,
144156 job_config .training .dataset_path ,
@@ -197,14 +209,38 @@ def loss_fn(pred, labels):
197209 model = models_parallelize_fns [model_name ](
198210 model , world_mesh , parallel_dims , job_config
199211 )
200- # allocate sharded model on GPU and initialize weights via DTensor
212+ if parallel_dims .pp_enabled :
213+ pipe_meta = model
214+ model = pipe_meta .get_stage_module (pp_rank )
215+
201216 model .to_empty (device = "cuda" )
202- model .init_weights ()
217+
218+ # TODO(whc) everything below needs to become a function that can be applied to each 'virtual stage' of PP, if
219+ # there are virtual stages
220+ if parallel_dims .pp_enabled :
221+ stage = PipelineStage (
222+ pipe = pipe_meta ,
223+ stage_index = pp_rank ,
224+ device = device ,
225+ group = pp_mesh .get_group (),
226+ )
227+ pp_schedule = PipelineScheduleGPipe (
228+ stage ,
229+ n_microbatches = parallel_dims .pp ,
230+ loss_fn = loss_fn ,
231+ )
232+ else :
233+ # if PP is enabled, we can't use init_weights. instead, we have to rely on offline creating an initial checkpoint
234+ # and loading it to get initialization values. This is becuase the init_weights functions are written assuming
235+ # the whole model (all its weights, or FQNs) exist on one rank. In PP, the init_weights on stage1 might crash
236+ # becuase it can't find "embedding" layer, for example.
237+
238+ # allocate sharded model on GPU and initialize weights via DTensor
239+ model .init_weights ()
203240
204241 # build optimizer after applying parallelisms to the model
205242 optimizer = build_optimizer (model , job_config )
206243 scheduler = get_lr_scheduler (optimizer , job_config )
207-
208244 metric_logger = build_metric_logger (job_config )
209245
210246 # torch.compile model for improved performance
@@ -274,13 +310,30 @@ def loss_fn(pred, labels):
274310
275311 input_ids = input_ids .cuda ()
276312 labels = labels .cuda ()
277-
278313 optimizer .zero_grad ()
279314
280- # forward / backward
281- pred = model (input_ids )
282- loss = loss_fn (pred , labels )
283- loss .backward ()
315+ if parallel_dims .pp_enabled :
316+ # pipeline parallel forward / backward inside step() call
317+ is_last_stage = pp_mesh .get_local_rank () == pp_mesh .size () - 1
318+
319+ if pp_mesh .get_local_rank () == 0 :
320+ pp_schedule .step (input_ids )
321+ elif is_last_stage :
322+ losses = []
323+ pp_schedule .step (target = labels , losses = losses )
324+ else :
325+ schedule .step ()
326+
327+ # accumulate losses across pipeline microbatches
328+ current_loss = (
329+ torch .mean (torch .stack (losses )).item () if is_last_stage else - 1.0
330+ )
331+ else :
332+ # forward / backward
333+ pred = model (input_ids )
334+ loss = loss_fn (pred , labels )
335+ loss .backward ()
336+ current_loss = loss .item ()
284337
285338 # clip gradients
286339 torch .nn .utils .clip_grad_norm_ (
@@ -291,7 +344,6 @@ def loss_fn(pred, labels):
291344 optimizer .step ()
292345 scheduler .step ()
293346
294- current_loss = loss .item ()
295347 losses_since_last_log .append (current_loss )
296348
297349 # log metrics
0 commit comments