diff --git a/docs/changing_default_behavior.rst b/docs/changing_default_behavior.rst index 60b233a2..796bbc96 100644 --- a/docs/changing_default_behavior.rst +++ b/docs/changing_default_behavior.rst @@ -37,6 +37,12 @@ Possible loader functions are: * - **revoked_token_loader** - Function to call when a revoked token accesses a protected endpoint - None + * - **user_loader_callback_loader** + - Function to call to load a user object from a token + - Takes one argument - The identity of the token to load a user from + * - **user_loader_error_loader** + - Function that is called when the user_loader callback function returns **None** + - Takes one argument - The identity of the user who failed to load Dynamic token expires time ~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/docs/complex_objects_from_token.rst b/docs/complex_objects_from_token.rst new file mode 100644 index 00000000..35453920 --- /dev/null +++ b/docs/complex_objects_from_token.rst @@ -0,0 +1,19 @@ +Complex Objects from Tokens +=========================== + +We can also do the inverse of creating tokens from complex objects like we did +in the last section. In this case, we can take a token and every time a +protected endpoint is accessed automatically use the token to load a complex +object, for example a SQLAlchemy user object. Here's an example of how it +might look: + +.. literalinclude:: ../examples/complex_objects_from_tokens.py + +If you do not provide a user_loader_callback in your application, and attempt +to access the **current_user** LocalProxy, it will simply be None. + +One thing to note with this is that you will now call the **user_loader_callback** +on all of your protected endpoints, which will probably incur the cost of a +database lookup. In most cases this likely isn't a big deal for your application, +but do be aware that it could slow things down if your frontend is doing several +calls to endpoints in rapid succession. diff --git a/docs/index.rst b/docs/index.rst index eeff3aaf..935a400b 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -16,6 +16,7 @@ documentation is coming soon! basic_usage add_custom_data_claims tokens_from_complex_object + complex_objects_from_token refresh_tokens token_freshness changing_default_behavior diff --git a/examples/complex_objects_from_tokens.py b/examples/complex_objects_from_tokens.py new file mode 100644 index 00000000..a1f7bd82 --- /dev/null +++ b/examples/complex_objects_from_tokens.py @@ -0,0 +1,74 @@ +from flask import Flask, jsonify, request +from flask_jwt_extended import ( + JWTManager, jwt_required, create_access_token, current_user +) + +app = Flask(__name__) +app.secret_key = 'super-secret' # Change this! +jwt = JWTManager(app) + + +# A user object that we will load our tokens +class UserObject: + def __init__(self, username, roles): + self.username = username + self.roles = roles + +# An example store of users. In production, this would likely +# be a sqlalchemy instance or something similiar +users_to_roles = { + 'foo': ['admin'], + 'bar': ['peasant'], + 'baz': ['peasant'] +} + + +# This function is called whenever a protected endpoint is accessed. +# This should return a complex object based on the token identity. +# This is called after the token is verified, so you can use +# get_jwt_claims() in here if desired. Note that this needs to +# return None if the user could not be loaded for any reason, +# such as not being found in the underlying data store +@jwt.user_loader_callback_loader +def user_loader_callback(identity): + if identity not in users_to_roles: + return None + + return UserObject( + username=identity, + roles=users_to_roles[identity] + ) + + +# You can override the error returned to the user if the +# user_loader_callback returns None. By default, if you don't +# override this, it will return a 401 status code with the json: +# {'msg': "Error loading the user "}. You can use +# get_jwt_claims() here too if desired +@jwt.user_loader_error_loader +def custom_user_loader_error(identity): + return jsonify({"msg": "User not found"}), 404 + + +# Create a token for any user, so this can be tested out +@app.route('/login', methods=['POST']) +def login(): + username = request.json.get('username', None) + access_token = create_access_token(identity=username) + ret = {'access_token': access_token} + return jsonify(ret), 200 + + +# If the user_loader_callback returns None, this method will +# not get hit, even if the access token is valid. You can +# access the loaded user via the ``current_user``` LocalProxy, +# or with the ```get_current_user()``` method +@app.route('/admin-only', methods=['GET']) +@jwt_required +def protected(): + if 'admin' not in current_user.roles: + return jsonify({"msg": "Forbidden"}), 403 + return jsonify({"secret_msg": "don't forget to drink your ovaltine"}) + +if __name__ == '__main__': + app.run() diff --git a/flask_jwt_extended/__init__.py b/flask_jwt_extended/__init__.py index 46867ac7..4dac6577 100644 --- a/flask_jwt_extended/__init__.py +++ b/flask_jwt_extended/__init__.py @@ -6,7 +6,7 @@ from .utils import ( create_refresh_token, create_access_token, get_jwt_identity, get_jwt_claims, set_access_cookies, set_refresh_cookies, - unset_jwt_cookies, get_raw_jwt + unset_jwt_cookies, get_raw_jwt, get_current_user, current_user ) from .blacklist import ( revoke_token, unrevoke_token, get_stored_tokens, get_all_stored_tokens, diff --git a/flask_jwt_extended/default_callbacks.py b/flask_jwt_extended/default_callbacks.py index d00d3a7f..2c962671 100644 --- a/flask_jwt_extended/default_callbacks.py +++ b/flask_jwt_extended/default_callbacks.py @@ -74,3 +74,12 @@ def default_revoked_token_callback(): return a general error message with a 401 status code """ return jsonify({'msg': 'Token has been revoked'}), 401 + + +def default_user_loader_error_callback(identity): + """ + By default, if a user_loader callback is defined and the callback + function returns None, we return a general error message with a 401 + status code + """ + return jsonify({'msg': "Error loading the user {}".format(identity)}), 401 diff --git a/flask_jwt_extended/exceptions.py b/flask_jwt_extended/exceptions.py index 91b77651..6b16dca5 100644 --- a/flask_jwt_extended/exceptions.py +++ b/flask_jwt_extended/exceptions.py @@ -54,3 +54,11 @@ class FreshTokenRequired(JWTExtendedException): protected by fresh_jwt_required """ pass + + +class UserLoadError(JWTExtendedException): + """ + Error raised when a user_loader callback function returns None, indicating + that it cannot or will not load a user for the given identity. + """ + pass diff --git a/flask_jwt_extended/jwt_manager.py b/flask_jwt_extended/jwt_manager.py index fdf61238..710d0f67 100644 --- a/flask_jwt_extended/jwt_manager.py +++ b/flask_jwt_extended/jwt_manager.py @@ -6,19 +6,18 @@ from flask_jwt_extended.config import config from flask_jwt_extended.exceptions import ( JWTDecodeError, NoAuthorizationError, InvalidHeaderError, WrongTokenError, - RevokedTokenError, FreshTokenRequired, CSRFError + RevokedTokenError, FreshTokenRequired, CSRFError, UserLoadError ) from flask_jwt_extended.default_callbacks import ( default_expired_token_callback, default_user_claims_callback, default_user_identity_callback, default_invalid_token_callback, - default_unauthorized_callback, - default_needs_fresh_token_callback, - default_revoked_token_callback + default_unauthorized_callback, default_needs_fresh_token_callback, + default_revoked_token_callback, default_user_loader_error_callback ) from flask_jwt_extended.tokens import ( - encode_refresh_token, decode_jwt, - encode_access_token + encode_refresh_token, decode_jwt, encode_access_token ) +from flask_jwt_extended.utils import get_jwt_identity class JWTManager(object): @@ -39,6 +38,8 @@ def __init__(self, app=None): self._unauthorized_callback = default_unauthorized_callback self._needs_fresh_token_callback = default_needs_fresh_token_callback self._revoked_token_callback = default_revoked_token_callback + self._user_loader_callback = None + self._user_loader_error_callback = default_user_loader_error_callback # Register this extension with the flask app now (if it is provided) if app is not None: @@ -101,6 +102,14 @@ def handle_revoked_token_error(e): def handle_fresh_token_required(e): return self._needs_fresh_token_callback() + @app.errorhandler(UserLoadError) + def handler_user_load_error(e): + # The identity is already saved before this exception was raised, + # otherwise a different exception would be raised, which is why we + # can safely call get_jwt_identity() here + identity = get_jwt_identity() + return self._user_loader_error_callback(identity) + @staticmethod def _set_default_configuration_options(app): """ @@ -244,6 +253,50 @@ def revoked_token_loader(self, callback): self._revoked_token_callback = callback return callback + def user_loader_callback_loader(self, callback): + """ + Sets the callback method to be called to load a user on a protected + endpoint. + + By default this is not is not used. + + If a callback method is passed in here, it must take one argument, + which is the identity of the user to load. It must return the user + object, or None in the case of an error (which will cause the TODO + error handler to be hit) + """ + self._user_loader_callback = callback + return callback + + def user_loader_error_loader(self, callback): + """ + Sets the callback method to be called if a user fails or is refused + to load when calling the _user_loader_callback function (indicated by + that function returning None) + + The default implementation will return json: + '{"msg": "Error loading the user "}' with a 400 status code. + + Callback must be a function that takes one argument, the identity of the + user who failed to load. + """ + self._user_loader_error_callback = callback + return callback + + def has_user_loader(self): + """ + Returns True if a user_loader_callback has been defined in this + application, False otherwise + """ + return self._user_loader_callback is not None + + def user_loader(self, identity): + """ + Calls the _user_loader_callback function (if it is defined) and returns + the resulting user from this callback. + """ + return self._user_loader_callback(identity) + def create_refresh_token(self, identity, expires_delta=None): """ Creates a new refresh token @@ -315,3 +368,4 @@ def create_access_token(self, identity, fresh=False, expires_delta=None): config.algorithm, csrf=config.csrf_protect) store_token(decoded_token, revoked=False) return access_token + diff --git a/flask_jwt_extended/utils.py b/flask_jwt_extended/utils.py index f2ab9e8f..65d94ce9 100644 --- a/flask_jwt_extended/utils.py +++ b/flask_jwt_extended/utils.py @@ -1,4 +1,6 @@ from flask import current_app +from werkzeug.local import LocalProxy + try: from flask import _app_ctx_stack as ctx_stack except ImportError: # pragma: no cover @@ -8,6 +10,10 @@ from flask_jwt_extended.tokens import decode_jwt +# Proxy to access the current user +current_user = LocalProxy(lambda: get_current_user()) + + def get_raw_jwt(): """ Returns the python dictionary which has all of the data in this JWT. If no @@ -32,6 +38,15 @@ def get_jwt_claims(): return get_raw_jwt().get('user_claims', {}) +def get_current_user(): + """ + Returns the loaded user from a user_loader callback in a protected endpoint. + If no user was loaded, or if no user_loader callback was defined, this will + return None + """ + return getattr(ctx_stack.top, 'jwt_user', None) + + def get_jti(encoded_token): """ Returns the JTI given the JWT encoded token @@ -60,6 +75,16 @@ def create_refresh_token(*args, **kwargs): return jwt_manager.create_refresh_token(*args, **kwargs) +def user_loader(*args, **kwargs): + jwt_manager = _get_jwt_manager() + return jwt_manager.user_loader(*args, **kwargs) + + +def has_user_loader(*args, **kwargs): + jwt_manager = _get_jwt_manager() + return jwt_manager.has_user_loader(*args, **kwargs) + + def get_csrf_token(encoded_token): token = decode_jwt(encoded_token, config.decode_key, config.algorithm, csrf=True) return token['csrf'] diff --git a/flask_jwt_extended/view_decorators.py b/flask_jwt_extended/view_decorators.py index 5c417ded..7b0a756a 100644 --- a/flask_jwt_extended/view_decorators.py +++ b/flask_jwt_extended/view_decorators.py @@ -11,9 +11,10 @@ from flask_jwt_extended.config import config from flask_jwt_extended.exceptions import ( InvalidHeaderError, NoAuthorizationError, WrongTokenError, - FreshTokenRequired, CSRFError + FreshTokenRequired, CSRFError, UserLoadError ) from flask_jwt_extended.tokens import decode_jwt +from flask_jwt_extended.utils import has_user_loader, user_loader def jwt_required(fn): @@ -28,10 +29,9 @@ def jwt_required(fn): """ @wraps(fn) def wrapper(*args, **kwargs): - # Save the jwt in the context so that it can be accessed later by - # the various endpoints that is using this decorator jwt_data = _decode_jwt_from_request(request_type='access') ctx_stack.top.jwt = jwt_data + _load_user(jwt_data['identity']) return fn(*args, **kwargs) return wrapper @@ -49,15 +49,11 @@ def jwt_optional(fn): @wraps(fn) def wrapper(*args, **kwargs): try: - # If an acceptable JWT is found in the request, put it into - # the application context jwt_data = _decode_jwt_from_request(request_type='access') ctx_stack.top.jwt = jwt_data + _load_user(jwt_data['identity']) except NoAuthorizationError: - # Allow request to proceed if no authorization header is present - # in the request, but don't modify application context pass - # Return the decorated function in either case return fn(*args, **kwargs) return wrapper @@ -78,9 +74,8 @@ def wrapper(*args, **kwargs): if not jwt_data['fresh']: raise FreshTokenRequired('Fresh token required') - # Save the jwt in the context so that it can be accessed later by - # the various endpoints that is using this decorator ctx_stack.top.jwt = jwt_data + _load_user(jwt_data['identity']) return fn(*args, **kwargs) return wrapper @@ -93,14 +88,22 @@ def jwt_refresh_token_required(fn): """ @wraps(fn) def wrapper(*args, **kwargs): - # Save the jwt in the context so that it can be accessed later by - # the various endpoints that is using this decorator jwt_data = _decode_jwt_from_request(request_type='refresh') ctx_stack.top.jwt = jwt_data + _load_user(jwt_data['identity']) return fn(*args, **kwargs) return wrapper +def _load_user(identity): + if has_user_loader(): + user = user_loader(identity) + if user is None: + raise UserLoadError("user_loader returned None for {}".format(identity)) + else: + ctx_stack.top.jwt_user = user + + def _decode_jwt_from_headers(): header_name = config.header_name header_type = config.header_type diff --git a/tests/test_blacklist.py b/tests/test_blacklist.py index a18537ea..cdd3db84 100644 --- a/tests/test_blacklist.py +++ b/tests/test_blacklist.py @@ -29,7 +29,7 @@ def setUp(self): @self.app.route('/auth/login', methods=['POST']) def login(): - username = request.json['username'] + username = request.get_json()['username'] ret = { 'access_token': create_access_token(username, fresh=True), 'refresh_token': create_refresh_token(username) diff --git a/tests/test_jwt_manager.py b/tests/test_jwt_manager.py index 7ee72277..067ee56f 100644 --- a/tests/test_jwt_manager.py +++ b/tests/test_jwt_manager.py @@ -32,7 +32,12 @@ def test_class_init(self): def test_default_user_claims_callback(self): identity = 'foobar' m = JWTManager(self.app) - assert m._user_claims_callback(identity) == {} + self.assertEqual(m._user_claims_callback(identity), {}) + + def test_default_user_identity_callback(self): + identity = 'foobar' + m = JWTManager(self.app) + self.assertEqual(m._user_identity_callback(identity), identity) def test_default_expired_token_callback(self): with self.app.test_request_context(): @@ -80,6 +85,24 @@ def test_default_revoked_token_callback(self): self.assertEqual(status_code, 401) self.assertEqual(data, {'msg': 'Token has been revoked'}) + def test_default_user_loader_callback(self): + m = JWTManager(self.app) + self.assertEqual(m._user_loader_callback, None) + + def test_default_user_loader_error_callback(self): + with self.app.test_request_context(): + identity = 'foobar' + m = JWTManager(self.app) + result = m._user_loader_error_callback(identity) + status_code, data = self._parse_callback_result(result) + + self.assertEqual(status_code, 401) + self.assertEqual(data, {'msg': 'Error loading the user foobar'}) + + def test_default_has_user_loader(self): + m = JWTManager(self.app) + self.assertEqual(m.has_user_loader(), False) + def test_custom_user_claims_callback(self): identity = 'foobar' m = JWTManager(self.app) @@ -159,3 +182,33 @@ def custom_revoken_token(): self.assertEqual(status_code, 422) self.assertEqual(data, {'err': 'Nice knowing you!'}) + + def test_custom_user_loader(self): + with self.app.test_request_context(): + m = JWTManager(self.app) + + @m.user_loader_callback_loader + def custom_user_loader(identity): + if identity == 'foo': + return None + return identity + + identity = 'foobar' + result = m._user_loader_callback(identity) + self.assertEqual(result, identity) + self.assertEqual(m.has_user_loader(), True) + + def test_custom_user_loader_error_callback(self): + with self.app.test_request_context(): + m = JWTManager(self.app) + + @m.user_loader_error_loader + def custom_user_loader_error(identity): + return jsonify({'msg': 'Not found'}), 404 + + identity = 'foobar' + result = m._user_loader_error_callback(identity) + status_code, data = self._parse_callback_result(result) + + self.assertEqual(status_code, 404) + self.assertEqual(data, {'msg': 'Not found'}) diff --git a/tests/test_protected_endpoints.py b/tests/test_protected_endpoints.py index 9f462ea1..6cbb80ca 100644 --- a/tests/test_protected_endpoints.py +++ b/tests/test_protected_endpoints.py @@ -72,7 +72,6 @@ def partially_protected(): return jsonify({'msg': "protected hello world"}) return jsonify({'msg': "unprotected hello world"}) - def _jwt_post(self, url, jwt): response = self.client.post(url, content_type='application/json', headers={'Authorization': 'Bearer {}'.format(jwt)}) diff --git a/tests/test_user_loader.py b/tests/test_user_loader.py new file mode 100644 index 00000000..33e27ac1 --- /dev/null +++ b/tests/test_user_loader.py @@ -0,0 +1,135 @@ +import json +import unittest +from datetime import timedelta + +from flask import Flask, jsonify, request + +from flask_jwt_extended import ( + JWTManager, create_access_token, create_refresh_token, + jwt_refresh_token_required, jwt_required, fresh_jwt_required, + jwt_optional, current_user +) + + +class TestUserLoader(unittest.TestCase): + + def setUp(self): + self.app = Flask(__name__) + self.app.secret_key = 'super=secret' + self.jwt_manager = JWTManager(self.app) + self.client = self.app.test_client() + + @self.jwt_manager.user_loader_callback_loader + def user_loader(identity): + if identity == 'foobar': + return None + return identity + + @self.app.route('/auth/login', methods=['POST']) + def login(): + username = request.get_json()['username'] + ret = { + 'access_token': create_access_token(username, fresh=True), + 'refresh_token': create_refresh_token(username) + } + return jsonify(ret), 200 + + @self.app.route('/refresh-protected') + @jwt_refresh_token_required + def refresh_endpoint(): + return jsonify({'username': str(current_user)}) + + @self.app.route('/protected') + @jwt_required + def protected_endpoint(): + return jsonify({'username': str(current_user)}) + + @self.app.route('/fresh-protected') + @fresh_jwt_required + def fresh_protected_endpoint(): + return jsonify({'username': str(current_user)}) + + @self.app.route('/partially-protected') + @jwt_optional + def optional_endpoint(): + return jsonify({'username': str(current_user)}) + + def _jwt_get(self, url, jwt): + response = self.client.get(url, content_type='application/json', + headers={'Authorization': 'Bearer {}'.format(jwt)}) + status_code = response.status_code + data = json.loads(response.get_data(as_text=True)) + return status_code, data + + def test_user_loads(self): + response = self.client.post('/auth/login', content_type='application/json', + data=json.dumps({'username': 'test'})) + data = json.loads(response.get_data(as_text=True)) + access_token = data['access_token'] + refresh_token = data['refresh_token'] + + status, data = self._jwt_get('/protected', access_token) + self.assertEqual(status, 200) + self.assertEqual(data, {'username': 'test'}) + + status, data = self._jwt_get('/fresh-protected', access_token) + self.assertEqual(status, 200) + self.assertEqual(data, {'username': 'test'}) + + status, data = self._jwt_get('/partially-protected', access_token) + self.assertEqual(status, 200) + self.assertEqual(data, {'username': 'test'}) + + status, data = self._jwt_get('/refresh-protected', refresh_token) + self.assertEqual(status, 200) + self.assertEqual(data, {'username': 'test'}) + + def test_failed_user_loads(self): + response = self.client.post('/auth/login', content_type='application/json', + data=json.dumps({'username': 'foobar'})) + data = json.loads(response.get_data(as_text=True)) + access_token = data['access_token'] + refresh_token = data['refresh_token'] + + status, data = self._jwt_get('/protected', access_token) + self.assertEqual(status, 401) + self.assertEqual(data, {'msg': 'Error loading the user foobar'}) + + status, data = self._jwt_get('/fresh-protected', access_token) + self.assertEqual(status, 401) + self.assertEqual(data, {'msg': 'Error loading the user foobar'}) + + status, data = self._jwt_get('/partially-protected', access_token) + self.assertEqual(status, 401) + self.assertEqual(data, {'msg': 'Error loading the user foobar'}) + + status, data = self._jwt_get('/refresh-protected', refresh_token) + self.assertEqual(status, 401) + self.assertEqual(data, {'msg': 'Error loading the user foobar'}) + + def test_custom_error_callback(self): + @self.jwt_manager.user_loader_error_loader + def custom_user_loader_error_callback(identity): + return jsonify({"msg": "Not found"}), 404 + + response = self.client.post('/auth/login', content_type='application/json', + data=json.dumps({'username': 'foobar'})) + data = json.loads(response.get_data(as_text=True)) + access_token = data['access_token'] + refresh_token = data['refresh_token'] + + status, data = self._jwt_get('/protected', access_token) + self.assertEqual(status, 404) + self.assertEqual(data, {'msg': 'Not found'}) + + status, data = self._jwt_get('/fresh-protected', access_token) + self.assertEqual(status, 404) + self.assertEqual(data, {'msg': 'Not found'}) + + status, data = self._jwt_get('/partially-protected', access_token) + self.assertEqual(status, 404) + self.assertEqual(data, {'msg': 'Not found'}) + + status, data = self._jwt_get('/refresh-protected', refresh_token) + self.assertEqual(status, 404) + self.assertEqual(data, {'msg': 'Not found'})