Skip to content

Commit eb64c69

Browse files
authored
[PP] Bypass seed checkpoint my init-ing model parts separately (#516)
Stack from [ghstack](https://github.com/ezyang/ghstack) (oldest at bottom): * #473 * #517 * __->__ #516 Allows PP to be used without a seed checkpoint by calling `init_weight` on each model part. This is the solution in step 1 of #514 proposed by @wconstab
1 parent cfc0f4e commit eb64c69

File tree

2 files changed

+26
-22
lines changed

2 files changed

+26
-22
lines changed

torchtitan/models/llama/model.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -394,19 +394,23 @@ def init_weights(self):
394394
"""
395395
with torch.device(self.freqs_cis.device):
396396
self.freqs_cis = self._precompute_freqs_cis()
397-
nn.init.normal_(self.tok_embeddings.weight)
397+
if self.tok_embeddings is not None:
398+
nn.init.normal_(self.tok_embeddings.weight)
398399
for layer in self.layers.values():
399-
layer.init_weights()
400-
self.norm.reset_parameters()
400+
if layer is not None:
401+
layer.init_weights()
402+
if self.norm is not None:
403+
self.norm.reset_parameters()
401404
final_out_std = self.model_args.dim**-0.5
402405
cutoff_factor = 3
403-
nn.init.trunc_normal_(
404-
self.output.weight,
405-
mean=0.0,
406-
std=final_out_std,
407-
a=-cutoff_factor * final_out_std,
408-
b=cutoff_factor * final_out_std,
409-
)
406+
if self.output is not None:
407+
nn.init.trunc_normal_(
408+
self.output.weight,
409+
mean=0.0,
410+
std=final_out_std,
411+
a=-cutoff_factor * final_out_std,
412+
b=cutoff_factor * final_out_std,
413+
)
410414

411415
def _precompute_freqs_cis(self) -> torch.Tensor:
412416
return precompute_freqs_cis(

train.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
import torch
1313
from torch.distributed.elastic.multiprocessing.errors import record
14+
from torch.fx import GraphModule
1415

1516
from torchtitan import utils
1617
from torchtitan.checkpoint import CheckpointManager, TrainState
@@ -152,25 +153,23 @@ def loss_fn(pred, labels):
152153
for m in model_parts:
153154
# apply SPMD-style PT-D techniques
154155
models_parallelize_fns[model_name](m, world_mesh, parallel_dims, job_config)
155-
156-
# In PP, we cannot call init_weights directly because some layers are missing.
157-
# In the future, we may make init_weights handle missing layers, but also have
158-
# to consider RNG seed propagation. For now, we rely on a seed checkpoint to
159-
# initialize the model.
160156
m.to_empty(device="cuda")
161-
m.train()
162157
else:
163158
# apply PT-D Tensor Parallel, activation checkpointing, torch.compile, Data Parallel
164159
models_parallelize_fns[model_name](model, world_mesh, parallel_dims, job_config)
165160

166161
# move sharded model to CPU/GPU and initialize weights via DTensor
167162
init_device = "cpu" if job_config.checkpoint.create_seed_checkpoint else "cuda"
168163
model.to_empty(device=init_device)
169-
model.init_weights()
170-
model.train()
171-
172164
model_parts = [model]
173165

166+
for mod in model_parts:
167+
# skip traced modules since we do not define init_weights in the traced module
168+
if isinstance(mod, GraphModule):
169+
continue
170+
mod.init_weights()
171+
mod.train()
172+
174173
gpu_mem_stats = gpu_memory_monitor.get_peak_stats()
175174
logger.info(
176175
f"GPU memory usage for model: "
@@ -205,9 +204,10 @@ def loss_fn(pred, labels):
205204
checkpoint_loaded = checkpoint.load()
206205

207206
if parallel_dims.pp_enabled and not checkpoint_loaded:
208-
raise RuntimeError(
209-
"Pipeline Parallelism requires meta-initialization and loading seed checkpoint. "
210-
"Please run `./create_seed_checkpoint.sh` and rerun training with `--checkpoint.enable_checkpoint`"
207+
# TODO: fix this by allowing each rank to set their own seed
208+
logger.warning(
209+
"Pipeline Parallelism is being used without a seed checkpoint. "
210+
"All the substages will be initialized with random weights with same RNG state which can affect convergence."
211211
)
212212

213213
metric_logger = build_metric_logger(job_config, parallel_dims)

0 commit comments

Comments
 (0)