@@ -153,18 +153,18 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig):
153153 input_layouts = Replicate (),
154154 ),
155155 "output" : col_parallel_strategy (
156- input_layouts = Shard (0 ),
156+ input_layouts = Shard (1 ),
157157 output_layouts = (
158158 Shard (- 1 )
159159 if parallel_dims .loss_parallel_enabled
160160 else Replicate ()
161161 ),
162162 use_local_output = not parallel_dims .loss_parallel_enabled ,
163163 ),
164- "norm" : SequenceParallel (sequence_dim = 0 ),
164+ "norm" : SequenceParallel (),
165165 "layers.0" : PrepareModuleInput (
166166 input_layouts = (Replicate (), None ),
167- desired_input_layouts = (Shard (0 ), None ),
167+ desired_input_layouts = (Shard (1 ), None ),
168168 use_local_output = True ,
169169 ),
170170 },
@@ -174,22 +174,22 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig):
174174 for layer_id , transformer_block in enumerate (model .layers ):
175175 layer_plan = {
176176 "attention" : PrepareModuleInput (
177- input_layouts = (Shard (0 ), None ),
177+ input_layouts = (Shard (1 ), None ),
178178 desired_input_layouts = (Replicate (), None ),
179179 ),
180180 "attention.wq" : col_parallel_strategy (),
181181 "attention.wk" : col_parallel_strategy (),
182182 "attention.wv" : col_parallel_strategy (),
183- "attention.wo" : row_parallel_strategy (output_layouts = Shard (0 )),
184- "attention_norm" : SequenceParallel (sequence_dim = 0 ),
183+ "attention.wo" : row_parallel_strategy (output_layouts = Shard (1 )),
184+ "attention_norm" : SequenceParallel (),
185185 "feed_forward" : PrepareModuleInput (
186- input_layouts = (Shard (0 ),),
186+ input_layouts = (Shard (1 ),),
187187 desired_input_layouts = (Replicate (),),
188188 ),
189189 "feed_forward.w1" : col_parallel_strategy (),
190- "feed_forward.w2" : row_parallel_strategy (output_layouts = Shard (0 )),
190+ "feed_forward.w2" : row_parallel_strategy (output_layouts = Shard (1 )),
191191 "feed_forward.w3" : col_parallel_strategy (),
192- "ffn_norm" : SequenceParallel (sequence_dim = 0 ),
192+ "ffn_norm" : SequenceParallel (),
193193 }
194194
195195 # Adjust attention module to use the local number of heads
0 commit comments