Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions lmdeploy/pytorch/engine/devices/ascend.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch

from .dipu import DIPUDeviceUtils
from .base_device_utils import BaseDeviceUtils


class ASCENDDeviceUtils(DIPUDeviceUtils):
class ASCENDDeviceUtils(BaseDeviceUtils):

device = 'ascend'

Expand Down Expand Up @@ -38,4 +38,7 @@ def update_step_context(cls, step_context):
kv_start_indices, device=step_context.block_offsets.device)
setattr(step_context, 'kv_start_indices', kv_start_indices)
setattr(step_context, 'attention_mask', attention_mask)
is_unpaged_prefill = (not step_context.is_decoding) and all(
(step_context.q_seq_length == step_context.kv_seq_length).tolist())
setattr(step_context, 'is_unpaged_prefill', is_unpaged_prefill)
return step_context
13 changes: 0 additions & 13 deletions lmdeploy/pytorch/engine/devices/dipu.py

This file was deleted.

2 changes: 1 addition & 1 deletion lmdeploy/pytorch/kernels/ascend/apply_rotary_pos_emb.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def apply_rotary_pos_emb(
cached_cos = context.cos if context else cos
cached_sin = context.sin if context else sin
ext_ops.apply_rotary_pos_emb(query_states_reshaped, key_states_reshaped,
cached_cos, cached_sin, None, None, None)
cached_cos, cached_sin, None, None)
if q_embed is None:
q_embed = query_states
else:
Expand Down
2 changes: 1 addition & 1 deletion lmdeploy/pytorch/kernels/ascend/fused_rotary_emb.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def fused_rotary_emb(
cached_cos = context.cos if context else cos
cached_sin = context.sin if context else sin
ext_ops.apply_rotary_pos_emb(query_states_reshaped, key_states_reshaped,
cached_cos, cached_sin, None, None, None)
cached_cos, cached_sin, None, None)
if out_q is None:
out_q = query_states
else:
Expand Down
66 changes: 33 additions & 33 deletions lmdeploy/pytorch/kernels/ascend/paged_attention_fwd.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,43 +21,41 @@ def flash_context_attention(
):
num_q_heads, dim = query_states.shape[1:3]
num_kv_heads = value_states.shape[1]
batch = q_start_loc.shape[0]

for i in range(batch):
if torch.equal(q_seq_len[i], kv_seq_len[i]):
ext_ops.context_attention(
query_states,
key_states,
value_states,
q_start_loc[i:i + 1],
q_seq_len[i:i + 1],
num_q_heads,
num_kv_heads,
attn_mask=context.attention_mask[i:i + 1],
attn_output=attn_output,
)
else:
key_cache = key_cache.reshape(1, kv_cache_len, num_kv_heads * dim)
value_cache = value_cache.reshape(1, kv_cache_len,
num_kv_heads * dim)
ext_ops.paged_prefill_attention(
query_states,
key_cache,
value_cache,
block_offsets,
block_size,
q_start_loc[i:i + 1],
q_seq_len[i:i + 1],
kv_seq_len[i:i + 1],
num_q_heads,
num_kv_heads,
attn_mask=context.attention_mask[i:i + 1],
attn_output=attn_output,
)
if context.is_unpaged_prefill:
ext_ops.prefill_attention(
query_states,
key_states,
value_states,
q_start_loc,
q_seq_len,
context.max_q_seq_length,
num_q_heads,
num_kv_heads,
attn_mask=context.attention_mask,
attn_output=attn_output,
)
else:
key_cache = key_cache.reshape(1, kv_cache_len, num_kv_heads * dim)
value_cache = value_cache.reshape(1, kv_cache_len, num_kv_heads * dim)
ext_ops.paged_prefill_attention(
query_states,
key_cache,
value_cache,
block_offsets,
block_size,
q_start_loc,
q_seq_len,
kv_seq_len,
num_q_heads,
num_kv_heads,
attn_mask=context.attention_mask,
attn_output=attn_output,
)


def paged_token_attention(q, k_cache, v_cache, attn_output, kv_seq_len,
block_offsets, block_size):
max_kv_seq_len, block_offsets, block_size):
num_kv_heads, num_q_heads = k_cache.shape[1], q.shape[1]
ext_ops.paged_decode_attention(
q,
Expand All @@ -66,6 +64,7 @@ def paged_token_attention(q, k_cache, v_cache, attn_output, kv_seq_len,
block_offsets,
block_size,
kv_seq_len,
max_kv_seq_len,
num_q_heads,
num_kv_heads,
attn_output=attn_output.view(q.shape),
Expand Down Expand Up @@ -115,6 +114,7 @@ def paged_attention_fwd(
v,
attn_output,
kv_seqlens,
context.max_kv_seq_length,
block_offsets,
block_size,
)
16 changes: 0 additions & 16 deletions lmdeploy/pytorch/kernels/dipu/__init__.py

This file was deleted.

29 changes: 0 additions & 29 deletions lmdeploy/pytorch/kernels/dipu/apply_rotary_pos_emb.py

This file was deleted.

27 changes: 0 additions & 27 deletions lmdeploy/pytorch/kernels/dipu/fill_kv_cache.py

This file was deleted.

36 changes: 0 additions & 36 deletions lmdeploy/pytorch/kernels/dipu/fused_rotary_emb.py

This file was deleted.

Loading