From b4e9f214a89e8cc6aef8e37692f6d4ddfcea33b3 Mon Sep 17 00:00:00 2001 From: Shen Li Date: Wed, 1 Jan 2020 19:50:07 -0800 Subject: [PATCH 01/14] Adding an RPC Tutorial --- index.rst | 5 + intermediate_source/rpc_tutorial.rst | 529 +++++++++++++++++++++++++++ 2 files changed, 534 insertions(+) create mode 100644 intermediate_source/rpc_tutorial.rst diff --git a/index.rst b/index.rst index 5817cadae94..884ec2148e1 100644 --- a/index.rst +++ b/index.rst @@ -203,6 +203,11 @@ Parallel and Distributed Training :description: :doc:`/intermediate/dist_tuto` :figure: _static/img/distributed/DistPyTorch.jpg +.. customgalleryitem:: + :tooltip: Getting Started with Distributed RPC Framework + :description: :doc:`/intermediate/rpc_tutorial` + :figure: _static/img/distributed/DistPyTorch.jpg + .. customgalleryitem:: :tooltip: PyTorch distributed trainer with Amazon AWS :description: :doc:`/beginner/aws_distributed_training_tutorial` diff --git a/intermediate_source/rpc_tutorial.rst b/intermediate_source/rpc_tutorial.rst new file mode 100644 index 00000000000..9a4a219887e --- /dev/null +++ b/intermediate_source/rpc_tutorial.rst @@ -0,0 +1,529 @@ +Getting Started with Distributed RPC Framework +================================================= +**Author**: `Shen Li `_ + + +This tutorial uses two simple examples to demonstrate how to build distributed +applications with the `torch.distributed.rpc` package. Source code of the two +examples can be found in `PyTorch examples `__ + +Previous tutorials described `DistributedDataParallel `__ +which supports a specific training paradigm where the model is replicated across +multiple processes and each process handles a split of the input data. +Sometimes, you might run into scenarios that require different training +paradigms: + +1) In reinforcement learning, it might be relatively expensive to acquire + training data from environments while the model itself can be quite small. In + this case, it might be useful to spawn multiple observers running in parallel + and share a single agent. In this case, the agent takes care of the training + locally, but the application would still need libraries to send and receive + data between observers and the trainer +2) Your model might be too large to fit in GPUs on a single machine, and hence + would need a library to help split a model onto multiple machines. Or you + might be implementing a parameter server training framework, where model + parameters and trainers live on different machines. + + +The `torch.distributed.rpc `__ package +can help with the above scenarios. In case 1, `RPC `__ +and `RRef `__ can help send data +from one worker to another and also easily referencing remote data objects. In +case 2, `distributed autograd `__ +and `distributed optimizer `__ +allows executing backward and optimizer step as if it is local training. In the +next two sections, we will demonstrate APIs of +`torch.distributed.rpc `__ using a +reinforcement learning example and a language model example. Please note, this +tutorial is not aiming at building the most accurate or efficient models to +solve given problems, instead the main goal is to show how to use the +`torch.distributed.rpc `__ package to +build distributed training applications. + + + +Distributed Reinforcement Learning using RPC and RRef +----------------------------------------------------- + +This section describes steps to build a toy distributed reinforcement learning +model using RPC to solve CartPole-v1 from `OpenAI Gym `__. +The policy code is mostly borrowed from the existing single-thread +`example `__ +as shown below. We will skip details of the ``Policy`` design, and focus on RPC +usages. + +.. code:: python + + class Policy(nn.Module): + + def __init__(self): + super(Policy, self).__init__() + self.affine1 = nn.Linear(4, 128) + self.dropout = nn.Dropout(p=0.6) + self.affine2 = nn.Linear(128, 2) + + self.saved_log_probs = [] + self.rewards = [] + + def forward(self, x): + x = self.affine1(x) + x = self.dropout(x) + x = F.relu(x) + action_scores = self.affine2(x) + return F.softmax(action_scores, dim=1) + +Let's first prepare a helper function to call a function on a local ``RRef``. It +might look unnecessary at the first glance, as you could simply do +``rref.local_value().some_func(args)`` to run the target function. The reason +for adding this helper function is because there is no way to get a reference +of a remote value, and ``local_value`` is only available on the owner of the +``RRef``. + +.. code:: python + + def _call_method(method, rref, *args, **kwargs): + return method(rref.local_value(), *args, **kwargs) + + + def _remote_method(method, rref, *args, async_call=False, **kwargs): + args = [method, rref] + list(args) + func = rpc_async if async_call else rpc_sync + return func(rref.owner(), _call_method, args=args, kwargs=kwargs) + + # to call a function on an rref, we could do the following + _remote_method(some_func, rref, *args) + + +We are ready to present the observer. In this example, each observer creates its +own environment, and waits for the agent's command to run an episode. In each +episode, one observer loops at most ``n_steps`` iterations, and in each +iteration, it uses RPC to pass its environment state to the agent and gets an +action back. Then it applies that action to its environment, and gets the reward +and the next state from the environment. After that, the observer uses another +RPC to report the reward to the agent. Again, please note that, this is +obviously not the most efficient observer implementation. For example, one +simple optimization could be packing current state and last reward in one RPC to +reduce the communication overhead. However, the goal is to demonstrate RPC API +instead of building the best solver for CartPole. + +.. code:: python + + class Observer: + + def __init__(self): + self.id = rpc.get_worker_info().id + self.env = gym.make('CartPole-v1') + self.env.seed(args.seed) + + def run_episode(self, agent_rref, n_steps): + state, ep_reward = self.env.reset(), 0 + for step in range(n_steps): + # send the state to the agent to get an action + action = _remote_method(Agent.select_action, agent_rref, self.id, state) + + # apply the action to the environment, and get the reward + state, reward, done, _ = self.env.step(action) + + # report the reward to the agent for training purpose + _remote_method(Agent.report_reward, agent_rref, self.id, reward) + + if done: + break + + +The code for agent is a little more complex, and we will break it into multiple +pieces. In this example, the agent serves as both the trainer and the master, +such that it sends command to multiple distributed observers to run episodes, +and it also records all actions and rewards locally which will be used during +the training phase after each episode. The code below shows ``Agent`` +constructor where most lines are initializing various components. The loop at +the end initializes observers on other workers, and holds ``RRefs`` to those +observers locally. The agent will use those observer ``RRefs`` later to send +commands. + + +.. code:: python + + class Agent: + def __init__(self, world_size): + self.ob_rrefs = [] + self.agent_rref = RRef(self) + self.rewards = {} + self.saved_log_probs = {} + self.policy = Policy() + self.optimizer = optim.Adam(self.policy.parameters(), lr=1e-2) + self.eps = np.finfo(np.float32).eps.item() + self.running_reward = 0 + self.reward_threshold = gym.make('CartPole-v1').spec.reward_threshold + for ob_rank in range(1, world_size): + ob_info = rpc.get_worker_info(OBSERVER_NAME.format(ob_rank)) + self.ob_rrefs.append(remote(ob_info, Observer)) + self.rewards[ob_info.id] = [] + self.saved_log_probs[ob_info.id] = [] + + +Next, the agent exposes two APIs to allow observers to select actions and report +rewards. Those functions are only run locally on the agent, but will be +triggered by observers through RPC. + + +.. code:: python + + class Agent: + ... + def select_action(self, ob_id, state): + state = torch.from_numpy(state).float().unsqueeze(0) + probs = self.policy(state) + m = Categorical(probs) + action = m.sample() + self.saved_log_probs[ob_id].append(m.log_prob(action)) + return action.item() + + def report_reward(self, ob_id, reward): + self.rewards[ob_id].append(reward) + + +Let's add a ``run_episode`` function on agent which tells all observers +to execute an episode. In this function, it first creates a list to collect +futures from asynchronous RPCs, and then loop over all observer ``RRefs`` to +make asynchronous RPCs. In these RPCs, the agent also passes an ``RRef`` of +itself to the observer, so that the observer can call functions on the agent as +well. As shown above, each observer will make RPCs back to the agent, which is +actually nested RPCs. After each episode, the ``saved_log_probs`` and +``rewards`` will contain the recorded action probs and rewards. + + +.. code:: python + + class Agent: + ... + def run_episode(self, n_steps=0): + futs = [] + for ob_rref in self.ob_rrefs: + # make async RPC to kick off an episode on all observers + futs.append( + _remote_method( + Observer.run_episode, + ob_rref, + self.agent_rref, + n_steps, + async_call=True + ) + ) + + # wait until all obervers have finished this episode + for fut in futs: + fut.wait() + + +Finally, after one episode, the agent needs to train the model, which +is implemented in the ``finish_episode`` function below. It is also a local +function and mostly borrowed from the single-thread +`example `__. + + + +.. code:: python + + class Agent: + ... + def finish_episode(self): + # joins probs and rewards from different observers into lists + R, probs, rewards = 0, [], [] + for ob_id in self.rewards: + probs.extend(self.saved_log_probs[ob_id]) + rewards.extend(self.rewards[ob_id]) + + # use the minimum observer reward to calculate the running reward + min_reward = min([sum(self.rewards[ob_id]) for ob_id in self.rewards]) + self.running_reward = 0.05 * min_reward + (1 - 0.05) * self.running_reward + + # clear saved probs and rewards + for ob_id in self.rewards: + self.rewards[ob_id] = [] + self.saved_log_probs[ob_id] = [] + + policy_loss, returns = [], [] + for r in rewards[::-1]: + R = r + args.gamma * R + returns.insert(0, R) + returns = torch.tensor(returns) + returns = (returns - returns.mean()) / (returns.std() + self.eps) + for log_prob, R in zip(probs, returns): + policy_loss.append(-log_prob * R) + self.optimizer.zero_grad() + policy_loss = torch.cat(policy_loss).sum() + policy_loss.backward() + self.optimizer.step() + return min_reward + + +With ``Policy``, ``Observer``, and ``Agent`` classes, we are ready to launch +multiple processes to perform the distributed training. In this example, all +processes run the same ``run_worker`` function, and they use the rank to +distinguish their role. Rank 0 is always the agent, and all other ranks are +observers. As agent as server as master, repeatedly call ``run_episode`` and +``finish_episode`` until the running reward surpasses the reward threshold +specified by the environment. All observers just passively waiting for commands +from the agent. + + +.. code:: python + + def run_worker(rank, world_size): + os.environ['MASTER_ADDR'] = 'localhost' + os.environ['MASTER_PORT'] = '29500' + if rank == 0: + # rank0 is the agent + rpc.init_rpc(AGENT_NAME, rank=rank, world_size=world_size) + + agent = Agent(world_size) + for i_episode in count(1): + n_steps = int(TOTAL_EPISODE_STEP / (args.world_size - 1)) + agent.run_episode(n_steps=n_steps) + last_reward = agent.finish_episode() + + if i_episode % args.log_interval == 0: + print('Episode {}\tLast reward: {:.2f}\tAverage reward: {:.2f}'.format( + i_episode, last_reward, agent.running_reward)) + + if agent.running_reward > agent.reward_threshold: + print("Solved! Running reward is now {}!".format(agent.running_reward)) + break + else: + # other ranks are the observer + rpc.init_rpc(OBSERVER_NAME.format(rank), rank=rank, world_size=world_size) + # observers passively waiting for instructions from agents + rpc.shutdown() + + + mp.spawn( + run_worker, + args=(args.world_size, ), + nprocs=args.world_size, + join=True + ) + + +In this example, we show how to use RPC as the communication vehicle to pass +date across workers, and how to use RRef to reference remote objects. It is true +that you could build the entire structure directly on top of ``ProcessGroup`` +``send`` and ``recv`` APIs or use other communication/RPC libraries. However, +by using `torch.dstributed.rpc`, you can get the native support plus +continuously optimized performance under the hood. + +Next, we will show how to combine RPC and RRef with distributed autograd and +distributed optimizer to perform distributed model parallel training. + + + + +Distributed RNN using Distributed Autograd and Distributed Optimizer +-------------------------------------------------------------------- + +In this section, we use an RNN model to show how to build distributed model +parallel training using the RPC API. The example RNN model is very small and +easily fit into a single GPU, but developer can apply the similar techniques to +much larger models that need to span multiple devices. The RNN model design is +borrowed from the word language model in PyTorch +`example `__ +repository, which contains three main components, an embedding table, an +``LSTM`` layer, and a decoder, as shown below. + + +.. code:: python + + class EmbeddingTable(nn.Module): + def __init__(self, ntoken, ninp, dropout): + super(EmbeddingTable, self).__init__() + self.drop = nn.Dropout(dropout) + self.encoder = nn.Embedding(ntoken, ninp) + self.encoder.weight.data.uniform_(-0.1, 0.1) + + def forward(self, input): + return self.drop(self.encoder(input)) + + + class RNN(nn.Module): + def __init__(self, ninp, nhid, nlayers, dropout): + super(RNN, self).__init__() + self.lstm = nn.LSTM(ninp, nhid, nlayers, dropout=dropout) + + def forward(self, emb, hidden): + return self.lstm(emb, hidden) + + + class Decoder(nn.Module): + def __init__(self, ntoken, nhid, dropout): + super(Decoder, self).__init__() + self.drop = nn.Dropout(dropout) + self.decoder = nn.Linear(nhid, ntoken) + self.decoder.bias.data.zero_() + self.decoder.weight.data.uniform_(-0.1, 0.1) + + def forward(self, output): + return self.decoder(self.drop(output)) + + +With the above three sub-modules, we can now piece them together using RPC to +create an RNN model. In the code below ``ps`` represents a parameter server, +which hosts paremeters of the embedding table and the decoder. The constructor +uses the `remote https://pytorch.org/docs/master/rpc.html#torch.distributed.rpc.remote`__ +API to create an `EmbeddingTable` and a `Decoder` object on the parameter +server, and locally creates the ``LSTM`` sub-module. During the forward pass, +the trainer uses the ``EmbeddingTable`` ``RRef`` to find the remote sub-module +and passes the input data to the ``EmbeddingTable`` using RPC and fetches the +lookup results. Then, it runs the embedding through the local ``LSTM`` layer, +and finally uses another RPC to send the output to the ``Decoder`` sub-module. +In general, to implement distributed model parallel training, developers can +divide the model into sub-modules, invoke RPC to create sub-module instances +remotely, and use on ``RRef`` to find them when necessary. As you can see in the +code below, it looks very similar to single-machine model parallel training. The +main difference is replacing ``Tensor.to(device)`` with RPC functions. + + +.. code:: python + + class RNNModel(nn.Module): + def __init__(self, ps, ntoken, ninp, nhid, nlayers, dropout=0.5): + super(RNNModel, self).__init__() + + # setup embedding table remotely + self.emb_table_rref = rpc.remote(ps, EmbeddingTable, args=(ntoken, ninp, dropout)) + # setup LSTM locally + self.rnn = nn.LSTM(ninp, nhid, nlayers, dropout=dropout) + # setup decoder remotely + self.decoder_rref = rpc.remote(ps, Decoder, args=(ntoken, nhid, dropout)) + + def forward(self, input, hidden): + # pass input to the remote embedding table and fetch emb tensor back + emb = _remote_method(EmbeddingTable.forward, self.emb_table_rref, input) + output, hidden = self.rnn(emb, hidden) + # pass output to the rremote decoder and get the decoded output back + decoded = _remote_method(Decoder.forward, self.decoder_rref, output) + return decoded, hidden + +Before introducing the distributed optimizer, let's add a helper function to +generate a list of RRefs of model parameters, which will be consumed by the +distributed optimizer. In local training, applications could call +``Module.parameters()`` to grab references to all parameter tensors, and pass it +to the local optimizer to update. However, the same API does not work in +the distributed training scenarios as some parameters live on remote machines. +Therefore, instead of taking a list of parameter ``Tensors``, the distributed +optimizer takes a list of ``RRefs``, one ``RRef`` per model parameter for both +local and remote parameters. The helper function is pretty simple, just call +``Module.parameters()`` and creates a local ``RRef`` on each of the parameters. + +.. code:: python + + def _parameter_rrefs(module): + param_rrefs = [] + for param in module.parameters(): + param_rrefs.append(RRef(param)) + return param_rrefs + +Then, as the ``RNNModel`` contains three sub-modules, we need to call +``_parameter_rrefs`` three times, and wrap that into another helper function. + +.. code:: python + + class RNNModel(nn.Module): + ... + def parameter_rrefs(self): + remote_params = [] + # get RRefs of embedding table + remote_params.extend(_remote_method(_parameter_rrefs, self.emb_table_rref)) + # create RRefs for local parameters + remote_params.extend(_parameter_rrefs(self.rnn)) + # get RRefs of decoder + remote_params.extend(_remote_method(_parameter_rrefs, self.decoder_rref)) + return remote_params + +Now, we are ready to implement the training loop. After initializing the model +arguments, we create the ``RNNModel`` and the ``DistributedOptimizer``. The +distributed optimizer will take a list of parameter ``RRefs``, find all distinct +owner workers, and create the given local optimizer (i.e., ``SGD`` in this case) +on each of the owner worker using the given arguments (i.e., ``lr=0.05``). + +In the training loop, it first creates a distributed autograd context, which +will help the distributed autograd engine to find gradients and involved RPC +send/recv functions. Then, it kicks off the forward pass as if it is a local +model, and run the distributed backward pass. For the distributed backward, you +only need to specify a list of roots, in this case, it is the loss ``Tensor``. +The distributed autograd engine will traverse the distributed graph +automatically and write gradients properly. Next, it runs the ``step`` +API on the distributed optimizer, which will reach out to all involved local +optimizers to update model parameters. Compared to local training, one minor +difference is that you don't need to run ``zero_grad()`` because each autograd +context has dedicated space to store gradients, and as we create a context +per iteration, those gradients from different iterations will not accumulate to +the same set of ``Tensors``. + +.. code:: python + + def run_trainer(): + batch = 5 + ntoken = 10 + ninp = 2 + + nhid = 3 + nindices = 3 + nlayers = 4 + hidden = ( + torch.randn(nlayers, nindices, nhid), + torch.randn(nlayers, nindices, nhid) + ) + + model = rnn.RNNModel('ps', ntoken, ninp, nhid, nlayers) + + # setup distributed optimizer + opt = DistributedOptimizer( + optim.SGD, + model.parameter_rrefs(), + lr=0.05, + ) + + # train for 10 iterations + for epoch in range(10): + # create distributed autograd context + with dist_autograd.context(): + inp = torch.LongTensor(batch, nindices) % ntoken + hidden[0].detach_() + hidden[1].detach_() + output, hidden = model(inp, hidden) + # run distributed backward pass + dist_autograd.backward([output.sum()]) + # run distributed optimizer + opt.step() + # not necessary to zero grads as each iteration creates a different + # distributed autograd context which hosts different grads + print("Training epoch {}".format(epoch)) + + +Finally, let's add some glue code to launch the parameter server and the trainer +processes. + + +.. code:: python + + def run_ps(): + pass + + def run_worker(name, rank, func, world_size): + os.environ['MASTER_ADDR'] = 'localhost' + os.environ['MASTER_PORT'] = '29500' + rpc.init_rpc(name, rank=rank, world_size=world_size) + + func() + + # block until all rpcs finish + rpc.shutdown() + + mp.set_start_method('spawn') + ps = mp.Process(target=run_worker, args=("ps", 0, run_ps, 2)) + ps.start() + + trainer = mp.Process(target=run_worker, args=("trainer", 1, run_trainer, 2)) + trainer.start() + ps.join() + trainer.join() From 494f02f1c937df1afae029262dac99c2c1b9c913 Mon Sep 17 00:00:00 2001 From: Shen Li Date: Fri, 3 Jan 2020 14:33:02 -0800 Subject: [PATCH 02/14] Address comments --- index.rst | 1 + intermediate_source/rpc_tutorial.rst | 73 ++++++++++++++++++---------- 2 files changed, 49 insertions(+), 25 deletions(-) diff --git a/index.rst b/index.rst index 884ec2148e1..67b7082cc0e 100644 --- a/index.rst +++ b/index.rst @@ -382,6 +382,7 @@ PyTorch Fundamentals In-Depth intermediate/model_parallel_tutorial intermediate/ddp_tutorial intermediate/dist_tuto + intermediate/rpc_tutorial beginner/aws_distributed_training_tutorial .. toctree:: diff --git a/intermediate_source/rpc_tutorial.rst b/intermediate_source/rpc_tutorial.rst index 9a4a219887e..c36c06bb7e2 100644 --- a/intermediate_source/rpc_tutorial.rst +++ b/intermediate_source/rpc_tutorial.rst @@ -54,6 +54,9 @@ usages. .. code:: python + import torch.nn as nn + import torch.nn.functional as F + class Policy(nn.Module): def __init__(self): @@ -81,17 +84,25 @@ of a remote value, and ``local_value`` is only available on the owner of the .. code:: python + from torch.distributed.rpc import rpc_sync + def _call_method(method, rref, *args, **kwargs): return method(rref.local_value(), *args, **kwargs) - def _remote_method(method, rref, *args, async_call=False, **kwargs): + def _remote_method(method, rref, *args, **kwargs): args = [method, rref] + list(args) - func = rpc_async if async_call else rpc_sync - return func(rref.owner(), _call_method, args=args, kwargs=kwargs) + return rpc_sync(rref.owner(), _call_method, args=args, kwargs=kwargs) # to call a function on an rref, we could do the following - _remote_method(some_func, rref, *args) + # _remote_method(some_func, rref, *args) + + +Ideally, the `torch.distributed.rpc` package should provide these helper +functions out of box. For example, it will be easier if applications can +directly call ``RRef.some_func(*arg)`` which will then translate to RPC to the +``RRef`` owner. The progress on this API is tracked in in +`pytorch/pytorch#31743 `__. We are ready to present the observer. In this example, each observer creates its @@ -108,6 +119,9 @@ instead of building the best solver for CartPole. .. code:: python + import gym + import torch.distributed.rpc as rpc + class Observer: def __init__(self): @@ -144,6 +158,15 @@ commands. .. code:: python + import gym + import numpy as np + + import torch + import torch.distributed.rpc as rpc + import torch.optim as optim + from torch.distributed.rpc import RRef, rpc_async, remote + from torch.distributions import Categorical + class Agent: def __init__(self, world_size): self.ob_rrefs = [] @@ -198,16 +221,13 @@ actually nested RPCs. After each episode, the ``saved_log_probs`` and class Agent: ... def run_episode(self, n_steps=0): - futs = [] for ob_rref in self.ob_rrefs: # make async RPC to kick off an episode on all observers futs.append( - _remote_method( - Observer.run_episode, - ob_rref, - self.agent_rref, - n_steps, - async_call=True + rpc_async( + ob_rref.owner(), + _call_method, + args=(Observer.run_episode, ob_rref, self.agent_rref, n_steps) ) ) @@ -270,6 +290,11 @@ from the agent. .. code:: python + import os + from itertools import count + + import torch.multiprocessing as mp + def run_worker(rank, world_size): os.environ['MASTER_ADDR'] = 'localhost' os.environ['MASTER_PORT'] = '29500' @@ -293,7 +318,7 @@ from the agent. else: # other ranks are the observer rpc.init_rpc(OBSERVER_NAME.format(rank), rank=rank, world_size=world_size) - # observers passively waiting for instructions from agents + # observers passively waiting for instructions from the agent rpc.shutdown() @@ -328,7 +353,9 @@ much larger models that need to span multiple devices. The RNN model design is borrowed from the word language model in PyTorch `example `__ repository, which contains three main components, an embedding table, an -``LSTM`` layer, and a decoder, as shown below. +``LSTM`` layer, and a decoder. The code below wraps the embedding table and the +decode into sub-modules, so that their constructors can be passed to the RPC +API. .. code:: python @@ -344,15 +371,6 @@ repository, which contains three main components, an embedding table, an return self.drop(self.encoder(input)) - class RNN(nn.Module): - def __init__(self, ninp, nhid, nlayers, dropout): - super(RNN, self).__init__() - self.lstm = nn.LSTM(ninp, nhid, nlayers, dropout=dropout) - - def forward(self, emb, hidden): - return self.lstm(emb, hidden) - - class Decoder(nn.Module): def __init__(self, ntoken, nhid, dropout): super(Decoder, self).__init__() @@ -365,10 +383,10 @@ repository, which contains three main components, an embedding table, an return self.decoder(self.drop(output)) -With the above three sub-modules, we can now piece them together using RPC to +With the above sub-modules, we can now piece them together using RPC to create an RNN model. In the code below ``ps`` represents a parameter server, -which hosts paremeters of the embedding table and the decoder. The constructor -uses the `remote https://pytorch.org/docs/master/rpc.html#torch.distributed.rpc.remote`__ +which hosts parameters of the embedding table and the decoder. The constructor +uses the `remote `__ API to create an `EmbeddingTable` and a `Decoder` object on the parameter server, and locally creates the ``LSTM`` sub-module. During the forward pass, the trainer uses the ``EmbeddingTable`` ``RRef`` to find the remote sub-module @@ -414,6 +432,7 @@ optimizer takes a list of ``RRefs``, one ``RRef`` per model parameter for both local and remote parameters. The helper function is pretty simple, just call ``Module.parameters()`` and creates a local ``RRef`` on each of the parameters. + .. code:: python def _parameter_rrefs(module): @@ -422,9 +441,11 @@ local and remote parameters. The helper function is pretty simple, just call param_rrefs.append(RRef(param)) return param_rrefs + Then, as the ``RNNModel`` contains three sub-modules, we need to call ``_parameter_rrefs`` three times, and wrap that into another helper function. + .. code:: python class RNNModel(nn.Module): @@ -439,6 +460,7 @@ Then, as the ``RNNModel`` contains three sub-modules, we need to call remote_params.extend(_remote_method(_parameter_rrefs, self.decoder_rref)) return remote_params + Now, we are ready to implement the training loop. After initializing the model arguments, we create the ``RNNModel`` and the ``DistributedOptimizer``. The distributed optimizer will take a list of parameter ``RRefs``, find all distinct @@ -459,6 +481,7 @@ context has dedicated space to store gradients, and as we create a context per iteration, those gradients from different iterations will not accumulate to the same set of ``Tensors``. + .. code:: python def run_trainer(): From 12925975281f60dc8a036747a9898f71e9ab89dd Mon Sep 17 00:00:00 2001 From: Shen Li Date: Fri, 3 Jan 2020 14:40:32 -0800 Subject: [PATCH 03/14] address comments --- intermediate_source/rpc_tutorial.rst | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/intermediate_source/rpc_tutorial.rst b/intermediate_source/rpc_tutorial.rst index c36c06bb7e2..ab4ce3ebb82 100644 --- a/intermediate_source/rpc_tutorial.rst +++ b/intermediate_source/rpc_tutorial.rst @@ -285,7 +285,11 @@ distinguish their role. Rank 0 is always the agent, and all other ranks are observers. As agent as server as master, repeatedly call ``run_episode`` and ``finish_episode`` until the running reward surpasses the reward threshold specified by the environment. All observers just passively waiting for commands -from the agent. +from the agent. The code is wrapped by +`rpc.init_rpc `__ and +`rpc.shutdown `__, +which initializes and terminates RPC instances respectively. More details are +available in the API page. .. code:: python @@ -319,6 +323,8 @@ from the agent. # other ranks are the observer rpc.init_rpc(OBSERVER_NAME.format(rank), rank=rank, world_size=world_size) # observers passively waiting for instructions from the agent + + # block until all rpcs finish, and shutdown the RPC instance rpc.shutdown() @@ -334,7 +340,7 @@ In this example, we show how to use RPC as the communication vehicle to pass date across workers, and how to use RRef to reference remote objects. It is true that you could build the entire structure directly on top of ``ProcessGroup`` ``send`` and ``recv`` APIs or use other communication/RPC libraries. However, -by using `torch.dstributed.rpc`, you can get the native support plus +by using `torch.distributed.rpc`, you can get the native support plus continuously optimized performance under the hood. Next, we will show how to combine RPC and RRef with distributed autograd and @@ -539,7 +545,7 @@ processes. func() - # block until all rpcs finish + # block until all rpcs finish, and shutdown the RPC instance rpc.shutdown() mp.set_start_method('spawn') From 7c0b10ec19726c6e995938b796c080404395078f Mon Sep 17 00:00:00 2001 From: Shen Li Date: Fri, 3 Jan 2020 15:00:44 -0800 Subject: [PATCH 04/14] address comments --- intermediate_source/rpc_tutorial.rst | 35 +++++++++++++--------------- 1 file changed, 16 insertions(+), 19 deletions(-) diff --git a/intermediate_source/rpc_tutorial.rst b/intermediate_source/rpc_tutorial.rst index ab4ce3ebb82..24a424ecdfd 100644 --- a/intermediate_source/rpc_tutorial.rst +++ b/intermediate_source/rpc_tutorial.rst @@ -442,10 +442,10 @@ local and remote parameters. The helper function is pretty simple, just call .. code:: python def _parameter_rrefs(module): - param_rrefs = [] - for param in module.parameters(): - param_rrefs.append(RRef(param)) - return param_rrefs + param_rrefs = [] + for param in module.parameters(): + param_rrefs.append(RRef(param)) + return param_rrefs Then, as the ``RNNModel`` contains three sub-modules, we need to call @@ -535,24 +535,21 @@ processes. .. code:: python - def run_ps(): - pass - - def run_worker(name, rank, func, world_size): + def run_worker(rank, world_size): os.environ['MASTER_ADDR'] = 'localhost' os.environ['MASTER_PORT'] = '29500' - rpc.init_rpc(name, rank=rank, world_size=world_size) - - func() + if rank == 1: + rpc.init_rpc("trainer", rank=rank, world_size=world_size) + _run_trainer() + else: + rpc.init_rpc("ps", rank=rank, world_size=world_size) + # parameter server do nothing + pass - # block until all rpcs finish, and shutdown the RPC instance + # block until all rpcs finish rpc.shutdown() - mp.set_start_method('spawn') - ps = mp.Process(target=run_worker, args=("ps", 0, run_ps, 2)) - ps.start() - trainer = mp.Process(target=run_worker, args=("trainer", 1, run_trainer, 2)) - trainer.start() - ps.join() - trainer.join() + if __name__=="__main__": + world_size = 2 + mp.spawn(run_worker, args=(world_size, ), nprocs=world_size, join=True) From 6b1d31049ca7b97ecdefeea7a5ee606bbbb5e3de Mon Sep 17 00:00:00 2001 From: Shen Li Date: Thu, 9 Jan 2020 13:56:09 -0800 Subject: [PATCH 05/14] address comments --- intermediate_source/rpc_tutorial.rst | 24 ++++++++++++++++-------- 1 file changed, 16 insertions(+), 8 deletions(-) diff --git a/intermediate_source/rpc_tutorial.rst b/intermediate_source/rpc_tutorial.rst index 24a424ecdfd..5a8f21ca80f 100644 --- a/intermediate_source/rpc_tutorial.rst +++ b/intermediate_source/rpc_tutorial.rst @@ -21,8 +21,9 @@ paradigms: data between observers and the trainer 2) Your model might be too large to fit in GPUs on a single machine, and hence would need a library to help split a model onto multiple machines. Or you - might be implementing a parameter server training framework, where model - parameters and trainers live on different machines. + might be implementing a `parameter server `__ + training framework, where model parameters and trainers live on different + machines. The `torch.distributed.rpc `__ package @@ -360,21 +361,27 @@ borrowed from the word language model in PyTorch `example `__ repository, which contains three main components, an embedding table, an ``LSTM`` layer, and a decoder. The code below wraps the embedding table and the -decode into sub-modules, so that their constructors can be passed to the RPC -API. +decoder into sub-modules, so that their constructors can be passed to the RPC +API. In the `EmbeddingTable` sub-module, we intentionally put the `Embedding` +layer on GPU to demonstrate the use case. In v1.4, RPC always creates CPU tensor +arguments or return values on the destination server. If the function takes a +GPU tensor, you need to move it to the proper device explicitly. .. code:: python class EmbeddingTable(nn.Module): + r""" + Encoding layers of the RNNModel + """ def __init__(self, ntoken, ninp, dropout): super(EmbeddingTable, self).__init__() self.drop = nn.Dropout(dropout) - self.encoder = nn.Embedding(ntoken, ninp) + self.encoder = nn.Embedding(ntoken, ninp).cuda() self.encoder.weight.data.uniform_(-0.1, 0.1) def forward(self, input): - return self.drop(self.encoder(input)) + return self.drop(self.encoder(input.cuda()).cpu() class Decoder(nn.Module): @@ -470,8 +477,9 @@ Then, as the ``RNNModel`` contains three sub-modules, we need to call Now, we are ready to implement the training loop. After initializing the model arguments, we create the ``RNNModel`` and the ``DistributedOptimizer``. The distributed optimizer will take a list of parameter ``RRefs``, find all distinct -owner workers, and create the given local optimizer (i.e., ``SGD`` in this case) -on each of the owner worker using the given arguments (i.e., ``lr=0.05``). +owner workers, and create the given local optimizer (i.e., ``SGD`` in this case, +you can use other local optimizers as well) on each of the owner worker using +the given arguments (i.e., ``lr=0.05``). In the training loop, it first creates a distributed autograd context, which will help the distributed autograd engine to find gradients and involved RPC From 2ac179473ca0831a9f49dff38691280e686d6fbd Mon Sep 17 00:00:00 2001 From: Shen Li Date: Thu, 9 Jan 2020 14:09:12 -0800 Subject: [PATCH 06/14] Add sample training outputs for RL example --- intermediate_source/rpc_tutorial.rst | 35 ++++++++++++++++++++++++++++ 1 file changed, 35 insertions(+) diff --git a/intermediate_source/rpc_tutorial.rst b/intermediate_source/rpc_tutorial.rst index 5a8f21ca80f..e0a26c8de8d 100644 --- a/intermediate_source/rpc_tutorial.rst +++ b/intermediate_source/rpc_tutorial.rst @@ -336,6 +336,41 @@ available in the API page. join=True ) +Below are some sample outputs when training with `world_size=2`. + +:: + + Episode 10 Last reward: 26.00 Average reward: 10.01 + Episode 20 Last reward: 16.00 Average reward: 11.27 + Episode 30 Last reward: 49.00 Average reward: 18.62 + Episode 40 Last reward: 45.00 Average reward: 26.09 + Episode 50 Last reward: 44.00 Average reward: 30.03 + Episode 60 Last reward: 111.00 Average reward: 42.23 + Episode 70 Last reward: 131.00 Average reward: 70.11 + Episode 80 Last reward: 87.00 Average reward: 76.51 + Episode 90 Last reward: 86.00 Average reward: 95.93 + Episode 100 Last reward: 13.00 Average reward: 123.93 + Episode 110 Last reward: 33.00 Average reward: 91.39 + Episode 120 Last reward: 73.00 Average reward: 76.38 + Episode 130 Last reward: 137.00 Average reward: 88.08 + Episode 140 Last reward: 89.00 Average reward: 104.96 + Episode 150 Last reward: 97.00 Average reward: 98.74 + Episode 160 Last reward: 150.00 Average reward: 100.87 + Episode 170 Last reward: 126.00 Average reward: 104.38 + Episode 180 Last reward: 500.00 Average reward: 213.74 + Episode 190 Last reward: 322.00 Average reward: 300.22 + Episode 200 Last reward: 165.00 Average reward: 272.71 + Episode 210 Last reward: 168.00 Average reward: 233.11 + Episode 220 Last reward: 184.00 Average reward: 195.02 + Episode 230 Last reward: 284.00 Average reward: 208.32 + Episode 240 Last reward: 395.00 Average reward: 247.37 + Episode 250 Last reward: 500.00 Average reward: 335.42 + Episode 260 Last reward: 500.00 Average reward: 386.30 + Episode 270 Last reward: 500.00 Average reward: 405.29 + Episode 280 Last reward: 500.00 Average reward: 443.29 + Episode 290 Last reward: 500.00 Average reward: 464.65 + Solved! Running reward is now 475.3163778435275! + In this example, we show how to use RPC as the communication vehicle to pass date across workers, and how to use RRef to reference remote objects. It is true From 62f6fd2cc494ad1955ec81b764298f908c0fdd72 Mon Sep 17 00:00:00 2001 From: Shen Li Date: Fri, 10 Jan 2020 08:18:46 -0800 Subject: [PATCH 07/14] Address comments --- intermediate_source/rpc_tutorial.rst | 33 ++++++++++++++++++---------- 1 file changed, 21 insertions(+), 12 deletions(-) diff --git a/intermediate_source/rpc_tutorial.rst b/intermediate_source/rpc_tutorial.rst index e0a26c8de8d..d614fd3a92b 100644 --- a/intermediate_source/rpc_tutorial.rst +++ b/intermediate_source/rpc_tutorial.rst @@ -555,21 +555,30 @@ the same set of ``Tensors``. lr=0.05, ) + criterion = torch.nn.CrossEntropyLoss() + + def get_next_batch(): + for _ in range(5): + data = torch.LongTensor(batch, nindices) % ntoken + target = torch.LongTensor(batch, ntoken) % nindices + yield data, target + # train for 10 iterations for epoch in range(10): # create distributed autograd context - with dist_autograd.context(): - inp = torch.LongTensor(batch, nindices) % ntoken - hidden[0].detach_() - hidden[1].detach_() - output, hidden = model(inp, hidden) - # run distributed backward pass - dist_autograd.backward([output.sum()]) - # run distributed optimizer - opt.step() - # not necessary to zero grads as each iteration creates a different - # distributed autograd context which hosts different grads - print("Training epoch {}".format(epoch)) + for data, target in get_next_batch(): + with dist_autograd.context(): + hidden[0].detach_() + hidden[1].detach_() + output, hidden = model(data, hidden) + loss = criterion(output, target) + # run distributed backward pass + dist_autograd.backward([loss]) + # run distributed optimizer + opt.step() + # not necessary to zero grads as each iteration creates a different + # distributed autograd context which hosts different grads + print("Training epoch {}".format(epoch)) Finally, let's add some glue code to launch the parameter server and the trainer From 755bd2fa838ce31d109b03fee0778fe81b685d51 Mon Sep 17 00:00:00 2001 From: Shen Li Date: Fri, 10 Jan 2020 08:44:35 -0800 Subject: [PATCH 08/14] Fix typos --- intermediate_source/rpc_tutorial.rst | 72 ++++++++++++++-------------- 1 file changed, 37 insertions(+), 35 deletions(-) diff --git a/intermediate_source/rpc_tutorial.rst b/intermediate_source/rpc_tutorial.rst index d614fd3a92b..b93a5331f63 100644 --- a/intermediate_source/rpc_tutorial.rst +++ b/intermediate_source/rpc_tutorial.rst @@ -4,23 +4,26 @@ Getting Started with Distributed RPC Framework This tutorial uses two simple examples to demonstrate how to build distributed -applications with the `torch.distributed.rpc` package. Source code of the two -examples can be found in `PyTorch examples `__ +training with the `torch.distributed.rpc `__ +package. Source code of the two examples can be found in +`PyTorch examples `__ -Previous tutorials described `DistributedDataParallel `__ +`Previous `__ +`tutorials `__ +described `DistributedDataParallel `__ which supports a specific training paradigm where the model is replicated across multiple processes and each process handles a split of the input data. Sometimes, you might run into scenarios that require different training -paradigms: +paradigms. For example: 1) In reinforcement learning, it might be relatively expensive to acquire training data from environments while the model itself can be quite small. In this case, it might be useful to spawn multiple observers running in parallel and share a single agent. In this case, the agent takes care of the training locally, but the application would still need libraries to send and receive - data between observers and the trainer + data between observers and the trainer. 2) Your model might be too large to fit in GPUs on a single machine, and hence - would need a library to help split a model onto multiple machines. Or you + would need a library to help split the model onto multiple machines. Or you might be implementing a `parameter server `__ training framework, where model parameters and trainers live on different machines. @@ -28,16 +31,16 @@ paradigms: The `torch.distributed.rpc `__ package can help with the above scenarios. In case 1, `RPC `__ -and `RRef `__ can help send data -from one worker to another and also easily referencing remote data objects. In +and `RRef `__ allow sending data +from one worker to another while easily referencing remote data objects. In case 2, `distributed autograd `__ and `distributed optimizer `__ -allows executing backward and optimizer step as if it is local training. In the -next two sections, we will demonstrate APIs of +make executing backward pass and optimizer step as if it is local training. In +the next two sections, we will demonstrate APIs of `torch.distributed.rpc `__ using a reinforcement learning example and a language model example. Please note, this -tutorial is not aiming at building the most accurate or efficient models to -solve given problems, instead the main goal is to show how to use the +tutorial does not aim at building the most accurate or efficient models to +solve given problems, instead, the main goal here is to show how to use the `torch.distributed.rpc `__ package to build distributed training applications. @@ -76,12 +79,13 @@ usages. action_scores = self.affine2(x) return F.softmax(action_scores, dim=1) -Let's first prepare a helper function to call a function on a local ``RRef``. It -might look unnecessary at the first glance, as you could simply do -``rref.local_value().some_func(args)`` to run the target function. The reason -for adding this helper function is because there is no way to get a reference -of a remote value, and ``local_value`` is only available on the owner of the -``RRef``. +Let's first prepare a helper to run functions remotely on the owner worker of an +``RRef``. You will find this function been used in several places this +tutorial's examples. Ideally, the `torch.distributed.rpc` package should provide +these helper functions out of box. For example, it will be easier if +applications can directly call ``RRef.some_func(*arg)`` which will then +translate to RPC to the ``RRef`` owner. The progress on this API is tracked in +`pytorch/pytorch#31743 `__. .. code:: python @@ -99,13 +103,6 @@ of a remote value, and ``local_value`` is only available on the owner of the # _remote_method(some_func, rref, *args) -Ideally, the `torch.distributed.rpc` package should provide these helper -functions out of box. For example, it will be easier if applications can -directly call ``RRef.some_func(*arg)`` which will then translate to RPC to the -``RRef`` owner. The progress on this API is tracked in in -`pytorch/pytorch#31743 `__. - - We are ready to present the observer. In this example, each observer creates its own environment, and waits for the agent's command to run an episode. In each episode, one observer loops at most ``n_steps`` iterations, and in each @@ -116,7 +113,8 @@ RPC to report the reward to the agent. Again, please note that, this is obviously not the most efficient observer implementation. For example, one simple optimization could be packing current state and last reward in one RPC to reduce the communication overhead. However, the goal is to demonstrate RPC API -instead of building the best solver for CartPole. +instead of building the best solver for CartPole. So, let's keep the logic +simple and the two steps explicit in this example. .. code:: python @@ -152,9 +150,13 @@ such that it sends command to multiple distributed observers to run episodes, and it also records all actions and rewards locally which will be used during the training phase after each episode. The code below shows ``Agent`` constructor where most lines are initializing various components. The loop at -the end initializes observers on other workers, and holds ``RRefs`` to those -observers locally. The agent will use those observer ``RRefs`` later to send -commands. +the end initializes observers remotely on other workers, and holds ``RRefs`` to +those observers locally. The agent will use those observer ``RRefs`` later to +send commands. Applications don't need to worry about the lifetime of ``RRefs``. +The owner of each ``RRef`` maintains a reference counting map to track it's +lifetime, and guarantees the remote data object will not be deleted as long as +there is any live user of that ``RRef``. Please refer to the ``RRef`` +`design doc `__ for details. .. code:: python @@ -186,9 +188,9 @@ commands. self.saved_log_probs[ob_info.id] = [] -Next, the agent exposes two APIs to allow observers to select actions and report -rewards. Those functions are only run locally on the agent, but will be -triggered by observers through RPC. +Next, the agent exposes two APIs to observers for selecting actions and +reporting rewards. Those functions are only run locally on the agent, but will +be triggered by observers through RPC. .. code:: python @@ -212,9 +214,9 @@ to execute an episode. In this function, it first creates a list to collect futures from asynchronous RPCs, and then loop over all observer ``RRefs`` to make asynchronous RPCs. In these RPCs, the agent also passes an ``RRef`` of itself to the observer, so that the observer can call functions on the agent as -well. As shown above, each observer will make RPCs back to the agent, which is -actually nested RPCs. After each episode, the ``saved_log_probs`` and -``rewards`` will contain the recorded action probs and rewards. +well. As shown above, each observer will make RPCs back to the agent, which are +nested RPCs. After each episode, the ``saved_log_probs`` and ``rewards`` will +contain the recorded action probs and rewards. .. code:: python From 110920f7c84638b5bf3c2c9ad2be8e25a47d71b9 Mon Sep 17 00:00:00 2001 From: Shen Li Date: Fri, 10 Jan 2020 09:03:42 -0800 Subject: [PATCH 09/14] fix typos --- intermediate_source/rpc_tutorial.rst | 69 +++++++++++++++------------- 1 file changed, 38 insertions(+), 31 deletions(-) diff --git a/intermediate_source/rpc_tutorial.rst b/intermediate_source/rpc_tutorial.rst index b93a5331f63..dd86be6579e 100644 --- a/intermediate_source/rpc_tutorial.rst +++ b/intermediate_source/rpc_tutorial.rst @@ -5,7 +5,8 @@ Getting Started with Distributed RPC Framework This tutorial uses two simple examples to demonstrate how to build distributed training with the `torch.distributed.rpc `__ -package. Source code of the two examples can be found in +package which is first introduced as an experimental feature in PyTorch v1.4. +Source code of the two examples can be found in `PyTorch examples `__ `Previous `__ @@ -189,7 +190,7 @@ there is any live user of that ``RRef``. Please refer to the ``RRef`` Next, the agent exposes two APIs to observers for selecting actions and -reporting rewards. Those functions are only run locally on the agent, but will +reporting rewards. Those functions only run locally on the agent, but will be triggered by observers through RPC. @@ -240,9 +241,10 @@ contain the recorded action probs and rewards. Finally, after one episode, the agent needs to train the model, which -is implemented in the ``finish_episode`` function below. It is also a local -function and mostly borrowed from the single-thread +is implemented in the ``finish_episode`` function below. There is no RPCs in +this function and it is mostly borrowed from the single-thread `example `__. +Hence, we skip describing its contents. @@ -285,14 +287,14 @@ With ``Policy``, ``Observer``, and ``Agent`` classes, we are ready to launch multiple processes to perform the distributed training. In this example, all processes run the same ``run_worker`` function, and they use the rank to distinguish their role. Rank 0 is always the agent, and all other ranks are -observers. As agent as server as master, repeatedly call ``run_episode`` and +observers. The agent serves as master by repeatedly calling ``run_episode`` and ``finish_episode`` until the running reward surpasses the reward threshold -specified by the environment. All observers just passively waiting for commands +specified by the environment. All observers passively waiting for commands from the agent. The code is wrapped by `rpc.init_rpc `__ and `rpc.shutdown `__, which initializes and terminates RPC instances respectively. More details are -available in the API page. +available in the `API page `__. .. code:: python @@ -375,10 +377,10 @@ Below are some sample outputs when training with `world_size=2`. In this example, we show how to use RPC as the communication vehicle to pass -date across workers, and how to use RRef to reference remote objects. It is true +data across workers, and how to use RRef to reference remote objects. It is true that you could build the entire structure directly on top of ``ProcessGroup`` ``send`` and ``recv`` APIs or use other communication/RPC libraries. However, -by using `torch.distributed.rpc`, you can get the native support plus +by using `torch.distributed.rpc`, you can get the native support and continuously optimized performance under the hood. Next, we will show how to combine RPC and RRef with distributed autograd and @@ -386,22 +388,24 @@ distributed optimizer to perform distributed model parallel training. - Distributed RNN using Distributed Autograd and Distributed Optimizer -------------------------------------------------------------------- In this section, we use an RNN model to show how to build distributed model -parallel training using the RPC API. The example RNN model is very small and -easily fit into a single GPU, but developer can apply the similar techniques to -much larger models that need to span multiple devices. The RNN model design is -borrowed from the word language model in PyTorch +parallel training with the RPC API. The example RNN model is very small and +can easily fit into a single GPU, but we still divide its layers onto two +different workers to demonstrate the idea. Developer can apply the similar +techniques to distribute much larger models across multiple devices and +machines. + +The RNN model design is borrowed from the word language model in PyTorch `example `__ repository, which contains three main components, an embedding table, an ``LSTM`` layer, and a decoder. The code below wraps the embedding table and the decoder into sub-modules, so that their constructors can be passed to the RPC API. In the `EmbeddingTable` sub-module, we intentionally put the `Embedding` -layer on GPU to demonstrate the use case. In v1.4, RPC always creates CPU tensor -arguments or return values on the destination server. If the function takes a +layer on GPU to cover the use case. In v1.4, RPC always creates CPU tensor +arguments or return values on the destination worker. If the function takes a GPU tensor, you need to move it to the proper device explicitly. @@ -437,7 +441,7 @@ With the above sub-modules, we can now piece them together using RPC to create an RNN model. In the code below ``ps`` represents a parameter server, which hosts parameters of the embedding table and the decoder. The constructor uses the `remote `__ -API to create an `EmbeddingTable` and a `Decoder` object on the parameter +API to create an `EmbeddingTable` object and a `Decoder` object on the parameter server, and locally creates the ``LSTM`` sub-module. During the forward pass, the trainer uses the ``EmbeddingTable`` ``RRef`` to find the remote sub-module and passes the input data to the ``EmbeddingTable`` using RPC and fetches the @@ -475,12 +479,13 @@ Before introducing the distributed optimizer, let's add a helper function to generate a list of RRefs of model parameters, which will be consumed by the distributed optimizer. In local training, applications could call ``Module.parameters()`` to grab references to all parameter tensors, and pass it -to the local optimizer to update. However, the same API does not work in -the distributed training scenarios as some parameters live on remote machines. -Therefore, instead of taking a list of parameter ``Tensors``, the distributed -optimizer takes a list of ``RRefs``, one ``RRef`` per model parameter for both -local and remote parameters. The helper function is pretty simple, just call -``Module.parameters()`` and creates a local ``RRef`` on each of the parameters. +to the local optimizer for subsequent updates. However, the same API does not +work in distributed training scenarios as some parameters live on remote +machines. Therefore, instead of taking a list of parameter ``Tensors``, the +distributed optimizer takes a list of ``RRefs``, one ``RRef`` per model +parameter for both local and remote model parameters. The helper function is +pretty simple, just call ``Module.parameters()`` and creates a local ``RRef`` on +each of the parameters. .. code:: python @@ -511,7 +516,7 @@ Then, as the ``RNNModel`` contains three sub-modules, we need to call return remote_params -Now, we are ready to implement the training loop. After initializing the model +Now, we are ready to implement the training loop. After initializing model arguments, we create the ``RNNModel`` and the ``DistributedOptimizer``. The distributed optimizer will take a list of parameter ``RRefs``, find all distinct owner workers, and create the given local optimizer (i.e., ``SGD`` in this case, @@ -520,17 +525,19 @@ the given arguments (i.e., ``lr=0.05``). In the training loop, it first creates a distributed autograd context, which will help the distributed autograd engine to find gradients and involved RPC -send/recv functions. Then, it kicks off the forward pass as if it is a local +send/recv functions. The design details of the distributed autograd engine can +be found in its `design note `__. +Then, it kicks off the forward pass as if it is a local model, and run the distributed backward pass. For the distributed backward, you only need to specify a list of roots, in this case, it is the loss ``Tensor``. The distributed autograd engine will traverse the distributed graph automatically and write gradients properly. Next, it runs the ``step`` -API on the distributed optimizer, which will reach out to all involved local -optimizers to update model parameters. Compared to local training, one minor -difference is that you don't need to run ``zero_grad()`` because each autograd -context has dedicated space to store gradients, and as we create a context -per iteration, those gradients from different iterations will not accumulate to -the same set of ``Tensors``. +function on the distributed optimizer, which will reach out to all involved +local optimizers to update model parameters. Compared to local training, one +minor difference is that you don't need to run ``zero_grad()`` because each +autograd context has dedicated space to store gradients, and as we create a +context per iteration, those gradients from different iterations will not +accumulate to the same set of ``Tensors``. .. code:: python From 2a5725672c3437f2f52b8b5b94416ffef79351f6 Mon Sep 17 00:00:00 2001 From: Shen Li Date: Fri, 10 Jan 2020 11:02:52 -0800 Subject: [PATCH 10/14] Add warning --- intermediate_source/rpc_tutorial.rst | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/intermediate_source/rpc_tutorial.rst b/intermediate_source/rpc_tutorial.rst index dd86be6579e..4cfcc127669 100644 --- a/intermediate_source/rpc_tutorial.rst +++ b/intermediate_source/rpc_tutorial.rst @@ -3,6 +3,11 @@ Getting Started with Distributed RPC Framework **Author**: `Shen Li `_ +.. warning:: + The `torch.distributed.rpc `__ package + is experimental and subject to change. + + This tutorial uses two simple examples to demonstrate how to build distributed training with the `torch.distributed.rpc `__ package which is first introduced as an experimental feature in PyTorch v1.4. From c66913a690faf8e45bf30bfb9693b3185f39603d Mon Sep 17 00:00:00 2001 From: Shen Li Date: Fri, 10 Jan 2020 13:25:49 -0800 Subject: [PATCH 11/14] Address Comments --- intermediate_source/rpc_tutorial.rst | 40 +++++++++++++++------------- 1 file changed, 21 insertions(+), 19 deletions(-) diff --git a/intermediate_source/rpc_tutorial.rst b/intermediate_source/rpc_tutorial.rst index 4cfcc127669..6808cb0f3eb 100644 --- a/intermediate_source/rpc_tutorial.rst +++ b/intermediate_source/rpc_tutorial.rst @@ -14,8 +14,9 @@ package which is first introduced as an experimental feature in PyTorch v1.4. Source code of the two examples can be found in `PyTorch examples `__ -`Previous `__ -`tutorials `__ +Previous tutorials, +`Getting Started With Distributed Data Parallel `__ + and `Writing Distributed Applications With PyTorch `__, described `DistributedDataParallel `__ which supports a specific training paradigm where the model is replicated across multiple processes and each process handles a split of the input data. @@ -86,7 +87,7 @@ usages. return F.softmax(action_scores, dim=1) Let's first prepare a helper to run functions remotely on the owner worker of an -``RRef``. You will find this function been used in several places this +``RRef``. You will find this function being used in several places this tutorial's examples. Ideally, the `torch.distributed.rpc` package should provide these helper functions out of box. For example, it will be easier if applications can directly call ``RRef.some_func(*arg)`` which will then @@ -159,7 +160,7 @@ constructor where most lines are initializing various components. The loop at the end initializes observers remotely on other workers, and holds ``RRefs`` to those observers locally. The agent will use those observer ``RRefs`` later to send commands. Applications don't need to worry about the lifetime of ``RRefs``. -The owner of each ``RRef`` maintains a reference counting map to track it's +The owner of each ``RRef`` maintains a reference counting map to track its lifetime, and guarantees the remote data object will not be deleted as long as there is any live user of that ``RRef``. Please refer to the ``RRef`` `design doc `__ for details. @@ -408,10 +409,10 @@ The RNN model design is borrowed from the word language model in PyTorch repository, which contains three main components, an embedding table, an ``LSTM`` layer, and a decoder. The code below wraps the embedding table and the decoder into sub-modules, so that their constructors can be passed to the RPC -API. In the `EmbeddingTable` sub-module, we intentionally put the `Embedding` -layer on GPU to cover the use case. In v1.4, RPC always creates CPU tensor -arguments or return values on the destination worker. If the function takes a -GPU tensor, you need to move it to the proper device explicitly. +API. In the ``EmbeddingTable`` sub-module, we intentionally put the +``Embedding`` layer on GPU to cover the use case. In v1.4, RPC always creates +CPU tensor arguments or return values on the destination worker. If the function +takes a GPU tensor, you need to move it to the proper device explicitly. .. code:: python @@ -446,17 +447,18 @@ With the above sub-modules, we can now piece them together using RPC to create an RNN model. In the code below ``ps`` represents a parameter server, which hosts parameters of the embedding table and the decoder. The constructor uses the `remote `__ -API to create an `EmbeddingTable` object and a `Decoder` object on the parameter -server, and locally creates the ``LSTM`` sub-module. During the forward pass, -the trainer uses the ``EmbeddingTable`` ``RRef`` to find the remote sub-module -and passes the input data to the ``EmbeddingTable`` using RPC and fetches the -lookup results. Then, it runs the embedding through the local ``LSTM`` layer, -and finally uses another RPC to send the output to the ``Decoder`` sub-module. -In general, to implement distributed model parallel training, developers can -divide the model into sub-modules, invoke RPC to create sub-module instances -remotely, and use on ``RRef`` to find them when necessary. As you can see in the -code below, it looks very similar to single-machine model parallel training. The -main difference is replacing ``Tensor.to(device)`` with RPC functions. +API to create an ``EmbeddingTable`` object and a ``Decoder`` object on the +parameter server, and locally creates the ``LSTM`` sub-module. During the +forward pass, the trainer uses the ``EmbeddingTable`` ``RRef`` to find the +remote sub-module and passes the input data to the ``EmbeddingTable`` using RPC +and fetches the lookup results. Then, it runs the embedding through the local +``LSTM`` layer, and finally uses another RPC to send the output to the +``Decoder`` sub-module. In general, to implement distributed model parallel +training, developers can divide the model into sub-modules, invoke RPC to create +sub-module instances remotely, and use on ``RRef`` to find them when necessary. +As you can see in the code below, it looks very similar to single-machine model +parallel training. The main difference is replacing ``Tensor.to(device)`` with +RPC functions. .. code:: python From 6a62f96e5ce3637b8a3ac0425e1a90d56e5278b3 Mon Sep 17 00:00:00 2001 From: Shen Li Date: Fri, 10 Jan 2020 14:44:50 -0800 Subject: [PATCH 12/14] Distinguish single-machine and distributed model parallel --- intermediate_source/model_parallel_tutorial.py | 9 ++++++++- intermediate_source/rpc_tutorial.rst | 4 ++-- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/intermediate_source/model_parallel_tutorial.py b/intermediate_source/model_parallel_tutorial.py index e35f443a469..4cc812c8127 100644 --- a/intermediate_source/model_parallel_tutorial.py +++ b/intermediate_source/model_parallel_tutorial.py @@ -1,6 +1,6 @@ # -*- coding: utf-8 -*- """ -Model Parallel Best Practices +Single-Machine Model Parallel Best Practices ================================ **Author**: `Shen Li `_ @@ -27,6 +27,13 @@ of model parallel. It is up to the readers to apply the ideas to real-world applications. +.. note:: + + For distributed model parallel training where a model spans multiple + servers, please refer to + `Getting Started With Distributed RPC Framework __ + for examples and details. + Basic Usage ----------- """ diff --git a/intermediate_source/rpc_tutorial.rst b/intermediate_source/rpc_tutorial.rst index 6808cb0f3eb..5844c4c7027 100644 --- a/intermediate_source/rpc_tutorial.rst +++ b/intermediate_source/rpc_tutorial.rst @@ -12,10 +12,10 @@ This tutorial uses two simple examples to demonstrate how to build distributed training with the `torch.distributed.rpc `__ package which is first introduced as an experimental feature in PyTorch v1.4. Source code of the two examples can be found in -`PyTorch examples `__ +`PyTorch examples `__. Previous tutorials, -`Getting Started With Distributed Data Parallel `__ +`Getting Started With Distributed Data Parallel `__ and `Writing Distributed Applications With PyTorch `__, described `DistributedDataParallel `__ which supports a specific training paradigm where the model is replicated across From 6c7d597580d59507d32218a5be786968e68e4432 Mon Sep 17 00:00:00 2001 From: Shen Li Date: Fri, 10 Jan 2020 14:59:13 -0800 Subject: [PATCH 13/14] User relative URLs --- intermediate_source/model_parallel_tutorial.py | 2 +- intermediate_source/rpc_tutorial.rst | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/intermediate_source/model_parallel_tutorial.py b/intermediate_source/model_parallel_tutorial.py index 4cc812c8127..515b689301a 100644 --- a/intermediate_source/model_parallel_tutorial.py +++ b/intermediate_source/model_parallel_tutorial.py @@ -31,7 +31,7 @@ For distributed model parallel training where a model spans multiple servers, please refer to - `Getting Started With Distributed RPC Framework __ + `Getting Started With Distributed RPC Framework `__ for examples and details. Basic Usage diff --git a/intermediate_source/rpc_tutorial.rst b/intermediate_source/rpc_tutorial.rst index 5844c4c7027..60158c47da1 100644 --- a/intermediate_source/rpc_tutorial.rst +++ b/intermediate_source/rpc_tutorial.rst @@ -16,7 +16,7 @@ Source code of the two examples can be found in Previous tutorials, `Getting Started With Distributed Data Parallel `__ - and `Writing Distributed Applications With PyTorch `__, +and `Writing Distributed Applications With PyTorch `__, described `DistributedDataParallel `__ which supports a specific training paradigm where the model is replicated across multiple processes and each process handles a split of the input data. From 08acbc3be15dd40c777b7bfad8be76f822462023 Mon Sep 17 00:00:00 2001 From: Shen Li Date: Fri, 10 Jan 2020 15:04:20 -0800 Subject: [PATCH 14/14] Address comments --- intermediate_source/rpc_tutorial.rst | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/intermediate_source/rpc_tutorial.rst b/intermediate_source/rpc_tutorial.rst index 60158c47da1..d2deb9a5d9e 100644 --- a/intermediate_source/rpc_tutorial.rst +++ b/intermediate_source/rpc_tutorial.rst @@ -125,9 +125,21 @@ simple and the two steps explicit in this example. .. code:: python + import argparse import gym import torch.distributed.rpc as rpc + parser = argparse.ArgumentParser( + description="RPC Reinforcement Learning Example", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + + parser.add_argument('--world_size', default=2, help='Number of workers') + parser.add_argument('--log_interval', default=1, help='Log every log_interval episodes') + parser.add_argument('--gamma', default=0.1, help='how much to value future rewards') + parser.add_argument('--seed', default=1, help='random seed for reproducibility') + args = parser.parse_args() + class Observer: def __init__(self): @@ -231,6 +243,7 @@ contain the recorded action probs and rewards. class Agent: ... def run_episode(self, n_steps=0): + futs = [] for ob_rref in self.ob_rrefs: # make async RPC to kick off an episode on all observers futs.append( @@ -310,6 +323,10 @@ available in the `API page `__. import torch.multiprocessing as mp + AGENT_NAME = "agent" + OBSERVER_NAME="obs" + TOTAL_EPISODE_STEP = 100 + def run_worker(rank, world_size): os.environ['MASTER_ADDR'] = 'localhost' os.environ['MASTER_PORT'] = '29500'