From a3f4c4882cd433f93caa164334b8965941f5d2c4 Mon Sep 17 00:00:00 2001 From: Nathan Thorpe Date: Wed, 19 Apr 2023 10:21:49 -0700 Subject: [PATCH 1/2] reauthenticate if refresh token expiry is too close and if endpoint changes --- cirro/api/auth/oauth_client.py | 19 ++++++++++++++++--- cirro/api/models/auth.py | 2 ++ 2 files changed, 18 insertions(+), 3 deletions(-) diff --git a/cirro/api/auth/oauth_client.py b/cirro/api/auth/oauth_client.py index 51f916b1..67a30cd1 100644 --- a/cirro/api/auth/oauth_client.py +++ b/cirro/api/auth/oauth_client.py @@ -3,7 +3,7 @@ import sys import threading import time -from datetime import datetime +from datetime import datetime, timedelta from pathlib import Path from typing import Optional @@ -12,6 +12,7 @@ import requests from botocore.exceptions import ClientError from msal_extensions import FilePersistence +from msal_extensions.persistence import BasePersistence from requests.auth import AuthBase from cirro.api.auth.base import AuthInfo, RequestAuthWrapper @@ -85,13 +86,24 @@ class ClientAuth(AuthInfo): def __init__(self, client_id: str, region: str, auth_endpoint: str, enable_cache=True): self.client_id = client_id self.region = region - self._token_info = None - self._persistence = None + self._token_info: Optional[OAuthTokenResponse] = None + self._persistence: Optional[BasePersistence] = None if enable_cache: self._persistence = _build_token_persistence(str(TOKEN_PATH), fallback_to_plaintext=True) self._token_info = self._load_token_info() + # Check saved token for endpoint and refresh expiry + if self._token_info: + if self._token_info.get('client_id') != client_id: + logger.debug('Different client ID found, clearing saved token info') + self._clear_token_info() + + refresh_expiry = datetime.fromtimestamp(self._token_info.get('refresh_expires_in')) + if refresh_expiry < datetime.now() - timedelta(hours=12): + logger.debug('Refresh token expiry is too soon, re-authenticating') + self._clear_token_info() + if not self._token_info: self._token_info = _authenticate(client_id=client_id, auth_endpoint=auth_endpoint) @@ -162,3 +174,4 @@ def _clear_token_info(self): return Path(self._persistence.get_location()).unlink(missing_ok=True) + self._token_info = None diff --git a/cirro/api/models/auth.py b/cirro/api/models/auth.py index 1a3f65b3..77fc5aa2 100644 --- a/cirro/api/models/auth.py +++ b/cirro/api/models/auth.py @@ -24,4 +24,6 @@ class OAuthTokenResponse(TypedDict): id_token: str token_type: str expires_in: int + refresh_expires_in: int + client_id: str message: Optional[str] From ccbcb4d39f34ee29adda240b97f32ae02d817a48 Mon Sep 17 00:00:00 2001 From: Nathan Thorpe Date: Wed, 19 Apr 2023 12:06:42 -0700 Subject: [PATCH 2/2] improve --- cirro/api/auth/oauth_client.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/cirro/api/auth/oauth_client.py b/cirro/api/auth/oauth_client.py index 67a30cd1..29053661 100644 --- a/cirro/api/auth/oauth_client.py +++ b/cirro/api/auth/oauth_client.py @@ -93,14 +93,16 @@ def __init__(self, client_id: str, region: str, auth_endpoint: str, enable_cache self._persistence = _build_token_persistence(str(TOKEN_PATH), fallback_to_plaintext=True) self._token_info = self._load_token_info() - # Check saved token for endpoint and refresh expiry - if self._token_info: - if self._token_info.get('client_id') != client_id: - logger.debug('Different client ID found, clearing saved token info') - self._clear_token_info() + # Check saved token for change in endpoint + if self._token_info and self._token_info.get('client_id') != client_id: + logger.debug('Different client ID found, clearing saved token info') + self._clear_token_info() - refresh_expiry = datetime.fromtimestamp(self._token_info.get('refresh_expires_in')) - if refresh_expiry < datetime.now() - timedelta(hours=12): + # Check saved token for refresh token expiry + if self._token_info and self._token_info.get('refresh_expires_in'): + refresh_expiry_threshold = datetime.fromtimestamp(self._token_info.get('refresh_expires_in'))\ + - timedelta(hours=12) + if refresh_expiry_threshold < datetime.now(): logger.debug('Refresh token expiry is too soon, re-authenticating') self._clear_token_info()