1111 get_jwt_identity , set_refresh_cookies , set_access_cookies , unset_jwt_cookies
1212from flask_jwt_extended import JWTManager , create_refresh_token , \
1313 jwt_refresh_token_required , create_access_token , fresh_jwt_required , \
14- jwt_required , get_raw_jwt
14+ jwt_optional , jwt_required , get_raw_jwt
1515
1616
1717class TestEndpoints (unittest .TestCase ):
@@ -55,6 +55,14 @@ def protected():
5555 def fresh_protected ():
5656 return jsonify ({'msg' : "fresh hello world" })
5757
58+ @self .app .route ('/partially-protected' )
59+ @jwt_optional
60+ def partially_protected ():
61+ if get_jwt_identity ():
62+ return jsonify ({'msg' : "protected hello world" })
63+ return jsonify ({'msg' : "unprotected hello world" })
64+
65+
5866 def _jwt_post (self , url , jwt ):
5967 response = self .client .post (url , content_type = 'application/json' ,
6068 headers = {'Authorization' : 'Bearer {}' .format (jwt )})
@@ -124,6 +132,32 @@ def test_jwt_required(self):
124132 self .assertEqual (status , 200 )
125133 self .assertEqual (data , {'msg' : 'hello world' })
126134
135+ def test_jwt_optional_no_jwt (self ):
136+ response = self .client .get ('/partially-protected' )
137+ data = json .loads (response .get_data (as_text = True ))
138+ status = response .status_code
139+ self .assertEqual (status , 200 )
140+ self .assertEqual (data , {'msg' : 'unprotected hello world' })
141+
142+ def test_jwt_optional_with_jwt (self ):
143+ response = self .client .post ('/auth/login' )
144+ data = json .loads (response .get_data (as_text = True ))
145+ fresh_access_token = data ['access_token' ]
146+ refresh_token = data ['refresh_token' ]
147+
148+ # Test it works with a fresh token
149+ status , data = self ._jwt_get ('/partially-protected' ,
150+ fresh_access_token )
151+ self .assertEqual (data , {'msg' : 'protected hello world' })
152+ self .assertEqual (status , 200 )
153+
154+ # Test it works with a non-fresh access token
155+ _ , data = self ._jwt_post ('/auth/refresh' , refresh_token )
156+ non_fresh_token = data ['access_token' ]
157+ status , data = self ._jwt_get ('/partially-protected' , non_fresh_token )
158+ self .assertEqual (status , 200 )
159+ self .assertEqual (data , {'msg' : 'protected hello world' })
160+
127161 def test_jwt_required_wrong_token (self ):
128162 response = self .client .post ('/auth/login' )
129163 data = json .loads (response .get_data (as_text = True ))
@@ -133,6 +167,15 @@ def test_jwt_required_wrong_token(self):
133167 status , text = self ._jwt_get ('/protected' , refresh_token )
134168 self .assertEqual (status , 422 )
135169
170+ def test_jwt_optional_wrong_token (self ):
171+ response = self .client .post ('/auth/login' )
172+ data = json .loads (response .get_data (as_text = True ))
173+ refresh_token = data ['refresh_token' ]
174+
175+ # Shouldn't work with a refresh token
176+ status , text = self ._jwt_get ('/partially-protected' , refresh_token )
177+ self .assertEqual (status , 422 )
178+
136179 def test_fresh_jwt_required (self ):
137180 response = self .client .post ('/auth/login' )
138181 data = json .loads (response .get_data (as_text = True ))
@@ -209,6 +252,38 @@ def test_bad_jwt_requests(self):
209252 self .assertEqual (status_code , 422 )
210253 self .assertIn ('msg' , data )
211254
255+ def test_optional_bad_jwt_requests (self ):
256+ response = self .client .post ('/auth/login' )
257+ data = json .loads (response .get_data (as_text = True ))
258+ access_token = data ['access_token' ]
259+
260+ # Test with missing type in authorization header
261+ auth_header = access_token
262+ response = self .client .get ('/partially-protected' ,
263+ headers = {'Authorization' : auth_header })
264+ data = json .loads (response .get_data (as_text = True ))
265+ status_code = response .status_code
266+ self .assertEqual (status_code , 422 )
267+ self .assertIn ('msg' , data )
268+
269+ # Test with type not being Bearer in authorization header
270+ auth_header = "BANANA {}" .format (access_token )
271+ response = self .client .get ('/partially-protected' ,
272+ headers = {'Authorization' : auth_header })
273+ data = json .loads (response .get_data (as_text = True ))
274+ status_code = response .status_code
275+ self .assertEqual (status_code , 422 )
276+ self .assertIn ('msg' , data )
277+
278+ # Test with too many items in auth header
279+ auth_header = "Bearer {} BANANA" .format (access_token )
280+ response = self .client .get ('/partially-protected' ,
281+ headers = {'Authorization' : auth_header })
282+ data = json .loads (response .get_data (as_text = True ))
283+ status_code = response .status_code
284+ self .assertEqual (status_code , 422 )
285+ self .assertIn ('msg' , data )
286+
212287 def test_bad_tokens (self ):
213288 # Test expired access token
214289 response = self .client .post ('/auth/login' )
@@ -267,6 +342,54 @@ def test_bad_tokens(self):
267342 self .assertEqual (status_code , 422 )
268343 self .assertIn ('msg' , data )
269344
345+ def test_optional_jwt_bad_tokens (self ):
346+ # Test expired access token
347+ response = self .client .post ('/auth/login' )
348+ data = json .loads (response .get_data (as_text = True ))
349+ access_token = data ['access_token' ]
350+ status_code , data = self ._jwt_get ('/partially-protected' , access_token )
351+ self .assertEqual (status_code , 200 )
352+ self .assertEqual (data , {'msg' : 'protected hello world' })
353+ time .sleep (2 )
354+ status_code , data = self ._jwt_get ('/partially-protected' , access_token )
355+ self .assertEqual (status_code , 401 )
356+ self .assertIn ('msg' , data )
357+
358+ # Test Bogus token
359+ auth_header = "Bearer {}" .format ('this_is_totally_an_access_token' )
360+ response = self .client .get ('/partially-protected' ,
361+ headers = {'Authorization' : auth_header })
362+ data = json .loads (response .get_data (as_text = True ))
363+ status_code = response .status_code
364+ self .assertEqual (status_code , 422 )
365+ self .assertIn ('msg' , data )
366+
367+ # Test token that was signed with a different key
368+ with self .app .test_request_context ():
369+ token = encode_access_token ('foo' , 'newsecret' , 'HS256' ,
370+ timedelta (minutes = 5 ), True , {},
371+ csrf = False )
372+ auth_header = "Bearer {}" .format (token )
373+ response = self .client .get ('/partially-protected' ,
374+ headers = {'Authorization' : auth_header })
375+ data = json .loads (response .get_data (as_text = True ))
376+ status_code = response .status_code
377+ self .assertEqual (status_code , 422 )
378+ self .assertIn ('msg' , data )
379+
380+ # Test with valid token that is missing required claims
381+ now = datetime .utcnow ()
382+ token_data = {'exp' : now + timedelta (minutes = 5 )}
383+ encoded_token = jwt .encode (token_data , self .app .config ['SECRET_KEY' ],
384+ self .app .config ['JWT_ALGORITHM' ]).decode ('utf-8' )
385+ auth_header = "Bearer {}" .format (encoded_token )
386+ response = self .client .get ('/partially-protected' ,
387+ headers = {'Authorization' : auth_header })
388+ data = json .loads (response .get_data (as_text = True ))
389+ status_code = response .status_code
390+ self .assertEqual (status_code , 422 )
391+ self .assertIn ('msg' , data )
392+
270393 def test_jwt_identity_claims (self ):
271394 # Setup custom claims
272395 @self .jwt_manager .user_claims_loader
@@ -349,6 +472,43 @@ def test_different_headers(self):
349472 header_type = 'Bearer' )
350473 self .assertIn ('msg' , data )
351474 self .assertEqual (status , 401 )
475+ self .assertEqual (data , {'msg' : 'Missing Auth Header' })
476+
477+ def test_different_headers_jwt_optional (self ):
478+ response = self .client .post ('/auth/login' )
479+ data = json .loads (response .get_data (as_text = True ))
480+ access_token = data ['access_token' ]
481+
482+ self .app .config ['JWT_HEADER_TYPE' ] = 'JWT'
483+ status , data = self ._jwt_get ('/partially-protected' , access_token ,
484+ header_type = 'JWT' )
485+ self .assertEqual (data , {'msg' : 'protected hello world' })
486+ self .assertEqual (status , 200 )
487+
488+ self .app .config ['JWT_HEADER_TYPE' ] = ''
489+ status , data = self ._jwt_get ('/partially-protected' , access_token ,
490+ header_type = '' )
491+ self .assertEqual (data , {'msg' : 'protected hello world' })
492+ self .assertEqual (status , 200 )
493+
494+ self .app .config ['JWT_HEADER_TYPE' ] = ''
495+ status , data = self ._jwt_get ('/partially-protected' , access_token ,
496+ header_type = 'Bearer' )
497+ self .assertIn ('msg' , data )
498+ self .assertEqual (status , 422 )
499+
500+ self .app .config ['JWT_HEADER_TYPE' ] = 'Bearer'
501+ self .app .config ['JWT_HEADER_NAME' ] = 'Auth'
502+ status , data = self ._jwt_get ('/partially-protected' , access_token ,
503+ header_name = 'Auth' , header_type = 'Bearer' )
504+ self .assertEqual (data , {'msg' : 'protected hello world' })
505+ self .assertEqual (status , 200 )
506+
507+ status , data = self ._jwt_get ('/partially-protected' , access_token ,
508+ header_name = 'Authorization' ,
509+ header_type = 'Bearer' )
510+ self .assertEqual (status , 200 )
511+ self .assertEqual (data , {'msg' : 'unprotected hello world' })
352512
353513 def test_cookie_methods_fail_with_headers_configured (self ):
354514 app = Flask (__name__ )
@@ -401,6 +561,22 @@ def test_jwt_with_different_algorithm(self):
401561 self .assertEqual (status , 422 )
402562 self .assertIn ('msg' , data )
403563
564+ def test_optional_jwt_with_different_algorithm (self ):
565+ self .app .config ['JWT_ALGORITHM' ] = 'HS256'
566+ self .app .secret_key = 'test_secret'
567+ access_token = encode_access_token (
568+ identity = 'bobdobbs' ,
569+ secret = 'test_secret' ,
570+ algorithm = 'HS512' ,
571+ expires_delta = timedelta (minutes = 5 ),
572+ fresh = True ,
573+ user_claims = {},
574+ csrf = False
575+ )
576+ status , data = self ._jwt_get ('/partially-protected' , access_token )
577+ self .assertEqual (status , 422 )
578+ self .assertIn ('msg' , data )
579+
404580
405581class TestEndpointsWithCookies (unittest .TestCase ):
406582
0 commit comments