@@ -125,9 +125,21 @@ simple and the two steps explicit in this example.
125125
126126.. code :: python
127127
128+ import argparse
128129 import gym
129130 import torch.distributed.rpc as rpc
130131
132+ parser = argparse.ArgumentParser(
133+ description = " RPC Reinforcement Learning Example" ,
134+ formatter_class = argparse.ArgumentDefaultsHelpFormatter,
135+ )
136+
137+ parser.add_argument(' --world_size' , default = 2 , help = ' Number of workers' )
138+ parser.add_argument(' --log_interval' , default = 1 , help = ' Log every log_interval episodes' )
139+ parser.add_argument(' --gamma' , default = 0.1 , help = ' how much to value future rewards' )
140+ parser.add_argument(' --seed' , default = 1 , help = ' random seed for reproducibility' )
141+ args = parser.parse_args()
142+
131143 class Observer :
132144
133145 def __init__ (self ):
@@ -231,6 +243,7 @@ contain the recorded action probs and rewards.
231243 class Agent :
232244 ...
233245 def run_episode (self , n_steps = 0 ):
246+ futs = []
234247 for ob_rref in self .ob_rrefs:
235248 # make async RPC to kick off an episode on all observers
236249 futs.append(
@@ -310,6 +323,10 @@ available in the `API page <https://pytorch.org/docs/master/rpc.html>`__.
310323
311324 import torch.multiprocessing as mp
312325
326+ AGENT_NAME = " agent"
327+ OBSERVER_NAME = " obs"
328+ TOTAL_EPISODE_STEP = 100
329+
313330 def run_worker (rank , world_size ):
314331 os.environ[' MASTER_ADDR' ] = ' localhost'
315332 os.environ[' MASTER_PORT' ] = ' 29500'
0 commit comments