Skip to content
23 changes: 16 additions & 7 deletions torchtitan/models/llama/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,9 +79,7 @@ def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor) -> torch.Ten
"""
ndim = x.ndim
assert 0 <= 1 < ndim
seqlen = x.shape[1]
freqs_cis = freqs_cis[0:seqlen]
assert freqs_cis.shape == (seqlen, x.shape[-1])
assert freqs_cis.shape == (x.shape[1], x.shape[-1])
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
return freqs_cis.view(*shape)

Expand Down Expand Up @@ -187,7 +185,10 @@ def forward(
torch.Tensor: Output tensor after attention.

"""
bs, seqlen, _ = x.shape
# dim 0 of x is a folded dimension of (bs, seqlen)
seqlen, _ = freqs_cis.shape
bs_seqlen, _ = x.shape
bs = bs_seqlen // seqlen
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)

xq = xq.view(bs, seqlen, self.n_heads, self.head_dim)
Expand All @@ -209,7 +210,8 @@ def forward(
output = output.transpose(
1, 2
).contiguous() # (bs, seqlen, n_local_heads, head_dim)
output = output.view(bs, seqlen, -1)
# output stay folded with batch and sequence dimension
output = output.view(bs * seqlen, -1)
return self.wo(output)


Expand Down Expand Up @@ -425,13 +427,20 @@ def forward(self, tokens: torch.Tensor):
torch.Tensor: Output logits after applying the Transformer model.

"""
# passthrough for nonexistent layers, allows easy configuration of pipeline parallel stages
# passthrough for nonexistent layers, allows easy configuration of pipeline parallel stage
h = self.tok_embeddings(tokens) if self.tok_embeddings else tokens
# fold batch dimension and sequence dimension for more efficient allgather/reduce_scatter
h = h.view(-1, self.model_args.dim)

seqlen = self.model_args.max_seq_len
freqs_cis = self.freqs_cis[0:seqlen]
for layer in self.layers.values():
h = layer(h, self.freqs_cis)
h = layer(h, freqs_cis)

h = self.norm(h) if self.norm else h
# unfold batch and sequence dimension
bs = tokens.shape[0]
h = h.view(bs, -1, self.model_args.dim)
output = self.output(h).float() if self.output else h
return output

Expand Down
8 changes: 4 additions & 4 deletions torchtitan/models/norms.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,8 +221,8 @@ def _rms_norm_bwd_kernel_sm(
class TritonFusedRMSNorm(torch.autograd.Function):
@partial(
local_map,
out_placements=[Shard(1)],
in_placements=(None, [Shard(1)], [Replicate()], None),
out_placements=[Shard(0)],
in_placements=(None, [Shard(0)], [Replicate()], None),
)
@staticmethod
def forward(ctx, x, weight, eps):
Expand Down Expand Up @@ -268,8 +268,8 @@ def forward(ctx, x, weight, eps):

@partial(
local_map,
out_placements=([Shard(1)], [Partial()], None),
in_placements=(None, [Shard(1)]),
out_placements=([Shard(0)], [Partial()], None),
in_placements=(None, [Shard(0)]),
)
@staticmethod
def backward(ctx, dy):
Expand Down
22 changes: 13 additions & 9 deletions torchtitan/parallelisms/parallelize_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,37 +350,41 @@ def apply_tp(model, world_mesh, parallel_dims, job_config: JobConfig):
{
"tok_embeddings": RowwiseParallel(
input_layouts=Replicate(),
output_layouts=Shard(1),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

curious: could this be output_layouts=Shard(0) and then do not need the PrepareModuleInput?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@awgu
Currently we are doing folding after embedding layer, so we can't do what you suggested.
But I just realize that maybe we can do folding even before embedding layer, then I think we can do this, just like the non-folding case.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@awgu
OK I tried out the change. Please see comparison here.
Everything works except the CI failure says

RuntimeError: It seems that we cannot capture your model as a full graph. Typical reasons include graph breaks, data/shape-dependent control flow, or missing meta kernels for custom operators. You can use our manual pipeline interfaces, or try to fix the graph breaks

So I decided to change it back.

),
"output": col_parallel_strategy(
input_layouts=Shard(1),
input_layouts=Shard(0),
output_layouts=Shard(-1) if loss_parallel else Replicate(),
use_local_output=not loss_parallel,
),
"norm": SequenceParallel(),
"norm": SequenceParallel(sequence_dim=0),
"layers.0": PrepareModuleInput(
input_layouts=(Replicate(), None),
desired_input_layouts=(Shard(0), None),
use_local_output=True,
),
},
)

# Apply tensor + sequence parallelism to every transformer block
for layer_id, transformer_block in model.layers.items():
layer_plan = {
"attention": prepare_module_input(
input_layouts=(Shard(1), None),
input_layouts=(Shard(0), None),
desired_input_layouts=(Replicate(), None),
),
"attention.wq": col_parallel_strategy(),
"attention.wk": col_parallel_strategy(),
"attention.wv": col_parallel_strategy(),
"attention.wo": row_parallel_strategy(output_layouts=Shard(1)),
"attention_norm": SequenceParallel(),
"attention.wo": row_parallel_strategy(output_layouts=Shard(0)),
"attention_norm": SequenceParallel(sequence_dim=0),
"feed_forward": prepare_module_input(
input_layouts=(Shard(1),),
input_layouts=(Shard(0),),
desired_input_layouts=(Replicate(),),
),
"feed_forward.w1": col_parallel_strategy(),
"feed_forward.w2": row_parallel_strategy(output_layouts=Shard(1)),
"feed_forward.w2": row_parallel_strategy(output_layouts=Shard(0)),
"feed_forward.w3": col_parallel_strategy(),
"ffn_norm": SequenceParallel(),
"ffn_norm": SequenceParallel(sequence_dim=0),
}

# Adjust attention module to use the local number of heads
Expand Down