diff --git a/docs/changing_default_behavior.rst b/docs/changing_default_behavior.rst index 286e979e..60b233a2 100644 --- a/docs/changing_default_behavior.rst +++ b/docs/changing_default_behavior.rst @@ -1,6 +1,9 @@ Changing Default Behaviors ========================== +Changing callback functions +~~~~~~~~~~~~~~~~~~~~~~~~~~~ + We provide what we think are sensible behaviors when attempting to access a protected endpoint. If the access token is not valid for any reason (missing, expired, tampered with, etc) we will return json in the format of {'msg': 'why @@ -34,3 +37,25 @@ Possible loader functions are: * - **revoked_token_loader** - Function to call when a revoked token accesses a protected endpoint - None + +Dynamic token expires time +~~~~~~~~~~~~~~~~~~~~~~~~~~ + +You can also change the expires time for a token via the **expires_delta** kwarg +in the **create_refresh_token** and **create_access_token** functions. This takes +a **datetime.timedelta** and overrides the **JWT_REFRESH_TOKEN_EXPIRES** and +**JWT_ACCESS_TOKEN_EXPIRES** options. This can be useful if you have different +use cases for different tokens. An example of this might be you use short lived +access tokens used in your web application, but you allow the creation of long +lived access tokens that other developers can generate and use to interact with +your api in their programs. + +.. code-block:: python + + @app.route('/create-dev-token', methods=[POST]) + @jwt_required + def create_dev_token(): + username = get_jwt_identity() + expires = datatime.timedelta(days=365) + token = create_access_token(username, expires_delta=expires) + return jsonify({'token': token}), 201 diff --git a/flask_jwt_extended/jwt_manager.py b/flask_jwt_extended/jwt_manager.py index d84059b3..fdf61238 100644 --- a/flask_jwt_extended/jwt_manager.py +++ b/flask_jwt_extended/jwt_manager.py @@ -244,7 +244,7 @@ def revoked_token_loader(self, callback): self._revoked_token_callback = callback return callback - def create_refresh_token(self, identity): + def create_refresh_token(self, identity, expires_delta=None): """ Creates a new refresh token @@ -256,13 +256,19 @@ def create_refresh_token(self, identity): query disk twice, once for initially finding the identity in your login endpoint, and once for setting addition data in the JWT via the user_claims_loader + :param expires_delta: A datetime.timedelta for how long this token should + last before it expires. If this is None, it will + use the 'JWT_REFRESH_TOKEN_EXPIRES` config value :return: A new refresh token """ + if expires_delta is None: + expires_delta = config.refresh_expires + refresh_token = encode_refresh_token( identity=self._user_identity_callback(identity), secret=config.encode_key, algorithm=config.algorithm, - expires_delta=config.refresh_expires, + expires_delta=expires_delta, csrf=config.csrf_protect ) @@ -273,7 +279,7 @@ def create_refresh_token(self, identity): store_token(decoded_token, revoked=False) return refresh_token - def create_access_token(self, identity, fresh=False): + def create_access_token(self, identity, fresh=False, expires_delta=None): """ Creates a new access token @@ -287,13 +293,19 @@ def create_access_token(self, identity, fresh=False): in the JWT via the user_claims_loader :param fresh: If this token should be marked as fresh, and can thus access fresh_jwt_required protected endpoints. Defaults to False + :param expires_delta: A datetime.timedelta for how long this token should + last before it expires. If this is None, it will + use the 'JWT_ACCESS_TOKEN_EXPIRES` config value :return: A new access token """ + if expires_delta is None: + expires_delta = config.access_expires + access_token = encode_access_token( identity=self._user_identity_callback(identity), secret=config.encode_key, algorithm=config.algorithm, - expires_delta=config.access_expires, + expires_delta=expires_delta, fresh=fresh, user_claims=self._user_claims_callback(identity), csrf=config.csrf_protect diff --git a/tests/test_protected_endpoints.py b/tests/test_protected_endpoints.py index a4cddaf7..9f462ea1 100644 --- a/tests/test_protected_endpoints.py +++ b/tests/test_protected_endpoints.py @@ -33,6 +33,16 @@ def login(): } return jsonify(ret), 200 + @self.app.route('/auth/login2', methods=['POST']) + def login2(): + expires = timedelta(minutes=5) + ret = { + 'access_token': create_access_token('test', fresh=True, + expires_delta=expires), + 'refresh_token': create_refresh_token('test', expires_delta=expires), + } + return jsonify(ret), 200 + @self.app.route('/auth/refresh', methods=['POST']) @jwt_refresh_token_required def refresh(): @@ -342,6 +352,26 @@ def test_bad_tokens(self): self.assertEqual(status_code, 422) self.assertIn('msg', data) + def test_expires_time_override(self): + # Test access token + response = self.client.post('/auth/login2') + data = json.loads(response.get_data(as_text=True)) + access_token = data['access_token'] + time.sleep(2) + status_code, data = self._jwt_get('/partially-protected', access_token) + self.assertEqual(status_code, 200) + self.assertEqual(data, {'msg': 'protected hello world'}) + + # Test refresh token + response = self.client.post('/auth/login2') + data = json.loads(response.get_data(as_text=True)) + refresh_token = data['refresh_token'] + time.sleep(2) + status_code, data = self._jwt_post('/auth/refresh', refresh_token) + self.assertEqual(status_code, 200) + self.assertIn('access_token', data) + self.assertNotIn('msg', data) + def test_optional_jwt_bad_tokens(self): # Test expired access token response = self.client.post('/auth/login')