File tree Expand file tree Collapse file tree 1 file changed +3
-3
lines changed Expand file tree Collapse file tree 1 file changed +3
-3
lines changed Original file line number Diff line number Diff line change @@ -51,15 +51,15 @@ def get_next_batch():
5151 for epoch in range (10 ):
5252 # create distributed autograd context
5353 for data , target in get_next_batch ():
54- with dist_autograd .context ():
54+ with dist_autograd .context () as context_id :
5555 hidden [0 ].detach_ ()
5656 hidden [1 ].detach_ ()
5757 output , hidden = model (data , hidden )
5858 loss = criterion (output , target )
5959 # run distributed backward pass
60- dist_autograd .backward ([loss ])
60+ dist_autograd .backward (context_id , [loss ])
6161 # run distributed optimizer
62- opt .step ()
62+ opt .step (context_id )
6363 # not necessary to zero grads as each iteration creates a different
6464 # distributed autograd context which hosts different grads
6565 print ("Training epoch {}" .format (epoch ))
You can’t perform that action at this time.
0 commit comments