Skip to content

Commit 8816366

Browse files
authored
Merge pull request #577 from jiangzhonglian/master
chatbot_tutorial.py: Solve the optimizer cuda call problem
2 parents 879845f + a473681 commit 8816366

File tree

1 file changed

+11
-0
lines changed

1 file changed

+11
-0
lines changed

beginner_source/chatbot_tutorial.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1330,6 +1330,17 @@ def evaluateInput(encoder, decoder, searcher, voc):
13301330
encoder_optimizer.load_state_dict(encoder_optimizer_sd)
13311331
decoder_optimizer.load_state_dict(decoder_optimizer_sd)
13321332

1333+
# If you have cuda, configure cuda to call
1334+
for state in encoder_optimizer.state.values():
1335+
for k, v in state.items():
1336+
if isinstance(v, torch.Tensor):
1337+
state[k] = v.cuda()
1338+
1339+
for state in decoder_optimizer.state.values():
1340+
for k, v in state.items():
1341+
if isinstance(v, torch.Tensor):
1342+
state[k] = v.cuda()
1343+
13331344
# Run training iterations
13341345
print("Starting Training!")
13351346
trainIters(model_name, voc, pairs, encoder, decoder, encoder_optimizer, decoder_optimizer,

0 commit comments

Comments
 (0)