Skip to content

Commit 26bed42

Browse files
committed
Update Deploy Seq2Seq Tutorial with New TorchScript API
1 parent eb7960c commit 26bed42

File tree

4 files changed

+83
-90
lines changed

4 files changed

+83
-90
lines changed

_static/img/chatbot/diff.png

100755100644
34.5 KB
Loading
-162 KB
Binary file not shown.

beginner_source/deploy_seq2seq_hybrid_frontend_tutorial.py

Lines changed: 78 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,24 @@
11
# -*- coding: utf-8 -*-
22
"""
3-
Deploying a Seq2Seq Model with the Hybrid Frontend
3+
Deploying a Seq2Seq Model with the TorchScript
44
==================================================
55
**Author:** `Matthew Inkawhich <https://github.com/MatthewInkawhich>`_
66
"""
77

88

99
######################################################################
1010
# This tutorial will walk through the process of transitioning a
11-
# sequence-to-sequence model to Torch Script using PyTorch’s Hybrid
12-
# Frontend. The model that we will convert is the chatbot model from the
13-
# `Chatbot tutorial <https://pytorch.org/tutorials/beginner/chatbot_tutorial.html>`__.
11+
# sequence-to-sequence model to TorchScript using the TorchScript
12+
# API. The model that we will convert is the chatbot model from the
13+
# `Chatbot tutorial <https://pytorch.org/tutorials/beginner/chatbot_tutorial.html>`__.
1414
# You can either treat this tutorial as a “Part 2” to the Chatbot tutorial
1515
# and deploy your own pretrained model, or you can start with this
1616
# document and use a pretrained model that we host. In the latter case,
1717
# you can reference the original Chatbot tutorial for details
1818
# regarding data preprocessing, model theory and definition, and model
1919
# training.
2020
#
21-
# What is the Hybrid Frontend?
21+
# What is the TorchScript?
2222
# ----------------------------
2323
#
2424
# During the research and development phase of a deep learning-based
@@ -34,13 +34,13 @@
3434
# to target highly optimized hardware architectures. Also, a graph-based
3535
# representation enables framework-agnostic model exportation. PyTorch
3636
# provides mechanisms for incrementally converting eager-mode code into
37-
# Torch Script, a statically analyzable and optimizable subset of Python
37+
# TorchScript, a statically analyzable and optimizable subset of Python
3838
# that Torch uses to represent deep learning programs independently from
3939
# the Python runtime.
4040
#
41-
# The API for converting eager-mode PyTorch programs into Torch Script is
41+
# The API for converting eager-mode PyTorch programs into TorchScript is
4242
# found in the torch.jit module. This module has two core modalities for
43-
# converting an eager-mode model to a Torch Script graph representation:
43+
# converting an eager-mode model to a TorchScript graph representation:
4444
# **tracing** and **scripting**. The ``torch.jit.trace`` function takes a
4545
# module or function and a set of example inputs. It then runs the example
4646
# input through the function or module while tracing the computational
@@ -52,23 +52,20 @@
5252
# operations called along the execution route taken by the example input
5353
# will be recorded. In other words, the control flow itself is not
5454
# captured. To convert modules and functions containing data-dependent
55-
# control flow, a **scripting** mechanism is provided. Scripting
56-
# explicitly converts the module or function code to Torch Script,
57-
# including all possible control flow routes. To use script mode, be sure
58-
# to inherit from the the ``torch.jit.ScriptModule`` base class (instead
59-
# of ``torch.nn.Module``) and add a ``torch.jit.script`` decorator to your
60-
# Python function or a ``torch.jit.script_method`` decorator to your
61-
# module’s methods. The one caveat with using scripting is that it only
62-
# supports a restricted subset of Python. For all details relating to the
63-
# supported features, see the Torch Script `language
64-
# reference <https://pytorch.org/docs/master/jit.html>`__. To provide the
65-
# maximum flexibility, the modes of Torch Script can be composed to
66-
# represent your whole program, and these techniques can be applied
67-
# incrementally.
68-
#
69-
# .. figure:: /_static/img/chatbot/pytorch_workflow.png
70-
# :align: center
71-
# :alt: workflow
55+
# control flow, a **scripting** mechanism is provided. The
56+
# ``torch.jit.script`` function takes module or function and does not
57+
# requires example inputs. Scripting then explicitly converts the module
58+
# or function code to TorchScript, including all possible control flow
59+
# routes. The one caveat with using scripting is that it only supports
60+
# a subset of Python, so you might need to rewrite the code to make it
61+
# compatible with TorchScript syntax.
62+
#
63+
# For all details relating to the supported features, see the TorchScript
64+
# `language reference <https://pytorch.org/docs/master/jit.html>`__. To
65+
# provide the maximum flexibility, you can also mix tracing and scripting
66+
# modes together to represent your whole program, and these techniques can
67+
# be applied incrementally.
68+
#
7269
#
7370

7471

@@ -273,7 +270,7 @@ def indexesFromSentence(voc, sentence):
273270
# used by the ``torch.nn.utils.rnn.pack_padded_sequence`` function when
274271
# padding.
275272
#
276-
# Hybrid Frontend Notes:
273+
# TorchScript Notes:
277274
# ~~~~~~~~~~~~~~~~~~~~~~
278275
#
279276
# Since the encoder’s ``forward`` function does not contain any
@@ -296,6 +293,7 @@ def __init__(self, hidden_size, embedding, n_layers=1, dropout=0):
296293
dropout=(0 if n_layers == 1 else dropout), bidirectional=True)
297294

298295
def forward(self, input_seq, input_lengths, hidden=None):
296+
# type: (Tensor, Tensor, Optional[Tensor]) -> Tuple[Tensor, Tensor]
299297
# Convert word indexes to embeddings
300298
embedded = self.embedding(input_seq)
301299
# Pack padded batch of sequences for RNN module
@@ -325,18 +323,18 @@ def forward(self, input_seq, input_lengths, hidden=None):
325323
#
326324

327325
# Luong attention layer
328-
class Attn(torch.nn.Module):
326+
class Attn(nn.Module):
329327
def __init__(self, method, hidden_size):
330328
super(Attn, self).__init__()
331329
self.method = method
332330
if self.method not in ['dot', 'general', 'concat']:
333331
raise ValueError(self.method, "is not an appropriate attention method.")
334332
self.hidden_size = hidden_size
335333
if self.method == 'general':
336-
self.attn = torch.nn.Linear(self.hidden_size, hidden_size)
334+
self.attn = nn.Linear(self.hidden_size, hidden_size)
337335
elif self.method == 'concat':
338-
self.attn = torch.nn.Linear(self.hidden_size * 2, hidden_size)
339-
self.v = torch.nn.Parameter(torch.FloatTensor(hidden_size))
336+
self.attn = nn.Linear(self.hidden_size * 2, hidden_size)
337+
self.v = nn.Parameter(torch.FloatTensor(hidden_size))
340338

341339
def dot_score(self, hidden, encoder_output):
342340
return torch.sum(hidden * encoder_output, dim=2)
@@ -383,14 +381,14 @@ def forward(self, hidden, encoder_outputs):
383381
# weighted sum indicating what parts of the encoder’s output to pay
384382
# attention to. From here, we use a linear layer and softmax normalization
385383
# to select the next word in the output sequence.
386-
#
387-
# Hybrid Frontend Notes:
384+
385+
# TorchScript Notes:
388386
# ~~~~~~~~~~~~~~~~~~~~~~
389387
#
390-
# Similarly to the ``EncoderRNN``, this module does not contain any
388+
# Similarly to the ``EncoderRNN```, this module does not contain any
391389
# data-dependent control flow. Therefore, we can once again use
392-
# **tracing** to convert this model to Torch Script after it is
393-
# initialized and its parameters are loaded.
390+
# **tracing** to convert this model to Torch Script after it
391+
# is initialized and its parameters are loaded.
394392
#
395393

396394
class LuongAttnDecoderRNN(nn.Module):
@@ -465,18 +463,18 @@ def forward(self, input_step, last_hidden, encoder_outputs):
465463
# terminates either if the ``decoded_words`` list has reached a length of
466464
# *MAX_LENGTH* or if the predicted word is the *EOS_token*.
467465
#
468-
# Hybrid Frontend Notes:
466+
# TorchScript Notes:
469467
# ~~~~~~~~~~~~~~~~~~~~~~
470468
#
471469
# The ``forward`` method of this module involves iterating over the range
472470
# of :math:`[0, max\_length)` when decoding an output sequence one word at
473471
# a time. Because of this, we should use **scripting** to convert this
474-
# module to Torch Script. Unlike with our encoder and decoder models,
472+
# module to TorchScript. Unlike with our encoder and decoder models,
475473
# which we can trace, we must make some necessary changes to the
476474
# ``GreedySearchDecoder`` module in order to initialize an object without
477475
# error. In other words, we must ensure that our module adheres to the
478-
# rules of the scripting mechanism, and does not utilize any language
479-
# features outside of the subset of Python that Torch Script includes.
476+
# rules of the TorchScript mechanism, and does not utilize any language
477+
# features outside of the subset of Python that TorchScript includes.
480478
#
481479
# To get an idea of some manipulations that may be required, we will go
482480
# over the diffs between the ``GreedySearchDecoder`` implementation from
@@ -491,12 +489,6 @@ def forward(self, input_step, last_hidden, encoder_outputs):
491489
# Changes:
492490
# ^^^^^^^^
493491
#
494-
# - ``nn.Module`` -> ``torch.jit.ScriptModule``
495-
#
496-
# - In order to use PyTorch’s scripting mechanism on a module, that
497-
# module must inherit from the ``torch.jit.ScriptModule``.
498-
#
499-
#
500492
# - Added ``decoder_n_layers`` to the constructor arguments
501493
#
502494
# - This change stems from the fact that the encoder and decoder
@@ -523,16 +515,9 @@ def forward(self, input_step, last_hidden, encoder_outputs):
523515
# ``self._SOS_token``.
524516
#
525517
#
526-
# - Add the ``torch.jit.script_method`` decorator to the ``forward``
527-
# method
528-
#
529-
# - Adding this decorator lets the JIT compiler know that the function
530-
# that it is decorating should be scripted.
531-
#
532-
#
533518
# - Enforce types of ``forward`` method arguments
534519
#
535-
# - By default, all parameters to a Torch Script function are assumed
520+
# - By default, all parameters to a TorchScript function are assumed
536521
# to be Tensor. If we need to pass an argument of a different type,
537522
# we can use function type annotations as introduced in `PEP
538523
# 3107 <https://www.python.org/dev/peps/pep-3107/>`__. In addition,
@@ -553,7 +538,7 @@ def forward(self, input_step, last_hidden, encoder_outputs):
553538
# ``self._SOS_token``.
554539
#
555540

556-
class GreedySearchDecoder(torch.jit.ScriptModule):
541+
class GreedySearchDecoder(nn.Module):
557542
def __init__(self, encoder, decoder, decoder_n_layers):
558543
super(GreedySearchDecoder, self).__init__()
559544
self.encoder = encoder
@@ -564,7 +549,6 @@ def __init__(self, encoder, decoder, decoder_n_layers):
564549

565550
__constants__ = ['_device', '_SOS_token', '_decoder_n_layers']
566551

567-
@torch.jit.script_method
568552
def forward(self, input_seq : torch.Tensor, input_length : torch.Tensor, max_length : int):
569553
# Forward input through encoder model
570554
encoder_outputs, encoder_hidden = self.encoder(input_seq, input_length)
@@ -613,7 +597,7 @@ def forward(self, input_seq : torch.Tensor, input_length : torch.Tensor, max_len
613597
# an argument, normalizes it, evaluates it, and prints the response.
614598
#
615599

616-
def evaluate(encoder, decoder, searcher, voc, sentence, max_length=MAX_LENGTH):
600+
def evaluate(searcher, voc, sentence, max_length=MAX_LENGTH):
617601
### Format input sentence as a batch
618602
# words -> indexes
619603
indexes_batch = [indexesFromSentence(voc, sentence)]
@@ -632,7 +616,7 @@ def evaluate(encoder, decoder, searcher, voc, sentence, max_length=MAX_LENGTH):
632616

633617

634618
# Evaluate inputs from user input (stdin)
635-
def evaluateInput(encoder, decoder, searcher, voc):
619+
def evaluateInput(searcher, voc):
636620
input_sentence = ''
637621
while(1):
638622
try:
@@ -643,7 +627,7 @@ def evaluateInput(encoder, decoder, searcher, voc):
643627
# Normalize sentence
644628
input_sentence = normalizeString(input_sentence)
645629
# Evaluate sentence
646-
output_words = evaluate(encoder, decoder, searcher, voc, input_sentence)
630+
output_words = evaluate(searcher, voc, input_sentence)
647631
# Format and print response sentence
648632
output_words[:] = [x for x in output_words if not (x == 'EOS' or x == 'PAD')]
649633
print('Bot:', ' '.join(output_words))
@@ -652,12 +636,12 @@ def evaluateInput(encoder, decoder, searcher, voc):
652636
print("Error: Encountered unknown word.")
653637

654638
# Normalize input sentence and call evaluate()
655-
def evaluateExample(sentence, encoder, decoder, searcher, voc):
639+
def evaluateExample(sentence, searcher, voc):
656640
print("> " + sentence)
657641
# Normalize sentence
658642
input_sentence = normalizeString(sentence)
659643
# Evaluate sentence
660-
output_words = evaluate(encoder, decoder, searcher, voc, input_sentence)
644+
output_words = evaluate(searcher, voc, input_sentence)
661645
output_words[:] = [x for x in output_words if not (x == 'EOS' or x == 'PAD')]
662646
print('Bot:', ' '.join(output_words))
663647

@@ -700,14 +684,17 @@ def evaluateExample(sentence, encoder, decoder, searcher, voc):
700684
# ``checkpoint = torch.load(loadFilename, map_location=torch.device('cpu'))``
701685
# line.
702686
#
703-
# Hybrid Frontend Notes:
687+
# TorchScript Notes:
704688
# ~~~~~~~~~~~~~~~~~~~~~~
705689
#
706690
# Notice that we initialize and load parameters into our encoder and
707-
# decoder models as usual. Also, we must call ``.to(device)`` to set the
708-
# device options of the models and ``.eval()`` to set the dropout layers
709-
# to test mode **before** we trace the models. ``TracedModule`` objects do
710-
# not inherit the ``to`` or ``eval`` methods.
691+
# decoder models as usual. If you are using tracing mode(`torch.jit.trace`)
692+
# for some part of your models, you must call .to(device) to set the device
693+
# options of the models and .eval() to set the dropout layers to test mode
694+
# **before** tracing the models. `TracedModule` objects do not inherit the
695+
# ``to``` or ``eval``` methods. Since in this tutorial we are only using
696+
# scripting instead of tracing, we only need to do this before we do
697+
# evaluation (which is the same as we normally do in eager mode).
711698
#
712699

713700
save_dir = os.path.join("data", "save")
@@ -766,16 +753,14 @@ def evaluateExample(sentence, encoder, decoder, searcher, voc):
766753

767754

768755
######################################################################
769-
# Convert Model to Torch Script
756+
# Convert Model to TorchScript
770757
# -----------------------------
771758
#
772759
# Encoder
773760
# ~~~~~~~
774761
#
775762
# As previously mentioned, to convert the encoder model to Torch Script,
776-
# we use **tracing**. Tracing any module requires running an example input
777-
# through the model’s ``forward`` method and trace the computational graph
778-
# that the data encounters. The encoder model takes an input sequence and
763+
# we use **scripting**. The encoder model takes an input sequence and
779764
# a corresponding lengths tensor. Therefore, we create an example input
780765
# sequence tensor ``test_seq``, which is of appropriate size (MAX_LENGTH,
781766
# 1), contains numbers in the appropriate range
@@ -803,13 +788,13 @@ def evaluateExample(sentence, encoder, decoder, searcher, voc):
803788
# ~~~~~~~~~~~~~~~~~~~
804789
#
805790
# Recall that we scripted our searcher module due to the presence of
806-
# data-dependent control flow. In the case of scripting, we do the
807-
# conversion work up front by adding the decorator and making sure the
808-
# implementation complies with scripting rules. We initialize the scripted
809-
# searcher the same way that we would initialize an un-scripted variant.
791+
# data-dependent control flow. In the case of scripting, we do necessary
792+
# language changes to make sure the implementation complies with
793+
# TorchScript. We initialize the scripted searcher the same way that we
794+
# would initialize an un-scripted variant.
810795
#
811796

812-
### Convert encoder model
797+
### Compile the whole greedy search model to TorchScript model
813798
# Create artificial inputs
814799
test_seq = torch.LongTensor(MAX_LENGTH, 1).random_(0, voc.num_words).to(device)
815800
test_seq_length = torch.LongTensor([test_seq.size()[0]]).to(device)
@@ -824,19 +809,21 @@ def evaluateExample(sentence, encoder, decoder, searcher, voc):
824809
# Trace the model
825810
traced_decoder = torch.jit.trace(decoder, (test_decoder_input, test_decoder_hidden, test_encoder_outputs))
826811

827-
### Initialize searcher module
828-
scripted_searcher = GreedySearchDecoder(traced_encoder, traced_decoder, decoder.n_layers)
812+
### Initialize searcher module by wrapping ``torch.jit.script`` call
813+
scripted_searcher = torch.jit.script(GreedySearchDecoder(traced_encoder, traced_decoder, decoder.n_layers))
814+
815+
829816

830817

831818
######################################################################
832819
# Print Graphs
833820
# ------------
834821
#
835-
# Now that our models are in Torch Script form, we can print the graphs of
822+
# Now that our models are in TorchScript form, we can print the graphs of
836823
# each to ensure that we captured the computational graph appropriately.
837-
# Since our ``scripted_searcher`` contains our ``traced_encoder`` and
838-
# ``traced_decoder``, these graphs will print inline.
839-
#
824+
# Since TorchScript allow us to recursively compile the whole model
825+
# hierarchy and inline the ``encoder`` and ``decoder`` graph into a single
826+
# graph, we just need to print the `scripted_searcher` graph
840827

841828
print('scripted_searcher graph:\n', scripted_searcher.graph)
842829

@@ -845,19 +832,25 @@ def evaluateExample(sentence, encoder, decoder, searcher, voc):
845832
# Run Evaluation
846833
# --------------
847834
#
848-
# Finally, we will run evaluation of the chatbot model using the Torch
849-
# Script models. If converted correctly, the models will behave exactly as
850-
# they would in their eager-mode representation.
835+
# Finally, we will run evaluation of the chatbot model using the TorchScript
836+
# models. If converted correctly, the models will behave exactly as they
837+
# would in their eager-mode representation.
851838
#
852839
# By default, we evaluate a few common query sentences. If you want to
853840
# chat with the bot yourself, uncomment the ``evaluateInput`` line and
854841
# give it a spin.
855842
#
856843

844+
845+
# Use appropriate device
846+
scripted_searcher.to(device)
847+
# Set dropout layers to eval mode
848+
scripted_searcher.eval()
849+
857850
# Evaluate examples
858851
sentences = ["hello", "what's up?", "who are you?", "where am I?", "where are you from?"]
859852
for s in sentences:
860-
evaluateExample(s, traced_encoder, traced_decoder, scripted_searcher, voc)
853+
evaluateExample(s, scripted_searcher, voc)
861854

862855
# Evaluate your input
863856
#evaluateInput(traced_encoder, traced_decoder, scripted_searcher, voc)
@@ -867,7 +860,7 @@ def evaluateExample(sentence, encoder, decoder, searcher, voc):
867860
# Save Model
868861
# ----------
869862
#
870-
# Now that we have successfully converted our model to Torch Script, we
863+
# Now that we have successfully converted our model to TorchScript, we
871864
# will serialize it for use in a non-Python deployment environment. To do
872865
# this, we can simply save our ``scripted_searcher`` module, as this is
873866
# the user-facing interface for running inference against the chatbot

0 commit comments

Comments
 (0)