diff --git a/flask_jwt_extended/__init__.py b/flask_jwt_extended/__init__.py index 6a4cfe07..46867ac7 100644 --- a/flask_jwt_extended/__init__.py +++ b/flask_jwt_extended/__init__.py @@ -1,6 +1,7 @@ from .jwt_manager import JWTManager from .view_decorators import ( - jwt_required, fresh_jwt_required, jwt_refresh_token_required + jwt_required, fresh_jwt_required, jwt_refresh_token_required, + jwt_optional ) from .utils import ( create_refresh_token, create_access_token, get_jwt_identity, diff --git a/flask_jwt_extended/view_decorators.py b/flask_jwt_extended/view_decorators.py index 596c5ca0..5c417ded 100644 --- a/flask_jwt_extended/view_decorators.py +++ b/flask_jwt_extended/view_decorators.py @@ -36,6 +36,32 @@ def wrapper(*args, **kwargs): return wrapper +def jwt_optional(fn): + """ + If you decorate a view with this, it will check the request for a valid + JWT and put it into the Flask application context before calling the view. + If no authorization header is present, the view will be called without the + application context being changed. Other authentication errors are not + affected. + + :param fn: The view function to decorate + """ + @wraps(fn) + def wrapper(*args, **kwargs): + try: + # If an acceptable JWT is found in the request, put it into + # the application context + jwt_data = _decode_jwt_from_request(request_type='access') + ctx_stack.top.jwt = jwt_data + except NoAuthorizationError: + # Allow request to proceed if no authorization header is present + # in the request, but don't modify application context + pass + # Return the decorated function in either case + return fn(*args, **kwargs) + return wrapper + + def fresh_jwt_required(fn): """ If you decorate a vew with this, it will ensure that the requester has a diff --git a/tests/test_protected_endpoints.py b/tests/test_protected_endpoints.py index ad06c1f2..a4cddaf7 100644 --- a/tests/test_protected_endpoints.py +++ b/tests/test_protected_endpoints.py @@ -11,7 +11,7 @@ get_jwt_identity, set_refresh_cookies, set_access_cookies, unset_jwt_cookies from flask_jwt_extended import JWTManager, create_refresh_token, \ jwt_refresh_token_required, create_access_token, fresh_jwt_required, \ - jwt_required, get_raw_jwt + jwt_optional, jwt_required, get_raw_jwt class TestEndpoints(unittest.TestCase): @@ -55,6 +55,14 @@ def protected(): def fresh_protected(): return jsonify({'msg': "fresh hello world"}) + @self.app.route('/partially-protected') + @jwt_optional + def partially_protected(): + if get_jwt_identity(): + return jsonify({'msg': "protected hello world"}) + return jsonify({'msg': "unprotected hello world"}) + + def _jwt_post(self, url, jwt): response = self.client.post(url, content_type='application/json', headers={'Authorization': 'Bearer {}'.format(jwt)}) @@ -124,6 +132,32 @@ def test_jwt_required(self): self.assertEqual(status, 200) self.assertEqual(data, {'msg': 'hello world'}) + def test_jwt_optional_no_jwt(self): + response = self.client.get('/partially-protected') + data = json.loads(response.get_data(as_text=True)) + status = response.status_code + self.assertEqual(status, 200) + self.assertEqual(data, {'msg': 'unprotected hello world'}) + + def test_jwt_optional_with_jwt(self): + response = self.client.post('/auth/login') + data = json.loads(response.get_data(as_text=True)) + fresh_access_token = data['access_token'] + refresh_token = data['refresh_token'] + + # Test it works with a fresh token + status, data = self._jwt_get('/partially-protected', + fresh_access_token) + self.assertEqual(data, {'msg': 'protected hello world'}) + self.assertEqual(status, 200) + + # Test it works with a non-fresh access token + _, data = self._jwt_post('/auth/refresh', refresh_token) + non_fresh_token = data['access_token'] + status, data = self._jwt_get('/partially-protected', non_fresh_token) + self.assertEqual(status, 200) + self.assertEqual(data, {'msg': 'protected hello world'}) + def test_jwt_required_wrong_token(self): response = self.client.post('/auth/login') data = json.loads(response.get_data(as_text=True)) @@ -133,6 +167,15 @@ def test_jwt_required_wrong_token(self): status, text = self._jwt_get('/protected', refresh_token) self.assertEqual(status, 422) + def test_jwt_optional_wrong_token(self): + response = self.client.post('/auth/login') + data = json.loads(response.get_data(as_text=True)) + refresh_token = data['refresh_token'] + + # Shouldn't work with a refresh token + status, text = self._jwt_get('/partially-protected', refresh_token) + self.assertEqual(status, 422) + def test_fresh_jwt_required(self): response = self.client.post('/auth/login') data = json.loads(response.get_data(as_text=True)) @@ -209,6 +252,38 @@ def test_bad_jwt_requests(self): self.assertEqual(status_code, 422) self.assertIn('msg', data) + def test_optional_bad_jwt_requests(self): + response = self.client.post('/auth/login') + data = json.loads(response.get_data(as_text=True)) + access_token = data['access_token'] + + # Test with missing type in authorization header + auth_header = access_token + response = self.client.get('/partially-protected', + headers={'Authorization': auth_header}) + data = json.loads(response.get_data(as_text=True)) + status_code = response.status_code + self.assertEqual(status_code, 422) + self.assertIn('msg', data) + + # Test with type not being Bearer in authorization header + auth_header = "BANANA {}".format(access_token) + response = self.client.get('/partially-protected', + headers={'Authorization': auth_header}) + data = json.loads(response.get_data(as_text=True)) + status_code = response.status_code + self.assertEqual(status_code, 422) + self.assertIn('msg', data) + + # Test with too many items in auth header + auth_header = "Bearer {} BANANA".format(access_token) + response = self.client.get('/partially-protected', + headers={'Authorization': auth_header}) + data = json.loads(response.get_data(as_text=True)) + status_code = response.status_code + self.assertEqual(status_code, 422) + self.assertIn('msg', data) + def test_bad_tokens(self): # Test expired access token response = self.client.post('/auth/login') @@ -267,6 +342,54 @@ def test_bad_tokens(self): self.assertEqual(status_code, 422) self.assertIn('msg', data) + def test_optional_jwt_bad_tokens(self): + # Test expired access token + response = self.client.post('/auth/login') + data = json.loads(response.get_data(as_text=True)) + access_token = data['access_token'] + status_code, data = self._jwt_get('/partially-protected', access_token) + self.assertEqual(status_code, 200) + self.assertEqual(data, {'msg': 'protected hello world'}) + time.sleep(2) + status_code, data = self._jwt_get('/partially-protected', access_token) + self.assertEqual(status_code, 401) + self.assertIn('msg', data) + + # Test Bogus token + auth_header = "Bearer {}".format('this_is_totally_an_access_token') + response = self.client.get('/partially-protected', + headers={'Authorization': auth_header}) + data = json.loads(response.get_data(as_text=True)) + status_code = response.status_code + self.assertEqual(status_code, 422) + self.assertIn('msg', data) + + # Test token that was signed with a different key + with self.app.test_request_context(): + token = encode_access_token('foo', 'newsecret', 'HS256', + timedelta(minutes=5), True, {}, + csrf=False) + auth_header = "Bearer {}".format(token) + response = self.client.get('/partially-protected', + headers={'Authorization': auth_header}) + data = json.loads(response.get_data(as_text=True)) + status_code = response.status_code + self.assertEqual(status_code, 422) + self.assertIn('msg', data) + + # Test with valid token that is missing required claims + now = datetime.utcnow() + token_data = {'exp': now + timedelta(minutes=5)} + encoded_token = jwt.encode(token_data, self.app.config['SECRET_KEY'], + self.app.config['JWT_ALGORITHM']).decode('utf-8') + auth_header = "Bearer {}".format(encoded_token) + response = self.client.get('/partially-protected', + headers={'Authorization': auth_header}) + data = json.loads(response.get_data(as_text=True)) + status_code = response.status_code + self.assertEqual(status_code, 422) + self.assertIn('msg', data) + def test_jwt_identity_claims(self): # Setup custom claims @self.jwt_manager.user_claims_loader @@ -349,6 +472,43 @@ def test_different_headers(self): header_type='Bearer') self.assertIn('msg', data) self.assertEqual(status, 401) + self.assertEqual(data, {'msg': 'Missing Auth Header'}) + + def test_different_headers_jwt_optional(self): + response = self.client.post('/auth/login') + data = json.loads(response.get_data(as_text=True)) + access_token = data['access_token'] + + self.app.config['JWT_HEADER_TYPE'] = 'JWT' + status, data = self._jwt_get('/partially-protected', access_token, + header_type='JWT') + self.assertEqual(data, {'msg': 'protected hello world'}) + self.assertEqual(status, 200) + + self.app.config['JWT_HEADER_TYPE'] = '' + status, data = self._jwt_get('/partially-protected', access_token, + header_type='') + self.assertEqual(data, {'msg': 'protected hello world'}) + self.assertEqual(status, 200) + + self.app.config['JWT_HEADER_TYPE'] = '' + status, data = self._jwt_get('/partially-protected', access_token, + header_type='Bearer') + self.assertIn('msg', data) + self.assertEqual(status, 422) + + self.app.config['JWT_HEADER_TYPE'] = 'Bearer' + self.app.config['JWT_HEADER_NAME'] = 'Auth' + status, data = self._jwt_get('/partially-protected', access_token, + header_name='Auth', header_type='Bearer') + self.assertEqual(data, {'msg': 'protected hello world'}) + self.assertEqual(status, 200) + + status, data = self._jwt_get('/partially-protected', access_token, + header_name='Authorization', + header_type='Bearer') + self.assertEqual(status, 200) + self.assertEqual(data, {'msg': 'unprotected hello world'}) def test_cookie_methods_fail_with_headers_configured(self): app = Flask(__name__) @@ -401,6 +561,22 @@ def test_jwt_with_different_algorithm(self): self.assertEqual(status, 422) self.assertIn('msg', data) + def test_optional_jwt_with_different_algorithm(self): + self.app.config['JWT_ALGORITHM'] = 'HS256' + self.app.secret_key = 'test_secret' + access_token = encode_access_token( + identity='bobdobbs', + secret='test_secret', + algorithm='HS512', + expires_delta=timedelta(minutes=5), + fresh=True, + user_claims={}, + csrf=False + ) + status, data = self._jwt_get('/partially-protected', access_token) + self.assertEqual(status, 422) + self.assertIn('msg', data) + class TestEndpointsWithCookies(unittest.TestCase):