diff --git a/docs/options.rst b/docs/options.rst index 3a3c3824..4b9e66ab 100644 --- a/docs/options.rst +++ b/docs/options.rst @@ -44,6 +44,9 @@ General Options: ``JWT_ERROR_MESSAGE_KEY`` The key of the error message in a JSON error response when using the default error handlers. Defaults to ``'msg'``. +``JWT_DECODE_AUDIENCE`` The audience you expect in a JWT when decoding it. + If this option differs from the 'aud' claim in a JWT, the ``'invalid_token_callback'`` is invoked. + Defaults to ``'None'``. ================================= ========================================= diff --git a/flask_jwt_extended/config.py b/flask_jwt_extended/config.py index 0fcd2f12..71e05b45 100644 --- a/flask_jwt_extended/config.py +++ b/flask_jwt_extended/config.py @@ -278,5 +278,9 @@ def error_msg_key(self): def json_encoder(self): return current_app.json_encoder + @property + def audience(self): + return current_app.config['JWT_DECODE_AUDIENCE'] + config = _Config() diff --git a/flask_jwt_extended/default_callbacks.py b/flask_jwt_extended/default_callbacks.py index b80bad60..7c85e8cc 100644 --- a/flask_jwt_extended/default_callbacks.py +++ b/flask_jwt_extended/default_callbacks.py @@ -103,7 +103,7 @@ def default_verify_claims_failed_callback(): return jsonify({config.error_msg_key: 'User claims verification failed'}), 400 -def default_decode_key_callback(claims): +def default_decode_key_callback(claims, headers): """ By default, the decode key specified via the JWT_SECRET_KEY or JWT_PUBLIC_KEY settings will be used to decode all tokens diff --git a/flask_jwt_extended/jwt_manager.py b/flask_jwt_extended/jwt_manager.py index f37dc8fd..d60fd05b 100644 --- a/flask_jwt_extended/jwt_manager.py +++ b/flask_jwt_extended/jwt_manager.py @@ -1,6 +1,6 @@ import datetime -from jwt import ExpiredSignatureError, InvalidTokenError +from jwt import ExpiredSignatureError, InvalidTokenError, InvalidAudienceError from flask_jwt_extended.config import config from flask_jwt_extended.exceptions import ( @@ -108,6 +108,10 @@ def handle_jwt_decode_error(e): def handle_wrong_token_error(e): return self._invalid_token_callback(str(e)) + @app.errorhandler(InvalidAudienceError) + def handle_invalid_audience_error(e): + return self._invalid_token_callback(str(e)) + @app.errorhandler(RevokedTokenError) def handle_revoked_token_error(e): return self._revoked_token_callback() @@ -192,6 +196,7 @@ def _set_default_configuration_options(app): app.config.setdefault('JWT_IDENTITY_CLAIM', 'identity') app.config.setdefault('JWT_USER_CLAIMS', 'user_claims') + app.config.setdefault('JWT_DECODE_AUDIENCE', None) app.config.setdefault('JWT_CLAIMS_IN_REFRESH_TOKEN', False) @@ -390,9 +395,10 @@ def decode_key_loader(self, callback): The default implementation returns the decode key specified by `JWT_SECRET_KEY` or `JWT_PUBLIC_KEY`, depending on the signing algorithm. - *HINT*: The callback function must be a function that takes only **one** argument, - which is the unverified claims of the jwt (dictionary) and must return a *string* - which is the decode key to verify the token. + *HINT*: The callback function should be a function that takes + **two** arguments, which are the unverified claims and headers of the jwt + (dictionaries). The function must return a *string* which is the decode key + in PEM format to verify the token. """ self._decode_key_callback = callback return callback diff --git a/flask_jwt_extended/tokens.py b/flask_jwt_extended/tokens.py index 7f500d9b..f990407f 100644 --- a/flask_jwt_extended/tokens.py +++ b/flask_jwt_extended/tokens.py @@ -15,7 +15,7 @@ def _create_csrf_token(): def _encode_jwt(additional_token_data, expires_delta, secret, algorithm, json_encoder=None): - uid = str(uuid.uuid4()) + uid = _create_csrf_token() now = datetime.datetime.utcnow() token_data = { 'iat': now, @@ -113,7 +113,7 @@ def encode_refresh_token(identity, secret, algorithm, expires_delta, user_claims def decode_jwt(encoded_token, secret, algorithm, identity_claim_key, - user_claims_key, csrf_value=None): + user_claims_key, csrf_value=None, audience=None): """ Decodes an encoded JWT @@ -123,21 +123,24 @@ def decode_jwt(encoded_token, secret, algorithm, identity_claim_key, :param identity_claim_key: expected key that contains the identity :param user_claims_key: expected key that contains the user claims :param csrf_value: Expected double submit csrf value + :param audience: expected audience in the JWT :return: Dictionary containing contents of the JWT """ - # This call verifies the ext, iat, and nbf claims - data = jwt.decode(encoded_token, secret, algorithms=[algorithm]) + # This call verifies the ext, iat, nbf, and aud claims + data = jwt.decode(encoded_token, secret, algorithms=[algorithm], audience=audience) # Make sure that any custom claims we expect in the token are present if 'jti' not in data: - raise JWTDecodeError("Missing claim: jti") + data['jti'] = None if identity_claim_key not in data: raise JWTDecodeError("Missing claim: {}".format(identity_claim_key)) - if 'type' not in data or data['type'] not in ('refresh', 'access'): + if 'type' not in data: + data['type'] = 'access' + if data['type'] not in ('refresh', 'access'): raise JWTDecodeError("Missing or invalid claim: type") if data['type'] == 'access': if 'fresh' not in data: - raise JWTDecodeError("Missing claim: fresh") + data['fresh'] = False if user_claims_key not in data: data[user_claims_key] = {} if csrf_value: diff --git a/flask_jwt_extended/utils.py b/flask_jwt_extended/utils.py index 29141be3..9a8d9d9d 100644 --- a/flask_jwt_extended/utils.py +++ b/flask_jwt_extended/utils.py @@ -1,5 +1,6 @@ from flask import current_app from werkzeug.local import LocalProxy +from warnings import warn try: from flask import _app_ctx_stack as ctx_stack @@ -76,14 +77,26 @@ def decode_token(encoded_token, csrf_value=None): unverified_claims = jwt.decode( encoded_token, verify=False, algorithms=config.algorithm ) - secret = jwt_manager._decode_key_callback(unverified_claims) + unverified_headers = jwt.get_unverified_header(encoded_token) + # Attempt to call callback with both claims and headers, but fallback to just claims + # for backwards compatibility + try: + secret = jwt_manager._decode_key_callback(unverified_claims, unverified_headers) + except TypeError: + msg = ( + "The single-argument (unverified_claims) form of decode_key_callback is deprecated. " + "Update your code to use the two-argument form (unverified_claims, unverified_headers)." + ) + warn(msg, DeprecationWarning) + secret = jwt_manager._decode_key_callback(unverified_claims) return decode_jwt( encoded_token=encoded_token, secret=secret, algorithm=config.algorithm, identity_claim_key=config.identity_claim_key, user_claims_key=config.user_claims_key, - csrf_value=csrf_value + csrf_value=csrf_value, + audience=config.audience ) diff --git a/tests/test_decode_tokens.py b/tests/test_decode_tokens.py index 65718581..adf37e0b 100644 --- a/tests/test_decode_tokens.py +++ b/tests/test_decode_tokens.py @@ -1,9 +1,10 @@ import jwt import pytest -from datetime import timedelta +from datetime import datetime, timedelta +import warnings from flask import Flask -from jwt import ExpiredSignatureError, InvalidSignatureError +from jwt import ExpiredSignatureError, InvalidSignatureError, InvalidAudienceError from flask_jwt_extended import ( JWTManager, create_access_token, decode_token, create_refresh_token, @@ -54,9 +55,9 @@ def empty_user_loader_return(identity): assert config.user_claims_key in extension_decoded -@pytest.mark.parametrize("missing_claim", ['jti', 'type', 'identity', 'fresh', 'csrf']) -def test_missing_jti_claim(app, default_access_token, missing_claim): - del default_access_token[missing_claim] +@pytest.mark.parametrize("missing_claims", ['identity', 'csrf']) +def test_missing_claims(app, default_access_token, missing_claims): + del default_access_token[missing_claims] missing_jwt_token = encode_token(app, default_access_token) with pytest.raises(JWTDecodeError): @@ -64,6 +65,19 @@ def test_missing_jti_claim(app, default_access_token, missing_claim): decode_token(missing_jwt_token, csrf_value='abcd') +def test_default_decode_token_values(app, default_access_token): + del default_access_token['type'] + del default_access_token['jti'] + del default_access_token['fresh'] + token = encode_token(app, default_access_token) + + with app.test_request_context(): + decoded = decode_token(token) + assert decoded['type'] == 'access' + assert decoded['jti'] is None + assert decoded['fresh'] is False + + def test_bad_token_type(app, default_access_token): default_access_token['type'] = 'banana' bad_type_token = encode_token(app, default_access_token) @@ -123,19 +137,36 @@ def test_encode_decode_callback_values(app, default_access_token): jwtM = get_jwt_manager(app) app.config['JWT_SECRET_KEY'] = 'foobarbaz' with app.test_request_context(): - assert jwtM._decode_key_callback({}) == 'foobarbaz' + assert jwtM._decode_key_callback({}, {}) == 'foobarbaz' assert jwtM._encode_key_callback({}) == 'foobarbaz' - @jwtM.decode_key_loader - def get_decode_key_1(claims): + @jwtM.encode_key_loader + def get_encode_key_1(identity): return 'different secret' + assert jwtM._encode_key_callback('') == 'different secret' - @jwtM.encode_key_loader - def get_decode_key_2(identity): + @jwtM.decode_key_loader + def get_decode_key_1(claims, headers): return 'different secret' + assert jwtM._decode_key_callback({}, {}) == 'different secret' - assert jwtM._decode_key_callback({}) == 'different secret' - assert jwtM._encode_key_callback('') == 'different secret' + +def test_legacy_decode_key_callback(app, default_access_token): + jwtM = get_jwt_manager(app) + app.config['JWT_SECRET_KEY'] = 'foobarbaz' + + # test decode key callback with one argument (backwards compatibility) + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + + @jwtM.decode_key_loader + def get_decode_key_legacy(claims): + return 'foobarbaz' + with app.test_request_context(): + token = encode_token(app, default_access_token) + decode_token(token) + assert len(w) == 1 + assert issubclass(w[-1].category, DeprecationWarning) def test_custom_encode_decode_key_callbacks(app, default_access_token): @@ -157,7 +188,7 @@ def get_encode_key_1(identity): decode_token(token) @jwtM.decode_key_loader - def get_decode_key_1(claims): + def get_decode_key_1(claims, headers): assert claims['identity'] == 'username' return 'different secret' @@ -166,3 +197,19 @@ def get_decode_key_1(claims): decode_token(token) token = create_refresh_token('username') decode_token(token) + + +def test_valid_aud(app, default_access_token): + app.config['JWT_DECODE_AUDIENCE'] = 'foo' + + default_access_token['aud'] = 'bar' + invalid_token = encode_token(app, default_access_token) + with pytest.raises(InvalidAudienceError): + with app.test_request_context(): + decode_token(invalid_token) + + default_access_token['aud'] = 'foo' + valid_token = encode_token(app, default_access_token) + with app.test_request_context(): + decoded = decode_token(valid_token) + assert decoded['aud'] == 'foo' diff --git a/tests/test_view_decorators.py b/tests/test_view_decorators.py index 946922aa..ae366075 100644 --- a/tests/test_view_decorators.py +++ b/tests/test_view_decorators.py @@ -207,7 +207,31 @@ def test_jwt_missing_claims(app): response = test_client.get(url, headers=make_headers(token)) assert response.status_code == 422 - assert response.get_json() == {'msg': 'Missing claim: jti'} + assert response.get_json() == {'msg': 'Missing claim: identity'} + + +def test_jwt_invalid_audience(app): + url = '/protected' + jwtM = get_jwt_manager(app) + test_client = app.test_client() + + # No audience claim expected or provided - OK + access_token = encode_token(app, {'identity': 'me'}) + response = test_client.get(url, headers=make_headers(access_token)) + assert response.status_code == 200 + + # Audience claim expected and not provided - not OK + app.config['JWT_DECODE_AUDIENCE'] = 'my_audience' + access_token = encode_token(app, {'identity': 'me'}) + response = test_client.get(url, headers=make_headers(access_token)) + assert response.status_code == 422 + assert response.get_json() == {'msg': 'Token is missing the "aud" claim'} + + # Audience claim still expected and wrong one provided - not OK + access_token = encode_token(app, {'aud': 'different_audience', 'identity': 'me'}) + response = test_client.get(url, headers=make_headers(access_token)) + assert response.status_code == 422 + assert response.get_json() == {'msg': 'Invalid audience'} def test_expired_token(app):