From f1106d1028daa1f38bdeb848ae2f88adf982c793 Mon Sep 17 00:00:00 2001 From: Juan Lee Date: Fri, 13 Jun 2025 03:02:04 -0700 Subject: [PATCH] chore: check for cached token and exception type before retrying --- .../federated_plugin.py | 7 +++-- aws_advanced_python_wrapper/okta_plugin.py | 7 +++-- tests/unit/test_federated_auth_plugin.py | 31 +++++++++++++++++++ tests/unit/test_okta_plugin.py | 30 ++++++++++++++++++ 4 files changed, 71 insertions(+), 4 deletions(-) diff --git a/aws_advanced_python_wrapper/federated_plugin.py b/aws_advanced_python_wrapper/federated_plugin.py index 4ab0fbf0..0eb8258f 100644 --- a/aws_advanced_python_wrapper/federated_plugin.py +++ b/aws_advanced_python_wrapper/federated_plugin.py @@ -98,7 +98,7 @@ def _connect(self, host_info: HostInfo, props: Properties, connect_func: Callabl region ) - token_info = FederatedAuthPlugin._token_cache.get(cache_key) + token_info: Optional[TokenInfo] = FederatedAuthPlugin._token_cache.get(cache_key) if token_info is not None and not token_info.is_expired(): logger.debug("FederatedAuthPlugin.UseCachedToken", token_info.token) @@ -110,7 +110,10 @@ def _connect(self, host_info: HostInfo, props: Properties, connect_func: Callabl try: return connect_func() - except Exception: + except Exception as e: + if token_info is None or token_info.is_expired() or not self._plugin_service.is_login_exception(e): + raise e + self._update_authentication_token(host_info, props, user, region, cache_key) try: diff --git a/aws_advanced_python_wrapper/okta_plugin.py b/aws_advanced_python_wrapper/okta_plugin.py index 88b92f13..55bd9980 100644 --- a/aws_advanced_python_wrapper/okta_plugin.py +++ b/aws_advanced_python_wrapper/okta_plugin.py @@ -94,7 +94,7 @@ def _connect(self, host_info: HostInfo, props: Properties, connect_func: Callabl region ) - token_info = OktaAuthPlugin._token_cache.get(cache_key) + token_info: Optional[TokenInfo] = OktaAuthPlugin._token_cache.get(cache_key) if token_info is not None and not token_info.is_expired(): logger.debug("OktaAuthPlugin.UseCachedToken", token_info.token) @@ -106,7 +106,10 @@ def _connect(self, host_info: HostInfo, props: Properties, connect_func: Callabl try: return connect_func() - except Exception: + except Exception as e: + if token_info is None or token_info.is_expired() or not self._plugin_service.is_login_exception(e): + raise e + self._update_authentication_token(host_info, props, user, region, cache_key) try: diff --git a/tests/unit/test_federated_auth_plugin.py b/tests/unit/test_federated_auth_plugin.py index 2014680f..1c3a77e3 100644 --- a/tests/unit/test_federated_auth_plugin.py +++ b/tests/unit/test_federated_auth_plugin.py @@ -173,6 +173,37 @@ def test_no_cached_token(mocker, mock_plugin_service, mock_session, mock_func, m assert WrapperProperties.PASSWORD.get(test_props) == _TEST_TOKEN +@patch("aws_advanced_python_wrapper.federated_plugin.FederatedAuthPlugin._token_cache", _token_cache) +def test_no_cached_token_raises_exception(mocker, mock_plugin_service, mock_session, mock_func, mock_client, mock_dialect, + mock_credentials_provider_factory): + test_props: Properties = Properties( + {"plugins": "federated_auth", "user": "postgresqlUser", "idp_username": "user", "idp_password": "password"}) + WrapperProperties.DB_USER.set(test_props, _DB_USER) + + exception_message = "generic exception" + mock_func.side_effect = Exception(exception_message) + + target_plugin: FederatedAuthPlugin = FederatedAuthPlugin(mock_plugin_service, mock_credentials_provider_factory, + mock_session) + with pytest.raises(Exception) as e_info: + target_plugin.connect( + target_driver_func=mocker.MagicMock(), + driver_dialect=mock_dialect, + host_info=_PG_HOST_INFO, + props=test_props, + is_initial_connection=False, + connect_func=mock_func) + + mock_client.generate_db_auth_token.assert_called_with( + DBHostname="pg.testdb.us-east-2.rds.amazonaws.com", + Port=5432, + DBUsername="postgresqlUser" + ) + + assert e_info.type == Exception + assert str(e_info.value) == exception_message + + @patch("aws_advanced_python_wrapper.federated_plugin.FederatedAuthPlugin._token_cache", _token_cache) def test_connect_with_specified_iam_host_port_region(mocker, mock_plugin_service, diff --git a/tests/unit/test_okta_plugin.py b/tests/unit/test_okta_plugin.py index 236bdd35..72f9727a 100644 --- a/tests/unit/test_okta_plugin.py +++ b/tests/unit/test_okta_plugin.py @@ -170,6 +170,36 @@ def test_no_cached_token(mocker, mock_plugin_service, mock_session, mock_func, m assert WrapperProperties.PASSWORD.get(test_props) == _TEST_TOKEN +@patch("aws_advanced_python_wrapper.okta_plugin.OktaAuthPlugin._token_cache", _token_cache) +def test_no_cached_token_raises_exception(mocker, mock_plugin_service, mock_session, mock_func, mock_client, + mock_dialect, mock_credentials_provider_factory): + test_props: Properties = Properties({"plugins": "okta", "user": "postgresqlUser", "idp_username": "user", "idp_password": "password"}) + WrapperProperties.DB_USER.set(test_props, _DB_USER) + + exception_message = "generic exception" + mock_func.side_effect = Exception(exception_message) + + target_plugin: OktaAuthPlugin = OktaAuthPlugin(mock_plugin_service, mock_credentials_provider_factory, mock_session) + + with pytest.raises(Exception) as e_info: + target_plugin.connect( + target_driver_func=mocker.MagicMock(), + driver_dialect=mock_dialect, + host_info=_PG_HOST_INFO, + props=test_props, + is_initial_connection=False, + connect_func=mock_func) + + mock_client.generate_db_auth_token.assert_called_with( + DBHostname="pg.testdb.us-east-2.rds.amazonaws.com", + Port=5432, + DBUsername="postgresqlUser" + ) + + assert e_info.type == Exception + assert str(e_info.value) == exception_message + + @patch("aws_advanced_python_wrapper.okta_plugin.OktaAuthPlugin._token_cache", _token_cache) def test_connect_with_specified_iam_host_port_region(mocker, mock_plugin_service,