1111 get_jwt_identity , set_refresh_cookies , set_access_cookies , unset_jwt_cookies
1212from flask_jwt_extended import JWTManager , create_refresh_token , \
1313 jwt_refresh_token_required , create_access_token , fresh_jwt_required , \
14- jwt_required , get_raw_jwt
14+ jwt_optional , jwt_required , get_raw_jwt
1515
1616
1717class TestEndpoints (unittest .TestCase ):
@@ -55,6 +55,14 @@ def protected():
5555 def fresh_protected ():
5656 return jsonify ({'msg' : "fresh hello world" })
5757
58+ @self .app .route ('/partially-protected' )
59+ @jwt_optional
60+ def partially_protected ():
61+ if get_jwt_identity ():
62+ return jsonify ({'msg' : "protected hello world" })
63+ return jsonify ({'msg' : "unprotected hello world" })
64+
65+
5866 def _jwt_post (self , url , jwt ):
5967 response = self .client .post (url , content_type = 'application/json' ,
6068 headers = {'Authorization' : 'Bearer {}' .format (jwt )})
@@ -124,6 +132,32 @@ def test_jwt_required(self):
124132 self .assertEqual (status , 200 )
125133 self .assertEqual (data , {'msg' : 'hello world' })
126134
135+ def test_jwt_optional_no_jwt (self ):
136+ response = self .client .get ('/partially-protected' )
137+ data = json .loads (response .get_data (as_text = True ))
138+ status = response .status_code
139+ self .assertEqual (status , 200 )
140+ self .assertEqual (data , {'msg' : 'unprotected hello world' })
141+
142+ def test_jwt_optional_with_jwt (self ):
143+ response = self .client .post ('/auth/login' )
144+ data = json .loads (response .get_data (as_text = True ))
145+ fresh_access_token = data ['access_token' ]
146+ refresh_token = data ['refresh_token' ]
147+
148+ # Test it works with a fresh token
149+ status , data = self ._jwt_get ('/partially-protected' ,
150+ fresh_access_token )
151+ self .assertEqual (data , {'msg' : 'protected hello world' })
152+ self .assertEqual (status , 200 )
153+
154+ # Test it works with a non-fresh access token
155+ _ , data = self ._jwt_post ('/auth/refresh' , refresh_token )
156+ non_fresh_token = data ['access_token' ]
157+ status , data = self ._jwt_get ('/partially-protected' , non_fresh_token )
158+ self .assertEqual (status , 200 )
159+ self .assertEqual (data , {'msg' : 'protected hello world' })
160+
127161 def test_jwt_required_wrong_token (self ):
128162 response = self .client .post ('/auth/login' )
129163 data = json .loads (response .get_data (as_text = True ))
@@ -133,6 +167,17 @@ def test_jwt_required_wrong_token(self):
133167 status , text = self ._jwt_get ('/protected' , refresh_token )
134168 self .assertEqual (status , 422 )
135169
170+ def test_jwt_optional_wrong_token (self ):
171+ response = self .client .post ('/auth/login' )
172+ data = json .loads (response .get_data (as_text = True ))
173+ refresh_token = data ['refresh_token' ]
174+
175+ # Should work with refresh token, but not return data
176+ # that has been manually protected in the view
177+ status , text = self ._jwt_get ('/partially-protected' , refresh_token )
178+ self .assertEqual (status , 200 )
179+ self .assertEqual (text , {'msg' : 'unprotected hello world' })
180+
136181 def test_fresh_jwt_required (self ):
137182 response = self .client .post ('/auth/login' )
138183 data = json .loads (response .get_data (as_text = True ))
@@ -209,6 +254,45 @@ def test_bad_jwt_requests(self):
209254 self .assertEqual (status_code , 422 )
210255 self .assertIn ('msg' , data )
211256
257+ def test_optional_bad_jwt_requests (self ):
258+ response = self .client .post ('/auth/login' )
259+ data = json .loads (response .get_data (as_text = True ))
260+ access_token = data ['access_token' ]
261+
262+ # Test with no authorization header
263+ response = self .client .get ('/partially-protected' )
264+ data = json .loads (response .get_data (as_text = True ))
265+ status_code = response .status_code
266+ self .assertEqual (status_code , 200 )
267+ self .assertEqual (data , {'msg' : 'unprotected hello world' })
268+
269+ # Test with missing type in authorization header
270+ auth_header = access_token
271+ response = self .client .get ('/partially-protected' ,
272+ headers = {'Authorization' : auth_header })
273+ data = json .loads (response .get_data (as_text = True ))
274+ status_code = response .status_code
275+ self .assertEqual (status_code , 200 )
276+ self .assertEqual (data , {'msg' : 'unprotected hello world' })
277+
278+ # Test with type not being Bearer in authorization header
279+ auth_header = "BANANA {}" .format (access_token )
280+ response = self .client .get ('/partially-protected' ,
281+ headers = {'Authorization' : auth_header })
282+ data = json .loads (response .get_data (as_text = True ))
283+ status_code = response .status_code
284+ self .assertEqual (status_code , 200 )
285+ self .assertEqual (data , {'msg' : 'unprotected hello world' })
286+
287+ # Test with too many items in auth header
288+ auth_header = "Bearer {} BANANA" .format (access_token )
289+ response = self .client .get ('/partially-protected' ,
290+ headers = {'Authorization' : auth_header })
291+ data = json .loads (response .get_data (as_text = True ))
292+ status_code = response .status_code
293+ self .assertEqual (status_code , 200 )
294+ self .assertEqual (data , {'msg' : 'unprotected hello world' })
295+
212296 def test_bad_tokens (self ):
213297 # Test expired access token
214298 response = self .client .post ('/auth/login' )
@@ -267,6 +351,54 @@ def test_bad_tokens(self):
267351 self .assertEqual (status_code , 422 )
268352 self .assertIn ('msg' , data )
269353
354+ def test_optional_jwt_bad_tokens (self ):
355+ # Test expired access token
356+ response = self .client .post ('/auth/login' )
357+ data = json .loads (response .get_data (as_text = True ))
358+ access_token = data ['access_token' ]
359+ status_code , data = self ._jwt_get ('/partially-protected' , access_token )
360+ self .assertEqual (status_code , 200 )
361+ self .assertEqual (data , {'msg' : 'protected hello world' })
362+ time .sleep (2 )
363+ status_code , data = self ._jwt_get ('/partially-protected' , access_token )
364+ self .assertEqual (status_code , 200 )
365+ self .assertEqual (data , {'msg' : 'unprotected hello world' })
366+
367+ # Test Bogus token
368+ auth_header = "Bearer {}" .format ('this_is_totally_an_access_token' )
369+ response = self .client .get ('/partially-protected' ,
370+ headers = {'Authorization' : auth_header })
371+ data = json .loads (response .get_data (as_text = True ))
372+ status_code = response .status_code
373+ self .assertEqual (status_code , 200 )
374+ self .assertEqual (data , {'msg' : 'unprotected hello world' })
375+
376+ # Test token that was signed with a different key
377+ with self .app .test_request_context ():
378+ token = encode_access_token ('foo' , 'newsecret' , 'HS256' ,
379+ timedelta (minutes = 5 ), True , {},
380+ csrf = False )
381+ auth_header = "Bearer {}" .format (token )
382+ response = self .client .get ('/partially-protected' ,
383+ headers = {'Authorization' : auth_header })
384+ data = json .loads (response .get_data (as_text = True ))
385+ status_code = response .status_code
386+ self .assertEqual (status_code , 200 )
387+ self .assertEqual (data , {'msg' : 'unprotected hello world' })
388+
389+ # Test with valid token that is missing required claims
390+ now = datetime .utcnow ()
391+ token_data = {'exp' : now + timedelta (minutes = 5 )}
392+ encoded_token = jwt .encode (token_data , self .app .config ['SECRET_KEY' ],
393+ self .app .config ['JWT_ALGORITHM' ]).decode ('utf-8' )
394+ auth_header = "Bearer {}" .format (encoded_token )
395+ response = self .client .get ('/partially-protected' ,
396+ headers = {'Authorization' : auth_header })
397+ data = json .loads (response .get_data (as_text = True ))
398+ status_code = response .status_code
399+ self .assertEqual (status_code , 200 )
400+ self .assertEqual (data , {'msg' : 'unprotected hello world' })
401+
270402 def test_jwt_identity_claims (self ):
271403 # Setup custom claims
272404 @self .jwt_manager .user_claims_loader
@@ -350,6 +482,42 @@ def test_different_headers(self):
350482 self .assertIn ('msg' , data )
351483 self .assertEqual (status , 401 )
352484
485+ def test_different_headers_jwt_optional (self ):
486+ response = self .client .post ('/auth/login' )
487+ data = json .loads (response .get_data (as_text = True ))
488+ access_token = data ['access_token' ]
489+
490+ self .app .config ['JWT_HEADER_TYPE' ] = 'JWT'
491+ status , data = self ._jwt_get ('/partially-protected' , access_token ,
492+ header_type = 'JWT' )
493+ self .assertEqual (data , {'msg' : 'protected hello world' })
494+ self .assertEqual (status , 200 )
495+
496+ self .app .config ['JWT_HEADER_TYPE' ] = ''
497+ status , data = self ._jwt_get ('/partially-protected' , access_token ,
498+ header_type = '' )
499+ self .assertEqual (data , {'msg' : 'protected hello world' })
500+ self .assertEqual (status , 200 )
501+
502+ self .app .config ['JWT_HEADER_TYPE' ] = ''
503+ status , data = self ._jwt_get ('/partially-protected' , access_token ,
504+ header_type = 'Bearer' )
505+ self .assertIn ('msg' , data )
506+ self .assertEqual (data , {'msg' : 'unprotected hello world' })
507+
508+ self .app .config ['JWT_HEADER_TYPE' ] = 'Bearer'
509+ self .app .config ['JWT_HEADER_NAME' ] = 'Auth'
510+ status , data = self ._jwt_get ('/partially-protected' , access_token ,
511+ header_name = 'Auth' , header_type = 'Bearer' )
512+ self .assertEqual (data , {'msg' : 'protected hello world' })
513+ self .assertEqual (status , 200 )
514+
515+ status , data = self ._jwt_get ('/partially-protected' , access_token ,
516+ header_name = 'Authorization' ,
517+ header_type = 'Bearer' )
518+ self .assertEqual (data , {'msg' : 'unprotected hello world' })
519+ self .assertEqual (status , 200 )
520+
353521 def test_cookie_methods_fail_with_headers_configured (self ):
354522 app = Flask (__name__ )
355523 app .config ['JWT_TOKEN_LOCATION' ] = ['headers' ]
@@ -401,6 +569,22 @@ def test_jwt_with_different_algorithm(self):
401569 self .assertEqual (status , 422 )
402570 self .assertIn ('msg' , data )
403571
572+ def test_optional_jwt_with_different_algorithm (self ):
573+ self .app .config ['JWT_ALGORITHM' ] = 'HS256'
574+ self .app .secret_key = 'test_secret'
575+ access_token = encode_access_token (
576+ identity = 'bobdobbs' ,
577+ secret = 'test_secret' ,
578+ algorithm = 'HS512' ,
579+ expires_delta = timedelta (minutes = 5 ),
580+ fresh = True ,
581+ user_claims = {},
582+ csrf = False
583+ )
584+ status , data = self ._jwt_get ('/partially-protected' , access_token )
585+ self .assertEqual (data , {'msg' : 'unprotected hello world' })
586+ self .assertEqual (status , 200 )
587+
404588
405589class TestEndpointsWithCookies (unittest .TestCase ):
406590
0 commit comments