Skip to content

Commit 08acbc3

Browse files
committed
Address comments
1 parent 6c7d597 commit 08acbc3

File tree

1 file changed

+17
-0
lines changed

1 file changed

+17
-0
lines changed

intermediate_source/rpc_tutorial.rst

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,9 +125,21 @@ simple and the two steps explicit in this example.
125125

126126
.. code:: python
127127
128+
import argparse
128129
import gym
129130
import torch.distributed.rpc as rpc
130131
132+
parser = argparse.ArgumentParser(
133+
description="RPC Reinforcement Learning Example",
134+
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
135+
)
136+
137+
parser.add_argument('--world_size', default=2, help='Number of workers')
138+
parser.add_argument('--log_interval', default=1, help='Log every log_interval episodes')
139+
parser.add_argument('--gamma', default=0.1, help='how much to value future rewards')
140+
parser.add_argument('--seed', default=1, help='random seed for reproducibility')
141+
args = parser.parse_args()
142+
131143
class Observer:
132144
133145
def __init__(self):
@@ -231,6 +243,7 @@ contain the recorded action probs and rewards.
231243
class Agent:
232244
...
233245
def run_episode(self, n_steps=0):
246+
futs = []
234247
for ob_rref in self.ob_rrefs:
235248
# make async RPC to kick off an episode on all observers
236249
futs.append(
@@ -310,6 +323,10 @@ available in the `API page <https://pytorch.org/docs/master/rpc.html>`__.
310323
311324
import torch.multiprocessing as mp
312325
326+
AGENT_NAME = "agent"
327+
OBSERVER_NAME="obs"
328+
TOTAL_EPISODE_STEP = 100
329+
313330
def run_worker(rank, world_size):
314331
os.environ['MASTER_ADDR'] = 'localhost'
315332
os.environ['MASTER_PORT'] = '29500'

0 commit comments

Comments
 (0)