From ec9c07ae63a6bbe0a9efbb485721d754781aaf0a Mon Sep 17 00:00:00 2001 From: Tianyu Liu Date: Wed, 3 Apr 2024 16:03:46 -0700 Subject: [PATCH] remove folding and unfolding of sequence dim in model.py [ghstack-poisoned] --- torchtrain/models/llama/model.py | 16 ++-------------- torchtrain/parallelisms/parallelize_llama.py | 18 +++++++++--------- 2 files changed, 11 insertions(+), 23 deletions(-) diff --git a/torchtrain/models/llama/model.py b/torchtrain/models/llama/model.py index 77c5e4006f..1cc7494cfe 100644 --- a/torchtrain/models/llama/model.py +++ b/torchtrain/models/llama/model.py @@ -226,10 +226,7 @@ def forward( torch.Tensor: Output tensor after attention. """ - seqlen, _ = freqs_cis.shape - bs_seqlen, _ = x.shape - bsz = bs_seqlen // seqlen - + bsz, seqlen, _ = x.shape xq, xk, xv = self.wq(x), self.wk(x), self.wv(x) xq = xq.view(bsz, seqlen, self.n_heads, self.head_dim) @@ -255,8 +252,7 @@ def forward( output = output.transpose( 1, 2 ).contiguous() # (bs, seqlen, n_local_heads, head_dim) - # output stay folded with batch and sequence dimension - output = output.view(bsz * seqlen, -1) + output = output.view(bsz, seqlen, -1) return self.wo(output) @@ -487,17 +483,9 @@ def forward(self, tokens: torch.Tensor): """ h, freqs_cis = self.embeddings(tokens) - # fold batch and sequence dimension for more efficient allgather/reduce_scatter - h = h.view(-1, self.model_args.dim) - for layer in self.layers: h = layer(h, freqs_cis) - h = self.norm(h) - # unfold batch and sequence dimension - bsz = tokens.shape[0] - bs_seqlen = h.shape[0] - h = h.view(bsz, bs_seqlen // bsz, self.model_args.dim) output = self.output(h).float() return output diff --git a/torchtrain/parallelisms/parallelize_llama.py b/torchtrain/parallelisms/parallelize_llama.py index e64267c506..1e7a9ed57b 100644 --- a/torchtrain/parallelisms/parallelize_llama.py +++ b/torchtrain/parallelisms/parallelize_llama.py @@ -153,7 +153,7 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig): input_layouts=Replicate(), ), "output": col_parallel_strategy( - input_layouts=Shard(0), + input_layouts=Shard(1), output_layouts=( Shard(-1) if parallel_dims.loss_parallel_enabled @@ -161,10 +161,10 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig): ), use_local_output=not parallel_dims.loss_parallel_enabled, ), - "norm": SequenceParallel(sequence_dim=0), + "norm": SequenceParallel(), "layers.0": PrepareModuleInput( input_layouts=(Replicate(), None), - desired_input_layouts=(Shard(0), None), + desired_input_layouts=(Shard(1), None), use_local_output=True, ), }, @@ -174,22 +174,22 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig): for layer_id, transformer_block in enumerate(model.layers): layer_plan = { "attention": PrepareModuleInput( - input_layouts=(Shard(0), None), + input_layouts=(Shard(1), None), desired_input_layouts=(Replicate(), None), ), "attention.wq": col_parallel_strategy(), "attention.wk": col_parallel_strategy(), "attention.wv": col_parallel_strategy(), - "attention.wo": row_parallel_strategy(output_layouts=Shard(0)), - "attention_norm": SequenceParallel(sequence_dim=0), + "attention.wo": row_parallel_strategy(output_layouts=Shard(1)), + "attention_norm": SequenceParallel(), "feed_forward": PrepareModuleInput( - input_layouts=(Shard(0),), + input_layouts=(Shard(1),), desired_input_layouts=(Replicate(),), ), "feed_forward.w1": col_parallel_strategy(), - "feed_forward.w2": row_parallel_strategy(output_layouts=Shard(0)), + "feed_forward.w2": row_parallel_strategy(output_layouts=Shard(1)), "feed_forward.w3": col_parallel_strategy(), - "ffn_norm": SequenceParallel(sequence_dim=0), + "ffn_norm": SequenceParallel(), } # Adjust attention module to use the local number of heads