Skip to content

Commit df14507

Browse files
committed
run sdpa with dtensor
ghstack-source-id: 43941c1 Pull Request resolved: #180
1 parent dca7657 commit df14507

File tree

1 file changed

+12
-10
lines changed

1 file changed

+12
-10
lines changed

torchtrain/parallelisms/parallelize_llama.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
ColwiseParallel,
2020
parallelize_module,
2121
PrepareModuleInput,
22+
PrepareModuleOutput,
2223
RowwiseParallel,
2324
SequenceParallel,
2425
)
@@ -143,15 +144,21 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig):
143144
)
144145

145146
# 1. Parallelize the first embedding and the last linear proj layer
146-
# 2. Parallelize the root norm layer over the sequence dim
147-
# 3. Shard the first transformer block's inputs
147+
# 2. Prepare the freq_cis in rotary embedding as dtensor
148+
# 3. Parallelize the root norm layer over the sequence dim
149+
# 4. Shard the first transformer block's inputs
148150
model = parallelize_module(
149151
model,
150152
tp_mesh,
151153
{
152154
"embeddings.tok_embeddings": RowwiseParallel(
153155
input_layouts=Replicate(),
154156
),
157+
"embeddings": PrepareModuleOutput(
158+
output_layouts=(None, Replicate()),
159+
desired_output_layouts=(None, Replicate()),
160+
use_local_output=False,
161+
),
155162
"output": col_parallel_strategy(
156163
input_layouts=Shard(0),
157164
output_layouts=(
@@ -177,9 +184,9 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig):
177184
input_layouts=(Shard(0), None),
178185
desired_input_layouts=(Replicate(), None),
179186
),
180-
"attention.wq": col_parallel_strategy(),
181-
"attention.wk": col_parallel_strategy(),
182-
"attention.wv": col_parallel_strategy(),
187+
"attention.wq": col_parallel_strategy(use_local_output=False),
188+
"attention.wk": col_parallel_strategy(use_local_output=False),
189+
"attention.wv": col_parallel_strategy(use_local_output=False),
183190
"attention.wo": row_parallel_strategy(output_layouts=Shard(0)),
184191
"attention_norm": SequenceParallel(sequence_dim=0),
185192
"feed_forward": PrepareModuleInput(
@@ -192,11 +199,6 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig):
192199
"ffn_norm": SequenceParallel(sequence_dim=0),
193200
}
194201

195-
# Adjust attention module to use the local number of heads
196-
attn_layer = transformer_block.attention
197-
attn_layer.n_heads = attn_layer.n_heads // tp_mesh.size()
198-
attn_layer.n_kv_heads = attn_layer.n_kv_heads // tp_mesh.size()
199-
200202
parallelize_module(
201203
module=transformer_block,
202204
device_mesh=tp_mesh,

0 commit comments

Comments
 (0)