@@ -369,7 +369,7 @@ train_fleet_static.py的完整训练代码如下所示。
369369 二、ParameterServer训练快速开始
370370-------------------------
371371
372- 本节将采用推荐领域非常经典的模型wide_and_deep为例,介绍如何使用Fleet API(paddle.distributed.fleet)完成参数服务器训练任务,本次快速开始的完整示例代码位于 https://github.com/PaddlePaddle/FleetX/tree/develop/examples/wide_and_deep 。
372+ 本节将采用推荐领域非常经典的模型wide_and_deep为例,介绍如何使用Fleet API(paddle.distributed.fleet)完成参数服务器训练任务,本次快速开始的完整示例代码位于 https://github.com/PaddlePaddle/FleetX/tree/develop/examples/wide_and_deep_dataset 。
373373
3743742.1 版本要求
375375^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
@@ -383,9 +383,10 @@ train_fleet_static.py的完整训练代码如下所示。
383383
384384 1. 导入分布式训练需要的依赖包。
385385 2. 定义分布式模式并初始化分布式训练环境。
386- 3. 加载模型及数据。
387- 4. 定义参数更新策略及优化器。
388- 5. 开始训练。
386+ 3. 加载模型。
387+ 4. 构建dataset加载数据
388+ 5. 定义参数更新策略及优化器。
389+ 6. 开始训练。
389390
390391下面将逐一进行讲解。
391392
@@ -410,37 +411,38 @@ train_fleet_static.py的完整训练代码如下所示。
410411 paddle.enable_static()
411412 fleet.init(is_collective = False )
412413
413- 2.2.3 加载模型及数据
414+ 2.2.3 加载模型
414415""""""""""""
415416
416417.. code-block :: python
417418
418- # 模型定义参考 examples/wide_and_deep 中 model.py
419+ # 模型定义参考 examples/wide_and_deep_dataset 中 model.py
419420 from model import WideDeepModel
420- from reader import WideDeepDataset
421421
422422 model = WideDeepModel()
423423 model.net(is_train = True )
424424
425- def distributed_training (exe , train_model , train_data_path = " ./data" , batch_size = 10 , epoch_num = 1 ):
426- train_data = WideDeepDataset(data_path = train_data_path)
427- reader = train_model.loader.set_sample_generator(
428- train_data, batch_size = batch_size, drop_last = True , places = paddle.CPUPlace())
429-
430- for epoch_id in range (epoch_num):
431- reader.start()
432- try :
433- while True :
434- loss_val = exe.run(program = paddle.static.default_main_program(),
435- fetch_list = [train_model.cost.name])
436- loss_val = np.mean(loss_val)
437- print (" TRAIN ---> pass: {} loss: {} \n " .format(epoch_id, loss_val))
438- except paddle.common_ops_import.core.EOFException:
439- reader.reset()
425+ 2.2.4 构建dataset加载数据
426+ """"""""""""
440427
441-
442-
443- 2.2.4 定义同步训练 Strategy 及 Optimizer
428+ .. code-block :: python
429+
430+ # 具体数据处理参考examples/wide_and_deep_dataset中reader.py
431+ dataset = paddle.distributed.QueueDataset()
432+ thread_num = 1
433+ dataset.init(use_var = model.inputs,
434+ pipe_command = " python reader.py" ,
435+ batch_size = batch_size,
436+ thread_num = thread_num)
437+
438+ train_files_list = [os.path.join(train_data_path, x)
439+ for x in os.listdir(train_data_path)]
440+ dataset.set_filelist(train_files_list)
441+
442+ 备注:dataset具体用法参见\ `使用InMemoryDataset/QueueDataset进行训练 <https://fleet-x.readthedocs.io/en/latest/paddle_fleet_rst/parameter_server/performance/dataset.html >`_\。
443+
444+
445+ 2.2.5 定义同步训练 Strategy 及 Optimizer
444446""""""""""""
445447
446448在Fleet API中,用户可以使用 ``fleet.DistributedStrategy() `` 接口定义自己想要使用的分布式策略。
@@ -466,14 +468,14 @@ train_fleet_static.py的完整训练代码如下所示。
466468 optimizer = fleet.distributed_optimizer(optimizer, dist_strategy)
467469 optimizer.minimize(model.loss)
468470
469- 2.2.5 开始训练
471+ 2.2.6 开始训练
470472""""""""""""
471473
472474完成模型及训练策略以后,我们就可以开始训练模型了。因为在参数服务器模式下会有不同的角色,所以根据不同节点分配不同的任务。
473475
474476对于服务器节点,首先用 ``init_server() `` 接口对其进行初始化,然后启动服务并开始监听由训练节点传来的梯度。
475477
476- 同样对于训练节点,用 ``init_worker() `` 接口进行初始化后, 开始执行训练任务。运行 ``exe.run () `` 接口开始训练,并得到训练中每一步的损失值 。
478+ 同样对于训练节点,用 ``init_worker() `` 接口进行初始化后, 开始执行训练任务。运行 ``exe.train_from_dataset () `` 接口开始训练。
477479
478480.. code-block :: python
479481
@@ -486,18 +488,28 @@ train_fleet_static.py的完整训练代码如下所示。
486488
487489 fleet.init_worker()
488490
489- distributed_training(exe, model)
490-
491+ for epoch_id in range (1 ):
492+ exe.train_from_dataset(paddle.static.default_main_program(),
493+ dataset,
494+ paddle.static.global_scope(),
495+ debug = False ,
496+ fetch_list = [train_model.cost],
497+ fetch_info = [" loss" ],
498+ print_period = 1 )
499+
491500 fleet.stop_worker()
492501
502+ 备注:Paddle2.3版本及以后,ParameterServer训练将废弃掉dataloader + exe.run()方式,请切换到dataset + exe.train_from_dataset()方式。
503+
504+
4935052.3 运行训练脚本
494506^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
495507
496- 定义完训练脚本后,我们就可以用 ``python3 -m paddle.distributed.launch `` 指令运行分布式任务了。其中 ``server_num `` , ``worker_num `` 分别为服务节点和训练节点的数量。在本例中,服务节点有1个,训练节点有2个。
508+ 定义完训练脚本后,我们就可以用 ``fleetrun `` 指令运行分布式任务了。其中 ``server_num `` , ``worker_num `` 分别为服务节点和训练节点的数量。在本例中,服务节点有1个,训练节点有2个。
497509
498510.. code-block :: bash
499511
500- python3 -m paddle.distributed.launch -- server_num=1 --worker_num=2 --gpus=0,1 train.py
512+ fleetrun -- server_num=1 --worker_num=2 train.py
501513
502514 您将看到显示如下日志信息:
503515
0 commit comments