Skip to content

Commit 15a001f

Browse files
committed
Add Sequence Parallelism to llama
Somehow the torch.compile not working although eager sequence parallelism working, so currently don't turn it on by default ghstack-source-id: 0d251f2 Pull Request resolved: #32
1 parent 3d27c70 commit 15a001f

File tree

2 files changed

+125
-5
lines changed

2 files changed

+125
-5
lines changed

torchtrain/models/llama/model.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -211,7 +211,10 @@ def forward(
211211
torch.Tensor: Output tensor after attention.
212212
213213
"""
214-
bsz, seqlen, _ = x.shape
214+
seqlen, _ = freqs_cis.shape
215+
bs_seqlen, _ = x.shape
216+
bsz = bs_seqlen // seqlen
217+
215218
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
216219

217220
xq = xq.view(bsz, seqlen, self.n_heads, self.head_dim)
@@ -237,7 +240,8 @@ def forward(
237240
output = output.transpose(
238241
1, 2
239242
).contiguous() # (bs, seqlen, n_local_heads, head_dim)
240-
output = output.view(bsz, seqlen, -1)
243+
# output stay folded with batch and sequence dimension
244+
output = output.view(bsz * seqlen, -1)
241245
return self.wo(output)
242246

243247

@@ -342,7 +346,6 @@ def __init__(self, layer_id: int, args: ModelArgs):
342346
super().__init__()
343347
self.n_heads = args.n_heads
344348
self.dim = args.dim
345-
self.head_dim = args.dim // args.n_heads
346349
self.attention = Attention(args)
347350
self.feed_forward = FeedForward(
348351
dim=args.dim,
@@ -422,10 +425,17 @@ def forward(self, tokens: torch.Tensor):
422425
423426
"""
424427
h, freqs_cis = self.embeddings(tokens)
428+
# fold batch and sequence dimension for more efficient allgather/reduce_scatter
429+
h = h.view(-1, self.params.dim)
425430

426431
for layer in self.layers:
427432
h = layer(h, freqs_cis)
433+
428434
h = self.norm(h)
435+
# unfold batch and sequence dimension
436+
bsz = tokens.shape[0]
437+
bs_seqlen = h.shape[0]
438+
h = h.view(bsz, bs_seqlen // bsz, self.params.dim)
429439
output = self.output(h).float()
430440
return output
431441

torchtrain/parallelisms/parallelize_llama.py

Lines changed: 112 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,13 @@
77
import logging
88

99
import torch
10+
from torch.distributed._tensor import (
11+
distribute_module,
12+
distribute_tensor,
13+
DTensor,
14+
Replicate,
15+
Shard,
16+
)
1017

1118
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
1219
checkpoint_wrapper as ptd_checkpoint_wrapper,
@@ -19,11 +26,46 @@
1926
ShardingStrategy,
2027
)
2128
from torch.distributed.fsdp.wrap import enable_wrap, wrap
29+
from torch.distributed.tensor.parallel import (
30+
ColwiseParallel,
31+
parallelize_module,
32+
PrepareModuleInput,
33+
RowwiseParallel,
34+
)
2235

2336
from torchtrain.logging_utils import rank0_log
2437

2538
logger = logging.getLogger(__name__)
2639

40+
41+
def distribute_rmsnorm(module, device_mesh):
42+
# temp sharding API until PTD API is added
43+
def prepare_input_fn(inputs, device_mesh):
44+
if isinstance(inputs[0], DTensor):
45+
return inputs
46+
elif isinstance(inputs[0], torch.Tensor):
47+
shard_tensor = DTensor.from_local(
48+
inputs[0], device_mesh, [Shard(0)], run_check=False
49+
)
50+
return shard_tensor
51+
else:
52+
raise NotImplementedError("!!")
53+
54+
def partition_fn(name, module, device_mesh):
55+
for name, param in module.named_parameters():
56+
dist_param = torch.nn.Parameter(
57+
distribute_tensor(param, device_mesh, [Replicate()])
58+
)
59+
module.register_parameter(name, dist_param)
60+
61+
return distribute_module(
62+
module,
63+
device_mesh,
64+
partition_fn,
65+
input_fn=prepare_input_fn,
66+
)
67+
68+
2769
# Uses PTD FSDP AC wrapper
2870
def checkpoint_wrapper(module, config):
2971
return ptd_checkpoint_wrapper(
@@ -43,7 +85,75 @@ def parallelize_llama(model, world_mesh, parallel_dims, args):
4385
if parallel_dims.pp_enabled:
4486
raise NotImplementedError("PP not implemented yet.")
4587
if parallel_dims.sp_enabled:
46-
raise NotImplementedError("SP not implemented yet.")
88+
# First we apply Sequence Parallelism if it's enabled
89+
tp_mesh = world_mesh["sp"] if world_mesh.ndim > 1 else world_mesh
90+
sp_degree = args.sp_degree
91+
# First:
92+
# 1. parallelize the first embedding and the last linear proj layer
93+
# 2. shard the first layer of transformer block
94+
# TODO: enable loss parallel once it's ready
95+
model = parallelize_module(
96+
model,
97+
tp_mesh,
98+
{
99+
"embeddings.tok_embeddings": RowwiseParallel(
100+
input_layouts=Replicate(),
101+
),
102+
"output": ColwiseParallel(
103+
input_layouts=Shard(0),
104+
output_layouts=Replicate(),
105+
),
106+
"layers.0": PrepareModuleInput(
107+
input_layouts=(Replicate(), None),
108+
desired_input_layouts=(Shard(0), None),
109+
use_local_output=True,
110+
),
111+
},
112+
)
113+
114+
# apply sequence parallelism to every transformer block
115+
for layer_id, transformer_block in enumerate(model.layers):
116+
layer_plan = {
117+
"attention": PrepareModuleInput(
118+
input_layouts=(Shard(0), None),
119+
desired_input_layouts=(Replicate(), None),
120+
),
121+
"attention.wq": ColwiseParallel(),
122+
"attention.wk": ColwiseParallel(),
123+
"attention.wv": ColwiseParallel(),
124+
"attention.wo": RowwiseParallel(output_layouts=Shard(0)),
125+
"feed_forward": PrepareModuleInput(
126+
input_layouts=(Shard(0),),
127+
desired_input_layouts=(Replicate(),),
128+
),
129+
"feed_forward.w1": ColwiseParallel(),
130+
"feed_forward.w2": RowwiseParallel(output_layouts=Shard(0)),
131+
"feed_forward.w3": ColwiseParallel(),
132+
}
133+
# if layer_id == 0:
134+
# # in first transformer block we need to shard the input
135+
# layer_plan[""] = PrepareModuleInput(
136+
# input_layouts=(Replicate(), None),
137+
# desired_input_layouts=(Shard(0), None),
138+
# )
139+
140+
# adjust num_heads in attention layer to local heads
141+
attn_layer = transformer_block.attention
142+
attn_layer.n_heads = attn_layer.n_heads // sp_degree
143+
attn_layer.n_kv_heads = attn_layer.n_kv_heads // sp_degree
144+
145+
# shard RMSNorm layers
146+
distribute_rmsnorm(transformer_block.attention_norm, tp_mesh)
147+
distribute_rmsnorm(transformer_block.ffn_norm, tp_mesh)
148+
149+
parallelize_module(
150+
module=transformer_block,
151+
device_mesh=tp_mesh,
152+
parallelize_plan=layer_plan,
153+
)
154+
155+
rank0_log("Applied Sequence Parallelism to the model...")
156+
47157
if parallel_dims.dp_enabled:
48158
dp_mesh = world_mesh["dp"] if world_mesh.ndim > 1 else world_mesh
49159
assert dp_mesh.mesh_dim_names == ["dp"], dp_mesh.mesh_dim_names
@@ -73,6 +183,6 @@ def parallelize_llama(model, world_mesh, parallel_dims, args):
73183
# wrap the rest layers with FSDP
74184
model = wrap(model.cuda())
75185

76-
rank0_log("Applied parallelisms to the model...")
186+
rank0_log("Applied FSDP to the model...")
77187

78188
return model

0 commit comments

Comments
 (0)