diff --git a/beginner_source/chatbot_tutorial.py b/beginner_source/chatbot_tutorial.py index 7c1a7d1c6fc..991087b6c56 100644 --- a/beginner_source/chatbot_tutorial.py +++ b/beginner_source/chatbot_tutorial.py @@ -1330,6 +1330,17 @@ def evaluateInput(encoder, decoder, searcher, voc): encoder_optimizer.load_state_dict(encoder_optimizer_sd) decoder_optimizer.load_state_dict(decoder_optimizer_sd) +# If you have cuda, configure cuda to call +for state in encoder_optimizer.state.values(): + for k, v in state.items(): + if isinstance(v, torch.Tensor): + state[k] = v.cuda() + +for state in decoder_optimizer.state.values(): + for k, v in state.items(): + if isinstance(v, torch.Tensor): + state[k] = v.cuda() + # Run training iterations print("Starting Training!") trainIters(model_name, voc, pairs, encoder, decoder, encoder_optimizer, decoder_optimizer,