From 22168755181f9d3107ebedd19fdb87fb14fbc439 Mon Sep 17 00:00:00 2001 From: Alex Kotenko Date: Fri, 5 May 2017 16:20:40 +0100 Subject: [PATCH] - cryptography added to dependencies - JWT_PUBLIC_KEY config added - symmetric/asymmetric distinction added - test coverage for asymmetric handling added - docs updated --- docs/options.rst | 6 +- flask_jwt_extended/config.py | 26 ++++++ flask_jwt_extended/jwt_manager.py | 14 ++-- flask_jwt_extended/utils.py | 2 +- flask_jwt_extended/view_decorators.py | 4 +- requirements.txt | 1 + setup.py | 2 +- tests/test_config.py | 24 ++++++ tests/test_protected_endpoints.py | 110 ++++++++++++++++++++++++++ 9 files changed, 177 insertions(+), 12 deletions(-) diff --git a/docs/options.rst b/docs/options.rst index 1d2811eb..b7d0d0a2 100644 --- a/docs/options.rst +++ b/docs/options.rst @@ -22,8 +22,10 @@ General Options: ``JWT_REFRESH_TOKEN_EXPIRES`` How long a refresh token should live before it expires. This takes a ``datetime.timedelta``, and defaults to 30 days ``JWT_ALGORITHM`` Which algorithm to sign the JWT with. `See here `_ - for the options. Defaults to ``'HS256'``. Note that Asymmetric - (Public-key) algorithms are not currently supported. + for the options. Defaults to ``'HS256'``. +``JWT_PUBLIC_KEY`` The public key needed for RSA and ECDSA based signing algorithms. + Has to be provided if any of ``RS*`` or ``ES*`` algorithms is used. + PEM format expected. ================================= ========================================= diff --git a/flask_jwt_extended/config.py b/flask_jwt_extended/config.py index ab6c0c1c..014041d6 100644 --- a/flask_jwt_extended/config.py +++ b/flask_jwt_extended/config.py @@ -3,6 +3,7 @@ import simplekv from flask import current_app +from jwt.algorithms import requires_cryptography class _Config(object): @@ -15,6 +16,18 @@ class _Config(object): object. All of these values are read only. """ + @property + def is_asymmetric(self): + return self.algorithm in requires_cryptography + + @property + def encode_key(self): + return self.secret_key + + @property + def decode_key(self): + return self.public_key if self.is_asymmetric else self.secret_key + @property def token_location(self): locations = current_app.config['JWT_TOKEN_LOCATION'] @@ -172,6 +185,17 @@ def secret_key(self): raise RuntimeError('flask SECRET_KEY must be set') return key + @property + def public_key(self): + key = None + if self.algorithm in requires_cryptography: + key = current_app.config.get('JWT_PUBLIC_KEY', None) + if not key: + raise RuntimeError('JWT_PUBLIC_KEY must be set to use ' + 'asymmetric cryptography algorith ' + '"{crypto_algorithm}"'.format(crypto_algorithm=self.algorithm)) + return key + @property def cookie_max_age(self): # Returns the appropiate value for max_age for flask set_cookies. If @@ -180,3 +204,5 @@ def cookie_max_age(self): return None if self.session_cookie else 2147483647 # 2^31 config = _Config() + + diff --git a/flask_jwt_extended/jwt_manager.py b/flask_jwt_extended/jwt_manager.py index 66fcb0b9..2583706d 100644 --- a/flask_jwt_extended/jwt_manager.py +++ b/flask_jwt_extended/jwt_manager.py @@ -138,10 +138,12 @@ def _set_default_configuration_options(app): app.config.setdefault('JWT_REFRESH_TOKEN_EXPIRES', datetime.timedelta(days=30)) # What algorithm to use to sign the token. See here for a list of options: - # https://github.com/jpadilla/pyjwt/blob/master/jwt/api_jwt.py (note - # that public private key is not yet supported in this extension) + # https://github.com/jpadilla/pyjwt/blob/master/jwt/api_jwt.py app.config.setdefault('JWT_ALGORITHM', 'HS256') + # must be set if using asymmetric cryptography algorithm (RS* or EC*) + app.config.setdefault('JWT_PUBLIC_KEY', None) + # Options for blacklisting/revoking tokens app.config.setdefault('JWT_BLACKLIST_ENABLED', False) app.config.setdefault('JWT_BLACKLIST_STORE', None) @@ -251,7 +253,7 @@ def create_refresh_token(self, identity): """ refresh_token = encode_refresh_token( identity=self._user_identity_callback(identity), - secret=config.secret_key, + secret=config.encode_key, algorithm=config.algorithm, expires_delta=config.refresh_expires, csrf=config.csrf_protect @@ -259,7 +261,7 @@ def create_refresh_token(self, identity): # If blacklisting is enabled, store this token in our key-value store if config.blacklist_enabled: - decoded_token = decode_jwt(refresh_token, config.secret_key, + decoded_token = decode_jwt(refresh_token, config.decode_key, config.algorithm, csrf=config.csrf_protect) store_token(decoded_token, revoked=False) return refresh_token @@ -282,7 +284,7 @@ def create_access_token(self, identity, fresh=False): """ access_token = encode_access_token( identity=self._user_identity_callback(identity), - secret=config.secret_key, + secret=config.encode_key, algorithm=config.algorithm, expires_delta=config.access_expires, fresh=fresh, @@ -290,7 +292,7 @@ def create_access_token(self, identity, fresh=False): csrf=config.csrf_protect ) if config.blacklist_enabled and config.blacklist_access_tokens: - decoded_token = decode_jwt(access_token, config.secret_key, + decoded_token = decode_jwt(access_token, config.decode_key, config.algorithm, csrf=config.csrf_protect) store_token(decoded_token, revoked=False) return access_token diff --git a/flask_jwt_extended/utils.py b/flask_jwt_extended/utils.py index 018300fd..01813cec 100644 --- a/flask_jwt_extended/utils.py +++ b/flask_jwt_extended/utils.py @@ -51,7 +51,7 @@ def create_refresh_token(*args, **kwargs): def get_csrf_token(encoded_token): - token = decode_jwt(encoded_token, config.secret_key, config.algorithm, csrf=True) + token = decode_jwt(encoded_token, config.decode_key, config.algorithm, csrf=True) return token['csrf'] diff --git a/flask_jwt_extended/view_decorators.py b/flask_jwt_extended/view_decorators.py index 14b4f67d..596c5ca0 100644 --- a/flask_jwt_extended/view_decorators.py +++ b/flask_jwt_extended/view_decorators.py @@ -98,7 +98,7 @@ def _decode_jwt_from_headers(): raise InvalidHeaderError(msg) token = parts[1] - return decode_jwt(token, config.secret_key, config.algorithm, csrf=False) + return decode_jwt(token, config.decode_key, config.algorithm, csrf=False) def _decode_jwt_from_cookies(request_type): @@ -115,7 +115,7 @@ def _decode_jwt_from_cookies(request_type): decoded_token = decode_jwt( encoded_token=encoded_token, - secret=config.secret_key, + secret=config.decode_key, algorithm=config.algorithm, csrf=config.csrf_protect ) diff --git a/requirements.txt b/requirements.txt index 0d39a180..c4f5714a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,6 +2,7 @@ alabaster==0.7.9 Babel==2.3.4 click==6.6 coverage==4.2 +cryptography==1.8.1 docutils==0.12 Flask==0.11.1 imagesize==0.7.1 diff --git a/setup.py b/setup.py index 10896da8..4f74e831 100644 --- a/setup.py +++ b/setup.py @@ -17,7 +17,7 @@ packages=['flask_jwt_extended'], zip_safe=False, platforms='any', - install_requires=['Flask', 'PyJWT', 'simplekv'], + install_requires=['Flask', 'PyJWT', 'simplekv', 'cryptography'], classifiers=[ 'Development Status :: 4 - Beta', 'Environment :: Web Environment', diff --git a/tests/test_config.py b/tests/test_config.py index 91da0395..b4ce7769 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -45,11 +45,15 @@ def test_default_configs(self): self.assertEqual(config.access_expires, timedelta(minutes=15)) self.assertEqual(config.refresh_expires, timedelta(days=30)) self.assertEqual(config.algorithm, 'HS256') + self.assertEqual(config.is_asymmetric, False) self.assertEqual(config.blacklist_enabled, False) self.assertEqual(config.blacklist_checks, 'refresh') self.assertEqual(config.blacklist_access_tokens, False) self.assertEqual(config.secret_key, self.app.secret_key) + self.assertEqual(config.public_key, None) + self.assertEqual(config.encode_key, self.app.secret_key) + self.assertEqual(config.decode_key, self.app.secret_key) self.assertEqual(config.cookie_max_age, None) with self.assertRaises(RuntimeError): @@ -166,6 +170,15 @@ def test_invalid_config_options(self): with self.assertRaises(RuntimeError): config.secret_key + self.app.secret_key = None + with self.assertRaises(RuntimeError): + config.encode_key + + self.app.config['JWT_ALGORITHM'] = 'RS256' + self.app.config['JWT_PUBLIC_KEY'] = None + with self.assertRaises(RuntimeError): + config.decode_key + def test_depreciated_options(self): self.app.config['JWT_CSRF_HEADER_NAME'] = 'Auth' @@ -205,3 +218,14 @@ def test_special_config_options(self): self.app.config['JWT_TOKEN_LOCATION'] = ['cookies'] self.app.config['JWT_COOKIE_CSRF_PROTECT'] = False self.assertEqual(config.csrf_protect, False) + + def test_asymmetric_encryption_key_handling(self): + self.app.secret_key = 'MOCK_RSA_PRIVATE_KEY' + self.app.config['JWT_PUBLIC_KEY'] = 'MOCK_RSA_PUBLIC_KEY' + self.app.config['JWT_ALGORITHM'] = 'RS256' + + with self.app.test_request_context(): + self.assertEqual(config.is_asymmetric, True) + self.assertEqual(config.secret_key, 'MOCK_RSA_PRIVATE_KEY') + self.assertEqual(config.encode_key, 'MOCK_RSA_PRIVATE_KEY') + self.assertEqual(config.decode_key, 'MOCK_RSA_PUBLIC_KEY') diff --git a/tests/test_protected_endpoints.py b/tests/test_protected_endpoints.py index 0d3f9057..7ec77cc1 100644 --- a/tests/test_protected_endpoints.py +++ b/tests/test_protected_endpoints.py @@ -814,3 +814,113 @@ def test_accessing_endpoint_without_jwt(self): data = json.loads(response.get_data(as_text=True)) self.assertEqual(status_code, 401) self.assertIn('msg', data) + + +# random 1024bit RSA keypair +RSA_PRIVATE = """ +-----BEGIN RSA PRIVATE KEY----- +MIICXgIBAAKBgQDN+p9a9oMyqRzkae8yLdJcEK0O0WesH6JiMz+KDrpUwAoAM/KP +DnxFnROJDSBHyHEmPVn5x8GqV5lQ9+6l97jdEEcPo6wkshycM82fgcxOmvtAy4Uo +xq/AeplYqplhcUTGVuo4ZldOLmN8ksGmzhWpsOdT0bkYipHCn5sWZxd21QIDAQAB +AoGBAMJ0++KVXXEDZMpjFDWsOq898xNNMHG3/8ZzmWXN161RC1/7qt/RjhLuYtX9 +NV9vZRrzyrDcHAKj5pMhLgUzpColKzvdG2vKCldUs2b0c8HEGmjsmpmgoI1Tdf9D +G1QK+q9pKHlbj/MLr4vZPX6xEwAFeqRKlzL30JPD+O6mOXs1AkEA8UDzfadH1Y+H +bcNN2COvCqzqJMwLNRMXHDmUsjHfR2gtzk6D5dDyEaL+O4FLiQCaNXGWWoDTy/HJ +Clh1Z0+KYwJBANqRtJ+RvdgHMq0Yd45MMyy0ODGr1B3PoRbUK8EdXpyUNMi1g3iJ +tXMbLywNkTfcEXZTlbbkVYwrEl6P2N1r42cCQQDb9UQLBEFSTRJE2RRYQ/CL4yt3 +cTGmqkkfyr/v19ii2jEpMBzBo8eQnPL+fdvIhWwT3gQfb+WqxD9v10bzcmnRAkEA +mzTgeHd7wg3KdJRtQYTmyhXn2Y3VAJ5SG+3qbCW466NqoCQVCeFwEh75rmSr/Giv +lcDhDZCzFuf3EWNAcmuMfQJARsWfM6q7v2p6vkYLLJ7+VvIwookkr6wymF5Zgb9d +E6oTM2EeUPSyyrj5IdsU2JCNBH1m3JnUflz8p8/NYCoOZg== +-----END RSA PRIVATE KEY----- +""" +RSA_PUBLIC = """ +-----BEGIN RSA PUBLIC KEY----- +MIGJAoGBAM36n1r2gzKpHORp7zIt0lwQrQ7RZ6wfomIzP4oOulTACgAz8o8OfEWd +E4kNIEfIcSY9WfnHwapXmVD37qX3uN0QRw+jrCSyHJwzzZ+BzE6a+0DLhSjGr8B6 +mViqmWFxRMZW6jhmV04uY3ySwabOFamw51PRuRiKkcKfmxZnF3bVAgMBAAE= +-----END RSA PUBLIC KEY----- +""" + +class TestEndpointsWithAssymmetricCrypto(unittest.TestCase): + + def setUp(self): + self.app = Flask(__name__) + self.app.secret_key = RSA_PRIVATE + self.app.config['JWT_PUBLIC_KEY'] = RSA_PUBLIC + self.app.config['JWT_ALGORITHM'] = 'RS256' + self.app.config['JWT_ACCESS_TOKEN_EXPIRES'] = timedelta(seconds=1) + self.app.config['JWT_REFRESH_TOKEN_EXPIRES'] = timedelta(seconds=1) + self.jwt_manager = JWTManager(self.app) + self.client = self.app.test_client() + + @self.app.route('/auth/login', methods=['POST']) + def login(): + ret = { + 'access_token': create_access_token('test', fresh=True), + 'refresh_token': create_refresh_token('test') + } + return jsonify(ret), 200 + + @self.app.route('/auth/refresh', methods=['POST']) + @jwt_refresh_token_required + def refresh(): + username = get_jwt_identity() + ret = {'access_token': create_access_token(username, fresh=False)} + return jsonify(ret), 200 + + @self.app.route('/auth/fresh-login', methods=['POST']) + def fresh_login(): + ret = {'access_token': create_access_token('test', fresh=True)} + return jsonify(ret), 200 + + @self.app.route('/protected') + @jwt_required + def protected(): + return jsonify({'msg': "hello world"}) + + @self.app.route('/fresh-protected') + @fresh_jwt_required + def fresh_protected(): + return jsonify({'msg': "fresh hello world"}) + + def _jwt_post(self, url, jwt): + response = self.client.post(url, content_type='application/json', + headers={'Authorization': 'Bearer {}'.format(jwt)}) + status_code = response.status_code + data = json.loads(response.get_data(as_text=True)) + return status_code, data + + def _jwt_get(self, url, jwt, header_name='Authorization', header_type='Bearer'): + header_type = '{} {}'.format(header_type, jwt).strip() + response = self.client.get(url, headers={header_name: header_type}) + status_code = response.status_code + data = json.loads(response.get_data(as_text=True)) + return status_code, data + + def test_login(self): + response = self.client.post('/auth/login') + status_code = response.status_code + data = json.loads(response.get_data(as_text=True)) + self.assertEqual(status_code, 200) + self.assertIn('access_token', data) + self.assertIn('refresh_token', data) + + def test_fresh_login(self): + response = self.client.post('/auth/fresh-login') + status_code = response.status_code + data = json.loads(response.get_data(as_text=True)) + self.assertEqual(status_code, 200) + self.assertIn('access_token', data) + self.assertNotIn('refresh_token', data) + + def test_refresh(self): + response = self.client.post('/auth/login') + data = json.loads(response.get_data(as_text=True)) + access_token = data['access_token'] + refresh_token = data['refresh_token'] + + status_code, data = self._jwt_post('/auth/refresh', refresh_token) + self.assertEqual(status_code, 200) + self.assertIn('access_token', data) + self.assertNotIn('refresh_token', data)