diff --git a/beginner_source/chatbot_tutorial.py b/beginner_source/chatbot_tutorial.py index 103efa93dcb..de5a97e3492 100644 --- a/beginner_source/chatbot_tutorial.py +++ b/beginner_source/chatbot_tutorial.py @@ -1207,7 +1207,7 @@ def evaluate(encoder, decoder, searcher, voc, sentence, max_length=MAX_LENGTH): input_batch = torch.LongTensor(indexes_batch).transpose(0, 1) # Use appropriate device input_batch = input_batch.to(device) - lengths = lengths.to(device) + lengths = lengths.to("cpu") # Decode sentence with searcher tokens, scores = searcher(input_batch, lengths, max_length) # indexes -> words