@@ -146,3 +146,150 @@ def forward(
146
146
past_key_value ,
147
147
world_size = world_size ,
148
148
)
149
+
150
+
151
+ class PatchedQwen2AttentionAscend (nn .Module ):
152
+
153
+ def _load_weights (self , loader , rank : int , world_size : int ,
154
+ device : torch .device ):
155
+ """load weights."""
156
+ for mod_name in ['q_proj' , 'k_proj' , 'v_proj' ]:
157
+ colwise_parallelize_linear (getattr (self , mod_name ),
158
+ loader ,
159
+ rank = rank ,
160
+ world_size = world_size ,
161
+ prefix = mod_name )
162
+ for mod_name in ['o_proj' ]:
163
+ rowwise_parallelize_linear (getattr (self , mod_name ),
164
+ loader ,
165
+ rank = rank ,
166
+ world_size = world_size ,
167
+ prefix = mod_name )
168
+
169
+ @classmethod
170
+ def _distribute_output_fn (cls , outputs , ** kwargs ):
171
+ """Distribution output hook."""
172
+ dist .all_reduce (outputs [0 ])
173
+ return outputs
174
+
175
+ def _contiguous_batching_forward_impl (
176
+ self ,
177
+ hidden_states : torch .Tensor ,
178
+ position_ids : Optional [torch .LongTensor ] = None ,
179
+ past_key_value : Optional [Tuple [torch .Tensor ]] = None ,
180
+ world_size : int = 1 ,
181
+ ) -> Tuple [torch .Tensor , Optional [torch .Tensor ],
182
+ Optional [Tuple [torch .Tensor ]]]:
183
+ """Rewrite implementation of forward.
184
+
185
+ Add continuous batching support. Add paged attention support. TP
186
+ support.
187
+ """
188
+ context = self .context .context
189
+ kv_seq_length = context .kv_seq_length
190
+ q_seq_length = context .q_seq_length
191
+ q_start_loc = context .q_start_loc
192
+ block_offsets = context .block_offsets
193
+ max_q_seq_length = context .max_q_seq_length
194
+ max_kv_seq_length = context .max_kv_seq_length
195
+
196
+ num_heads = self .num_heads // world_size
197
+ num_kv_heads = self .num_key_value_heads // world_size
198
+ head_dim = self .head_dim
199
+ hidden_size = num_heads * head_dim
200
+
201
+ def __qkv_proj (hidden_states ):
202
+ """qkv proj."""
203
+ query_states = self .q_proj (hidden_states )
204
+ key_states = self .k_proj (hidden_states )
205
+ value_states = self .v_proj (hidden_states )
206
+
207
+ return query_states , key_states , value_states
208
+
209
+ def __rotary_emb_fn (query_states , key_states , value_states ):
210
+ if hasattr (self , 'rotary_emb' ):
211
+ cos , sin = self .rotary_emb (value_states ,
212
+ seq_len = max_kv_seq_length )
213
+ query_states , key_states = apply_rotary_pos_emb (
214
+ query_states ,
215
+ key_states ,
216
+ cos ,
217
+ sin ,
218
+ position_ids ,
219
+ context .position_ids_1d ,
220
+ context = context )
221
+ return query_states , key_states , value_states
222
+
223
+ query_states , key_states , value_states = __qkv_proj (hidden_states )
224
+
225
+ query_states = query_states .view (- 1 , num_heads , head_dim )
226
+ key_states = key_states .view (- 1 , num_kv_heads , head_dim )
227
+ value_states = value_states .view (- 1 , num_kv_heads , head_dim )
228
+
229
+ query_states , key_states , value_states = __rotary_emb_fn (
230
+ query_states , key_states , value_states )
231
+
232
+ fill_kv_cache (
233
+ key_states ,
234
+ value_states ,
235
+ past_key_value [0 ],
236
+ past_key_value [1 ],
237
+ q_start_loc ,
238
+ q_seq_length ,
239
+ kv_seq_length = kv_seq_length ,
240
+ max_q_seq_length = max_q_seq_length ,
241
+ block_offsets = block_offsets ,
242
+ context = context ,
243
+ )
244
+
245
+ attn_output = query_states
246
+
247
+ use_sliding_windows = (getattr (self .config , 'sliding_window' , None )
248
+ is not None and self .config .use_sliding_window )
249
+ window_size = self .config .sliding_window
250
+ if not use_sliding_windows :
251
+ window_size = - 1
252
+ paged_attention_fwd (
253
+ query_states ,
254
+ key_states ,
255
+ value_states ,
256
+ past_key_value [0 ],
257
+ past_key_value [1 ],
258
+ attn_output ,
259
+ block_offsets ,
260
+ q_start_loc = q_start_loc ,
261
+ q_seqlens = q_seq_length ,
262
+ kv_seqlens = kv_seq_length ,
263
+ max_seqlen = max_q_seq_length ,
264
+ window_size = window_size ,
265
+ context = context ,
266
+ )
267
+
268
+ attn_output = attn_output .reshape (* hidden_states .shape [:- 1 ],
269
+ hidden_size )
270
+
271
+ attn_output = self .o_proj (attn_output )
272
+
273
+ return attn_output , None , past_key_value
274
+
275
+ def forward (
276
+ self ,
277
+ hidden_states : torch .Tensor ,
278
+ attention_mask : Optional [torch .Tensor ] = None ,
279
+ position_ids : Optional [torch .LongTensor ] = None ,
280
+ past_key_value : Optional [Tuple [torch .Tensor ]] = None ,
281
+ output_attentions : bool = False ,
282
+ use_cache : bool = False ,
283
+ ** kwargs ,
284
+ ) -> Tuple [torch .Tensor , Optional [torch .Tensor ],
285
+ Optional [Tuple [torch .Tensor ]]]:
286
+ """Rewrite of forward."""
287
+ world_size = 1
288
+ if dist .is_initialized ():
289
+ world_size = dist .get_world_size ()
290
+ return self ._contiguous_batching_forward_impl (
291
+ hidden_states ,
292
+ position_ids ,
293
+ past_key_value ,
294
+ world_size = world_size ,
295
+ )
0 commit comments