Skip to content

Commit c530a64

Browse files
tianyu-lwconstab
authored andcommitted
run sdpa with dtensor
ghstack-source-id: b8b2b58 Pull Request resolved: #180
1 parent 05f0802 commit c530a64

File tree

1 file changed

+12
-10
lines changed

1 file changed

+12
-10
lines changed

torchtitan/parallelisms/parallelize_llama.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
ColwiseParallel,
2323
parallelize_module,
2424
PrepareModuleInput,
25+
PrepareModuleOutput,
2526
RowwiseParallel,
2627
SequenceParallel,
2728
)
@@ -181,15 +182,21 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig):
181182
loss_parallel = parallel_dims.loss_parallel_enabled
182183

183184
# 1. Parallelize the first embedding and the last linear proj layer
184-
# 2. Parallelize the root norm layer over the sequence dim
185-
# 3. Shard the first transformer block's inputs
185+
# 2. Prepare the freq_cis in rotary embedding as dtensor
186+
# 3. Parallelize the root norm layer over the sequence dim
187+
# 4. Shard the first transformer block's inputs
186188
model = parallelize_module(
187189
model,
188190
tp_mesh,
189191
{
190192
"tok_embeddings": RowwiseParallel(
191193
input_layouts=Replicate(),
192194
),
195+
"embeddings": PrepareModuleOutput(
196+
output_layouts=(None, Replicate()),
197+
desired_output_layouts=(None, Replicate()),
198+
use_local_output=False,
199+
),
193200
"output": col_parallel_strategy(
194201
input_layouts=Shard(1),
195202
output_layouts=(Shard(-1) if loss_parallel else Replicate()),
@@ -212,9 +219,9 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig):
212219
input_layouts=(Shard(1), None),
213220
desired_input_layouts=(Replicate(), None),
214221
),
215-
"attention.wq": col_parallel_strategy(),
216-
"attention.wk": col_parallel_strategy(),
217-
"attention.wv": col_parallel_strategy(),
222+
"attention.wq": col_parallel_strategy(use_local_output=False),
223+
"attention.wk": col_parallel_strategy(use_local_output=False),
224+
"attention.wv": col_parallel_strategy(use_local_output=False),
218225
"attention.wo": row_parallel_strategy(output_layouts=Shard(1)),
219226
"attention_norm": SequenceParallel(),
220227
"feed_forward": PrepareModuleInput(
@@ -227,11 +234,6 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig):
227234
"ffn_norm": SequenceParallel(),
228235
}
229236

230-
# Adjust attention module to use the local number of heads
231-
attn_layer = transformer_block.attention
232-
attn_layer.n_heads = attn_layer.n_heads // tp_mesh.size()
233-
attn_layer.n_kv_heads = attn_layer.n_kv_heads // tp_mesh.size()
234-
235237
parallelize_module(
236238
module=transformer_block,
237239
device_mesh=tp_mesh,

0 commit comments

Comments
 (0)