Skip to content

Commit f14a295

Browse files
authored
新增API:paddle.nn.functional.sparse_attention (#3958) (#4012)
* add sparse_attention docs * fix some bug * modify the docs
1 parent 5a2e342 commit f14a295

File tree

1 file changed

+71
-0
lines changed

1 file changed

+71
-0
lines changed
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
.. _cn_api_sparse_attention:
2+
sparse_attention
3+
-------------------------------
4+
5+
.. py:function:: paddle.nn.functional.sparse_attention(query, key, value, sparse_csr_offset, sparse_csr_columns, name=None)
6+
7+
8+
该OP对Transformer模块中的Attention矩阵进行了稀疏化,从而减少内存消耗和计算量。
9+
10+
其稀疏数据排布通过CSR格式表示,CSR格式包含两个参数, ``offset`` 和 ``colunms`` 。计算公式为:
11+
12+
.. math::
13+
result=softmax(\frac{ Q * K^T }{\sqrt{d}}) * V
14+
15+
其中,``Q``,``K``,``V`` 表示注意力模块的三个输入参数。这三个参数的维度是一样的。 ``d`` 代表这三个参数的最后一个维度的大小。
16+
17+
参数:
18+
:::::::::
19+
- query (Tensor) - 输入的Tensor,代表注意力模块中的 ``query`` ,这是一个4维Tensor,形状为 :[batch_size, num_heads, seq_len, head_dim],数据类型为float32或float64。
20+
- key (Tensor) - 输入的Tensor,代表注意力模块中的 ``key`` ,这是一个4维Tensor,形状为 :[batch_size, num_heads, seq_len, head_dim],数据类型为float32或float64。
21+
- value (Tensor) - 输入的Tensor,代表注意力模块中的 ``value`` ,这是一个4维Tensor,形状为 :[batch_size, num_heads, seq_len, head_dim],数据类型为float32或float64。
22+
- sparse_csr_offset (Tensor) - 输入的Tensor,注意力模块中的稀疏特性,稀疏特性使用CSR格式表示, ``offset`` 代表矩阵中每一行非零元的数量。这是一个3维Tensor,形状为 :[batch_size, num_heads, seq_len + 1],数据类型为int32。
23+
- sparse_csr_columns (Tensor) - 输入的Tensor,注意力模块中的稀疏特性,稀疏特性使用CSR格式表示, ``colunms`` 代表矩阵中每一行非零元的列索引值。这是一个3维Tensor,形状为 :[batch_size, num_heads, sparse_nnz],数据类型为int32。
24+
25+
返回:
26+
:::::::::
27+
``Tensor`` ,代表注意力模块的结果。这是一个4维Tensor,形状为 :[batch_size, num_heads, seq_len, head_dim],数据类型为float32或float64。
28+
29+
代码示例
30+
::::::::::
31+
32+
.. code-block:: python
33+
34+
import paddle
35+
import numpy as np
36+
37+
query_data = np.array([[[[0, 1,], [2, 3],
38+
[ 0, 1], [2, 3]]]]).astype("float32")
39+
key_data = np.array([[[[0, 1,], [2, 3],
40+
[ 0, 1], [2, 3]]]]).astype("float32")
41+
value_data = np.array([[[[0, 1,], [2, 3],
42+
[ 0, 1], [2, 3]]]]).astype("float32")
43+
sparse_csr_offset_data = np.array([[[0, 2,
44+
4, 6, 8]]]).astype("int32")
45+
sparse_csr_columns_data = np.array([[[0, 1,
46+
0, 1, 2, 3, 2, 3]]]).astype("int32")
47+
print(query_data.shape)
48+
# (1, 1, 4, 2)
49+
print(sparse_csr_offset_data.shape)
50+
# (1, 1, 5)
51+
print(sparse_csr_columns_data.shape)
52+
# (1, 1, 8)
53+
paddle.disable_static()
54+
query = paddle.to_tensor(query_data, stop_gradient=False,
55+
place=paddle.CUDAPlace(0))
56+
key = paddle.to_tensor(key_data, stop_gradient=False,
57+
place=paddle.CUDAPlace(0))
58+
value = paddle.to_tensor(value_data, stop_gradient=False,
59+
place=paddle.CUDAPlace(0))
60+
offset = paddle.to_tensor(sparse_csr_offset_data, stop_gradient=False,
61+
place=paddle.CUDAPlace(0))
62+
columns = paddle.to_tensor(sparse_csr_columns_data, stop_gradient=False,
63+
place=paddle.CUDAPlace(0))
64+
output = paddle.nn.functional.sparse_attention(query, key,
65+
value, offset, columns)
66+
print(output)
67+
68+
# [[[[1.60885942, 2.60885954],
69+
# [1.99830270, 2.99830270],
70+
# [1.60885942, 2.60885954],
71+
# [1.99830270, 2.99830270]]]]

0 commit comments

Comments
 (0)