|
11 | 11 |
|
12 | 12 | import torch |
13 | 13 | from torch.distributed.elastic.multiprocessing.errors import record |
| 14 | +from torch.fx import GraphModule |
14 | 15 |
|
15 | 16 | from torchtitan import utils |
16 | 17 | from torchtitan.checkpoint import CheckpointManager, TrainState |
@@ -152,25 +153,23 @@ def loss_fn(pred, labels): |
152 | 153 | for m in model_parts: |
153 | 154 | # apply SPMD-style PT-D techniques |
154 | 155 | models_parallelize_fns[model_name](m, world_mesh, parallel_dims, job_config) |
155 | | - |
156 | | - # In PP, we cannot call init_weights directly because some layers are missing. |
157 | | - # In the future, we may make init_weights handle missing layers, but also have |
158 | | - # to consider RNG seed propagation. For now, we rely on a seed checkpoint to |
159 | | - # initialize the model. |
160 | 156 | m.to_empty(device="cuda") |
161 | | - m.train() |
162 | 157 | else: |
163 | 158 | # apply PT-D Tensor Parallel, activation checkpointing, torch.compile, Data Parallel |
164 | 159 | models_parallelize_fns[model_name](model, world_mesh, parallel_dims, job_config) |
165 | 160 |
|
166 | 161 | # move sharded model to CPU/GPU and initialize weights via DTensor |
167 | 162 | init_device = "cpu" if job_config.checkpoint.create_seed_checkpoint else "cuda" |
168 | 163 | model.to_empty(device=init_device) |
169 | | - model.init_weights() |
170 | | - model.train() |
171 | | - |
172 | 164 | model_parts = [model] |
173 | 165 |
|
| 166 | + for mod in model_parts: |
| 167 | + # skip traced modules since we do not define init_weights in the traced module |
| 168 | + if isinstance(mod, GraphModule): |
| 169 | + continue |
| 170 | + mod.init_weights() |
| 171 | + mod.train() |
| 172 | + |
174 | 173 | gpu_mem_stats = gpu_memory_monitor.get_peak_stats() |
175 | 174 | logger.info( |
176 | 175 | f"GPU memory usage for model: " |
@@ -205,9 +204,10 @@ def loss_fn(pred, labels): |
205 | 204 | checkpoint_loaded = checkpoint.load() |
206 | 205 |
|
207 | 206 | if parallel_dims.pp_enabled and not checkpoint_loaded: |
208 | | - raise RuntimeError( |
209 | | - "Pipeline Parallelism requires meta-initialization and loading seed checkpoint. " |
210 | | - "Please run `./create_seed_checkpoint.sh` and rerun training with `--checkpoint.enable_checkpoint`" |
| 207 | + # TODO: fix this by allowing each rank to set their own seed |
| 208 | + logger.warning( |
| 209 | + "Pipeline Parallelism is being used without a seed checkpoint. " |
| 210 | + "All the substages will be initialized with random weights with same RNG state which can affect convergence." |
211 | 211 | ) |
212 | 212 |
|
213 | 213 | metric_logger = build_metric_logger(job_config, parallel_dims) |
|
0 commit comments