From 3c86efe4a7515b25a09c60b3e92a0a40dba78eed Mon Sep 17 00:00:00 2001 From: Tianyu Liu Date: Wed, 3 Jul 2024 22:16:34 -0700 Subject: [PATCH] add comment pointing to Sequence Parallel optimization example [ghstack-poisoned] --- torchtitan/parallelisms/parallelize_llama.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/torchtitan/parallelisms/parallelize_llama.py b/torchtitan/parallelisms/parallelize_llama.py index be627432a3..6cafa4abe4 100644 --- a/torchtitan/parallelisms/parallelize_llama.py +++ b/torchtitan/parallelisms/parallelize_llama.py @@ -362,6 +362,9 @@ def apply_tp(model, world_mesh, parallel_dims, job_config: JobConfig): ) # Apply tensor + sequence parallelism to every transformer block + # NOTE: At the cost of model code change, we can accelerate Sequence Parallel + # by folding (and unfolding) the batch dimension and the sequence dimension. + # Examples can be found at https://github.com/pytorch/torchtitan/pull/437 for layer_id, transformer_block in model.layers.items(): layer_plan = { "attention": prepare_module_input(