Skip to content

Commit 79874e0

Browse files
authored
Add paddle.incubate.graph_send_recv API docs (#4104)
* add paddle.incubate.send_recv API doc * add default value of pool_type * fix default value of pool_type * change import * change intro * fix bug in api * fix display bug * modify wording * mv send_recv to graph_send_recv
1 parent dbede61 commit 79874e0

File tree

1 file changed

+53
-0
lines changed

1 file changed

+53
-0
lines changed
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
.. _cn_api_incubate_graph_send_recv:
2+
3+
graph_send_recv
4+
-------------------------------
5+
6+
.. py:function:: paddle.incubate.graph_send_recv(x, src_index, dst_index, pool_type="sum", name=None)
7+
8+
此API主要应用于图学习领域,目的是为了减少在消息传递过程中带来的中间变量显存或内存的损耗。其中, ``x`` 作为输入Tensor,首先利用 ``src_index`` 作为索引来gather出在 ``x`` 中相应位置的数据,随后再将gather出的结果利用 ``dst_index`` 来更新到对应的输出结果中,其中 ``pool_type`` 表示不同的更新方式,包括sum、mean、max、min共计4种处理模式。
9+
10+
.. code-block:: text
11+
12+
X = [[0, 2, 3],
13+
[1, 4, 5],
14+
[2, 6, 7]]
15+
16+
src_index = [0, 1, 2, 0]
17+
18+
dst_index = [1, 2, 1, 0]
19+
20+
pool_type = "sum"
21+
22+
Then:
23+
24+
Out = [[0, 2, 3],
25+
[2, 8, 10],
26+
[1, 4, 5]]
27+
28+
参数
29+
:::::::::
30+
- x (Tensor) - 输入的 Tensor,数据类型为:float32、float64、int32、int64。
31+
- src_index (Tensor) - 1-D Tensor,数据类型为:int32、int64。
32+
- dst_index (Tensor) - 1-D Tensor,数据类型为:int32、int64。注意: ``dst_index`` 的形状应当与 ``src_index`` 一致。
33+
- pool_type (str) - scatter结果的不同处理方式,包括sum、mean、max、min。 默认值为 sum。
34+
- name (str,可选) - 操作的名称(可选,默认值为None)。更多信息请参见 :ref:`api_guide_Name` 。
35+
36+
返回
37+
:::::::::
38+
``Tensor`` ,维度和数据类型都与 ``x`` 相同,存储运算后的结果。
39+
40+
41+
代码示例
42+
::::::::::
43+
44+
.. code-block:: python
45+
46+
import paddle
47+
48+
x = paddle.to_tensor([[0, 2, 3], [1, 4, 5], [2, 6, 7]], dtype="float32")
49+
indexes = paddle.to_tensor([[0, 1], [1, 2], [2, 1], [0, 0]], dtype="int32")
50+
src_index = indexes[:, 0]
51+
dst_index = indexes[:, 1]
52+
out = paddle.incubate.graph_send_recv(x, src_index, dst_index, pool_type="sum")
53+
# Outputs: [[0., 2., 3.], [2., 8., 10.], [1., 4., 5.]]

0 commit comments

Comments
 (0)