From cce58c12b0784e675a7510888524ff45a0f75b8f Mon Sep 17 00:00:00 2001 From: Wanchao Liang Date: Tue, 7 May 2024 14:15:17 -0700 Subject: [PATCH 1/2] simplify embedding + first transformer block TP as titled, we can directly specify the rowwise parallel embedding output layouts be shard on sequence dim, so that we don't need the first layer prepare input. Switching to output_layouts = Shard(1) would also trigger reduce_scatter instead of allreduce for embedding layer, which could give some small perf wins --- torchtitan/parallelisms/parallelize_llama.py | 6 +----- train_configs/llama3_8b.toml | 2 +- 2 files changed, 2 insertions(+), 6 deletions(-) diff --git a/torchtitan/parallelisms/parallelize_llama.py b/torchtitan/parallelisms/parallelize_llama.py index fca776a7b5..779be60e8e 100644 --- a/torchtitan/parallelisms/parallelize_llama.py +++ b/torchtitan/parallelisms/parallelize_llama.py @@ -160,6 +160,7 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig): { "tok_embeddings": RowwiseParallel( input_layouts=Replicate(), + output_layouts=Shard(1), ), "output": col_parallel_strategy( input_layouts=Shard(1), @@ -167,11 +168,6 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig): use_local_output=not loss_parallel, ), "norm": SequenceParallel(), - "layers.0": PrepareModuleInput( - input_layouts=(Replicate(), None), - desired_input_layouts=(Shard(1), None), - use_local_output=True, - ), }, ) diff --git a/train_configs/llama3_8b.toml b/train_configs/llama3_8b.toml index aaba99a215..434a74fda4 100644 --- a/train_configs/llama3_8b.toml +++ b/train_configs/llama3_8b.toml @@ -36,7 +36,7 @@ tensor_parallel_degree = 1 pipeline_parallel_degree = 1 fp8_linear = "" compile = false -dataset = "c4" +dataset = "c4_mini" [checkpoint] enable_checkpoint = false From 0b9a3a7b13dceb8720138328d0aba5bf3bd32da9 Mon Sep 17 00:00:00 2001 From: Wanchao Liang Date: Tue, 7 May 2024 14:17:55 -0700 Subject: [PATCH 2/2] revert toml changes --- train_configs/llama3_8b.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/train_configs/llama3_8b.toml b/train_configs/llama3_8b.toml index 434a74fda4..aaba99a215 100644 --- a/train_configs/llama3_8b.toml +++ b/train_configs/llama3_8b.toml @@ -36,7 +36,7 @@ tensor_parallel_degree = 1 pipeline_parallel_degree = 1 fp8_linear = "" compile = false -dataset = "c4_mini" +dataset = "c4" [checkpoint] enable_checkpoint = false