From 4ab4234a454cf71282f84f4d83e04a03116816ce Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rapha=C3=ABl=20Cohen?= Date: Mon, 24 Jun 2019 21:34:48 +0200 Subject: [PATCH] feat: Adds support for multiple decode algorithms --- docs/options.rst | 2 ++ flask_jwt_extended/config.py | 9 +++++++++ flask_jwt_extended/jwt_manager.py | 3 +++ flask_jwt_extended/tokens.py | 6 +++--- flask_jwt_extended/utils.py | 6 +++--- tests/test_config.py | 11 +++++++++++ 6 files changed, 31 insertions(+), 6 deletions(-) diff --git a/docs/options.rst b/docs/options.rst index 3fe7d6cd..be57b088 100644 --- a/docs/options.rst +++ b/docs/options.rst @@ -31,6 +31,8 @@ General Options: Can be set to ``False`` to disable expiration. ``JWT_ALGORITHM`` Which algorithm to sign the JWT with. `See here `_ for the options. Defaults to ``'HS256'``. +``JWT_DECODE_ALGORITHMS`` Which algorithms are allowed to decode a JWT. + Defaults to a list with only the algorithm set in ``JWT_ALGORITHM``. ``JWT_SECRET_KEY`` The secret key needed for symmetric based signing algorithms, such as ``HS*``. If this is not set, we use the flask ``SECRET_KEY`` value instead. diff --git a/flask_jwt_extended/config.py b/flask_jwt_extended/config.py index f779fbb2..29d20c61 100644 --- a/flask_jwt_extended/config.py +++ b/flask_jwt_extended/config.py @@ -218,6 +218,15 @@ def refresh_expires(self): def algorithm(self): return current_app.config['JWT_ALGORITHM'] + @property + def decode_algorithms(self): + algorithms = current_app.config['JWT_DECODE_ALGORITHMS'] + if not algorithms: + return [self.algorithm] + if self.algorithm not in algorithms: + algorithms.append(self.algorithm) + return algorithms + @property def blacklist_enabled(self): return current_app.config['JWT_BLACKLIST_ENABLED'] diff --git a/flask_jwt_extended/jwt_manager.py b/flask_jwt_extended/jwt_manager.py index 79f6de73..9d42ae5b 100644 --- a/flask_jwt_extended/jwt_manager.py +++ b/flask_jwt_extended/jwt_manager.py @@ -194,6 +194,9 @@ def _set_default_configuration_options(app): # https://github.com/jpadilla/pyjwt/blob/master/jwt/api_jwt.py app.config.setdefault('JWT_ALGORITHM', 'HS256') + # What algorithms are allowed to decode a token + app.config.setdefault('JWT_DECODE_ALGORITHMS', None) + # Secret key to sign JWTs with. Only used if a symmetric algorithm is # used (such as the HS* algorithms). We will use the app secret key # if this is not set. diff --git a/flask_jwt_extended/tokens.py b/flask_jwt_extended/tokens.py index 33561f3c..11e8b75c 100644 --- a/flask_jwt_extended/tokens.py +++ b/flask_jwt_extended/tokens.py @@ -112,7 +112,7 @@ def encode_refresh_token(identity, secret, algorithm, expires_delta, user_claims json_encoder=json_encoder) -def decode_jwt(encoded_token, secret, algorithm, identity_claim_key, +def decode_jwt(encoded_token, secret, algorithms, identity_claim_key, user_claims_key, csrf_value=None, audience=None, leeway=0, allow_expired=False): """ @@ -120,7 +120,7 @@ def decode_jwt(encoded_token, secret, algorithm, identity_claim_key, :param encoded_token: The encoded JWT string to decode :param secret: Secret key used to encode the JWT - :param algorithm: Algorithm used to encode the JWT + :param algorithms: Algorithms allowed to decode the token :param identity_claim_key: expected key that contains the identity :param user_claims_key: expected key that contains the user claims :param csrf_value: Expected double submit csrf value @@ -134,7 +134,7 @@ def decode_jwt(encoded_token, secret, algorithm, identity_claim_key, options['verify_exp'] = False # This call verifies the ext, iat, nbf, and aud claims - data = jwt.decode(encoded_token, secret, algorithms=[algorithm], audience=audience, + data = jwt.decode(encoded_token, secret, algorithms=algorithms, audience=audience, leeway=leeway, options=options) # Make sure that any custom claims we expect in the token are present diff --git a/flask_jwt_extended/utils.py b/flask_jwt_extended/utils.py index 2a118c93..e12dc0ea 100644 --- a/flask_jwt_extended/utils.py +++ b/flask_jwt_extended/utils.py @@ -78,7 +78,7 @@ def decode_token(encoded_token, csrf_value=None, allow_expired=False): """ jwt_manager = _get_jwt_manager() unverified_claims = jwt.decode( - encoded_token, verify=False, algorithms=config.algorithm + encoded_token, verify=False, algorithms=config.decode_algorithms ) unverified_headers = jwt.get_unverified_header(encoded_token) # Attempt to call callback with both claims and headers, but fallback to just claims @@ -98,7 +98,7 @@ def decode_token(encoded_token, csrf_value=None, allow_expired=False): return decode_jwt( encoded_token=encoded_token, secret=secret, - algorithm=config.algorithm, + algorithms=config.decode_algorithms, identity_claim_key=config.identity_claim_key, user_claims_key=config.user_claims_key, csrf_value=csrf_value, @@ -110,7 +110,7 @@ def decode_token(encoded_token, csrf_value=None, allow_expired=False): expired_token = decode_jwt( encoded_token=encoded_token, secret=secret, - algorithm=config.algorithm, + algorithms=config.decode_algorithms, identity_claim_key=config.identity_claim_key, user_claims_key=config.user_claims_key, csrf_value=csrf_value, diff --git a/tests/test_config.py b/tests/test_config.py index 672a400b..97eebbd5 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -55,6 +55,7 @@ def test_default_configs(app): assert config.access_expires == timedelta(minutes=15) assert config.refresh_expires == timedelta(days=30) assert config.algorithm == 'HS256' + assert config.decode_algorithms == ['HS256'] assert config.is_asymmetric is False assert config.blacklist_enabled is False assert config.blacklist_checks == ('access', 'refresh') @@ -105,6 +106,7 @@ def test_override_configs(app, delta_func): app.config['JWT_ACCESS_TOKEN_EXPIRES'] = delta_func(minutes=5) app.config['JWT_REFRESH_TOKEN_EXPIRES'] = delta_func(days=5) app.config['JWT_ALGORITHM'] = 'HS512' + app.config['JWT_DECODE_ALGORITHMS'] = ['HS512', 'HS256'] app.config['JWT_BLACKLIST_ENABLED'] = True app.config['JWT_BLACKLIST_TOKEN_CHECKS'] = ('refresh',) @@ -156,6 +158,7 @@ class CustomJSONEncoder(JSONEncoder): assert config.access_expires == delta_func(minutes=5) assert config.refresh_expires == delta_func(days=5) assert config.algorithm == 'HS512' + assert config.decode_algorithms == ['HS512', 'HS256'] assert config.blacklist_enabled is True assert config.blacklist_checks == ('refresh',) @@ -396,3 +399,11 @@ def test_depreciated_options(app): assert len(w) == 2 assert w[0].category == DeprecationWarning assert w[1].category == DeprecationWarning + + +def test_missing_algorithm_in_decode_algorithms(app): + app.config['JWT_ALGORITHM'] = 'RS256' + app.config['JWT_DECODE_ALGORITHMS'] = ['HS512'] + + with app.test_request_context(): + assert config.decode_algorithms == ['HS512', 'RS256']