11import jwt
22import pytest
3- from datetime import timedelta
3+ from datetime import datetime , timedelta
4+ import warnings
45
56from flask import Flask
6- from jwt import ExpiredSignatureError , InvalidSignatureError
7+ from jwt import ExpiredSignatureError , InvalidSignatureError , InvalidAudienceError
78
89from flask_jwt_extended import (
910 JWTManager , create_access_token , decode_token , create_refresh_token ,
@@ -54,16 +55,29 @@ def empty_user_loader_return(identity):
5455 assert config .user_claims_key in extension_decoded
5556
5657
57- @pytest .mark .parametrize ("missing_claim " , ['jti' , 'type' , ' identity' , 'fresh ' , 'csrf' ])
58- def test_missing_jti_claim (app , default_access_token , missing_claim ):
59- del default_access_token [missing_claim ]
58+ @pytest .mark .parametrize ("missing_claims " , ['identity' , 'csrf' ])
59+ def test_missing_claims (app , default_access_token , missing_claims ):
60+ del default_access_token [missing_claims ]
6061 missing_jwt_token = encode_token (app , default_access_token )
6162
6263 with pytest .raises (JWTDecodeError ):
6364 with app .test_request_context ():
6465 decode_token (missing_jwt_token , csrf_value = 'abcd' )
6566
6667
68+ def test_default_decode_token_values (app , default_access_token ):
69+ del default_access_token ['type' ]
70+ del default_access_token ['jti' ]
71+ del default_access_token ['fresh' ]
72+ token = encode_token (app , default_access_token )
73+
74+ with app .test_request_context ():
75+ decoded = decode_token (token )
76+ assert decoded ['type' ] == 'access'
77+ assert decoded ['jti' ] is None
78+ assert decoded ['fresh' ] is False
79+
80+
6781def test_bad_token_type (app , default_access_token ):
6882 default_access_token ['type' ] = 'banana'
6983 bad_type_token = encode_token (app , default_access_token )
@@ -123,19 +137,36 @@ def test_encode_decode_callback_values(app, default_access_token):
123137 jwtM = get_jwt_manager (app )
124138 app .config ['JWT_SECRET_KEY' ] = 'foobarbaz'
125139 with app .test_request_context ():
126- assert jwtM ._decode_key_callback ({}) == 'foobarbaz'
140+ assert jwtM ._decode_key_callback ({}, {} ) == 'foobarbaz'
127141 assert jwtM ._encode_key_callback ({}) == 'foobarbaz'
128142
129- @jwtM .decode_key_loader
130- def get_decode_key_1 ( claims ):
143+ @jwtM .encode_key_loader
144+ def get_encode_key_1 ( identity ):
131145 return 'different secret'
146+ assert jwtM ._encode_key_callback ('' ) == 'different secret'
132147
133- @jwtM .encode_key_loader
134- def get_decode_key_2 ( identity ):
148+ @jwtM .decode_key_loader
149+ def get_decode_key_1 ( claims , headers ):
135150 return 'different secret'
151+ assert jwtM ._decode_key_callback ({}, {}) == 'different secret'
136152
137- assert jwtM ._decode_key_callback ({}) == 'different secret'
138- assert jwtM ._encode_key_callback ('' ) == 'different secret'
153+
154+ def test_legacy_decode_key_callback (app , default_access_token ):
155+ jwtM = get_jwt_manager (app )
156+ app .config ['JWT_SECRET_KEY' ] = 'foobarbaz'
157+
158+ # test decode key callback with one argument (backwards compatibility)
159+ with warnings .catch_warnings (record = True ) as w :
160+ warnings .simplefilter ("always" )
161+
162+ @jwtM .decode_key_loader
163+ def get_decode_key_legacy (claims ):
164+ return 'foobarbaz'
165+ with app .test_request_context ():
166+ token = encode_token (app , default_access_token )
167+ decode_token (token )
168+ assert len (w ) == 1
169+ assert issubclass (w [- 1 ].category , DeprecationWarning )
139170
140171
141172def test_custom_encode_decode_key_callbacks (app , default_access_token ):
@@ -157,7 +188,7 @@ def get_encode_key_1(identity):
157188 decode_token (token )
158189
159190 @jwtM .decode_key_loader
160- def get_decode_key_1 (claims ):
191+ def get_decode_key_1 (claims , headers ):
161192 assert claims ['identity' ] == 'username'
162193 return 'different secret'
163194
@@ -166,3 +197,19 @@ def get_decode_key_1(claims):
166197 decode_token (token )
167198 token = create_refresh_token ('username' )
168199 decode_token (token )
200+
201+
202+ def test_valid_aud (app , default_access_token ):
203+ app .config ['JWT_DECODE_AUDIENCE' ] = 'foo'
204+
205+ default_access_token ['aud' ] = 'bar'
206+ invalid_token = encode_token (app , default_access_token )
207+ with pytest .raises (InvalidAudienceError ):
208+ with app .test_request_context ():
209+ decode_token (invalid_token )
210+
211+ default_access_token ['aud' ] = 'foo'
212+ valid_token = encode_token (app , default_access_token )
213+ with app .test_request_context ():
214+ decoded = decode_token (valid_token )
215+ assert decoded ['aud' ] == 'foo'
0 commit comments