@@ -200,29 +200,98 @@ fit Optional Arguments
200200Distributed PyTorch Training
201201============================
202202
203- You can run a multi-machine, distributed PyTorch training using the PyTorch Estimator. By default, PyTorch objects will
204- submit single-machine training jobs to SageMaker. If you set ``instance_count `` to be greater than one, multi-machine
205- training jobs will be launched when ``fit `` is called. When you run multi-machine training, SageMaker will import your
206- training script and run it on each host in the cluster.
203+ SageMaker supports the `PyTorch DistributedDataParallel (DDP)
204+ <https://pytorch.org/docs/master/generated/torch.nn.parallel.DistributedDataParallel.html> `_
205+ package. You simply need to check the variables in your training script,
206+ such as the world size and the rank of the current host, when initializing
207+ process groups for distributed training.
208+ And then, launch the training job using the
209+ :class: `sagemaker.pytorch.estimator.PyTorch ` estimator class
210+ with the ``pytorchddp `` option as the distribution strategy.
207211
208- To initialize distributed training in your script you would call ``dist.init_process_group `` providing desired backend
209- and rank and setting 'WORLD_SIZE' environment variable similar to how you would do it outside of SageMaker using
210- environment variable initialization:
212+ .. note ::
213+
214+ This PyTorch DDP support is available
215+ in the SageMaker PyTorch Deep Learning Containers v1.12 and later.
216+
217+ Adapt Your Training Script
218+ --------------------------
219+
220+ To initialize distributed training in your script, call
221+ `torch.distributed.init_process_group
222+ <https://pytorch.org/docs/master/distributed.html#torch.distributed.init_process_group> `_
223+ with the desired backend and the rank of the current host.
224+
225+ .. code :: python
226+
227+ import torch.distributed as dist
228+
229+ if args.distributed:
230+ # Initialize the distributed environment.
231+ world_size = len (args.hosts)
232+ os.environ[' WORLD_SIZE' ] = str (world_size)
233+ host_rank = args.hosts.index(args.current_host)
234+ dist.init_process_group(backend = args.backend, rank = host_rank)
235+
236+ SageMaker sets ``'MASTER_ADDR' `` and ``'MASTER_PORT' `` environment variables for you,
237+ but you can also overwrite them.
238+
239+ **Supported backends: **
240+
241+ - ``gloo `` and ``tcp `` for CPU instances
242+ - ``gloo `` and ``nccl `` for GPU instances
243+
244+ Launching a Distributed Training Job
245+ ------------------------------------
246+
247+ You can run multi-node distributed PyTorch training jobs using the
248+ :class: `sagemaker.pytorch.estimator.PyTorch ` estimator class.
249+ With ``instance_count=1 ``, the estimator submits a
250+ single-node training job to SageMaker; with ``instance_count `` greater
251+ than one, a multi-node training job is launched.
252+
253+ To run a distributed training script that adopts
254+ the `PyTorch DistributedDataParallel (DDP) package
255+ <https://pytorch.org/docs/master/generated/torch.nn.parallel.DistributedDataParallel.html> `_,
256+ choose the ``pytorchddp `` as the distributed training option in the ``PyTorch `` estimator.
257+
258+ With the ``pytorchddp `` option, the SageMaker PyTorch estimator runs a SageMaker
259+ training container for PyTorch, sets up the environment for MPI, and launches
260+ the training job using the ``mpirun `` command on each worker with the given information
261+ during the PyTorch DDP initialization.
262+
263+ .. note ::
264+
265+ The SageMaker PyTorch estimator doesn’t use ``torchrun `` for distributed training.
266+
267+ For more information about setting up PyTorch DDP in your training script,
268+ see `Getting Started with Distributed Data Parallel
269+ <https://pytorch.org/tutorials/intermediate/ddp_tutorial.html> `_ in the
270+ PyTorch documentation.
271+
272+ The following example shows how to run a PyTorch DDP training in SageMaker
273+ using two ``ml.p4d.24xlarge `` instances:
211274
212275.. code :: python
213276
214- if args.distributed:
215- # Initialize the distributed environment.
216- world_size = len (args.hosts)
217- os.environ[' WORLD_SIZE' ] = str (world_size)
218- host_rank = args.hosts.index(args.current_host)
219- dist.init_process_group(backend = args.backend, rank = host_rank)
277+ from sagemaker.pytorch import PyTorch
278+
279+ pt_estimator = PyTorch(
280+ entry_point = " train_ptddp.py" ,
281+ role = " SageMakerRole" ,
282+ framework_version = " 1.12.0" ,
283+ py_version = " py38" ,
284+ instance_count = 2 ,
285+ instance_type = " ml.p4d.24xlarge" ,
286+ distribution = {
287+ " pytorchddp" : {
288+ " enabled" : True
289+ }
290+ }
291+ )
220292
221- SageMaker sets 'MASTER_ADDR' and 'MASTER_PORT' environment variables for you, but you can overwrite them.
293+ pt_estimator.fit( " s3://bucket/path/to/training/data " )
222294
223- Supported backends:
224- - `gloo ` and `tcp ` for cpu instances
225- - `gloo ` and `nccl ` for gpu instances
226295
227296
228297*********************
0 commit comments