Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file modified _static/img/chatbot/diff.png
100755 → 100644
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified _static/img/chatbot/pytorch_workflow.png
100755 → 100644
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file removed _static/img/hybrid.png
Binary file not shown.
158 changes: 77 additions & 81 deletions beginner_source/deploy_seq2seq_hybrid_frontend_tutorial.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,24 @@
# -*- coding: utf-8 -*-
"""
Deploying a Seq2Seq Model with the Hybrid Frontend
Deploying a Seq2Seq Model with TorchScript
==================================================
**Author:** `Matthew Inkawhich <https://github.com/MatthewInkawhich>`_
"""


######################################################################
# This tutorial will walk through the process of transitioning a
# sequence-to-sequence model to Torch Script using PyTorch’s Hybrid
# Frontend. The model that we will convert is the chatbot model from the
# `Chatbot tutorial <https://pytorch.org/tutorials/beginner/chatbot_tutorial.html>`__.
# sequence-to-sequence model to TorchScript using the TorchScript
# API. The model that we will convert is the chatbot model from the
# `Chatbot tutorial <https://pytorch.org/tutorials/beginner/chatbot_tutorial.html>`__.
# You can either treat this tutorial as a “Part 2” to the Chatbot tutorial
# and deploy your own pretrained model, or you can start with this
# document and use a pretrained model that we host. In the latter case,
# you can reference the original Chatbot tutorial for details
# regarding data preprocessing, model theory and definition, and model
# training.
#
# What is the Hybrid Frontend?
# What is TorchScript?
# ----------------------------
#
# During the research and development phase of a deep learning-based
Expand All @@ -34,13 +34,13 @@
# to target highly optimized hardware architectures. Also, a graph-based
# representation enables framework-agnostic model exportation. PyTorch
# provides mechanisms for incrementally converting eager-mode code into
# Torch Script, a statically analyzable and optimizable subset of Python
# TorchScript, a statically analyzable and optimizable subset of Python
# that Torch uses to represent deep learning programs independently from
# the Python runtime.
#
# The API for converting eager-mode PyTorch programs into Torch Script is
# The API for converting eager-mode PyTorch programs into TorchScript is
# found in the torch.jit module. This module has two core modalities for
# converting an eager-mode model to a Torch Script graph representation:
# converting an eager-mode model to a TorchScript graph representation:
# **tracing** and **scripting**. The ``torch.jit.trace`` function takes a
# module or function and a set of example inputs. It then runs the example
# input through the function or module while tracing the computational
Expand All @@ -52,19 +52,19 @@
# operations called along the execution route taken by the example input
# will be recorded. In other words, the control flow itself is not
# captured. To convert modules and functions containing data-dependent
# control flow, a **scripting** mechanism is provided. Scripting
# explicitly converts the module or function code to Torch Script,
# including all possible control flow routes. To use script mode, be sure
# to inherit from the the ``torch.jit.ScriptModule`` base class (instead
# of ``torch.nn.Module``) and add a ``torch.jit.script`` decorator to your
# Python function or a ``torch.jit.script_method`` decorator to your
# module’s methods. The one caveat with using scripting is that it only
# supports a restricted subset of Python. For all details relating to the
# supported features, see the Torch Script `language
# reference <https://pytorch.org/docs/master/jit.html>`__. To provide the
# maximum flexibility, the modes of Torch Script can be composed to
# represent your whole program, and these techniques can be applied
# incrementally.
# control flow, a **scripting** mechanism is provided. The
# ``torch.jit.script`` function/decorator takes a module or function and
# does not requires example inputs. Scripting then explicitly converts
# the module or function code to TorchScript, including all control flows.
# One caveat with using scripting is that it only supports a subset of
# Python, so you might need to rewrite the code to make it compatible
# with the TorchScript syntax.
#
# For all details relating to the supported features, see the `TorchScript
# language reference <https://pytorch.org/docs/master/jit.html>`__.
# To provide the maximum flexibility, you can also mix tracing and scripting
# modes together to represent your whole program, and these techniques can
# be applied incrementally.
#
# .. figure:: /_static/img/chatbot/pytorch_workflow.png
# :align: center
Expand Down Expand Up @@ -273,7 +273,7 @@ def indexesFromSentence(voc, sentence):
# used by the ``torch.nn.utils.rnn.pack_padded_sequence`` function when
# padding.
#
# Hybrid Frontend Notes:
# TorchScript Notes:
# ~~~~~~~~~~~~~~~~~~~~~~
#
# Since the encoder’s ``forward`` function does not contain any
Expand All @@ -296,6 +296,7 @@ def __init__(self, hidden_size, embedding, n_layers=1, dropout=0):
dropout=(0 if n_layers == 1 else dropout), bidirectional=True)

def forward(self, input_seq, input_lengths, hidden=None):
# type: (Tensor, Tensor, Optional[Tensor]) -> Tuple[Tensor, Tensor]
# Convert word indexes to embeddings
embedded = self.embedding(input_seq)
# Pack padded batch of sequences for RNN module
Expand Down Expand Up @@ -325,18 +326,18 @@ def forward(self, input_seq, input_lengths, hidden=None):
#

# Luong attention layer
class Attn(torch.nn.Module):
class Attn(nn.Module):
def __init__(self, method, hidden_size):
super(Attn, self).__init__()
self.method = method
if self.method not in ['dot', 'general', 'concat']:
raise ValueError(self.method, "is not an appropriate attention method.")
self.hidden_size = hidden_size
if self.method == 'general':
self.attn = torch.nn.Linear(self.hidden_size, hidden_size)
self.attn = nn.Linear(self.hidden_size, hidden_size)
elif self.method == 'concat':
self.attn = torch.nn.Linear(self.hidden_size * 2, hidden_size)
self.v = torch.nn.Parameter(torch.FloatTensor(hidden_size))
self.attn = nn.Linear(self.hidden_size * 2, hidden_size)
self.v = nn.Parameter(torch.FloatTensor(hidden_size))

def dot_score(self, hidden, encoder_output):
return torch.sum(hidden * encoder_output, dim=2)
Expand Down Expand Up @@ -383,14 +384,14 @@ def forward(self, hidden, encoder_outputs):
# weighted sum indicating what parts of the encoder’s output to pay
# attention to. From here, we use a linear layer and softmax normalization
# to select the next word in the output sequence.
#
# Hybrid Frontend Notes:

# TorchScript Notes:
# ~~~~~~~~~~~~~~~~~~~~~~
#
# Similarly to the ``EncoderRNN``, this module does not contain any
# data-dependent control flow. Therefore, we can once again use
# **tracing** to convert this model to Torch Script after it is
# initialized and its parameters are loaded.
# **tracing** to convert this model to TorchScript after it
# is initialized and its parameters are loaded.
#

class LuongAttnDecoderRNN(nn.Module):
Expand Down Expand Up @@ -465,18 +466,18 @@ def forward(self, input_step, last_hidden, encoder_outputs):
# terminates either if the ``decoded_words`` list has reached a length of
# *MAX_LENGTH* or if the predicted word is the *EOS_token*.
#
# Hybrid Frontend Notes:
# TorchScript Notes:
# ~~~~~~~~~~~~~~~~~~~~~~
#
# The ``forward`` method of this module involves iterating over the range
# of :math:`[0, max\_length)` when decoding an output sequence one word at
# a time. Because of this, we should use **scripting** to convert this
# module to Torch Script. Unlike with our encoder and decoder models,
# module to TorchScript. Unlike with our encoder and decoder models,
# which we can trace, we must make some necessary changes to the
# ``GreedySearchDecoder`` module in order to initialize an object without
# error. In other words, we must ensure that our module adheres to the
# rules of the scripting mechanism, and does not utilize any language
# features outside of the subset of Python that Torch Script includes.
# rules of the TorchScript mechanism, and does not utilize any language
# features outside of the subset of Python that TorchScript includes.
#
# To get an idea of some manipulations that may be required, we will go
# over the diffs between the ``GreedySearchDecoder`` implementation from
Expand All @@ -491,12 +492,6 @@ def forward(self, input_step, last_hidden, encoder_outputs):
# Changes:
# ^^^^^^^^
#
# - ``nn.Module`` -> ``torch.jit.ScriptModule``
#
# - In order to use PyTorch’s scripting mechanism on a module, that
# module must inherit from the ``torch.jit.ScriptModule``.
#
#
# - Added ``decoder_n_layers`` to the constructor arguments
#
# - This change stems from the fact that the encoder and decoder
Expand All @@ -523,16 +518,9 @@ def forward(self, input_step, last_hidden, encoder_outputs):
# ``self._SOS_token``.
#
#
# - Add the ``torch.jit.script_method`` decorator to the ``forward``
# method
#
# - Adding this decorator lets the JIT compiler know that the function
# that it is decorating should be scripted.
#
#
# - Enforce types of ``forward`` method arguments
#
# - By default, all parameters to a Torch Script function are assumed
# - By default, all parameters to a TorchScript function are assumed
# to be Tensor. If we need to pass an argument of a different type,
# we can use function type annotations as introduced in `PEP
# 3107 <https://www.python.org/dev/peps/pep-3107/>`__. In addition,
Expand All @@ -553,7 +541,7 @@ def forward(self, input_step, last_hidden, encoder_outputs):
# ``self._SOS_token``.
#

class GreedySearchDecoder(torch.jit.ScriptModule):
class GreedySearchDecoder(nn.Module):
def __init__(self, encoder, decoder, decoder_n_layers):
super(GreedySearchDecoder, self).__init__()
self.encoder = encoder
Expand All @@ -564,7 +552,6 @@ def __init__(self, encoder, decoder, decoder_n_layers):

__constants__ = ['_device', '_SOS_token', '_decoder_n_layers']

@torch.jit.script_method
def forward(self, input_seq : torch.Tensor, input_length : torch.Tensor, max_length : int):
# Forward input through encoder model
encoder_outputs, encoder_hidden = self.encoder(input_seq, input_length)
Expand Down Expand Up @@ -613,7 +600,7 @@ def forward(self, input_seq : torch.Tensor, input_length : torch.Tensor, max_len
# an argument, normalizes it, evaluates it, and prints the response.
#

def evaluate(encoder, decoder, searcher, voc, sentence, max_length=MAX_LENGTH):
def evaluate(searcher, voc, sentence, max_length=MAX_LENGTH):
### Format input sentence as a batch
# words -> indexes
indexes_batch = [indexesFromSentence(voc, sentence)]
Expand All @@ -632,7 +619,7 @@ def evaluate(encoder, decoder, searcher, voc, sentence, max_length=MAX_LENGTH):


# Evaluate inputs from user input (stdin)
def evaluateInput(encoder, decoder, searcher, voc):
def evaluateInput(searcher, voc):
input_sentence = ''
while(1):
try:
Expand All @@ -643,7 +630,7 @@ def evaluateInput(encoder, decoder, searcher, voc):
# Normalize sentence
input_sentence = normalizeString(input_sentence)
# Evaluate sentence
output_words = evaluate(encoder, decoder, searcher, voc, input_sentence)
output_words = evaluate(searcher, voc, input_sentence)
# Format and print response sentence
output_words[:] = [x for x in output_words if not (x == 'EOS' or x == 'PAD')]
print('Bot:', ' '.join(output_words))
Expand All @@ -652,12 +639,12 @@ def evaluateInput(encoder, decoder, searcher, voc):
print("Error: Encountered unknown word.")

# Normalize input sentence and call evaluate()
def evaluateExample(sentence, encoder, decoder, searcher, voc):
def evaluateExample(sentence, searcher, voc):
print("> " + sentence)
# Normalize sentence
input_sentence = normalizeString(sentence)
# Evaluate sentence
output_words = evaluate(encoder, decoder, searcher, voc, input_sentence)
output_words = evaluate(searcher, voc, input_sentence)
output_words[:] = [x for x in output_words if not (x == 'EOS' or x == 'PAD')]
print('Bot:', ' '.join(output_words))

Expand Down Expand Up @@ -700,14 +687,17 @@ def evaluateExample(sentence, encoder, decoder, searcher, voc):
# ``checkpoint = torch.load(loadFilename, map_location=torch.device('cpu'))``
# line.
#
# Hybrid Frontend Notes:
# TorchScript Notes:
# ~~~~~~~~~~~~~~~~~~~~~~
#
# Notice that we initialize and load parameters into our encoder and
# decoder models as usual. Also, we must call ``.to(device)`` to set the
# device options of the models and ``.eval()`` to set the dropout layers
# to test mode **before** we trace the models. ``TracedModule`` objects do
# not inherit the ``to`` or ``eval`` methods.
# decoder models as usual. If you are using tracing mode(`torch.jit.trace`)
# for some part of your models, you must call .to(device) to set the device
# options of the models and .eval() to set the dropout layers to test mode
# **before** tracing the models. `TracedModule` objects do not inherit the
# ``to`` or ``eval`` methods. Since in this tutorial we are only using
# scripting instead of tracing, we only need to do this before we do
# evaluation (which is the same as we normally do in eager mode).
#

save_dir = os.path.join("data", "save")
Expand Down Expand Up @@ -766,16 +756,14 @@ def evaluateExample(sentence, encoder, decoder, searcher, voc):


######################################################################
# Convert Model to Torch Script
# Convert Model to TorchScript
# -----------------------------
#
# Encoder
# ~~~~~~~
#
# As previously mentioned, to convert the encoder model to Torch Script,
# we use **tracing**. Tracing any module requires running an example input
# through the model’s ``forward`` method and trace the computational graph
# that the data encounters. The encoder model takes an input sequence and
# As previously mentioned, to convert the encoder model to TorchScript,
# we use **scripting**. The encoder model takes an input sequence and
# a corresponding lengths tensor. Therefore, we create an example input
# sequence tensor ``test_seq``, which is of appropriate size (MAX_LENGTH,
# 1), contains numbers in the appropriate range
Expand Down Expand Up @@ -803,13 +791,13 @@ def evaluateExample(sentence, encoder, decoder, searcher, voc):
# ~~~~~~~~~~~~~~~~~~~
#
# Recall that we scripted our searcher module due to the presence of
# data-dependent control flow. In the case of scripting, we do the
# conversion work up front by adding the decorator and making sure the
# implementation complies with scripting rules. We initialize the scripted
# searcher the same way that we would initialize an un-scripted variant.
# data-dependent control flow. In the case of scripting, we do necessary
# language changes to make sure the implementation complies with
# TorchScript. We initialize the scripted searcher the same way that we
# would initialize an un-scripted variant.
#

### Convert encoder model
### Compile the whole greedy search model to TorchScript model
# Create artificial inputs
test_seq = torch.LongTensor(MAX_LENGTH, 1).random_(0, voc.num_words).to(device)
test_seq_length = torch.LongTensor([test_seq.size()[0]]).to(device)
Expand All @@ -824,19 +812,21 @@ def evaluateExample(sentence, encoder, decoder, searcher, voc):
# Trace the model
traced_decoder = torch.jit.trace(decoder, (test_decoder_input, test_decoder_hidden, test_encoder_outputs))

### Initialize searcher module
scripted_searcher = GreedySearchDecoder(traced_encoder, traced_decoder, decoder.n_layers)
### Initialize searcher module by wrapping ``torch.jit.script`` call
scripted_searcher = torch.jit.script(GreedySearchDecoder(traced_encoder, traced_decoder, decoder.n_layers))




######################################################################
# Print Graphs
# ------------
#
# Now that our models are in Torch Script form, we can print the graphs of
# Now that our models are in TorchScript form, we can print the graphs of
# each to ensure that we captured the computational graph appropriately.
# Since our ``scripted_searcher`` contains our ``traced_encoder`` and
# ``traced_decoder``, these graphs will print inline.
#
# Since TorchScript allow us to recursively compile the whole model
# hierarchy and inline the ``encoder`` and ``decoder`` graph into a single
# graph, we just need to print the `scripted_searcher` graph

print('scripted_searcher graph:\n', scripted_searcher.graph)

Expand All @@ -845,19 +835,25 @@ def evaluateExample(sentence, encoder, decoder, searcher, voc):
# Run Evaluation
# --------------
#
# Finally, we will run evaluation of the chatbot model using the Torch
# Script models. If converted correctly, the models will behave exactly as
# they would in their eager-mode representation.
# Finally, we will run evaluation of the chatbot model using the TorchScript
# models. If converted correctly, the models will behave exactly as they
# would in their eager-mode representation.
#
# By default, we evaluate a few common query sentences. If you want to
# chat with the bot yourself, uncomment the ``evaluateInput`` line and
# give it a spin.
#


# Use appropriate device
scripted_searcher.to(device)
# Set dropout layers to eval mode
scripted_searcher.eval()

# Evaluate examples
sentences = ["hello", "what's up?", "who are you?", "where am I?", "where are you from?"]
for s in sentences:
evaluateExample(s, traced_encoder, traced_decoder, scripted_searcher, voc)
evaluateExample(s, scripted_searcher, voc)

# Evaluate your input
#evaluateInput(traced_encoder, traced_decoder, scripted_searcher, voc)
Expand All @@ -867,7 +863,7 @@ def evaluateExample(sentence, encoder, decoder, searcher, voc):
# Save Model
# ----------
#
# Now that we have successfully converted our model to Torch Script, we
# Now that we have successfully converted our model to TorchScript, we
# will serialize it for use in a non-Python deployment environment. To do
# this, we can simply save our ``scripted_searcher`` module, as this is
# the user-facing interface for running inference against the chatbot
Expand Down
Loading