Skip to content

Commit 762bedd

Browse files
LiyulingyueLigoml
andauthored
fix docs of auto_cast, static.save (#4668)
* Update auto_cast_cn.rst * Update save_cn.rst * Update auto_cast_cn.rst * Update auto_cast_cn.rst * Update save_cn.rst * Update auto_cast_cn.rst Co-authored-by: Ligoml <[email protected]>
1 parent f54f467 commit 762bedd

File tree

2 files changed

+10
-8
lines changed

2 files changed

+10
-8
lines changed

docs/api/paddle/amp/auto_cast_cn.rst

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,21 +3,22 @@
33
auto_cast
44
-------------------------------
55

6-
.. py:function:: paddle.amp.auto_cast(enable=True, custom_white_list=None, custom_black_list=None, level='O1')
6+
.. py:function:: paddle.amp.auto_cast(enable=True, custom_white_list=None, custom_black_list=None, level='O1', dtype='float16')
77
88
99
创建一个上下文环境,来支持动态图模式下执行的算子的自动混合精度策略(AMP)。
10-
如果启用AMP,使用autocast算法确定每个算子的输入数据类型(float32或float16),以获得更好的性能。
10+
如果启用 AMP,使用 autocast 算法确定每个算子的输入数据类型(float32 或 float16),以获得更好的性能。
1111
通常,它与 ``decorate`` 和 ``GradScaler`` 一起使用,来实现动态图模式下的自动混合精度。
12-
混合精度训练提供两种模式:``O1``代表采用黑名名单策略的混合精度训练;``O2``代表纯float16训练,除自定义黑名单和不支持float16的算子之外,全部使用float16计算
12+
混合精度训练提供两种模式: ``O1`` 代表采用黑名名单策略的混合精度训练; ``O2`` 代表纯 float16 训练,除自定义黑名单和不支持 float16 的算子之外,全部使用 float16 计算
1313

1414

1515
参数
1616
:::::::::
17-
- **enable** (bool, 可选) - 是否开启自动混合精度。默认值为True。
18-
- **custom_white_list** (set|list, 可选) - 自定义算子白名单。这个名单中的算子在支持float16计算时会被认为是数值安全的,并且对性能至关重要。如果设置了白名单,该名单中的算子会使用float16计算。
19-
- **custom_black_list** (set|list, 可选) - 自定义算子黑名单。这个名单中的算子在支持float16计算时会被认为是数值危险的,它们的影响也可能会在下游操作中观察到。这些算子通常不会转为float16计算。
20-
- **level** (str, 可选) - 混合精度训练模式,可为``O1``或``O2``模式,默认``O1``模式。
17+
- **enable** (bool,可选) - 是否开启自动混合精度。默认值为True。
18+
- **custom_white_list** (set|list,可选) - 自定义算子白名单。这个名单中的算子在支持 float16 计算时会被认为是数值安全的,并且对性能至关重要。如果设置了白名单,该名单中的算子会使用 float16 计算。
19+
- **custom_black_list** (set|list,可选) - 自定义算子黑名单。这个名单中的算子在支持 float16 计算时会被认为是数值危险的,它们的影响也可能会在下游操作中观察到。这些算子通常不会转为 float16 计算。
20+
- **level** (str,可选) - 混合精度训练模式,可为 ``O1`` 或 ``O2`` 模式,默认 ``O1`` 模式。
21+
- **dtype** (str,可选) - 使用的数据类型,可以是float16 或 bfloat16。默认为 float16。
2122

2223

2324
代码示例

docs/api/paddle/static/save_cn.rst

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
save
44
-------------------------------
55

6-
.. py:function:: paddle.static.save(program, model_path, protocol=4)
6+
.. py:function:: paddle.static.save(program, model_path, protocol=4, **configs)
77
88
99
该接口将传入的参数、优化器信息和网络描述保存到 ``model_path`` 。
@@ -21,6 +21,7 @@ save
2121
- **program** ( :ref:`cn_api_fluid_Program` ) – 要保存的Program。
2222
- **model_path** (str) – 保存program的文件前缀。格式为 ``目录名称/文件前缀``。如果文件前缀为空字符串,会引发异常。
2323
- **protocol** (int,可选) – pickle模块的协议版本,默认值为4,取值范围是[2,4]。
24+
- **\*\*configs** (dict,可选) - 可选的关键字参数。
2425

2526
返回
2627
::::::::::::

0 commit comments

Comments
 (0)