diff --git a/torchtitan/models/llama/model.py b/torchtitan/models/llama/model.py index 49cda6241d..516261c999 100644 --- a/torchtitan/models/llama/model.py +++ b/torchtitan/models/llama/model.py @@ -79,9 +79,7 @@ def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor) -> torch.Ten """ ndim = x.ndim assert 0 <= 1 < ndim - seqlen = x.shape[1] - freqs_cis = freqs_cis[0:seqlen] - assert freqs_cis.shape == (seqlen, x.shape[-1]) + assert freqs_cis.shape == (x.shape[1], x.shape[-1]) shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] return freqs_cis.view(*shape) @@ -187,7 +185,10 @@ def forward( torch.Tensor: Output tensor after attention. """ - bs, seqlen, _ = x.shape + # dim 0 of x is a folded dimension of (bs, seqlen) + seqlen, _ = freqs_cis.shape + bs_seqlen, _ = x.shape + bs = bs_seqlen // seqlen xq, xk, xv = self.wq(x), self.wk(x), self.wv(x) xq = xq.view(bs, seqlen, self.n_heads, self.head_dim) @@ -209,7 +210,8 @@ def forward( output = output.transpose( 1, 2 ).contiguous() # (bs, seqlen, n_local_heads, head_dim) - output = output.view(bs, seqlen, -1) + # output stay folded with batch and sequence dimension + output = output.view(bs * seqlen, -1) return self.wo(output) @@ -425,13 +427,20 @@ def forward(self, tokens: torch.Tensor): torch.Tensor: Output logits after applying the Transformer model. """ - # passthrough for nonexistent layers, allows easy configuration of pipeline parallel stages + # passthrough for nonexistent layers, allows easy configuration of pipeline parallel stage h = self.tok_embeddings(tokens) if self.tok_embeddings else tokens + # fold batch dimension and sequence dimension for more efficient allgather/reduce_scatter + h = h.view(-1, self.model_args.dim) + seqlen = self.model_args.max_seq_len + freqs_cis = self.freqs_cis[0:seqlen] for layer in self.layers.values(): - h = layer(h, self.freqs_cis) + h = layer(h, freqs_cis) h = self.norm(h) if self.norm else h + # unfold batch and sequence dimension + bs = tokens.shape[0] + h = h.view(bs, -1, self.model_args.dim) output = self.output(h).float() if self.output else h return output diff --git a/torchtitan/models/norms.py b/torchtitan/models/norms.py index 4245fe41df..5e40b0750a 100644 --- a/torchtitan/models/norms.py +++ b/torchtitan/models/norms.py @@ -221,8 +221,8 @@ def _rms_norm_bwd_kernel_sm( class TritonFusedRMSNorm(torch.autograd.Function): @partial( local_map, - out_placements=[Shard(1)], - in_placements=(None, [Shard(1)], [Replicate()], None), + out_placements=[Shard(0)], + in_placements=(None, [Shard(0)], [Replicate()], None), ) @staticmethod def forward(ctx, x, weight, eps): @@ -268,8 +268,8 @@ def forward(ctx, x, weight, eps): @partial( local_map, - out_placements=([Shard(1)], [Partial()], None), - in_placements=(None, [Shard(1)]), + out_placements=([Shard(0)], [Partial()], None), + in_placements=(None, [Shard(0)]), ) @staticmethod def backward(ctx, dy): diff --git a/torchtitan/parallelisms/parallelize_llama.py b/torchtitan/parallelisms/parallelize_llama.py index be627432a3..ee52859f5d 100644 --- a/torchtitan/parallelisms/parallelize_llama.py +++ b/torchtitan/parallelisms/parallelize_llama.py @@ -350,14 +350,18 @@ def apply_tp(model, world_mesh, parallel_dims, job_config: JobConfig): { "tok_embeddings": RowwiseParallel( input_layouts=Replicate(), - output_layouts=Shard(1), ), "output": col_parallel_strategy( - input_layouts=Shard(1), + input_layouts=Shard(0), output_layouts=Shard(-1) if loss_parallel else Replicate(), use_local_output=not loss_parallel, ), - "norm": SequenceParallel(), + "norm": SequenceParallel(sequence_dim=0), + "layers.0": PrepareModuleInput( + input_layouts=(Replicate(), None), + desired_input_layouts=(Shard(0), None), + use_local_output=True, + ), }, ) @@ -365,22 +369,22 @@ def apply_tp(model, world_mesh, parallel_dims, job_config: JobConfig): for layer_id, transformer_block in model.layers.items(): layer_plan = { "attention": prepare_module_input( - input_layouts=(Shard(1), None), + input_layouts=(Shard(0), None), desired_input_layouts=(Replicate(), None), ), "attention.wq": col_parallel_strategy(), "attention.wk": col_parallel_strategy(), "attention.wv": col_parallel_strategy(), - "attention.wo": row_parallel_strategy(output_layouts=Shard(1)), - "attention_norm": SequenceParallel(), + "attention.wo": row_parallel_strategy(output_layouts=Shard(0)), + "attention_norm": SequenceParallel(sequence_dim=0), "feed_forward": prepare_module_input( - input_layouts=(Shard(1),), + input_layouts=(Shard(0),), desired_input_layouts=(Replicate(),), ), "feed_forward.w1": col_parallel_strategy(), - "feed_forward.w2": row_parallel_strategy(output_layouts=Shard(1)), + "feed_forward.w2": row_parallel_strategy(output_layouts=Shard(0)), "feed_forward.w3": col_parallel_strategy(), - "ffn_norm": SequenceParallel(), + "ffn_norm": SequenceParallel(sequence_dim=0), } # Adjust attention module to use the local number of heads