77import logging
88
99import torch
10+ from torch .distributed ._tensor import (
11+ distribute_module ,
12+ distribute_tensor ,
13+ DTensor ,
14+ Replicate ,
15+ Shard ,
16+ )
1017
1118from torch .distributed .algorithms ._checkpoint .checkpoint_wrapper import (
1219 checkpoint_wrapper as ptd_checkpoint_wrapper ,
1926 ShardingStrategy ,
2027)
2128from torch .distributed .fsdp .wrap import enable_wrap , wrap
29+ from torch .distributed .tensor .parallel import (
30+ ColwiseParallel ,
31+ parallelize_module ,
32+ PrepareModuleInput ,
33+ RowwiseParallel ,
34+ )
2235
2336from torchtrain .logging_utils import rank0_log
2437
2538logger = logging .getLogger (__name__ )
2639
40+
41+ def distribute_rmsnorm (module , device_mesh ):
42+ # temp sharding API until PTD API is added
43+ def prepare_input_fn (inputs , device_mesh ):
44+ if isinstance (inputs [0 ], DTensor ):
45+ return inputs
46+ elif isinstance (inputs [0 ], torch .Tensor ):
47+ shard_tensor = DTensor .from_local (
48+ inputs [0 ], device_mesh , [Shard (0 )], run_check = False
49+ )
50+ return shard_tensor
51+ else :
52+ raise NotImplementedError ("!!" )
53+
54+ def partition_fn (name , module , device_mesh ):
55+ for name , param in module .named_parameters ():
56+ dist_param = torch .nn .Parameter (
57+ distribute_tensor (param , device_mesh , [Replicate ()])
58+ )
59+ module .register_parameter (name , dist_param )
60+
61+ return distribute_module (
62+ module ,
63+ device_mesh ,
64+ partition_fn ,
65+ input_fn = prepare_input_fn ,
66+ )
67+
68+
2769# Uses PTD FSDP AC wrapper
2870def checkpoint_wrapper (module , config ):
2971 return ptd_checkpoint_wrapper (
@@ -43,7 +85,75 @@ def parallelize_llama(model, world_mesh, parallel_dims, args):
4385 if parallel_dims .pp_enabled :
4486 raise NotImplementedError ("PP not implemented yet." )
4587 if parallel_dims .sp_enabled :
46- raise NotImplementedError ("SP not implemented yet." )
88+ # First we apply Sequence Parallelism if it's enabled
89+ tp_mesh = world_mesh ["sp" ] if world_mesh .ndim > 1 else world_mesh
90+ sp_degree = args .sp_degree
91+ # First:
92+ # 1. parallelize the first embedding and the last linear proj layer
93+ # 2. shard the first layer of transformer block
94+ # TODO: enable loss parallel once it's ready
95+ model = parallelize_module (
96+ model ,
97+ tp_mesh ,
98+ {
99+ "embeddings.tok_embeddings" : RowwiseParallel (
100+ input_layouts = Replicate (),
101+ ),
102+ "output" : ColwiseParallel (
103+ input_layouts = Shard (0 ),
104+ output_layouts = Replicate (),
105+ ),
106+ "layers.0" : PrepareModuleInput (
107+ input_layouts = (Replicate (), None ),
108+ desired_input_layouts = (Shard (0 ), None ),
109+ use_local_output = True ,
110+ ),
111+ },
112+ )
113+
114+ # apply sequence parallelism to every transformer block
115+ for layer_id , transformer_block in enumerate (model .layers ):
116+ layer_plan = {
117+ "attention" : PrepareModuleInput (
118+ input_layouts = (Shard (0 ), None ),
119+ desired_input_layouts = (Replicate (), None ),
120+ ),
121+ "attention.wq" : ColwiseParallel (),
122+ "attention.wk" : ColwiseParallel (),
123+ "attention.wv" : ColwiseParallel (),
124+ "attention.wo" : RowwiseParallel (output_layouts = Shard (0 )),
125+ "feed_forward" : PrepareModuleInput (
126+ input_layouts = (Shard (0 ),),
127+ desired_input_layouts = (Replicate (),),
128+ ),
129+ "feed_forward.w1" : ColwiseParallel (),
130+ "feed_forward.w2" : RowwiseParallel (output_layouts = Shard (0 )),
131+ "feed_forward.w3" : ColwiseParallel (),
132+ }
133+ # if layer_id == 0:
134+ # # in first transformer block we need to shard the input
135+ # layer_plan[""] = PrepareModuleInput(
136+ # input_layouts=(Replicate(), None),
137+ # desired_input_layouts=(Shard(0), None),
138+ # )
139+
140+ # adjust num_heads in attention layer to local heads
141+ attn_layer = transformer_block .attention
142+ attn_layer .n_heads = attn_layer .n_heads // sp_degree
143+ attn_layer .n_kv_heads = attn_layer .n_kv_heads // sp_degree
144+
145+ # shard RMSNorm layers
146+ distribute_rmsnorm (transformer_block .attention_norm , tp_mesh )
147+ distribute_rmsnorm (transformer_block .ffn_norm , tp_mesh )
148+
149+ parallelize_module (
150+ module = transformer_block ,
151+ device_mesh = tp_mesh ,
152+ parallelize_plan = layer_plan ,
153+ )
154+
155+ rank0_log ("Applied Sequence Parallelism to the model..." )
156+
47157 if parallel_dims .dp_enabled :
48158 dp_mesh = world_mesh ["dp" ] if world_mesh .ndim > 1 else world_mesh
49159 assert dp_mesh .mesh_dim_names == ["dp" ], dp_mesh .mesh_dim_names
@@ -73,6 +183,6 @@ def parallelize_llama(model, world_mesh, parallel_dims, args):
73183 # wrap the rest layers with FSDP
74184 model = wrap (model .cuda ())
75185
76- rank0_log ("Applied parallelisms to the model..." )
186+ rank0_log ("Applied FSDP to the model..." )
77187
78188 return model
0 commit comments