Skip to content

Commit 3216f3e

Browse files
committed
Manual PP
runs PP+DP and PP+TP without issue, runs PP+TP+DP with decreasing loss, but fails DCP save TODOs - clean up manualstage creation - config options for configuring stage split - a way to switch between tracer/manual ghstack-source-id: 952f364 Pull Request resolved: #308
1 parent 9778897 commit 3216f3e

File tree

3 files changed

+180
-18
lines changed

3 files changed

+180
-18
lines changed

torchtitan/models/llama/model.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor) -> torch.Ten
7676
"""
7777
ndim = x.ndim
7878
assert 0 <= 1 < ndim
79-
assert freqs_cis.shape == (x.shape[1], x.shape[-1])
79+
assert freqs_cis.shape == (x.shape[1], x.shape[-1]), (freqs_cis.shape, x.shape)
8080
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
8181
return freqs_cis.view(*shape)
8282

@@ -182,6 +182,7 @@ def forward(
182182
torch.Tensor: Output tensor after attention.
183183
184184
"""
185+
print(f"transformer layer got input shape {x.shape}")
185186
bs, seqlen, _ = x.shape
186187
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
187188

torchtitan/parallelisms/parallelize_llama.py

Lines changed: 112 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
# llama model, i.e. activation checkpointing, etc.
99

1010
from collections import defaultdict
11-
from typing import Tuple
11+
from typing import List, Tuple
1212

1313
import torch
1414

@@ -138,7 +138,112 @@ def get_tp_parallel_strategy(
138138
return RowwiseParallel, ColwiseParallel
139139

140140

141-
def apply_pipeline_parallelism(model, world_mesh, parallel_dims, job_config: JobConfig):
141+
class DummyTransformerLayer(torch.nn.Module):
142+
def forward(self, input, freqs_cis):
143+
return input
144+
145+
146+
class TransformerChunk(torch.nn.Module):
147+
def __init__(
148+
self,
149+
orig_model, # : Transformer,
150+
this_stage_layer_names: List[str],
151+
device,
152+
input_seqlen: int,
153+
):
154+
super().__init__()
155+
self.tok_embeddings = None
156+
157+
# inferring seqlen from forward(input) only works on stage0, bc on later stages
158+
# the hidden state input may have reduced seqlen due to TP. We need to use the
159+
# original (full) seqlen for freqs_cis to be correct.
160+
self.input_seqlen = input_seqlen
161+
162+
if "tok_embeddings" in this_stage_layer_names:
163+
self.tok_embeddings = orig_model.tok_embeddings
164+
165+
with torch.device(device):
166+
self.freqs_cis = orig_model._precompute_freqs_cis()
167+
168+
# preserve FQNs of original model by preserving structure
169+
# (including preserving position in layers[] list)- use dummy module
170+
self.layers = orig_model.layers
171+
for i in range(len(self.layers)):
172+
if f"layers.{i}" not in this_stage_layer_names:
173+
self.layers[i] = DummyTransformerLayer()
174+
self.norm = None
175+
if "norm" in this_stage_layer_names:
176+
self.norm = orig_model.norm
177+
self.output = None
178+
if "output" in this_stage_layer_names:
179+
self.output = orig_model.output
180+
181+
def forward(self, input):
182+
"""
183+
Copypaste of original Transformer.forward, with conditionals and unpacking added
184+
such that we handle the cases where this rank doesn't have the embedding, or doesn't have
185+
the output layers.
186+
"""
187+
if self.tok_embeddings:
188+
h = self.tok_embeddings(input)
189+
else:
190+
h = input
191+
192+
freqs_cis = self.freqs_cis[0 : self.input_seqlen]
193+
194+
for layer in self.layers:
195+
h = layer(h, freqs_cis)
196+
output = h
197+
198+
if self.norm:
199+
h = self.norm(h)
200+
output = h
201+
202+
if self.output:
203+
output = self.output(h).float()
204+
return output
205+
206+
207+
def apply_pipeline_parallelism_manual(
208+
model, world_mesh, parallel_dims, job_config: JobConfig, device
209+
):
210+
"""
211+
This API gets individual torch.nn.Module objects for each pipeline stage (including virtual stages).
212+
213+
The SPMD parallelisms should be applied to
214+
"""
215+
pp_mesh = world_mesh["pp"]
216+
pp_rank = pp_mesh.get_local_rank()
217+
pp_size = pp_mesh.size()
218+
stage_idx = pp_rank # TODO support virtual stages
219+
layers_per_rank = len(model.layers) // parallel_dims.pp
220+
layer_offset = layers_per_rank * pp_rank
221+
this_stage_layer_names = [
222+
f"layers.{i + layer_offset}" for i in range(layers_per_rank)
223+
]
224+
if pp_rank == 0:
225+
this_stage_layer_names.insert(0, "tok_embeddings")
226+
assert "layers.0" in this_stage_layer_names
227+
elif pp_rank == pp_size - 1:
228+
this_stage_layer_names.append("norm")
229+
this_stage_layer_names.append("output")
230+
assert "layers.1" in this_stage_layer_names
231+
232+
input_seqlen = 2048 # TODO hack
233+
234+
stage_model = TransformerChunk(model, this_stage_layer_names, device, input_seqlen)
235+
# Create a pipeline representation from the model
236+
237+
# note for PipPy API
238+
# it would be nice if we could get fx.graph out of PipeInfo and then make it possible to manually construct PipeInfo
239+
# and then use the same _PipelineStage ctor in either tracer or manual cases.
240+
241+
return (stage_model,)
242+
243+
244+
def apply_pipeline_parallelism_tracer(
245+
model, world_mesh, parallel_dims, job_config: JobConfig
246+
):
142247
assert (
143248
parallel_dims.pp_enabled
144249
), "can't apply pipeline parallelism if it is not enabled"
@@ -212,6 +317,8 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig):
212317

213318
# Apply tensor + sequence parallelism to every transformer block
214319
for layer_id, transformer_block in enumerate(model.layers):
320+
if isinstance(transformer_block, DummyTransformerLayer):
321+
continue
215322
layer_plan = {
216323
"attention": PrepareModuleInput(
217324
input_layouts=(Shard(1), None),
@@ -259,6 +366,8 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig):
259366
ac_mode = job_config.activation_checkpoint.mode
260367
fsdp_config = {"mesh": dp_mesh, "mp_policy": mp_policy}
261368
for layer_name, transformer_block in model.layers.named_children():
369+
if isinstance(transformer_block, DummyTransformerLayer):
370+
continue
262371
if job_config.activation_checkpoint.mode in ("full", "selective"):
263372
transformer_block = checkpoint_wrapper(
264373
transformer_block, job_config.activation_checkpoint
@@ -275,6 +384,7 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig):
275384
)
276385
model.layers.add_module(layer_name, transformer_block)
277386

387+
# TODO(whc) do we need reshard_after_forward setting here too?
278388
model = fully_shard(model, **fsdp_config)
279389
if ac_mode in ("full", "selective"):
280390
logger.info(f"Applied {ac_mode} activation checkpointing to the model")

train.py

Lines changed: 66 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -22,15 +22,17 @@
2222

2323
# TODO(whc) this can be removed after pippy migration into pytorch core is complete.
2424
try:
25-
from pippy import ScheduleGPipe
26-
from pippy.PipelineStage import _PipelineStage
25+
from pippy import ManualPipelineStage, ScheduleGPipe
26+
27+
# from pippy.PipelineStage import _PipelineStage
2728
except ImportError as exc:
2829
raise ImportError(
2930
"pippy is not installed. Please install it to use pipeline parallelism. "
3031
"`pip install git+https://github.com/pytorch/pippy`"
3132
) from exc
3233

3334
from torch.distributed import destroy_process_group
35+
from torch.distributed._composable.fsdp.fully_shard import FSDPModule
3436
from torch.distributed.checkpoint.stateful import Stateful
3537
from torch.distributed.elastic.multiprocessing.errors import record
3638
from torch.distributed.tensor.parallel import loss_parallel
@@ -224,28 +226,70 @@ def loss_fn(pred, labels):
224226

225227
if parallel_dims.pp_enabled:
226228
# TODO(whc) now i need to figure out how to align this with the `model_parallelize_fns[model_name] pattern`
227-
from torchtitan.parallelisms.parallelize_llama import apply_pipeline_parallelism
229+
from torchtitan.parallelisms.parallelize_llama import (
230+
apply_pipeline_parallelism_manual,
231+
)
228232

229-
model, pipe_info = apply_pipeline_parallelism(
230-
model, world_mesh, parallel_dims, job_config
233+
stage_models = apply_pipeline_parallelism_manual(
234+
model, world_mesh, parallel_dims, job_config, device
231235
)
236+
stage_models = [
237+
models_parallelize_fns[model_name](
238+
model, world_mesh, parallel_dims, job_config
239+
)
240+
for model in stage_models
241+
]
242+
# TODO virtual stages NYI
243+
model = stage_models[0]
232244

233-
# apply PT-D DP/TP parallelisms and activation checkpointing
234-
model = models_parallelize_fns[model_name](
235-
model, world_mesh, parallel_dims, job_config
236-
)
245+
else:
246+
# apply PT-D DP/TP parallelisms and activation checkpointing
247+
model = models_parallelize_fns[model_name](
248+
model, world_mesh, parallel_dims, job_config
249+
)
237250

238251
model.to_empty(device="cuda")
239252

240253
# TODO(whc) everything below needs to become a function that can be applied to each 'virtual stage' of PP, if
241254
# there are virtual stages
242255
if parallel_dims.pp_enabled:
243-
stage = _PipelineStage(
244-
stage_module=model,
245-
stage_index=pp_rank,
246-
pipe_info=pipe_info,
247-
device=device,
248-
group=pp_mesh.get_group(),
256+
# stage = _PipelineStage(
257+
# stage_module=model,
258+
# stage_index=pp_rank,
259+
# pipe_info=pipe_info,
260+
# device=device,
261+
# group=pp_mesh.get_group(),
262+
# )
263+
assert len(stage_models) == 1, "virtual stages NYI"
264+
chunks = parallel_dims.pp
265+
pp_mesh = world_mesh["pp"]
266+
pp_rank = pp_mesh.get_local_rank()
267+
pp_size = pp_mesh.size()
268+
stage_idx = pp_rank # TODO support virtual stages
269+
# Get example input
270+
if pp_rank == 0:
271+
input_shape = (job_config.training.batch_size, job_config.training.seq_len)
272+
input_ids = torch.randint(
273+
model_config.vocab_size, input_shape, dtype=torch.int64, device="meta"
274+
)
275+
else:
276+
# TODO(whc) can we rely on shape inference so that user doesn't have to compute TP impact on seq_len
277+
input_shape = (
278+
job_config.training.batch_size,
279+
int(job_config.training.seq_len // parallel_dims.tp),
280+
model_config.dim,
281+
)
282+
input_ids = torch.randint(
283+
model_config.vocab_size, input_shape, dtype=torch.float32, device="meta"
284+
)
285+
stage = ManualPipelineStage(
286+
model,
287+
pp_rank,
288+
pp_size,
289+
device,
290+
chunks,
291+
input_args=input_ids.chunk(chunks)[0],
292+
group=pp_mesh.get_group("pp"),
249293
)
250294
pp_schedule = ScheduleGPipe(
251295
stage,
@@ -259,6 +303,9 @@ def loss_fn(pred, labels):
259303
# becuase it can't find "embedding" layer, for example.
260304

261305
# allocate sharded model on GPU and initialize weights via DTensor
306+
307+
# if we were to rewrite init_weights to work on the pp-model, we could call it unconditionally here, and that
308+
# would not free us from needing seed-checkpoint init
262309
model.init_weights()
263310

264311
gpu_mem_stats = gpu_memory_monitor.get_peak_stats()
@@ -268,6 +315,10 @@ def loss_fn(pred, labels):
268315
f"({gpu_mem_stats.max_reserved_pct:.2f}%)"
269316
)
270317

318+
if isinstance(model, FSDPModule) and parallel_dims.pp_enabled:
319+
# reshard now to counteract an issue where FSDP's states got advanced during PP stage shape inference
320+
model.reshard()
321+
271322
# build optimizer after applying parallelisms to the model
272323
optimizer = build_optimizer(model, job_config)
273324
scheduler = get_lr_scheduler(optimizer, job_config)

0 commit comments

Comments
 (0)