|
| 1 | +.. _cn_api_sparse_attention: |
| 2 | +sparse_attention |
| 3 | +------------------------------- |
| 4 | + |
| 5 | +.. py:function:: paddle.nn.functional.sparse_attention(query, key, value, sparse_csr_offset, sparse_csr_columns, name=None) |
| 6 | +
|
| 7 | +
|
| 8 | +该OP对Transformer模块中的Attention矩阵进行了稀疏化,从而减少内存消耗和计算量。 |
| 9 | + |
| 10 | +其稀疏数据排布通过CSR格式表示,CSR格式包含两个参数, ``offset`` 和 ``colunms`` 。计算公式为: |
| 11 | + |
| 12 | +.. math:: |
| 13 | + result=softmax(\frac{ Q * K^T }{\sqrt{d}}) * V |
| 14 | +
|
| 15 | +其中,``Q``,``K``,``V`` 表示注意力模块的三个输入参数。这三个参数的维度是一样的。 ``d`` 代表这三个参数的最后一个维度的大小。 |
| 16 | + |
| 17 | +参数: |
| 18 | +::::::::: |
| 19 | + - query (Tensor) - 输入的Tensor,代表注意力模块中的 ``query`` ,这是一个4维Tensor,形状为 :[batch_size, num_heads, seq_len, head_dim],数据类型为float32或float64。 |
| 20 | + - key (Tensor) - 输入的Tensor,代表注意力模块中的 ``key`` ,这是一个4维Tensor,形状为 :[batch_size, num_heads, seq_len, head_dim],数据类型为float32或float64。 |
| 21 | + - value (Tensor) - 输入的Tensor,代表注意力模块中的 ``value`` ,这是一个4维Tensor,形状为 :[batch_size, num_heads, seq_len, head_dim],数据类型为float32或float64。 |
| 22 | + - sparse_csr_offset (Tensor) - 输入的Tensor,注意力模块中的稀疏特性,稀疏特性使用CSR格式表示, ``offset`` 代表矩阵中每一行非零元的数量。这是一个3维Tensor,形状为 :[batch_size, num_heads, seq_len + 1],数据类型为int32。 |
| 23 | + - sparse_csr_columns (Tensor) - 输入的Tensor,注意力模块中的稀疏特性,稀疏特性使用CSR格式表示, ``colunms`` 代表矩阵中每一行非零元的列索引值。这是一个3维Tensor,形状为 :[batch_size, num_heads, sparse_nnz],数据类型为int32。 |
| 24 | + |
| 25 | +返回: |
| 26 | +::::::::: |
| 27 | + ``Tensor`` ,代表注意力模块的结果。这是一个4维Tensor,形状为 :[batch_size, num_heads, seq_len, head_dim],数据类型为float32或float64。 |
| 28 | + |
| 29 | +代码示例 |
| 30 | +:::::::::: |
| 31 | + |
| 32 | +.. code-block:: python |
| 33 | +
|
| 34 | + import paddle |
| 35 | + import numpy as np |
| 36 | + |
| 37 | + query_data = np.array([[[[0, 1,], [2, 3], |
| 38 | + [ 0, 1], [2, 3]]]]).astype("float32") |
| 39 | + key_data = np.array([[[[0, 1,], [2, 3], |
| 40 | + [ 0, 1], [2, 3]]]]).astype("float32") |
| 41 | + value_data = np.array([[[[0, 1,], [2, 3], |
| 42 | + [ 0, 1], [2, 3]]]]).astype("float32") |
| 43 | + sparse_csr_offset_data = np.array([[[0, 2, |
| 44 | + 4, 6, 8]]]).astype("int32") |
| 45 | + sparse_csr_columns_data = np.array([[[0, 1, |
| 46 | + 0, 1, 2, 3, 2, 3]]]).astype("int32") |
| 47 | + print(query_data.shape) |
| 48 | + # (1, 1, 4, 2) |
| 49 | + print(sparse_csr_offset_data.shape) |
| 50 | + # (1, 1, 5) |
| 51 | + print(sparse_csr_columns_data.shape) |
| 52 | + # (1, 1, 8) |
| 53 | + paddle.disable_static() |
| 54 | + query = paddle.to_tensor(query_data, stop_gradient=False, |
| 55 | + place=paddle.CUDAPlace(0)) |
| 56 | + key = paddle.to_tensor(key_data, stop_gradient=False, |
| 57 | + place=paddle.CUDAPlace(0)) |
| 58 | + value = paddle.to_tensor(value_data, stop_gradient=False, |
| 59 | + place=paddle.CUDAPlace(0)) |
| 60 | + offset = paddle.to_tensor(sparse_csr_offset_data, stop_gradient=False, |
| 61 | + place=paddle.CUDAPlace(0)) |
| 62 | + columns = paddle.to_tensor(sparse_csr_columns_data, stop_gradient=False, |
| 63 | + place=paddle.CUDAPlace(0)) |
| 64 | + output = paddle.nn.functional.sparse_attention(query, key, |
| 65 | + value, offset, columns) |
| 66 | + print(output) |
| 67 | + |
| 68 | + # [[[[1.60885942, 2.60885954], |
| 69 | + # [1.99830270, 2.99830270], |
| 70 | + # [1.60885942, 2.60885954], |
| 71 | + # [1.99830270, 2.99830270]]]] |
0 commit comments