Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions _static/imagenet_class_index.json

Large diffs are not rendered by default.

Binary file added _static/img/sample_file.jpeg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
128 changes: 65 additions & 63 deletions intermediate_source/flask_rest_api_tutorial.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ def transform_image(image_bytes):
# and returns a tensor. To test the above method, read an image file in
# bytes mode and see if you get a tensor back:

with open('sample_file.jpeg', 'rb') as f:
with open("../_static/img/sample_file.jpeg", 'rb') as f:
image_bytes = f.read()
tensor = transform_image(image_bytes=image_bytes)
print(tensor)
Expand Down Expand Up @@ -176,7 +176,7 @@ def get_prediction(image_bytes):

import json

imagenet_class_index = json.load(open('imagenet_class_index.json'))
imagenet_class_index = json.load(open('../_static/imagenet_class_index.json'))

def get_prediction(image_bytes):
tensor = transform_image(image_bytes=image_bytes)
Expand All @@ -193,7 +193,7 @@ def get_prediction(image_bytes):
# We will test our above method:


with open('sample_file.jpeg', 'rb') as f:
with open("../_static/img/sample_file.jpeg", 'rb') as f:
image_bytes = f.read()
print(get_prediction(image_bytes=image_bytes))

Expand All @@ -207,7 +207,7 @@ def get_prediction(image_bytes):
# readable name.
#
# .. Note ::
# Did you notice that why ``model`` variable is not part of ``get_prediction``
# Did you notice that ``model`` variable is not part of ``get_prediction``
# method? Or why is model a global variable? Loading a model can be an
# expensive operation in terms of memory and compute. If we loaded the model in the
# ``get_prediction`` method, then it would get unnecessarily loaded every
Expand All @@ -226,69 +226,71 @@ def get_prediction(image_bytes):
# In this final part we will add our model to our Flask API server. Since
# our API server is supposed to take an image file, we will update our ``predict``
# method to read files from the requests:

from flask import request


@app.route('/predict', methods=['POST'])
def predict():
if request.method == 'POST':
# we will get the file from the request
file = request.files['file']
# convert that to bytes
img_bytes = file.read()
class_id, class_name = get_prediction(image_bytes=img_bytes)
return jsonify({'class_id': class_id, 'class_name': class_name})
#
# .. code-block:: python
#
# from flask import request
#
# @app.route('/predict', methods=['POST'])
# def predict():
# if request.method == 'POST':
# # we will get the file from the request
# file = request.files['file']
# # convert that to bytes
# img_bytes = file.read()
# class_id, class_name = get_prediction(image_bytes=img_bytes)
# return jsonify({'class_id': class_id, 'class_name': class_name})

######################################################################
# The ``app.py`` file is now complete. Following is the full version:
#

import io
import json

from torchvision import models
import torchvision.transforms as transforms
from PIL import Image
from flask import Flask, jsonify, request


app = Flask(__name__)
imagenet_class_index = json.load(open('imagenet_class_index.json'))
model = models.densenet121(pretrained=True)
model.eval()


def transform_image(image_bytes):
my_transforms = transforms.Compose([transforms.Resize(255),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(
[0.485, 0.456, 0.406],
[0.229, 0.224, 0.225])])
image = Image.open(io.BytesIO(image_bytes))
return my_transforms(image).unsqueeze(0)


def get_prediction(image_bytes):
tensor = transform_image(image_bytes=image_bytes)
outputs = model.forward(tensor)
_, y_hat = outputs.max(1)
predicted_idx = str(y_hat.item())
return imagenet_class_index[predicted_idx]


@app.route('/predict', methods=['POST'])
def predict():
if request.method == 'POST':
file = request.files['file']
img_bytes = file.read()
class_id, class_name = get_prediction(image_bytes=img_bytes)
return jsonify({'class_id': class_id, 'class_name': class_name})


if __name__ == '__main__':
app.run()
# .. code-block:: python
#
# import io
# import json
#
# from torchvision import models
# import torchvision.transforms as transforms
# from PIL import Image
# from flask import Flask, jsonify, request
#
#
# app = Flask(__name__)
# imagenet_class_index = json.load(open('imagenet_class_index.json'))
# model = models.densenet121(pretrained=True)
# model.eval()
#
#
# def transform_image(image_bytes):
# my_transforms = transforms.Compose([transforms.Resize(255),
# transforms.CenterCrop(224),
# transforms.ToTensor(),
# transforms.Normalize(
# [0.485, 0.456, 0.406],
# [0.229, 0.224, 0.225])])
# image = Image.open(io.BytesIO(image_bytes))
# return my_transforms(image).unsqueeze(0)
#
#
# def get_prediction(image_bytes):
# tensor = transform_image(image_bytes=image_bytes)
# outputs = model.forward(tensor)
# _, y_hat = outputs.max(1)
# predicted_idx = str(y_hat.item())
# return imagenet_class_index[predicted_idx]
#
#
# @app.route('/predict', methods=['POST'])
# def predict():
# if request.method == 'POST':
# file = request.files['file']
# img_bytes = file.read()
# class_id, class_name = get_prediction(image_bytes=img_bytes)
# return jsonify({'class_id': class_id, 'class_name': class_name})
#
#
# if __name__ == '__main__':
# app.run()

######################################################################
# Let's test our web server! Run:
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ torchvision
PyHamcrest
bs4
awscli==1.16.35
flask

# PyTorch Theme
-e git+git://github.com/pytorch/pytorch_sphinx_theme.git#egg=pytorch_sphinx_theme
Expand Down