Skip to content
7 changes: 7 additions & 0 deletions docs/api/paddle/distributed/all_gather_cn.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,13 @@ all_gather
.. py:function:: paddle.distributed.all_gather(tensor_list, tensor, group=0)

进程组内所有进程的指定tensor进行聚合操作,并返回给所有进程聚合的结果。
如下图所示,4个GPU分别开启4个进程,每张卡上的数据用卡号代表,
经过all_gather算子后,每张卡都会拥有所有卡的数据。

.. image:: ./img/allgather.png
:width: 800
:alt: all_gather
:align: center

参数
:::::::::
Expand Down
7 changes: 7 additions & 0 deletions docs/api/paddle/distributed/all_reduce_cn.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,13 @@ all_reduce
.. py:function:: paddle.distributed.all_reduce(tensor, op=ReduceOp.SUM, group=0)

进程组内所有进程的指定tensor进行归约操作,并返回给所有进程归约的结果。
如下图所示,4个GPU分别开启4个进程,每张卡上的数据用卡号代表,规约操作为求和,
经过all_reduce算子后,每张卡都会拥有所有卡数据的总和。

.. image:: ./img/allreduce.png
:width: 800
:alt: all_reduce
:align: center

参数
:::::::::
Expand Down
11 changes: 9 additions & 2 deletions docs/api/paddle/distributed/alltoall_cn.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,15 @@ alltoall

.. py:function:: paddle.distributed.alltoall(in_tensor_list, out_tensor_list, group=None, use_calc_stream=True)

将in_tensor_list里面的tensors分发到所有参与的卡并将结果tensors汇总到out_tensor_list。

将in_tensor_list里面的tensors按照卡数均分并按照卡的顺序分发到所有参与的卡并将结果tensors汇总到out_tensor_list。
如下图所示,GPU0卡的in_tensor_list会按照两张卡拆分成0_0和0_1, GPU1卡的in_tensor_list同样拆分成1_0和1_1,经过alltoall算子后,
GPU0卡的0_0会发送给GPU0,GPU0卡的0_1会发送给GPU1,GPU1卡的1_0会发送给GPU0,GPU1卡的1_1会发送给GPU1,所以GPU0卡的out_tensor_list包含0_0和1_0,
GPU1卡的out_tensor_list包含0_1和1_1。

.. image:: ./img/alltoall.png
:width: 800
:alt: alltoall
:align: center

参数
:::::::::
Expand Down
8 changes: 7 additions & 1 deletion docs/api/paddle/distributed/broadcast_cn.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,13 @@ broadcast

.. py:function:: paddle.distributed.broadcast(tensor, src, group=0)

广播一个Tensor给其他所有进程
广播一个Tensor给其他所有进程。
如下图所示,4个GPU分别开启4个进程,GPU0卡拥有数据,经过broadcast算子后,会将这个数据传播到所有卡上。

.. image:: ./img/broadcast.png
:width: 800
:alt: broadcast
:align: center

参数
:::::::::
Expand Down
24 changes: 24 additions & 0 deletions docs/api/paddle/distributed/fleet/utils/recompute_cn.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
.. _cn_api_distributed_fleet_utils_recompute:

recompute
-------------------------------


.. py:function:: paddle.distributed.fleet.utils.recompute(function, *args, **kwargs)

重新计算中间激活函数值来节省显存。

参数
:::::::::
- function (paddle.nn.Sequential) - 模型前向传播的部分连续的层函数组成的序列,
它们的中间激活函数值将在前向传播过程中被释放掉来节省显存,并且在反向梯度计算的时候会重新被计算。
- args (Tensor) - function的输入。
- kwargs (Dict) - kwargs只应该包含preserve_rng_state的键值对,用来表示是否保存前向的rng,如果为True,那么在反向传播的重计算前向时会还原上次前向的rng值。默认preserve_rng_state为True。

返回
:::::::::
function作用在输入的输出

代码示例
:::::::::
COPY-FROM: paddle.distributed.fleet.utils.recompute
Binary file added docs/api/paddle/distributed/img/allgather.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/api/paddle/distributed/img/allreduce.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/api/paddle/distributed/img/alltoall.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/api/paddle/distributed/img/broadcast.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/api/paddle/distributed/img/reduce.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/api/paddle/distributed/img/scatter.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/api/paddle/distributed/img/split_col.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/api/paddle/distributed/img/split_row.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
7 changes: 7 additions & 0 deletions docs/api/paddle/distributed/reduce_cn.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,13 @@ reduce
.. py:function:: paddle.distributed.reduce(tensor, dst, op=ReduceOp.SUM, group=0)

进程组内所有进程的指定tensor进行归约操作,并返回给所有进程归约的结果。
如下图所示,4个GPU分别开启4个进程,每张卡上的数据用卡号代表,reduce的目标是第0张卡,
规约操作是求和,经过reduce操作后,第0张卡会得到所有卡数据的总和。

.. image:: ./img/reduce.png
:width: 800
:alt: reduce
:align: center

参数
:::::::::
Expand Down
7 changes: 7 additions & 0 deletions docs/api/paddle/distributed/scatter_cn.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,13 @@ scatter
.. py:function:: paddle.distributed.scatter(tensor, tensor_list=None, src=0, group=0)

进程组内指定进程源的tensor列表分发到其他所有进程中。
如下图所示,4个GPU分别开启4个进程,scatter的源选择为第0张卡,
经过scatter算子后,会将第0张卡的数据平均分到所有卡上。

.. image:: ./img/scatter.png
:width: 800
:alt: scatter
:align: center

参数
:::::::::
Expand Down
54 changes: 51 additions & 3 deletions docs/api/paddle/distributed/split_cn.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,61 @@ split
情形1:并行Embedding
Embedding操作的参数是个NxM的矩阵,行数为N,列数为M。并行Embedding情形下,参数切分到num_partitions个设备,每个设备上的参数是 (N/num_partitions + 1)行、M列的矩阵。其中,最后一行作为padding idx。

假设将NxM的参数矩阵切分到两个设备device_0和device_1。那么每个设置上的参数矩阵为(N/2+1)行和M列。device_0上,输入x中的值如果介于[0, N/2-1],则其值保持不变;否则值变更为N/2,经过embedding映射为全0值。类似地,device_1上,输入x中的值V如果介于[N/2, N-1]之间,那么这些值将变更为(V-N/2);否则,值变更为N/2,经过embedding映射为全0值。最后,使用all_reduce_sum操作汇聚各个卡上的结果。
假设将NxM的参数矩阵切分到两个设备device_0和device_1。那么每个设备上的参数矩阵为(N/2+1)行和M列。device_0上,输入x中的值如果介于[0, N/2-1],则其值保持不变;否则值变更为N/2,经过embedding映射为全0值。类似地,device_1上,输入x中的值V如果介于[N/2, N-1]之间,那么这些值将变更为(V-N/2);否则,值变更为N/2,经过embedding映射为全0值。最后,使用all_reduce_sum操作汇聚各个卡上的结果。

单卡Embedding情况如下图所示
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • Embedding 单卡 是不是做成乘法的方式好一点 in * [0, 0, 1, ... 0] -> out?并行同理,这样,比较好描述下面的 00....0是怎么来的?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

embedding不能单纯的看做乘法,它是类似于查表操作,split embedding相当于将这个表格分成几份放在不同位置,然后对input进行查表的时候,只会对有这个input的表格拿出对应feature,其他未存放这个input的就输出0


.. image:: ./img/split_embedding_single.png
:width: 800
:height: 350
:alt: single_embedding
:align: center

并行Embedding情况如下图所示

.. image:: ./img/split_embedding_split.png
:width: 800
:alt: split_embedding
:align: center

情形2:行并行Linear
Linear操作的参数是个NxM的矩阵,行数为N,列数为M。行并行Linear情形下,参数切分到num_partitions个设备,每个设备上的参数是N/num_partitions行、M列的矩阵。
Linear操作是将输入变量X(N*N)与权重矩阵W(N*M)进行矩阵相乘。行并行Linear情形下,参数切分到num_partitions个设备,每个设备上的参数是N/num_partitions行、M列的矩阵。

单卡Linear情况如下图所示,输入变量用X表示,权重矩阵用W表示,输出变量用O表示,单卡Linear就是一个简单的矩阵乘操作,O = X * W。


.. image:: ./img/split_single.png
:width: 800
:alt: single_linear
:align: center

行并行Linear情况如下图所示,顾名思义,行并行是按照权重矩阵W的行切分权重矩阵为
[[W_row1], [W_row2]],对应的输入X也按照列切成了两份[X_col1, X_col2],分别与各自对应的权重矩阵相乘,
最后通过AllReduce规约每张卡的输出得到最终输出。

.. image:: ./img/split_row.png
:width: 800
:alt: split_row
:align: center

情形3:列并行Linear
Linear操作的参数是个NxM的矩阵,行数为N,列数为M。列并行Linear情形下,参数切分到num_partitions个设备,每个设备上的参数是N行、M/num_partitions列的矩阵。
Linear操作是将输入变量X(N*N)与权重矩阵W(N*M)进行矩阵相乘。列并行Linear情形下,参数切分到num_partitions个设备,每个设备上的参数是N行、M/num_partitions列的矩阵。

单卡并行Linear可以看上面对应的图,列并行Linear情况如下图所示。列并行是按照权重矩阵W的列切分权重矩阵为[W_col1, W_col2],
X分别与切分出来的矩阵相乘,最后通过AllGather拼接每张卡的输出得到最终输出。

.. image:: ./img/split_col.png
:width: 800
:alt: split_col
:align: center

我们观察到,可以把上述按列切分矩阵乘法和按行切分矩阵乘法串联起来,从而省略掉一次AllGather通信操作,如下图所示。同时,我们注意到Transformer的Attention和MLP组件中各种两次矩阵乘法操作。因此,我们可以按照这种串联方式分别把Attention和MLP组件中的两次矩阵乘法串联起来,从而进一步优化性能。

.. image:: ./img/split_col_row.png
:width: 800
:alt: split_col_row
:align: center


参数
:::::::::
Expand Down
25 changes: 24 additions & 1 deletion docs/api/paddle/distributed/utils/global_gather_cn.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,32 @@ global_gather

.. py:function:: paddle.distributed.utils.global_gather(x, local_count, global_count, group=None, use_calc_stream=True)

根据global_count将x的数据收集到n_expert * world_size个expert,然后根据local_count接收数据。
global_gather根据global_count将x的数据收集到n_expert * world_size个expert,然后根据local_count接收数据。
其中expert是用户定义的专家网络,n_expert是指每张卡拥有的专家网络数目,world_size是指运行网络的显卡数目。

如下图所示,world_size是2,n_expert是2,x的batch_size是4,local_count是[2, 0, 2, 0],0卡的global_count是[2, 0, , ],
1卡的global_count是[2, 0, ,](因为篇幅问题,这里只展示在0卡运算的数据),在global_gather算子里,
global_count和local_count的意义与其在global_scatter里正好相反,
global_count[i]代表向第 (i // n_expert)张卡的第 (i % n_expert)个expert发送local_expert[i]个数据,
local_count[i]代表从第 (i // n_expert)张卡接收global_count[i]个数据给本卡的 第(i % n_expert)个expert。
发送的数据会按照每张卡的每个expert排列。图中的rank0代表第0张卡,rank1代表第1张卡。

global_gather发送数据的流程如下:

第0张卡的global_count[0]代表向第0张卡的第0个expert发送2个数据;

第0张卡的global_count[1]代表向第0张卡的第1个expert发送0个数据;

第1张卡的global_count[0]代表向第0张卡的第0个expert发送2个数据;

第1张卡的global_count[1]代表向第0张卡的第1个expert发送0个数据。


.. image:: ../img/global_scatter_gather.png
:width: 800
:alt: global_scatter_gather
:align: center


参数
:::::::::
Expand Down
30 changes: 29 additions & 1 deletion docs/api/paddle/distributed/utils/global_scatter_cn.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,37 @@ global_scatter

.. py:function:: paddle.distributed.utils.global_scatter(x, local_count, global_count, group=None, use_calc_stream=True)

根据local_count将x的数据分发到n_expert * world_size个expert,然后根据global_count接收数据。
global_scatter根据local_count将x的数据分发到n_expert * world_size个expert,然后根据global_count接收数据。
其中expert是用户定义的专家网络,n_expert是指每张卡拥有的专家网络数目,world_size是指运行网络的显卡数目。

如下图所示,world_size是2,n_expert是2,x的batch_size是4,local_count是[2, 0, 2, 0],0卡的global_count是[2, 0, , ],
1卡的global_count是[2, 0, ,](因为篇幅问题,这里只展示在0卡运算的数据),在global_scatter算子里,
local_count[i]代表向第 (i // n_expert)张卡的第 (i % n_expert)个expert发送local_expert[i]个数据,
global_count[i]代表从第 (i // n_expert)张卡接收global_count[i]个数据给本卡的 第(i % n_expert)个expert。
图中的rank0代表第0张卡,rank1代表第1张卡。
global_scatter发送数据的流程如下:

local_count[0]代表从x里取出2个batch的数据向第0张卡的第0个expert发送2个数据;

local_count[1]代表从x里取出0个batch的数据向第0张卡的第1个expert发送0个数据;

local_count[2]代表从x里取出2个batch的数据向第1张卡的第0个expert发送2个数据;

local_count[3]代表从x里取出0个batch的数据向第1张卡的第1个expert发送0个数据;

所以第0张卡的global_count[0]等于2,代表从第0张卡接收2个batch的数据给第0个expert;

第0张卡的global_count[1]等于0,代表从第0张卡接收0个batch的数据给第1个expert;

第1张卡的global_count[0]等于2,代表从第0张卡接收2个batch的数据给第0个expert;

第1张卡的global_count[1]等与0,代表从第0张卡接收0个batch的数据给第1个expert。


.. image:: ../img/global_scatter_gather.png
:width: 800
:alt: global_scatter_gather
:align: center

参数
:::::::::
Expand Down