Skip to content

Commit a6592da

Browse files
committed
ps quick start
1 parent a21215e commit a6592da

File tree

1 file changed

+43
-31
lines changed

1 file changed

+43
-31
lines changed

docs/guides/06_distributed_training/cluster_quick_start_cn.rst

Lines changed: 43 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -369,7 +369,7 @@ train_fleet_static.py的完整训练代码如下所示。
369369
二、ParameterServer训练快速开始
370370
-------------------------
371371

372-
本节将采用推荐领域非常经典的模型wide_and_deep为例,介绍如何使用Fleet API(paddle.distributed.fleet)完成参数服务器训练任务,本次快速开始的完整示例代码位于 https://github.com/PaddlePaddle/FleetX/tree/develop/examples/wide_and_deep
372+
本节将采用推荐领域非常经典的模型wide_and_deep为例,介绍如何使用Fleet API(paddle.distributed.fleet)完成参数服务器训练任务,本次快速开始的完整示例代码位于 https://github.com/PaddlePaddle/FleetX/tree/develop/examples/wide_and_deep_dataset
373373

374374
2.1 版本要求
375375
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
@@ -383,9 +383,10 @@ train_fleet_static.py的完整训练代码如下所示。
383383

384384
1. 导入分布式训练需要的依赖包。
385385
2. 定义分布式模式并初始化分布式训练环境。
386-
3. 加载模型及数据。
387-
4. 定义参数更新策略及优化器。
388-
5. 开始训练。
386+
3. 加载模型。
387+
4. 构建dataset加载数据
388+
5. 定义参数更新策略及优化器。
389+
6. 开始训练。
389390

390391
下面将逐一进行讲解。
391392

@@ -410,37 +411,38 @@ train_fleet_static.py的完整训练代码如下所示。
410411
paddle.enable_static()
411412
fleet.init(is_collective=False)
412413
413-
2.2.3 加载模型及数据
414+
2.2.3 加载模型
414415
""""""""""""
415416

416417
.. code-block:: python
417418
418-
# 模型定义参考 examples/wide_and_deep 中 model.py
419+
# 模型定义参考 examples/wide_and_deep_dataset 中 model.py
419420
from model import WideDeepModel
420-
from reader import WideDeepDataset
421421
422422
model = WideDeepModel()
423423
model.net(is_train=True)
424424
425-
def distributed_training(exe, train_model, train_data_path="./data", batch_size=10, epoch_num=1):
426-
train_data = WideDeepDataset(data_path=train_data_path)
427-
reader = train_model.loader.set_sample_generator(
428-
train_data, batch_size=batch_size, drop_last=True, places=paddle.CPUPlace())
429-
430-
for epoch_id in range(epoch_num):
431-
reader.start()
432-
try:
433-
while True:
434-
loss_val = exe.run(program=paddle.static.default_main_program(),
435-
fetch_list=[train_model.cost.name])
436-
loss_val = np.mean(loss_val)
437-
print("TRAIN ---> pass: {} loss: {}\n".format(epoch_id, loss_val))
438-
except paddle.common_ops_import.core.EOFException:
439-
reader.reset()
425+
2.2.4 构建dataset加载数据
426+
""""""""""""
440427

441-
442-
443-
2.2.4 定义同步训练 Strategy 及 Optimizer
428+
.. code-block:: python
429+
430+
# 具体数据处理参考examples/wide_and_deep_dataset中reader.py
431+
dataset = paddle.distributed.QueueDataset()
432+
thread_num = 1
433+
dataset.init(use_var=model.inputs,
434+
pipe_command="python reader.py",
435+
batch_size=batch_size,
436+
thread_num=thread_num)
437+
438+
train_files_list = [os.path.join(train_data_path, x)
439+
for x in os.listdir(train_data_path)]
440+
dataset.set_filelist(train_files_list)
441+
442+
备注:dataset具体用法参见\ `使用InMemoryDataset/QueueDataset进行训练 <https://fleet-x.readthedocs.io/en/latest/paddle_fleet_rst/parameter_server/performance/dataset.html>`_\。
443+
444+
445+
2.2.5 定义同步训练 Strategy 及 Optimizer
444446
""""""""""""
445447

446448
在Fleet API中,用户可以使用 ``fleet.DistributedStrategy()`` 接口定义自己想要使用的分布式策略。
@@ -466,14 +468,14 @@ train_fleet_static.py的完整训练代码如下所示。
466468
optimizer = fleet.distributed_optimizer(optimizer, dist_strategy)
467469
optimizer.minimize(model.loss)
468470
469-
2.2.5 开始训练
471+
2.2.6 开始训练
470472
""""""""""""
471473

472474
完成模型及训练策略以后,我们就可以开始训练模型了。因为在参数服务器模式下会有不同的角色,所以根据不同节点分配不同的任务。
473475

474476
对于服务器节点,首先用 ``init_server()`` 接口对其进行初始化,然后启动服务并开始监听由训练节点传来的梯度。
475477

476-
同样对于训练节点,用 ``init_worker()`` 接口进行初始化后, 开始执行训练任务。运行 ``exe.run()`` 接口开始训练,并得到训练中每一步的损失值
478+
同样对于训练节点,用 ``init_worker()`` 接口进行初始化后, 开始执行训练任务。运行 ``exe.train_from_dataset()`` 接口开始训练。
477479

478480
.. code-block:: python
479481
@@ -486,18 +488,28 @@ train_fleet_static.py的完整训练代码如下所示。
486488
487489
fleet.init_worker()
488490
489-
distributed_training(exe, model)
490-
491+
for epoch_id in range(1):
492+
exe.train_from_dataset(paddle.static.default_main_program(),
493+
dataset,
494+
paddle.static.global_scope(),
495+
debug=False,
496+
fetch_list=[train_model.cost],
497+
fetch_info=["loss"],
498+
print_period=1)
499+
491500
fleet.stop_worker()
492501
502+
备注:Paddle2.3版本及以后,ParameterServer训练将废弃掉dataloader + exe.run()方式,请切换到dataset + exe.train_from_dataset()方式。
503+
504+
493505
2.3 运行训练脚本
494506
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
495507

496-
定义完训练脚本后,我们就可以用 ``python3 -m paddle.distributed.launch`` 指令运行分布式任务了。其中 ``server_num`` , ``worker_num`` 分别为服务节点和训练节点的数量。在本例中,服务节点有1个,训练节点有2个。
508+
定义完训练脚本后,我们就可以用 ``fleetrun`` 指令运行分布式任务了。其中 ``server_num`` , ``worker_num`` 分别为服务节点和训练节点的数量。在本例中,服务节点有1个,训练节点有2个。
497509

498510
.. code-block:: bash
499511
500-
python3 -m paddle.distributed.launch --server_num=1 --worker_num=2 --gpus=0,1 train.py
512+
fleetrun --server_num=1 --worker_num=2 train.py
501513
502514
您将看到显示如下日志信息:
503515

0 commit comments

Comments
 (0)