@@ -219,6 +219,7 @@ Next let's adjust the ``layer_tp_plan`` to enable sequence parallel on the ``RMS
219219 layer_tp_plan = {
220220 # Now the input and output of SequenceParallel has Shard(1) layouts,
221221 # to represent the input/output tensors sharded on the sequence dimension
222+ " attention_norm" : SequenceParallel(),
222223 " attention" : PrepareModuleInput(
223224 input_layouts = (Shard(1 ),),
224225 desired_input_layouts = (Replicate(),),
@@ -227,15 +228,14 @@ Next let's adjust the ``layer_tp_plan`` to enable sequence parallel on the ``RMS
227228 " attention.wk" : ColwiseParallel(),
228229 " attention.wv" : ColwiseParallel(),
229230 " attention.wo" : RowwiseParallel(output_layouts = Shard(1 )),
230- " attention_norm " : SequenceParallel(),
231+ " ffn_norm " : SequenceParallel(),
231232 " feed_forward" : PrepareModuleInput(
232233 input_layouts = (Shard(1 ),),
233234 desired_input_layouts = (Replicate(),),
234235 ),
235236 " feed_forward.w1" : ColwiseParallel(),
236237 " feed_forward.w2" : RowwiseParallel(output_layouts = Shard(1 )),
237238 " feed_forward.w3" : ColwiseParallel(),
238- " ffn_norm" : SequenceParallel(),
239239 }
240240
241241
0 commit comments