Skip to content

Commit e8cf0c5

Browse files
authored
Merge pull request #588 from SethHWeidman/flask_tutorial
Flask tutorial
2 parents 039209e + 1fec09f commit e8cf0c5

File tree

1 file changed

+36
-29
lines changed

1 file changed

+36
-29
lines changed

intermediate_source/flask_rest_api_tutorial.py

Lines changed: 36 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -130,9 +130,10 @@ def transform_image(image_bytes):
130130

131131

132132
######################################################################
133-
# Above method takes image data in bytes, applies the series of transforms
133+
# The above method takes image data in bytes, applies the series of transforms
134134
# and returns a tensor. To test the above method, read an image file in
135-
# bytes mode and see if you get a tensor back:
135+
# bytes mode (first replacing `../_static/img/sample_file.jpeg` with the actual
136+
# path to the file on your computer) and see if you get a tensor back:
136137

137138
with open("../_static/img/sample_file.jpeg", 'rb') as f:
138139
image_bytes = f.read()
@@ -168,11 +169,12 @@ def get_prediction(image_bytes):
168169
# The tensor ``y_hat`` will contain the index of the predicted class id.
169170
# However, we need a human readable class name. For that we need a class id
170171
# to name mapping. Download
171-
# `this file <https://s3.amazonaws.com/deep-learning-models/image-models/imagenet_class_index.json>`_
172-
# and place it in current directory as file ``imagenet_class_index.json``.
173-
# This file contains the mapping of ImageNet class id to ImageNet class
174-
# name. We will load this JSON file and get the class name of the
175-
# predicted index.
172+
# `this file <https://s3.amazonaws.com/deep-learning-models/image-models/imagenet_class_index.json>`_
173+
# as ``imagenet_class_index.json`` and remember where you saved it (or, if you
174+
# are following the exact steps in this tutorial, save it in
175+
# `tutorials/_static`). This file contains the mapping of ImageNet class id to
176+
# ImageNet class name. We will load this JSON file and get the class name of
177+
# the predicted index.
176178

177179
import json
178180

@@ -228,7 +230,7 @@ def get_prediction(image_bytes):
228230
# method to read files from the requests:
229231
#
230232
# .. code-block:: python
231-
#
233+
#
232234
# from flask import request
233235
#
234236
# @app.route('/predict', methods=['POST'])
@@ -242,25 +244,26 @@ def get_prediction(image_bytes):
242244
# return jsonify({'class_id': class_id, 'class_name': class_name})
243245

244246
######################################################################
245-
# The ``app.py`` file is now complete. Following is the full version:
247+
# The ``app.py`` file is now complete. Following is the full version; replace
248+
# the paths with the paths where you saved your files and it should run:
246249
#
247250
# .. code-block:: python
248-
#
251+
#
249252
# import io
250253
# import json
251-
#
254+
#
252255
# from torchvision import models
253256
# import torchvision.transforms as transforms
254257
# from PIL import Image
255258
# from flask import Flask, jsonify, request
256-
#
257-
#
259+
#
260+
#
258261
# app = Flask(__name__)
259-
# imagenet_class_index = json.load(open('imagenet_class_index.json'))
262+
# imagenet_class_index = json.load(open('<PATH/TO/.json/FILE>/imagenet_class_index.json'))
260263
# model = models.densenet121(pretrained=True)
261264
# model.eval()
262-
#
263-
#
265+
#
266+
#
264267
# def transform_image(image_bytes):
265268
# my_transforms = transforms.Compose([transforms.Resize(255),
266269
# transforms.CenterCrop(224),
@@ -270,25 +273,25 @@ def get_prediction(image_bytes):
270273
# [0.229, 0.224, 0.225])])
271274
# image = Image.open(io.BytesIO(image_bytes))
272275
# return my_transforms(image).unsqueeze(0)
273-
#
274-
#
276+
#
277+
#
275278
# def get_prediction(image_bytes):
276279
# tensor = transform_image(image_bytes=image_bytes)
277280
# outputs = model.forward(tensor)
278281
# _, y_hat = outputs.max(1)
279282
# predicted_idx = str(y_hat.item())
280283
# return imagenet_class_index[predicted_idx]
281-
#
282-
#
284+
#
285+
#
283286
# @app.route('/predict', methods=['POST'])
284287
# def predict():
285288
# if request.method == 'POST':
286289
# file = request.files['file']
287290
# img_bytes = file.read()
288291
# class_id, class_name = get_prediction(image_bytes=img_bytes)
289292
# return jsonify({'class_id': class_id, 'class_name': class_name})
290-
#
291-
#
293+
#
294+
#
292295
# if __name__ == '__main__':
293296
# app.run()
294297

@@ -300,20 +303,24 @@ def get_prediction(image_bytes):
300303
# $ FLASK_ENV=development FLASK_APP=app.py flask run
301304

302305
#######################################################################
303-
# We can use a command line tool like curl or `Postman <https://www.getpostman.com/>`_ to send requests to
304-
# this webserver:
306+
# We can use the
307+
# `requests <https://pypi.org/project/requests/>`_
308+
# library to send a POST request to our app:
305309
#
306-
# ::
310+
# .. code-block:: python
307311
#
308-
# $ curl -X POST -F file=@cat_pic.jpeg http://localhost:5000/predict
312+
# import requests
309313
#
310-
# You will get a response in the form:
314+
# resp = requests.post("http://localhost:5000/predict",
315+
# files={"file": open('<PATH/TO/.jpg/FILE>/cat.jpg','rb')})
316+
317+
#######################################################################
318+
# Printing `resp.json()` will now show the following:
311319
#
312320
# ::
313321
#
314322
# {"class_id": "n02124075", "class_name": "Egyptian_cat"}
315323
#
316-
#
317324

318325
######################################################################
319326
# Next steps
@@ -342,4 +349,4 @@ def get_prediction(image_bytes):
342349
#
343350
# - You can also add a UI by creating a page with a form which takes the image and
344351
# displays the prediction. Check out the `demo <https://pytorch-imagenet.herokuapp.com/>`_
345-
# of a similar project and its `source code <https://github.com/avinassh/pytorch-flask-api-heroku>`_.
352+
# of a similar project and its `source code <https://github.com/avinassh/pytorch-flask-api-heroku>`_.

0 commit comments

Comments
 (0)