Skip to content

Commit c48975a

Browse files
committed
add @jwt_optional view decorator and corresponding unit tests
1 parent 9bfa900 commit c48975a

File tree

3 files changed

+215
-3
lines changed

3 files changed

+215
-3
lines changed

flask_jwt_extended/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from .jwt_manager import JWTManager
22
from .view_decorators import (
3-
jwt_required, fresh_jwt_required, jwt_refresh_token_required
3+
jwt_required, fresh_jwt_required, jwt_refresh_token_required,
4+
jwt_optional
45
)
56
from .utils import (
67
create_refresh_token, create_access_token, get_jwt_identity,

flask_jwt_extended/view_decorators.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from functools import wraps
22

33
from flask import request
4+
from jwt.exceptions import InvalidTokenError
45
from werkzeug.security import safe_str_cmp
56
try:
67
from flask import _app_ctx_stack as ctx_stack
@@ -11,7 +12,7 @@
1112
from flask_jwt_extended.config import config
1213
from flask_jwt_extended.exceptions import (
1314
InvalidHeaderError, NoAuthorizationError, WrongTokenError,
14-
FreshTokenRequired, CSRFError
15+
FreshTokenRequired, CSRFError, JWTDecodeError
1516
)
1617
from flask_jwt_extended.tokens import decode_jwt
1718

@@ -36,6 +37,32 @@ def wrapper(*args, **kwargs):
3637
return wrapper
3738

3839

40+
def jwt_optional(fn):
41+
"""
42+
If you decorate a view with this, it will check the request for a valid
43+
JWT and put it into the Flask application context before calling the view.
44+
In case of an error authenticating the JWT, the view is still called and
45+
the application context is unchanged.
46+
47+
:param fn: The view function to decorate
48+
"""
49+
@wraps(fn)
50+
def wrapper(*args, **kwargs):
51+
try:
52+
# If an acceptable JWT is found in the request, put it into
53+
# the application context
54+
jwt_data = _decode_jwt_from_request(request_type='access')
55+
ctx_stack.top.jwt = jwt_data
56+
except (NoAuthorizationError, InvalidHeaderError, WrongTokenError,
57+
InvalidTokenError, JWTDecodeError):
58+
# Stop authorization-related exceptions from being caught
59+
# higher up in the stack
60+
pass
61+
# Return the decorated function in either case
62+
return fn(*args, **kwargs)
63+
return wrapper
64+
65+
3966
def fresh_jwt_required(fn):
4067
"""
4168
If you decorate a vew with this, it will ensure that the requester has a

tests/test_protected_endpoints.py

Lines changed: 185 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
get_jwt_identity, set_refresh_cookies, set_access_cookies, unset_jwt_cookies
1212
from flask_jwt_extended import JWTManager, create_refresh_token, \
1313
jwt_refresh_token_required, create_access_token, fresh_jwt_required, \
14-
jwt_required, get_raw_jwt
14+
jwt_optional, jwt_required, get_raw_jwt
1515

1616

1717
class TestEndpoints(unittest.TestCase):
@@ -55,6 +55,14 @@ def protected():
5555
def fresh_protected():
5656
return jsonify({'msg': "fresh hello world"})
5757

58+
@self.app.route('/partially-protected')
59+
@jwt_optional
60+
def partially_protected():
61+
if get_jwt_identity():
62+
return jsonify({'msg': "protected hello world"})
63+
return jsonify({'msg': "unprotected hello world"})
64+
65+
5866
def _jwt_post(self, url, jwt):
5967
response = self.client.post(url, content_type='application/json',
6068
headers={'Authorization': 'Bearer {}'.format(jwt)})
@@ -124,6 +132,32 @@ def test_jwt_required(self):
124132
self.assertEqual(status, 200)
125133
self.assertEqual(data, {'msg': 'hello world'})
126134

135+
def test_jwt_optional_no_jwt(self):
136+
response = self.client.get('/partially-protected')
137+
data = json.loads(response.get_data(as_text=True))
138+
status = response.status_code
139+
self.assertEqual(status, 200)
140+
self.assertEqual(data, {'msg': 'unprotected hello world'})
141+
142+
def test_jwt_optional_with_jwt(self):
143+
response = self.client.post('/auth/login')
144+
data = json.loads(response.get_data(as_text=True))
145+
fresh_access_token = data['access_token']
146+
refresh_token = data['refresh_token']
147+
148+
# Test it works with a fresh token
149+
status, data = self._jwt_get('/partially-protected',
150+
fresh_access_token)
151+
self.assertEqual(data, {'msg': 'protected hello world'})
152+
self.assertEqual(status, 200)
153+
154+
# Test it works with a non-fresh access token
155+
_, data = self._jwt_post('/auth/refresh', refresh_token)
156+
non_fresh_token = data['access_token']
157+
status, data = self._jwt_get('/partially-protected', non_fresh_token)
158+
self.assertEqual(status, 200)
159+
self.assertEqual(data, {'msg': 'protected hello world'})
160+
127161
def test_jwt_required_wrong_token(self):
128162
response = self.client.post('/auth/login')
129163
data = json.loads(response.get_data(as_text=True))
@@ -133,6 +167,17 @@ def test_jwt_required_wrong_token(self):
133167
status, text = self._jwt_get('/protected', refresh_token)
134168
self.assertEqual(status, 422)
135169

170+
def test_jwt_optional_wrong_token(self):
171+
response = self.client.post('/auth/login')
172+
data = json.loads(response.get_data(as_text=True))
173+
refresh_token = data['refresh_token']
174+
175+
# Should work with refresh token, but not return data
176+
# that has been manually protected in the view
177+
status, text = self._jwt_get('/partially-protected', refresh_token)
178+
self.assertEqual(status, 200)
179+
self.assertEqual(text, {'msg': 'unprotected hello world'})
180+
136181
def test_fresh_jwt_required(self):
137182
response = self.client.post('/auth/login')
138183
data = json.loads(response.get_data(as_text=True))
@@ -209,6 +254,45 @@ def test_bad_jwt_requests(self):
209254
self.assertEqual(status_code, 422)
210255
self.assertIn('msg', data)
211256

257+
def test_optional_bad_jwt_requests(self):
258+
response = self.client.post('/auth/login')
259+
data = json.loads(response.get_data(as_text=True))
260+
access_token = data['access_token']
261+
262+
# Test with no authorization header
263+
response = self.client.get('/partially-protected')
264+
data = json.loads(response.get_data(as_text=True))
265+
status_code = response.status_code
266+
self.assertEqual(status_code, 200)
267+
self.assertEqual(data, {'msg': 'unprotected hello world'})
268+
269+
# Test with missing type in authorization header
270+
auth_header = access_token
271+
response = self.client.get('/partially-protected',
272+
headers={'Authorization': auth_header})
273+
data = json.loads(response.get_data(as_text=True))
274+
status_code = response.status_code
275+
self.assertEqual(status_code, 200)
276+
self.assertEqual(data, {'msg': 'unprotected hello world'})
277+
278+
# Test with type not being Bearer in authorization header
279+
auth_header = "BANANA {}".format(access_token)
280+
response = self.client.get('/partially-protected',
281+
headers={'Authorization': auth_header})
282+
data = json.loads(response.get_data(as_text=True))
283+
status_code = response.status_code
284+
self.assertEqual(status_code, 200)
285+
self.assertEqual(data, {'msg': 'unprotected hello world'})
286+
287+
# Test with too many items in auth header
288+
auth_header = "Bearer {} BANANA".format(access_token)
289+
response = self.client.get('/partially-protected',
290+
headers={'Authorization': auth_header})
291+
data = json.loads(response.get_data(as_text=True))
292+
status_code = response.status_code
293+
self.assertEqual(status_code, 200)
294+
self.assertEqual(data, {'msg': 'unprotected hello world'})
295+
212296
def test_bad_tokens(self):
213297
# Test expired access token
214298
response = self.client.post('/auth/login')
@@ -267,6 +351,54 @@ def test_bad_tokens(self):
267351
self.assertEqual(status_code, 422)
268352
self.assertIn('msg', data)
269353

354+
def test_optional_jwt_bad_tokens(self):
355+
# Test expired access token
356+
response = self.client.post('/auth/login')
357+
data = json.loads(response.get_data(as_text=True))
358+
access_token = data['access_token']
359+
status_code, data = self._jwt_get('/partially-protected', access_token)
360+
self.assertEqual(status_code, 200)
361+
self.assertEqual(data, {'msg': 'protected hello world'})
362+
time.sleep(2)
363+
status_code, data = self._jwt_get('/partially-protected', access_token)
364+
self.assertEqual(status_code, 200)
365+
self.assertEqual(data, {'msg': 'unprotected hello world'})
366+
367+
# Test Bogus token
368+
auth_header = "Bearer {}".format('this_is_totally_an_access_token')
369+
response = self.client.get('/partially-protected',
370+
headers={'Authorization': auth_header})
371+
data = json.loads(response.get_data(as_text=True))
372+
status_code = response.status_code
373+
self.assertEqual(status_code, 200)
374+
self.assertEqual(data, {'msg': 'unprotected hello world'})
375+
376+
# Test token that was signed with a different key
377+
with self.app.test_request_context():
378+
token = encode_access_token('foo', 'newsecret', 'HS256',
379+
timedelta(minutes=5), True, {},
380+
csrf=False)
381+
auth_header = "Bearer {}".format(token)
382+
response = self.client.get('/partially-protected',
383+
headers={'Authorization': auth_header})
384+
data = json.loads(response.get_data(as_text=True))
385+
status_code = response.status_code
386+
self.assertEqual(status_code, 200)
387+
self.assertEqual(data, {'msg': 'unprotected hello world'})
388+
389+
# Test with valid token that is missing required claims
390+
now = datetime.utcnow()
391+
token_data = {'exp': now + timedelta(minutes=5)}
392+
encoded_token = jwt.encode(token_data, self.app.config['SECRET_KEY'],
393+
self.app.config['JWT_ALGORITHM']).decode('utf-8')
394+
auth_header = "Bearer {}".format(encoded_token)
395+
response = self.client.get('/partially-protected',
396+
headers={'Authorization': auth_header})
397+
data = json.loads(response.get_data(as_text=True))
398+
status_code = response.status_code
399+
self.assertEqual(status_code, 200)
400+
self.assertEqual(data, {'msg': 'unprotected hello world'})
401+
270402
def test_jwt_identity_claims(self):
271403
# Setup custom claims
272404
@self.jwt_manager.user_claims_loader
@@ -350,6 +482,42 @@ def test_different_headers(self):
350482
self.assertIn('msg', data)
351483
self.assertEqual(status, 401)
352484

485+
def test_different_headers_jwt_optional(self):
486+
response = self.client.post('/auth/login')
487+
data = json.loads(response.get_data(as_text=True))
488+
access_token = data['access_token']
489+
490+
self.app.config['JWT_HEADER_TYPE'] = 'JWT'
491+
status, data = self._jwt_get('/partially-protected', access_token,
492+
header_type='JWT')
493+
self.assertEqual(data, {'msg': 'protected hello world'})
494+
self.assertEqual(status, 200)
495+
496+
self.app.config['JWT_HEADER_TYPE'] = ''
497+
status, data = self._jwt_get('/partially-protected', access_token,
498+
header_type='')
499+
self.assertEqual(data, {'msg': 'protected hello world'})
500+
self.assertEqual(status, 200)
501+
502+
self.app.config['JWT_HEADER_TYPE'] = ''
503+
status, data = self._jwt_get('/partially-protected', access_token,
504+
header_type='Bearer')
505+
self.assertIn('msg', data)
506+
self.assertEqual(data, {'msg': 'unprotected hello world'})
507+
508+
self.app.config['JWT_HEADER_TYPE'] = 'Bearer'
509+
self.app.config['JWT_HEADER_NAME'] = 'Auth'
510+
status, data = self._jwt_get('/partially-protected', access_token,
511+
header_name='Auth', header_type='Bearer')
512+
self.assertEqual(data, {'msg': 'protected hello world'})
513+
self.assertEqual(status, 200)
514+
515+
status, data = self._jwt_get('/partially-protected', access_token,
516+
header_name='Authorization',
517+
header_type='Bearer')
518+
self.assertEqual(data, {'msg': 'unprotected hello world'})
519+
self.assertEqual(status, 200)
520+
353521
def test_cookie_methods_fail_with_headers_configured(self):
354522
app = Flask(__name__)
355523
app.config['JWT_TOKEN_LOCATION'] = ['headers']
@@ -401,6 +569,22 @@ def test_jwt_with_different_algorithm(self):
401569
self.assertEqual(status, 422)
402570
self.assertIn('msg', data)
403571

572+
def test_optional_jwt_with_different_algorithm(self):
573+
self.app.config['JWT_ALGORITHM'] = 'HS256'
574+
self.app.secret_key = 'test_secret'
575+
access_token = encode_access_token(
576+
identity='bobdobbs',
577+
secret='test_secret',
578+
algorithm='HS512',
579+
expires_delta=timedelta(minutes=5),
580+
fresh=True,
581+
user_claims={},
582+
csrf=False
583+
)
584+
status, data = self._jwt_get('/partially-protected', access_token)
585+
self.assertEqual(data, {'msg': 'unprotected hello world'})
586+
self.assertEqual(status, 200)
587+
404588

405589
class TestEndpointsWithCookies(unittest.TestCase):
406590

0 commit comments

Comments
 (0)