Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 2 additions & 14 deletions torchtrain/models/llama/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,10 +226,7 @@ def forward(
torch.Tensor: Output tensor after attention.

"""
seqlen, _ = freqs_cis.shape
bs_seqlen, _ = x.shape
bsz = bs_seqlen // seqlen

bsz, seqlen, _ = x.shape
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)

xq = xq.view(bsz, seqlen, self.n_heads, self.head_dim)
Expand All @@ -255,8 +252,7 @@ def forward(
output = output.transpose(
1, 2
).contiguous() # (bs, seqlen, n_local_heads, head_dim)
# output stay folded with batch and sequence dimension
output = output.view(bsz * seqlen, -1)
output = output.view(bsz, seqlen, -1)
return self.wo(output)


Expand Down Expand Up @@ -487,17 +483,9 @@ def forward(self, tokens: torch.Tensor):

"""
h, freqs_cis = self.embeddings(tokens)
# fold batch and sequence dimension for more efficient allgather/reduce_scatter
h = h.view(-1, self.model_args.dim)

for layer in self.layers:
h = layer(h, freqs_cis)

h = self.norm(h)
# unfold batch and sequence dimension
bsz = tokens.shape[0]
bs_seqlen = h.shape[0]
h = h.view(bsz, bs_seqlen // bsz, self.model_args.dim)
output = self.output(h).float()
return output

Expand Down
18 changes: 9 additions & 9 deletions torchtrain/parallelisms/parallelize_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,18 +153,18 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig):
input_layouts=Replicate(),
),
"output": col_parallel_strategy(
input_layouts=Shard(0),
input_layouts=Shard(1),
output_layouts=(
Shard(-1)
if parallel_dims.loss_parallel_enabled
else Replicate()
),
use_local_output=not parallel_dims.loss_parallel_enabled,
),
"norm": SequenceParallel(sequence_dim=0),
"norm": SequenceParallel(),
"layers.0": PrepareModuleInput(
input_layouts=(Replicate(), None),
desired_input_layouts=(Shard(0), None),
desired_input_layouts=(Shard(1), None),
use_local_output=True,
),
},
Expand All @@ -174,22 +174,22 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig):
for layer_id, transformer_block in enumerate(model.layers):
layer_plan = {
"attention": PrepareModuleInput(
input_layouts=(Shard(0), None),
input_layouts=(Shard(1), None),
desired_input_layouts=(Replicate(), None),
),
"attention.wq": col_parallel_strategy(),
"attention.wk": col_parallel_strategy(),
"attention.wv": col_parallel_strategy(),
"attention.wo": row_parallel_strategy(output_layouts=Shard(0)),
"attention_norm": SequenceParallel(sequence_dim=0),
"attention.wo": row_parallel_strategy(output_layouts=Shard(1)),
"attention_norm": SequenceParallel(),
"feed_forward": PrepareModuleInput(
input_layouts=(Shard(0),),
input_layouts=(Shard(1),),
desired_input_layouts=(Replicate(),),
),
"feed_forward.w1": col_parallel_strategy(),
"feed_forward.w2": row_parallel_strategy(output_layouts=Shard(0)),
"feed_forward.w2": row_parallel_strategy(output_layouts=Shard(1)),
"feed_forward.w3": col_parallel_strategy(),
"ffn_norm": SequenceParallel(sequence_dim=0),
"ffn_norm": SequenceParallel(),
}

# Adjust attention module to use the local number of heads
Expand Down