|
109 | 109 | input_layouts=Replicate(), |
110 | 110 | output_layouts=Shard(1), |
111 | 111 | ), |
| 112 | + "norm": SequenceParallel(), |
112 | 113 | "output": ColwiseParallel( |
113 | 114 | input_layouts=Shard(1), |
114 | 115 | output_layouts=Replicate() |
115 | 116 | ), |
116 | | - "norm": SequenceParallel(), |
117 | 117 | } |
118 | 118 | ) |
119 | 119 |
|
120 | 120 | for layer_id, transformer_block in enumerate(model.layers): |
121 | 121 | layer_tp_plan = { |
| 122 | + "attention_norm": SequenceParallel(), |
122 | 123 | "attention": PrepareModuleInput( |
123 | 124 | input_layouts=(Shard(1), None), |
124 | 125 | desired_input_layouts=(Replicate(), None), |
|
127 | 128 | "attention.wk": ColwiseParallel(), |
128 | 129 | "attention.wv": ColwiseParallel(), |
129 | 130 | "attention.wo": RowwiseParallel(output_layouts=Shard(1)), |
130 | | - "attention_norm": SequenceParallel(), |
| 131 | + "ffn_norm": SequenceParallel(), |
131 | 132 | "feed_forward": PrepareModuleInput( |
132 | 133 | input_layouts=(Shard(1),), |
133 | 134 | desired_input_layouts=(Replicate(),), |
134 | 135 | ), |
135 | 136 | "feed_forward.w1": ColwiseParallel(), |
136 | 137 | "feed_forward.w2": RowwiseParallel(output_layouts=Shard(1)), |
137 | 138 | "feed_forward.w3": ColwiseParallel(), |
138 | | - "ffn_norm": SequenceParallel(), |
139 | 139 | } |
140 | 140 |
|
141 | 141 | # Adjust attention module to use the local number of heads |
|
0 commit comments