Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 18 additions & 3 deletions cirro/api/auth/oauth_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import sys
import threading
import time
from datetime import datetime
from datetime import datetime, timedelta
from pathlib import Path
from typing import Optional

Expand All @@ -12,6 +12,7 @@
import requests
from botocore.exceptions import ClientError
from msal_extensions import FilePersistence
from msal_extensions.persistence import BasePersistence
from requests.auth import AuthBase

from cirro.api.auth.base import AuthInfo, RequestAuthWrapper
Expand Down Expand Up @@ -85,13 +86,26 @@ class ClientAuth(AuthInfo):
def __init__(self, client_id: str, region: str, auth_endpoint: str, enable_cache=True):
self.client_id = client_id
self.region = region
self._token_info = None
self._persistence = None
self._token_info: Optional[OAuthTokenResponse] = None
self._persistence: Optional[BasePersistence] = None

if enable_cache:
self._persistence = _build_token_persistence(str(TOKEN_PATH), fallback_to_plaintext=True)
self._token_info = self._load_token_info()

# Check saved token for change in endpoint
if self._token_info and self._token_info.get('client_id') != client_id:
logger.debug('Different client ID found, clearing saved token info')
self._clear_token_info()

# Check saved token for refresh token expiry
if self._token_info and self._token_info.get('refresh_expires_in'):
refresh_expiry_threshold = datetime.fromtimestamp(self._token_info.get('refresh_expires_in'))\
- timedelta(hours=12)
if refresh_expiry_threshold < datetime.now():
logger.debug('Refresh token expiry is too soon, re-authenticating')
self._clear_token_info()

if not self._token_info:
self._token_info = _authenticate(client_id=client_id, auth_endpoint=auth_endpoint)

Expand Down Expand Up @@ -162,3 +176,4 @@ def _clear_token_info(self):
return

Path(self._persistence.get_location()).unlink(missing_ok=True)
self._token_info = None
2 changes: 2 additions & 0 deletions cirro/api/models/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,4 +24,6 @@ class OAuthTokenResponse(TypedDict):
id_token: str
token_type: str
expires_in: int
refresh_expires_in: int
client_id: str
message: Optional[str]