|
3 | 3 | import sys |
4 | 4 | import threading |
5 | 5 | import time |
6 | | -from datetime import datetime |
| 6 | +from datetime import datetime, timedelta |
7 | 7 | from pathlib import Path |
8 | 8 | from typing import Optional |
9 | 9 |
|
|
12 | 12 | import requests |
13 | 13 | from botocore.exceptions import ClientError |
14 | 14 | from msal_extensions import FilePersistence |
| 15 | +from msal_extensions.persistence import BasePersistence |
15 | 16 | from requests.auth import AuthBase |
16 | 17 |
|
17 | 18 | from cirro.api.auth.base import AuthInfo, RequestAuthWrapper |
@@ -85,13 +86,26 @@ class ClientAuth(AuthInfo): |
85 | 86 | def __init__(self, client_id: str, region: str, auth_endpoint: str, enable_cache=True): |
86 | 87 | self.client_id = client_id |
87 | 88 | self.region = region |
88 | | - self._token_info = None |
89 | | - self._persistence = None |
| 89 | + self._token_info: Optional[OAuthTokenResponse] = None |
| 90 | + self._persistence: Optional[BasePersistence] = None |
90 | 91 |
|
91 | 92 | if enable_cache: |
92 | 93 | self._persistence = _build_token_persistence(str(TOKEN_PATH), fallback_to_plaintext=True) |
93 | 94 | self._token_info = self._load_token_info() |
94 | 95 |
|
| 96 | + # Check saved token for change in endpoint |
| 97 | + if self._token_info and self._token_info.get('client_id') != client_id: |
| 98 | + logger.debug('Different client ID found, clearing saved token info') |
| 99 | + self._clear_token_info() |
| 100 | + |
| 101 | + # Check saved token for refresh token expiry |
| 102 | + if self._token_info and self._token_info.get('refresh_expires_in'): |
| 103 | + refresh_expiry_threshold = datetime.fromtimestamp(self._token_info.get('refresh_expires_in'))\ |
| 104 | + - timedelta(hours=12) |
| 105 | + if refresh_expiry_threshold < datetime.now(): |
| 106 | + logger.debug('Refresh token expiry is too soon, re-authenticating') |
| 107 | + self._clear_token_info() |
| 108 | + |
95 | 109 | if not self._token_info: |
96 | 110 | self._token_info = _authenticate(client_id=client_id, auth_endpoint=auth_endpoint) |
97 | 111 |
|
@@ -162,3 +176,4 @@ def _clear_token_info(self): |
162 | 176 | return |
163 | 177 |
|
164 | 178 | Path(self._persistence.get_location()).unlink(missing_ok=True) |
| 179 | + self._token_info = None |
0 commit comments