Skip to content

Commit 02aef1d

Browse files
author
Jessica Lin
committed
Minor syntax edits to mobile perf recipe
1 parent 6285b8f commit 02aef1d

File tree

62 files changed

+9879
-13
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

62 files changed

+9879
-13
lines changed

recipes/deployment_with_flask.rst

Lines changed: 284 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,284 @@
1+
Deploying with Flask
2+
====================
3+
4+
In this recipe, you will learn:
5+
6+
- How to wrap your trained PyTorch model in a Flask container to expose
7+
it via a web API
8+
- How to translate incoming web requests into PyTorch tensors for your
9+
model
10+
- How to package your model’s output for an HTTP response
11+
12+
Requirements
13+
------------
14+
15+
You will need a Python 3 environment with the following packages (and
16+
their dependencies) installed:
17+
18+
- PyTorch 1.5
19+
- TorchVision 0.6.0
20+
- Flask 1.1
21+
22+
Optionally, to get some of the supporting files, you'll need git.
23+
24+
The instructions for installing PyTorch and TorchVision are available at
25+
`pytorch.org`_. Instructions for installing Flask are available on `the
26+
Flask site`_.
27+
28+
What is Flask?
29+
--------------
30+
31+
Flask is a lightweight web server written in Python. It provides a
32+
convenient way for you to quickly set up a web API for predictions from
33+
your trained PyTorch model, either for direct use, or as a web service
34+
within a larger system.
35+
36+
Setup and Supporting Files
37+
--------------------------
38+
39+
We're going to create a web service that takes in images, and maps them
40+
to one of the 1000 classes of the ImageNet dataset. To do this, you'll
41+
need an image file for testing. Optionally, you can also get a file that
42+
will map the class index output by the model to a human-readable class
43+
name.
44+
45+
Option 1: To Get Both Files Quickly
46+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
47+
48+
You can pull both of the supporting files quickly by checking out the
49+
TorchServe repository and copying them to your working folder. *(NB:
50+
There is no dependency on TorchServe for this tutorial - it's just a
51+
quick way to get the files.)* Issue the following commands from your
52+
shell prompt:
53+
54+
::
55+
56+
git clone https://github.com/pytorch/serve
57+
cp serve/examples/image_classifier/kitten.jpg .
58+
cp serve/examples/image_classifier/index_to_name.json .
59+
60+
And you've got them!
61+
62+
Option 2: Bring Your Own Image
63+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
64+
65+
The ``index_to_name.json`` file is optional in the Flask service below.
66+
You can test your service with your own image - just make sure it's a
67+
3-color JPEG.
68+
69+
Building Your Flask Service
70+
---------------------------
71+
72+
The full Python script for the Flask service is shown at the end of this
73+
recipe; you can copy and paste that into your own ``app.py`` file. Below
74+
we'll look at individual sections to make their functions clear.
75+
76+
Imports
77+
~~~~~~~
78+
79+
::
80+
81+
import torchvision.models as models
82+
import torchvision.transforms as transforms
83+
from PIL import Image
84+
from flask import Flask, jsonify, request
85+
86+
In order:
87+
88+
- We'll be using a pre-trained DenseNet model from
89+
``torchvision.models``
90+
- ``torchvision.transforms`` contains tools for manipulating your image
91+
data
92+
- Pillow (``PIL``) is what we'll use to load the image file initially
93+
- And of course we'll need classes from ``flask``
94+
95+
Pre-Processing
96+
~~~~~~~~~~~~~~
97+
98+
::
99+
100+
def transform_image(infile):
101+
input_transforms = [transforms.Resize(255),
102+
transforms.CenterCrop(224),
103+
transforms.ToTensor(),
104+
transforms.Normalize([0.485, 0.456, 0.406],
105+
[0.229, 0.224, 0.225])]
106+
my_transforms = transforms.Compose(input_transforms)
107+
image = Image.open(infile)
108+
timg = my_transforms(image)
109+
timg.unsqueeze_(0)
110+
return timg
111+
112+
The web request gave us an image file, but our model expects a PyTorch
113+
tensor of shape (N, 3, 224, 224) where *N* is the number of items in the
114+
input batch. (We will just have a batch size of 1.) The first thing we
115+
do is compose a set of TorchVision transforms that resize and crop the
116+
image, convert it to a tensor, then normalize the values in the tensor.
117+
(For more information on this normalization, see the documentation for
118+
``torchvision.models_``.)
119+
120+
After that, we open the file and apply the transforms. The transforms
121+
return a tensor of shape (3, 224, 224) - the 3 color channels of a
122+
224x224 image. Because we need to make this single image a batch, we use
123+
the ``unsqueeze_(0)`` call to modify the tensor in place by adding a new
124+
first dimension. The tensor contains the same data, but now has shape
125+
(1, 3, 224, 224).
126+
127+
In general, even if you're not working with image data, you will need to
128+
transform the input from your HTTP request into a tensor that PyTorch
129+
can consume.
130+
131+
Inference
132+
~~~~~~~~~
133+
134+
::
135+
136+
def get_prediction(input_tensor):
137+
outputs = model.forward(input_tensor)
138+
_, y_hat = outputs.max(1)
139+
prediction = y_hat.item()
140+
return prediction
141+
142+
The inference itself is the simplest part: When we pass the input tensor
143+
to them model, we get back a tensor of values that represent the model's
144+
estimated likelihood that the image belongs to a particular class. The
145+
``max()`` call finds the class with the maximum likelihood value, and
146+
returns that value with the ImageNet class index. Finally, we extract
147+
that class index from the tensor containing it with the ``item()`` call, and
148+
return it.
149+
150+
Post-Processing
151+
~~~~~~~~~~~~~~~
152+
153+
::
154+
155+
def render_prediction(prediction_idx):
156+
stridx = str(prediction_idx)
157+
class_name = 'Unknown'
158+
if img_class_map is not None:
159+
if stridx in img_class_map is not None:
160+
class_name = img_class_map[stridx][1]
161+
162+
return prediction_idx, class_name
163+
164+
The ``render_prediction()`` method maps the predicted class index to a
165+
human-readable class label. It's typical, after getting the prediction
166+
from your model, to perform post-processing to make the prediction ready
167+
for either human consumption, or for another piece of software.
168+
169+
Running The Full Flask App
170+
--------------------------
171+
172+
Paste the following into a file called ``app.py``:
173+
174+
::
175+
176+
import io
177+
import json
178+
import os
179+
180+
import torchvision.models as models
181+
import torchvision.transforms as transforms
182+
from PIL import Image
183+
from flask import Flask, jsonify, request
184+
185+
186+
app = Flask(__name__)
187+
model = models.densenet121(pretrained=True) # Trained on 1000 classes from ImageNet
188+
model.eval() # Turns off autograd and
189+
190+
191+
192+
img_class_map = None
193+
mapping_file_path = 'index_to_name.json' # Human-readable names for Imagenet classes
194+
if os.path.isfile(mapping_file_path):
195+
with open (mapping_file_path) as f:
196+
img_class_map = json.load(f)
197+
198+
199+
200+
# Transform input into the form our model expects
201+
def transform_image(infile):
202+
input_transforms = [transforms.Resize(255), # We use multiple TorchVision transforms to ready the image
203+
transforms.CenterCrop(224),
204+
transforms.ToTensor(),
205+
transforms.Normalize([0.485, 0.456, 0.406], # Standard normalization for ImageNet model input
206+
[0.229, 0.224, 0.225])]
207+
my_transforms = transforms.Compose(input_transforms)
208+
image = Image.open(infile) # Open the image file
209+
timg = my_transforms(image) # Transform PIL image to appropriately-shaped PyTorch tensor
210+
timg.unsqueeze_(0) # PyTorch models expect batched input; create a batch of 1
211+
return timg
212+
213+
214+
# Get a prediction
215+
def get_prediction(input_tensor):
216+
outputs = model.forward(input_tensor) # Get likelihoods for all ImageNet classes
217+
_, y_hat = outputs.max(1) # Extract the most likely class
218+
prediction = y_hat.item() # Extract the int value from the PyTorch tensor
219+
return prediction
220+
221+
# Make the prediction human-readable
222+
def render_prediction(prediction_idx):
223+
stridx = str(prediction_idx)
224+
class_name = 'Unknown'
225+
if img_class_map is not None:
226+
if stridx in img_class_map is not None:
227+
class_name = img_class_map[stridx][1]
228+
229+
return prediction_idx, class_name
230+
231+
232+
@app.route('/', methods=['GET'])
233+
def root():
234+
return jsonify({'msg' : 'Try POSTing to the /predict endpoint with an RGB image attachment'})
235+
236+
237+
@app.route('/predict', methods=['POST'])
238+
def predict():
239+
if request.method == 'POST':
240+
file = request.files['file']
241+
if file is not None:
242+
input_tensor = transform_image(file)
243+
prediction_idx = get_prediction(input_tensor)
244+
class_id, class_name = render_prediction(prediction_idx)
245+
return jsonify({'class_id': class_id, 'class_name': class_name})
246+
247+
248+
if __name__ == '__main__':
249+
app.run()
250+
251+
To start the server from your shell prompt, issue the following command:
252+
253+
::
254+
255+
FLASK_APP=app.py flask run
256+
257+
By default, your Flask server is listening on port 5000. Once the server
258+
is running, open another terminal window, and test your new inference
259+
server:
260+
261+
::
262+
263+
curl -X POST -H "Content-Type: multipart/form-data" http://localhost:5000/predict -F "[email protected]"
264+
265+
If everything is set up correctly, you should recevie a response similar
266+
to the following:
267+
268+
::
269+
270+
{"class_id":285,"class_name":"Egyptian_cat"}
271+
272+
Important Resources
273+
-------------------
274+
275+
- `pytorch.org`_ for installation instructions, and more documentation
276+
and tutorials
277+
- The `Flask site`_ has a `Quick Start guide`_ that goes into more
278+
detail on setting up a simple Flask service
279+
280+
.. _pytorch.org: https://pytorch.org
281+
.. _Flask site: https://flask.palletsprojects.com/en/1.1.x/
282+
.. _Quick Start guide: https://flask.palletsprojects.com/en/1.1.x/quickstart/
283+
.. _torchvision.models: https://pytorch.org/docs/stable/torchvision/models.html
284+
.. _the Flask site: https://flask.palletsprojects.com/en/1.1.x/installation/

0 commit comments

Comments
 (0)