1515
1616import torch
1717import torch .nn .functional as F
18+ from pippy .PipelineSchedule import PipelineScheduleGPipe
19+ from pippy .PipelineStage import PipelineStage
1820from torch .distributed .elastic .multiprocessing .errors import record
1921from torch .distributed .fsdp .sharded_grad_scaler import ShardedGradScaler
2022from torch .distributed .tensor .parallel import loss_parallel
@@ -129,7 +131,9 @@ def main(job_config: JobConfig):
129131 world_size = world_size ,
130132 enable_loss_parallel = job_config .training .enable_loss_parallel ,
131133 )
132- torch .cuda .set_device (int (os .environ ["LOCAL_RANK" ]))
134+ device = torch .device (f"cuda:{ int (os .environ ['LOCAL_RANK' ])} " )
135+ # torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))
136+ torch .cuda .set_device (device )
133137 init_distributed (job_config )
134138
135139 world_mesh = parallel_dims .build_mesh (device_type = "cuda" )
@@ -148,6 +152,14 @@ def main(job_config: JobConfig):
148152 dp_rank = dp_mesh .get_local_rank ()
149153 else :
150154 dp_degree , dp_rank = 1 , 0
155+
156+ if parallel_dims .pp_enabled :
157+ pp_mesh = world_mesh ["pp" ]
158+ pp_degree = pp_mesh .size ()
159+ pp_rank = pp_mesh .get_local_rank ()
160+ else :
161+ pp_degree , pp_rank = 1 , 0
162+
151163 data_loader = build_dataloader_fn (
152164 job_config .training .dataset ,
153165 job_config .training .dataset_path ,
@@ -197,18 +209,54 @@ def main(job_config: JobConfig):
197209 model = models_parallelize_fns [model_name ](
198210 model , world_mesh , parallel_dims , job_config
199211 )
200- # allocate sharded model on GPU and initialize weights via DTensor
201- model .to_empty (device = "cuda" )
202- model .init_weights ()
203-
204- # build optimizer after applying parallelisms to the model
205- optimizer = build_optimizer (model , job_config )
206- scheduler = get_lr_scheduler (optimizer , job_config )
212+ if parallel_dims .pp_enabled :
213+ pipe_meta = model
214+ model = pipe_meta .get_stage_module (pp_rank )
207215
208216 # build grad scaler which is effective only when mixed precision training
209217 # is enabled with fp16 param dtype under FSDP
210218 scaler = build_grad_scaler (model )
211219
220+ def loss_fn (pred , labels ):
221+ with (
222+ loss_parallel ()
223+ if parallel_dims .loss_parallel_enabled
224+ else contextlib .nullcontext ()
225+ ):
226+ loss = F .cross_entropy (pred .flatten (0 , 1 ), labels .flatten (0 , 1 ))
227+
228+ # backward on scaled loss to create scaled gradients
229+ scaler .scale (loss )
230+ return loss
231+
232+ # TODO(whc) everything below needs to become a function that can be applied to each 'virtual stage' of PP, if
233+ # there are virtual stages
234+ if parallel_dims .pp_enabled :
235+ stage = PipelineStage (
236+ pipe = pipe_meta ,
237+ stage_index = pp_rank ,
238+ device = device ,
239+ group = pp_mesh .get_group (),
240+ )
241+ pp_schedule = PipelineScheduleGPipe (
242+ stage ,
243+ n_microbatches = parallel_dims .pp ,
244+ loss_fn = loss_fn ,
245+ )
246+ model .to_empty (device = "cuda" )
247+ else :
248+ # if PP is enabled, we can't use init_weights. instead, we have to rely on offline creating an initial checkpoint
249+ # and loading it to get initialization values. This is becuase the init_weights functions are written assuming
250+ # the whole model (all its weights, or FQNs) exist on one rank. In PP, the init_weights on stage1 might crash
251+ # becuase it can't find "embedding" layer, for example.
252+
253+ # allocate sharded model on GPU and initialize weights via DTensor
254+ model .to_empty (device = "cuda" )
255+ model .init_weights ()
256+
257+ # build optimizer after applying parallelisms to the model
258+ optimizer = build_optimizer (model , job_config )
259+ scheduler = get_lr_scheduler (optimizer , job_config )
212260 metric_logger = build_metric_logger (job_config )
213261
214262 # torch.compile model for improved performance
@@ -278,21 +326,32 @@ def main(job_config: JobConfig):
278326
279327 input_ids = input_ids .cuda ()
280328 labels = labels .cuda ()
281-
282329 optimizer .zero_grad ()
283330
284- # forward
285- pred = model (input_ids )
331+ if parallel_dims .pp_enabled :
332+ # pipeline F/Loss/B
333+ is_last_stage = pp_mesh .get_local_rank () == pp_mesh .size () - 1
286334
287- with (
288- loss_parallel ()
289- if parallel_dims .loss_parallel_enabled
290- else contextlib .nullcontext ()
291- ):
292- loss = F .cross_entropy (pred .flatten (0 , 1 ), labels .flatten (0 , 1 ))
335+ if pp_mesh .get_local_rank () == 0 :
336+ pp_schedule .step (input_ids )
337+ elif is_last_stage :
338+ losses = []
339+ pp_schedule .step (target = labels , losses = losses )
340+ else :
341+ schedule .step ()
342+
343+ # accumulate losses across pipeline microbatches
344+ current_loss = (
345+ torch .mean (torch .stack (losses )).item () if is_last_stage else - 1.0
346+ )
347+ else :
348+ # non-pipeline F/Loss/B
349+ pred = model (input_ids )
350+
351+ loss = loss_fn (pred , labels )
352+ loss .backward ()
293353
294- # backward on scaled loss to create scaled gradients
295- scaler .scale (loss ).backward ()
354+ current_loss = loss .item ()
296355
297356 # clip gradients (after unscaling gradients of the optimizer's params)
298357 scaler .unscale_ (optimizer )
@@ -309,7 +368,6 @@ def main(job_config: JobConfig):
309368 # updates the scale for next iteration
310369 scaler .update ()
311370
312- current_loss = loss .item ()
313371 losses_since_last_log .append (current_loss )
314372
315373 # log metrics
0 commit comments