Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions docs/composability.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ When applying Pipeline Parallelism, you will have to construct nn.Module objects
Most likely, you can write your model in such a way that the top-level nn.Module owns a sequence of child modules that it calls during forward, delegating most of the complexity to the child module forwards. If you can reduce your top level forward to mostly a for-loop over child module calls, then you'll simplify the pipeline-partitioning task to choosing the set of submodules to keep per stage. If you have non-trivial logic in the top-level forward, you'll have to find a way to patch that logic back onto the resulting pipeline stage model, which can be annoying.

example ([PR #321](https://github.com/pytorch/torchtitan/pull/321)):
we used to slice the `freqs_cis` buffer by `seq_len` in the top level forward, pass that into child modules, and expect that inside the child modules the `seq_len` would match up with the size of other local tensors. But we don't know about whether TP was applied or not when we consider PP splitting and could create a mismatch. Its just as easy to perform the `freqs_cis` slicing inside the child submodule, using the runtime-accurate local `seq_len`, and this sidesteps the issue at PP slicing time.
we used to slice the `freqs_cis` buffer by `seq_len` in the top level forward, pass that into child modules, and expect that inside the child modules the `seq_len` would match up with the size of other local tensors. But we don't know about whether TP was applied or not when we consider PP splitting and could create a mismatch. Its just as easy to perform the `freqs_cis` slicing inside the child submodule, using the runtime-accurate local `seq_len`, and this sidesteps the issue at PP slicing time.

example ([PR #322])https://github.com/pytorch/torchtitan/pull/322)): We decided to actually reuse the top-level model object on every PP stage, just delete the layers we don't want, and make sure that the top-level forward would do the right thing. This means we don't have to make a separate runtime pp_forward that glues together child modules per stage. The first change was using a moduledict instead of modulelist to store layers. This preserves layer Fully Qualified Names (FQNs) even when deleting some layers - e.g. layers.1 stays layers.1 even if you remove layers.0, which isn't true for a list- this matters for checkpoint save/load. Preserving FQNs is a requirement for using Distributed Checkpointing (DCP) since it uses FQNs as globally unique IDs for sharding metadata. The second change was making the input and output layers optional- if the layer exists, we run it, otherwise we feed the input through to bypass it. With these two changes, we can just (meta)-initialize the whole model, delete the unused parts per stage, then materialize the remaining part on GPU before loading a checkpoint.

Expand All @@ -20,4 +20,3 @@ Initializing the pipeline-parallel model is challenging becuase we assume the mo
For now, we sidestep all these problems with a simple but brutal solution: Initialize the whole model on some CPU instance, save a checkpoint file, and then lean on Distributed Checkpointing's "load" functionality to initialize the FQNs that are present on a given PP stage after stage creation. For future work, we consider adding a more elaborate initialization scheme to `torch.pipelining`.

One issue with seed checkpoints is that we rely on initializing _every_ model state from the checkpoint, which means the model can't have any non-persistent buffers, or else we have to specially initialize those in `train.py` after pipeline splitting. `freqs_cis` was originally a non-persistent buffer, and we changed this to persistent in order to load it from the seed checkpoint.

24 changes: 14 additions & 10 deletions torchtitan/models/llama/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,19 +394,23 @@ def init_weights(self):
"""
with torch.device(self.freqs_cis.device):
self.freqs_cis = self._precompute_freqs_cis()
nn.init.normal_(self.tok_embeddings.weight)
if self.tok_embeddings is not None:
nn.init.normal_(self.tok_embeddings.weight)
for layer in self.layers.values():
layer.init_weights()
self.norm.reset_parameters()
if layer is not None:
layer.init_weights()
if self.norm is not None:
self.norm.reset_parameters()
final_out_std = self.model_args.dim**-0.5
cutoff_factor = 3
nn.init.trunc_normal_(
self.output.weight,
mean=0.0,
std=final_out_std,
a=-cutoff_factor * final_out_std,
b=cutoff_factor * final_out_std,
)
if self.output is not None:
nn.init.trunc_normal_(
self.output.weight,
mean=0.0,
std=final_out_std,
a=-cutoff_factor * final_out_std,
b=cutoff_factor * final_out_std,
)

def _precompute_freqs_cis(self) -> torch.Tensor:
return precompute_freqs_cis(
Expand Down
24 changes: 12 additions & 12 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

import torch
from torch.distributed.elastic.multiprocessing.errors import record
from torch.fx import GraphModule

from torchtitan import utils
from torchtitan.checkpoint import CheckpointManager, TrainState
Expand Down Expand Up @@ -152,25 +153,23 @@ def loss_fn(pred, labels):
for m in model_parts:
# apply SPMD-style PT-D techniques
models_parallelize_fns[model_name](m, world_mesh, parallel_dims, job_config)

# In PP, we cannot call init_weights directly because some layers are missing.
# In the future, we may make init_weights handle missing layers, but also have
# to consider RNG seed propagation. For now, we rely on a seed checkpoint to
# initialize the model.
m.to_empty(device="cuda")
m.train()
else:
# apply PT-D Tensor Parallel, activation checkpointing, torch.compile, Data Parallel
models_parallelize_fns[model_name](model, world_mesh, parallel_dims, job_config)

# move sharded model to CPU/GPU and initialize weights via DTensor
init_device = "cpu" if job_config.checkpoint.create_seed_checkpoint else "cuda"
model.to_empty(device=init_device)
model.init_weights()
model.train()

model_parts = [model]

for mod in model_parts:
# skip traced modules since we do not define init_weights in the traced module
if isinstance(mod, GraphModule):
continue
mod.init_weights()
mod.train()

gpu_mem_stats = gpu_memory_monitor.get_peak_stats()
logger.info(
f"GPU memory usage for model: "
Expand Down Expand Up @@ -205,9 +204,10 @@ def loss_fn(pred, labels):
checkpoint_loaded = checkpoint.load()

if parallel_dims.pp_enabled and not checkpoint_loaded:
raise RuntimeError(
"Pipeline Parallelism requires meta-initialization and loading seed checkpoint. "
"Please run `./create_seed_checkpoint.sh` and rerun training with `--checkpoint.enable_checkpoint`"
# TODO: fix this by allowing each rank to set their own seed
logger.warning(
"Pipeline Parallelism is being used without a seed checkpoint. "
"All the substages will be initialized with random weights with same RNG state which can affect convergence."
)

metric_logger = build_metric_logger(job_config, parallel_dims)
Expand Down