From d7e33a972774e19c898ccdb07a99c1cc871d62cf Mon Sep 17 00:00:00 2001 From: zhangkaihuo Date: Thu, 28 Oct 2021 09:47:05 +0800 Subject: [PATCH] =?UTF-8?q?=E6=96=B0=E5=A2=9E=E4=B8=80=E4=B8=AA=E8=9E=8D?= =?UTF-8?q?=E5=90=88=E7=AE=97=E5=AD=90:fused=5Ffeedforward=20(#3999)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * 修改错别字 * 备注CPU不支持float16 * update example * add fused_feedforward * add fused_feedforward * add fused_feedforward * opt the description * update docs * update docs * update doc * move fused_feedforward docs position --- .../nn/functional/fused_feedforward_cn.rst | 59 +++++++++++++++++++ 1 file changed, 59 insertions(+) create mode 100644 docs/api/paddle/incubate/nn/functional/fused_feedforward_cn.rst diff --git a/docs/api/paddle/incubate/nn/functional/fused_feedforward_cn.rst b/docs/api/paddle/incubate/nn/functional/fused_feedforward_cn.rst new file mode 100644 index 00000000000..c3a81b56490 --- /dev/null +++ b/docs/api/paddle/incubate/nn/functional/fused_feedforward_cn.rst @@ -0,0 +1,59 @@ +.. _cn_api_incubate_nn_functional_fused_feedforward: + +fused_feedforward +------------------------------- + +.. py:function:: paddle.incubate.nn.functional.fused_feedforward(x, linear1_weight, linear2_weight, linear1_bias=None, linear2_bias=None, ln1_scale=None, ln1_bias=None, ln2_scale=None, ln2_bias=None, dropout1_rate=0.5, dropout2_rate=0.5,activation="relu", ln1_epsilon=1e-5, ln2_epsilon=1e-5, pre_layer_norm=False, name=None): + +这是一个融合算子,该算子是对transformer模型中feed forward层的多个算子进行融合,该算子只支持在GPU下运行,该算子与如下伪代码表达一样的功能: + +.. code-block:: ipython + + residual = src; + if pre_layer_norm: + src = layer_norm(src) + src = linear(dropout(activation(dropout(linear(src))))) + if not pre_layer_norm: + src = layer_norm(out) + +参数 +::::::::: + - **x** (Tensor) - 输入Tensor,数据类型支持float16, float32 和float64, 输入的形状是`[batch_size, sequence_length, d_model]`。 + - **linear1_weight** (Tensor) - 第一个linear算子的权重数据,数据类型与`x`一样,形状是`[d_model, dim_feedforward]`。 + - **linear2_weight** (Tensor) - 第二个linear算子的权重数据,数据类型与`x`一样,形状是`[dim_feedforward, d_model]`。 + - **linear1_bias** (Tensor, 可选) - 第一个linear算子的偏置数据,数据类型与`x`一样,形状是`[dim_feedforward]`。默认值为None。 + - **linear2_bias** (Tensor, 可选) - 第二个linear算子的偏置数据,数据类型与`x`一样,形状是`[d_model]`。默认值为None。 + - **ln1_scale** (Tensor, 可选) - 第一个layer_norm算子的权重数据,数据类型可以是float32或者float64,形状和`x`一样。默认值为None。 + - **ln1_bias** (Tensor, 可选) - 第一个layer_norm算子的偏置数据,数据类型和`ln1_scale`一样, 形状是`[d_model]`。默认值为None。 + - **ln2_scale** (Tensor, 可选) - 第二个layer_norm算子的权重数据,数据类型可以是float32或者float64,形状和`x`一样。默认值为None。 + - **ln2_bias** (Tensor, 可选) - 第二个layer_norm算子的偏置数据,数据类型和`ln2_scale`一样, 形状是`[d\_model]`。默认值为None。 + - **dropout1_rate** (float, 可选) - 第一个dropout算子置零的概率。默认是0.5。 + - **dropout2_rate** (float, 可选) - 第二个dropout算子置零的概率。默认是0.5。 + - **activation** (string, 可选) - 激活函数。默认值是relu。 + - **ln1_epsilon** (float, 可选) - 一个很小的浮点数,被第一个layer_norm算子加到分母,避免出现除零的情况。默认值是1e-5。 + - **ln2_epsilon** (float, 可选) - 一个很小的浮点数,被第二个layer_norm算子加到分母,避免出现除零的情况。默认值是1e-5。 + - **pre_layer_norm** (bool, 可选) - 在预处理阶段加上layer_norm,或者在后处理阶段加上layer_norm。默认值是False。 + - **name** (string, 可选) – fused_feedforward的名称, 默认值为None。更多信息请参见 :ref:`api_guide_Name` 。 + +返回 +::::::::: + - Tensor, 输出Tensor,数据类型与`x`一样。 + +代码示例 +:::::::::: + +.. code-block:: python + + # required: gpu + import paddle + import numpy as np + x_data = np.random.random((1, 8, 8)).astype("float32") + linear1_weight_data = np.random.random((8, 8)).astype("float32") + linear2_weight_data = np.random.random((8, 8)).astype("float32") + x = paddle.to_tensor(x_data) + linear1_weight = paddle.to_tensor(linear1_weight_data) + linear2_weight = paddle.to_tensor(linear2_weight_data) + out = paddle.incubate.nn.functional.fused_feedforward(x, linear1_weight, linear2_weight) + print(out.numpy().shape) + # (1, 8, 8) +