Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/options.rst
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ General Options:
Can be set to ``False`` to disable expiration.
``JWT_ALGORITHM`` Which algorithm to sign the JWT with. `See here <https://pyjwt.readthedocs.io/en/latest/algorithms.html>`_
for the options. Defaults to ``'HS256'``.
``JWT_DECODE_ALGORITHMS`` Which algorithms are allowed to decode a JWT.
Defaults to a list with only the algorithm set in ``JWT_ALGORITHM``.
``JWT_SECRET_KEY`` The secret key needed for symmetric based signing algorithms,
such as ``HS*``. If this is not set, we use the
flask ``SECRET_KEY`` value instead.
Expand Down
9 changes: 9 additions & 0 deletions flask_jwt_extended/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,15 @@ def refresh_expires(self):
def algorithm(self):
return current_app.config['JWT_ALGORITHM']

@property
def decode_algorithms(self):
algorithms = current_app.config['JWT_DECODE_ALGORITHMS']
if not algorithms:
return [self.algorithm]
if self.algorithm not in algorithms:
algorithms.append(self.algorithm)
return algorithms

@property
def blacklist_enabled(self):
return current_app.config['JWT_BLACKLIST_ENABLED']
Expand Down
3 changes: 3 additions & 0 deletions flask_jwt_extended/jwt_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,9 @@ def _set_default_configuration_options(app):
# https://github.com/jpadilla/pyjwt/blob/master/jwt/api_jwt.py
app.config.setdefault('JWT_ALGORITHM', 'HS256')

# What algorithms are allowed to decode a token
app.config.setdefault('JWT_DECODE_ALGORITHMS', None)

# Secret key to sign JWTs with. Only used if a symmetric algorithm is
# used (such as the HS* algorithms). We will use the app secret key
# if this is not set.
Expand Down
6 changes: 3 additions & 3 deletions flask_jwt_extended/tokens.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,15 +112,15 @@ def encode_refresh_token(identity, secret, algorithm, expires_delta, user_claims
json_encoder=json_encoder)


def decode_jwt(encoded_token, secret, algorithm, identity_claim_key,
def decode_jwt(encoded_token, secret, algorithms, identity_claim_key,
user_claims_key, csrf_value=None, audience=None,
leeway=0, allow_expired=False):
"""
Decodes an encoded JWT

:param encoded_token: The encoded JWT string to decode
:param secret: Secret key used to encode the JWT
:param algorithm: Algorithm used to encode the JWT
:param algorithms: Algorithms allowed to decode the token
:param identity_claim_key: expected key that contains the identity
:param user_claims_key: expected key that contains the user claims
:param csrf_value: Expected double submit csrf value
Expand All @@ -134,7 +134,7 @@ def decode_jwt(encoded_token, secret, algorithm, identity_claim_key,
options['verify_exp'] = False

# This call verifies the ext, iat, nbf, and aud claims
data = jwt.decode(encoded_token, secret, algorithms=[algorithm], audience=audience,
data = jwt.decode(encoded_token, secret, algorithms=algorithms, audience=audience,
leeway=leeway, options=options)

# Make sure that any custom claims we expect in the token are present
Expand Down
6 changes: 3 additions & 3 deletions flask_jwt_extended/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def decode_token(encoded_token, csrf_value=None, allow_expired=False):
"""
jwt_manager = _get_jwt_manager()
unverified_claims = jwt.decode(
encoded_token, verify=False, algorithms=config.algorithm
encoded_token, verify=False, algorithms=config.decode_algorithms
)
unverified_headers = jwt.get_unverified_header(encoded_token)
# Attempt to call callback with both claims and headers, but fallback to just claims
Expand All @@ -98,7 +98,7 @@ def decode_token(encoded_token, csrf_value=None, allow_expired=False):
return decode_jwt(
encoded_token=encoded_token,
secret=secret,
algorithm=config.algorithm,
algorithms=config.decode_algorithms,
identity_claim_key=config.identity_claim_key,
user_claims_key=config.user_claims_key,
csrf_value=csrf_value,
Expand All @@ -110,7 +110,7 @@ def decode_token(encoded_token, csrf_value=None, allow_expired=False):
expired_token = decode_jwt(
encoded_token=encoded_token,
secret=secret,
algorithm=config.algorithm,
algorithms=config.decode_algorithms,
identity_claim_key=config.identity_claim_key,
user_claims_key=config.user_claims_key,
csrf_value=csrf_value,
Expand Down
11 changes: 11 additions & 0 deletions tests/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ def test_default_configs(app):
assert config.access_expires == timedelta(minutes=15)
assert config.refresh_expires == timedelta(days=30)
assert config.algorithm == 'HS256'
assert config.decode_algorithms == ['HS256']
assert config.is_asymmetric is False
assert config.blacklist_enabled is False
assert config.blacklist_checks == ('access', 'refresh')
Expand Down Expand Up @@ -105,6 +106,7 @@ def test_override_configs(app, delta_func):
app.config['JWT_ACCESS_TOKEN_EXPIRES'] = delta_func(minutes=5)
app.config['JWT_REFRESH_TOKEN_EXPIRES'] = delta_func(days=5)
app.config['JWT_ALGORITHM'] = 'HS512'
app.config['JWT_DECODE_ALGORITHMS'] = ['HS512', 'HS256']

app.config['JWT_BLACKLIST_ENABLED'] = True
app.config['JWT_BLACKLIST_TOKEN_CHECKS'] = ('refresh',)
Expand Down Expand Up @@ -156,6 +158,7 @@ class CustomJSONEncoder(JSONEncoder):
assert config.access_expires == delta_func(minutes=5)
assert config.refresh_expires == delta_func(days=5)
assert config.algorithm == 'HS512'
assert config.decode_algorithms == ['HS512', 'HS256']

assert config.blacklist_enabled is True
assert config.blacklist_checks == ('refresh',)
Expand Down Expand Up @@ -396,3 +399,11 @@ def test_depreciated_options(app):
assert len(w) == 2
assert w[0].category == DeprecationWarning
assert w[1].category == DeprecationWarning


def test_missing_algorithm_in_decode_algorithms(app):
app.config['JWT_ALGORITHM'] = 'RS256'
app.config['JWT_DECODE_ALGORITHMS'] = ['HS512']

with app.test_request_context():
assert config.decode_algorithms == ['HS512', 'RS256']