Skip to content

Conversation

@Liu-xiandong
Copy link
Member

给paddle添加API:paddle.nn.functional.sparse_attention
cherry-pick #3958

示例:

import paddle
import numpy as np

query_data = np.array([[[[0, 1,], [2, 3], [ 0, 1], [2, 3]]]]).astype("float32")
key_data = np.array([[[[0, 1,], [2, 3], [ 0, 1], [2, 3]]]]).astype("float32")
value_data = np.array([[[[0, 1,], [2, 3], [ 0, 1], [2, 3]]]]).astype("float32")
sparse_csr_offset_data = np.array([[[0, 2, 4, 6, 8]]]).astype("int32")
sparse_csr_columns_data = np.array([[[0, 1, 0, 1, 2, 3, 2, 3]]]).astype("int32")
print(query_data.shape)
# (1, 1, 4, 2)
print(sparse_csr_offset_data.shape)
# (1, 1, 5)
print(sparse_csr_columns_data.shape)
# (1, 1, 8)
paddle.disable_static()
query = paddle.to_tensor(query_data, stop_gradient=False, place=paddle.CUDAPlace(0))
key = paddle.to_tensor(key_data, stop_gradient=False, place=paddle.CUDAPlace(0))
value = paddle.to_tensor(value_data, stop_gradient=False, place=paddle.CUDAPlace(0))
offset = paddle.to_tensor(sparse_csr_offset_data, stop_gradient=False, place=paddle.CUDAPlace(0))
columns = paddle.to_tensor(sparse_csr_columns_data, stop_gradient=False, place=paddle.CUDAPlace(0))
output = paddle.nn.functional.sparse_attention(query, key, value, offset, columns)
print(output)

# [[[[1.60885942, 2.60885954],
#       [1.99830270, 2.99830270],
#       [1.60885942, 2.60885954],
#       [1.99830270, 2.99830270]]]]

* add sparse_attention docs

* fix some bug

* modify the docs
@Liu-xiandong Liu-xiandong changed the title [cherry-pick]新增API:paddle.nn.functional.sparse_attention (#3958) [cherry-pick]新增API:paddle.nn.functional.sparse_attention Oct 26, 2021
@dingjiaweiww dingjiaweiww merged commit f14a295 into PaddlePaddle:release/2.2 Oct 26, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants