Skip to content

Commit b6a683f

Browse files
authored
Merge pull request #585 from pytorch/flask-fix
[WIP] Flask fix
2 parents 18cada0 + 2963370 commit b6a683f

File tree

4 files changed

+67
-63
lines changed

4 files changed

+67
-63
lines changed

_static/imagenet_class_index.json

Lines changed: 1 addition & 0 deletions
Large diffs are not rendered by default.

_static/img/sample_file.jpeg

43.3 KB
Loading

intermediate_source/flask_rest_api_tutorial.py

Lines changed: 65 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ def transform_image(image_bytes):
134134
# and returns a tensor. To test the above method, read an image file in
135135
# bytes mode and see if you get a tensor back:
136136

137-
with open('sample_file.jpeg', 'rb') as f:
137+
with open("../_static/img/sample_file.jpeg", 'rb') as f:
138138
image_bytes = f.read()
139139
tensor = transform_image(image_bytes=image_bytes)
140140
print(tensor)
@@ -176,7 +176,7 @@ def get_prediction(image_bytes):
176176

177177
import json
178178

179-
imagenet_class_index = json.load(open('imagenet_class_index.json'))
179+
imagenet_class_index = json.load(open('../_static/imagenet_class_index.json'))
180180

181181
def get_prediction(image_bytes):
182182
tensor = transform_image(image_bytes=image_bytes)
@@ -193,7 +193,7 @@ def get_prediction(image_bytes):
193193
# We will test our above method:
194194

195195

196-
with open('sample_file.jpeg', 'rb') as f:
196+
with open("../_static/img/sample_file.jpeg", 'rb') as f:
197197
image_bytes = f.read()
198198
print(get_prediction(image_bytes=image_bytes))
199199

@@ -207,7 +207,7 @@ def get_prediction(image_bytes):
207207
# readable name.
208208
#
209209
# .. Note ::
210-
# Did you notice that why ``model`` variable is not part of ``get_prediction``
210+
# Did you notice that ``model`` variable is not part of ``get_prediction``
211211
# method? Or why is model a global variable? Loading a model can be an
212212
# expensive operation in terms of memory and compute. If we loaded the model in the
213213
# ``get_prediction`` method, then it would get unnecessarily loaded every
@@ -226,69 +226,71 @@ def get_prediction(image_bytes):
226226
# In this final part we will add our model to our Flask API server. Since
227227
# our API server is supposed to take an image file, we will update our ``predict``
228228
# method to read files from the requests:
229-
230-
from flask import request
231-
232-
233-
@app.route('/predict', methods=['POST'])
234-
def predict():
235-
if request.method == 'POST':
236-
# we will get the file from the request
237-
file = request.files['file']
238-
# convert that to bytes
239-
img_bytes = file.read()
240-
class_id, class_name = get_prediction(image_bytes=img_bytes)
241-
return jsonify({'class_id': class_id, 'class_name': class_name})
229+
#
230+
# .. code-block:: python
231+
#
232+
# from flask import request
233+
#
234+
# @app.route('/predict', methods=['POST'])
235+
# def predict():
236+
# if request.method == 'POST':
237+
# # we will get the file from the request
238+
# file = request.files['file']
239+
# # convert that to bytes
240+
# img_bytes = file.read()
241+
# class_id, class_name = get_prediction(image_bytes=img_bytes)
242+
# return jsonify({'class_id': class_id, 'class_name': class_name})
242243

243244
######################################################################
244245
# The ``app.py`` file is now complete. Following is the full version:
245246
#
246-
247-
import io
248-
import json
249-
250-
from torchvision import models
251-
import torchvision.transforms as transforms
252-
from PIL import Image
253-
from flask import Flask, jsonify, request
254-
255-
256-
app = Flask(__name__)
257-
imagenet_class_index = json.load(open('imagenet_class_index.json'))
258-
model = models.densenet121(pretrained=True)
259-
model.eval()
260-
261-
262-
def transform_image(image_bytes):
263-
my_transforms = transforms.Compose([transforms.Resize(255),
264-
transforms.CenterCrop(224),
265-
transforms.ToTensor(),
266-
transforms.Normalize(
267-
[0.485, 0.456, 0.406],
268-
[0.229, 0.224, 0.225])])
269-
image = Image.open(io.BytesIO(image_bytes))
270-
return my_transforms(image).unsqueeze(0)
271-
272-
273-
def get_prediction(image_bytes):
274-
tensor = transform_image(image_bytes=image_bytes)
275-
outputs = model.forward(tensor)
276-
_, y_hat = outputs.max(1)
277-
predicted_idx = str(y_hat.item())
278-
return imagenet_class_index[predicted_idx]
279-
280-
281-
@app.route('/predict', methods=['POST'])
282-
def predict():
283-
if request.method == 'POST':
284-
file = request.files['file']
285-
img_bytes = file.read()
286-
class_id, class_name = get_prediction(image_bytes=img_bytes)
287-
return jsonify({'class_id': class_id, 'class_name': class_name})
288-
289-
290-
if __name__ == '__main__':
291-
app.run()
247+
# .. code-block:: python
248+
#
249+
# import io
250+
# import json
251+
#
252+
# from torchvision import models
253+
# import torchvision.transforms as transforms
254+
# from PIL import Image
255+
# from flask import Flask, jsonify, request
256+
#
257+
#
258+
# app = Flask(__name__)
259+
# imagenet_class_index = json.load(open('imagenet_class_index.json'))
260+
# model = models.densenet121(pretrained=True)
261+
# model.eval()
262+
#
263+
#
264+
# def transform_image(image_bytes):
265+
# my_transforms = transforms.Compose([transforms.Resize(255),
266+
# transforms.CenterCrop(224),
267+
# transforms.ToTensor(),
268+
# transforms.Normalize(
269+
# [0.485, 0.456, 0.406],
270+
# [0.229, 0.224, 0.225])])
271+
# image = Image.open(io.BytesIO(image_bytes))
272+
# return my_transforms(image).unsqueeze(0)
273+
#
274+
#
275+
# def get_prediction(image_bytes):
276+
# tensor = transform_image(image_bytes=image_bytes)
277+
# outputs = model.forward(tensor)
278+
# _, y_hat = outputs.max(1)
279+
# predicted_idx = str(y_hat.item())
280+
# return imagenet_class_index[predicted_idx]
281+
#
282+
#
283+
# @app.route('/predict', methods=['POST'])
284+
# def predict():
285+
# if request.method == 'POST':
286+
# file = request.files['file']
287+
# img_bytes = file.read()
288+
# class_id, class_name = get_prediction(image_bytes=img_bytes)
289+
# return jsonify({'class_id': class_id, 'class_name': class_name})
290+
#
291+
#
292+
# if __name__ == '__main__':
293+
# app.run()
292294

293295
######################################################################
294296
# Let's test our web server! Run:

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ torchvision
1010
PyHamcrest
1111
bs4
1212
awscli==1.16.35
13+
flask
1314

1415
# PyTorch Theme
1516
-e git+git://github.com/pytorch/pytorch_sphinx_theme.git#egg=pytorch_sphinx_theme

0 commit comments

Comments
 (0)