diff --git a/docs/guides/01_paddle2.0_introduction/basic_concept/layer_and_model_cn.md b/docs/guides/01_paddle2.0_introduction/basic_concept/layer_and_model_cn.md index 9872d55e8f9..0f89e8308c4 100644 --- a/docs/guides/01_paddle2.0_introduction/basic_concept/layer_and_model_cn.md +++ b/docs/guides/01_paddle2.0_introduction/basic_concept/layer_and_model_cn.md @@ -303,8 +303,9 @@ Tensor(shape=[10, 1], dtype=float32, place=CPUPlace, stop_gradient=True, 同样的也可以使用 ``register_forward_pre_hook()`` 来注册**pre_hook**: ```python -def forward_pre_hook(layer, input, output): - return 2*output +def forward_pre_hook(layer, input): + print(input) + return input x = paddle.ones([10, 1], 'float32') model = Model() @@ -313,10 +314,17 @@ out = model(x) ``` ```text -Tensor(shape=[10, 1], dtype=float32, place=CPUPlace, stop_gradient=True, - [[2.], - [2.], - ... +(Tensor(shape=[10, 1], dtype=float32, place=CUDAPlace(0), stop_gradient=True, + [[1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.]]),) ``` ## 模型数据保存 diff --git a/docs/guides/01_paddle2.0_introduction/basic_concept/layer_and_model_en.md b/docs/guides/01_paddle2.0_introduction/basic_concept/layer_and_model_en.md index 3a667fc8c33..e96637cbb05 100644 --- a/docs/guides/01_paddle2.0_introduction/basic_concept/layer_and_model_en.md +++ b/docs/guides/01_paddle2.0_introduction/basic_concept/layer_and_model_en.md @@ -311,8 +311,9 @@ Tensor(shape=[10, 1], dtype=float32, place=CPUPlace, stop_gradient=True, Similarly, we can also register a **pre_hook** through ``register_forward_pre_hook()`` ```python -def forward_pre_hook(layer, input, output): - return 2*output +def forward_pre_hook(layer, input): + print(input) + return input x = paddle.ones([10, 1], 'float32') model = Model() @@ -321,10 +322,17 @@ out = model(x) ``` ```text -Tensor(shape=[10, 1], dtype=float32, place=CPUPlace, stop_gradient=True, - [[2.], - [2.], - ... +(Tensor(shape=[10, 1], dtype=float32, place=CUDAPlace(0), stop_gradient=True, + [[1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.]]),) ``` ## Save a model's data