Skip to content

Commit 61b4d00

Browse files
committed
chore: check for cached token and exception type before retrying
1 parent 383b605 commit 61b4d00

File tree

4 files changed

+71
-4
lines changed

4 files changed

+71
-4
lines changed

aws_advanced_python_wrapper/federated_plugin.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ def _connect(self, host_info: HostInfo, props: Properties, connect_func: Callabl
9898
region
9999
)
100100

101-
token_info = FederatedAuthPlugin._token_cache.get(cache_key)
101+
token_info: Optional[TokenInfo] = FederatedAuthPlugin._token_cache.get(cache_key)
102102

103103
if token_info is not None and not token_info.is_expired():
104104
logger.debug("FederatedAuthPlugin.UseCachedToken", token_info.token)
@@ -110,7 +110,10 @@ def _connect(self, host_info: HostInfo, props: Properties, connect_func: Callabl
110110

111111
try:
112112
return connect_func()
113-
except Exception:
113+
except Exception as e:
114+
if token_info is None or token_info.is_expired() or not self._plugin_service.is_login_exception(e):
115+
raise e
116+
114117
self._update_authentication_token(host_info, props, user, region, cache_key)
115118

116119
try:

aws_advanced_python_wrapper/okta_plugin.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,8 @@ def _connect(self, host_info: HostInfo, props: Properties, connect_func: Callabl
9494
region
9595
)
9696

97-
token_info = OktaAuthPlugin._token_cache.get(cache_key)
97+
token_info: Optional[TokenInfo] = OktaAuthPlugin._token_cache.get(cache_key)
98+
9899

99100
if token_info is not None and not token_info.is_expired():
100101
logger.debug("OktaAuthPlugin.UseCachedToken", token_info.token)
@@ -106,7 +107,10 @@ def _connect(self, host_info: HostInfo, props: Properties, connect_func: Callabl
106107

107108
try:
108109
return connect_func()
109-
except Exception:
110+
except Exception as e:
111+
if token_info is None or token_info.is_expired() or not self._plugin_service.is_login_exception(e):
112+
raise e
113+
110114
self._update_authentication_token(host_info, props, user, region, cache_key)
111115

112116
try:

tests/unit/test_federated_auth_plugin.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,37 @@ def test_no_cached_token(mocker, mock_plugin_service, mock_session, mock_func, m
173173
assert WrapperProperties.PASSWORD.get(test_props) == _TEST_TOKEN
174174

175175

176+
@patch("aws_advanced_python_wrapper.federated_plugin.FederatedAuthPlugin._token_cache", _token_cache)
177+
def test_no_cached_token_raises_exception(mocker, mock_plugin_service, mock_session, mock_func, mock_client, mock_dialect,
178+
mock_credentials_provider_factory):
179+
test_props: Properties = Properties(
180+
{"plugins": "federated_auth", "user": "postgresqlUser", "idp_username": "user", "idp_password": "password"})
181+
WrapperProperties.DB_USER.set(test_props, _DB_USER)
182+
183+
exception_message = "generic exception"
184+
mock_func.side_effect = Exception(exception_message)
185+
186+
target_plugin: FederatedAuthPlugin = FederatedAuthPlugin(mock_plugin_service, mock_credentials_provider_factory,
187+
mock_session)
188+
with pytest.raises(Exception) as e_info:
189+
target_plugin.connect(
190+
target_driver_func=mocker.MagicMock(),
191+
driver_dialect=mock_dialect,
192+
host_info=_PG_HOST_INFO,
193+
props=test_props,
194+
is_initial_connection=False,
195+
connect_func=mock_func)
196+
197+
mock_client.generate_db_auth_token.assert_called_with(
198+
DBHostname="pg.testdb.us-east-2.rds.amazonaws.com",
199+
Port=5432,
200+
DBUsername="postgresqlUser"
201+
)
202+
203+
assert e_info.type == Exception
204+
assert str(e_info.value) == exception_message
205+
206+
176207
@patch("aws_advanced_python_wrapper.federated_plugin.FederatedAuthPlugin._token_cache", _token_cache)
177208
def test_connect_with_specified_iam_host_port_region(mocker,
178209
mock_plugin_service,

tests/unit/test_okta_plugin.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,35 @@ def test_no_cached_token(mocker, mock_plugin_service, mock_session, mock_func, m
170170
assert WrapperProperties.PASSWORD.get(test_props) == _TEST_TOKEN
171171

172172

173+
@patch("aws_advanced_python_wrapper.okta_plugin.OktaAuthPlugin._token_cache", _token_cache)
174+
def test_no_cached_token_raises_exception(mocker, mock_plugin_service, mock_session, mock_func, mock_client, mock_dialect, mock_credentials_provider_factory):
175+
test_props: Properties = Properties({"plugins": "okta", "user": "postgresqlUser", "idp_username": "user", "idp_password": "password"})
176+
WrapperProperties.DB_USER.set(test_props, _DB_USER)
177+
178+
exception_message = "generic exception"
179+
mock_func.side_effect = Exception(exception_message)
180+
181+
target_plugin: OktaAuthPlugin = OktaAuthPlugin(mock_plugin_service, mock_credentials_provider_factory, mock_session)
182+
183+
with pytest.raises(Exception) as e_info:
184+
target_plugin.connect(
185+
target_driver_func=mocker.MagicMock(),
186+
driver_dialect=mock_dialect,
187+
host_info=_PG_HOST_INFO,
188+
props=test_props,
189+
is_initial_connection=False,
190+
connect_func=mock_func)
191+
192+
mock_client.generate_db_auth_token.assert_called_with(
193+
DBHostname="pg.testdb.us-east-2.rds.amazonaws.com",
194+
Port=5432,
195+
DBUsername="postgresqlUser"
196+
)
197+
198+
assert e_info.type == Exception
199+
assert str(e_info.value) == exception_message
200+
201+
173202
@patch("aws_advanced_python_wrapper.okta_plugin.OktaAuthPlugin._token_cache", _token_cache)
174203
def test_connect_with_specified_iam_host_port_region(mocker,
175204
mock_plugin_service,

0 commit comments

Comments
 (0)