2222 ColwiseParallel ,
2323 parallelize_module ,
2424 PrepareModuleInput ,
25+ PrepareModuleOutput ,
2526 RowwiseParallel ,
2627 SequenceParallel ,
2728)
@@ -181,15 +182,21 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig):
181182 loss_parallel = parallel_dims .loss_parallel_enabled
182183
183184 # 1. Parallelize the first embedding and the last linear proj layer
184- # 2. Parallelize the root norm layer over the sequence dim
185- # 3. Shard the first transformer block's inputs
185+ # 2. Prepare the freq_cis in rotary embedding as dtensor
186+ # 3. Parallelize the root norm layer over the sequence dim
187+ # 4. Shard the first transformer block's inputs
186188 model = parallelize_module (
187189 model ,
188190 tp_mesh ,
189191 {
190192 "tok_embeddings" : RowwiseParallel (
191193 input_layouts = Replicate (),
192194 ),
195+ "embeddings" : PrepareModuleOutput (
196+ output_layouts = (None , Replicate ()),
197+ desired_output_layouts = (None , Replicate ()),
198+ use_local_output = False ,
199+ ),
193200 "output" : col_parallel_strategy (
194201 input_layouts = Shard (1 ),
195202 output_layouts = (Shard (- 1 ) if loss_parallel else Replicate ()),
@@ -212,9 +219,9 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig):
212219 input_layouts = (Shard (1 ), None ),
213220 desired_input_layouts = (Replicate (), None ),
214221 ),
215- "attention.wq" : col_parallel_strategy (),
216- "attention.wk" : col_parallel_strategy (),
217- "attention.wv" : col_parallel_strategy (),
222+ "attention.wq" : col_parallel_strategy (use_local_output = False ),
223+ "attention.wk" : col_parallel_strategy (use_local_output = False ),
224+ "attention.wv" : col_parallel_strategy (use_local_output = False ),
218225 "attention.wo" : row_parallel_strategy (output_layouts = Shard (1 )),
219226 "attention_norm" : SequenceParallel (),
220227 "feed_forward" : PrepareModuleInput (
@@ -227,11 +234,6 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig):
227234 "ffn_norm" : SequenceParallel (),
228235 }
229236
230- # Adjust attention module to use the local number of heads
231- attn_layer = transformer_block .attention
232- attn_layer .n_heads = attn_layer .n_heads // tp_mesh .size ()
233- attn_layer .n_kv_heads = attn_layer .n_kv_heads // tp_mesh .size ()
234-
235237 parallelize_module (
236238 module = transformer_block ,
237239 device_mesh = tp_mesh ,
0 commit comments