6464 import paddle.nn.functional as F
6565 import paddle.distributed as dist
6666 import random
67+ from paddle.io import Dataset, BatchSampler, DataLoader
68+
69+
70+ 创建数据集
71+
72+ .. code-block :: python
73+ BATCH_NUM = 20
74+ BATCH_SIZE = 16
75+ EPOCH_NUM = 4
76+
77+ IMAGE_SIZE = 784
78+ CLASS_NUM = 10
79+ MICRO_BATCH_SIZE = 2
80+
81+ class RandomDataset (Dataset ):
82+ def __init__ (self , num_samples ):
83+ self .num_samples = num_samples
84+
85+ def __getitem__ (self , idx ):
86+ image = np.random.random([1 , 28 , 28 ]).astype(' float32' )
87+ label = np.random.randint(0 , CLASS_NUM - 1 , (1 , )).astype(' int64' )
88+ return image, label
89+
90+ def __len__ (self ):
91+ return self .num_samples
92+
93+ dataset = RandomDataset(BATCH_NUM * BATCH_SIZE )
94+ train_reader = DataLoader(dataset,
95+ batch_size = BATCH_SIZE ,
96+ shuffle = True ,
97+ drop_last = True ,
98+ num_workers = 2 )
6799
68100
69101 构建一个可以运行流水线的模型,模型的 layer 需要被 LayerDesc 或者继承了 LayerDesc 的 SharedLayerDesc 包裹,这里因为不需要共享参数,所以就使用 LayerDesc
77109 def forward (self , x ):
78110 return x.reshape(shape = self .shape)
79111
112+
80113 class AlexNetPipeDesc (PipelineLayer ):
81- def __init__ (self , num_classes = 10 , ** kwargs ):
114+ def __init__ (self , num_classes = CLASS_NUM , ** kwargs ):
82115 self .num_classes = num_classes
83116 decs = [
84117 LayerDesc(
108141 ]
109142 super (AlexNetPipeDesc, self ).__init__ (
110143 layers = decs, loss_fn = nn.CrossEntropyLoss(), ** kwargs)
111-
144+
112145 然后初始化分布式环境,这一步主要是构建流水线通信组的拓扑
113146
114147.. code-block :: python
115148
116- batch_size = 4
117- micro_batch_size = 2
118-
119149 strategy = fleet.DistributedStrategy()
120150 model_parallel_size = 1
121151 data_parallel_size = 1
126156 " pp_degree" : pipeline_parallel_size
127157 }
128158 strategy.pipeline_configs = {
129- " accumulate_steps" : batch_size // micro_batch_size ,
130- " micro_batch_size" : micro_batch_size
159+ " accumulate_steps" : BATCH_SIZE // MICRO_BATCH_SIZE ,
160+ " micro_batch_size" : MICRO_BATCH_SIZE
131161 }
132-
133-
134- fleet.init(is_collective = True , strategy = strategy)
162+
163+ fleet.init(is_collective = True , strategy = strategy)
135164
136165 为了保证流水线并行参数初始化和普通模型初始化一致,需要在不同卡间设置不同的 seed。
137166
@@ -162,7 +191,6 @@ fleet.distributed_optimizer(...):这一步则是为优化器添加分布式属
162191
163192.. code-block :: python
164193
165-
166194 class ReshapeHelp (Layer ):
167195 def __init__ (self , shape ):
168196 super (ReshapeHelp, self ).__init__ ()
@@ -214,35 +242,16 @@ fleet.distributed_optimizer(...):这一步则是为优化器添加分布式属
214242 optimizer = fleet.distributed_optimizer(optimizer)
215243
216244
217- 创建 mnist 数据集
218-
219- .. code-block :: python
220-
221- train_reader = paddle.batch(
222- paddle.dataset.mnist.train(), batch_size = batch_size, drop_last = True
223- )
224-
225245 开始训练
226246
227247model.train_batch(...):这一步主要就是执行 1F1B 的流水线并行方式
228248
229249.. code-block :: python
230250
231- for step_id, data in enumerate (train_reader()):
232- x_data = np.array([x[0 ] for x in data]).astype(" float32" ).reshape(
233- batch_size, 1 , 28 , 28
234- )
235- y_data = np.array([x[1 ] for x in data]).astype(" int64" ).reshape(
236- batch_size, 1
237- )
238- img = paddle.to_tensor(x_data)
239- label = paddle.to_tensor(y_data)
240- img.stop_gradient = True
241- label.stop_gradient = True
242- if step_id >= 5 :
243- break
244-
245- loss = model.train_batch([img, label], optimizer, scheduler)
251+ for i, (image, label) in enumerate (train_reader()):
252+ if i >= 5 :
253+ break
254+ loss = model.train_batch([image, label], optimizer, scheduler)
246255 print (" pp_loss: " , loss.numpy())
247256
248257 运行方式(需要保证当前机器有两张 GPU):
@@ -252,7 +261,7 @@ model.train_batch(...):这一步主要就是执行 1F1B 的流水线并行方
252261 export CUDA_VISIBLE_DEVICES=0,1
253262 python -m paddle.distributed.launch alexnet_dygraph_pipeline.py # alexnet_dygraph_pipeline.py 是用户运行动态图流水线的 python 文件
254263
255- 基于 AlexNet 的流水线并行动态图代码 :`alex <https://github.com/PaddlePaddle/FleetX/tree/develop/examples/pipeline >`_。
264+ 基于 AlexNet 的完整的流水线并行动态图代码 :`alex <https://github.com/PaddlePaddle/FleetX/tree/develop/examples/pipeline >`_。
256265
257266控制台输出信息如下:
258267
0 commit comments