Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 13 additions & 3 deletions torchtrain/models/llama/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,10 @@ def forward(
torch.Tensor: Output tensor after attention.

"""
bsz, seqlen, _ = x.shape
seqlen, _ = freqs_cis.shape
bs_seqlen, _ = x.shape
bsz = bs_seqlen // seqlen

xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)

xq = xq.view(bsz, seqlen, self.n_heads, self.head_dim)
Expand All @@ -237,7 +240,8 @@ def forward(
output = output.transpose(
1, 2
).contiguous() # (bs, seqlen, n_local_heads, head_dim)
output = output.view(bsz, seqlen, -1)
# output stay folded with batch and sequence dimension
output = output.view(bsz * seqlen, -1)
return self.wo(output)


Expand Down Expand Up @@ -342,7 +346,6 @@ def __init__(self, layer_id: int, args: ModelArgs):
super().__init__()
self.n_heads = args.n_heads
self.dim = args.dim
self.head_dim = args.dim // args.n_heads
self.attention = Attention(args)
self.feed_forward = FeedForward(
dim=args.dim,
Expand Down Expand Up @@ -422,10 +425,17 @@ def forward(self, tokens: torch.Tensor):

"""
h, freqs_cis = self.embeddings(tokens)
# fold batch and sequence dimension for more efficient allgather/reduce_scatter
h = h.view(-1, self.params.dim)

for layer in self.layers:
h = layer(h, freqs_cis)

h = self.norm(h)
# unfold batch and sequence dimension
bsz = tokens.shape[0]
bs_seqlen = h.shape[0]
h = h.view(bsz, bs_seqlen // bsz, self.params.dim)
output = self.output(h).float()
return output

Expand Down
114 changes: 112 additions & 2 deletions torchtrain/parallelisms/parallelize_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,13 @@
import logging

import torch
from torch.distributed._tensor import (
distribute_module,
distribute_tensor,
DTensor,
Replicate,
Shard,
)

from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
checkpoint_wrapper as ptd_checkpoint_wrapper,
Expand All @@ -19,11 +26,46 @@
ShardingStrategy,
)
from torch.distributed.fsdp.wrap import enable_wrap, wrap
from torch.distributed.tensor.parallel import (
ColwiseParallel,
parallelize_module,
PrepareModuleInput,
RowwiseParallel,
)

from torchtrain.logging_utils import rank0_log

logger = logging.getLogger(__name__)


def distribute_rmsnorm(module, device_mesh):
# temp sharding API until PTD API is added
def prepare_input_fn(inputs, device_mesh):
if isinstance(inputs[0], DTensor):
return inputs
elif isinstance(inputs[0], torch.Tensor):
shard_tensor = DTensor.from_local(
inputs[0], device_mesh, [Shard(0)], run_check=False
)
return shard_tensor
else:
raise NotImplementedError("!!")

def partition_fn(name, module, device_mesh):
for name, param in module.named_parameters():
dist_param = torch.nn.Parameter(
distribute_tensor(param, device_mesh, [Replicate()])
)
module.register_parameter(name, dist_param)

return distribute_module(
module,
device_mesh,
partition_fn,
input_fn=prepare_input_fn,
)


# Uses PTD FSDP AC wrapper
def checkpoint_wrapper(module, config):
return ptd_checkpoint_wrapper(
Expand All @@ -43,7 +85,75 @@ def parallelize_llama(model, world_mesh, parallel_dims, args):
if parallel_dims.pp_enabled:
raise NotImplementedError("PP not implemented yet.")
if parallel_dims.sp_enabled:
raise NotImplementedError("SP not implemented yet.")
# First we apply Sequence Parallelism if it's enabled
tp_mesh = world_mesh["sp"] if world_mesh.ndim > 1 else world_mesh
sp_degree = args.sp_degree
# First:
# 1. parallelize the first embedding and the last linear proj layer
# 2. shard the first layer of transformer block
# TODO: enable loss parallel once it's ready
model = parallelize_module(
model,
tp_mesh,
{
"embeddings.tok_embeddings": RowwiseParallel(
input_layouts=Replicate(),
),
"output": ColwiseParallel(
input_layouts=Shard(0),
output_layouts=Replicate(),
),
"layers.0": PrepareModuleInput(
input_layouts=(Replicate(), None),
desired_input_layouts=(Shard(0), None),
use_local_output=True,
),
},
)

# apply sequence parallelism to every transformer block
for layer_id, transformer_block in enumerate(model.layers):
layer_plan = {
"attention": PrepareModuleInput(
input_layouts=(Shard(0), None),
desired_input_layouts=(Replicate(), None),
),
"attention.wq": ColwiseParallel(),
"attention.wk": ColwiseParallel(),
"attention.wv": ColwiseParallel(),
"attention.wo": RowwiseParallel(output_layouts=Shard(0)),
"feed_forward": PrepareModuleInput(
input_layouts=(Shard(0),),
desired_input_layouts=(Replicate(),),
),
"feed_forward.w1": ColwiseParallel(),
"feed_forward.w2": RowwiseParallel(output_layouts=Shard(0)),
"feed_forward.w3": ColwiseParallel(),
}
# if layer_id == 0:
# # in first transformer block we need to shard the input
# layer_plan[""] = PrepareModuleInput(
# input_layouts=(Replicate(), None),
# desired_input_layouts=(Shard(0), None),
# )

# adjust num_heads in attention layer to local heads
attn_layer = transformer_block.attention
attn_layer.n_heads = attn_layer.n_heads // sp_degree
attn_layer.n_kv_heads = attn_layer.n_kv_heads // sp_degree

# shard RMSNorm layers
distribute_rmsnorm(transformer_block.attention_norm, tp_mesh)
distribute_rmsnorm(transformer_block.ffn_norm, tp_mesh)
Comment on lines +146 to +147
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shall we also apply it on the final norm after all transformer blocks?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not sth currently enabled, but I think we can explore this in real training and see if shard the final norm would give additional memory/perf benefits :)


parallelize_module(
module=transformer_block,
device_mesh=tp_mesh,
parallelize_plan=layer_plan,
)

rank0_log("Applied Sequence Parallelism to the model...")

if parallel_dims.dp_enabled:
dp_mesh = world_mesh["dp"] if world_mesh.ndim > 1 else world_mesh
assert dp_mesh.mesh_dim_names == ["dp"], dp_mesh.mesh_dim_names
Expand Down Expand Up @@ -73,6 +183,6 @@ def parallelize_llama(model, world_mesh, parallel_dims, args):
# wrap the rest layers with FSDP
model = wrap(model.cuda())

rank0_log("Applied parallelisms to the model...")
rank0_log("Applied FSDP to the model...")

return model