Skip to content

Commit 05f0802

Browse files
committed
Enable TP+PP support
ghstack-source-id: 8177304 Pull Request resolved: #285
1 parent efb2845 commit 05f0802

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

torchtitan/parallelisms/parallelize_llama.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -205,7 +205,8 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig):
205205
)
206206

207207
# Apply tensor + sequence parallelism to every transformer block
208-
for layer_id, transformer_block in enumerate(model.layers):
208+
for layer_name, transformer_block in model.layers.named_children():
209+
# for layer_id, transformer_block in enumerate(model.layers):
209210
layer_plan = {
210211
"attention": PrepareModuleInput(
211212
input_layouts=(Shard(1), None),

0 commit comments

Comments
 (0)