From fd3ad6338919f5cf95f3f9a3a4fb7ca033457f5e Mon Sep 17 00:00:00 2001 From: zhoushenglong Date: Wed, 26 Jun 2024 22:14:27 +0000 Subject: [PATCH 1/2] add paged_attention impl. --- .../dicp/vendor/AscendGraph/codegen/ascend.py | 9 ++++-- dicp/dicp/vendor/AscendGraph/conversion.py | 23 +++++++++++++++ dicp/dicp/vendor/AscendGraph/ext_ops.py | 29 +++++++++++++++++++ 3 files changed, 59 insertions(+), 2 deletions(-) diff --git a/dicp/dicp/vendor/AscendGraph/codegen/ascend.py b/dicp/dicp/vendor/AscendGraph/codegen/ascend.py index 6dc8f107e..9bc032b2a 100644 --- a/dicp/dicp/vendor/AscendGraph/codegen/ascend.py +++ b/dicp/dicp/vendor/AscendGraph/codegen/ascend.py @@ -1609,15 +1609,20 @@ def PromptFlashAttention(name, q, k, v, head_num, seqlen, mask, head_dim, num_ke return op.to_node() @staticmethod - def IncreFlashAttention(name, q, k_list, v_list, kv_input_num, kv_head_num, head_num, dim, input_layout="BSH"): + def IncreFlashAttention(name, q, k_list, v_list, kv_input_num, head_num, kv_head_num, dim, input_layout="BSH", block_table=None, seq_lengths=None, block_size=128): op = OP(name, "IncreFlashAttention") op.set_input("query", q) op.set_dynamic_input("key", kv_input_num, k_list) op.set_dynamic_input("value", kv_input_num, v_list) op.set_attr_int("num_heads", head_num) - op.set_attr_float("scale_value", float(1 / math.sqrt(dim))) + if not block_table: + op.set_attr_float("scale_value", float(1 / math.sqrt(dim))) op.set_attr_int("num_key_value_heads", kv_head_num) op.set_attr_str("input_layout", input_layout) + if block_table: + op.set_input("block_table", block_table) + op.set_input("actual_seq_lengths", seq_lengths) + op.set_attr_int("block_size", block_size) return op.to_node() @staticmethod diff --git a/dicp/dicp/vendor/AscendGraph/conversion.py b/dicp/dicp/vendor/AscendGraph/conversion.py index 5ab6f3236..81e0a6f84 100644 --- a/dicp/dicp/vendor/AscendGraph/conversion.py +++ b/dicp/dicp/vendor/AscendGraph/conversion.py @@ -1721,6 +1721,8 @@ def lightllm_rotary_emb(self, x, cos, sin): seq_len = x_shape[0] dim = x_shape[2] + if isinstance(dim, torch.fx.proxy.Proxy): + dim = int(sympy.N(dim.node.meta['val'])) cos_sin_shape = self.get_shape_proxy([seq_len, 1, dim // 2]) cos = self.get_proxy(ascend_op.Reshape, (cos, cos_sin_shape)) @@ -1813,6 +1815,11 @@ def copy_with_offset(self, x, src, start_dim, end_dim): src = self.get_proxy(ascend_op.Cast, (src, get_ascend_dtype(src_dtype))) return self.get_proxy(ascend_op.ScatterNdUpdate, (x, dims, src)) + @register_conversion(torch.ops.lightllm.copy_with_index.default) + def copy_with_index(self, x, src, dims): + dims = self.get_proxy(ascend_op.Unsqueeze, (dims, [-1])) + return self.get_proxy(ascend_op.ScatterNdUpdate, (x, dims, src)) + @register_conversion(torch.ops.lightllm.flash_attention_inference.default) def flash_attention_inference(self, q, all_k, all_v, current_lens, max_len, kvhead=-1, head=-1, dim=-1): q_shape = list(q.node.meta['val'].shape) @@ -1876,3 +1883,19 @@ def flash_attention_inference(self, q, all_k, all_v, current_lens, max_len, kvhe res = self.get_proxy(ascend_op.ConcatD, (res, 0)) return res + + @register_conversion(torch.ops.lightllm.paged_attention_inference.default) + def paged_attention_inference(self, q, all_k, all_v, q_head_num, dim, kv_head_num, block_table=None, seq_lengths=None, block_size=128): + if isinstance(q_head_num, torch.fx.proxy.Proxy): + q_head_num = int(sympy.N(q_head_num.node.meta['val'])) + if isinstance(dim, torch.fx.proxy.Proxy): + dim = int(sympy.N(dim.node.meta['val'])) + if isinstance(kv_head_num, torch.fx.proxy.Proxy): + kv_head_num = int(sympy.N(kv_head_num.node.meta['val'])) + q = self.get_proxy(ascend_op.Unsqueeze, (q, [1])) + all_k = self.get_proxy(ascend_op.Unsqueeze, (all_k, [1])) + all_v = self.get_proxy(ascend_op.Unsqueeze, (all_v, [1])) + out = self.get_proxy(ascend_op.IncreFlashAttention, (q, [all_k], [all_v], 1, + q_head_num, kv_head_num, dim, "BSH", block_table, seq_lengths, block_size)) + out = self.get_proxy(ascend_op.Squeeze, (out, [1])) + return out diff --git a/dicp/dicp/vendor/AscendGraph/ext_ops.py b/dicp/dicp/vendor/AscendGraph/ext_ops.py index 95ac8e26c..503c6ca45 100644 --- a/dicp/dicp/vendor/AscendGraph/ext_ops.py +++ b/dicp/dicp/vendor/AscendGraph/ext_ops.py @@ -132,6 +132,20 @@ def lightllm_flash_attention_inference_impl(q, all_k, all_v, current_lens, max_l res = torch.cat(res) return res +@torch._custom_op.impl.custom_op('lightllm::paged_attention_inference') +def paged_attention_inference(q: Tensor, all_k: Tensor, all_v: Tensor, q_head_num: int, dim: int, kv_head_num: int, block_table: Tensor, seq_lengths: Tensor, block_size: int) -> Tensor: + ... + + +@paged_attention_inference.impl_abstract() +def lightllm_paged_attention_inference_abstract(q: Tensor, all_k: Tensor, all_v: Tensor, q_head_num: int, dim: int, kv_head_num: int, block_table: Tensor, seq_lengths: Tensor, block_size: int): + return torch.empty_like(q) + +@paged_attention_inference.impl(['cpu', 'cuda']) +def lightllm_paged_attention_inference_impl(q, all_k, all_v, q_head_num, dim, kv_head_num, block_table, seq_lengths, block_size): + # fake impl + return q + @torch._custom_op.impl.custom_op('lightllm::copy_with_offset') def copy_with_offset(x: Tensor, src: Tensor, start_dim: int, end_dim: int) -> Tensor: @@ -147,3 +161,18 @@ def lightllm_copy_with_offset_abstract(x: Tensor, src: Tensor, start_dim: int, e def lightllm_copy_with_offset_impl(x, src, start_dim, end_dim) -> Tensor: x[start_dim:end_dim] = src return x + +@torch._custom_op.impl.custom_op('lightllm::copy_with_index') +def copy_with_index(x: Tensor, src: Tensor, index: Tensor) -> Tensor: + ... + + +@copy_with_index.impl_abstract() +def lightllm_copy_with_index_abstract(x: Tensor, src: Tensor, index: Tensor) -> Tensor: + return x + + +@copy_with_index.impl(['cpu', 'cuda']) +def lightllm_copy_with_index_impl(x, src, index) -> Tensor: + x[index] = src + return x From 050f397ef3b2bff866d5f5cc882661d106088c69 Mon Sep 17 00:00:00 2001 From: zhoushenglong Date: Mon, 1 Jul 2024 19:01:19 +0000 Subject: [PATCH 2/2] add paged_attention test. --- dicp/dicp/vendor/AscendGraph/conversion.py | 6 +- dicp/dicp/vendor/AscendGraph/ext_ops.py | 33 +++++++++- dicp/test/op/test_lightllm_paged_attention.py | 61 +++++++++++++++++++ 3 files changed, 95 insertions(+), 5 deletions(-) create mode 100644 dicp/test/op/test_lightllm_paged_attention.py diff --git a/dicp/dicp/vendor/AscendGraph/conversion.py b/dicp/dicp/vendor/AscendGraph/conversion.py index 81e0a6f84..56efe828b 100644 --- a/dicp/dicp/vendor/AscendGraph/conversion.py +++ b/dicp/dicp/vendor/AscendGraph/conversion.py @@ -1766,7 +1766,7 @@ def prompt_attention_inference(self, q, k, v, seqlen, num_head, head_dim, num_ke return self.get_proxy(ascend_op.Cast, (fa, get_ascend_dtype(q_dtype))) return fa - def incre_flash_attention(self, q, k, v, kv_head_num, head_num, dim): + def incre_flash_attention(self, q, k, v, head_num, kv_head_num, dim): k_list = [] v_list = [] if not isinstance(k, list): @@ -1779,7 +1779,7 @@ def incre_flash_attention(self, q, k, v, kv_head_num, head_num, dim): v_list = v assert len(k_list) == len(v_list) kv_input_num = len(k_list) - out = self.get_proxy(ascend_op.IncreFlashAttention, (q, k_list, v_list, kv_input_num, kv_head_num, head_num, dim, "BSH")) + out = self.get_proxy(ascend_op.IncreFlashAttention, (q, k_list, v_list, kv_input_num, head_num, kv_head_num, dim, "BSH")) return out @register_conversion(aten.select_scatter.default) @@ -1874,7 +1874,7 @@ def flash_attention_inference(self, q, all_k, all_v, current_lens, max_len, kvhe xq = self.get_proxy(ascend_op.Reshape, (xq, q_shape)) xq = self.get_proxy(ascend_op.Reshape, (xq, q_compute_shape)) - out = self.incre_flash_attention(xq, k, v, kvhead, head, dim) # q shape is BSH + out = self.incre_flash_attention(xq, k, v, head, kvhead, dim) # q shape is BSH out_shape = self.get_shape_proxy([compute_batch, 1, head, dim]) out_shape2 = self.get_shape_proxy([compute_batch, head, dim]) out = self.get_proxy(ascend_op.Reshape, (out, out_shape)) diff --git a/dicp/dicp/vendor/AscendGraph/ext_ops.py b/dicp/dicp/vendor/AscendGraph/ext_ops.py index 503c6ca45..da874a60b 100644 --- a/dicp/dicp/vendor/AscendGraph/ext_ops.py +++ b/dicp/dicp/vendor/AscendGraph/ext_ops.py @@ -143,8 +143,37 @@ def lightllm_paged_attention_inference_abstract(q: Tensor, all_k: Tensor, all_v: @paged_attention_inference.impl(['cpu', 'cuda']) def lightllm_paged_attention_inference_impl(q, all_k, all_v, q_head_num, dim, kv_head_num, block_table, seq_lengths, block_size): - # fake impl - return q + # q: batch, head, dim + batch = q.shape[0] + head = q_head_num + current_lens = seq_lengths + + res = [] + compute_batch = 1 + for i in range(batch): + current_len = current_lens[i] + kv_seq_len = current_len + + k = all_k[:current_len].reshape(compute_batch, kv_seq_len, head, dim) + v = all_v[:current_len].reshape(compute_batch, kv_seq_len, head, dim) + + xq = q[i].view(compute_batch, 1, head, dim).transpose(1, 2).transpose(0, 1) # shape: head, batch, 1, dim + bmm_xq = xq.reshape(head * compute_batch, 1, dim).float() + bmm_xk = k.transpose(1, 2).transpose(0, 1).transpose(2, 3).reshape(head * compute_batch, dim, kv_seq_len).float() + + # q @ k + out = torch.bmm(bmm_xq, bmm_xk) / math.sqrt(dim) + out = out.reshape(head, compute_batch, 1, -1).reshape(head, compute_batch, -1) + + # softmax + out = out.softmax(-1).reshape(head, compute_batch, 1, kv_seq_len).transpose(0, 1) # shape: batch head 1 seq_len + xv = v.transpose(1, 2).float() # shape: batch head, seq_len, dim + out = torch.bmm(out.reshape(compute_batch * head, 1, kv_seq_len), xv.reshape(compute_batch * head, kv_seq_len, dim)) + + out = out.reshape(compute_batch, head, 1, dim).view(compute_batch, head, dim) + res.append(out) + res = torch.cat(res) + return res @torch._custom_op.impl.custom_op('lightllm::copy_with_offset') diff --git a/dicp/test/op/test_lightllm_paged_attention.py b/dicp/test/op/test_lightllm_paged_attention.py new file mode 100644 index 000000000..78145630c --- /dev/null +++ b/dicp/test/op/test_lightllm_paged_attention.py @@ -0,0 +1,61 @@ +import pytest + +from dicp.vendor.AscendGraph import ext_ops +from ..common.utils import ( + torch, + dynamo, + parse_args, + compile_model, + get_device, + Size, + update_dynamo_config, +) + + +class OpModule(torch.nn.Module): + def forward(self, q, all_k, all_v, q_head_num, dim, kv_head_num, block_table, seq_lengths, block_size): + res = torch.ops.lightllm.paged_attention_inference.default(q, all_k, all_v, q_head_num, dim, kv_head_num, block_table, seq_lengths, block_size) + return res + + +model = OpModule() +args = parse_args() +compiled_model = compile_model(model, args.backend, args.dynamic) + + +class TestLightllmPagedAttention(): + @pytest.mark.parametrize("dtype", [torch.float32]) + @pytest.mark.parametrize("sizes", [Size(((10,), (8, 16), (8, 16)), ((10,), (8, 16), (8, 16))), Size(((10,), (16, 32), (2, 32)), ((10,), (16, 32), (2, 32)))]) + @pytest.mark.parametrize("compiled_model", compiled_model) + def test_lightllm_paged_attention(self, sizes, dtype, compiled_model): + device = get_device() + size = sizes.dynamic if compiled_model.dynamic else sizes.static + + q = torch.randn((1,) + size[1], dtype=dtype) + k = torch.randn(size[0] + size[2], dtype=dtype) + v = torch.randn(size[0] + size[2], dtype=dtype) + + q_head_num = size[1][0] + dim = size[1][1] + kv_head_num = size[2][0] + block_table = torch.tensor([[0]], dtype=torch.int32) + seq_lengths = list(size[0]) + block_size = 128 + + dicp_q = q.to(device) + dicp_k = k.to(device) + dicp_v = v.to(device) + dicp_block_table = block_table.to(device) + dicp_seq_lengths = torch.tensor([seq_lengths], device=device, dtype=torch.int64) + + if q_head_num != kv_head_num: + repeat = q_head_num / kv_head_num + k = k.repeat(1, repeat, 1) + v = v.repeat(1, repeat, 1) + + output = model(q, k, v, q_head_num, dim, kv_head_num, block_table, seq_lengths, block_size).half().reshape(-1, q_head_num, dim) + dynamo.reset() + update_dynamo_config(compiled_model.dynamic) + dicp_output = compiled_model.model(dicp_q, dicp_k, dicp_v, q_head_num, dim, kv_head_num, dicp_block_table, dicp_seq_lengths, block_size).reshape(-1, q_head_num, dim) + + assert torch.allclose(output, dicp_output.cpu(), rtol=1e-02, atol=1e-02, equal_nan=True)