Skip to content

Commit e2f5eaf

Browse files
committed
feat: Adds support for multiple decode algorithms
1 parent 23584dd commit e2f5eaf

File tree

5 files changed

+29
-6
lines changed

5 files changed

+29
-6
lines changed

flask_jwt_extended/config.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,15 @@ def refresh_expires(self):
218218
def algorithm(self):
219219
return current_app.config['JWT_ALGORITHM']
220220

221+
@property
222+
def decode_algorithms(self):
223+
algorithms = current_app.config['JWT_DECODE_ALGORITHMS']
224+
if not algorithms:
225+
return [self.algorithm]
226+
if self.algorithm not in algorithms:
227+
algorithms.append(self.algorithm)
228+
return algorithms
229+
221230
@property
222231
def blacklist_enabled(self):
223232
return current_app.config['JWT_BLACKLIST_ENABLED']

flask_jwt_extended/jwt_manager.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,9 @@ def _set_default_configuration_options(app):
194194
# https://github.com/jpadilla/pyjwt/blob/master/jwt/api_jwt.py
195195
app.config.setdefault('JWT_ALGORITHM', 'HS256')
196196

197+
# What algorithms are allowed to decode a token
198+
app.config.setdefault('JWT_DECODE_ALGORITHMS', None)
199+
197200
# Secret key to sign JWTs with. Only used if a symmetric algorithm is
198201
# used (such as the HS* algorithms). We will use the app secret key
199202
# if this is not set.

flask_jwt_extended/tokens.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -112,15 +112,15 @@ def encode_refresh_token(identity, secret, algorithm, expires_delta, user_claims
112112
json_encoder=json_encoder)
113113

114114

115-
def decode_jwt(encoded_token, secret, algorithm, identity_claim_key,
115+
def decode_jwt(encoded_token, secret, algorithms, identity_claim_key,
116116
user_claims_key, csrf_value=None, audience=None,
117117
leeway=0, allow_expired=False):
118118
"""
119119
Decodes an encoded JWT
120120
121121
:param encoded_token: The encoded JWT string to decode
122122
:param secret: Secret key used to encode the JWT
123-
:param algorithm: Algorithm used to encode the JWT
123+
:param algorithms: Algorithms allowed to decode the token
124124
:param identity_claim_key: expected key that contains the identity
125125
:param user_claims_key: expected key that contains the user claims
126126
:param csrf_value: Expected double submit csrf value
@@ -134,7 +134,7 @@ def decode_jwt(encoded_token, secret, algorithm, identity_claim_key,
134134
options['verify_exp'] = False
135135

136136
# This call verifies the ext, iat, nbf, and aud claims
137-
data = jwt.decode(encoded_token, secret, algorithms=[algorithm], audience=audience,
137+
data = jwt.decode(encoded_token, secret, algorithms=algorithms, audience=audience,
138138
leeway=leeway, options=options)
139139

140140
# Make sure that any custom claims we expect in the token are present

flask_jwt_extended/utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ def decode_token(encoded_token, csrf_value=None, allow_expired=False):
7878
"""
7979
jwt_manager = _get_jwt_manager()
8080
unverified_claims = jwt.decode(
81-
encoded_token, verify=False, algorithms=config.algorithm
81+
encoded_token, verify=False, algorithms=config.decode_algorithms
8282
)
8383
unverified_headers = jwt.get_unverified_header(encoded_token)
8484
# Attempt to call callback with both claims and headers, but fallback to just claims
@@ -98,7 +98,7 @@ def decode_token(encoded_token, csrf_value=None, allow_expired=False):
9898
return decode_jwt(
9999
encoded_token=encoded_token,
100100
secret=secret,
101-
algorithm=config.algorithm,
101+
algorithms=config.decode_algorithms,
102102
identity_claim_key=config.identity_claim_key,
103103
user_claims_key=config.user_claims_key,
104104
csrf_value=csrf_value,
@@ -110,7 +110,7 @@ def decode_token(encoded_token, csrf_value=None, allow_expired=False):
110110
expired_token = decode_jwt(
111111
encoded_token=encoded_token,
112112
secret=secret,
113-
algorithm=config.algorithm,
113+
algorithms=config.decode_algorithms,
114114
identity_claim_key=config.identity_claim_key,
115115
user_claims_key=config.user_claims_key,
116116
csrf_value=csrf_value,

tests/test_config.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ def test_default_configs(app):
5555
assert config.access_expires == timedelta(minutes=15)
5656
assert config.refresh_expires == timedelta(days=30)
5757
assert config.algorithm == 'HS256'
58+
assert config.decode_algorithms == ['HS256']
5859
assert config.is_asymmetric is False
5960
assert config.blacklist_enabled is False
6061
assert config.blacklist_checks == ('access', 'refresh')
@@ -105,6 +106,7 @@ def test_override_configs(app, delta_func):
105106
app.config['JWT_ACCESS_TOKEN_EXPIRES'] = delta_func(minutes=5)
106107
app.config['JWT_REFRESH_TOKEN_EXPIRES'] = delta_func(days=5)
107108
app.config['JWT_ALGORITHM'] = 'HS512'
109+
app.config['JWT_DECODE_ALGORITHMS'] = ['HS512', 'HS256']
108110

109111
app.config['JWT_BLACKLIST_ENABLED'] = True
110112
app.config['JWT_BLACKLIST_TOKEN_CHECKS'] = ('refresh',)
@@ -156,6 +158,7 @@ class CustomJSONEncoder(JSONEncoder):
156158
assert config.access_expires == delta_func(minutes=5)
157159
assert config.refresh_expires == delta_func(days=5)
158160
assert config.algorithm == 'HS512'
161+
assert config.decode_algorithms == ['HS512', 'HS256']
159162

160163
assert config.blacklist_enabled is True
161164
assert config.blacklist_checks == ('refresh',)
@@ -396,3 +399,11 @@ def test_depreciated_options(app):
396399
assert len(w) == 2
397400
assert w[0].category == DeprecationWarning
398401
assert w[1].category == DeprecationWarning
402+
403+
404+
def test_missing_algorithm_in_decode_algorithms(app):
405+
app.config['JWT_ALGORITHM'] = 'RS256'
406+
app.config['JWT_DECODE_ALGORITHMS'] = ['HS512']
407+
408+
with app.test_request_context():
409+
assert config.decode_algorithms == ['HS512', 'RS256']

0 commit comments

Comments
 (0)