Skip to content

Commit d60f2b3

Browse files
committed
WIP integrate pippy's tracer frontend
ghstack-source-id: 1490240 Pull Request resolved: #161
1 parent 2b82d50 commit d60f2b3

File tree

6 files changed

+57
-16
lines changed

6 files changed

+57
-16
lines changed

run_llama_train.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ TRAINER_DIR=${1:-/home/$USER/local/torchtrain}
1111
# e.g.
1212
# LOG_RANK=0,1 NGPU=4 ./run_llama_train.sh
1313

14-
NGPU=${NGPU:-"8"}
14+
NGPU=${NGPU:-"2"}
1515

1616
# by default log just rank 0 output,
1717
LOG_RANK=${LOG_RANK:-0}

torchtrain/meta_init.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,3 +46,9 @@ def meta_to_real_init_fn(module: nn.Module):
4646
torch.randn_like(param, device=torch.device("cuda"))
4747
)
4848
setattr(submodule, param_name, materialized_param)
49+
for param_name, param in submodule.named_buffers(recurse=False):
50+
if param.is_meta:
51+
materialized_param = nn.Parameter(
52+
torch.randn_like(param, device=torch.device("cuda"))
53+
)
54+
setattr(submodule, param_name, materialized_param)

torchtrain/models/llama/model.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -334,13 +334,16 @@ def __init__(self, model_args: ModelArgs):
334334
self.model_args = model_args
335335
self.tok_embeddings = nn.Embedding(model_args.vocab_size, model_args.dim)
336336

337-
self.freqs_cis = precompute_freqs_cis(
338-
# Note that self.model_args.max_seq_len is multiplied by 2 because the token limit for the Llama 2 generation
339-
# of models is 4096.
340-
# Adding this multiplier instead of using 4096 directly allows for dynamism of token lengths while training
341-
# or fine-tuning.
342-
self.model_args.dim // self.model_args.n_heads,
343-
self.model_args.max_seq_len * 2,
337+
self.register_buffer(
338+
"freqs_cis",
339+
precompute_freqs_cis(
340+
# Note that self.model_args.max_seq_len is multiplied by 2 because the token limit for the Llama 2 generation
341+
# of models is 4096.
342+
# Adding this multiplier instead of using 4096 directly allows for dynamism of token lengths while training
343+
# or fine-tuning.
344+
self.model_args.dim // self.model_args.n_heads,
345+
self.model_args.max_seq_len * 2,
346+
),
344347
)
345348

346349
def forward(self, tokens: torch.Tensor):
@@ -355,7 +358,7 @@ def forward(self, tokens: torch.Tensor):
355358
"""
356359
_bsz, seqlen = tokens.shape
357360
h = self.tok_embeddings(tokens)
358-
self.freqs_cis = self.freqs_cis.to(h.device)
361+
# self.freqs_cis = self.freqs_cis.to(h.device)
359362
freqs_cis = self.freqs_cis[0:seqlen]
360363
return h, freqs_cis
361364

torchtrain/parallelisms/parallelize_llama.py

Lines changed: 35 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from collections import defaultdict
88

99
import torch
10+
from pippy import annotate_split_points, Pipe, PipeSplitWrapper
1011
from torch.distributed._tensor import Replicate, Shard
1112

1213
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
@@ -125,7 +126,31 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig):
125126
"""
126127
# apply PTD parallelisms
127128
if parallel_dims.pp_enabled:
128-
raise NotImplementedError("PP not implemented yet.")
129+
pp_mesh = world_mesh["pp"]
130+
stage_idx = pp_mesh.get_local_rank()
131+
layers_per_rank = len(model.layers) // parallel_dims.pp
132+
for i in range(1, parallel_dims.pp):
133+
annotate_split_points(
134+
model,
135+
{
136+
f"layers.{i * layers_per_rank}": PipeSplitWrapper.SplitPoint.BEGINNING
137+
},
138+
)
139+
140+
# Get example input
141+
label_shape = input_shape = (8, 2048) # TODO
142+
input_ids = torch.randint(
143+
model.vocab_size, input_shape, dtype=torch.int64, device="meta"
144+
)
145+
labels = torch.randint(
146+
model.vocab_size, label_shape, dtype=torch.int64, device="meta"
147+
)
148+
print("input_ids: ", input_ids.shape, input_ids.dtype)
149+
print("labels: ", labels.shape, labels.dtype)
150+
151+
# Create a pipeline representation from the model
152+
pipe = Pipe.from_tracing(model, parallel_dims.pp, example_args=(input_ids,))
153+
model = pipe.get_stage_module(stage_idx)
129154

130155
# First we apply Sequence Parallelism if it's enabled
131156
if parallel_dims.sp_enabled:
@@ -230,9 +255,14 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig):
230255
meta_to_real_init_fn(model)
231256
model.cuda()
232257

233-
# we have now moved from meta to device,
234-
# reset parameters for proper initialization
235-
model.reset_parameters()
236-
logger.info("Model fully initialized via reset_parameters")
258+
if parallel_dims.pp_enabled:
259+
setattr(pipe.split_gm, f"submod_{stage_idx}", model)
260+
return pipe
261+
else:
262+
# TODO figure out PP compatible deferred initialization
263+
# we have now moved from meta to device,
264+
# reset parameters for proper initialization
265+
model.reset_parameters()
266+
logger.info("Model fully initialized via reset_parameters")
237267

238268
return model

train.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -241,10 +241,12 @@ def main(job_config: JobConfig):
241241

242242
input_ids = input_ids.cuda()
243243
labels = labels.cuda()
244-
244+
print("i", input_ids.shape)
245+
print("l", labels.shape)
245246
optimizer.zero_grad()
246247

247248
# forward
249+
# TODO - integrate pp batch splitter
248250
pred = model(input_ids)
249251

250252
with loss_parallel() if parallel_dims.loss_parallel_enabled else contextlib.nullcontext():

train_configs/debug_model.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ max_norm = 1.0 # grad norm clipping
3232
steps = 10
3333
data_parallel_degree = -1
3434
sequence_parallel_degree = 1
35-
pipeline_parallel_degree = 1
35+
pipeline_parallel_degree = 2
3636
fp8_linear = ""
3737
compile = false
3838
checkpoint_interval = 3600

0 commit comments

Comments
 (0)