diff --git a/flask_jwt_extended/internal_utils.py b/flask_jwt_extended/internal_utils.py index 2527e971..c7ee8c21 100644 --- a/flask_jwt_extended/internal_utils.py +++ b/flask_jwt_extended/internal_utils.py @@ -25,12 +25,14 @@ def user_lookup(*args, **kwargs): return jwt_manager._user_lookup_callback(*args, **kwargs) -def verify_token_type(decoded_token, expected_type): - if decoded_token["type"] != expected_type: - raise WrongTokenError("Only {} tokens are allowed".format(expected_type)) +def verify_token_type(decoded_token, refresh): + if not refresh and decoded_token["type"] == "refresh": + raise WrongTokenError("Only non-refresh tokens are allowed") + elif refresh and decoded_token["type"] != "refresh": + raise WrongTokenError("Only refresh tokens are allowed") -def verify_token_not_blocklisted(jwt_header, jwt_data, request_type): +def verify_token_not_blocklisted(jwt_header, jwt_data): jwt_manager = get_jwt_manager() if jwt_manager._token_in_blocklist_callback(jwt_header, jwt_data): raise RevokedTokenError(jwt_header, jwt_data) diff --git a/flask_jwt_extended/tokens.py b/flask_jwt_extended/tokens.py index ce0aff17..3fcfa473 100644 --- a/flask_jwt_extended/tokens.py +++ b/flask_jwt_extended/tokens.py @@ -98,9 +98,6 @@ def _decode_jwt( if "type" not in decoded_token: decoded_token["type"] = "access" - if decoded_token["type"] not in ("access", "refresh"): - raise JWTDecodeError("Invalid token type: {}".format(decoded_token["type"])) - if "fresh" not in decoded_token: decoded_token["fresh"] = False diff --git a/flask_jwt_extended/view_decorators.py b/flask_jwt_extended/view_decorators.py index 0a4cc7f0..6a4b864e 100644 --- a/flask_jwt_extended/view_decorators.py +++ b/flask_jwt_extended/view_decorators.py @@ -47,8 +47,7 @@ def verify_jwt_in_request(optional=False, fresh=False, refresh=False, locations= Defaults to ``False``. :param refresh: - If ``True``, require a refresh JWT to be verified. If ``False`` require an access - JWT to be verified. Defaults to ``False``. + If ``True``, require a refresh JWT to be verified. :param locations: A list of locations to look for the JWT in this request, for example: @@ -61,9 +60,11 @@ def verify_jwt_in_request(optional=False, fresh=False, refresh=False, locations= try: if refresh: - jwt_data, jwt_header = _decode_jwt_from_request("refresh", locations, fresh) + jwt_data, jwt_header = _decode_jwt_from_request( + locations, fresh, refresh=True + ) else: - jwt_data, jwt_header = _decode_jwt_from_request("access", locations, fresh) + jwt_data, jwt_header = _decode_jwt_from_request(locations, fresh) except (NoAuthorizationError, InvalidHeaderError): if not optional: raise @@ -170,15 +171,15 @@ def _decode_jwt_from_headers(): return encoded_token, None -def _decode_jwt_from_cookies(token_type): - if token_type == "access": - cookie_key = config.access_cookie_name - csrf_header_key = config.access_csrf_header_name - csrf_field_key = config.access_csrf_field_name - else: +def _decode_jwt_from_cookies(refresh): + if refresh: cookie_key = config.refresh_cookie_name csrf_header_key = config.refresh_csrf_header_name csrf_field_key = config.refresh_csrf_field_name + else: + cookie_key = config.access_cookie_name + csrf_header_key = config.access_csrf_header_name + csrf_field_key = config.access_csrf_field_name encoded_token = request.cookies.get(cookie_key) if not encoded_token: @@ -205,15 +206,15 @@ def _decode_jwt_from_query_string(): return encoded_token, None -def _decode_jwt_from_json(token_type): +def _decode_jwt_from_json(refresh): content_type = request.content_type or "" if not content_type.startswith("application/json"): raise NoAuthorizationError("Invalid content-type. Must be application/json.") - if token_type == "access": - token_key = config.json_key - else: + if refresh: token_key = config.refresh_json_key + else: + token_key = config.json_key try: encoded_token = request.json.get(token_key, None) @@ -225,7 +226,7 @@ def _decode_jwt_from_json(token_type): return encoded_token, None -def _decode_jwt_from_request(token_type, locations, fresh): +def _decode_jwt_from_request(locations, fresh, refresh=False): # All the places we can get a JWT from in this request get_encoded_token_functions = [] @@ -238,16 +239,14 @@ def _decode_jwt_from_request(token_type, locations, fresh): for location in locations: if location == "cookies": get_encoded_token_functions.append( - lambda: _decode_jwt_from_cookies(token_type) + lambda: _decode_jwt_from_cookies(refresh) ) if location == "query_string": get_encoded_token_functions.append(_decode_jwt_from_query_string) if location == "headers": get_encoded_token_functions.append(_decode_jwt_from_headers) if location == "json": - get_encoded_token_functions.append( - lambda: _decode_jwt_from_json(token_type) - ) + get_encoded_token_functions.append(lambda: _decode_jwt_from_json(refresh)) # Try to find the token from one of these locations. It only needs to exist # in one place to be valid (not every location). @@ -277,10 +276,10 @@ def _decode_jwt_from_request(token_type, locations, fresh): raise NoAuthorizationError(errors[0]) # Additional verifications provided by this extension - verify_token_type(decoded_token, expected_type=token_type) + verify_token_type(decoded_token, refresh) if fresh: _verify_token_is_fresh(jwt_header, decoded_token) - verify_token_not_blocklisted(jwt_header, decoded_token, token_type) + verify_token_not_blocklisted(jwt_header, decoded_token) custom_verification_for_token(jwt_header, decoded_token) return decoded_token, jwt_header diff --git a/tests/test_decode_tokens.py b/tests/test_decode_tokens.py index 458a1944..8c1f0e5f 100644 --- a/tests/test_decode_tokens.py +++ b/tests/test_decode_tokens.py @@ -70,13 +70,13 @@ def test_default_decode_token_values(app, default_access_token): assert decoded["fresh"] is False -def test_bad_token_type(app, default_access_token): - default_access_token["type"] = "banana" - bad_type_token = encode_token(app, default_access_token) +def test_supports_decoding_other_token_types(app, default_access_token): + default_access_token["type"] = "app" + other_token = encode_token(app, default_access_token) - with pytest.raises(JWTDecodeError): - with app.test_request_context(): - decode_token(bad_type_token) + with app.test_request_context(): + decoded = decode_token(other_token) + assert decoded["type"] == "app" def test_encode_decode_audience(app): diff --git a/tests/test_view_decorators.py b/tests/test_view_decorators.py index 3ba99856..53f715f6 100644 --- a/tests/test_view_decorators.py +++ b/tests/test_view_decorators.py @@ -72,7 +72,7 @@ def test_jwt_required(app): # Test refresh token access to jwt_required response = test_client.get(url, headers=make_headers(refresh_token)) assert response.status_code == 422 - assert response.get_json() == {"msg": "Only access tokens are allowed"} + assert response.get_json() == {"msg": "Only non-refresh tokens are allowed"} def test_fresh_jwt_required(app): @@ -113,7 +113,7 @@ def test_fresh_jwt_required(app): response = test_client.get(url, headers=make_headers(refresh_token)) assert response.status_code == 422 - assert response.get_json() == {"msg": "Only access tokens are allowed"} + assert response.get_json() == {"msg": "Only non-refresh tokens are allowed"} # Test with custom response @jwtM.needs_fresh_token_loader @@ -176,7 +176,7 @@ def test_jwt_optional(app, delta_func): response = test_client.get(url, headers=make_headers(refresh_token)) assert response.status_code == 422 - assert response.get_json() == {"msg": "Only access tokens are allowed"} + assert response.get_json() == {"msg": "Only non-refresh tokens are allowed"} response = test_client.get(url, headers=None) assert response.status_code == 200