11# -*- coding: utf-8 -*-
22"""
3- Deploying a Seq2Seq Model with the Hybrid Frontend
3+ Deploying a Seq2Seq Model with the TorchScript
44==================================================
55**Author:** `Matthew Inkawhich <https://github.com/MatthewInkawhich>`_
66"""
77
88
99######################################################################
1010# This tutorial will walk through the process of transitioning a
11- # sequence-to-sequence model to Torch Script using PyTorch’s Hybrid
12- # Frontend . The model that we will convert is the chatbot model from the
13- # `Chatbot tutorial <https://pytorch.org/tutorials/beginner/chatbot_tutorial.html>`__.
11+ # sequence-to-sequence model to TorchScript using the TorchScript
12+ # API . The model that we will convert is the chatbot model from the
13+ # `Chatbot tutorial <https://pytorch.org/tutorials/beginner/chatbot_tutorial.html>`__.
1414# You can either treat this tutorial as a “Part 2” to the Chatbot tutorial
1515# and deploy your own pretrained model, or you can start with this
1616# document and use a pretrained model that we host. In the latter case,
1717# you can reference the original Chatbot tutorial for details
1818# regarding data preprocessing, model theory and definition, and model
1919# training.
2020#
21- # What is the Hybrid Frontend ?
21+ # What is the TorchScript ?
2222# ----------------------------
2323#
2424# During the research and development phase of a deep learning-based
3434# to target highly optimized hardware architectures. Also, a graph-based
3535# representation enables framework-agnostic model exportation. PyTorch
3636# provides mechanisms for incrementally converting eager-mode code into
37- # Torch Script , a statically analyzable and optimizable subset of Python
37+ # TorchScript , a statically analyzable and optimizable subset of Python
3838# that Torch uses to represent deep learning programs independently from
3939# the Python runtime.
4040#
41- # The API for converting eager-mode PyTorch programs into Torch Script is
41+ # The API for converting eager-mode PyTorch programs into TorchScript is
4242# found in the torch.jit module. This module has two core modalities for
43- # converting an eager-mode model to a Torch Script graph representation:
43+ # converting an eager-mode model to a TorchScript graph representation:
4444# **tracing** and **scripting**. The ``torch.jit.trace`` function takes a
4545# module or function and a set of example inputs. It then runs the example
4646# input through the function or module while tracing the computational
5252# operations called along the execution route taken by the example input
5353# will be recorded. In other words, the control flow itself is not
5454# captured. To convert modules and functions containing data-dependent
55- # control flow, a **scripting** mechanism is provided. Scripting
56- # explicitly converts the module or function code to Torch Script,
57- # including all possible control flow routes. To use script mode, be sure
58- # to inherit from the the ``torch.jit.ScriptModule`` base class (instead
59- # of ``torch.nn.Module``) and add a ``torch.jit.script`` decorator to your
60- # Python function or a ``torch.jit.script_method`` decorator to your
61- # module’s methods. The one caveat with using scripting is that it only
62- # supports a restricted subset of Python. For all details relating to the
63- # supported features, see the Torch Script `language
64- # reference <https://pytorch.org/docs/master/jit.html>`__. To provide the
65- # maximum flexibility, the modes of Torch Script can be composed to
66- # represent your whole program, and these techniques can be applied
67- # incrementally.
68- #
69- # .. figure:: /_static/img/chatbot/pytorch_workflow.png
70- # :align: center
71- # :alt: workflow
55+ # control flow, a **scripting** mechanism is provided. The
56+ # ``torch.jit.script`` function takes module or function and does not
57+ # requires example inputs. Scripting then explicitly converts the module
58+ # or function code to TorchScript, including all possible control flow
59+ # routes. The one caveat with using scripting is that it only supports
60+ # a subset of Python, so you might need to rewrite the code to make it
61+ # compatible with TorchScript syntax.
62+ #
63+ # For all details relating to the supported features, see the TorchScript
64+ # `language reference <https://pytorch.org/docs/master/jit.html>`__. To
65+ # provide the maximum flexibility, you can also mix tracing and scripting
66+ # modes together to represent your whole program, and these techniques can
67+ # be applied incrementally.
68+ #
7269#
7370
7471
@@ -273,7 +270,7 @@ def indexesFromSentence(voc, sentence):
273270# used by the ``torch.nn.utils.rnn.pack_padded_sequence`` function when
274271# padding.
275272#
276- # Hybrid Frontend Notes:
273+ # TorchScript Notes:
277274# ~~~~~~~~~~~~~~~~~~~~~~
278275#
279276# Since the encoder’s ``forward`` function does not contain any
@@ -296,6 +293,7 @@ def __init__(self, hidden_size, embedding, n_layers=1, dropout=0):
296293 dropout = (0 if n_layers == 1 else dropout ), bidirectional = True )
297294
298295 def forward (self , input_seq , input_lengths , hidden = None ):
296+ # type: (Tensor, Tensor, Optional[Tensor]) -> Tuple[Tensor, Tensor]
299297 # Convert word indexes to embeddings
300298 embedded = self .embedding (input_seq )
301299 # Pack padded batch of sequences for RNN module
@@ -325,18 +323,18 @@ def forward(self, input_seq, input_lengths, hidden=None):
325323#
326324
327325# Luong attention layer
328- class Attn (torch . nn .Module ):
326+ class Attn (nn .Module ):
329327 def __init__ (self , method , hidden_size ):
330328 super (Attn , self ).__init__ ()
331329 self .method = method
332330 if self .method not in ['dot' , 'general' , 'concat' ]:
333331 raise ValueError (self .method , "is not an appropriate attention method." )
334332 self .hidden_size = hidden_size
335333 if self .method == 'general' :
336- self .attn = torch . nn .Linear (self .hidden_size , hidden_size )
334+ self .attn = nn .Linear (self .hidden_size , hidden_size )
337335 elif self .method == 'concat' :
338- self .attn = torch . nn .Linear (self .hidden_size * 2 , hidden_size )
339- self .v = torch . nn .Parameter (torch .FloatTensor (hidden_size ))
336+ self .attn = nn .Linear (self .hidden_size * 2 , hidden_size )
337+ self .v = nn .Parameter (torch .FloatTensor (hidden_size ))
340338
341339 def dot_score (self , hidden , encoder_output ):
342340 return torch .sum (hidden * encoder_output , dim = 2 )
@@ -383,14 +381,14 @@ def forward(self, hidden, encoder_outputs):
383381# weighted sum indicating what parts of the encoder’s output to pay
384382# attention to. From here, we use a linear layer and softmax normalization
385383# to select the next word in the output sequence.
386- #
387- # Hybrid Frontend Notes:
384+
385+ # TorchScript Notes:
388386# ~~~~~~~~~~~~~~~~~~~~~~
389387#
390- # Similarly to the ``EncoderRNN``, this module does not contain any
388+ # Similarly to the ``EncoderRNN``` , this module does not contain any
391389# data-dependent control flow. Therefore, we can once again use
392- # **tracing** to convert this model to Torch Script after it is
393- # initialized and its parameters are loaded.
390+ # **tracing** to convert this model to Torch Script after it
391+ # is initialized and its parameters are loaded.
394392#
395393
396394class LuongAttnDecoderRNN (nn .Module ):
@@ -465,18 +463,18 @@ def forward(self, input_step, last_hidden, encoder_outputs):
465463# terminates either if the ``decoded_words`` list has reached a length of
466464# *MAX_LENGTH* or if the predicted word is the *EOS_token*.
467465#
468- # Hybrid Frontend Notes:
466+ # TorchScript Notes:
469467# ~~~~~~~~~~~~~~~~~~~~~~
470468#
471469# The ``forward`` method of this module involves iterating over the range
472470# of :math:`[0, max\_length)` when decoding an output sequence one word at
473471# a time. Because of this, we should use **scripting** to convert this
474- # module to Torch Script . Unlike with our encoder and decoder models,
472+ # module to TorchScript . Unlike with our encoder and decoder models,
475473# which we can trace, we must make some necessary changes to the
476474# ``GreedySearchDecoder`` module in order to initialize an object without
477475# error. In other words, we must ensure that our module adheres to the
478- # rules of the scripting mechanism, and does not utilize any language
479- # features outside of the subset of Python that Torch Script includes.
476+ # rules of the TorchScript mechanism, and does not utilize any language
477+ # features outside of the subset of Python that TorchScript includes.
480478#
481479# To get an idea of some manipulations that may be required, we will go
482480# over the diffs between the ``GreedySearchDecoder`` implementation from
@@ -491,12 +489,6 @@ def forward(self, input_step, last_hidden, encoder_outputs):
491489# Changes:
492490# ^^^^^^^^
493491#
494- # - ``nn.Module`` -> ``torch.jit.ScriptModule``
495- #
496- # - In order to use PyTorch’s scripting mechanism on a module, that
497- # module must inherit from the ``torch.jit.ScriptModule``.
498- #
499- #
500492# - Added ``decoder_n_layers`` to the constructor arguments
501493#
502494# - This change stems from the fact that the encoder and decoder
@@ -523,16 +515,9 @@ def forward(self, input_step, last_hidden, encoder_outputs):
523515# ``self._SOS_token``.
524516#
525517#
526- # - Add the ``torch.jit.script_method`` decorator to the ``forward``
527- # method
528- #
529- # - Adding this decorator lets the JIT compiler know that the function
530- # that it is decorating should be scripted.
531- #
532- #
533518# - Enforce types of ``forward`` method arguments
534519#
535- # - By default, all parameters to a Torch Script function are assumed
520+ # - By default, all parameters to a TorchScript function are assumed
536521# to be Tensor. If we need to pass an argument of a different type,
537522# we can use function type annotations as introduced in `PEP
538523# 3107 <https://www.python.org/dev/peps/pep-3107/>`__. In addition,
@@ -553,7 +538,7 @@ def forward(self, input_step, last_hidden, encoder_outputs):
553538# ``self._SOS_token``.
554539#
555540
556- class GreedySearchDecoder (torch . jit . ScriptModule ):
541+ class GreedySearchDecoder (nn . Module ):
557542 def __init__ (self , encoder , decoder , decoder_n_layers ):
558543 super (GreedySearchDecoder , self ).__init__ ()
559544 self .encoder = encoder
@@ -564,7 +549,6 @@ def __init__(self, encoder, decoder, decoder_n_layers):
564549
565550 __constants__ = ['_device' , '_SOS_token' , '_decoder_n_layers' ]
566551
567- @torch .jit .script_method
568552 def forward (self , input_seq : torch .Tensor , input_length : torch .Tensor , max_length : int ):
569553 # Forward input through encoder model
570554 encoder_outputs , encoder_hidden = self .encoder (input_seq , input_length )
@@ -613,7 +597,7 @@ def forward(self, input_seq : torch.Tensor, input_length : torch.Tensor, max_len
613597# an argument, normalizes it, evaluates it, and prints the response.
614598#
615599
616- def evaluate (encoder , decoder , searcher , voc , sentence , max_length = MAX_LENGTH ):
600+ def evaluate (searcher , voc , sentence , max_length = MAX_LENGTH ):
617601 ### Format input sentence as a batch
618602 # words -> indexes
619603 indexes_batch = [indexesFromSentence (voc , sentence )]
@@ -632,7 +616,7 @@ def evaluate(encoder, decoder, searcher, voc, sentence, max_length=MAX_LENGTH):
632616
633617
634618# Evaluate inputs from user input (stdin)
635- def evaluateInput (encoder , decoder , searcher , voc ):
619+ def evaluateInput (searcher , voc ):
636620 input_sentence = ''
637621 while (1 ):
638622 try :
@@ -643,7 +627,7 @@ def evaluateInput(encoder, decoder, searcher, voc):
643627 # Normalize sentence
644628 input_sentence = normalizeString (input_sentence )
645629 # Evaluate sentence
646- output_words = evaluate (encoder , decoder , searcher , voc , input_sentence )
630+ output_words = evaluate (searcher , voc , input_sentence )
647631 # Format and print response sentence
648632 output_words [:] = [x for x in output_words if not (x == 'EOS' or x == 'PAD' )]
649633 print ('Bot:' , ' ' .join (output_words ))
@@ -652,12 +636,12 @@ def evaluateInput(encoder, decoder, searcher, voc):
652636 print ("Error: Encountered unknown word." )
653637
654638# Normalize input sentence and call evaluate()
655- def evaluateExample (sentence , encoder , decoder , searcher , voc ):
639+ def evaluateExample (sentence , searcher , voc ):
656640 print ("> " + sentence )
657641 # Normalize sentence
658642 input_sentence = normalizeString (sentence )
659643 # Evaluate sentence
660- output_words = evaluate (encoder , decoder , searcher , voc , input_sentence )
644+ output_words = evaluate (searcher , voc , input_sentence )
661645 output_words [:] = [x for x in output_words if not (x == 'EOS' or x == 'PAD' )]
662646 print ('Bot:' , ' ' .join (output_words ))
663647
@@ -700,14 +684,17 @@ def evaluateExample(sentence, encoder, decoder, searcher, voc):
700684# ``checkpoint = torch.load(loadFilename, map_location=torch.device('cpu'))``
701685# line.
702686#
703- # Hybrid Frontend Notes:
687+ # TorchScript Notes:
704688# ~~~~~~~~~~~~~~~~~~~~~~
705689#
706690# Notice that we initialize and load parameters into our encoder and
707- # decoder models as usual. Also, we must call ``.to(device)`` to set the
708- # device options of the models and ``.eval()`` to set the dropout layers
709- # to test mode **before** we trace the models. ``TracedModule`` objects do
710- # not inherit the ``to`` or ``eval`` methods.
691+ # decoder models as usual. If you are using tracing mode(`torch.jit.trace`)
692+ # for some part of your models, you must call .to(device) to set the device
693+ # options of the models and .eval() to set the dropout layers to test mode
694+ # **before** tracing the models. `TracedModule` objects do not inherit the
695+ # ``to``` or ``eval``` methods. Since in this tutorial we are only using
696+ # scripting instead of tracing, we only need to do this before we do
697+ # evaluation (which is the same as we normally do in eager mode).
711698#
712699
713700save_dir = os .path .join ("data" , "save" )
@@ -766,16 +753,14 @@ def evaluateExample(sentence, encoder, decoder, searcher, voc):
766753
767754
768755######################################################################
769- # Convert Model to Torch Script
756+ # Convert Model to TorchScript
770757# -----------------------------
771758#
772759# Encoder
773760# ~~~~~~~
774761#
775762# As previously mentioned, to convert the encoder model to Torch Script,
776- # we use **tracing**. Tracing any module requires running an example input
777- # through the model’s ``forward`` method and trace the computational graph
778- # that the data encounters. The encoder model takes an input sequence and
763+ # we use **scripting**. The encoder model takes an input sequence and
779764# a corresponding lengths tensor. Therefore, we create an example input
780765# sequence tensor ``test_seq``, which is of appropriate size (MAX_LENGTH,
781766# 1), contains numbers in the appropriate range
@@ -803,13 +788,13 @@ def evaluateExample(sentence, encoder, decoder, searcher, voc):
803788# ~~~~~~~~~~~~~~~~~~~
804789#
805790# Recall that we scripted our searcher module due to the presence of
806- # data-dependent control flow. In the case of scripting, we do the
807- # conversion work up front by adding the decorator and making sure the
808- # implementation complies with scripting rules . We initialize the scripted
809- # searcher the same way that we would initialize an un-scripted variant.
791+ # data-dependent control flow. In the case of scripting, we do necessary
792+ # language changes to make sure the implementation complies with
793+ # TorchScript . We initialize the scripted searcher the same way that we
794+ # would initialize an un-scripted variant.
810795#
811796
812- ### Convert encoder model
797+ ### Compile the whole greedy search model to TorchScript model
813798# Create artificial inputs
814799test_seq = torch .LongTensor (MAX_LENGTH , 1 ).random_ (0 , voc .num_words ).to (device )
815800test_seq_length = torch .LongTensor ([test_seq .size ()[0 ]]).to (device )
@@ -824,19 +809,21 @@ def evaluateExample(sentence, encoder, decoder, searcher, voc):
824809# Trace the model
825810traced_decoder = torch .jit .trace (decoder , (test_decoder_input , test_decoder_hidden , test_encoder_outputs ))
826811
827- ### Initialize searcher module
828- scripted_searcher = GreedySearchDecoder (traced_encoder , traced_decoder , decoder .n_layers )
812+ ### Initialize searcher module by wrapping ``torch.jit.script`` call
813+ scripted_searcher = torch .jit .script (GreedySearchDecoder (traced_encoder , traced_decoder , decoder .n_layers ))
814+
815+
829816
830817
831818######################################################################
832819# Print Graphs
833820# ------------
834821#
835- # Now that our models are in Torch Script form, we can print the graphs of
822+ # Now that our models are in TorchScript form, we can print the graphs of
836823# each to ensure that we captured the computational graph appropriately.
837- # Since our ``scripted_searcher`` contains our ``traced_encoder`` and
838- # ``traced_decoder``, these graphs will print inline.
839- #
824+ # Since TorchScript allow us to recursively compile the whole model
825+ # hierarchy and inline the ``encoder`` and ``decoder`` graph into a single
826+ # graph, we just need to print the `scripted_searcher` graph
840827
841828print ('scripted_searcher graph:\n ' , scripted_searcher .graph )
842829
@@ -845,19 +832,25 @@ def evaluateExample(sentence, encoder, decoder, searcher, voc):
845832# Run Evaluation
846833# --------------
847834#
848- # Finally, we will run evaluation of the chatbot model using the Torch
849- # Script models. If converted correctly, the models will behave exactly as
850- # they would in their eager-mode representation.
835+ # Finally, we will run evaluation of the chatbot model using the TorchScript
836+ # models. If converted correctly, the models will behave exactly as they
837+ # would in their eager-mode representation.
851838#
852839# By default, we evaluate a few common query sentences. If you want to
853840# chat with the bot yourself, uncomment the ``evaluateInput`` line and
854841# give it a spin.
855842#
856843
844+
845+ # Use appropriate device
846+ scripted_searcher .to (device )
847+ # Set dropout layers to eval mode
848+ scripted_searcher .eval ()
849+
857850# Evaluate examples
858851sentences = ["hello" , "what's up?" , "who are you?" , "where am I?" , "where are you from?" ]
859852for s in sentences :
860- evaluateExample (s , traced_encoder , traced_decoder , scripted_searcher , voc )
853+ evaluateExample (s , scripted_searcher , voc )
861854
862855# Evaluate your input
863856#evaluateInput(traced_encoder, traced_decoder, scripted_searcher, voc)
@@ -867,7 +860,7 @@ def evaluateExample(sentence, encoder, decoder, searcher, voc):
867860# Save Model
868861# ----------
869862#
870- # Now that we have successfully converted our model to Torch Script , we
863+ # Now that we have successfully converted our model to TorchScript , we
871864# will serialize it for use in a non-Python deployment environment. To do
872865# this, we can simply save our ``scripted_searcher`` module, as this is
873866# the user-facing interface for running inference against the chatbot
0 commit comments