|
7 | 7 | from collections import defaultdict |
8 | 8 |
|
9 | 9 | import torch |
| 10 | +from pippy import annotate_split_points, Pipe, PipeSplitWrapper |
10 | 11 | from torch.distributed._tensor import Replicate, Shard |
11 | 12 |
|
12 | 13 | from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( |
@@ -125,7 +126,31 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig): |
125 | 126 | """ |
126 | 127 | # apply PTD parallelisms |
127 | 128 | if parallel_dims.pp_enabled: |
128 | | - raise NotImplementedError("PP not implemented yet.") |
| 129 | + pp_mesh = world_mesh["pp"] |
| 130 | + stage_idx = pp_mesh.get_local_rank() |
| 131 | + layers_per_rank = len(model.layers) // parallel_dims.pp |
| 132 | + for i in range(1, parallel_dims.pp): |
| 133 | + annotate_split_points( |
| 134 | + model, |
| 135 | + { |
| 136 | + f"layers.{i * layers_per_rank}": PipeSplitWrapper.SplitPoint.BEGINNING |
| 137 | + }, |
| 138 | + ) |
| 139 | + |
| 140 | + # Get example input |
| 141 | + label_shape = input_shape = (8, 2048) # TODO |
| 142 | + input_ids = torch.randint( |
| 143 | + model.vocab_size, input_shape, dtype=torch.int64, device="meta" |
| 144 | + ) |
| 145 | + labels = torch.randint( |
| 146 | + model.vocab_size, label_shape, dtype=torch.int64, device="meta" |
| 147 | + ) |
| 148 | + print("input_ids: ", input_ids.shape, input_ids.dtype) |
| 149 | + print("labels: ", labels.shape, labels.dtype) |
| 150 | + |
| 151 | + # Create a pipeline representation from the model |
| 152 | + pipe = Pipe.from_tracing(model, parallel_dims.pp, example_args=(input_ids,)) |
| 153 | + model = pipe.get_stage_module(stage_idx) |
129 | 154 |
|
130 | 155 | # First we apply Sequence Parallelism if it's enabled |
131 | 156 | if parallel_dims.sp_enabled: |
@@ -230,9 +255,14 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig): |
230 | 255 | meta_to_real_init_fn(model) |
231 | 256 | model.cuda() |
232 | 257 |
|
233 | | - # we have now moved from meta to device, |
234 | | - # reset parameters for proper initialization |
235 | | - model.reset_parameters() |
236 | | - logger.info("Model fully initialized via reset_parameters") |
| 258 | + if parallel_dims.pp_enabled: |
| 259 | + setattr(pipe.split_gm, f"submod_{stage_idx}", model) |
| 260 | + return pipe |
| 261 | + else: |
| 262 | + # TODO figure out PP compatible deferred initialization |
| 263 | + # we have now moved from meta to device, |
| 264 | + # reset parameters for proper initialization |
| 265 | + model.reset_parameters() |
| 266 | + logger.info("Model fully initialized via reset_parameters") |
237 | 267 |
|
238 | 268 | return model |
0 commit comments