Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions docs/options.rst
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,10 @@ General Options:
``JWT_REFRESH_TOKEN_EXPIRES`` How long a refresh token should live before it expires. This
takes a ``datetime.timedelta``, and defaults to 30 days
``JWT_ALGORITHM`` Which algorithm to sign the JWT with. `See here <https://pyjwt.readthedocs.io/en/latest/algorithms.html>`_
for the options. Defaults to ``'HS256'``. Note that Asymmetric
(Public-key) algorithms are not currently supported.
for the options. Defaults to ``'HS256'``.
``JWT_PUBLIC_KEY`` The public key needed for RSA and ECDSA based signing algorithms.
Has to be provided if any of ``RS*`` or ``ES*`` algorithms is used.
PEM format expected.
================================= =========================================


Expand Down
26 changes: 26 additions & 0 deletions flask_jwt_extended/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import simplekv
from flask import current_app
from jwt.algorithms import requires_cryptography


class _Config(object):
Expand All @@ -15,6 +16,18 @@ class _Config(object):
object. All of these values are read only.
"""

@property
def is_asymmetric(self):
return self.algorithm in requires_cryptography

@property
def encode_key(self):
return self.secret_key

@property
def decode_key(self):
return self.public_key if self.is_asymmetric else self.secret_key

@property
def token_location(self):
locations = current_app.config['JWT_TOKEN_LOCATION']
Expand Down Expand Up @@ -172,6 +185,17 @@ def secret_key(self):
raise RuntimeError('flask SECRET_KEY must be set')
return key

@property
def public_key(self):
key = None
if self.algorithm in requires_cryptography:
key = current_app.config.get('JWT_PUBLIC_KEY', None)
if not key:
raise RuntimeError('JWT_PUBLIC_KEY must be set to use '
'asymmetric cryptography algorith '
'"{crypto_algorithm}"'.format(crypto_algorithm=self.algorithm))
return key

@property
def cookie_max_age(self):
# Returns the appropiate value for max_age for flask set_cookies. If
Expand All @@ -180,3 +204,5 @@ def cookie_max_age(self):
return None if self.session_cookie else 2147483647 # 2^31

config = _Config()


14 changes: 8 additions & 6 deletions flask_jwt_extended/jwt_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,10 +138,12 @@ def _set_default_configuration_options(app):
app.config.setdefault('JWT_REFRESH_TOKEN_EXPIRES', datetime.timedelta(days=30))

# What algorithm to use to sign the token. See here for a list of options:
# https://github.com/jpadilla/pyjwt/blob/master/jwt/api_jwt.py (note
# that public private key is not yet supported in this extension)
# https://github.com/jpadilla/pyjwt/blob/master/jwt/api_jwt.py
app.config.setdefault('JWT_ALGORITHM', 'HS256')

# must be set if using asymmetric cryptography algorithm (RS* or EC*)
app.config.setdefault('JWT_PUBLIC_KEY', None)

# Options for blacklisting/revoking tokens
app.config.setdefault('JWT_BLACKLIST_ENABLED', False)
app.config.setdefault('JWT_BLACKLIST_STORE', None)
Expand Down Expand Up @@ -251,15 +253,15 @@ def create_refresh_token(self, identity):
"""
refresh_token = encode_refresh_token(
identity=self._user_identity_callback(identity),
secret=config.secret_key,
secret=config.encode_key,
algorithm=config.algorithm,
expires_delta=config.refresh_expires,
csrf=config.csrf_protect
)

# If blacklisting is enabled, store this token in our key-value store
if config.blacklist_enabled:
decoded_token = decode_jwt(refresh_token, config.secret_key,
decoded_token = decode_jwt(refresh_token, config.decode_key,
config.algorithm, csrf=config.csrf_protect)
store_token(decoded_token, revoked=False)
return refresh_token
Expand All @@ -282,15 +284,15 @@ def create_access_token(self, identity, fresh=False):
"""
access_token = encode_access_token(
identity=self._user_identity_callback(identity),
secret=config.secret_key,
secret=config.encode_key,
algorithm=config.algorithm,
expires_delta=config.access_expires,
fresh=fresh,
user_claims=self._user_claims_callback(identity),
csrf=config.csrf_protect
)
if config.blacklist_enabled and config.blacklist_access_tokens:
decoded_token = decode_jwt(access_token, config.secret_key,
decoded_token = decode_jwt(access_token, config.decode_key,
config.algorithm, csrf=config.csrf_protect)
store_token(decoded_token, revoked=False)
return access_token
2 changes: 1 addition & 1 deletion flask_jwt_extended/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def create_refresh_token(*args, **kwargs):


def get_csrf_token(encoded_token):
token = decode_jwt(encoded_token, config.secret_key, config.algorithm, csrf=True)
token = decode_jwt(encoded_token, config.decode_key, config.algorithm, csrf=True)
return token['csrf']


Expand Down
4 changes: 2 additions & 2 deletions flask_jwt_extended/view_decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def _decode_jwt_from_headers():
raise InvalidHeaderError(msg)
token = parts[1]

return decode_jwt(token, config.secret_key, config.algorithm, csrf=False)
return decode_jwt(token, config.decode_key, config.algorithm, csrf=False)


def _decode_jwt_from_cookies(request_type):
Expand All @@ -115,7 +115,7 @@ def _decode_jwt_from_cookies(request_type):

decoded_token = decode_jwt(
encoded_token=encoded_token,
secret=config.secret_key,
secret=config.decode_key,
algorithm=config.algorithm,
csrf=config.csrf_protect
)
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ alabaster==0.7.9
Babel==2.3.4
click==6.6
coverage==4.2
cryptography==1.8.1
docutils==0.12
Flask==0.11.1
imagesize==0.7.1
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
packages=['flask_jwt_extended'],
zip_safe=False,
platforms='any',
install_requires=['Flask', 'PyJWT', 'simplekv'],
install_requires=['Flask', 'PyJWT', 'simplekv', 'cryptography'],
classifiers=[
'Development Status :: 4 - Beta',
'Environment :: Web Environment',
Expand Down
24 changes: 24 additions & 0 deletions tests/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,15 @@ def test_default_configs(self):
self.assertEqual(config.access_expires, timedelta(minutes=15))
self.assertEqual(config.refresh_expires, timedelta(days=30))
self.assertEqual(config.algorithm, 'HS256')
self.assertEqual(config.is_asymmetric, False)
self.assertEqual(config.blacklist_enabled, False)
self.assertEqual(config.blacklist_checks, 'refresh')
self.assertEqual(config.blacklist_access_tokens, False)

self.assertEqual(config.secret_key, self.app.secret_key)
self.assertEqual(config.public_key, None)
self.assertEqual(config.encode_key, self.app.secret_key)
self.assertEqual(config.decode_key, self.app.secret_key)
self.assertEqual(config.cookie_max_age, None)

with self.assertRaises(RuntimeError):
Expand Down Expand Up @@ -166,6 +170,15 @@ def test_invalid_config_options(self):
with self.assertRaises(RuntimeError):
config.secret_key

self.app.secret_key = None
with self.assertRaises(RuntimeError):
config.encode_key

self.app.config['JWT_ALGORITHM'] = 'RS256'
self.app.config['JWT_PUBLIC_KEY'] = None
with self.assertRaises(RuntimeError):
config.decode_key

def test_depreciated_options(self):
self.app.config['JWT_CSRF_HEADER_NAME'] = 'Auth'

Expand Down Expand Up @@ -205,3 +218,14 @@ def test_special_config_options(self):
self.app.config['JWT_TOKEN_LOCATION'] = ['cookies']
self.app.config['JWT_COOKIE_CSRF_PROTECT'] = False
self.assertEqual(config.csrf_protect, False)

def test_asymmetric_encryption_key_handling(self):
self.app.secret_key = 'MOCK_RSA_PRIVATE_KEY'
self.app.config['JWT_PUBLIC_KEY'] = 'MOCK_RSA_PUBLIC_KEY'
self.app.config['JWT_ALGORITHM'] = 'RS256'

with self.app.test_request_context():
self.assertEqual(config.is_asymmetric, True)
self.assertEqual(config.secret_key, 'MOCK_RSA_PRIVATE_KEY')
self.assertEqual(config.encode_key, 'MOCK_RSA_PRIVATE_KEY')
self.assertEqual(config.decode_key, 'MOCK_RSA_PUBLIC_KEY')
110 changes: 110 additions & 0 deletions tests/test_protected_endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -814,3 +814,113 @@ def test_accessing_endpoint_without_jwt(self):
data = json.loads(response.get_data(as_text=True))
self.assertEqual(status_code, 401)
self.assertIn('msg', data)


# random 1024bit RSA keypair
RSA_PRIVATE = """
-----BEGIN RSA PRIVATE KEY-----
MIICXgIBAAKBgQDN+p9a9oMyqRzkae8yLdJcEK0O0WesH6JiMz+KDrpUwAoAM/KP
DnxFnROJDSBHyHEmPVn5x8GqV5lQ9+6l97jdEEcPo6wkshycM82fgcxOmvtAy4Uo
xq/AeplYqplhcUTGVuo4ZldOLmN8ksGmzhWpsOdT0bkYipHCn5sWZxd21QIDAQAB
AoGBAMJ0++KVXXEDZMpjFDWsOq898xNNMHG3/8ZzmWXN161RC1/7qt/RjhLuYtX9
NV9vZRrzyrDcHAKj5pMhLgUzpColKzvdG2vKCldUs2b0c8HEGmjsmpmgoI1Tdf9D
G1QK+q9pKHlbj/MLr4vZPX6xEwAFeqRKlzL30JPD+O6mOXs1AkEA8UDzfadH1Y+H
bcNN2COvCqzqJMwLNRMXHDmUsjHfR2gtzk6D5dDyEaL+O4FLiQCaNXGWWoDTy/HJ
Clh1Z0+KYwJBANqRtJ+RvdgHMq0Yd45MMyy0ODGr1B3PoRbUK8EdXpyUNMi1g3iJ
tXMbLywNkTfcEXZTlbbkVYwrEl6P2N1r42cCQQDb9UQLBEFSTRJE2RRYQ/CL4yt3
cTGmqkkfyr/v19ii2jEpMBzBo8eQnPL+fdvIhWwT3gQfb+WqxD9v10bzcmnRAkEA
mzTgeHd7wg3KdJRtQYTmyhXn2Y3VAJ5SG+3qbCW466NqoCQVCeFwEh75rmSr/Giv
lcDhDZCzFuf3EWNAcmuMfQJARsWfM6q7v2p6vkYLLJ7+VvIwookkr6wymF5Zgb9d
E6oTM2EeUPSyyrj5IdsU2JCNBH1m3JnUflz8p8/NYCoOZg==
-----END RSA PRIVATE KEY-----
"""
RSA_PUBLIC = """
-----BEGIN RSA PUBLIC KEY-----
MIGJAoGBAM36n1r2gzKpHORp7zIt0lwQrQ7RZ6wfomIzP4oOulTACgAz8o8OfEWd
E4kNIEfIcSY9WfnHwapXmVD37qX3uN0QRw+jrCSyHJwzzZ+BzE6a+0DLhSjGr8B6
mViqmWFxRMZW6jhmV04uY3ySwabOFamw51PRuRiKkcKfmxZnF3bVAgMBAAE=
-----END RSA PUBLIC KEY-----
"""

class TestEndpointsWithAssymmetricCrypto(unittest.TestCase):

def setUp(self):
self.app = Flask(__name__)
self.app.secret_key = RSA_PRIVATE
self.app.config['JWT_PUBLIC_KEY'] = RSA_PUBLIC
self.app.config['JWT_ALGORITHM'] = 'RS256'
self.app.config['JWT_ACCESS_TOKEN_EXPIRES'] = timedelta(seconds=1)
self.app.config['JWT_REFRESH_TOKEN_EXPIRES'] = timedelta(seconds=1)
self.jwt_manager = JWTManager(self.app)
self.client = self.app.test_client()

@self.app.route('/auth/login', methods=['POST'])
def login():
ret = {
'access_token': create_access_token('test', fresh=True),
'refresh_token': create_refresh_token('test')
}
return jsonify(ret), 200

@self.app.route('/auth/refresh', methods=['POST'])
@jwt_refresh_token_required
def refresh():
username = get_jwt_identity()
ret = {'access_token': create_access_token(username, fresh=False)}
return jsonify(ret), 200

@self.app.route('/auth/fresh-login', methods=['POST'])
def fresh_login():
ret = {'access_token': create_access_token('test', fresh=True)}
return jsonify(ret), 200

@self.app.route('/protected')
@jwt_required
def protected():
return jsonify({'msg': "hello world"})

@self.app.route('/fresh-protected')
@fresh_jwt_required
def fresh_protected():
return jsonify({'msg': "fresh hello world"})

def _jwt_post(self, url, jwt):
response = self.client.post(url, content_type='application/json',
headers={'Authorization': 'Bearer {}'.format(jwt)})
status_code = response.status_code
data = json.loads(response.get_data(as_text=True))
return status_code, data

def _jwt_get(self, url, jwt, header_name='Authorization', header_type='Bearer'):
header_type = '{} {}'.format(header_type, jwt).strip()
response = self.client.get(url, headers={header_name: header_type})
status_code = response.status_code
data = json.loads(response.get_data(as_text=True))
return status_code, data

def test_login(self):
response = self.client.post('/auth/login')
status_code = response.status_code
data = json.loads(response.get_data(as_text=True))
self.assertEqual(status_code, 200)
self.assertIn('access_token', data)
self.assertIn('refresh_token', data)

def test_fresh_login(self):
response = self.client.post('/auth/fresh-login')
status_code = response.status_code
data = json.loads(response.get_data(as_text=True))
self.assertEqual(status_code, 200)
self.assertIn('access_token', data)
self.assertNotIn('refresh_token', data)

def test_refresh(self):
response = self.client.post('/auth/login')
data = json.loads(response.get_data(as_text=True))
access_token = data['access_token']
refresh_token = data['refresh_token']

status_code, data = self._jwt_post('/auth/refresh', refresh_token)
self.assertEqual(status_code, 200)
self.assertIn('access_token', data)
self.assertNotIn('refresh_token', data)