Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions docs/changing_default_behavior.rst
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,12 @@ Possible loader functions are:
* - **revoked_token_loader**
- Function to call when a revoked token accesses a protected endpoint
- None
* - **user_loader_callback_loader**
- Function to call to load a user object from a token
- Takes one argument - The identity of the token to load a user from
* - **user_loader_error_loader**
- Function that is called when the user_loader callback function returns **None**
- Takes one argument - The identity of the user who failed to load

Dynamic token expires time
~~~~~~~~~~~~~~~~~~~~~~~~~~
Expand Down
19 changes: 19 additions & 0 deletions docs/complex_objects_from_token.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
Complex Objects from Tokens
===========================

We can also do the inverse of creating tokens from complex objects like we did
in the last section. In this case, we can take a token and every time a
protected endpoint is accessed automatically use the token to load a complex
object, for example a SQLAlchemy user object. Here's an example of how it
might look:

.. literalinclude:: ../examples/complex_objects_from_tokens.py

If you do not provide a user_loader_callback in your application, and attempt
to access the **current_user** LocalProxy, it will simply be None.

One thing to note with this is that you will now call the **user_loader_callback**
on all of your protected endpoints, which will probably incur the cost of a
database lookup. In most cases this likely isn't a big deal for your application,
but do be aware that it could slow things down if your frontend is doing several
calls to endpoints in rapid succession.
1 change: 1 addition & 0 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ documentation is coming soon!
basic_usage
add_custom_data_claims
tokens_from_complex_object
complex_objects_from_token
refresh_tokens
token_freshness
changing_default_behavior
Expand Down
74 changes: 74 additions & 0 deletions examples/complex_objects_from_tokens.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
from flask import Flask, jsonify, request
from flask_jwt_extended import (
JWTManager, jwt_required, create_access_token, current_user
)

app = Flask(__name__)
app.secret_key = 'super-secret' # Change this!
jwt = JWTManager(app)


# A user object that we will load our tokens
class UserObject:
def __init__(self, username, roles):
self.username = username
self.roles = roles

# An example store of users. In production, this would likely
# be a sqlalchemy instance or something similiar
users_to_roles = {
'foo': ['admin'],
'bar': ['peasant'],
'baz': ['peasant']
}


# This function is called whenever a protected endpoint is accessed.
# This should return a complex object based on the token identity.
# This is called after the token is verified, so you can use
# get_jwt_claims() in here if desired. Note that this needs to
# return None if the user could not be loaded for any reason,
# such as not being found in the underlying data store
@jwt.user_loader_callback_loader
def user_loader_callback(identity):
if identity not in users_to_roles:
return None

return UserObject(
username=identity,
roles=users_to_roles[identity]
)


# You can override the error returned to the user if the
# user_loader_callback returns None. By default, if you don't
# override this, it will return a 401 status code with the json:
# {'msg': "Error loading the user <identity>"}. You can use
# get_jwt_claims() here too if desired
@jwt.user_loader_error_loader
def custom_user_loader_error(identity):
return jsonify({"msg": "User not found"}), 404


# Create a token for any user, so this can be tested out
@app.route('/login', methods=['POST'])
def login():
username = request.json.get('username', None)
access_token = create_access_token(identity=username)
ret = {'access_token': access_token}
return jsonify(ret), 200


# If the user_loader_callback returns None, this method will
# not get hit, even if the access token is valid. You can
# access the loaded user via the ``current_user``` LocalProxy,
# or with the ```get_current_user()``` method
@app.route('/admin-only', methods=['GET'])
@jwt_required
def protected():
if 'admin' not in current_user.roles:
return jsonify({"msg": "Forbidden"}), 403
return jsonify({"secret_msg": "don't forget to drink your ovaltine"})

if __name__ == '__main__':
app.run()
2 changes: 1 addition & 1 deletion flask_jwt_extended/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from .utils import (
create_refresh_token, create_access_token, get_jwt_identity,
get_jwt_claims, set_access_cookies, set_refresh_cookies,
unset_jwt_cookies, get_raw_jwt
unset_jwt_cookies, get_raw_jwt, get_current_user, current_user
)
from .blacklist import (
revoke_token, unrevoke_token, get_stored_tokens, get_all_stored_tokens,
Expand Down
9 changes: 9 additions & 0 deletions flask_jwt_extended/default_callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,3 +74,12 @@ def default_revoked_token_callback():
return a general error message with a 401 status code
"""
return jsonify({'msg': 'Token has been revoked'}), 401


def default_user_loader_error_callback(identity):
"""
By default, if a user_loader callback is defined and the callback
function returns None, we return a general error message with a 401
status code
"""
return jsonify({'msg': "Error loading the user {}".format(identity)}), 401
8 changes: 8 additions & 0 deletions flask_jwt_extended/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,3 +54,11 @@ class FreshTokenRequired(JWTExtendedException):
protected by fresh_jwt_required
"""
pass


class UserLoadError(JWTExtendedException):
"""
Error raised when a user_loader callback function returns None, indicating
that it cannot or will not load a user for the given identity.
"""
pass
66 changes: 60 additions & 6 deletions flask_jwt_extended/jwt_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,18 @@
from flask_jwt_extended.config import config
from flask_jwt_extended.exceptions import (
JWTDecodeError, NoAuthorizationError, InvalidHeaderError, WrongTokenError,
RevokedTokenError, FreshTokenRequired, CSRFError
RevokedTokenError, FreshTokenRequired, CSRFError, UserLoadError
)
from flask_jwt_extended.default_callbacks import (
default_expired_token_callback, default_user_claims_callback,
default_user_identity_callback, default_invalid_token_callback,
default_unauthorized_callback,
default_needs_fresh_token_callback,
default_revoked_token_callback
default_unauthorized_callback, default_needs_fresh_token_callback,
default_revoked_token_callback, default_user_loader_error_callback
)
from flask_jwt_extended.tokens import (
encode_refresh_token, decode_jwt,
encode_access_token
encode_refresh_token, decode_jwt, encode_access_token
)
from flask_jwt_extended.utils import get_jwt_identity


class JWTManager(object):
Expand All @@ -39,6 +38,8 @@ def __init__(self, app=None):
self._unauthorized_callback = default_unauthorized_callback
self._needs_fresh_token_callback = default_needs_fresh_token_callback
self._revoked_token_callback = default_revoked_token_callback
self._user_loader_callback = None
self._user_loader_error_callback = default_user_loader_error_callback

# Register this extension with the flask app now (if it is provided)
if app is not None:
Expand Down Expand Up @@ -101,6 +102,14 @@ def handle_revoked_token_error(e):
def handle_fresh_token_required(e):
return self._needs_fresh_token_callback()

@app.errorhandler(UserLoadError)
def handler_user_load_error(e):
# The identity is already saved before this exception was raised,
# otherwise a different exception would be raised, which is why we
# can safely call get_jwt_identity() here
identity = get_jwt_identity()
return self._user_loader_error_callback(identity)

@staticmethod
def _set_default_configuration_options(app):
"""
Expand Down Expand Up @@ -244,6 +253,50 @@ def revoked_token_loader(self, callback):
self._revoked_token_callback = callback
return callback

def user_loader_callback_loader(self, callback):
"""
Sets the callback method to be called to load a user on a protected
endpoint.

By default this is not is not used.

If a callback method is passed in here, it must take one argument,
which is the identity of the user to load. It must return the user
object, or None in the case of an error (which will cause the TODO
error handler to be hit)
"""
self._user_loader_callback = callback
return callback

def user_loader_error_loader(self, callback):
"""
Sets the callback method to be called if a user fails or is refused
to load when calling the _user_loader_callback function (indicated by
that function returning None)

The default implementation will return json:
'{"msg": "Error loading the user <identity>"}' with a 400 status code.

Callback must be a function that takes one argument, the identity of the
user who failed to load.
"""
self._user_loader_error_callback = callback
return callback

def has_user_loader(self):
"""
Returns True if a user_loader_callback has been defined in this
application, False otherwise
"""
return self._user_loader_callback is not None

def user_loader(self, identity):
"""
Calls the _user_loader_callback function (if it is defined) and returns
the resulting user from this callback.
"""
return self._user_loader_callback(identity)

def create_refresh_token(self, identity, expires_delta=None):
"""
Creates a new refresh token
Expand Down Expand Up @@ -315,3 +368,4 @@ def create_access_token(self, identity, fresh=False, expires_delta=None):
config.algorithm, csrf=config.csrf_protect)
store_token(decoded_token, revoked=False)
return access_token

25 changes: 25 additions & 0 deletions flask_jwt_extended/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from flask import current_app
from werkzeug.local import LocalProxy

try:
from flask import _app_ctx_stack as ctx_stack
except ImportError: # pragma: no cover
Expand All @@ -8,6 +10,10 @@
from flask_jwt_extended.tokens import decode_jwt


# Proxy to access the current user
current_user = LocalProxy(lambda: get_current_user())


def get_raw_jwt():
"""
Returns the python dictionary which has all of the data in this JWT. If no
Expand All @@ -32,6 +38,15 @@ def get_jwt_claims():
return get_raw_jwt().get('user_claims', {})


def get_current_user():
"""
Returns the loaded user from a user_loader callback in a protected endpoint.
If no user was loaded, or if no user_loader callback was defined, this will
return None
"""
return getattr(ctx_stack.top, 'jwt_user', None)


def get_jti(encoded_token):
"""
Returns the JTI given the JWT encoded token
Expand Down Expand Up @@ -60,6 +75,16 @@ def create_refresh_token(*args, **kwargs):
return jwt_manager.create_refresh_token(*args, **kwargs)


def user_loader(*args, **kwargs):
jwt_manager = _get_jwt_manager()
return jwt_manager.user_loader(*args, **kwargs)


def has_user_loader(*args, **kwargs):
jwt_manager = _get_jwt_manager()
return jwt_manager.has_user_loader(*args, **kwargs)


def get_csrf_token(encoded_token):
token = decode_jwt(encoded_token, config.decode_key, config.algorithm, csrf=True)
return token['csrf']
Expand Down
27 changes: 15 additions & 12 deletions flask_jwt_extended/view_decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,10 @@
from flask_jwt_extended.config import config
from flask_jwt_extended.exceptions import (
InvalidHeaderError, NoAuthorizationError, WrongTokenError,
FreshTokenRequired, CSRFError
FreshTokenRequired, CSRFError, UserLoadError
)
from flask_jwt_extended.tokens import decode_jwt
from flask_jwt_extended.utils import has_user_loader, user_loader


def jwt_required(fn):
Expand All @@ -28,10 +29,9 @@ def jwt_required(fn):
"""
@wraps(fn)
def wrapper(*args, **kwargs):
# Save the jwt in the context so that it can be accessed later by
# the various endpoints that is using this decorator
jwt_data = _decode_jwt_from_request(request_type='access')
ctx_stack.top.jwt = jwt_data
_load_user(jwt_data['identity'])
return fn(*args, **kwargs)
return wrapper

Expand All @@ -49,15 +49,11 @@ def jwt_optional(fn):
@wraps(fn)
def wrapper(*args, **kwargs):
try:
# If an acceptable JWT is found in the request, put it into
# the application context
jwt_data = _decode_jwt_from_request(request_type='access')
ctx_stack.top.jwt = jwt_data
_load_user(jwt_data['identity'])
except NoAuthorizationError:
# Allow request to proceed if no authorization header is present
# in the request, but don't modify application context
pass
# Return the decorated function in either case
return fn(*args, **kwargs)
return wrapper

Expand All @@ -78,9 +74,8 @@ def wrapper(*args, **kwargs):
if not jwt_data['fresh']:
raise FreshTokenRequired('Fresh token required')

# Save the jwt in the context so that it can be accessed later by
# the various endpoints that is using this decorator
ctx_stack.top.jwt = jwt_data
_load_user(jwt_data['identity'])
return fn(*args, **kwargs)
return wrapper

Expand All @@ -93,14 +88,22 @@ def jwt_refresh_token_required(fn):
"""
@wraps(fn)
def wrapper(*args, **kwargs):
# Save the jwt in the context so that it can be accessed later by
# the various endpoints that is using this decorator
jwt_data = _decode_jwt_from_request(request_type='refresh')
ctx_stack.top.jwt = jwt_data
_load_user(jwt_data['identity'])
return fn(*args, **kwargs)
return wrapper


def _load_user(identity):
if has_user_loader():
user = user_loader(identity)
if user is None:
raise UserLoadError("user_loader returned None for {}".format(identity))
else:
ctx_stack.top.jwt_user = user


def _decode_jwt_from_headers():
header_name = config.header_name
header_type = config.header_type
Expand Down
2 changes: 1 addition & 1 deletion tests/test_blacklist.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def setUp(self):

@self.app.route('/auth/login', methods=['POST'])
def login():
username = request.json['username']
username = request.get_json()['username']
ret = {
'access_token': create_access_token(username, fresh=True),
'refresh_token': create_refresh_token(username)
Expand Down
Loading