Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/options.rst
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ General Options:
Defaults to ``'identity'`` for legacy reasons.
``JWT_USER_CLAIMS`` Claim in the tokens that is used to store user claims.
Defaults to ``'user_claims'``.
``JWT_CLAIMS_IN_REFRESH_TOKEN`` If user claims should be included in refresh tokens.
Defaults to ``False``.
================================= =========================================


Expand Down
4 changes: 4 additions & 0 deletions flask_jwt_extended/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,10 @@ def identity_claim_key(self):
def user_claims_key(self):
return current_app.config['JWT_USER_CLAIMS']

@property
def user_claims_in_refresh_token(self):
return current_app.config['JWT_CLAIMS_IN_REFRESH_TOKEN']

@property
def exempt_methods(self):
return {"OPTIONS"}
Expand Down
9 changes: 9 additions & 0 deletions flask_jwt_extended/jwt_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,8 @@ def _set_default_configuration_options(app):
app.config.setdefault('JWT_IDENTITY_CLAIM', 'identity')
app.config.setdefault('JWT_USER_CLAIMS', 'user_claims')

app.config.setdefault('JWT_CLAIMS_IN_REFRESH_TOKEN', False)

def user_claims_loader(self, callback):
"""
This decorator sets the callback function for adding custom claims to an
Expand Down Expand Up @@ -375,13 +377,20 @@ def _create_refresh_token(self, identity, expires_delta=None):
if expires_delta is None:
expires_delta = config.refresh_expires

if config.user_claims_in_refresh_token:
user_claims = self._user_claims_callback(identity)
else:
user_claims = None

refresh_token = encode_refresh_token(
identity=self._user_identity_callback(identity),
secret=config.encode_key,
algorithm=config.algorithm,
expires_delta=expires_delta,
user_claims=user_claims,
csrf=config.csrf_protect,
identity_claim_key=config.identity_claim_key,
user_claims_key=config.user_claims_key,
json_encoder=config.json_encoder
)
return refresh_token
Expand Down
17 changes: 13 additions & 4 deletions flask_jwt_extended/tokens.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,9 @@ def encode_access_token(identity, secret, algorithm, expires_delta, fresh,
json_encoder=json_encoder)


def encode_refresh_token(identity, secret, algorithm, expires_delta, csrf,
identity_claim_key, json_encoder=None):
def encode_refresh_token(identity, secret, algorithm, expires_delta, user_claims,
csrf, identity_claim_key, user_claims_key,
json_encoder=None):
"""
Creates a new encoded (utf-8) refresh token.

Expand All @@ -88,15 +89,23 @@ def encode_refresh_token(identity, secret, algorithm, expires_delta, csrf,
:param expires_delta: How far in the future this token should expire
(set to False to disable expiration)
:type expires_delta: datetime.timedelta or False
:param user_claims: Custom claims to include in this token. This data must
be json serializable
:param csrf: Whether to include a csrf double submit claim in this token
(boolean)
:param identity_claim_key: Which key should be used to store the identity
:param user_claims_key: Which key should be used to store the user claims
:return: Encoded refresh token
"""
token_data = {
identity_claim_key: identity,
'type': 'refresh',
}

# Don't add extra data to the token if user_claims is empty.
if user_claims:
token_data[user_claims_key] = user_claims

if csrf:
token_data['csrf'] = _create_csrf_token()
return _encode_jwt(token_data, expires_delta, secret, algorithm,
Expand Down Expand Up @@ -129,8 +138,8 @@ def decode_jwt(encoded_token, secret, algorithm, identity_claim_key,
if data['type'] == 'access':
if 'fresh' not in data:
raise JWTDecodeError("Missing claim: fresh")
if user_claims_key not in data:
data[user_claims_key] = {}
if user_claims_key not in data:
data[user_claims_key] = {}
if csrf_value:
if 'csrf' not in data:
raise JWTDecodeError("Missing claim: csrf")
Expand Down
6 changes: 6 additions & 0 deletions tests/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@ def test_default_configs(app):
assert config.identity_claim_key == 'identity'
assert config.user_claims_key == 'user_claims'

assert config.user_claims_in_refresh_token is False

assert config.json_encoder is app.json_encoder


Expand Down Expand Up @@ -100,6 +102,8 @@ def test_override_configs(app):
app.config['JWT_IDENTITY_CLAIM'] = 'foo'
app.config['JWT_USER_CLAIMS'] = 'bar'

app.config['JWT_CLAIMS_IN_REFRESH_TOKEN'] = True

class CustomJSONEncoder(JSONEncoder):
pass

Expand Down Expand Up @@ -148,6 +152,8 @@ class CustomJSONEncoder(JSONEncoder):
assert config.identity_claim_key == 'foo'
assert config.user_claims_key == 'bar'

assert config.user_claims_in_refresh_token is True

assert config.json_encoder is CustomJSONEncoder


Expand Down
40 changes: 39 additions & 1 deletion tests/test_user_claims_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from flask_jwt_extended import (
JWTManager, create_access_token, jwt_required, get_jwt_claims,
decode_token
decode_token, jwt_refresh_token_required, create_refresh_token
)
from tests.utils import get_jwt_manager, make_headers

Expand All @@ -19,6 +19,11 @@ def app():
def get_claims():
return jsonify(get_jwt_claims())

@app.route('/protected2', methods=['GET'])
@jwt_refresh_token_required
def get_refresh_claims():
return jsonify(get_jwt_claims())

return app


Expand Down Expand Up @@ -99,3 +104,36 @@ def add_claims(identity):
response = test_client.get('/protected', headers=make_headers(access_token))
assert response.get_json() == {'foo': 'bar'}
assert response.status_code == 200


def test_user_claim_not_in_refresh_token(app):
jwt = get_jwt_manager(app)

@jwt.user_claims_loader
def add_claims(identity):
return {'foo': 'bar'}

with app.test_request_context():
refresh_token = create_refresh_token('username')

test_client = app.test_client()
response = test_client.get('/protected2', headers=make_headers(refresh_token))
assert response.get_json() == {}
assert response.status_code == 200


def test_user_claim_in_refresh_token(app):
app.config['JWT_CLAIMS_IN_REFRESH_TOKEN'] = True
jwt = get_jwt_manager(app)

@jwt.user_claims_loader
def add_claims(identity):
return {'foo': 'bar'}

with app.test_request_context():
refresh_token = create_refresh_token('username')

test_client = app.test_client()
response = test_client.get('/protected2', headers=make_headers(refresh_token))
assert response.get_json() == {'foo': 'bar'}
assert response.status_code == 200