Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions aws_advanced_python_wrapper/federated_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def _connect(self, host_info: HostInfo, props: Properties, connect_func: Callabl
region
)

token_info = FederatedAuthPlugin._token_cache.get(cache_key)
token_info: Optional[TokenInfo] = FederatedAuthPlugin._token_cache.get(cache_key)

if token_info is not None and not token_info.is_expired():
logger.debug("FederatedAuthPlugin.UseCachedToken", token_info.token)
Expand All @@ -110,7 +110,10 @@ def _connect(self, host_info: HostInfo, props: Properties, connect_func: Callabl

try:
return connect_func()
except Exception:
except Exception as e:
if token_info is None or token_info.is_expired() or not self._plugin_service.is_login_exception(e):
raise e

self._update_authentication_token(host_info, props, user, region, cache_key)

try:
Expand Down
7 changes: 5 additions & 2 deletions aws_advanced_python_wrapper/okta_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def _connect(self, host_info: HostInfo, props: Properties, connect_func: Callabl
region
)

token_info = OktaAuthPlugin._token_cache.get(cache_key)
token_info: Optional[TokenInfo] = OktaAuthPlugin._token_cache.get(cache_key)

if token_info is not None and not token_info.is_expired():
logger.debug("OktaAuthPlugin.UseCachedToken", token_info.token)
Expand All @@ -106,7 +106,10 @@ def _connect(self, host_info: HostInfo, props: Properties, connect_func: Callabl

try:
return connect_func()
except Exception:
except Exception as e:
if token_info is None or token_info.is_expired() or not self._plugin_service.is_login_exception(e):
raise e

self._update_authentication_token(host_info, props, user, region, cache_key)

try:
Expand Down
31 changes: 31 additions & 0 deletions tests/unit/test_federated_auth_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,37 @@ def test_no_cached_token(mocker, mock_plugin_service, mock_session, mock_func, m
assert WrapperProperties.PASSWORD.get(test_props) == _TEST_TOKEN


@patch("aws_advanced_python_wrapper.federated_plugin.FederatedAuthPlugin._token_cache", _token_cache)
def test_no_cached_token_raises_exception(mocker, mock_plugin_service, mock_session, mock_func, mock_client, mock_dialect,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Didn't see any unit tests associated with the JDBC wrapper so not sure if this is appropriate.

mock_credentials_provider_factory):
test_props: Properties = Properties(
{"plugins": "federated_auth", "user": "postgresqlUser", "idp_username": "user", "idp_password": "password"})
WrapperProperties.DB_USER.set(test_props, _DB_USER)

exception_message = "generic exception"
mock_func.side_effect = Exception(exception_message)

target_plugin: FederatedAuthPlugin = FederatedAuthPlugin(mock_plugin_service, mock_credentials_provider_factory,
mock_session)
with pytest.raises(Exception) as e_info:
target_plugin.connect(
target_driver_func=mocker.MagicMock(),
driver_dialect=mock_dialect,
host_info=_PG_HOST_INFO,
props=test_props,
is_initial_connection=False,
connect_func=mock_func)

mock_client.generate_db_auth_token.assert_called_with(
DBHostname="pg.testdb.us-east-2.rds.amazonaws.com",
Port=5432,
DBUsername="postgresqlUser"
)

assert e_info.type == Exception
assert str(e_info.value) == exception_message


@patch("aws_advanced_python_wrapper.federated_plugin.FederatedAuthPlugin._token_cache", _token_cache)
def test_connect_with_specified_iam_host_port_region(mocker,
mock_plugin_service,
Expand Down
30 changes: 30 additions & 0 deletions tests/unit/test_okta_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,36 @@ def test_no_cached_token(mocker, mock_plugin_service, mock_session, mock_func, m
assert WrapperProperties.PASSWORD.get(test_props) == _TEST_TOKEN


@patch("aws_advanced_python_wrapper.okta_plugin.OktaAuthPlugin._token_cache", _token_cache)
def test_no_cached_token_raises_exception(mocker, mock_plugin_service, mock_session, mock_func, mock_client,
mock_dialect, mock_credentials_provider_factory):
test_props: Properties = Properties({"plugins": "okta", "user": "postgresqlUser", "idp_username": "user", "idp_password": "password"})
WrapperProperties.DB_USER.set(test_props, _DB_USER)

exception_message = "generic exception"
mock_func.side_effect = Exception(exception_message)

target_plugin: OktaAuthPlugin = OktaAuthPlugin(mock_plugin_service, mock_credentials_provider_factory, mock_session)

with pytest.raises(Exception) as e_info:
target_plugin.connect(
target_driver_func=mocker.MagicMock(),
driver_dialect=mock_dialect,
host_info=_PG_HOST_INFO,
props=test_props,
is_initial_connection=False,
connect_func=mock_func)

mock_client.generate_db_auth_token.assert_called_with(
DBHostname="pg.testdb.us-east-2.rds.amazonaws.com",
Port=5432,
DBUsername="postgresqlUser"
)

assert e_info.type == Exception
assert str(e_info.value) == exception_message


@patch("aws_advanced_python_wrapper.okta_plugin.OktaAuthPlugin._token_cache", _token_cache)
def test_connect_with_specified_iam_host_port_region(mocker,
mock_plugin_service,
Expand Down