Skip to content

Commit ef3da3c

Browse files
authored
Add get_jwt_request_location() function to find out where the request JWT was located (#420)
* Add get_jwt_request_location() function to find out where the request JWT was located Sometimes it is desirable to change behavior of a view based on where the JWT was located. For example, if the same route is used for cookie-based access or header-based access, and you want to implicitly refresh cookie-based access tokens. With this change, a protected view can determine which location (e.g., "cookies", "headers", "query_string", or "json") was selected as the source of the current request's JWT. * Fix linting errors * fix linting issues
1 parent fdc0602 commit ef3da3c

File tree

4 files changed

+45
-17
lines changed

4 files changed

+45
-17
lines changed

flask_jwt_extended/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from .utils import get_jwt
1010
from .utils import get_jwt_header
1111
from .utils import get_jwt_identity
12+
from .utils import get_jwt_request_location
1213
from .utils import get_unverified_jwt_headers
1314
from .utils import set_access_cookies
1415
from .utils import set_refresh_cookies

flask_jwt_extended/utils.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,21 @@ def get_jwt_identity():
5858
return get_jwt().get(config.identity_claim_key, None)
5959

6060

61+
def get_jwt_request_location():
62+
"""
63+
In a protected endpoint, this will return the "location" at which the JWT
64+
that is accessing the endpoint was found--e.g., "cookies", "query-string",
65+
"headers", or "json". If no JWT is present due to ``jwt_required(optional=True)``,
66+
None is returned.
67+
68+
:return:
69+
The location of the JWT in the current request; e.g., cookies",
70+
"query-string", "headers", or "json"
71+
"""
72+
location = getattr(_request_ctx_stack.top, "jwt_location", None)
73+
return location
74+
75+
6176
def get_current_user():
6277
"""
6378
In a protected endpoint, this will return the user object for the JWT that

flask_jwt_extended/view_decorators.py

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -60,24 +60,28 @@ def verify_jwt_in_request(optional=False, fresh=False, refresh=False, locations=
6060

6161
try:
6262
if refresh:
63-
jwt_data, jwt_header = _decode_jwt_from_request(
63+
jwt_data, jwt_header, jwt_location = _decode_jwt_from_request(
6464
locations, fresh, refresh=True
6565
)
6666
else:
67-
jwt_data, jwt_header = _decode_jwt_from_request(locations, fresh)
67+
jwt_data, jwt_header, jwt_location = _decode_jwt_from_request(
68+
locations, fresh
69+
)
6870
except (NoAuthorizationError, InvalidHeaderError):
6971
if not optional:
7072
raise
7173
_request_ctx_stack.top.jwt = {}
7274
_request_ctx_stack.top.jwt_header = {}
7375
_request_ctx_stack.top.jwt_user = {"loaded_user": None}
76+
_request_ctx_stack.top.jwt_location = None
7477
return
7578

7679
# Save these at the very end so that they are only saved in the requet
7780
# context if the token is valid and all callbacks succeed
7881
_request_ctx_stack.top.jwt_user = _load_user(jwt_header, jwt_data)
7982
_request_ctx_stack.top.jwt_header = jwt_header
8083
_request_ctx_stack.top.jwt = jwt_data
84+
_request_ctx_stack.top.jwt_location = jwt_location
8185

8286
return jwt_header, jwt_data
8387

@@ -235,18 +239,23 @@ def _decode_jwt_from_request(locations, fresh, refresh=False):
235239
locations = config.token_location
236240

237241
# Get the decode functions in the order specified by locations.
242+
# Each entry in this list is a tuple (<location>, <encoded-token-function>)
238243
get_encoded_token_functions = []
239244
for location in locations:
240245
if location == "cookies":
241246
get_encoded_token_functions.append(
242-
lambda: _decode_jwt_from_cookies(refresh)
247+
(location, lambda: _decode_jwt_from_cookies(refresh))
243248
)
244249
elif location == "query_string":
245-
get_encoded_token_functions.append(_decode_jwt_from_query_string)
250+
get_encoded_token_functions.append(
251+
(location, _decode_jwt_from_query_string)
252+
)
246253
elif location == "headers":
247-
get_encoded_token_functions.append(_decode_jwt_from_headers)
254+
get_encoded_token_functions.append((location, _decode_jwt_from_headers))
248255
elif location == "json":
249-
get_encoded_token_functions.append(lambda: _decode_jwt_from_json(refresh))
256+
get_encoded_token_functions.append(
257+
(location, lambda: _decode_jwt_from_json(refresh))
258+
)
250259
else:
251260
raise RuntimeError(f"'{location}' is not a valid location")
252261

@@ -255,10 +264,12 @@ def _decode_jwt_from_request(locations, fresh, refresh=False):
255264
errors = []
256265
decoded_token = None
257266
jwt_header = None
258-
for get_encoded_token_function in get_encoded_token_functions:
267+
jwt_location = None
268+
for location, get_encoded_token_function in get_encoded_token_functions:
259269
try:
260270
encoded_token, csrf_token = get_encoded_token_function()
261271
decoded_token = decode_token(encoded_token, csrf_token)
272+
jwt_location = location
262273
jwt_header = get_unverified_jwt_headers(encoded_token)
263274
break
264275
except NoAuthorizationError as e:
@@ -284,4 +295,4 @@ def _decode_jwt_from_request(locations, fresh, refresh=False):
284295
verify_token_not_blocklisted(jwt_header, decoded_token)
285296
custom_verification_for_token(jwt_header, decoded_token)
286297

287-
return decoded_token, jwt_header
298+
return decoded_token, jwt_header, jwt_location

tests/test_multiple_token_locations.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from flask import jsonify
44

55
from flask_jwt_extended import create_access_token
6+
from flask_jwt_extended import get_jwt_request_location
67
from flask_jwt_extended import jwt_required
78
from flask_jwt_extended import JWTManager
89
from flask_jwt_extended import set_access_cookies
@@ -25,7 +26,7 @@ def cookie_login():
2526
@app.route("/protected", methods=["GET", "POST"])
2627
@jwt_required()
2728
def access_protected():
28-
return jsonify(foo="bar")
29+
return jsonify(foo="bar", location=get_jwt_request_location())
2930

3031
return app
3132

@@ -48,7 +49,7 @@ def cookie_login():
4849
@app.route("/protected", methods=["GET", "POST"])
4950
@jwt_required(locations=locations)
5051
def access_protected():
51-
return jsonify(foo="bar")
52+
return jsonify(foo="bar", location=get_jwt_request_location())
5253

5354
return app
5455

@@ -62,7 +63,7 @@ def test_header_access(app, app_with_locations):
6263
access_headers = {"Authorization": "Bearer {}".format(access_token)}
6364
response = test_client.get("/protected", headers=access_headers)
6465
assert response.status_code == 200
65-
assert response.get_json() == {"foo": "bar"}
66+
assert response.get_json() == {"foo": "bar", "location": "headers"}
6667

6768

6869
def test_cookie_access(app, app_with_locations):
@@ -71,7 +72,7 @@ def test_cookie_access(app, app_with_locations):
7172
test_client.get("/cookie_login")
7273
response = test_client.get("/protected")
7374
assert response.status_code == 200
74-
assert response.get_json() == {"foo": "bar"}
75+
assert response.get_json() == {"foo": "bar", "location": "cookies"}
7576

7677

7778
def test_query_string_access(app, app_with_locations):
@@ -83,7 +84,7 @@ def test_query_string_access(app, app_with_locations):
8384
url = "/protected?jwt={}".format(access_token)
8485
response = test_client.get(url)
8586
assert response.status_code == 200
86-
assert response.get_json() == {"foo": "bar"}
87+
assert response.get_json() == {"foo": "bar", "location": "query_string"}
8788

8889

8990
def test_json_access(app, app_with_locations):
@@ -94,7 +95,7 @@ def test_json_access(app, app_with_locations):
9495
data = {"access_token": access_token}
9596
response = test_client.post("/protected", json=data)
9697
assert response.status_code == 200
97-
assert response.get_json() == {"foo": "bar"}
98+
assert response.get_json() == {"foo": "bar", "location": "json"}
9899

99100

100101
@pytest.mark.parametrize(
@@ -129,8 +130,8 @@ def test_no_jwt_in_request(app, options):
129130
@pytest.mark.parametrize(
130131
"options",
131132
[
132-
(["cookies", "headers"], 200, None, {"foo": "bar"}),
133-
(["headers", "cookies"], 200, None, {"foo": "bar"}),
133+
(["cookies", "headers"], 200, None, {"foo": "bar", "location": "cookies"}),
134+
(["headers", "cookies"], 200, None, {"foo": "bar", "location": "cookies"}),
134135
],
135136
)
136137
def test_order_of_jwt_locations_in_request(app, options):
@@ -151,7 +152,7 @@ def test_order_of_jwt_locations_in_request(app, options):
151152
@pytest.mark.parametrize(
152153
"options",
153154
[
154-
(["cookies", "headers"], 200, None, {"foo": "bar"}),
155+
(["cookies", "headers"], 200, None, {"foo": "bar", "location": "cookies"}),
155156
(["headers", "cookies"], 422, ("Invalid header padding"), None),
156157
],
157158
)

0 commit comments

Comments
 (0)