Skip to content

Commit 234bcff

Browse files
author
Jessica Lin
authored
Merge pull request #733 from osalpekar/dist_autograd_update
[Dist Autograd - API Change] Updated dist_autograd and dist_optim to be functional
2 parents 8a5b379 + 55506b5 commit 234bcff

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

distributed/rpc/rnn/main.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -51,15 +51,15 @@ def get_next_batch():
5151
for epoch in range(10):
5252
# create distributed autograd context
5353
for data, target in get_next_batch():
54-
with dist_autograd.context():
54+
with dist_autograd.context() as context_id:
5555
hidden[0].detach_()
5656
hidden[1].detach_()
5757
output, hidden = model(data, hidden)
5858
loss = criterion(output, target)
5959
# run distributed backward pass
60-
dist_autograd.backward([loss])
60+
dist_autograd.backward(context_id, [loss])
6161
# run distributed optimizer
62-
opt.step()
62+
opt.step(context_id)
6363
# not necessary to zero grads as each iteration creates a different
6464
# distributed autograd context which hosts different grads
6565
print("Training epoch {}".format(epoch))

0 commit comments

Comments
 (0)