11# -*- coding: utf-8 -*-
22"""
3- Deploying a Seq2Seq Model with the Hybrid Frontend
3+ Deploying a Seq2Seq Model with TorchScript
44==================================================
55**Author:** `Matthew Inkawhich <https://github.com/MatthewInkawhich>`_
66"""
77
88
99######################################################################
1010# This tutorial will walk through the process of transitioning a
11- # sequence-to-sequence model to Torch Script using PyTorch’s Hybrid
12- # Frontend . The model that we will convert is the chatbot model from the
13- # `Chatbot tutorial <https://pytorch.org/tutorials/beginner/chatbot_tutorial.html>`__.
11+ # sequence-to-sequence model to TorchScript using the TorchScript
12+ # API . The model that we will convert is the chatbot model from the
13+ # `Chatbot tutorial <https://pytorch.org/tutorials/beginner/chatbot_tutorial.html>`__.
1414# You can either treat this tutorial as a “Part 2” to the Chatbot tutorial
1515# and deploy your own pretrained model, or you can start with this
1616# document and use a pretrained model that we host. In the latter case,
1717# you can reference the original Chatbot tutorial for details
1818# regarding data preprocessing, model theory and definition, and model
1919# training.
2020#
21- # What is the Hybrid Frontend ?
21+ # What is TorchScript ?
2222# ----------------------------
2323#
2424# During the research and development phase of a deep learning-based
3434# to target highly optimized hardware architectures. Also, a graph-based
3535# representation enables framework-agnostic model exportation. PyTorch
3636# provides mechanisms for incrementally converting eager-mode code into
37- # Torch Script , a statically analyzable and optimizable subset of Python
37+ # TorchScript , a statically analyzable and optimizable subset of Python
3838# that Torch uses to represent deep learning programs independently from
3939# the Python runtime.
4040#
41- # The API for converting eager-mode PyTorch programs into Torch Script is
41+ # The API for converting eager-mode PyTorch programs into TorchScript is
4242# found in the torch.jit module. This module has two core modalities for
43- # converting an eager-mode model to a Torch Script graph representation:
43+ # converting an eager-mode model to a TorchScript graph representation:
4444# **tracing** and **scripting**. The ``torch.jit.trace`` function takes a
4545# module or function and a set of example inputs. It then runs the example
4646# input through the function or module while tracing the computational
5252# operations called along the execution route taken by the example input
5353# will be recorded. In other words, the control flow itself is not
5454# captured. To convert modules and functions containing data-dependent
55- # control flow, a **scripting** mechanism is provided. Scripting
56- # explicitly converts the module or function code to Torch Script,
57- # including all possible control flow routes. To use script mode, be sure
58- # to inherit from the the ``torch.jit.ScriptModule`` base class (instead
59- # of ``torch.nn.Module``) and add a ``torch.jit.script`` decorator to your
60- # Python function or a ``torch.jit.script_method`` decorator to your
61- # module’s methods. The one caveat with using scripting is that it only
62- # supports a restricted subset of Python. For all details relating to the
63- # supported features, see the Torch Script `language
64- # reference <https://pytorch.org/docs/master/jit.html>`__. To provide the
65- # maximum flexibility, the modes of Torch Script can be composed to
66- # represent your whole program, and these techniques can be applied
67- # incrementally.
55+ # control flow, a **scripting** mechanism is provided. The
56+ # ``torch.jit.script`` function/decorator takes a module or function and
57+ # does not requires example inputs. Scripting then explicitly converts
58+ # the module or function code to TorchScript, including all control flows.
59+ # One caveat with using scripting is that it only supports a subset of
60+ # Python, so you might need to rewrite the code to make it compatible
61+ # with the TorchScript syntax.
62+ #
63+ # For all details relating to the supported features, see the `TorchScript
64+ # language reference <https://pytorch.org/docs/master/jit.html>`__.
65+ # To provide the maximum flexibility, you can also mix tracing and scripting
66+ # modes together to represent your whole program, and these techniques can
67+ # be applied incrementally.
6868#
6969# .. figure:: /_static/img/chatbot/pytorch_workflow.png
7070# :align: center
@@ -273,7 +273,7 @@ def indexesFromSentence(voc, sentence):
273273# used by the ``torch.nn.utils.rnn.pack_padded_sequence`` function when
274274# padding.
275275#
276- # Hybrid Frontend Notes:
276+ # TorchScript Notes:
277277# ~~~~~~~~~~~~~~~~~~~~~~
278278#
279279# Since the encoder’s ``forward`` function does not contain any
@@ -296,6 +296,7 @@ def __init__(self, hidden_size, embedding, n_layers=1, dropout=0):
296296 dropout = (0 if n_layers == 1 else dropout ), bidirectional = True )
297297
298298 def forward (self , input_seq , input_lengths , hidden = None ):
299+ # type: (Tensor, Tensor, Optional[Tensor]) -> Tuple[Tensor, Tensor]
299300 # Convert word indexes to embeddings
300301 embedded = self .embedding (input_seq )
301302 # Pack padded batch of sequences for RNN module
@@ -325,18 +326,18 @@ def forward(self, input_seq, input_lengths, hidden=None):
325326#
326327
327328# Luong attention layer
328- class Attn (torch . nn .Module ):
329+ class Attn (nn .Module ):
329330 def __init__ (self , method , hidden_size ):
330331 super (Attn , self ).__init__ ()
331332 self .method = method
332333 if self .method not in ['dot' , 'general' , 'concat' ]:
333334 raise ValueError (self .method , "is not an appropriate attention method." )
334335 self .hidden_size = hidden_size
335336 if self .method == 'general' :
336- self .attn = torch . nn .Linear (self .hidden_size , hidden_size )
337+ self .attn = nn .Linear (self .hidden_size , hidden_size )
337338 elif self .method == 'concat' :
338- self .attn = torch . nn .Linear (self .hidden_size * 2 , hidden_size )
339- self .v = torch . nn .Parameter (torch .FloatTensor (hidden_size ))
339+ self .attn = nn .Linear (self .hidden_size * 2 , hidden_size )
340+ self .v = nn .Parameter (torch .FloatTensor (hidden_size ))
340341
341342 def dot_score (self , hidden , encoder_output ):
342343 return torch .sum (hidden * encoder_output , dim = 2 )
@@ -383,14 +384,14 @@ def forward(self, hidden, encoder_outputs):
383384# weighted sum indicating what parts of the encoder’s output to pay
384385# attention to. From here, we use a linear layer and softmax normalization
385386# to select the next word in the output sequence.
386- #
387- # Hybrid Frontend Notes:
387+
388+ # TorchScript Notes:
388389# ~~~~~~~~~~~~~~~~~~~~~~
389390#
390391# Similarly to the ``EncoderRNN``, this module does not contain any
391392# data-dependent control flow. Therefore, we can once again use
392- # **tracing** to convert this model to Torch Script after it is
393- # initialized and its parameters are loaded.
393+ # **tracing** to convert this model to TorchScript after it
394+ # is initialized and its parameters are loaded.
394395#
395396
396397class LuongAttnDecoderRNN (nn .Module ):
@@ -465,18 +466,18 @@ def forward(self, input_step, last_hidden, encoder_outputs):
465466# terminates either if the ``decoded_words`` list has reached a length of
466467# *MAX_LENGTH* or if the predicted word is the *EOS_token*.
467468#
468- # Hybrid Frontend Notes:
469+ # TorchScript Notes:
469470# ~~~~~~~~~~~~~~~~~~~~~~
470471#
471472# The ``forward`` method of this module involves iterating over the range
472473# of :math:`[0, max\_length)` when decoding an output sequence one word at
473474# a time. Because of this, we should use **scripting** to convert this
474- # module to Torch Script . Unlike with our encoder and decoder models,
475+ # module to TorchScript . Unlike with our encoder and decoder models,
475476# which we can trace, we must make some necessary changes to the
476477# ``GreedySearchDecoder`` module in order to initialize an object without
477478# error. In other words, we must ensure that our module adheres to the
478- # rules of the scripting mechanism, and does not utilize any language
479- # features outside of the subset of Python that Torch Script includes.
479+ # rules of the TorchScript mechanism, and does not utilize any language
480+ # features outside of the subset of Python that TorchScript includes.
480481#
481482# To get an idea of some manipulations that may be required, we will go
482483# over the diffs between the ``GreedySearchDecoder`` implementation from
@@ -491,12 +492,6 @@ def forward(self, input_step, last_hidden, encoder_outputs):
491492# Changes:
492493# ^^^^^^^^
493494#
494- # - ``nn.Module`` -> ``torch.jit.ScriptModule``
495- #
496- # - In order to use PyTorch’s scripting mechanism on a module, that
497- # module must inherit from the ``torch.jit.ScriptModule``.
498- #
499- #
500495# - Added ``decoder_n_layers`` to the constructor arguments
501496#
502497# - This change stems from the fact that the encoder and decoder
@@ -523,16 +518,9 @@ def forward(self, input_step, last_hidden, encoder_outputs):
523518# ``self._SOS_token``.
524519#
525520#
526- # - Add the ``torch.jit.script_method`` decorator to the ``forward``
527- # method
528- #
529- # - Adding this decorator lets the JIT compiler know that the function
530- # that it is decorating should be scripted.
531- #
532- #
533521# - Enforce types of ``forward`` method arguments
534522#
535- # - By default, all parameters to a Torch Script function are assumed
523+ # - By default, all parameters to a TorchScript function are assumed
536524# to be Tensor. If we need to pass an argument of a different type,
537525# we can use function type annotations as introduced in `PEP
538526# 3107 <https://www.python.org/dev/peps/pep-3107/>`__. In addition,
@@ -553,7 +541,7 @@ def forward(self, input_step, last_hidden, encoder_outputs):
553541# ``self._SOS_token``.
554542#
555543
556- class GreedySearchDecoder (torch . jit . ScriptModule ):
544+ class GreedySearchDecoder (nn . Module ):
557545 def __init__ (self , encoder , decoder , decoder_n_layers ):
558546 super (GreedySearchDecoder , self ).__init__ ()
559547 self .encoder = encoder
@@ -564,7 +552,6 @@ def __init__(self, encoder, decoder, decoder_n_layers):
564552
565553 __constants__ = ['_device' , '_SOS_token' , '_decoder_n_layers' ]
566554
567- @torch .jit .script_method
568555 def forward (self , input_seq : torch .Tensor , input_length : torch .Tensor , max_length : int ):
569556 # Forward input through encoder model
570557 encoder_outputs , encoder_hidden = self .encoder (input_seq , input_length )
@@ -613,7 +600,7 @@ def forward(self, input_seq : torch.Tensor, input_length : torch.Tensor, max_len
613600# an argument, normalizes it, evaluates it, and prints the response.
614601#
615602
616- def evaluate (encoder , decoder , searcher , voc , sentence , max_length = MAX_LENGTH ):
603+ def evaluate (searcher , voc , sentence , max_length = MAX_LENGTH ):
617604 ### Format input sentence as a batch
618605 # words -> indexes
619606 indexes_batch = [indexesFromSentence (voc , sentence )]
@@ -632,7 +619,7 @@ def evaluate(encoder, decoder, searcher, voc, sentence, max_length=MAX_LENGTH):
632619
633620
634621# Evaluate inputs from user input (stdin)
635- def evaluateInput (encoder , decoder , searcher , voc ):
622+ def evaluateInput (searcher , voc ):
636623 input_sentence = ''
637624 while (1 ):
638625 try :
@@ -643,7 +630,7 @@ def evaluateInput(encoder, decoder, searcher, voc):
643630 # Normalize sentence
644631 input_sentence = normalizeString (input_sentence )
645632 # Evaluate sentence
646- output_words = evaluate (encoder , decoder , searcher , voc , input_sentence )
633+ output_words = evaluate (searcher , voc , input_sentence )
647634 # Format and print response sentence
648635 output_words [:] = [x for x in output_words if not (x == 'EOS' or x == 'PAD' )]
649636 print ('Bot:' , ' ' .join (output_words ))
@@ -652,12 +639,12 @@ def evaluateInput(encoder, decoder, searcher, voc):
652639 print ("Error: Encountered unknown word." )
653640
654641# Normalize input sentence and call evaluate()
655- def evaluateExample (sentence , encoder , decoder , searcher , voc ):
642+ def evaluateExample (sentence , searcher , voc ):
656643 print ("> " + sentence )
657644 # Normalize sentence
658645 input_sentence = normalizeString (sentence )
659646 # Evaluate sentence
660- output_words = evaluate (encoder , decoder , searcher , voc , input_sentence )
647+ output_words = evaluate (searcher , voc , input_sentence )
661648 output_words [:] = [x for x in output_words if not (x == 'EOS' or x == 'PAD' )]
662649 print ('Bot:' , ' ' .join (output_words ))
663650
@@ -700,14 +687,17 @@ def evaluateExample(sentence, encoder, decoder, searcher, voc):
700687# ``checkpoint = torch.load(loadFilename, map_location=torch.device('cpu'))``
701688# line.
702689#
703- # Hybrid Frontend Notes:
690+ # TorchScript Notes:
704691# ~~~~~~~~~~~~~~~~~~~~~~
705692#
706693# Notice that we initialize and load parameters into our encoder and
707- # decoder models as usual. Also, we must call ``.to(device)`` to set the
708- # device options of the models and ``.eval()`` to set the dropout layers
709- # to test mode **before** we trace the models. ``TracedModule`` objects do
710- # not inherit the ``to`` or ``eval`` methods.
694+ # decoder models as usual. If you are using tracing mode(`torch.jit.trace`)
695+ # for some part of your models, you must call .to(device) to set the device
696+ # options of the models and .eval() to set the dropout layers to test mode
697+ # **before** tracing the models. `TracedModule` objects do not inherit the
698+ # ``to`` or ``eval`` methods. Since in this tutorial we are only using
699+ # scripting instead of tracing, we only need to do this before we do
700+ # evaluation (which is the same as we normally do in eager mode).
711701#
712702
713703save_dir = os .path .join ("data" , "save" )
@@ -766,16 +756,14 @@ def evaluateExample(sentence, encoder, decoder, searcher, voc):
766756
767757
768758######################################################################
769- # Convert Model to Torch Script
759+ # Convert Model to TorchScript
770760# -----------------------------
771761#
772762# Encoder
773763# ~~~~~~~
774764#
775- # As previously mentioned, to convert the encoder model to Torch Script,
776- # we use **tracing**. Tracing any module requires running an example input
777- # through the model’s ``forward`` method and trace the computational graph
778- # that the data encounters. The encoder model takes an input sequence and
765+ # As previously mentioned, to convert the encoder model to TorchScript,
766+ # we use **scripting**. The encoder model takes an input sequence and
779767# a corresponding lengths tensor. Therefore, we create an example input
780768# sequence tensor ``test_seq``, which is of appropriate size (MAX_LENGTH,
781769# 1), contains numbers in the appropriate range
@@ -803,13 +791,13 @@ def evaluateExample(sentence, encoder, decoder, searcher, voc):
803791# ~~~~~~~~~~~~~~~~~~~
804792#
805793# Recall that we scripted our searcher module due to the presence of
806- # data-dependent control flow. In the case of scripting, we do the
807- # conversion work up front by adding the decorator and making sure the
808- # implementation complies with scripting rules . We initialize the scripted
809- # searcher the same way that we would initialize an un-scripted variant.
794+ # data-dependent control flow. In the case of scripting, we do necessary
795+ # language changes to make sure the implementation complies with
796+ # TorchScript . We initialize the scripted searcher the same way that we
797+ # would initialize an un-scripted variant.
810798#
811799
812- ### Convert encoder model
800+ ### Compile the whole greedy search model to TorchScript model
813801# Create artificial inputs
814802test_seq = torch .LongTensor (MAX_LENGTH , 1 ).random_ (0 , voc .num_words ).to (device )
815803test_seq_length = torch .LongTensor ([test_seq .size ()[0 ]]).to (device )
@@ -824,19 +812,21 @@ def evaluateExample(sentence, encoder, decoder, searcher, voc):
824812# Trace the model
825813traced_decoder = torch .jit .trace (decoder , (test_decoder_input , test_decoder_hidden , test_encoder_outputs ))
826814
827- ### Initialize searcher module
828- scripted_searcher = GreedySearchDecoder (traced_encoder , traced_decoder , decoder .n_layers )
815+ ### Initialize searcher module by wrapping ``torch.jit.script`` call
816+ scripted_searcher = torch .jit .script (GreedySearchDecoder (traced_encoder , traced_decoder , decoder .n_layers ))
817+
818+
829819
830820
831821######################################################################
832822# Print Graphs
833823# ------------
834824#
835- # Now that our models are in Torch Script form, we can print the graphs of
825+ # Now that our models are in TorchScript form, we can print the graphs of
836826# each to ensure that we captured the computational graph appropriately.
837- # Since our ``scripted_searcher`` contains our ``traced_encoder`` and
838- # ``traced_decoder``, these graphs will print inline.
839- #
827+ # Since TorchScript allow us to recursively compile the whole model
828+ # hierarchy and inline the ``encoder`` and ``decoder`` graph into a single
829+ # graph, we just need to print the `scripted_searcher` graph
840830
841831print ('scripted_searcher graph:\n ' , scripted_searcher .graph )
842832
@@ -845,19 +835,25 @@ def evaluateExample(sentence, encoder, decoder, searcher, voc):
845835# Run Evaluation
846836# --------------
847837#
848- # Finally, we will run evaluation of the chatbot model using the Torch
849- # Script models. If converted correctly, the models will behave exactly as
850- # they would in their eager-mode representation.
838+ # Finally, we will run evaluation of the chatbot model using the TorchScript
839+ # models. If converted correctly, the models will behave exactly as they
840+ # would in their eager-mode representation.
851841#
852842# By default, we evaluate a few common query sentences. If you want to
853843# chat with the bot yourself, uncomment the ``evaluateInput`` line and
854844# give it a spin.
855845#
856846
847+
848+ # Use appropriate device
849+ scripted_searcher .to (device )
850+ # Set dropout layers to eval mode
851+ scripted_searcher .eval ()
852+
857853# Evaluate examples
858854sentences = ["hello" , "what's up?" , "who are you?" , "where am I?" , "where are you from?" ]
859855for s in sentences :
860- evaluateExample (s , traced_encoder , traced_decoder , scripted_searcher , voc )
856+ evaluateExample (s , scripted_searcher , voc )
861857
862858# Evaluate your input
863859#evaluateInput(traced_encoder, traced_decoder, scripted_searcher, voc)
@@ -867,7 +863,7 @@ def evaluateExample(sentence, encoder, decoder, searcher, voc):
867863# Save Model
868864# ----------
869865#
870- # Now that we have successfully converted our model to Torch Script , we
866+ # Now that we have successfully converted our model to TorchScript , we
871867# will serialize it for use in a non-Python deployment environment. To do
872868# this, we can simply save our ``scripted_searcher`` module, as this is
873869# the user-facing interface for running inference against the chatbot
0 commit comments