diff --git a/cirro/api/auth/oauth_client.py b/cirro/api/auth/oauth_client.py index 51f916b1..29053661 100644 --- a/cirro/api/auth/oauth_client.py +++ b/cirro/api/auth/oauth_client.py @@ -3,7 +3,7 @@ import sys import threading import time -from datetime import datetime +from datetime import datetime, timedelta from pathlib import Path from typing import Optional @@ -12,6 +12,7 @@ import requests from botocore.exceptions import ClientError from msal_extensions import FilePersistence +from msal_extensions.persistence import BasePersistence from requests.auth import AuthBase from cirro.api.auth.base import AuthInfo, RequestAuthWrapper @@ -85,13 +86,26 @@ class ClientAuth(AuthInfo): def __init__(self, client_id: str, region: str, auth_endpoint: str, enable_cache=True): self.client_id = client_id self.region = region - self._token_info = None - self._persistence = None + self._token_info: Optional[OAuthTokenResponse] = None + self._persistence: Optional[BasePersistence] = None if enable_cache: self._persistence = _build_token_persistence(str(TOKEN_PATH), fallback_to_plaintext=True) self._token_info = self._load_token_info() + # Check saved token for change in endpoint + if self._token_info and self._token_info.get('client_id') != client_id: + logger.debug('Different client ID found, clearing saved token info') + self._clear_token_info() + + # Check saved token for refresh token expiry + if self._token_info and self._token_info.get('refresh_expires_in'): + refresh_expiry_threshold = datetime.fromtimestamp(self._token_info.get('refresh_expires_in'))\ + - timedelta(hours=12) + if refresh_expiry_threshold < datetime.now(): + logger.debug('Refresh token expiry is too soon, re-authenticating') + self._clear_token_info() + if not self._token_info: self._token_info = _authenticate(client_id=client_id, auth_endpoint=auth_endpoint) @@ -162,3 +176,4 @@ def _clear_token_info(self): return Path(self._persistence.get_location()).unlink(missing_ok=True) + self._token_info = None diff --git a/cirro/api/models/auth.py b/cirro/api/models/auth.py index 1a3f65b3..77fc5aa2 100644 --- a/cirro/api/models/auth.py +++ b/cirro/api/models/auth.py @@ -24,4 +24,6 @@ class OAuthTokenResponse(TypedDict): id_token: str token_type: str expires_in: int + refresh_expires_in: int + client_id: str message: Optional[str]