From 2471147f15384b8cc8d0a52b03495684c2bcc138 Mon Sep 17 00:00:00 2001 From: Sam McKelvie Date: Thu, 29 Apr 2021 13:59:34 -0700 Subject: [PATCH 1/3] Add get_jwt_request_location() function to find out where the request JWT was located Sometimes it is desirable to change behavior of a view based on where the JWT was located. For example, if the same route is used for cookie-based access or header-based access, and you want to implicitly refresh cookie-based access tokens. With this change, a protected view can determine which location (e.g., "cookies", "headers", "query_string", or "json") was selected as the source of the current request's JWT. --- flask_jwt_extended/__init__.py | 1 + flask_jwt_extended/utils.py | 15 +++++++++++++++ flask_jwt_extended/view_decorators.py | 24 ++++++++++++++---------- tests/test_multiple_token_locations.py | 19 ++++++++++--------- 4 files changed, 40 insertions(+), 19 deletions(-) diff --git a/flask_jwt_extended/__init__.py b/flask_jwt_extended/__init__.py index 3386ebcb..528baecb 100644 --- a/flask_jwt_extended/__init__.py +++ b/flask_jwt_extended/__init__.py @@ -7,6 +7,7 @@ from .utils import get_current_user from .utils import get_jti from .utils import get_jwt +from .utils import get_jwt_request_location from .utils import get_jwt_header from .utils import get_jwt_identity from .utils import get_unverified_jwt_headers diff --git a/flask_jwt_extended/utils.py b/flask_jwt_extended/utils.py index 28753e00..6eb6289f 100644 --- a/flask_jwt_extended/utils.py +++ b/flask_jwt_extended/utils.py @@ -58,6 +58,21 @@ def get_jwt_identity(): return get_jwt().get(config.identity_claim_key, None) +def get_jwt_request_location(): + """ + In a protected endpoint, this will return the "location" at which the JWT + that is accessing the endpoint was found--e.g., "cookies", "query-string", + "headers", or "json". If no JWT is present due to ``jwt_required(optional=True)``, + None is returned. + + :return: + The location of the JWT in the current request; e.g., cookies", + "query-string", "headers", or "json" + """ + location = getattr(_request_ctx_stack.top, "jwt_location", None) + return location + + def get_current_user(): """ In a protected endpoint, this will return the user object for the JWT that diff --git a/flask_jwt_extended/view_decorators.py b/flask_jwt_extended/view_decorators.py index 9419fe80..07c0c126 100644 --- a/flask_jwt_extended/view_decorators.py +++ b/flask_jwt_extended/view_decorators.py @@ -60,17 +60,18 @@ def verify_jwt_in_request(optional=False, fresh=False, refresh=False, locations= try: if refresh: - jwt_data, jwt_header = _decode_jwt_from_request( + jwt_data, jwt_header, jwt_location = _decode_jwt_from_request( locations, fresh, refresh=True ) else: - jwt_data, jwt_header = _decode_jwt_from_request(locations, fresh) + jwt_data, jwt_header, jwt_location = _decode_jwt_from_request(locations, fresh) except (NoAuthorizationError, InvalidHeaderError): if not optional: raise _request_ctx_stack.top.jwt = {} _request_ctx_stack.top.jwt_header = {} _request_ctx_stack.top.jwt_user = {"loaded_user": None} + _request_ctx_stack.top.jwt_location = None return # Save these at the very end so that they are only saved in the requet @@ -78,6 +79,7 @@ def verify_jwt_in_request(optional=False, fresh=False, refresh=False, locations= _request_ctx_stack.top.jwt_user = _load_user(jwt_header, jwt_data) _request_ctx_stack.top.jwt_header = jwt_header _request_ctx_stack.top.jwt = jwt_data + _request_ctx_stack.top.jwt_location = jwt_location return jwt_header, jwt_data @@ -235,30 +237,32 @@ def _decode_jwt_from_request(locations, fresh, refresh=False): locations = config.token_location # Get the decode functions in the order specified by locations. + # Each entry in this list is a tuple (, ) get_encoded_token_functions = [] for location in locations: if location == "cookies": - get_encoded_token_functions.append( - lambda: _decode_jwt_from_cookies(refresh) - ) + fn = lambda: _decode_jwt_from_cookies(refresh) elif location == "query_string": - get_encoded_token_functions.append(_decode_jwt_from_query_string) + fn = _decode_jwt_from_query_string elif location == "headers": - get_encoded_token_functions.append(_decode_jwt_from_headers) + fn = _decode_jwt_from_headers elif location == "json": - get_encoded_token_functions.append(lambda: _decode_jwt_from_json(refresh)) + fn = lambda: _decode_jwt_from_json(refresh) else: raise RuntimeError(f"'{location}' is not a valid location") + get_encoded_token_functions.append((location, fn)) # Try to find the token from one of these locations. It only needs to exist # in one place to be valid (not every location). errors = [] decoded_token = None jwt_header = None - for get_encoded_token_function in get_encoded_token_functions: + jwt_location = None + for location, get_encoded_token_function in get_encoded_token_functions: try: encoded_token, csrf_token = get_encoded_token_function() decoded_token = decode_token(encoded_token, csrf_token) + jwt_location = location jwt_header = get_unverified_jwt_headers(encoded_token) break except NoAuthorizationError as e: @@ -284,4 +288,4 @@ def _decode_jwt_from_request(locations, fresh, refresh=False): verify_token_not_blocklisted(jwt_header, decoded_token) custom_verification_for_token(jwt_header, decoded_token) - return decoded_token, jwt_header + return decoded_token, jwt_header, jwt_location diff --git a/tests/test_multiple_token_locations.py b/tests/test_multiple_token_locations.py index 1016a173..a1145420 100644 --- a/tests/test_multiple_token_locations.py +++ b/tests/test_multiple_token_locations.py @@ -6,6 +6,7 @@ from flask_jwt_extended import jwt_required from flask_jwt_extended import JWTManager from flask_jwt_extended import set_access_cookies +from flask_jwt_extended import get_jwt_request_location @pytest.fixture(scope="function") @@ -25,7 +26,7 @@ def cookie_login(): @app.route("/protected", methods=["GET", "POST"]) @jwt_required() def access_protected(): - return jsonify(foo="bar") + return jsonify(foo="bar", location=get_jwt_request_location()) return app @@ -48,7 +49,7 @@ def cookie_login(): @app.route("/protected", methods=["GET", "POST"]) @jwt_required(locations=locations) def access_protected(): - return jsonify(foo="bar") + return jsonify(foo="bar", location=get_jwt_request_location()) return app @@ -62,7 +63,7 @@ def test_header_access(app, app_with_locations): access_headers = {"Authorization": "Bearer {}".format(access_token)} response = test_client.get("/protected", headers=access_headers) assert response.status_code == 200 - assert response.get_json() == {"foo": "bar"} + assert response.get_json() == {"foo": "bar", "location": "headers"} def test_cookie_access(app, app_with_locations): @@ -71,7 +72,7 @@ def test_cookie_access(app, app_with_locations): test_client.get("/cookie_login") response = test_client.get("/protected") assert response.status_code == 200 - assert response.get_json() == {"foo": "bar"} + assert response.get_json() == {"foo": "bar", "location": "cookies"} def test_query_string_access(app, app_with_locations): @@ -83,7 +84,7 @@ def test_query_string_access(app, app_with_locations): url = "/protected?jwt={}".format(access_token) response = test_client.get(url) assert response.status_code == 200 - assert response.get_json() == {"foo": "bar"} + assert response.get_json() == {"foo": "bar", "location": "query_string"} def test_json_access(app, app_with_locations): @@ -94,7 +95,7 @@ def test_json_access(app, app_with_locations): data = {"access_token": access_token} response = test_client.post("/protected", json=data) assert response.status_code == 200 - assert response.get_json() == {"foo": "bar"} + assert response.get_json() == {"foo": "bar", "location": "json"} @pytest.mark.parametrize( @@ -129,8 +130,8 @@ def test_no_jwt_in_request(app, options): @pytest.mark.parametrize( "options", [ - (["cookies", "headers"], 200, None, {"foo": "bar"}), - (["headers", "cookies"], 200, None, {"foo": "bar"}), + (["cookies", "headers"], 200, None, {"foo": "bar", "location": "cookies"}), + (["headers", "cookies"], 200, None, {"foo": "bar", "location": "cookies"}), ], ) def test_order_of_jwt_locations_in_request(app, options): @@ -151,7 +152,7 @@ def test_order_of_jwt_locations_in_request(app, options): @pytest.mark.parametrize( "options", [ - (["cookies", "headers"], 200, None, {"foo": "bar"}), + (["cookies", "headers"], 200, None, {"foo": "bar", "location": "cookies"}), (["headers", "cookies"], 422, ("Invalid header padding"), None), ], ) From fa2619ce9ca49955c667b79865d1a1809ed5369a Mon Sep 17 00:00:00 2001 From: Sam McKelvie Date: Thu, 29 Apr 2021 18:54:35 -0700 Subject: [PATCH 2/3] Fix linting errors --- flask_jwt_extended/__init__.py | 2 +- flask_jwt_extended/view_decorators.py | 4 +++- tests/test_multiple_token_locations.py | 2 +- 3 files changed, 5 insertions(+), 3 deletions(-) diff --git a/flask_jwt_extended/__init__.py b/flask_jwt_extended/__init__.py index 528baecb..6e9069bc 100644 --- a/flask_jwt_extended/__init__.py +++ b/flask_jwt_extended/__init__.py @@ -7,9 +7,9 @@ from .utils import get_current_user from .utils import get_jti from .utils import get_jwt -from .utils import get_jwt_request_location from .utils import get_jwt_header from .utils import get_jwt_identity +from .utils import get_jwt_request_location from .utils import get_unverified_jwt_headers from .utils import set_access_cookies from .utils import set_refresh_cookies diff --git a/flask_jwt_extended/view_decorators.py b/flask_jwt_extended/view_decorators.py index 07c0c126..70d0df96 100644 --- a/flask_jwt_extended/view_decorators.py +++ b/flask_jwt_extended/view_decorators.py @@ -64,7 +64,9 @@ def verify_jwt_in_request(optional=False, fresh=False, refresh=False, locations= locations, fresh, refresh=True ) else: - jwt_data, jwt_header, jwt_location = _decode_jwt_from_request(locations, fresh) + jwt_data, jwt_header, jwt_location = _decode_jwt_from_request( + locations, fresh + ) except (NoAuthorizationError, InvalidHeaderError): if not optional: raise diff --git a/tests/test_multiple_token_locations.py b/tests/test_multiple_token_locations.py index a1145420..ac551d29 100644 --- a/tests/test_multiple_token_locations.py +++ b/tests/test_multiple_token_locations.py @@ -3,10 +3,10 @@ from flask import jsonify from flask_jwt_extended import create_access_token +from flask_jwt_extended import get_jwt_request_location from flask_jwt_extended import jwt_required from flask_jwt_extended import JWTManager from flask_jwt_extended import set_access_cookies -from flask_jwt_extended import get_jwt_request_location @pytest.fixture(scope="function") From abe8a97f0f717950f6710f585768fb06bab15300 Mon Sep 17 00:00:00 2001 From: Sam McKelvie Date: Thu, 29 Apr 2021 19:17:29 -0700 Subject: [PATCH 3/3] fix linting issues --- flask_jwt_extended/view_decorators.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/flask_jwt_extended/view_decorators.py b/flask_jwt_extended/view_decorators.py index 70d0df96..35c1bf01 100644 --- a/flask_jwt_extended/view_decorators.py +++ b/flask_jwt_extended/view_decorators.py @@ -243,16 +243,21 @@ def _decode_jwt_from_request(locations, fresh, refresh=False): get_encoded_token_functions = [] for location in locations: if location == "cookies": - fn = lambda: _decode_jwt_from_cookies(refresh) + get_encoded_token_functions.append( + (location, lambda: _decode_jwt_from_cookies(refresh)) + ) elif location == "query_string": - fn = _decode_jwt_from_query_string + get_encoded_token_functions.append( + (location, _decode_jwt_from_query_string) + ) elif location == "headers": - fn = _decode_jwt_from_headers + get_encoded_token_functions.append((location, _decode_jwt_from_headers)) elif location == "json": - fn = lambda: _decode_jwt_from_json(refresh) + get_encoded_token_functions.append( + (location, lambda: _decode_jwt_from_json(refresh)) + ) else: raise RuntimeError(f"'{location}' is not a valid location") - get_encoded_token_functions.append((location, fn)) # Try to find the token from one of these locations. It only needs to exist # in one place to be valid (not every location).