Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion flask_jwt_extended/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from .jwt_manager import JWTManager
from .view_decorators import (
jwt_required, fresh_jwt_required, jwt_refresh_token_required
jwt_required, fresh_jwt_required, jwt_refresh_token_required,
jwt_optional
)
from .utils import (
create_refresh_token, create_access_token, get_jwt_identity,
Expand Down
26 changes: 26 additions & 0 deletions flask_jwt_extended/view_decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,32 @@ def wrapper(*args, **kwargs):
return wrapper


def jwt_optional(fn):
"""
If you decorate a view with this, it will check the request for a valid
JWT and put it into the Flask application context before calling the view.
If no authorization header is present, the view will be called without the
application context being changed. Other authentication errors are not
affected.

:param fn: The view function to decorate
"""
@wraps(fn)
def wrapper(*args, **kwargs):
try:
# If an acceptable JWT is found in the request, put it into
# the application context
jwt_data = _decode_jwt_from_request(request_type='access')
ctx_stack.top.jwt = jwt_data
except NoAuthorizationError:
# Allow request to proceed if no authorization header is present
# in the request, but don't modify application context
pass
# Return the decorated function in either case
return fn(*args, **kwargs)
return wrapper


def fresh_jwt_required(fn):
"""
If you decorate a vew with this, it will ensure that the requester has a
Expand Down
178 changes: 177 additions & 1 deletion tests/test_protected_endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
get_jwt_identity, set_refresh_cookies, set_access_cookies, unset_jwt_cookies
from flask_jwt_extended import JWTManager, create_refresh_token, \
jwt_refresh_token_required, create_access_token, fresh_jwt_required, \
jwt_required, get_raw_jwt
jwt_optional, jwt_required, get_raw_jwt


class TestEndpoints(unittest.TestCase):
Expand Down Expand Up @@ -55,6 +55,14 @@ def protected():
def fresh_protected():
return jsonify({'msg': "fresh hello world"})

@self.app.route('/partially-protected')
@jwt_optional
def partially_protected():
if get_jwt_identity():
return jsonify({'msg': "protected hello world"})
return jsonify({'msg': "unprotected hello world"})


def _jwt_post(self, url, jwt):
response = self.client.post(url, content_type='application/json',
headers={'Authorization': 'Bearer {}'.format(jwt)})
Expand Down Expand Up @@ -124,6 +132,32 @@ def test_jwt_required(self):
self.assertEqual(status, 200)
self.assertEqual(data, {'msg': 'hello world'})

def test_jwt_optional_no_jwt(self):
response = self.client.get('/partially-protected')
data = json.loads(response.get_data(as_text=True))
status = response.status_code
self.assertEqual(status, 200)
self.assertEqual(data, {'msg': 'unprotected hello world'})

def test_jwt_optional_with_jwt(self):
response = self.client.post('/auth/login')
data = json.loads(response.get_data(as_text=True))
fresh_access_token = data['access_token']
refresh_token = data['refresh_token']

# Test it works with a fresh token
status, data = self._jwt_get('/partially-protected',
fresh_access_token)
self.assertEqual(data, {'msg': 'protected hello world'})
self.assertEqual(status, 200)

# Test it works with a non-fresh access token
_, data = self._jwt_post('/auth/refresh', refresh_token)
non_fresh_token = data['access_token']
status, data = self._jwt_get('/partially-protected', non_fresh_token)
self.assertEqual(status, 200)
self.assertEqual(data, {'msg': 'protected hello world'})

def test_jwt_required_wrong_token(self):
response = self.client.post('/auth/login')
data = json.loads(response.get_data(as_text=True))
Expand All @@ -133,6 +167,15 @@ def test_jwt_required_wrong_token(self):
status, text = self._jwt_get('/protected', refresh_token)
self.assertEqual(status, 422)

def test_jwt_optional_wrong_token(self):
response = self.client.post('/auth/login')
data = json.loads(response.get_data(as_text=True))
refresh_token = data['refresh_token']

# Shouldn't work with a refresh token
status, text = self._jwt_get('/partially-protected', refresh_token)
self.assertEqual(status, 422)

def test_fresh_jwt_required(self):
response = self.client.post('/auth/login')
data = json.loads(response.get_data(as_text=True))
Expand Down Expand Up @@ -209,6 +252,38 @@ def test_bad_jwt_requests(self):
self.assertEqual(status_code, 422)
self.assertIn('msg', data)

def test_optional_bad_jwt_requests(self):
response = self.client.post('/auth/login')
data = json.loads(response.get_data(as_text=True))
access_token = data['access_token']

# Test with missing type in authorization header
auth_header = access_token
response = self.client.get('/partially-protected',
headers={'Authorization': auth_header})
data = json.loads(response.get_data(as_text=True))
status_code = response.status_code
self.assertEqual(status_code, 422)
self.assertIn('msg', data)

# Test with type not being Bearer in authorization header
auth_header = "BANANA {}".format(access_token)
response = self.client.get('/partially-protected',
headers={'Authorization': auth_header})
data = json.loads(response.get_data(as_text=True))
status_code = response.status_code
self.assertEqual(status_code, 422)
self.assertIn('msg', data)

# Test with too many items in auth header
auth_header = "Bearer {} BANANA".format(access_token)
response = self.client.get('/partially-protected',
headers={'Authorization': auth_header})
data = json.loads(response.get_data(as_text=True))
status_code = response.status_code
self.assertEqual(status_code, 422)
self.assertIn('msg', data)

def test_bad_tokens(self):
# Test expired access token
response = self.client.post('/auth/login')
Expand Down Expand Up @@ -267,6 +342,54 @@ def test_bad_tokens(self):
self.assertEqual(status_code, 422)
self.assertIn('msg', data)

def test_optional_jwt_bad_tokens(self):
# Test expired access token
response = self.client.post('/auth/login')
data = json.loads(response.get_data(as_text=True))
access_token = data['access_token']
status_code, data = self._jwt_get('/partially-protected', access_token)
self.assertEqual(status_code, 200)
self.assertEqual(data, {'msg': 'protected hello world'})
time.sleep(2)
status_code, data = self._jwt_get('/partially-protected', access_token)
self.assertEqual(status_code, 401)
self.assertIn('msg', data)

# Test Bogus token
auth_header = "Bearer {}".format('this_is_totally_an_access_token')
response = self.client.get('/partially-protected',
headers={'Authorization': auth_header})
data = json.loads(response.get_data(as_text=True))
status_code = response.status_code
self.assertEqual(status_code, 422)
self.assertIn('msg', data)

# Test token that was signed with a different key
with self.app.test_request_context():
token = encode_access_token('foo', 'newsecret', 'HS256',
timedelta(minutes=5), True, {},
csrf=False)
auth_header = "Bearer {}".format(token)
response = self.client.get('/partially-protected',
headers={'Authorization': auth_header})
data = json.loads(response.get_data(as_text=True))
status_code = response.status_code
self.assertEqual(status_code, 422)
self.assertIn('msg', data)

# Test with valid token that is missing required claims
now = datetime.utcnow()
token_data = {'exp': now + timedelta(minutes=5)}
encoded_token = jwt.encode(token_data, self.app.config['SECRET_KEY'],
self.app.config['JWT_ALGORITHM']).decode('utf-8')
auth_header = "Bearer {}".format(encoded_token)
response = self.client.get('/partially-protected',
headers={'Authorization': auth_header})
data = json.loads(response.get_data(as_text=True))
status_code = response.status_code
self.assertEqual(status_code, 422)
self.assertIn('msg', data)

def test_jwt_identity_claims(self):
# Setup custom claims
@self.jwt_manager.user_claims_loader
Expand Down Expand Up @@ -349,6 +472,43 @@ def test_different_headers(self):
header_type='Bearer')
self.assertIn('msg', data)
self.assertEqual(status, 401)
self.assertEqual(data, {'msg': 'Missing Auth Header'})

def test_different_headers_jwt_optional(self):
response = self.client.post('/auth/login')
data = json.loads(response.get_data(as_text=True))
access_token = data['access_token']

self.app.config['JWT_HEADER_TYPE'] = 'JWT'
status, data = self._jwt_get('/partially-protected', access_token,
header_type='JWT')
self.assertEqual(data, {'msg': 'protected hello world'})
self.assertEqual(status, 200)

self.app.config['JWT_HEADER_TYPE'] = ''
status, data = self._jwt_get('/partially-protected', access_token,
header_type='')
self.assertEqual(data, {'msg': 'protected hello world'})
self.assertEqual(status, 200)

self.app.config['JWT_HEADER_TYPE'] = ''
status, data = self._jwt_get('/partially-protected', access_token,
header_type='Bearer')
self.assertIn('msg', data)
self.assertEqual(status, 422)

self.app.config['JWT_HEADER_TYPE'] = 'Bearer'
self.app.config['JWT_HEADER_NAME'] = 'Auth'
status, data = self._jwt_get('/partially-protected', access_token,
header_name='Auth', header_type='Bearer')
self.assertEqual(data, {'msg': 'protected hello world'})
self.assertEqual(status, 200)

status, data = self._jwt_get('/partially-protected', access_token,
header_name='Authorization',
header_type='Bearer')
self.assertEqual(status, 200)
self.assertEqual(data, {'msg': 'unprotected hello world'})

def test_cookie_methods_fail_with_headers_configured(self):
app = Flask(__name__)
Expand Down Expand Up @@ -401,6 +561,22 @@ def test_jwt_with_different_algorithm(self):
self.assertEqual(status, 422)
self.assertIn('msg', data)

def test_optional_jwt_with_different_algorithm(self):
self.app.config['JWT_ALGORITHM'] = 'HS256'
self.app.secret_key = 'test_secret'
access_token = encode_access_token(
identity='bobdobbs',
secret='test_secret',
algorithm='HS512',
expires_delta=timedelta(minutes=5),
fresh=True,
user_claims={},
csrf=False
)
status, data = self._jwt_get('/partially-protected', access_token)
self.assertEqual(status, 422)
self.assertIn('msg', data)


class TestEndpointsWithCookies(unittest.TestCase):

Expand Down