|
1 | | -本节将采用推荐领域非常经典的模型wide_and_deep为例,介绍如何使用Fleet API(paddle.distributed.fleet)完成参数服务器训练任务,本次快速开始的完整示例代码位于 https://github.com/PaddlePaddle/FleetX/tree/develop/examples/wide_and_deep。 |
2 | 1 |
|
3 | | -2.1 版本要求 |
| 2 | +.. _cluster_quick_start_ps: |
| 3 | + |
| 4 | +快速开始-参数服务器 |
| 5 | +------------------------- |
| 6 | + |
| 7 | +搜索推荐场景经常面临两个问题: |
| 8 | + |
| 9 | +1. 海量训练数据:单机训练太慢,需要增加训练节点数。 |
| 10 | +2. 特征维度高且稀疏化:模型稀疏参数过多,单机内存无法容纳,需要采用分布式存储。 |
| 11 | + |
| 12 | +参数服务器(ParameterServer)模式采用了一种将模型参数中心化管理的方式来实现模型参数的分布式存储和更新。该模式下的节点/进程有两种不同的角色: |
| 13 | + |
| 14 | +1. 训练节点(Trainer/Worker):该节点负责完成数据读取、从服务节点拉取参数、前向计算、反向梯度计算等过程,并将计算出的梯度上传至服务节点。 |
| 15 | +2. 服务节点(Server):在收到所有训练节点传来的梯度后,该节点会将梯度聚合并更新参数,供训练节点拉取进行下一轮的训练。 |
| 16 | + |
| 17 | +因此参数服务器模式对于存储超大规模模型参数的训练场景十分友好,常被用于训练拥有海量稀疏参数的搜索推荐领域模型。 |
| 18 | + |
| 19 | +2.1 任务介绍 |
4 | 20 | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ |
5 | 21 |
|
6 | | -在编写分布式训练程序之前,用户需要确保已经安装paddlepaddle-2.0.0-rc-cpu或paddlepaddle-2.0.0-rc-gpu及以上版本的飞桨开源框架。 |
| 22 | +本节将采用推荐领域非常经典的模型wide_and_deep为例,介绍如何使用飞桨分布式完成参数服务器训练任务,本次快速开始的完整示例代码位于 https://github.com/PaddlePaddle/FleetX/tree/develop/examples/wide_and_deep_dataset。 |
| 23 | +在编写分布式训练程序之前,用户需要确保已经安装PaddlePaddle2.3及以上版本的飞桨开源框架。 |
7 | 24 |
|
8 | 25 | 2.2 操作方法 |
9 | 26 | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ |
|
12 | 29 |
|
13 | 30 | 1. 导入分布式训练需要的依赖包。 |
14 | 31 | 2. 定义分布式模式并初始化分布式训练环境。 |
15 | | - 3. 加载模型及数据。 |
16 | | - 4. 定义参数更新策略及优化器。 |
17 | | - 5. 开始训练。 |
| 32 | + 3. 加载模型。 |
| 33 | + 4. 构建dataset加载数据 |
| 34 | + 5. 定义参数更新策略及优化器。 |
| 35 | + 6. 开始训练。 |
| 36 | + |
18 | 37 |
|
19 | 38 | 下面将逐一进行讲解。 |
20 | 39 |
|
|
39 | 58 | paddle.enable_static() |
40 | 59 | fleet.init(is_collective=False) |
41 | 60 |
|
42 | | -2.2.3 加载模型及数据 |
43 | | -"""""""""""" |
| 61 | +2.2.3 加载模型 |
44 | 62 |
|
45 | 63 | .. code-block:: python |
46 | 64 |
|
47 | | - # 模型定义参考 examples/wide_and_deep 中 model.py |
| 65 | + # 模型定义参考 examples/wide_and_deep_dataset 中 model.py |
48 | 66 | from model import WideDeepModel |
49 | | - from reader import WideDeepDataset |
50 | | -
|
51 | 67 | model = WideDeepModel() |
52 | 68 | model.net(is_train=True) |
53 | 69 |
|
54 | | - def distributed_training(exe, train_model, train_data_path="./data", batch_size=10, epoch_num=1): |
55 | | - train_data = WideDeepDataset(data_path=train_data_path) |
56 | | - reader = train_model.loader.set_sample_generator( |
57 | | - train_data, batch_size=batch_size, drop_last=True, places=paddle.CPUPlace()) |
58 | | -
|
59 | | - for epoch_id in range(epoch_num): |
60 | | - reader.start() |
61 | | - try: |
62 | | - while True: |
63 | | - loss_val = exe.run(program=paddle.static.default_main_program(), |
64 | | - fetch_list=[train_model.cost.name]) |
65 | | - loss_val = np.mean(loss_val) |
66 | | - print("TRAIN ---> pass: {} loss: {}\n".format(epoch_id, loss_val)) |
67 | | - except paddle.common_ops_import.core.EOFException: |
68 | | - reader.reset() |
| 70 | +2.2.4 构建dataset加载数据 |
| 71 | +"""""""""""" |
69 | 72 |
|
70 | | - |
71 | | - |
72 | | -2.2.4 定义同步训练 Strategy 及 Optimizer |
| 73 | +.. code-block:: python |
| 74 | +
|
| 75 | + # 具体数据处理参考examples/wide_and_deep_dataset中reader.py |
| 76 | + dataset = paddle.distributed.QueueDataset() |
| 77 | + thread_num = 1 |
| 78 | + dataset.init(use_var=model.inputs, |
| 79 | + pipe_command="python reader.py", |
| 80 | + batch_size=batch_size, |
| 81 | + thread_num=thread_num) |
| 82 | +
|
| 83 | + train_files_list = [os.path.join(train_data_path, x) |
| 84 | + for x in os.listdir(train_data_path)] |
| 85 | + dataset.set_filelist(train_files_list) |
| 86 | +
|
| 87 | +备注:dataset具体用法参见\ `使用InMemoryDataset/QueueDataset进行训练 <https://fleet-x.readthedocs.io/en/latest/paddle_fleet_rst/parameter_server/performance/dataset.html>`_\。 |
| 88 | + |
| 89 | + |
| 90 | +2.2.5 定义同步训练 Strategy 及 Optimizer |
73 | 91 | """""""""""" |
74 | 92 |
|
75 | 93 | 在Fleet API中,用户可以使用 ``fleet.DistributedStrategy()`` 接口定义自己想要使用的分布式策略。 |
|
82 | 100 | dist_strategy = fleet.DistributedStrategy() |
83 | 101 | dist_strategy.a_sync = True |
84 | 102 |
|
85 | | - # 定义同步训练 |
86 | | - dist_strategy = fleet.DistributedStrategy() |
87 | | - dist_strategy.a_sync = False |
88 | | -
|
89 | | - # 定义Geo异步训练, Geo异步目前只支持SGD优化算法 |
90 | | - dist_strategy = fleet.DistributedStrategy() |
91 | | - dist_strategy.a_sync = True |
92 | | - dist_strategy.a_sync_configs = {"k_steps": 100} |
93 | | -
|
94 | 103 | optimizer = paddle.optimizer.SGD(learning_rate=0.0001) |
95 | 104 | optimizer = fleet.distributed_optimizer(optimizer, dist_strategy) |
96 | 105 | optimizer.minimize(model.loss) |
97 | 106 |
|
98 | | -2.2.5 开始训练 |
| 107 | +2.2.6 开始训练 |
99 | 108 | """""""""""" |
100 | 109 |
|
101 | 110 | 完成模型及训练策略以后,我们就可以开始训练模型了。因为在参数服务器模式下会有不同的角色,所以根据不同节点分配不同的任务。 |
102 | 111 |
|
103 | 112 | 对于服务器节点,首先用 ``init_server()`` 接口对其进行初始化,然后启动服务并开始监听由训练节点传来的梯度。 |
104 | 113 |
|
105 | | -同样对于训练节点,用 ``init_worker()`` 接口进行初始化后, 开始执行训练任务。运行 ``exe.run()`` 接口开始训练,并得到训练中每一步的损失值。 |
| 114 | +同样对于训练节点,用 ``init_worker()`` 接口进行初始化后, 开始执行训练任务。运行 ``exe.train_from_dataset()`` 接口开始训练。 |
106 | 115 |
|
107 | 116 | .. code-block:: python |
108 | 117 |
|
|
115 | 124 |
|
116 | 125 | fleet.init_worker() |
117 | 126 |
|
118 | | - distributed_training(exe, model) |
119 | | -
|
| 127 | + for epoch_id in range(1): |
| 128 | + exe.train_from_dataset(paddle.static.default_main_program(), |
| 129 | + dataset, |
| 130 | + paddle.static.global_scope(), |
| 131 | + debug=False, |
| 132 | + fetch_list=[train_model.cost], |
| 133 | + fetch_info=["loss"], |
| 134 | + print_period=1) |
| 135 | + |
120 | 136 | fleet.stop_worker() |
121 | 137 |
|
| 138 | +备注:Paddle2.3版本及以后,ParameterServer训练将废弃掉dataloader + exe.run()方式,请切换到dataset + exe.train_from_dataset()方式。 |
| 139 | + |
| 140 | + |
122 | 141 | 2.3 运行训练脚本 |
123 | 142 | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ |
124 | 143 |
|
125 | | -定义完训练脚本后,我们就可以用 ``python3 -m paddle.distributed.launch`` 指令运行分布式任务了。其中 ``server_num`` , ``worker_num`` 分别为服务节点和训练节点的数量。在本例中,服务节点有1个,训练节点有2个。 |
| 144 | +定义完训练脚本后,我们就可以用 ``fleetrun`` 指令运行分布式任务了。其中 ``server_num`` , ``worker_num`` 分别为服务节点和训练节点的数量。在本例中,服务节点有1个,训练节点有2个。 |
126 | 145 |
|
127 | 146 | .. code-block:: bash |
128 | 147 |
|
129 | | - python3 -m paddle.distributed.launch --server_num=1 --worker_num=2 --gpus=0,1 train.py |
| 148 | + fleetrun --server_num=1 --trainer_num=2 train.py |
130 | 149 |
|
131 | | -您将看到显示如下日志信息: |
| 150 | +您将在执行终端看到如下日志信息: |
132 | 151 |
|
133 | 152 | .. code-block:: bash |
134 | 153 | |
135 | | - ----------- Configuration Arguments ----------- |
136 | | - gpus: 0,1 |
137 | | - heter_worker_num: None |
138 | | - heter_workers: |
139 | | - http_port: None |
140 | | - ips: 127.0.0.1 |
141 | | - log_dir: log |
142 | | - nproc_per_node: None |
143 | | - server_num: 1 |
144 | | - servers: |
145 | | - training_script: train.py |
146 | | - training_script_args: [] |
147 | | - worker_num: 2 |
148 | | - workers: |
149 | | - ------------------------------------------------ |
150 | | - INFO 2021-05-06 12:14:26,890 launch.py:298] Run parameter-sever mode. pserver arguments:['--worker_num', '--server_num'], cuda count:8 |
151 | | - INFO 2021-05-06 12:14:26,892 launch_utils.py:973] Local server start 1 processes. First process distributed environment info (Only For Debug): |
152 | | - +=======================================================================================+ |
153 | | - | Distributed Envs Value | |
154 | | - +---------------------------------------------------------------------------------------+ |
155 | | - | PADDLE_TRAINERS_NUM 2 | |
156 | | - | TRAINING_ROLE PSERVER | |
157 | | - | POD_IP 127.0.0.1 | |
158 | | - | PADDLE_GLOO_RENDEZVOUS 3 | |
159 | | - | PADDLE_PSERVERS_IP_PORT_LIST 127.0.0.1:34008 | |
160 | | - | PADDLE_PORT 34008 | |
161 | | - | PADDLE_WITH_GLOO 0 | |
162 | | - | PADDLE_HETER_TRAINER_IP_PORT_LIST | |
163 | | - | PADDLE_TRAINER_ENDPOINTS 127.0.0.1:18913,127.0.0.1:10025 | |
164 | | - | PADDLE_GLOO_HTTP_ENDPOINT 127.0.0.1:23053 | |
165 | | - | PADDLE_GLOO_FS_PATH /tmp/tmp8vqb8arq | |
166 | | - +=======================================================================================+ |
167 | | - |
168 | | - INFO 2021-05-06 12:14:26,902 launch_utils.py:1041] Local worker start 2 processes. First process distributed environment info (Only For Debug): |
169 | | - +=======================================================================================+ |
170 | | - | Distributed Envs Value | |
171 | | - +---------------------------------------------------------------------------------------+ |
172 | | - | PADDLE_GLOO_HTTP_ENDPOINT 127.0.0.1:23053 | |
173 | | - | PADDLE_GLOO_RENDEZVOUS 3 | |
174 | | - | PADDLE_PSERVERS_IP_PORT_LIST 127.0.0.1:34008 | |
175 | | - | PADDLE_WITH_GLOO 0 | |
176 | | - | PADDLE_TRAINER_ENDPOINTS 127.0.0.1:18913,127.0.0.1:10025 | |
177 | | - | FLAGS_selected_gpus 0 | |
178 | | - | PADDLE_GLOO_FS_PATH /tmp/tmp8vqb8arq | |
179 | | - | PADDLE_TRAINERS_NUM 2 | |
180 | | - | TRAINING_ROLE TRAINER | |
181 | | - | XPU_VISIBLE_DEVICES 0 | |
182 | | - | PADDLE_HETER_TRAINER_IP_PORT_LIST | |
183 | | - | PADDLE_TRAINER_ID 0 | |
184 | | - | CUDA_VISIBLE_DEVICES 0 | |
185 | | - | FLAGS_selected_xpus 0 | |
186 | | - +=======================================================================================+ |
187 | | - |
188 | | - INFO 2021-05-06 12:14:26,921 launch_utils.py:903] Please check servers, workers and heter_worker logs in log/workerlog.*, log/serverlog.* and log/heterlog.* |
189 | | - INFO 2021-05-06 12:14:33,446 launch_utils.py:914] all workers exit, going to finish parameter server and heter_worker. |
190 | | - INFO 2021-05-06 12:14:33,446 launch_utils.py:926] all parameter server are killed |
| 154 | + LAUNCH INFO 2022-05-18 11:27:17,761 ----------- Configuration ---------------------- |
| 155 | + LAUNCH INFO 2022-05-18 11:27:17,761 devices: None |
| 156 | + LAUNCH INFO 2022-05-18 11:27:17,761 elastic_level: -1 |
| 157 | + LAUNCH INFO 2022-05-18 11:27:17,761 elastic_timeout: 30 |
| 158 | + LAUNCH INFO 2022-05-18 11:27:17,761 gloo_port: 6767 |
| 159 | + LAUNCH INFO 2022-05-1811:27:17,761 host: None |
| 160 | + LAUNCH INFO 2022-05-18 11:27:17,761 job_id: default |
| 161 | + LAUNCH INFO 2022-05-18 11:27:17,761 legacy: False |
| 162 | + LAUNCH INFO 2022-05-18 11:27:17,761 log_dir: log |
| 163 | + LAUNCH INFO 2022-05-18 11:27:17,761 log_level: INFO |
| 164 | + LAUNCH INFO 2022-05-18 11:27:17,762 master: None |
| 165 | + LAUNCH INFO 2022-05-18 11:27:17,762 max_restart: 3 |
| 166 | + LAUNCH INFO 2022-05-18 11:27:17,762 nnodes: 1 |
| 167 | + LAUNCH INFO 2022-05-18 11:27:17,762 nproc_per_node: None |
| 168 | + LAUNCH INFO 2022-05-18 11:27:17,762 rank: -1 |
| 169 | + LAUNCH INFO 2022-05-18 11:27:17,762 run_mode: collective |
| 170 | + LAUNCH INFO 2022-05-18 11:27:17,762 server_num: 1 |
| 171 | + LAUNCH INFO 2022-05-18 11:27:17,762 servers: |
| 172 | + LAUNCH INFO 2022-05-18 11:27:17,762 trainer_num: 2 |
| 173 | + LAUNCH INFO 2022-05-18 11:27:17,762 trainers: |
| 174 | + LAUNCH INFO 2022-05-18 11:27:17,762 training_script: train.py |
| 175 | + LAUNCH INFO 2022-05-18 11:27:17,762 training_script_args: [] |
| 176 | + LAUNCH INFO 2022-05-18 11:27:17,762 with_gloo: 0 |
| 177 | + LAUNCH INFO 2022-05-18 11:27:17,762 -------------------------------------------------- |
| 178 | + LAUNCH INFO 2022-05-18 11:27:17,772 Job: default, mode ps, replicas 1[1:1], elastic False |
| 179 | + LAUNCH INFO 2022-05-18 11:27:17,775 Run Pod: evjsyn, replicas 3, status ready |
| 180 | + LAUNCH INFO 2022-05-18 11:27:17,795 Watching Pod: evjsyn, replicas 3, status running |
| 181 | +
|
| 182 | +同时,在log目录下,会生成服务节点和训练节点的日志文件。 |
| 183 | +服务节点日志:default.evjsyn.ps.0.log,日志中须包含以下内容,证明服务节点启动成功,可以提供服务。 |
| 184 | + |
| 185 | +.. code-block:: bash |
| 186 | +
|
| 187 | + I0518 11:27:20.730531 177420 brpc_ps_server.cc:73] running server with rank id: 0, endpoint: IP:PORT |
| 188 | +
|
| 189 | +训练节点日志:default.evjsyn.trainer.0.log,日志中打印了训练过程中的部分变量值。 |
| 190 | + |
| 191 | +.. code-block:: bash |
| 192 | +
|
| 193 | + time: [2022-05-18 11:27:27], batch: [1], loss[1]:[0.666739] |
| 194 | + time: [2022-05-18 11:27:27], batch: [2], loss[1]:[0.690405] |
| 195 | + time: [2022-05-18 11:27:27], batch: [3], loss[1]:[0.681693] |
| 196 | + time: [2022-05-18 11:27:27], batch: [4], loss[1]:[0.703863] |
| 197 | + time: [2022-05-18 11:27:27], batch: [5], loss[1]:[0.670717] |
| 198 | +
|
| 199 | +备注:启动相关问题,请参考\ `launch <https://www.paddlepaddle.org.cn/documentation/docs/zh/api/paddle/distributed/launch_cn.html>`_\。 |
0 commit comments