Skip to content

Commit 83a1105

Browse files
authored
feat: support ascend qwen2 and qwen2_moe (#6)
* feat: support ascend qwen2 and qwen2_moe * fix: fix ascend mixtral
1 parent 455786b commit 83a1105

File tree

4 files changed

+215
-0
lines changed

4 files changed

+215
-0
lines changed

lmdeploy/pytorch/models/mixtral.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -292,6 +292,7 @@ def forward(
292292
past_key_value: Optional[Tuple[torch.Tensor]] = None,
293293
output_attentions: bool = False,
294294
use_cache: bool = False,
295+
**kwargs,
295296
) -> Tuple[torch.Tensor, Optional[torch.Tensor],
296297
Optional[Tuple[torch.Tensor]]]:
297298
"""Rewrite of MistralAttention.forward."""

lmdeploy/pytorch/models/module_map.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -394,3 +394,25 @@
394394
'transformers.models.mixtral.modeling_mixtral.MixtralSparseMoeBlock':
395395
f'{LMDEPLOY_PYTORCH_MODEL_PATH}.mixtral.PatchedMixtralSparseMoeBlockAscend', # noqa: E501
396396
})
397+
398+
# ascend qwen1.5
399+
ASCEND_MODULE_MAP.update({
400+
'transformers.models.qwen2.modeling_qwen2.Qwen2Attention':
401+
f'{LMDEPLOY_PYTORCH_MODEL_PATH}.qwen2.PatchedQwen2AttentionAscend',
402+
'transformers.models.qwen2.modeling_qwen2.Qwen2FlashAttention2':
403+
f'{LMDEPLOY_PYTORCH_MODEL_PATH}.qwen2.PatchedQwen2AttentionAscend',
404+
'transformers.models.qwen2.modeling_qwen2.Qwen2SdpaAttention':
405+
f'{LMDEPLOY_PYTORCH_MODEL_PATH}.qwen2.PatchedQwen2AttentionAscend',
406+
})
407+
408+
# ascend qwen2 moe
409+
ASCEND_MODULE_MAP.update({
410+
'transformers.models.qwen2_moe.modeling_qwen2_moe.Qwen2MoeAttention':
411+
f'{LMDEPLOY_PYTORCH_MODEL_PATH}.qwen2.PatchedQwen2AttentionAscend',
412+
'transformers.models.qwen2_moe.modeling_qwen2_moe.Qwen2MoeFlashAttention2':
413+
f'{LMDEPLOY_PYTORCH_MODEL_PATH}.qwen2.PatchedQwen2AttentionAscend',
414+
'transformers.models.qwen2_moe.modeling_qwen2_moe.Qwen2MoeSdpaAttention':
415+
f'{LMDEPLOY_PYTORCH_MODEL_PATH}.qwen2.PatchedQwen2AttentionAscend',
416+
'transformers.models.qwen2_moe.modeling_qwen2_moe.Qwen2MoeSparseMoeBlock':
417+
f'{LMDEPLOY_PYTORCH_MODEL_PATH}.qwen2_moe.PatchedQwen2MoeSparseMoeBlockAscend', # noqa: E501
418+
})

lmdeploy/pytorch/models/qwen2.py

Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,3 +146,150 @@ def forward(
146146
past_key_value,
147147
world_size=world_size,
148148
)
149+
150+
151+
class PatchedQwen2AttentionAscend(nn.Module):
152+
153+
def _load_weights(self, loader, rank: int, world_size: int,
154+
device: torch.device):
155+
"""load weights."""
156+
for mod_name in ['q_proj', 'k_proj', 'v_proj']:
157+
colwise_parallelize_linear(getattr(self, mod_name),
158+
loader,
159+
rank=rank,
160+
world_size=world_size,
161+
prefix=mod_name)
162+
for mod_name in ['o_proj']:
163+
rowwise_parallelize_linear(getattr(self, mod_name),
164+
loader,
165+
rank=rank,
166+
world_size=world_size,
167+
prefix=mod_name)
168+
169+
@classmethod
170+
def _distribute_output_fn(cls, outputs, **kwargs):
171+
"""Distribution output hook."""
172+
dist.all_reduce(outputs[0])
173+
return outputs
174+
175+
def _contiguous_batching_forward_impl(
176+
self,
177+
hidden_states: torch.Tensor,
178+
position_ids: Optional[torch.LongTensor] = None,
179+
past_key_value: Optional[Tuple[torch.Tensor]] = None,
180+
world_size: int = 1,
181+
) -> Tuple[torch.Tensor, Optional[torch.Tensor],
182+
Optional[Tuple[torch.Tensor]]]:
183+
"""Rewrite implementation of forward.
184+
185+
Add continuous batching support. Add paged attention support. TP
186+
support.
187+
"""
188+
context = self.context.context
189+
kv_seq_length = context.kv_seq_length
190+
q_seq_length = context.q_seq_length
191+
q_start_loc = context.q_start_loc
192+
block_offsets = context.block_offsets
193+
max_q_seq_length = context.max_q_seq_length
194+
max_kv_seq_length = context.max_kv_seq_length
195+
196+
num_heads = self.num_heads // world_size
197+
num_kv_heads = self.num_key_value_heads // world_size
198+
head_dim = self.head_dim
199+
hidden_size = num_heads * head_dim
200+
201+
def __qkv_proj(hidden_states):
202+
"""qkv proj."""
203+
query_states = self.q_proj(hidden_states)
204+
key_states = self.k_proj(hidden_states)
205+
value_states = self.v_proj(hidden_states)
206+
207+
return query_states, key_states, value_states
208+
209+
def __rotary_emb_fn(query_states, key_states, value_states):
210+
if hasattr(self, 'rotary_emb'):
211+
cos, sin = self.rotary_emb(value_states,
212+
seq_len=max_kv_seq_length)
213+
query_states, key_states = apply_rotary_pos_emb(
214+
query_states,
215+
key_states,
216+
cos,
217+
sin,
218+
position_ids,
219+
context.position_ids_1d,
220+
context=context)
221+
return query_states, key_states, value_states
222+
223+
query_states, key_states, value_states = __qkv_proj(hidden_states)
224+
225+
query_states = query_states.view(-1, num_heads, head_dim)
226+
key_states = key_states.view(-1, num_kv_heads, head_dim)
227+
value_states = value_states.view(-1, num_kv_heads, head_dim)
228+
229+
query_states, key_states, value_states = __rotary_emb_fn(
230+
query_states, key_states, value_states)
231+
232+
fill_kv_cache(
233+
key_states,
234+
value_states,
235+
past_key_value[0],
236+
past_key_value[1],
237+
q_start_loc,
238+
q_seq_length,
239+
kv_seq_length=kv_seq_length,
240+
max_q_seq_length=max_q_seq_length,
241+
block_offsets=block_offsets,
242+
context=context,
243+
)
244+
245+
attn_output = query_states
246+
247+
use_sliding_windows = (getattr(self.config, 'sliding_window', None)
248+
is not None and self.config.use_sliding_window)
249+
window_size = self.config.sliding_window
250+
if not use_sliding_windows:
251+
window_size = -1
252+
paged_attention_fwd(
253+
query_states,
254+
key_states,
255+
value_states,
256+
past_key_value[0],
257+
past_key_value[1],
258+
attn_output,
259+
block_offsets,
260+
q_start_loc=q_start_loc,
261+
q_seqlens=q_seq_length,
262+
kv_seqlens=kv_seq_length,
263+
max_seqlen=max_q_seq_length,
264+
window_size=window_size,
265+
context=context,
266+
)
267+
268+
attn_output = attn_output.reshape(*hidden_states.shape[:-1],
269+
hidden_size)
270+
271+
attn_output = self.o_proj(attn_output)
272+
273+
return attn_output, None, past_key_value
274+
275+
def forward(
276+
self,
277+
hidden_states: torch.Tensor,
278+
attention_mask: Optional[torch.Tensor] = None,
279+
position_ids: Optional[torch.LongTensor] = None,
280+
past_key_value: Optional[Tuple[torch.Tensor]] = None,
281+
output_attentions: bool = False,
282+
use_cache: bool = False,
283+
**kwargs,
284+
) -> Tuple[torch.Tensor, Optional[torch.Tensor],
285+
Optional[Tuple[torch.Tensor]]]:
286+
"""Rewrite of forward."""
287+
world_size = 1
288+
if dist.is_initialized():
289+
world_size = dist.get_world_size()
290+
return self._contiguous_batching_forward_impl(
291+
hidden_states,
292+
position_ids,
293+
past_key_value,
294+
world_size=world_size,
295+
)

lmdeploy/pytorch/models/qwen2_moe.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
from torch import nn
88

99
from lmdeploy.pytorch.kernels.fused_moe import fused_moe
10+
from lmdeploy.pytorch.kernels.moe_gating_topk_softmax import \
11+
moe_gating_topk_softmax
1012

1113

1214
class PatchedQwen2MoeSparseMoeBlock(nn.Module):
@@ -90,6 +92,49 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
9092
return out_states, router_logits
9193

9294

95+
class PatchedQwen2MoeSparseMoeBlockAscend(nn.Module):
96+
97+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
98+
""""""
99+
batch_size, sequence_length, hidden_dim = hidden_states.shape
100+
hidden_states = hidden_states.view(-1, hidden_dim)
101+
router_logits = self.gate(hidden_states)
102+
103+
routing_weights, selected_experts = moe_gating_topk_softmax(
104+
router_logits, self.top_k)
105+
if self.norm_topk_prob:
106+
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
107+
routing_weights = routing_weights.to(hidden_states.dtype)
108+
109+
out_states = torch.zeros(
110+
(batch_size * sequence_length, hidden_dim),
111+
dtype=hidden_states.dtype,
112+
device=hidden_states.device)
113+
114+
expert_mask = torch.nn.functional.one_hot(
115+
selected_experts, num_classes=self.num_experts).permute(2, 1, 0)
116+
117+
for expert_idx in range(self.num_experts):
118+
expert_layer = self.experts[expert_idx]
119+
idx, top_x = torch.where(expert_mask[expert_idx])
120+
121+
current_state = hidden_states[None, top_x].reshape(-1, hidden_dim)
122+
current_hidden_states = expert_layer(
123+
current_state) * routing_weights[top_x, idx, None]
124+
125+
out_states.index_add_(
126+
0, top_x, current_hidden_states.to(hidden_states.dtype))
127+
128+
shared_expert_output = self.shared_expert(hidden_states)
129+
shared_expert_output = F.sigmoid(
130+
self.shared_expert_gate(hidden_states)) * shared_expert_output
131+
132+
out_states = out_states + shared_expert_output
133+
out_states = out_states.unflatten(0, (-1, sequence_length))
134+
135+
return out_states, router_logits
136+
137+
93138
class PatchedQwen2MoeModel(nn.Module):
94139

95140
def _continuous_batching_forward(

0 commit comments

Comments
 (0)