diff --git a/distributed/rpc/rnn/main.py b/distributed/rpc/rnn/main.py index e4d3309a98..6c14e421c5 100644 --- a/distributed/rpc/rnn/main.py +++ b/distributed/rpc/rnn/main.py @@ -51,15 +51,15 @@ def get_next_batch(): for epoch in range(10): # create distributed autograd context for data, target in get_next_batch(): - with dist_autograd.context(): + with dist_autograd.context() as context_id: hidden[0].detach_() hidden[1].detach_() output, hidden = model(data, hidden) loss = criterion(output, target) # run distributed backward pass - dist_autograd.backward([loss]) + dist_autograd.backward(context_id, [loss]) # run distributed optimizer - opt.step() + opt.step(context_id) # not necessary to zero grads as each iteration creates a different # distributed autograd context which hosts different grads print("Training epoch {}".format(epoch))