Skip to content

Commit efb2845

Browse files
committed
Add Pipeline Parallel (and 2D PP+FSDP) support
- uses pipeline tracer frontend to extract a graph and partition it into chunks per stage - hardcodes one schedule (1F1B) for now (need to expose option to switch schedule and test other schedules) - supports 2D parallelism currently, 3D (TP) is work in progress ghstack-source-id: 821fa21 Pull Request resolved: #161
1 parent e92e3d7 commit efb2845

File tree

5 files changed

+167
-25
lines changed

5 files changed

+167
-25
lines changed

requirements.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,3 +5,5 @@ tensorboard
55
sentencepiece
66
tiktoken
77
blobfile
8+
# TODO remove pippy requirement after completing migration to pytorch
9+
git+https://github.com/pytorch/pippy

test_runner.py

Lines changed: 45 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ class OverrideDefinitions:
2626

2727
override_args: Sequence[Sequence[str]] = tuple(tuple(" "))
2828
test_descr: str = "default"
29+
requires_seed_checkpoint: bool = False
2930

3031

3132
CONFIG_DIR = "./train_configs"
@@ -85,25 +86,62 @@ class OverrideDefinitions:
8586
],
8687
"Checkpoint Integration Test - Save Model Weights Only bf16",
8788
),
89+
OverrideDefinitions(
90+
[
91+
[
92+
"--checkpoint.enable_checkpoint",
93+
"--training.pipeline_parallel_degree 4",
94+
"--training.data_parallel_degree 1",
95+
"--model.norm_type rmsnorm", # TODO fix fused_rmsnorm issue
96+
],
97+
],
98+
"PP 1D test",
99+
requires_seed_checkpoint=True,
100+
),
101+
OverrideDefinitions(
102+
[
103+
[
104+
"--checkpoint.enable_checkpoint",
105+
"--training.pipeline_parallel_degree 2",
106+
"--training.data_parallel_degree 2",
107+
"--model.norm_type rmsnorm", # TODO fix fused_rmsnorm issue
108+
],
109+
],
110+
"PP+DP 2D test",
111+
requires_seed_checkpoint=True,
112+
),
88113
]
89114

90115

116+
def _run_cmd(cmd):
117+
return subprocess.run(
118+
[cmd],
119+
stdout=subprocess.PIPE,
120+
stderr=subprocess.STDOUT,
121+
text=True,
122+
shell=True,
123+
)
124+
125+
91126
def run_test(test_flavor: OverrideDefinitions, full_path: str):
92127
# run_test supports sequence of tests.
93128
for override_arg in test_flavor.override_args:
129+
94130
cmd = f"CONFIG_FILE={full_path} NGPU=4 LOG_RANK=0,1,2,3 ./run_llama_train.sh"
95131
if override_arg:
96132
cmd += " " + " ".join(override_arg)
97133
print(
98134
f"=====Integration test, flavor : {test_flavor.test_descr}, command : {cmd}====="
99135
)
100-
result = subprocess.run(
101-
[cmd],
102-
stdout=subprocess.PIPE,
103-
stderr=subprocess.STDOUT,
104-
text=True,
105-
shell=True,
106-
)
136+
137+
if test_flavor.requires_seed_checkpoint:
138+
print("Creating seed checkpoint")
139+
result = run_cmd(
140+
f"CONFIG_FILE={full_path} ./create_seed_checkpoint.sh --checkpoint.folder {test_checkpoint_dir}"
141+
)
142+
print(result.stdout)
143+
144+
result = _run_cmd(cmd)
107145
print(result.stdout)
108146
if result.returncode != 0:
109147
raise Exception(

torchtitan/models/llama/model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -423,7 +423,7 @@ def forward(self, tokens: torch.Tensor):
423423
"""
424424
_bsz, seqlen = tokens.shape
425425
h = self.tok_embeddings(tokens)
426-
self.freqs_cis = self.freqs_cis.to(h.device)
426+
# self.freqs_cis = self.freqs_cis.to(h.device)
427427
freqs_cis = self.freqs_cis[0:seqlen]
428428

429429
for layer in self.layers:

torchtitan/parallelisms/parallelize_llama.py

Lines changed: 47 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from typing import Tuple
1212

1313
import torch
14-
14+
from pippy import annotate_split_points, pipeline, SplitPoint
1515
from torch.distributed._composable.fsdp import fully_shard, MixedPrecisionPolicy
1616
from torch.distributed._tensor import Replicate, Shard
1717
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
@@ -137,7 +137,36 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig):
137137
the model must fit on GPU or CPU memory.
138138
"""
139139
if parallel_dims.pp_enabled:
140-
raise NotImplementedError("PP not implemented yet.")
140+
141+
if job_config.model.norm_type == "fused_rmsnorm":
142+
# TODO(whc) - torch._dynamo.exc.Unsupported: Illegal getattr invocation stride in strict mode
143+
# coming from ` if dy.stride(-1) != 1:` in fused_rmsnorm
144+
raise NotImplementedError(
145+
"fused_rmsnorm not yet compatible with Pipeline Tracer (strides error). Please use layernorm or rmsnorm."
146+
)
147+
pp_mesh = world_mesh["pp"]
148+
stage_idx = pp_mesh.get_local_rank()
149+
layers_per_rank = len(model.layers) // parallel_dims.pp
150+
for i in range(1, parallel_dims.pp):
151+
annotate_split_points(
152+
model,
153+
{f"layers.{i * layers_per_rank}": SplitPoint.BEGINNING},
154+
)
155+
156+
# Get example input
157+
label_shape = input_shape = (8, 2048) # TODO
158+
input_ids = torch.randint(
159+
model.vocab_size, input_shape, dtype=torch.int64, device="meta"
160+
)
161+
labels = torch.randint(
162+
model.vocab_size, label_shape, dtype=torch.int64, device="meta"
163+
)
164+
print("input_ids: ", input_ids.shape, input_ids.dtype)
165+
print("labels: ", labels.shape, labels.dtype)
166+
167+
# Create a pipeline representation from the model
168+
pipe = pipeline(model, parallel_dims.pp, example_args=(input_ids,))
169+
model = pipe.get_stage_module(stage_idx)
141170

142171
if parallel_dims.tp_enabled:
143172
if job_config.model.norm_type == "fused_rmsnorm":
@@ -215,27 +244,39 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig):
215244
assert dp_mesh.mesh_dim_names == ("dp",), dp_mesh.mesh_dim_names
216245
# TODO: Expose `reduce_dtype` as a config option.
217246
mp_policy = MixedPrecisionPolicy(
218-
param_dtype=torch.bfloat16, reduce_dtype=torch.float32
247+
# TODO(whc) need to fix PP + FSDP-mixed-precision
248+
# tracer for PP assumes f32 and is caught off guard when runtime FSDP interacts using bf16 inputs
249+
# param_dtype=torch.bfloat16, reduce_dtype=torch.float32
250+
param_dtype=torch.float32,
251+
reduce_dtype=torch.float32,
219252
)
220253
ac_mode = job_config.activation_checkpoint.mode
221254
fsdp_config = {"mesh": dp_mesh, "mp_policy": mp_policy}
222-
for layer_id, transformer_block in enumerate(model.layers):
255+
for layer_name, transformer_block in model.layers.named_children():
223256
if job_config.activation_checkpoint.mode in ("full", "selective"):
224257
transformer_block = checkpoint_wrapper(
225258
transformer_block, job_config.activation_checkpoint
226259
)
227260
# As an optimization, do not reshard after forward for the last
228261
# transformer block since FSDP would prefetch it immediately
229-
reshard_after_forward = layer_id < len(model.layers) - 1
262+
# reshard_after_forward = layer_id < len(model.layers) - 1
263+
# TODO(whc) need to fix correctly handle layer-ids on pp-split module
264+
reshard_after_forward = True
230265
fully_shard(
231266
transformer_block,
232267
**fsdp_config,
233268
reshard_after_forward=reshard_after_forward,
234269
)
235-
model.layers[layer_id] = transformer_block
270+
# model.layers[layer_id] = transformer_block
271+
# TODO(whc)
272+
setattr(model.layers, layer_name, transformer_block)
236273
model = fully_shard(model, **fsdp_config)
237274
if ac_mode in ("full", "selective"):
238275
logger.info(f"Applied {ac_mode} activation checkpointing to the model")
239276
logger.info("Applied FSDP to the model")
240277

278+
if parallel_dims.pp_enabled:
279+
setattr(pipe.split_gm, f"submod_{stage_idx}", model)
280+
return pipe
281+
241282
return model

train.py

Lines changed: 72 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919

2020
import torch
2121
import torch.nn.functional as F
22+
from pippy.PipelineSchedule import ScheduleGPipe
23+
from pippy.PipelineStage import PipelineStage
2224
from torch.distributed import destroy_process_group
2325
from torch.distributed.checkpoint.stateful import Stateful
2426
from torch.distributed.elastic.multiprocessing.errors import record
@@ -126,7 +128,8 @@ def main(job_config: JobConfig):
126128
world_size=world_size,
127129
enable_loss_parallel=job_config.training.enable_loss_parallel,
128130
)
129-
torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))
131+
device = torch.device(f"cuda:{int(os.environ['LOCAL_RANK'])}")
132+
torch.cuda.set_device(device)
130133
init_distributed(job_config)
131134

132135
world_mesh = parallel_dims.build_mesh(device_type="cuda")
@@ -144,6 +147,15 @@ def main(job_config: JobConfig):
144147
dp_rank = dp_mesh.get_local_rank()
145148
else:
146149
dp_degree, dp_rank = 1, 0
150+
151+
if parallel_dims.pp_enabled:
152+
pp_mesh = world_mesh["pp"]
153+
pp_degree = pp_mesh.size()
154+
pp_rank = pp_mesh.get_local_rank()
155+
156+
else:
157+
pp_degree, pp_rank = 1, 0
158+
147159
data_loader = build_hf_data_loader(
148160
job_config.training.dataset,
149161
job_config.training.dataset_path,
@@ -203,9 +215,34 @@ def loss_fn(pred, labels):
203215
model = models_parallelize_fns[model_name](
204216
model, world_mesh, parallel_dims, job_config
205217
)
206-
# allocate sharded model on GPU and initialize weights via DTensor
218+
if parallel_dims.pp_enabled:
219+
pipe_meta = model
220+
model = pipe_meta.get_stage_module(pp_rank)
221+
207222
model.to_empty(device="cuda")
208-
model.init_weights()
223+
224+
# TODO(whc) everything below needs to become a function that can be applied to each 'virtual stage' of PP, if
225+
# there are virtual stages
226+
if parallel_dims.pp_enabled:
227+
stage = PipelineStage(
228+
pipe=pipe_meta,
229+
stage_index=pp_rank,
230+
device=device,
231+
group=pp_mesh.get_group(),
232+
)
233+
pp_schedule = ScheduleGPipe(
234+
stage,
235+
n_microbatches=parallel_dims.pp,
236+
loss_fn=loss_fn,
237+
)
238+
else:
239+
# if PP is enabled, we can't use init_weights. instead, we have to rely on offline creating an initial checkpoint
240+
# and loading it to get initialization values. This is becuase the init_weights functions are written assuming
241+
# the whole model (all its weights, or FQNs) exist on one rank. In PP, the init_weights on stage1 might crash
242+
# becuase it can't find "embedding" layer, for example.
243+
244+
# allocate sharded model on GPU and initialize weights via DTensor
245+
model.init_weights()
209246

210247
gpu_mem_stats = gpu_memory_monitor.get_peak_stats()
211248
logger.info(
@@ -217,7 +254,6 @@ def loss_fn(pred, labels):
217254
# build optimizer after applying parallelisms to the model
218255
optimizer = build_optimizer(model, job_config)
219256
scheduler = get_lr_scheduler(optimizer, job_config)
220-
221257
metric_logger = build_metric_logger(job_config)
222258

223259
# torch.compile model for improved performance
@@ -253,7 +289,13 @@ def loss_fn(pred, labels):
253289
logger.info("Created seed checkpoint")
254290
return
255291

256-
checkpoint.load()
292+
checkpoint_loaded = checkpoint.load()
293+
294+
if parallel_dims.pp_enabled and not checkpoint_loaded:
295+
raise RuntimeError(
296+
"Pipeline Parallelism requires meta-initialization and loading seed checkpoint. "
297+
"Please run `./create_seed_checkpoint.sh` and rerun training with `--checkpoint.enable_checkpoint`"
298+
)
257299

258300
# plot losses loaded from checkpoint (if any) to TensorBoard
259301
# NOTE: Loss info after the last log step before checkpoint saving will not be ploted.
@@ -293,14 +335,33 @@ def loss_fn(pred, labels):
293335

294336
input_ids = input_ids.cuda()
295337
labels = labels.cuda()
296-
297338
optimizer.zero_grad()
298339

299-
# forward / backward
300-
with loss_parallel_ctx():
301-
pred = model(input_ids)
302-
loss = loss_fn(pred, labels)
303-
loss.backward()
340+
if parallel_dims.pp_enabled:
341+
# pipeline parallel forward / backward inside step() call
342+
is_last_stage = pp_mesh.get_local_rank() == pp_mesh.size() - 1
343+
344+
with loss_parallel_ctx():
345+
if pp_mesh.get_local_rank() == 0:
346+
pp_schedule.step(input_ids)
347+
elif is_last_stage:
348+
losses = []
349+
pp_schedule.step(target=labels, losses=losses)
350+
else:
351+
schedule.step()
352+
353+
# accumulate losses across pipeline microbatches
354+
loss = (
355+
torch.mean(torch.stack(losses))
356+
if is_last_stage
357+
else torch.Tensor([-1.0])
358+
)
359+
else:
360+
# Non-PP forward / backward
361+
with loss_parallel_ctx():
362+
pred = model(input_ids)
363+
loss = loss_fn(pred, labels)
364+
loss.backward()
304365

305366
# clip gradients
306367
torch.nn.utils.clip_grad_norm_(

0 commit comments

Comments
 (0)