diff --git a/torchtitan/parallelisms/parallelize_llama.py b/torchtitan/parallelisms/parallelize_llama.py index be627432a3..6cafa4abe4 100644 --- a/torchtitan/parallelisms/parallelize_llama.py +++ b/torchtitan/parallelisms/parallelize_llama.py @@ -362,6 +362,9 @@ def apply_tp(model, world_mesh, parallel_dims, job_config: JobConfig): ) # Apply tensor + sequence parallelism to every transformer block + # NOTE: At the cost of model code change, we can accelerate Sequence Parallel + # by folding (and unfolding) the batch dimension and the sequence dimension. + # Examples can be found at https://github.com/pytorch/torchtitan/pull/437 for layer_id, transformer_block in model.layers.items(): layer_plan = { "attention": prepare_module_input(