@@ -14,8 +14,9 @@ package which is first introduced as an experimental feature in PyTorch v1.4.
1414Source code of the two examples can be found in
1515`PyTorch examples <https://github.com/pytorch/examples >`__
1616
17- `Previous <https://deploy-preview-807--pytorch-tutorials-preview.netlify.com/intermediate/ddp_tutorial.html >`__
18- `tutorials <https://deploy-preview-807--pytorch-tutorials-preview.netlify.com/intermediate/dist_tuto.html >`__
17+ Previous tutorials,
18+ `Getting Started With Distributed Data Parallel <https://pytorch.org/tutorials/intermediate/ddp_tutorial.html >`__
19+ and `Writing Distributed Applications With PyTorch <https://pytorch.org/tutorials/intermediate/dist_tuto.html >`__,
1920described `DistributedDataParallel <https://pytorch.org/docs/stable/_modules/torch/nn/parallel/distributed.html >`__
2021which supports a specific training paradigm where the model is replicated across
2122multiple processes and each process handles a split of the input data.
@@ -86,7 +87,7 @@ usages.
8687 return F.softmax(action_scores, dim = 1 )
8788
8889 Let's first prepare a helper to run functions remotely on the owner worker of an
89- ``RRef ``. You will find this function been used in several places this
90+ ``RRef ``. You will find this function being used in several places this
9091tutorial's examples. Ideally, the `torch.distributed.rpc ` package should provide
9192these helper functions out of box. For example, it will be easier if
9293applications can directly call ``RRef.some_func(*arg) `` which will then
@@ -159,7 +160,7 @@ constructor where most lines are initializing various components. The loop at
159160the end initializes observers remotely on other workers, and holds ``RRefs `` to
160161those observers locally. The agent will use those observer ``RRefs `` later to
161162send commands. Applications don't need to worry about the lifetime of ``RRefs ``.
162- The owner of each ``RRef `` maintains a reference counting map to track it's
163+ The owner of each ``RRef `` maintains a reference counting map to track its
163164lifetime, and guarantees the remote data object will not be deleted as long as
164165there is any live user of that ``RRef ``. Please refer to the ``RRef ``
165166`design doc <https://pytorch.org/docs/master/notes/rref.html >`__ for details.
@@ -408,10 +409,10 @@ The RNN model design is borrowed from the word language model in PyTorch
408409repository, which contains three main components, an embedding table, an
409410``LSTM `` layer, and a decoder. The code below wraps the embedding table and the
410411decoder into sub-modules, so that their constructors can be passed to the RPC
411- API. In the `EmbeddingTable ` sub-module, we intentionally put the ` Embedding `
412- layer on GPU to cover the use case. In v1.4, RPC always creates CPU tensor
413- arguments or return values on the destination worker. If the function takes a
414- GPU tensor, you need to move it to the proper device explicitly.
412+ API. In the `` EmbeddingTable `` sub-module, we intentionally put the
413+ `` Embedding `` layer on GPU to cover the use case. In v1.4, RPC always creates
414+ CPU tensor arguments or return values on the destination worker. If the function
415+ takes a GPU tensor, you need to move it to the proper device explicitly.
415416
416417
417418.. code :: python
@@ -446,17 +447,18 @@ With the above sub-modules, we can now piece them together using RPC to
446447create an RNN model. In the code below `` ps`` represents a parameter server,
447448which hosts parameters of the embedding table and the decoder. The constructor
448449uses the `remote < https:// pytorch.org/ docs/ master/ rpc.html# torch.distributed.rpc.remote>`__
449- API to create an `EmbeddingTable` object and a `Decoder` object on the parameter
450- server, and locally creates the `` LSTM `` sub- module. During the forward pass ,
451- the trainer uses the `` EmbeddingTable`` `` RRef`` to find the remote sub- module
452- and passes the input data to the `` EmbeddingTable`` using RPC and fetches the
453- lookup results. Then, it runs the embedding through the local `` LSTM `` layer,
454- and finally uses another RPC to send the output to the `` Decoder`` sub- module.
455- In general, to implement distributed model parallel training, developers can
456- divide the model into sub- modules, invoke RPC to create sub- module instances
457- remotely, and use on `` RRef`` to find them when necessary. As you can see in the
458- code below, it looks very similar to single- machine model parallel training. The
459- main difference is replacing `` Tensor.to(device)`` with RPC functions.
450+ API to create an `` EmbeddingTable`` object and a `` Decoder`` object on the
451+ parameter server, and locally creates the `` LSTM `` sub- module. During the
452+ forward pass , the trainer uses the `` EmbeddingTable`` `` RRef`` to find the
453+ remote sub- module and passes the input data to the `` EmbeddingTable`` using RPC
454+ and fetches the lookup results. Then, it runs the embedding through the local
455+ `` LSTM `` layer, and finally uses another RPC to send the output to the
456+ `` Decoder`` sub- module. In general, to implement distributed model parallel
457+ training, developers can divide the model into sub- modules, invoke RPC to create
458+ sub- module instances remotely, and use on `` RRef`` to find them when necessary.
459+ As you can see in the code below, it looks very similar to single- machine model
460+ parallel training. The main difference is replacing `` Tensor.to(device)`` with
461+ RPC functions.
460462
461463
462464.. code:: python
0 commit comments