From fdf46290549df3667143f7a16b098b76adb98e69 Mon Sep 17 00:00:00 2001 From: Tianyu Liu Date: Wed, 3 Jul 2024 21:57:41 -0700 Subject: [PATCH 1/7] fold batch and sequence dimensions to accelerate Sequence Parallel [ghstack-poisoned] --- torchtitan/models/llama/model.py | 22 ++++++++++++++------ torchtitan/models/norms.py | 8 +++---- torchtitan/parallelisms/parallelize_llama.py | 22 ++++++++++++-------- train_configs/debug_model.toml | 6 +++--- 4 files changed, 36 insertions(+), 22 deletions(-) diff --git a/torchtitan/models/llama/model.py b/torchtitan/models/llama/model.py index 49cda6241d..672f1e22e9 100644 --- a/torchtitan/models/llama/model.py +++ b/torchtitan/models/llama/model.py @@ -79,9 +79,7 @@ def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor) -> torch.Ten """ ndim = x.ndim assert 0 <= 1 < ndim - seqlen = x.shape[1] - freqs_cis = freqs_cis[0:seqlen] - assert freqs_cis.shape == (seqlen, x.shape[-1]) + assert freqs_cis.shape == (x.shape[1], x.shape[-1]) shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] return freqs_cis.view(*shape) @@ -187,7 +185,10 @@ def forward( torch.Tensor: Output tensor after attention. """ - bs, seqlen, _ = x.shape + # dim 0 of x is a folded dimension of [bs, seqlen] + seqlen, _ = freqs_cis.shape + bs_seqlen, _ = x.shape + bs = bs_seqlen // seqlen xq, xk, xv = self.wq(x), self.wk(x), self.wv(x) xq = xq.view(bs, seqlen, self.n_heads, self.head_dim) @@ -209,7 +210,8 @@ def forward( output = output.transpose( 1, 2 ).contiguous() # (bs, seqlen, n_local_heads, head_dim) - output = output.view(bs, seqlen, -1) + # output stay folded with batch and sequence dimension + output = output.view(bs * seqlen, -1) return self.wo(output) @@ -427,11 +429,19 @@ def forward(self, tokens: torch.Tensor): """ # passthrough for nonexistent layers, allows easy configuration of pipeline parallel stages h = self.tok_embeddings(tokens) if self.tok_embeddings else tokens + # fold batch dimension and sequence dimension + # for more efficient allgather/reduce_scatter + h = h.view(-1, self.model_args.dim) + freqs_cis = self.freqs_cis[0 : self.model_args.max_seq_len] for layer in self.layers.values(): - h = layer(h, self.freqs_cis) + h = layer(h, freqs_cis) h = self.norm(h) if self.norm else h + # unfold batch and sequence dimension + bs = tokens.shape[0] + bs_seqlen = h.shape[0] + h = h.view(bs, bs_seqlen // bs, self.model_args.dim) output = self.output(h).float() if self.output else h return output diff --git a/torchtitan/models/norms.py b/torchtitan/models/norms.py index 4245fe41df..5e40b0750a 100644 --- a/torchtitan/models/norms.py +++ b/torchtitan/models/norms.py @@ -221,8 +221,8 @@ def _rms_norm_bwd_kernel_sm( class TritonFusedRMSNorm(torch.autograd.Function): @partial( local_map, - out_placements=[Shard(1)], - in_placements=(None, [Shard(1)], [Replicate()], None), + out_placements=[Shard(0)], + in_placements=(None, [Shard(0)], [Replicate()], None), ) @staticmethod def forward(ctx, x, weight, eps): @@ -268,8 +268,8 @@ def forward(ctx, x, weight, eps): @partial( local_map, - out_placements=([Shard(1)], [Partial()], None), - in_placements=(None, [Shard(1)]), + out_placements=([Shard(0)], [Partial()], None), + in_placements=(None, [Shard(0)]), ) @staticmethod def backward(ctx, dy): diff --git a/torchtitan/parallelisms/parallelize_llama.py b/torchtitan/parallelisms/parallelize_llama.py index be627432a3..ee52859f5d 100644 --- a/torchtitan/parallelisms/parallelize_llama.py +++ b/torchtitan/parallelisms/parallelize_llama.py @@ -350,14 +350,18 @@ def apply_tp(model, world_mesh, parallel_dims, job_config: JobConfig): { "tok_embeddings": RowwiseParallel( input_layouts=Replicate(), - output_layouts=Shard(1), ), "output": col_parallel_strategy( - input_layouts=Shard(1), + input_layouts=Shard(0), output_layouts=Shard(-1) if loss_parallel else Replicate(), use_local_output=not loss_parallel, ), - "norm": SequenceParallel(), + "norm": SequenceParallel(sequence_dim=0), + "layers.0": PrepareModuleInput( + input_layouts=(Replicate(), None), + desired_input_layouts=(Shard(0), None), + use_local_output=True, + ), }, ) @@ -365,22 +369,22 @@ def apply_tp(model, world_mesh, parallel_dims, job_config: JobConfig): for layer_id, transformer_block in model.layers.items(): layer_plan = { "attention": prepare_module_input( - input_layouts=(Shard(1), None), + input_layouts=(Shard(0), None), desired_input_layouts=(Replicate(), None), ), "attention.wq": col_parallel_strategy(), "attention.wk": col_parallel_strategy(), "attention.wv": col_parallel_strategy(), - "attention.wo": row_parallel_strategy(output_layouts=Shard(1)), - "attention_norm": SequenceParallel(), + "attention.wo": row_parallel_strategy(output_layouts=Shard(0)), + "attention_norm": SequenceParallel(sequence_dim=0), "feed_forward": prepare_module_input( - input_layouts=(Shard(1),), + input_layouts=(Shard(0),), desired_input_layouts=(Replicate(),), ), "feed_forward.w1": col_parallel_strategy(), - "feed_forward.w2": row_parallel_strategy(output_layouts=Shard(1)), + "feed_forward.w2": row_parallel_strategy(output_layouts=Shard(0)), "feed_forward.w3": col_parallel_strategy(), - "ffn_norm": SequenceParallel(), + "ffn_norm": SequenceParallel(sequence_dim=0), } # Adjust attention module to use the local number of heads diff --git a/train_configs/debug_model.toml b/train_configs/debug_model.toml index da634031a3..b267ddce52 100644 --- a/train_configs/debug_model.toml +++ b/train_configs/debug_model.toml @@ -6,7 +6,7 @@ description = "Llama 3 debug training" use_for_integration_test = true [profiling] -enable_profiling = true +enable_profiling = false save_traces_folder = "profile_trace" profile_freq = 10 enable_memory_snapshot = false @@ -15,7 +15,7 @@ save_memory_snapshot_folder = "memory_snapshot" [metrics] log_freq = 1 enable_color_printing = true -enable_tensorboard = true +enable_tensorboard = false save_tb_folder = "tb" [model] @@ -36,7 +36,7 @@ warmup_steps = 2 # lr scheduler warm up, normally 20% of the train steps max_norm = 1.0 # grad norm clipping steps = 10 data_parallel_degree = -1 -tensor_parallel_degree = 1 +tensor_parallel_degree = 4 fp8_linear = "" compile = false dataset = "c4_mini" # supported datasets: c4_mini (45K), c4 (177M) From e023e12287c2e744a2067ef7bd5fba7fb8550d4e Mon Sep 17 00:00:00 2001 From: Tianyu Liu Date: Wed, 3 Jul 2024 22:08:07 -0700 Subject: [PATCH 2/7] Update on "[DO NOT MERGE][example] fold batch and sequence dimensions to accelerate Sequence Parallel" At the cost of model code change, we can obtain better Sequence Parallel performance. Without folding and unfolding, all-gather and reduce-scatter are performed on dim 1 (sequence dim) instead of dim 0 (folded dim), which incurs an extra `aten.cat` after each collective. Stats from awgu: > for 8k seq len, batch size 1 on H100, these two cats take about 0.18 ms out of 3 ms of FFN compute (6%) This is almost a reverse of #190. [ghstack-poisoned] --- train_configs/debug_model.toml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/train_configs/debug_model.toml b/train_configs/debug_model.toml index b267ddce52..da634031a3 100644 --- a/train_configs/debug_model.toml +++ b/train_configs/debug_model.toml @@ -6,7 +6,7 @@ description = "Llama 3 debug training" use_for_integration_test = true [profiling] -enable_profiling = false +enable_profiling = true save_traces_folder = "profile_trace" profile_freq = 10 enable_memory_snapshot = false @@ -15,7 +15,7 @@ save_memory_snapshot_folder = "memory_snapshot" [metrics] log_freq = 1 enable_color_printing = true -enable_tensorboard = false +enable_tensorboard = true save_tb_folder = "tb" [model] @@ -36,7 +36,7 @@ warmup_steps = 2 # lr scheduler warm up, normally 20% of the train steps max_norm = 1.0 # grad norm clipping steps = 10 data_parallel_degree = -1 -tensor_parallel_degree = 4 +tensor_parallel_degree = 1 fp8_linear = "" compile = false dataset = "c4_mini" # supported datasets: c4_mini (45K), c4 (177M) From 59773bb91ac42d570f5027e252c7a8451f4115ff Mon Sep 17 00:00:00 2001 From: Tianyu Liu Date: Wed, 3 Jul 2024 23:34:33 -0700 Subject: [PATCH 3/7] Update on "[DO NOT MERGE][example] fold batch and sequence dimensions to accelerate Sequence Parallel" Note: This PR is for showcasing purpose only and is almost a reverse of #190. At the cost of model code change, we can obtain better Sequence Parallel performance. Without folding and unfolding, all-gather and reduce-scatter are performed on dim 1 (sequence dim) instead of dim 0 (folded dim), which incurs an extra `aten.cat` after each collective. Stats from awgu: > for 8k seq len, batch size 1 on H100, these two cats take about 0.18 ms out of 3 ms of FFN compute (6%) Experiment on 8-layer `debug_model` before: image after: image [ghstack-poisoned] --- torchtitan/models/llama/model.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/torchtitan/models/llama/model.py b/torchtitan/models/llama/model.py index 672f1e22e9..d9ba768534 100644 --- a/torchtitan/models/llama/model.py +++ b/torchtitan/models/llama/model.py @@ -439,9 +439,8 @@ def forward(self, tokens: torch.Tensor): h = self.norm(h) if self.norm else h # unfold batch and sequence dimension - bs = tokens.shape[0] - bs_seqlen = h.shape[0] - h = h.view(bs, bs_seqlen // bs, self.model_args.dim) + bs, seqlen = tokens.shape + h = h.view(bs, seqlen, self.model_args.dim) output = self.output(h).float() if self.output else h return output From 67dffdbaae8847ba77fa2be4e3d2bcd6d0b036ef Mon Sep 17 00:00:00 2001 From: Tianyu Liu Date: Mon, 8 Jul 2024 18:26:46 -0700 Subject: [PATCH 4/7] Update on "[DO NOT MERGE][example] fold batch and sequence dimensions to accelerate Sequence Parallel" Note: This PR is for showcasing purpose only and is almost a reverse of #190. At the cost of model code change, we can obtain better Sequence Parallel performance. Without folding and unfolding, all-gather and reduce-scatter are performed on dim 1 (sequence dim) instead of dim 0 (folded dim), which incurs an extra `aten.cat` after each collective. Stats from awgu: > for 8k seq len, batch size 1 on H100, these two cats take about 0.18 ms out of 3 ms of FFN compute (6%) Experiment on 8-layer `debug_model` before: image after: image [ghstack-poisoned] --- torchtitan/models/llama/model.py | 28 ++++++++++++-------- torchtitan/parallelisms/parallelize_llama.py | 6 +---- 2 files changed, 18 insertions(+), 16 deletions(-) diff --git a/torchtitan/models/llama/model.py b/torchtitan/models/llama/model.py index d9ba768534..44fda2686c 100644 --- a/torchtitan/models/llama/model.py +++ b/torchtitan/models/llama/model.py @@ -185,7 +185,7 @@ def forward( torch.Tensor: Output tensor after attention. """ - # dim 0 of x is a folded dimension of [bs, seqlen] + # dim 0 of x is a folded dimension of (bs, seqlen) seqlen, _ = freqs_cis.shape bs_seqlen, _ = x.shape bs = bs_seqlen // seqlen @@ -427,21 +427,27 @@ def forward(self, tokens: torch.Tensor): torch.Tensor: Output logits after applying the Transformer model. """ - # passthrough for nonexistent layers, allows easy configuration of pipeline parallel stages - h = self.tok_embeddings(tokens) if self.tok_embeddings else tokens - # fold batch dimension and sequence dimension - # for more efficient allgather/reduce_scatter - h = h.view(-1, self.model_args.dim) + # passthrough for nonexistent layers, allows easy configuration of pipeline parallel stage + if self.tok_embeddings: + # fold batch dimension and sequence dimension + # for more efficient allgather/reduce_scatter + tokens = tokens.view(-1) + h = self.tok_embeddings(tokens) + else: + h = tokens - freqs_cis = self.freqs_cis[0 : self.model_args.max_seq_len] + seqlen = self.model_args.max_seq_len + freqs_cis = self.freqs_cis[0:seqlen] for layer in self.layers.values(): h = layer(h, freqs_cis) h = self.norm(h) if self.norm else h - # unfold batch and sequence dimension - bs, seqlen = tokens.shape - h = h.view(bs, seqlen, self.model_args.dim) - output = self.output(h).float() if self.output else h + if self.output: + # unfold batch and sequence dimension + h = h.view(-1, seqlen, self.model_args.dim) + output = self.output(h).float() + else: + output = h return output @classmethod diff --git a/torchtitan/parallelisms/parallelize_llama.py b/torchtitan/parallelisms/parallelize_llama.py index ee52859f5d..4ff9aea215 100644 --- a/torchtitan/parallelisms/parallelize_llama.py +++ b/torchtitan/parallelisms/parallelize_llama.py @@ -350,6 +350,7 @@ def apply_tp(model, world_mesh, parallel_dims, job_config: JobConfig): { "tok_embeddings": RowwiseParallel( input_layouts=Replicate(), + output_layouts=Shard(0), ), "output": col_parallel_strategy( input_layouts=Shard(0), @@ -357,11 +358,6 @@ def apply_tp(model, world_mesh, parallel_dims, job_config: JobConfig): use_local_output=not loss_parallel, ), "norm": SequenceParallel(sequence_dim=0), - "layers.0": PrepareModuleInput( - input_layouts=(Replicate(), None), - desired_input_layouts=(Shard(0), None), - use_local_output=True, - ), }, ) From 7ac41b123c7a80a3e7243bfde2be4def0583996b Mon Sep 17 00:00:00 2001 From: Tianyu Liu Date: Mon, 8 Jul 2024 18:44:46 -0700 Subject: [PATCH 5/7] Update on "[DO NOT MERGE][example] fold batch and sequence dimensions to accelerate Sequence Parallel" Note: This PR is for showcasing purpose only and is almost a reverse of #190. At the cost of model code change, we can obtain better Sequence Parallel performance. Without folding and unfolding, all-gather and reduce-scatter are performed on dim 1 (sequence dim) instead of dim 0 (folded dim), which incurs an extra `aten.cat` after each collective. Stats from awgu: > for 8k seq len, batch size 1 on H100, these two cats take about 0.18 ms out of 3 ms of FFN compute (6%) Experiment on 8-layer `debug_model` before: image after: image [ghstack-poisoned] --- torchtitan/models/llama/model.py | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/torchtitan/models/llama/model.py b/torchtitan/models/llama/model.py index 44fda2686c..84609148a4 100644 --- a/torchtitan/models/llama/model.py +++ b/torchtitan/models/llama/model.py @@ -428,13 +428,13 @@ def forward(self, tokens: torch.Tensor): """ # passthrough for nonexistent layers, allows easy configuration of pipeline parallel stage + # fold batch dimension and sequence dimension for more efficient allgather/reduce_scatter if self.tok_embeddings: - # fold batch dimension and sequence dimension - # for more efficient allgather/reduce_scatter tokens = tokens.view(-1) h = self.tok_embeddings(tokens) else: h = tokens + h = h.view(-1, self.model_args.dim) seqlen = self.model_args.max_seq_len freqs_cis = self.freqs_cis[0:seqlen] @@ -442,12 +442,9 @@ def forward(self, tokens: torch.Tensor): h = layer(h, freqs_cis) h = self.norm(h) if self.norm else h - if self.output: - # unfold batch and sequence dimension - h = h.view(-1, seqlen, self.model_args.dim) - output = self.output(h).float() - else: - output = h + # unfold batch and sequence dimension + h = h.view(-1, seqlen, self.model_args.dim) + output = self.output(h).float() if self.output else h return output @classmethod From b579e87db82ec4ff89aaa773272987708fcef93b Mon Sep 17 00:00:00 2001 From: Tianyu Liu Date: Mon, 8 Jul 2024 20:17:22 -0700 Subject: [PATCH 6/7] Update on "[DO NOT MERGE][example] fold batch and sequence dimensions to accelerate Sequence Parallel" Note: This PR is for showcasing purpose only and is almost a reverse of #190. At the cost of model code change, we can obtain better Sequence Parallel performance. Without folding and unfolding, all-gather and reduce-scatter are performed on dim 1 (sequence dim) instead of dim 0 (folded dim), which incurs an extra `aten.cat` after each collective. Stats from awgu: > for 8k seq len, batch size 1 on H100, these two cats take about 0.18 ms out of 3 ms of FFN compute (6%) Experiment on 8-layer `debug_model` before: image after: image [ghstack-poisoned] --- torchtitan/models/llama/model.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/torchtitan/models/llama/model.py b/torchtitan/models/llama/model.py index 84609148a4..5a23626c44 100644 --- a/torchtitan/models/llama/model.py +++ b/torchtitan/models/llama/model.py @@ -427,6 +427,7 @@ def forward(self, tokens: torch.Tensor): torch.Tensor: Output logits after applying the Transformer model. """ + bs = tokens.shape[0] # passthrough for nonexistent layers, allows easy configuration of pipeline parallel stage # fold batch dimension and sequence dimension for more efficient allgather/reduce_scatter if self.tok_embeddings: @@ -443,7 +444,7 @@ def forward(self, tokens: torch.Tensor): h = self.norm(h) if self.norm else h # unfold batch and sequence dimension - h = h.view(-1, seqlen, self.model_args.dim) + h = h.view(bs, -1, self.model_args.dim) output = self.output(h).float() if self.output else h return output From 43c08cd3d501bec14586bd6752e3d533dc076cee Mon Sep 17 00:00:00 2001 From: Tianyu Liu Date: Tue, 9 Jul 2024 17:50:39 -0700 Subject: [PATCH 7/7] Update on "[DO NOT MERGE][example] fold batch and sequence dimensions to accelerate Sequence Parallel" Note: This PR is for showcasing purpose only and is almost a reverse of #190. At the cost of model code change, we can obtain better Sequence Parallel performance. Without folding and unfolding, all-gather and reduce-scatter are performed on dim 1 (sequence dim) instead of dim 0 (folded dim), which incurs an extra `aten.cat` after each collective. Stats from awgu: > for 8k seq len, batch size 1 on H100, these two cats take about 0.18 ms out of 3 ms of FFN compute (6%) Experiment on 8-layer `debug_model` before: image after: image [ghstack-poisoned] --- torchtitan/models/llama/model.py | 10 +++------- torchtitan/parallelisms/parallelize_llama.py | 6 +++++- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/torchtitan/models/llama/model.py b/torchtitan/models/llama/model.py index 5a23626c44..516261c999 100644 --- a/torchtitan/models/llama/model.py +++ b/torchtitan/models/llama/model.py @@ -427,15 +427,10 @@ def forward(self, tokens: torch.Tensor): torch.Tensor: Output logits after applying the Transformer model. """ - bs = tokens.shape[0] # passthrough for nonexistent layers, allows easy configuration of pipeline parallel stage + h = self.tok_embeddings(tokens) if self.tok_embeddings else tokens # fold batch dimension and sequence dimension for more efficient allgather/reduce_scatter - if self.tok_embeddings: - tokens = tokens.view(-1) - h = self.tok_embeddings(tokens) - else: - h = tokens - h = h.view(-1, self.model_args.dim) + h = h.view(-1, self.model_args.dim) seqlen = self.model_args.max_seq_len freqs_cis = self.freqs_cis[0:seqlen] @@ -444,6 +439,7 @@ def forward(self, tokens: torch.Tensor): h = self.norm(h) if self.norm else h # unfold batch and sequence dimension + bs = tokens.shape[0] h = h.view(bs, -1, self.model_args.dim) output = self.output(h).float() if self.output else h return output diff --git a/torchtitan/parallelisms/parallelize_llama.py b/torchtitan/parallelisms/parallelize_llama.py index 4ff9aea215..ee52859f5d 100644 --- a/torchtitan/parallelisms/parallelize_llama.py +++ b/torchtitan/parallelisms/parallelize_llama.py @@ -350,7 +350,6 @@ def apply_tp(model, world_mesh, parallel_dims, job_config: JobConfig): { "tok_embeddings": RowwiseParallel( input_layouts=Replicate(), - output_layouts=Shard(0), ), "output": col_parallel_strategy( input_layouts=Shard(0), @@ -358,6 +357,11 @@ def apply_tp(model, world_mesh, parallel_dims, job_config: JobConfig): use_local_output=not loss_parallel, ), "norm": SequenceParallel(sequence_dim=0), + "layers.0": PrepareModuleInput( + input_layouts=(Replicate(), None), + desired_input_layouts=(Shard(0), None), + use_local_output=True, + ), }, )