1919 ColwiseParallel ,
2020 parallelize_module ,
2121 PrepareModuleInput ,
22+ PrepareModuleOutput ,
2223 RowwiseParallel ,
2324 SequenceParallel ,
2425)
@@ -143,15 +144,21 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig):
143144 )
144145
145146 # 1. Parallelize the first embedding and the last linear proj layer
146- # 2. Parallelize the root norm layer over the sequence dim
147- # 3. Shard the first transformer block's inputs
147+ # 2. Prepare the freq_cis in rotary embedding as dtensor
148+ # 3. Parallelize the root norm layer over the sequence dim
149+ # 4. Shard the first transformer block's inputs
148150 model = parallelize_module (
149151 model ,
150152 tp_mesh ,
151153 {
152154 "embeddings.tok_embeddings" : RowwiseParallel (
153155 input_layouts = Replicate (),
154156 ),
157+ "embeddings" : PrepareModuleOutput (
158+ output_layouts = (None , Replicate ()),
159+ desired_output_layouts = (None , Replicate ()),
160+ use_local_output = False ,
161+ ),
155162 "output" : col_parallel_strategy (
156163 input_layouts = Shard (0 ),
157164 output_layouts = (
@@ -177,9 +184,9 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig):
177184 input_layouts = (Shard (0 ), None ),
178185 desired_input_layouts = (Replicate (), None ),
179186 ),
180- "attention.wq" : col_parallel_strategy (),
181- "attention.wk" : col_parallel_strategy (),
182- "attention.wv" : col_parallel_strategy (),
187+ "attention.wq" : col_parallel_strategy (use_local_output = False ),
188+ "attention.wk" : col_parallel_strategy (use_local_output = False ),
189+ "attention.wv" : col_parallel_strategy (use_local_output = False ),
183190 "attention.wo" : row_parallel_strategy (output_layouts = Shard (0 )),
184191 "attention_norm" : SequenceParallel (sequence_dim = 0 ),
185192 "feed_forward" : PrepareModuleInput (
@@ -192,11 +199,6 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig):
192199 "ffn_norm" : SequenceParallel (sequence_dim = 0 ),
193200 }
194201
195- # Adjust attention module to use the local number of heads
196- attn_layer = transformer_block .attention
197- attn_layer .n_heads = attn_layer .n_heads // tp_mesh .size ()
198- attn_layer .n_kv_heads = attn_layer .n_kv_heads // tp_mesh .size ()
199-
200202 parallelize_module (
201203 module = transformer_block ,
202204 device_mesh = tp_mesh ,
0 commit comments