From 61dc665f903f482ec0774c784e9da1458b1494c1 Mon Sep 17 00:00:00 2001 From: Juan Lee Date: Fri, 27 Jun 2025 23:57:52 -0700 Subject: [PATCH 1/2] feat: limitless plugin implementation --- .../database_dialect.py | 12 +- aws_advanced_python_wrapper/default_plugin.py | 9 +- .../fastest_response_strategy_plugin.py | 6 +- .../host_list_provider.py | 10 + .../limitless_connection_plugin.py | 531 ++++++++++++++++++ aws_advanced_python_wrapper/plugin.py | 5 +- aws_advanced_python_wrapper/plugin_service.py | 17 +- ...dvanced_python_wrapper_messages.properties | 35 ++ .../utils/properties.py | 35 +- benchmarks/benchmark_plugin.py | 4 +- tests/unit/test_host_response_time_monitor.py | 2 +- tests/unit/test_plugin_manager.py | 2 +- 12 files changed, 646 insertions(+), 22 deletions(-) create mode 100644 aws_advanced_python_wrapper/limitless_connection_plugin.py diff --git a/aws_advanced_python_wrapper/database_dialect.py b/aws_advanced_python_wrapper/database_dialect.py index 4d997e7dd..0706ad1a5 100644 --- a/aws_advanced_python_wrapper/database_dialect.py +++ b/aws_advanced_python_wrapper/database_dialect.py @@ -98,6 +98,15 @@ def is_reader_query(self) -> str: return self._IS_READER_QUERY +@runtime_checkable +class AuroraLimitlessDialect(Protocol): + _LIMITLESS_ROUTER_ENDPOINT_QUERY: str + + @property + def limitless_router_endpoint_query(self) -> str: + return self._LIMITLESS_ROUTER_ENDPOINT_QUERY + + class DatabaseDialect(Protocol): """ Database dialects help the AWS Advanced Python Driver determine what kind of underlying database is being used, @@ -342,7 +351,7 @@ def get_host_list_provider_supplier(self) -> Callable: return lambda provider_service, props: RdsHostListProvider(provider_service, props) -class AuroraPgDialect(PgDatabaseDialect, TopologyAwareDatabaseDialect): +class AuroraPgDialect(PgDatabaseDialect, TopologyAwareDatabaseDialect, AuroraLimitlessDialect): _DIALECT_UPDATE_CANDIDATES: Tuple[DialectCode, ...] = (DialectCode.MULTI_AZ_PG,) _EXTENSIONS_QUERY = "SELECT (setting LIKE '%aurora_stat_utils%') AS aurora_stat_utils " \ @@ -359,6 +368,7 @@ class AuroraPgDialect(PgDatabaseDialect, TopologyAwareDatabaseDialect): _HOST_ID_QUERY = "SELECT aurora_db_instance_identifier()" _IS_READER_QUERY = "SELECT pg_is_in_recovery()" + _LIMITLESS_ROUTER_ENDPOINT_QUERY = "SELECT router_endpoint, load FROM aurora_limitless_router_endpoints()" @property def dialect_update_candidates(self) -> Optional[Tuple[DialectCode, ...]]: diff --git a/aws_advanced_python_wrapper/default_plugin.py b/aws_advanced_python_wrapper/default_plugin.py index 79ceb253d..86934f4df 100644 --- a/aws_advanced_python_wrapper/default_plugin.py +++ b/aws_advanced_python_wrapper/default_plugin.py @@ -14,7 +14,7 @@ from __future__ import annotations -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, List, Optional if TYPE_CHECKING: from aws_advanced_python_wrapper.connection_provider import (ConnectionProvider, @@ -118,7 +118,7 @@ def accepts_strategy(self, role: HostRole, strategy: str) -> bool: return False return self._connection_provider_manager.accepts_strategy(role, strategy) - def get_host_info_by_strategy(self, role: HostRole, strategy: str) -> HostInfo: + def get_host_info_by_strategy(self, role: HostRole, strategy: str, host_list: Optional[List[HostInfo]] = None) -> HostInfo: if HostRole.UNKNOWN == role: raise AwsWrapperError(Messages.get("DefaultPlugin.UnknownHosts")) @@ -127,7 +127,10 @@ def get_host_info_by_strategy(self, role: HostRole, strategy: str) -> HostInfo: if len(hosts) < 1: raise AwsWrapperError(Messages.get("DefaultPlugin.EmptyHosts")) - return self._connection_provider_manager.get_host_info_by_strategy(hosts, role, strategy, self._plugin_service.props) + if host_list is None: + return self._connection_provider_manager.get_host_info_by_strategy(hosts, role, strategy, self._plugin_service.props) + else: + return self._connection_provider_manager.get_host_info_by_strategy(tuple(host_list), role, strategy, self._plugin_service.props) @property def subscribed_methods(self) -> Set[str]: diff --git a/aws_advanced_python_wrapper/fastest_response_strategy_plugin.py b/aws_advanced_python_wrapper/fastest_response_strategy_plugin.py index 26c915cc8..e85d40cc6 100644 --- a/aws_advanced_python_wrapper/fastest_response_strategy_plugin.py +++ b/aws_advanced_python_wrapper/fastest_response_strategy_plugin.py @@ -58,8 +58,8 @@ def __init__(self, plugin_service: PluginService, props: Properties): self._plugin_service = plugin_service self._properties = props self._host_response_time_service: HostResponseTimeService = \ - HostResponseTimeService(plugin_service, props, WrapperProperties.RESPONSE_MEASUREMENT_INTERVAL_MILLIS.get_int(props)) - self._cache_expiration_nanos = WrapperProperties.RESPONSE_MEASUREMENT_INTERVAL_MILLIS.get_int(props) * 10 ^ 6 + HostResponseTimeService(plugin_service, props, WrapperProperties.RESPONSE_MEASUREMENT_INTERVAL_MS.get_int(props)) + self._cache_expiration_nanos = WrapperProperties.RESPONSE_MEASUREMENT_INTERVAL_MS.get_int(props) * 10 ^ 6 self._random_host_selector = RandomHostSelector() self._cached_fastest_response_host_by_role: CacheMap[str, HostInfo] = CacheMap() self._hosts: Tuple[HostInfo, ...] = () @@ -86,7 +86,7 @@ def connect( def accepts_strategy(self, role: HostRole, strategy: str) -> bool: return strategy == FastestResponseStrategyPlugin._FASTEST_RESPONSE_STRATEGY_NAME - def get_host_info_by_strategy(self, role: HostRole, strategy: str) -> HostInfo: + def get_host_info_by_strategy(self, role: HostRole, strategy: str, host_list: Optional[List[HostInfo]] = None) -> HostInfo: if not self.accepts_strategy(role, strategy): logger.error("FastestResponseStrategyPlugin.UnsupportedHostSelectorStrategy", strategy) raise AwsWrapperError(Messages.get_formatted("FastestResponseStrategyPlugin.UnsupportedHostSelectorStrategy", strategy)) diff --git a/aws_advanced_python_wrapper/host_list_provider.py b/aws_advanced_python_wrapper/host_list_provider.py index 4ead4ebe6..16cf7b7cd 100644 --- a/aws_advanced_python_wrapper/host_list_provider.py +++ b/aws_advanced_python_wrapper/host_list_provider.py @@ -69,6 +69,9 @@ def get_host_role(self, connection: Connection) -> HostRole: def identify_connection(self, connection: Optional[Connection]) -> Optional[HostInfo]: ... + def get_cluster_id(self) -> str: + ... + @runtime_checkable class DynamicHostListProvider(HostListProvider, Protocol): @@ -519,6 +522,10 @@ def _identify_connection(self, conn: Connection): cursor.execute(self._dialect.host_id_query) return cursor.fetchone() + def get_cluster_id(self): + self._initialize() + return self._cluster_id + @dataclass() class ClusterIdSuggestion: cluster_id: str @@ -646,3 +653,6 @@ def get_host_role(self, connection: Connection) -> HostRole: def identify_connection(self, connection: Optional[Connection]) -> Optional[HostInfo]: raise UnsupportedOperationError( Messages.get_formatted("ConnectionStringHostListProvider.UnsupportedMethod", "identify_connection")) + + def get_cluster_id(self): + return "" diff --git a/aws_advanced_python_wrapper/limitless_connection_plugin.py b/aws_advanced_python_wrapper/limitless_connection_plugin.py new file mode 100644 index 000000000..41b89cd18 --- /dev/null +++ b/aws_advanced_python_wrapper/limitless_connection_plugin.py @@ -0,0 +1,531 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import copy +import time +from contextlib import closing +from threading import Event, RLock, Thread +from time import sleep +from typing import (TYPE_CHECKING, Any, Callable, ClassVar, List, Optional, + Set, Tuple) + +from aws_advanced_python_wrapper.database_dialect import ( + AuroraLimitlessDialect, DatabaseDialect) +from aws_advanced_python_wrapper.errors import (AwsWrapperError, + UnsupportedOperationError) +from aws_advanced_python_wrapper.host_availability import HostAvailability +from aws_advanced_python_wrapper.hostinfo import HostInfo, HostRole +from aws_advanced_python_wrapper.plugin import Plugin +from aws_advanced_python_wrapper.utils.concurrent import ConcurrentDict +from aws_advanced_python_wrapper.utils.log import Logger +from aws_advanced_python_wrapper.utils.messages import Messages +from aws_advanced_python_wrapper.utils.properties import (Properties, + WrapperProperties) +from aws_advanced_python_wrapper.utils.sliding_expiration_cache import \ + SlidingExpirationCacheWithCleanupThread +from aws_advanced_python_wrapper.utils.telemetry.telemetry import ( + TelemetryContext, TelemetryFactory, TelemetryTraceLevel) + +if TYPE_CHECKING: + from aws_advanced_python_wrapper.driver_dialect import DriverDialect + from aws_advanced_python_wrapper.pep249 import Connection + from aws_advanced_python_wrapper.plugin_service import PluginService + +logger = Logger(__name__) + + +class LimitlessConnectionPlugin(Plugin): + _SUBSCRIBED_METHODS: Set[str] = {"connect"} + + def __init__(self, plugin_service: PluginService, props: Properties): + self._plugin_service = plugin_service + self._properties = props + + @property + def subscribed_methods(self) -> Set[str]: + return self._SUBSCRIBED_METHODS + + def connect( + self, + target_driver_func: Callable, + driver_dialect: DriverDialect, + host_info: HostInfo, + props: Properties, + is_initial_connection: bool, + connect_func: Callable) -> Connection: + + connection: Optional[Connection] = None + + dialect: DatabaseDialect = self._plugin_service.database_dialect + if not isinstance(dialect, AuroraLimitlessDialect): + connection = connect_func() + refreshed_dialect = self._plugin_service.database_dialect + + if not isinstance(refreshed_dialect, AuroraLimitlessDialect): + raise UnsupportedOperationError( + Messages.get_formatted("LimitlessConnectionPlugin.UnsupportedDialectOrDatabase", + type(refreshed_dialect).__name__)) + + limitless_router_service = LimitlessRouterService( + self._plugin_service, + LimitlessQueryHelper(self._plugin_service) + ) + + if is_initial_connection: + limitless_router_service.start_monitoring(host_info, props) + + context: LimitlessConnectionContext = LimitlessConnectionContext( + host_info, + props, + connection, + connect_func, + [], + self + ) + limitless_router_service.establish_connection(context) + connection = context.get_connection() + if connection is not None and not self._plugin_service.driver_dialect.is_closed(connection): + return context.get_connection() + + raise AwsWrapperError(Messages.get_formatted("LimitlessConnectionPlugin.FailedToConnectToHost", host_info.host)) + + +class LimitlessConnectionPluginFactory: + + def get_instance(self, plugin_service: PluginService, props: Properties) -> Plugin: + return LimitlessConnectionPlugin(plugin_service, props) + + +class LimitlessRouterMonitor: + _MONITORING_PROPERTY_PREFIX: str = "limitless-router-monitor-" + + def __init__(self, + plugin_service: PluginService, + host_info: HostInfo, + limitless_router_cache: SlidingExpirationCacheWithCleanupThread, + limitless_router_cache_key: str, + props: Properties, + interval_ms: int): + self._plugin_service = plugin_service + self._host_info = host_info + self._limitless_router_cache = limitless_router_cache + self._limitless_router_cache_key = limitless_router_cache_key + + self._properties = copy.deepcopy(props) + for property_key in self._properties.keys(): + if property_key.startswith(self._MONITORING_PROPERTY_PREFIX): + self._properties[property_key[len(self._MONITORING_PROPERTY_PREFIX):]] = self._properties[property_key] + self._properties.pop(property_key) + + WrapperProperties.WAIT_FOR_ROUTER_INFO.set(self._properties, False) + + self._interval_ms = interval_ms + self._query_helper = LimitlessQueryHelper(plugin_service) + self._telemetry_factory: TelemetryFactory = self._plugin_service.get_telemetry_factory() + self._monitoring_conn: Optional[Connection] = None + self._is_stopped: Event = Event() + + self._daemon_thread: Thread = Thread(daemon=True, target=self.run) + self._daemon_thread.start() + + @property + def host_info(self): + return self._host_info + + @property + def is_stopped(self): + return self._is_stopped.is_set() + + def close(self): + self._is_stopped.set() + if self._monitoring_conn: + self._monitoring_conn.close() + self._daemon_thread.join(5) + logger.debug("LimitlessRouterMonitor.Stopped", self.host_info.host) + + def run(self): + logger.debug("LimitlessRouterMonitor.Running", self.host_info.host) + + telemetry_context: TelemetryContext = self._telemetry_factory.open_telemetry_context( + "limitless router monitor thread", TelemetryTraceLevel.TOP_LEVEL) + if telemetry_context is not None: + telemetry_context.set_attribute("url", self._host_info.url) + + try: + while not self.is_stopped: + self._open_connection() + + if self._monitoring_conn is not None: + + new_limitless_routers = self._query_helper.query_for_limitless_routers(self._monitoring_conn, self._host_info.port) + self._limitless_router_cache.compute_if_absent(self._limitless_router_cache_key, + lambda _: new_limitless_routers, + WrapperProperties.LIMITLESS_MONITOR_DISPOSAL_TIME_MS.get( + self._properties) * 1_000_000) + + sleep(self._interval_ms / 1000) + + except InterruptedError as e: + logger.debug("LimitlessRouterMonitor.InterruptedExceptionDuringMonitoring", self.host_info.host) + if telemetry_context: + telemetry_context.set_exception(e) + telemetry_context.set_success(False) + except Exception as e: + # this should not be reached; log and exit thread + logger.debug("LimitlessRouterMonitor.ExceptionDuringMonitoringStop", self.host_info.host, e) + if telemetry_context: + telemetry_context.set_exception(e) + telemetry_context.set_success(False) + + finally: + self._is_stopped.set() + if self._monitoring_conn is not None: + try: + self._monitoring_conn.close() + except Exception: + # Ignore + pass + + if telemetry_context is not None: + telemetry_context.close_context() + + def _open_connection(self): + try: + driver_dialect = self._plugin_service.driver_dialect + if self._monitoring_conn is None or driver_dialect.is_closed(self._monitoring_conn): + logger.debug("LimitlessRouterMonitor.OpeningConnection", self.host_info.url) + self._monitoring_conn = self._plugin_service.force_connect(self._host_info, self._properties, None) + logger.debug("LimitlessRouterMonitor.OpenedConnection", self._monitoring_conn) + + except Exception as e: + if self._monitoring_conn is not None: + try: + self._monitoring_conn.close() + except Exception: + pass # ignore + + self._monitoring_conn = None + raise e + + +class LimitlessQueryHelper: + _DEFAULT_QUERY_TIMEOUT_SEC: int = 5 + + def __init__(self, plugin_service: PluginService): + self._plugin_service = plugin_service + + def query_for_limitless_routers(self, connection: Connection, host_port_to_map: int) -> List[HostInfo]: + + database_dialect = self._plugin_service.database_dialect + if not isinstance(database_dialect, AuroraLimitlessDialect): + raise UnsupportedOperationError(Messages.get("LimitlessQueryHelper.UnsupportedDialectOrDatabase")) + aurora_limitless_dialect: AuroraLimitlessDialect = database_dialect + query = aurora_limitless_dialect.limitless_router_endpoint_query + + with closing(connection.cursor()) as cursor: + self._plugin_service.driver_dialect.execute("Cursor.execute", + lambda: cursor.execute(query), + query, + exec_timeout=LimitlessQueryHelper._DEFAULT_QUERY_TIMEOUT_SEC) + return self._map_result_set_to_host_info_list(cursor.fetchall(), host_port_to_map) + + def _map_result_set_to_host_info_list(self, result_set: List[Tuple[Any, Any]], host_port_to_map: int) -> List[HostInfo]: + list_of_host_infos: List[HostInfo] = [] + for result in result_set: + list_of_host_infos.append(self._create_host_info(result, host_port_to_map)) + return list_of_host_infos + + def _create_host_info(self, result: Tuple[Any, Any], host_port_to_map: int) -> HostInfo: + host_name: str = result[0] + cpu: float = float(result[1]) + + weight: int = round(10 - (cpu * 10)) + if weight < 1 or weight > 10: + weight = 1 + logger.debug("LimitlessRouterMonitor.InvalidRouterLoad", host_name, cpu) + + return HostInfo(host_name, host_port_to_map, weight=weight, host_id=host_name) + + +class LimitlessConnectionContext: + + def __init__(self, + host_info: HostInfo, + props: Properties, + connection: Optional[Connection], + connect_func: Callable, + limitless_routers: List[HostInfo], + connection_plugin: LimitlessConnectionPlugin) -> None: + self._host_info = host_info + self._props = props + self._connection = connection + self._connect_func = connect_func + self._limitless_routers = limitless_routers + self._connection_plugin = connection_plugin + + def get_host_info(self): + return self._host_info + + def get_props(self): + return self._props + + def get_connection(self): + return self._connection + + def set_connection(self, connection: Connection): + if self._connection is not None and self._connection != connection: + try: + self._connection.close() + except Exception: + pass + + self._connection = connection + + def get_connect_func(self) -> Callable: + return self._connect_func + + def get_limitless_routers(self): + return self._limitless_routers + + def set_limitless_routers(self, limitless_routers: List[HostInfo]): + self._limitless_routers = limitless_routers + + def get_connection_plugin(self): + return self._connection_plugin + + def is_any_router_available(self): + for router in self.get_limitless_routers(): + if router.get_availability() == HostAvailability.AVAILABLE: + return True + return False + + +class LimitlessRouterService: + _CACHE_CLEANUP_NS: int = 6 * 10 ^ 10 # 1 minute + _limitless_router_cache: ClassVar[SlidingExpirationCacheWithCleanupThread[str, List[HostInfo]]] = \ + SlidingExpirationCacheWithCleanupThread(_CACHE_CLEANUP_NS) + + _limitless_router_monitor: ClassVar[SlidingExpirationCacheWithCleanupThread[str, LimitlessRouterMonitor]] = \ + SlidingExpirationCacheWithCleanupThread(_CACHE_CLEANUP_NS, + should_dispose_func=lambda monitor: True, + item_disposal_func=lambda monitor: monitor.close()) + + _force_get_limitless_routers_lock_map: ClassVar[ConcurrentDict[str, RLock]] = ConcurrentDict() + + def __init__(self, plugin_service: PluginService, query_helper: LimitlessQueryHelper): + self._plugin_service = plugin_service + self._query_helper = query_helper + + def establish_connection(self, context: LimitlessConnectionContext) -> None: + context.set_limitless_routers(self._get_limitless_routers( + self._plugin_service.host_list_provider.get_cluster_id(), context.get_props())) + + if context.get_limitless_routers() is None or len(context.get_limitless_routers()) == 0: + logger.debug("LimitlessRouterServiceImpl.limitlessRouterCacheEmpty") + + wait_for_router_info = WrapperProperties.WAIT_FOR_ROUTER_INFO.get(context.get_props()) + if wait_for_router_info: + self._synchronously_get_limitless_routers_with_retry(context) + else: + logger.debug("LimitlessRouterServiceImpl.UsingProvidedConnectUrl") + if context.get_connection() is None or self._plugin_service.driver_dialect.is_closed(context.get_connection()): + context.set_connection(context.get_connect_func()()) + + if context.get_host_info in context.get_limitless_routers(): + logger.debug("LimitlessRouterServiceImpl.ConnectWithHost") + if context.get_connection() is None: + try: + context.set_connection(context.get_connect_func()()) + except Exception as e: + if self._is_login_exception(e): + raise e + + self._retry_connection_with_least_loaded_routers(context) + return + + try: + selected_host_info = self._plugin_service.get_host_info_by_strategy( + HostRole.WRITER, "weighted_random", context.get_limitless_routers()) + logger.debug("LimitlessRouterServiceImpl.SelectedHost", "None" if selected_host_info is None else selected_host_info.host) + except Exception as e: + if self._is_login_exception(e) or isinstance(e, UnsupportedOperationError): + raise e + + self._retry_connection_with_least_loaded_routers(context) + return + + if selected_host_info is None: + self._retry_connection_with_least_loaded_routers(context) + return + + try: + context.set_connection(self._plugin_service.connect(selected_host_info, context.get_props(), context.get_connection_plugin())) + except Exception as e: + if self._is_login_exception(e): + raise e + + if selected_host_info is not None: + logger.debug("LimitlessRouterServiceImpl.FailedToConnectToHost", selected_host_info.host) + selected_host_info.set_availability(HostAvailability.UNAVAILABLE) + + self._retry_connection_with_least_loaded_routers(context) + + def _get_limitless_routers(self, cluster_id: str, props: Properties) -> List[HostInfo]: + # Convert milliseconds to nanoseconds + cache_expiration_nano: int = WrapperProperties.LIMITLESS_MONITOR_DISPOSAL_TIME_MS.get_int(props) * 1_000_000 + LimitlessRouterService._limitless_router_cache.set_cleanup_interval_ns(cache_expiration_nano) + routers = LimitlessRouterService._limitless_router_cache.get(cluster_id) + if routers is None: + return [] + return routers + + def _retry_connection_with_least_loaded_routers(self, context: LimitlessConnectionContext) -> None: + retry_count = 0 + max_retries = WrapperProperties.MAX_RETRIES_MS.get_int(context.get_props()) + while retry_count < max_retries: + retry_count += 1 + if context.get_limitless_routers() is None or len(context.get_limitless_routers()) == 0 or not context.is_any_router_available(): + self._synchronously_get_limitless_routers_with_retry(context) + + if (context.get_limitless_routers() is None + or len(context.get_limitless_routers()) == 0 + or not context.is_any_router_available()): + logger.debug("LimitlessRouterServiceImpl.NoRoutersAvailableForRetry") + + if context.get_connection() is not None and not self._plugin_service.driver_dialect.is_closed(context.get_connection()): + return + else: + try: + context.set_connection(context.get_connect_func()()) + return + except Exception as e: + if self._is_login_exception(e): + raise e + + raise AwsWrapperError(Messages.get_formatted("LimitlessRouterService.UnableToConnectNoRoutersAvailable"), + context.get_host_info().host) from e + + try: + selected_host_info = self._plugin_service.get_host_info_by_strategy( + HostRole.WRITER, "weighted_random", context.get_limitless_routers()) + logger.debug("LimitlessRouterServiceImpl.SelectedHostForRetry", + "None" if selected_host_info is None else selected_host_info.host) + if selected_host_info is None: + continue + + except UnsupportedOperationError as e: + logger.error("LimitlessRouterServiceImpl.IncorrectConfiguration") + raise e + except AwsWrapperError: + continue + + try: + context.set_connection(self._plugin_service.connect(selected_host_info, context.get_props(), context.get_connection_plugin())) + if context.get_connection() is not None: + return + + except Exception as e: + if self._is_login_exception(e): + raise e + selected_host_info.set_availability(HostAvailability.UNAVAILABLE) + logger.debug("LimitlessRouterServiceImpl.FailedToConnectToHost", selected_host_info.host) + + raise AwsWrapperError(Messages.get("LimitlessRouterService.MaxRetriesExceeded")) + + def _synchronously_get_limitless_routers_with_retry(self, context: LimitlessConnectionContext) -> None: + logger.debug("LimitlessRouterServiceImpl.SynchronouslyGetLimitlessRouters") + retry_count = -1 + max_retries = WrapperProperties.MAX_RETRIES_MS.get_int(context.get_props()) + retry_interval_ms = WrapperProperties.GET_ROUTER_RETRY_INTERVAL_MS.get_float(context.get_props()) + first_iteration = True + while first_iteration or retry_count < max_retries: + # Emulate do while loop + first_iteration = False + try: + self._synchronously_get_limitless_routers(context) + if context.get_limitless_routers() is not None or len(context.get_limitless_routers()) > 0: + return + + time.sleep(retry_interval_ms) + except InterruptedError as e: + raise AwsWrapperError(Messages.get("LimitlessRouterService.InterruptedSynchronousGetRouter"), e) + except Exception as e: + if self._is_login_exception(e): + raise e + finally: + retry_count += 1 + + raise AwsWrapperError(Messages.get("LimitlessRouterService.NoRoutersAvailable")) + + def _synchronously_get_limitless_routers(self, context: LimitlessConnectionContext) -> None: + cache_expiration_nano: int = WrapperProperties.LIMITLESS_MONITOR_DISPOSAL_TIME_MS.get_int(context.get_props()) * 1_000_000 + + lock = LimitlessRouterService._force_get_limitless_routers_lock_map.compute_if_absent( + self._plugin_service.host_list_provider.get_cluster_id(), + lambda _: RLock() + ) + if lock is None: + raise AwsWrapperError(Messages.get("LimitlessRouterService.LockFailedToAcquire")) + + lock.acquire() + try: + limitless_routers = LimitlessRouterService._limitless_router_cache.get( + self._plugin_service.host_list_provider.get_cluster_id()) + if limitless_routers is not None and len(limitless_routers) != 0: + context.set_limitless_routers(limitless_routers) + return + connection = context.get_connection() + if connection is None or self._plugin_service.driver_dialect.is_closed(connection): + context.set_connection(context.get_connect_func()()) + + new_limitless_routers: List[HostInfo] = self._query_helper.query_for_limitless_routers( + connection, context.get_host_info().port) + + if new_limitless_routers is not None and len(new_limitless_routers) != 0: + context.set_limitless_routers(new_limitless_routers) + LimitlessRouterService._limitless_router_cache.compute_if_absent( + self._plugin_service.host_list_provider.get_cluster_id(), + lambda _: new_limitless_routers, + cache_expiration_nano + ) + else: + raise AwsWrapperError(Messages.get("LimitlessRouterService.FetchedEmptyRouterList")) + + finally: + lock.release() + + def _is_login_exception(self, error: Optional[Exception] = None): + self._plugin_service.is_login_exception(error) + + def start_monitoring(self, host_info: HostInfo, + props: Properties) -> None: + try: + limitless_router_monitor_key: str = self._plugin_service.host_list_provider.get_cluster_id() + cache_expiration_nano: int = WrapperProperties.LIMITLESS_MONITOR_DISPOSAL_TIME_MS.get_int(props) * 1_000_000 + intervals_ms: int = WrapperProperties.LIMITLESS_INTERVAL_MILLIS.get_int(props) + + LimitlessRouterService._limitless_router_monitor.compute_if_absent( + limitless_router_monitor_key, + lambda _: LimitlessRouterMonitor(self._plugin_service, + host_info, + LimitlessRouterService._limitless_router_cache, + limitless_router_monitor_key, + props, + intervals_ms), cache_expiration_nano) + except Exception as e: + logger.debug("LimitlessRouterService.ErrorStartingMonitor", e) + raise e + + def clear_cache(self) -> None: + LimitlessRouterService._force_get_limitless_routers_lock_map.clear() + LimitlessRouterService._limitless_router_cache.clear() diff --git a/aws_advanced_python_wrapper/plugin.py b/aws_advanced_python_wrapper/plugin.py index 2ce34dd46..6a1d1be4b 100644 --- a/aws_advanced_python_wrapper/plugin.py +++ b/aws_advanced_python_wrapper/plugin.py @@ -14,7 +14,7 @@ from __future__ import annotations -from typing import TYPE_CHECKING, runtime_checkable +from typing import TYPE_CHECKING, List, Optional, runtime_checkable if TYPE_CHECKING: from aws_advanced_python_wrapper.driver_dialect import DriverDialect @@ -113,7 +113,7 @@ def accepts_strategy(self, role: HostRole, strategy: str) -> bool: """ return False - def get_host_info_by_strategy(self, role: HostRole, strategy: str) -> HostInfo: + def get_host_info_by_strategy(self, role: HostRole, strategy: str, host_list: Optional[List[HostInfo]] = None) -> HostInfo: """ Selects a :py:class:`HostInfo` with the requested role from available hosts using the requested strategy. :py:method:`ConnectionPlugin.accepts_strategy` should be called first to evaluate if this py:class:`ConnectionPlugin` supports the selection @@ -121,6 +121,7 @@ def get_host_info_by_strategy(self, role: HostRole, strategy: str) -> HostInfo: :param role: the desired role of the selected host - either a reader host or a writer host. :param strategy: the strategy that should be used to pick a host (eg "random"). + :param host_list: Optional list to select host from given input. :return: a py:class:`HostInfo` with the requested role. """ raise UnsupportedOperationError(Messages.get_formatted("Plugin.UnsupportedMethod", "get_host_info_by_strategy")) diff --git a/aws_advanced_python_wrapper/plugin_service.py b/aws_advanced_python_wrapper/plugin_service.py index ae15cacaf..11c13c687 100644 --- a/aws_advanced_python_wrapper/plugin_service.py +++ b/aws_advanced_python_wrapper/plugin_service.py @@ -24,6 +24,8 @@ FastestResponseStrategyPluginFactory from aws_advanced_python_wrapper.federated_plugin import \ FederatedAuthPluginFactory +from aws_advanced_python_wrapper.limitless_connection_plugin import \ + LimitlessConnectionPluginFactory from aws_advanced_python_wrapper.okta_plugin import OktaAuthPluginFactory from aws_advanced_python_wrapper.states.session_state_service import ( SessionStateService, SessionStateServiceImpl) @@ -222,7 +224,7 @@ def accepts_strategy(self, role: HostRole, strategy: str) -> bool: """ ... - def get_host_info_by_strategy(self, role: HostRole, strategy: str) -> Optional[HostInfo]: + def get_host_info_by_strategy(self, role: HostRole, strategy: str, host_list: Optional[List[HostInfo]] = None) -> Optional[HostInfo]: """ Selects a :py:class:`HostInfo` with the requested role from available hosts using the requested strategy. :py:method:`PluginService.accepts_strategy` should be called first to evaluate if any of the configured :py:class:`ConnectionPlugin` @@ -230,6 +232,7 @@ def get_host_info_by_strategy(self, role: HostRole, strategy: str) -> Optional[H :param role: the desired role of the selected host - either a reader host or a writer host. :param strategy: the strategy that should be used to pick a host (eg "random"). + :param host_list: Optional list to select host from given input. :return: a py:class:`HostInfo` with the requested role. """ ... @@ -506,9 +509,9 @@ def accepts_strategy(self, role: HostRole, strategy: str) -> bool: plugin_manager: PluginManager = self._container.plugin_manager return plugin_manager.accepts_strategy(role, strategy) - def get_host_info_by_strategy(self, role: HostRole, strategy: str) -> Optional[HostInfo]: + def get_host_info_by_strategy(self, role: HostRole, strategy: str, host_list: Optional[List[HostInfo]] = None) -> Optional[HostInfo]: plugin_manager: PluginManager = self._container.plugin_manager - return plugin_manager.get_host_info_by_strategy(role, strategy) + return plugin_manager.get_host_info_by_strategy(role, strategy, host_list) def get_host_role(self, connection: Optional[Connection] = None) -> HostRole: connection = connection if connection is not None else self.current_connection @@ -727,7 +730,8 @@ class PluginManager(CanReleaseResources): "dev": DeveloperPluginFactory, "federated_auth": FederatedAuthPluginFactory, "okta": OktaAuthPluginFactory, - "initial_connection": AuroraInitialConnectionStrategyPluginFactory + "initial_connection": AuroraInitialConnectionStrategyPluginFactory, + "limitless": LimitlessConnectionPluginFactory, } WEIGHT_RELATIVE_TO_PRIOR_PLUGIN = -1 @@ -747,6 +751,7 @@ class PluginManager(CanReleaseResources): IamAuthPluginFactory: 700, AwsSecretsManagerPluginFactory: 800, FederatedAuthPluginFactory: 900, + LimitlessConnectionPluginFactory: 950, OktaAuthPluginFactory: 1000, ConnectTimePluginFactory: WEIGHT_RELATIVE_TO_PRIOR_PLUGIN, ExecuteTimePluginFactory: WEIGHT_RELATIVE_TO_PRIOR_PLUGIN, @@ -1030,7 +1035,7 @@ def accepts_strategy(self, role: HostRole, strategy: str) -> bool: return False - def get_host_info_by_strategy(self, role: HostRole, strategy: str) -> Optional[HostInfo]: + def get_host_info_by_strategy(self, role: HostRole, strategy: str, host_list: Optional[List[HostInfo]] = None) -> Optional[HostInfo]: for plugin in self._plugins: plugin_subscribed_methods = plugin.subscribed_methods is_subscribed = \ @@ -1039,7 +1044,7 @@ def get_host_info_by_strategy(self, role: HostRole, strategy: str) -> Optional[H if is_subscribed: try: - host: HostInfo = plugin.get_host_info_by_strategy(role, strategy) + host: HostInfo = plugin.get_host_info_by_strategy(role, strategy, host_list) if host is not None: return host except UnsupportedOperationError: diff --git a/aws_advanced_python_wrapper/resources/aws_advanced_python_wrapper_messages.properties b/aws_advanced_python_wrapper/resources/aws_advanced_python_wrapper_messages.properties index 5fdb2aed5..28fd0582b 100644 --- a/aws_advanced_python_wrapper/resources/aws_advanced_python_wrapper_messages.properties +++ b/aws_advanced_python_wrapper/resources/aws_advanced_python_wrapper_messages.properties @@ -153,6 +153,40 @@ IamAuthPlugin.InvalidHost=[IamAuthPlugin] Invalid IAM host {}. The IAM host must IamAuthPlugin.IsNoneOrEmpty=[IamAuthPlugin] Property "{}" is None or empty. IamAuthUtils.GeneratedNewAuthToken=Generated new authentication token = {} +LimitlessConnectionPlugin.FailedToConnectToHost=[LimitlessConnectionPlugin] Failed to connect to host {}. +LimitlessConnectionPlugin.UnsupportedDialectOrDatabase=[LimitlessConnectionPlugin] Unsupported dialect '{}' encountered. Please ensure the connection parameters are correct, and refer to the documentation to ensure that the connecting database is compatible with the Limitless Connection Plugin. + +LimitlessQueryHelper.UnsupportedDialectOrDatabase=[LimitlessQueryHelper] Unsupported dialect '{}' encountered. Please ensure JDBC connection parameters are correct, and refer to the documentation to ensure that the connecting database is compatible with the Limitless Connection Plugin. + +LimitlessRouterMonitor.errorDuringMonitoringStop=[LimitlessRouterMonitor] Stopping monitoring after unhandled error was thrown in Limitless Router Monitoring thread for node {}. Error: {} +LimitlessRouterMonitor.InterruptedErrorDuringMonitoring=[LimitlessRouterMonitor] Limitless Router Monitoring thread for node {} was interrupted. +LimitlessRouterMonitor.InvalidQuery=[LimitlessRouterMonitor] Limitless Connection Plugin has encountered an error obtaining Limitless Router endpoints. Please ensure that you are connecting to an Aurora Limitless Database Shard Group Endpoint URL. +LimitlessRouterMonitor.InvalidRouterLoad=[LimitlessRouterMonitor] Invalid load metric value of '{}' from the transaction router query aurora_limitless_router_endpoints() for transaction router '{}'. The load metric value must be a decimal value between 0 and 1. Host weight be assigned a default weight of 1. +LimitlessRouterMonitor.GetNetworkTimeoutError=[LimitlessRouterMonitor] An error occurred while getting the connection network timeout: {} +LimitlessRouterMonitor.OpeningConnection=[LimitlessRouterMonitor] Opening Limitless Router Monitor connection to '{}'. +LimitlessRouterMonitor.OpenedConnection=[LimitlessRouterMonitor] Opened Limitless Router Monitor connection: {}. +LimitlessRouterMonitor.Running=[LimitlessRouterMonitor] Limitless Router Monitor thread running on node {}. +LimitlessRouterMonitor.Stopped=[LimitlessRouterMonitor] Limitless Router Monitor thread stopped on node {}. + +LimitlessRouterService.ConnectWithHost=[LimitlessRouterService] Connecting to host {}. +LimitlessRouterService.ErrorClosingMonitor=[LimitlessRouterService] An error occurred while closing Limitless Router Monitor: {} +LimitlessRouterService.ErrorStartingMonitor=[LimitlessRouterService] An error occurred while starting Limitless Router Monitor: {} +LimitlessRouterService.FailedToConnectToHost=[LimitlessRouterService] Failed to connect to host {}. +LimitlessRouterService.FetchedEmptyRouterList=[LimitlessRouterService] Empty router list was fetched. +LimitlessRouterService.GetLimitlessRoutersError=[LimitlessRouterService] error encountered getting Limitless Routers. {} +LimitlessRouterService.IncorrectConfiguration=[LimitlessRouterService] Limitless Connection Plugin is unable to run. Please ensure the connection settings are correct. +LimitlessRouterService.InterruptedSynchronousGetRouter=[LimitlessRouterService] Limitless Router Service thread was interrupted while waiting to fetch Limitless Transaction Routers. +LimitlessRouterService.LimitlessRouterCacheEmpty=[LimitlessRouterService] Limitless Router cache is empty. This is normal during application start up when the cache is not yet populated. +LimitlessRouterService.LockFailedToAcquire=[LimitlessRouterService] Failed to acquire Lock. +LimitlessRouterService.MaxRetriesExceeded=[LimitlessRouterService] Max number of connection retries has been exceeded. Unable to connect to any transaction router. +LimitlessRouterService.NoRoutersAvailable=[LimitlessRouterService] Unable to connect to any transaction router. +LimitlessRouterService.NoRoutersAvailableForRetry=[LimitlessRouterService] No transaction routers available for connection retry. Retrying with original connection. +LimitlessRouterService.UnableToConnectNoRoutersAvailable=[LimitlessRouterService] Unable to connect to original host {}. All transaction routers are unavailable. Please verify connection credentials and network connectivity. +LimitlessRouterService.SelectedHost=[LimitlessRouterService] Host {} has been selected. +LimitlessRouterService.SelectedHostForRetry=[LimitlessRouterService] Host {} has been selected for connection retry. +LimitlessRouterService.SynchronouslyGetLimitlessRouters=[LimitlessRouterService] Fetching Limitless Routers synchronously. +LimitlessRouterService.UsingProvidedConnectUrl=[LimitlessRouterService] Connecting using provided connection URL. + LogUtils.Topology=[LogUtils] Topology {} Monitor.ContextNone=[Monitor] Parameter 'context' should not evaluate to None. @@ -208,6 +242,7 @@ PluginServiceImpl.IncorrectStatusType=[PluginServiceImpl] Received an unexpected PluginServiceImpl.NonEmptyAliases=[PluginServiceImpl] fill_aliases called when HostInfo already contains the following aliases: {}. PluginServiceImpl.UnableToUpdateTransactionStatus=[PluginServiceImpl] Unable to update transaction status, current connection is None. PluginServiceImpl.UpdateDialectConnectionNone=[PluginServiceImpl] The plugin service attempted to update the current dialect but could not identify a connection to use. +PluginServiceImpl.UnsupportedStrategy=[PluginServiceImpl] The driver does not support the requested host selection strategy: {} PropertiesUtils.ErrorParsingConnectionString=[PropertiesUtils] An error occurred while parsing the connection string: '{}'. Please ensure the format of your connection string is valid. PropertiesUtils.InvalidPgSchemeUrl=[PropertiesUtils] PropertiesUtils.parse_pg_scheme_url was called, but the passed in string did not begin with 'postgresql://' or 'postgres://'. Detected connection string: '{}'. diff --git a/aws_advanced_python_wrapper/utils/properties.py b/aws_advanced_python_wrapper/utils/properties.py index 50133e438..2dc060d01 100644 --- a/aws_advanced_python_wrapper/utils/properties.py +++ b/aws_advanced_python_wrapper/utils/properties.py @@ -340,9 +340,38 @@ class WrapperProperties: APP_ID = WrapperProperty("app_id", "The ID of the AWS application configured on Okta", None) # Fastest Response Strategy - RESPONSE_MEASUREMENT_INTERVAL_MILLIS = WrapperProperty("response_measurement_interval_ms", - "Interval in milliseconds between measuring response time to a database host", - 30_000) + RESPONSE_MEASUREMENT_INTERVAL_MS = WrapperProperty("response_measurement_interval_ms", + "Interval in milliseconds between measuring response time to a database host", + 30_000) + + # Limitless + LIMITLESS_MONITOR_DISPOSAL_TIME_MS = WrapperProperty("limitless_transaction_router_monitor_disposal_time_ms", + "Interval in milliseconds for an Limitless router monitor to be " + "considered inactive and to be disposed.", + 600_000) + + LIMITLESS_INTERVAL_MILLIS = WrapperProperty("limitless_transaction_router_monitor_interval_ms", + "Interval in millis between polling for Limitless Transaction Routers to the database.", + 7_500) + + WAIT_FOR_ROUTER_INFO = WrapperProperty("limitless_wait_for_transaction_router_info", + "If the cache of transaction router info is empty " + "and a new connection is made, this property toggles whether " + "the plugin will wait and synchronously fetch transaction router info before selecting a transaction " + "router to connect to, or to fall back to using the provided DB Shard Group endpoint URL.", + True) + + GET_ROUTER_RETRY_INTERVAL_MS = WrapperProperty("limitless_get_transaction_router_retry_interval_ms", + "Interval in milliseconds between retries fetching Limitless Transaction Router information.", + 300) + + GET_ROUTER_MAX_RETRIES = WrapperProperty("limitless_get_transaction_router_max_retries", + "Max number of connection retries the Limitless Connection Plugin will attempt.", + 5) + + MAX_RETRIES_MS = WrapperProperty("limitless_connection_max_retries_ms", + "Interval in milliseconds between polling for Limitless Transaction Routers to the database.", + 7_500) # Telemetry ENABLE_TELEMETRY = WrapperProperty( diff --git a/benchmarks/benchmark_plugin.py b/benchmarks/benchmark_plugin.py index 6800935e1..602f70a37 100644 --- a/benchmarks/benchmark_plugin.py +++ b/benchmarks/benchmark_plugin.py @@ -14,7 +14,7 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Callable, Dict, List, Set +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Set if TYPE_CHECKING: from aws_advanced_python_wrapper.driver_dialect import DriverDialect @@ -74,7 +74,7 @@ def notify_connection_changed(self, changes: Set[ConnectionEvent]) -> OldConnect def accepts_strategy(self, role: HostRole, strategy: str) -> bool: return False - def get_host_info_by_strategy(self, role: HostRole, strategy: str) -> HostInfo: + def get_host_info_by_strategy(self, role: HostRole, strategy: str, host_list: Optional[List[HostInfo]] = None) -> HostInfo: self.resources.append("get_host_info_by_strategy") return HostInfo("host", 1234, role) diff --git a/tests/unit/test_host_response_time_monitor.py b/tests/unit/test_host_response_time_monitor.py index e9a9b4e7c..fa6e44bc5 100644 --- a/tests/unit/test_host_response_time_monitor.py +++ b/tests/unit/test_host_response_time_monitor.py @@ -37,7 +37,7 @@ def host_info(): @pytest.fixture def props(): - return Properties({WrapperProperties.RESPONSE_MEASUREMENT_INTERVAL_MILLIS.name: 30000, "frt-some_prop": "some_value"}) + return Properties({WrapperProperties.RESPONSE_MEASUREMENT_INTERVAL_MS.name: 30000, "frt-some_prop": "some_value"}) @pytest.fixture diff --git a/tests/unit/test_plugin_manager.py b/tests/unit/test_plugin_manager.py index 286b3f3a6..3fbb1e09c 100644 --- a/tests/unit/test_plugin_manager.py +++ b/tests/unit/test_plugin_manager.py @@ -520,7 +520,7 @@ def notify_connection_changed(self, changes: Set[ConnectionEvent]) -> OldConnect def accepts_strategy(self, role: HostRole, strategy: str) -> bool: return False - def get_host_info_by_strategy(self, role: HostRole, strategy: str) -> HostInfo: + def get_host_info_by_strategy(self, role: HostRole, strategy: str, host_list: Optional[List[HostInfo]] = None) -> HostInfo: return HostInfo(type(self).__name__ + ":host_info") From 1c67c10b7db52b2e54ddf758237ac407b409d4f9 Mon Sep 17 00:00:00 2001 From: Juan Lee Date: Wed, 9 Jul 2025 20:05:28 -0700 Subject: [PATCH 2/2] feat: limitless plugin implementation --- .../connection_provider.py | 7 +- .../database_dialect.py | 6 + aws_advanced_python_wrapper/host_selector.py | 12 + ...nnection_plugin.py => limitless_plugin.py} | 72 +-- aws_advanced_python_wrapper/plugin_service.py | 7 +- ...dvanced_python_wrapper_messages.properties | 4 +- .../sql_alchemy_connection_provider.py | 7 +- .../utils/properties.py | 2 +- .../utils/rds_url_type.py | 1 + aws_advanced_python_wrapper/utils/rdsutils.py | 29 +- .../unit/test_highest_weight_host_selector.py | 61 ++ tests/unit/test_limitless_plugin.py | 143 +++++ tests/unit/test_limitless_router_service.py | 549 ++++++++++++++++++ tests/unit/test_rds_utils.py | 9 +- 14 files changed, 850 insertions(+), 59 deletions(-) rename aws_advanced_python_wrapper/{limitless_connection_plugin.py => limitless_plugin.py} (90%) create mode 100644 tests/unit/test_highest_weight_host_selector.py create mode 100644 tests/unit/test_limitless_plugin.py create mode 100644 tests/unit/test_limitless_router_service.py diff --git a/aws_advanced_python_wrapper/connection_provider.py b/aws_advanced_python_wrapper/connection_provider.py index e3749807b..16ba12fec 100644 --- a/aws_advanced_python_wrapper/connection_provider.py +++ b/aws_advanced_python_wrapper/connection_provider.py @@ -26,8 +26,8 @@ from aws_advanced_python_wrapper.errors import AwsWrapperError from aws_advanced_python_wrapper.host_selector import ( - HostSelector, RandomHostSelector, RoundRobinHostSelector, - WeightedRandomHostSelector) + HighestWeightHostSelector, HostSelector, RandomHostSelector, + RoundRobinHostSelector, WeightedRandomHostSelector) from aws_advanced_python_wrapper.plugin import CanReleaseResources from aws_advanced_python_wrapper.utils.log import Logger from aws_advanced_python_wrapper.utils.messages import Messages @@ -98,7 +98,8 @@ def connect( class DriverConnectionProvider(ConnectionProvider): _accepted_strategies: Dict[str, HostSelector] = {"random": RandomHostSelector(), "round_robin": RoundRobinHostSelector(), - "weighted_random": WeightedRandomHostSelector()} + "weighted_random": WeightedRandomHostSelector(), + "highest_weight": HighestWeightHostSelector()} def accepts_host_info(self, host_info: HostInfo, props: Properties) -> bool: return True diff --git a/aws_advanced_python_wrapper/database_dialect.py b/aws_advanced_python_wrapper/database_dialect.py index 0706ad1a5..40d1bc7d2 100644 --- a/aws_advanced_python_wrapper/database_dialect.py +++ b/aws_advanced_python_wrapper/database_dialect.py @@ -18,6 +18,7 @@ Protocol, Tuple, runtime_checkable) from aws_advanced_python_wrapper.driver_info import DriverInfo +from aws_advanced_python_wrapper.utils.rds_url_type import RdsUrlType if TYPE_CHECKING: from aws_advanced_python_wrapper.pep249 import Connection @@ -631,6 +632,11 @@ def get_dialect(self, driver_dialect: str, props: Properties) -> DatabaseDialect if target_driver_type is TargetDriverType.POSTGRES: rds_type = self._rds_helper.identify_rds_type(host) + if rds_type == RdsUrlType.RDS_AURORA_LIMITLESS_DB_SHARD_GROUP: + self._can_update = False + self._dialect_code = DialectCode.AURORA_PG + self._dialect = DatabaseDialectManager._known_dialects_by_code[DialectCode.AURORA_PG] + return self._dialect if rds_type.is_rds_cluster: self._can_update = True self._dialect_code = DialectCode.AURORA_PG diff --git a/aws_advanced_python_wrapper/host_selector.py b/aws_advanced_python_wrapper/host_selector.py index f6a388056..c8125ff58 100644 --- a/aws_advanced_python_wrapper/host_selector.py +++ b/aws_advanced_python_wrapper/host_selector.py @@ -260,3 +260,15 @@ def _update_host_weight_map_from_string(self, props: Optional[Properties] = None except ValueError: logger.error(message, pair) raise AwsWrapperError(Messages.get_formatted(message, pair)) + + +class HighestWeightHostSelector(HostSelector): + + def get_host(self, hosts: Tuple[HostInfo, ...], role: HostRole, props: Optional[Properties] = None) -> HostInfo: + eligible_hosts: List[HostInfo] = [host for host in hosts if + host.role == role and host.get_availability() == HostAvailability.AVAILABLE] + + if len(eligible_hosts) == 0: + raise AwsWrapperError(Messages.get_formatted("HostSelector.NoHostsMatchingRole", role)) + + return max(eligible_hosts, key=lambda host: host.weight) diff --git a/aws_advanced_python_wrapper/limitless_connection_plugin.py b/aws_advanced_python_wrapper/limitless_plugin.py similarity index 90% rename from aws_advanced_python_wrapper/limitless_connection_plugin.py rename to aws_advanced_python_wrapper/limitless_plugin.py index 41b89cd18..52b05260b 100644 --- a/aws_advanced_python_wrapper/limitless_connection_plugin.py +++ b/aws_advanced_python_wrapper/limitless_plugin.py @@ -35,6 +35,7 @@ SlidingExpirationCacheWithCleanupThread from aws_advanced_python_wrapper.utils.telemetry.telemetry import ( TelemetryContext, TelemetryFactory, TelemetryTraceLevel) +from aws_advanced_python_wrapper.utils.utils import LogUtils if TYPE_CHECKING: from aws_advanced_python_wrapper.driver_dialect import DriverDialect @@ -44,12 +45,16 @@ logger = Logger(__name__) -class LimitlessConnectionPlugin(Plugin): +class LimitlessPlugin(Plugin): _SUBSCRIBED_METHODS: Set[str] = {"connect"} def __init__(self, plugin_service: PluginService, props: Properties): self._plugin_service = plugin_service self._properties = props + self._limitless_router_service = LimitlessRouterService( + self._plugin_service, + LimitlessQueryHelper(self._plugin_service) + ) @property def subscribed_methods(self) -> Set[str]: @@ -68,23 +73,16 @@ def connect( dialect: DatabaseDialect = self._plugin_service.database_dialect if not isinstance(dialect, AuroraLimitlessDialect): - connection = connect_func() refreshed_dialect = self._plugin_service.database_dialect - if not isinstance(refreshed_dialect, AuroraLimitlessDialect): raise UnsupportedOperationError( - Messages.get_formatted("LimitlessConnectionPlugin.UnsupportedDialectOrDatabase", + Messages.get_formatted("LimitlessPlugin.UnsupportedDialectOrDatabase", type(refreshed_dialect).__name__)) - limitless_router_service = LimitlessRouterService( - self._plugin_service, - LimitlessQueryHelper(self._plugin_service) - ) - if is_initial_connection: - limitless_router_service.start_monitoring(host_info, props) + self._limitless_router_service.start_monitoring(host_info, props) - context: LimitlessConnectionContext = LimitlessConnectionContext( + self._context: LimitlessContext = LimitlessContext( host_info, props, connection, @@ -92,18 +90,18 @@ def connect( [], self ) - limitless_router_service.establish_connection(context) - connection = context.get_connection() + self._limitless_router_service.establish_connection(self._context) + connection = self._context.get_connection() if connection is not None and not self._plugin_service.driver_dialect.is_closed(connection): - return context.get_connection() + return connection - raise AwsWrapperError(Messages.get_formatted("LimitlessConnectionPlugin.FailedToConnectToHost", host_info.host)) + raise AwsWrapperError(Messages.get_formatted("LimitlessPlugin.FailedToConnectToHost", host_info.host)) -class LimitlessConnectionPluginFactory: +class LimitlessPluginFactory: def get_instance(self, plugin_service: PluginService, props: Properties) -> Plugin: - return LimitlessConnectionPlugin(plugin_service, props) + return LimitlessPlugin(plugin_service, props) class LimitlessRouterMonitor: @@ -172,6 +170,7 @@ def run(self): lambda _: new_limitless_routers, WrapperProperties.LIMITLESS_MONITOR_DISPOSAL_TIME_MS.get( self._properties) * 1_000_000) + logger.debug(LogUtils.log_topology(tuple(new_limitless_routers), "[limitlessRouterMonitor] Topology:")) sleep(self._interval_ms / 1000) @@ -257,7 +256,7 @@ def _create_host_info(self, result: Tuple[Any, Any], host_port_to_map: int) -> H return HostInfo(host_name, host_port_to_map, weight=weight, host_id=host_name) -class LimitlessConnectionContext: +class LimitlessContext: def __init__(self, host_info: HostInfo, @@ -265,7 +264,7 @@ def __init__(self, connection: Optional[Connection], connect_func: Callable, limitless_routers: List[HostInfo], - connection_plugin: LimitlessConnectionPlugin) -> None: + connection_plugin: LimitlessPlugin) -> None: self._host_info = host_info self._props = props self._connection = connection @@ -326,23 +325,24 @@ def __init__(self, plugin_service: PluginService, query_helper: LimitlessQueryHe self._plugin_service = plugin_service self._query_helper = query_helper - def establish_connection(self, context: LimitlessConnectionContext) -> None: + def establish_connection(self, context: LimitlessContext) -> None: context.set_limitless_routers(self._get_limitless_routers( self._plugin_service.host_list_provider.get_cluster_id(), context.get_props())) if context.get_limitless_routers() is None or len(context.get_limitless_routers()) == 0: - logger.debug("LimitlessRouterServiceImpl.limitlessRouterCacheEmpty") + logger.debug("LimitlessRouterService.LimitlessRouterCacheEmpty") wait_for_router_info = WrapperProperties.WAIT_FOR_ROUTER_INFO.get(context.get_props()) if wait_for_router_info: self._synchronously_get_limitless_routers_with_retry(context) else: - logger.debug("LimitlessRouterServiceImpl.UsingProvidedConnectUrl") + logger.debug("LimitlessRouterService.UsingProvidedConnectUrl") if context.get_connection() is None or self._plugin_service.driver_dialect.is_closed(context.get_connection()): context.set_connection(context.get_connect_func()()) + return - if context.get_host_info in context.get_limitless_routers(): - logger.debug("LimitlessRouterServiceImpl.ConnectWithHost") + if context.get_host_info() in context.get_limitless_routers(): + logger.debug(Messages.get_formatted("LimitlessRouterService.ConnectWithHost", context.get_host_info().host)) if context.get_connection() is None: try: context.set_connection(context.get_connect_func()()) @@ -356,7 +356,7 @@ def establish_connection(self, context: LimitlessConnectionContext) -> None: try: selected_host_info = self._plugin_service.get_host_info_by_strategy( HostRole.WRITER, "weighted_random", context.get_limitless_routers()) - logger.debug("LimitlessRouterServiceImpl.SelectedHost", "None" if selected_host_info is None else selected_host_info.host) + logger.debug("LimitlessRouterService.SelectedHost", "None" if selected_host_info is None else selected_host_info.host) except Exception as e: if self._is_login_exception(e) or isinstance(e, UnsupportedOperationError): raise e @@ -375,7 +375,7 @@ def establish_connection(self, context: LimitlessConnectionContext) -> None: raise e if selected_host_info is not None: - logger.debug("LimitlessRouterServiceImpl.FailedToConnectToHost", selected_host_info.host) + logger.debug("LimitlessRouterService.FailedToConnectToHost", selected_host_info.host) selected_host_info.set_availability(HostAvailability.UNAVAILABLE) self._retry_connection_with_least_loaded_routers(context) @@ -389,7 +389,7 @@ def _get_limitless_routers(self, cluster_id: str, props: Properties) -> List[Hos return [] return routers - def _retry_connection_with_least_loaded_routers(self, context: LimitlessConnectionContext) -> None: + def _retry_connection_with_least_loaded_routers(self, context: LimitlessContext) -> None: retry_count = 0 max_retries = WrapperProperties.MAX_RETRIES_MS.get_int(context.get_props()) while retry_count < max_retries: @@ -400,7 +400,7 @@ def _retry_connection_with_least_loaded_routers(self, context: LimitlessConnecti if (context.get_limitless_routers() is None or len(context.get_limitless_routers()) == 0 or not context.is_any_router_available()): - logger.debug("LimitlessRouterServiceImpl.NoRoutersAvailableForRetry") + logger.debug("LimitlessRouterService.NoRoutersAvailableForRetry") if context.get_connection() is not None and not self._plugin_service.driver_dialect.is_closed(context.get_connection()): return @@ -417,14 +417,14 @@ def _retry_connection_with_least_loaded_routers(self, context: LimitlessConnecti try: selected_host_info = self._plugin_service.get_host_info_by_strategy( - HostRole.WRITER, "weighted_random", context.get_limitless_routers()) - logger.debug("LimitlessRouterServiceImpl.SelectedHostForRetry", + HostRole.WRITER, "highest_weight", context.get_limitless_routers()) + logger.debug("LimitlessRouterService.SelectedHostForRetry", "None" if selected_host_info is None else selected_host_info.host) if selected_host_info is None: continue except UnsupportedOperationError as e: - logger.error("LimitlessRouterServiceImpl.IncorrectConfiguration") + logger.error("LimitlessRouterService.IncorrectConfiguration") raise e except AwsWrapperError: continue @@ -438,14 +438,14 @@ def _retry_connection_with_least_loaded_routers(self, context: LimitlessConnecti if self._is_login_exception(e): raise e selected_host_info.set_availability(HostAvailability.UNAVAILABLE) - logger.debug("LimitlessRouterServiceImpl.FailedToConnectToHost", selected_host_info.host) + logger.debug("LimitlessRouterService.FailedToConnectToHost", selected_host_info.host) raise AwsWrapperError(Messages.get("LimitlessRouterService.MaxRetriesExceeded")) - def _synchronously_get_limitless_routers_with_retry(self, context: LimitlessConnectionContext) -> None: - logger.debug("LimitlessRouterServiceImpl.SynchronouslyGetLimitlessRouters") + def _synchronously_get_limitless_routers_with_retry(self, context: LimitlessContext) -> None: + logger.debug("LimitlessRouterService.SynchronouslyGetLimitlessRouters") retry_count = -1 - max_retries = WrapperProperties.MAX_RETRIES_MS.get_int(context.get_props()) + max_retries = WrapperProperties.GET_ROUTER_MAX_RETRIES.get_int(context.get_props()) retry_interval_ms = WrapperProperties.GET_ROUTER_RETRY_INTERVAL_MS.get_float(context.get_props()) first_iteration = True while first_iteration or retry_count < max_retries: @@ -467,7 +467,7 @@ def _synchronously_get_limitless_routers_with_retry(self, context: LimitlessConn raise AwsWrapperError(Messages.get("LimitlessRouterService.NoRoutersAvailable")) - def _synchronously_get_limitless_routers(self, context: LimitlessConnectionContext) -> None: + def _synchronously_get_limitless_routers(self, context: LimitlessContext) -> None: cache_expiration_nano: int = WrapperProperties.LIMITLESS_MONITOR_DISPOSAL_TIME_MS.get_int(context.get_props()) * 1_000_000 lock = LimitlessRouterService._force_get_limitless_routers_lock_map.compute_if_absent( diff --git a/aws_advanced_python_wrapper/plugin_service.py b/aws_advanced_python_wrapper/plugin_service.py index 11c13c687..ca276ad0b 100644 --- a/aws_advanced_python_wrapper/plugin_service.py +++ b/aws_advanced_python_wrapper/plugin_service.py @@ -24,8 +24,7 @@ FastestResponseStrategyPluginFactory from aws_advanced_python_wrapper.federated_plugin import \ FederatedAuthPluginFactory -from aws_advanced_python_wrapper.limitless_connection_plugin import \ - LimitlessConnectionPluginFactory +from aws_advanced_python_wrapper.limitless_plugin import LimitlessPluginFactory from aws_advanced_python_wrapper.okta_plugin import OktaAuthPluginFactory from aws_advanced_python_wrapper.states.session_state_service import ( SessionStateService, SessionStateServiceImpl) @@ -731,7 +730,7 @@ class PluginManager(CanReleaseResources): "federated_auth": FederatedAuthPluginFactory, "okta": OktaAuthPluginFactory, "initial_connection": AuroraInitialConnectionStrategyPluginFactory, - "limitless": LimitlessConnectionPluginFactory, + "limitless": LimitlessPluginFactory, } WEIGHT_RELATIVE_TO_PRIOR_PLUGIN = -1 @@ -751,7 +750,7 @@ class PluginManager(CanReleaseResources): IamAuthPluginFactory: 700, AwsSecretsManagerPluginFactory: 800, FederatedAuthPluginFactory: 900, - LimitlessConnectionPluginFactory: 950, + LimitlessPluginFactory: 950, OktaAuthPluginFactory: 1000, ConnectTimePluginFactory: WEIGHT_RELATIVE_TO_PRIOR_PLUGIN, ExecuteTimePluginFactory: WEIGHT_RELATIVE_TO_PRIOR_PLUGIN, diff --git a/aws_advanced_python_wrapper/resources/aws_advanced_python_wrapper_messages.properties b/aws_advanced_python_wrapper/resources/aws_advanced_python_wrapper_messages.properties index 28fd0582b..a7c37c484 100644 --- a/aws_advanced_python_wrapper/resources/aws_advanced_python_wrapper_messages.properties +++ b/aws_advanced_python_wrapper/resources/aws_advanced_python_wrapper_messages.properties @@ -153,8 +153,8 @@ IamAuthPlugin.InvalidHost=[IamAuthPlugin] Invalid IAM host {}. The IAM host must IamAuthPlugin.IsNoneOrEmpty=[IamAuthPlugin] Property "{}" is None or empty. IamAuthUtils.GeneratedNewAuthToken=Generated new authentication token = {} -LimitlessConnectionPlugin.FailedToConnectToHost=[LimitlessConnectionPlugin] Failed to connect to host {}. -LimitlessConnectionPlugin.UnsupportedDialectOrDatabase=[LimitlessConnectionPlugin] Unsupported dialect '{}' encountered. Please ensure the connection parameters are correct, and refer to the documentation to ensure that the connecting database is compatible with the Limitless Connection Plugin. +LimitlessPlugin.FailedToConnectToHost=[LimitlessPlugin] Failed to connect to host {}. +LimitlessPlugin.UnsupportedDialectOrDatabase=[LimitlessPlugin] Unsupported dialect '{}' encountered. Please ensure the connection parameters are correct, and refer to the documentation to ensure that the connecting database is compatible with the Limitless Connection Plugin. LimitlessQueryHelper.UnsupportedDialectOrDatabase=[LimitlessQueryHelper] Unsupported dialect '{}' encountered. Please ensure JDBC connection parameters are correct, and refer to the documentation to ensure that the connecting database is compatible with the Limitless Connection Plugin. diff --git a/aws_advanced_python_wrapper/sql_alchemy_connection_provider.py b/aws_advanced_python_wrapper/sql_alchemy_connection_provider.py index 4bb12fe77..2f3c746d4 100644 --- a/aws_advanced_python_wrapper/sql_alchemy_connection_provider.py +++ b/aws_advanced_python_wrapper/sql_alchemy_connection_provider.py @@ -26,8 +26,8 @@ from aws_advanced_python_wrapper.connection_provider import ConnectionProvider from aws_advanced_python_wrapper.errors import AwsWrapperError from aws_advanced_python_wrapper.host_selector import ( - HostSelector, RandomHostSelector, RoundRobinHostSelector, - WeightedRandomHostSelector) + HighestWeightHostSelector, HostSelector, RandomHostSelector, + RoundRobinHostSelector, WeightedRandomHostSelector) from aws_advanced_python_wrapper.plugin import CanReleaseResources from aws_advanced_python_wrapper.utils.messages import Messages from aws_advanced_python_wrapper.utils.properties import (Properties, @@ -48,7 +48,8 @@ class SqlAlchemyPooledConnectionProvider(ConnectionProvider, CanReleaseResources _LEAST_CONNECTIONS: ClassVar[str] = "least_connections" _accepted_strategies: Dict[str, HostSelector] = {"random": RandomHostSelector(), "round_robin": RoundRobinHostSelector(), - "weighted_random": WeightedRandomHostSelector()} + "weighted_random": WeightedRandomHostSelector(), + "highest_weight": HighestWeightHostSelector()} _rds_utils: ClassVar[RdsUtils] = RdsUtils() _database_pools: ClassVar[SlidingExpirationCache[PoolKey, QueuePool]] = SlidingExpirationCache( should_dispose_func=lambda queue_pool: queue_pool.checkedout() == 0, diff --git a/aws_advanced_python_wrapper/utils/properties.py b/aws_advanced_python_wrapper/utils/properties.py index 2dc060d01..ccd98eb75 100644 --- a/aws_advanced_python_wrapper/utils/properties.py +++ b/aws_advanced_python_wrapper/utils/properties.py @@ -369,7 +369,7 @@ class WrapperProperties: "Max number of connection retries the Limitless Connection Plugin will attempt.", 5) - MAX_RETRIES_MS = WrapperProperty("limitless_connection_max_retries_ms", + MAX_RETRIES_MS = WrapperProperty("limitless_max_retries_ms", "Interval in milliseconds between polling for Limitless Transaction Routers to the database.", 7_500) diff --git a/aws_advanced_python_wrapper/utils/rds_url_type.py b/aws_advanced_python_wrapper/utils/rds_url_type.py index ac1837a70..7226c33ce 100644 --- a/aws_advanced_python_wrapper/utils/rds_url_type.py +++ b/aws_advanced_python_wrapper/utils/rds_url_type.py @@ -33,4 +33,5 @@ def __init__(self, is_rds: bool, is_rds_cluster: bool): RDS_CUSTOM_CLUSTER = True, True, RDS_PROXY = True, False, RDS_INSTANCE = True, False, + RDS_AURORA_LIMITLESS_DB_SHARD_GROUP = True, False, OTHER = False, False diff --git a/aws_advanced_python_wrapper/utils/rdsutils.py b/aws_advanced_python_wrapper/utils/rdsutils.py index 28b3f014b..7e289d2db 100644 --- a/aws_advanced_python_wrapper/utils/rdsutils.py +++ b/aws_advanced_python_wrapper/utils/rdsutils.py @@ -61,7 +61,7 @@ class RdsUtils: """ AURORA_DNS_PATTERN = r"^(?P.+)\." \ - r"(?Pproxy-|cluster-|cluster-ro-|cluster-custom-|limitless-)?" \ + r"(?Pproxy-|cluster-|cluster-ro-|cluster-custom-|shardgrp-)?" \ r"(?P[a-zA-Z0-9]+\." \ r"(?P[a-zA-Z0-9\-]+)\.rds\.amazonaws\.com)(?!\.cn)$" AURORA_INSTANCE_PATTERN = r"^(?P.+)\." \ @@ -71,6 +71,11 @@ class RdsUtils: r"(?Pcluster-|cluster-ro-)+" \ r"(?P[a-zA-Z0-9]+\." \ r"(?P[a-zA-Z0-9\-]+)\.rds\.amazonaws\.com)(?!\.cn)$" + AURORA_LIMITLESS_CLUSTER_PATTERN = r"^(?P.+)\." \ + r"(?Pshardgrp-)+" \ + r"(?P[a-zA-Z0-9]+\." \ + r"(?P[a-zA-Z0-9\-]+)" \ + r"\.rds\.(amazonaws\.com\.?|amazonaws\.com\.cn\.?|sc2s\.sgov\.gov\.?|c2s\.ic\.gov\.?))$" AURORA_CUSTOM_CLUSTER_PATTERN = r"^(?P.+)\." \ r"(?Pcluster-custom-)+" \ r"(?P[a-zA-Z0-9]+\." \ @@ -80,11 +85,11 @@ class RdsUtils: r"(?P[a-zA-Z0-9]+\." \ r"(?P[a-zA-Z0-9\\-]+)\.rds\.amazonaws\.com)(?!\.cn)$" AURORA_OLD_CHINA_DNS_PATTERN = r"^(?P.+)\." \ - r"(?Pproxy-|cluster-|cluster-ro-|cluster-custom-|limitless-)?" \ + r"(?Pproxy-|cluster-|cluster-ro-|cluster-custom-|shardgrp-)?" \ r"(?P[a-zA-Z0-9]+\." \ r"(?P[a-zA-Z0-9\-]+)\.rds\.amazonaws\.com\.cn)$" AURORA_CHINA_DNS_PATTERN = r"^(?P.+)\." \ - r"(?Pproxy-|cluster-|cluster-ro-|cluster-custom-|limitless-)?" \ + r"(?Pproxy-|cluster-|cluster-ro-|cluster-custom-|shardgrp-)?" \ r"(?P[a-zA-Z0-9]+\." \ r"rds\.(?P[a-zA-Z0-9\-]+)\.amazonaws\.com\.cn)$" AURORA_OLD_CHINA_CLUSTER_PATTERN = r"^(?P.+)\." \ @@ -96,7 +101,7 @@ class RdsUtils: r"(?P[a-zA-Z0-9]+\." \ r"rds\.(?P[a-zA-Z0-9\-]+)\.amazonaws\.com\.cn)$" AURORA_GOV_DNS_PATTERN = r"^(?P.+)\." \ - r"(?Pproxy-|cluster-|cluster-ro-|cluster-custom-|limitless-)?" \ + r"(?Pproxy-|cluster-|cluster-ro-|cluster-custom-|shardgrp-)?" \ r"(?P[a-zA-Z0-9]+\.rds\.(?P[a-zA-Z0-9\-]+)" \ r"\.(amazonaws\.com|c2s\.ic\.gov|sc2s\.sgov\.gov))$" AURORA_GOV_CLUSTER_PATTERN = r"^(?P.+)\." \ @@ -179,18 +184,26 @@ def is_reader_cluster_dns(self, host: str) -> bool: dns_group = self._get_dns_group(host) return dns_group is not None and dns_group.casefold() == "cluster-ro-" + def is_limitless_database_shard_group_dns(self, host: str) -> bool: + dns_group = self._get_dns_group(host) + return dns_group is not None and dns_group.casefold() == "shardgrp-" + def get_rds_cluster_host_url(self, host: str): if not host or not host.strip(): return None - for pattern in [RdsUtils.AURORA_DNS_PATTERN, + for pattern in [RdsUtils.AURORA_CLUSTER_PATTERN, RdsUtils.AURORA_CHINA_DNS_PATTERN, RdsUtils.AURORA_OLD_CHINA_DNS_PATTERN, - RdsUtils.AURORA_GOV_DNS_PATTERN]: + RdsUtils.AURORA_GOV_DNS_PATTERN, + RdsUtils.AURORA_LIMITLESS_CLUSTER_PATTERN]: if m := search(pattern, host): group = self._get_regex_group(m, RdsUtils.DNS_GROUP) if group is not None: - return sub(pattern, r"\g.cluster-\g", host) + if pattern == RdsUtils.AURORA_LIMITLESS_CLUSTER_PATTERN: + return sub(pattern, r"\g.shardgrp-\g", host) + else: + return sub(pattern, r"\g.cluster-\g", host) return None return None @@ -236,6 +249,8 @@ def identify_rds_type(self, host: Optional[str]) -> RdsUrlType: return RdsUrlType.RDS_WRITER_CLUSTER elif self.is_reader_cluster_dns(host): return RdsUrlType.RDS_READER_CLUSTER + elif self.is_limitless_database_shard_group_dns(host): + return RdsUrlType.RDS_AURORA_LIMITLESS_DB_SHARD_GROUP elif self.is_rds_custom_cluster_dns(host): return RdsUrlType.RDS_CUSTOM_CLUSTER elif self.is_rds_proxy_dns(host): diff --git a/tests/unit/test_highest_weight_host_selector.py b/tests/unit/test_highest_weight_host_selector.py new file mode 100644 index 000000000..fb4f1d637 --- /dev/null +++ b/tests/unit/test_highest_weight_host_selector.py @@ -0,0 +1,61 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from aws_advanced_python_wrapper.host_availability import HostAvailability +from aws_advanced_python_wrapper.host_selector import HighestWeightHostSelector +from aws_advanced_python_wrapper.hostinfo import HostInfo, HostRole +from aws_advanced_python_wrapper.utils.properties import Properties + +HOST_ROLE = HostRole.READER + + +def test_get_host_given_unavailable_host(): + unavailable_host: HostInfo = HostInfo(host="some_unavailable_host", role=HOST_ROLE, availability=HostAvailability.UNAVAILABLE) + available_host: HostInfo = HostInfo(host="some_available_host", role=HOST_ROLE, availability=HostAvailability.AVAILABLE) + + host_selector = HighestWeightHostSelector() + actual_host = host_selector.get_host((unavailable_host, available_host), HOST_ROLE, Properties()) + + assert available_host == actual_host + + +def test_get_host_given_multiple_unavailable_hosts(): + hosts = ( + HostInfo(host="some_unavailable_host", role=HOST_ROLE, availability=HostAvailability.UNAVAILABLE), + HostInfo(host="some_unavailable_host", role=HOST_ROLE, availability=HostAvailability.UNAVAILABLE), + HostInfo(host="some_available_host", role=HOST_ROLE, availability=HostAvailability.AVAILABLE) + ) + + host_selector = HighestWeightHostSelector() + actual_host = host_selector.get_host(hosts, HOST_ROLE, Properties()) + + assert HostAvailability.AVAILABLE == actual_host.get_availability() + + +def test_get_host_given_different_weights(): + + highest_weight_host = HostInfo(host="some_available_host", role=HOST_ROLE, availability=HostAvailability.AVAILABLE, weight=3) + + hosts = ( + HostInfo(host="some_unavailable_host", role=HOST_ROLE, availability=HostAvailability.UNAVAILABLE), + HostInfo(host="some_unavailable_host", role=HOST_ROLE, availability=HostAvailability.UNAVAILABLE), + HostInfo(host="some_available_host", role=HOST_ROLE, availability=HostAvailability.AVAILABLE, weight=1), + HostInfo(host="some_available_host", role=HOST_ROLE, availability=HostAvailability.AVAILABLE, weight=2), + highest_weight_host + ) + + host_selector = HighestWeightHostSelector() + actual_host = host_selector.get_host(hosts, HOST_ROLE, Properties()) + + assert actual_host == highest_weight_host diff --git a/tests/unit/test_limitless_plugin.py b/tests/unit/test_limitless_plugin.py new file mode 100644 index 000000000..8bcadf1c2 --- /dev/null +++ b/tests/unit/test_limitless_plugin.py @@ -0,0 +1,143 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from unittest.mock import PropertyMock + +import psycopg +import pytest + +from aws_advanced_python_wrapper.database_dialect import (DatabaseDialect, + MysqlDatabaseDialect) +from aws_advanced_python_wrapper.errors import (AwsWrapperError, + UnsupportedOperationError) +from aws_advanced_python_wrapper.hostinfo import HostInfo, HostRole +from aws_advanced_python_wrapper.limitless_plugin import LimitlessPlugin +from aws_advanced_python_wrapper.utils.messages import Messages +from aws_advanced_python_wrapper.utils.properties import Properties + + +@pytest.fixture +def mock_driver_dialect(mocker): + driver_dialect_mock = mocker.MagicMock() + driver_dialect_mock.is_closed.return_value = False + return driver_dialect_mock + + +@pytest.fixture +def mock_plugin_service(mocker, mock_driver_dialect, mock_conn, host_info): + service_mock = mocker.MagicMock() + service_mock.current_connection = mock_conn + service_mock.current_host_info = host_info + + type(service_mock).driver_dialect = mocker.PropertyMock(return_value=mock_driver_dialect) + return service_mock + + +@pytest.fixture +def mock_conn(mocker): + return mocker.MagicMock(spec=psycopg.Connection) + + +@pytest.fixture +def mock_limitless_router_service(mocker): + limitless_router_service_mock = mocker.MagicMock() + return limitless_router_service_mock + + +@pytest.fixture +def host_info(): + return HostInfo(host="host-info", role=HostRole.READER) + + +@pytest.fixture +def props(): + return Properties() + + +@pytest.fixture +def plugin(mock_plugin_service, props, mock_limitless_router_service): + plugin = LimitlessPlugin(mock_plugin_service, props) + plugin._limitless_router_service = mock_limitless_router_service + return plugin + + +def test_connect(mocker, plugin, host_info, props, mock_conn, mock_limitless_router_service): + def replace_context_connection(invocation): + context = invocation._connection_plugin._context + context._connection = mock_conn + return None + + mock_connect_func = mocker.MagicMock() + mock_connect_func.return_value = None + plugin._limitless_router_service.establish_connection.side_effect = replace_context_connection + + connection = plugin.connect(mocker.MagicMock(), mocker.MagicMock(), host_info, props, True, mock_connect_func) + mock_connect_func.assert_not_called() + mock_limitless_router_service.start_monitoring.assert_called_once_with(host_info, props) + mock_limitless_router_service.establish_connection.assert_called_once() + assert mock_conn == connection + + +def test_connect_none_connection(mocker, plugin, host_info, props, mock_conn, mock_limitless_router_service): + def replace_context_connection_to_none(invocation): + context = invocation._connection_plugin._context + context._connection = None + return None + + mock_connect_func = mocker.MagicMock() + mock_connect_func.return_value = mock_conn + plugin._limitless_router_service.establish_connection.side_effect = replace_context_connection_to_none + + with pytest.raises(Exception) as e_info: + plugin.connect(mocker.MagicMock(), mocker.MagicMock(), host_info, props, True, mock_connect_func) + + mock_connect_func.assert_not_called() + mock_limitless_router_service.start_monitoring.assert_called_once_with(host_info, props) + mock_limitless_router_service.establish_connection.assert_called_once() + assert e_info.type == AwsWrapperError + assert str(e_info.value) == Messages.get_formatted("LimitlessPlugin.FailedToConnectToHost", host_info.host) + + +def test_connect_unsupported_dialect(mocker, plugin, host_info, props, mock_conn, mock_plugin_service, + mock_limitless_router_service): + unsupported_dialect: DatabaseDialect = MysqlDatabaseDialect() + mock_plugin_service.database_dialect = unsupported_dialect + + mock_connect_func = mocker.MagicMock() + mock_connect_func.return_value = mock_conn + + with pytest.raises(Exception) as e_info: + plugin.connect(mocker.MagicMock(), mocker.MagicMock(), host_info, props, True, mock_connect_func) + + assert e_info.type == UnsupportedOperationError + assert str(e_info.value) == Messages.get_formatted("LimitlessPlugin.UnsupportedDialectOrDatabase", type(unsupported_dialect).__name__) + + +def test_connect_supported_dialect_after_refresh( + mocker, plugin, host_info, props, mock_conn, mock_plugin_service, mock_limitless_router_service, mock_driver_dialect +): + unsupported_dialect: DatabaseDialect = MysqlDatabaseDialect() + type(mock_plugin_service).database_dialect = PropertyMock(side_effect=[unsupported_dialect, mock_driver_dialect]) + + def replace_context_connection(invocation): + context = invocation._connection_plugin._context + context._connection = mock_conn + return None + + mock_connect_func = mocker.MagicMock() + mock_connect_func.return_value = mock_conn + plugin._limitless_router_service.establish_connection.side_effect = replace_context_connection + + connection = plugin.connect(mocker.MagicMock(), mocker.MagicMock(), host_info, props, True, mock_connect_func) + mock_connect_func.assert_not_called() + mock_limitless_router_service.start_monitoring.assert_called_once_with(host_info, props) + mock_limitless_router_service.establish_connection.assert_called_once() + assert mock_conn == connection diff --git a/tests/unit/test_limitless_router_service.py b/tests/unit/test_limitless_router_service.py new file mode 100644 index 000000000..53ba45ab4 --- /dev/null +++ b/tests/unit/test_limitless_router_service.py @@ -0,0 +1,549 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import psycopg +import pytest + +from aws_advanced_python_wrapper.errors import AwsWrapperError +from aws_advanced_python_wrapper.host_availability import HostAvailability +from aws_advanced_python_wrapper.hostinfo import HostInfo, HostRole +from aws_advanced_python_wrapper.limitless_plugin import ( + LimitlessContext, LimitlessPlugin, LimitlessRouterService) +from aws_advanced_python_wrapper.utils.messages import Messages +from aws_advanced_python_wrapper.utils.properties import (Properties, + WrapperProperties) + +CLUSTER_ID: str = "some_cluster_id" +EXPIRATION_NANO_SECONDS: int = 60 * 60 * 1_000_000_000 + + +@pytest.fixture +def writer_host(): + return HostInfo("instance-0", 5432, HostRole.WRITER, HostAvailability.AVAILABLE) + + +@pytest.fixture +def reader_host1() -> HostInfo: + return HostInfo("instance-1", 5432, HostRole.READER, HostAvailability.AVAILABLE) + + +@pytest.fixture +def reader_host2(): + return HostInfo("instance-2", 5432, HostRole.READER, HostAvailability.AVAILABLE) + + +@pytest.fixture +def reader_host3(): + return HostInfo("instance-3", 5432, HostRole.READER, HostAvailability.AVAILABLE) + + +@pytest.fixture +def default_hosts(writer_host, reader_host2, reader_host3, reader_host1): + return [writer_host, reader_host1, reader_host2, reader_host3] + + +@pytest.fixture +def limitless_router1(): + return HostInfo("limitless-router-1", 5432, HostRole.READER, HostAvailability.AVAILABLE) + + +@pytest.fixture +def limitless_router2(): + return HostInfo("limitless-router-2", 5432, HostRole.WRITER, HostAvailability.AVAILABLE) + + +@pytest.fixture +def limitless_router3(): + return HostInfo("limitless-router-3", 5432, HostRole.READER, HostAvailability.UNAVAILABLE) + + +@pytest.fixture +def limitless_router4(): + return HostInfo("limitless-router-4", 5432, HostRole.READER, HostAvailability.AVAILABLE) + + +@pytest.fixture +def limitless_routers(limitless_router1, limitless_router2, limitless_router3, limitless_router4): + return [limitless_router1, limitless_router2, limitless_router3, limitless_router4] + + +@pytest.fixture +def mock_driver_dialect(mocker): + driver_dialect_mock = mocker.MagicMock() + driver_dialect_mock.is_closed.return_value = False + return driver_dialect_mock + + +@pytest.fixture +def mock_plugin_service(mocker, mock_driver_dialect, mock_conn, host_info, default_hosts): + service_mock = mocker.MagicMock() + service_mock.current_connection = mock_conn + service_mock.current_host_info = host_info + service_mock.hosts = default_hosts + service_mock.host_list_provider = mocker.MagicMock() + service_mock.host_list_provider.get_cluster_id.return_value = CLUSTER_ID + + type(service_mock).driver_dialect = mocker.PropertyMock(return_value=mock_driver_dialect) + return service_mock + + +@pytest.fixture +def mock_conn(mocker): + return mocker.MagicMock(spec=psycopg.Connection) + + +@pytest.fixture +def mock_limitless_router_service(mocker): + limitless_router_service_mock = mocker.MagicMock() + return limitless_router_service_mock + + +@pytest.fixture +def mock_limitless_query_helper(mocker): + limitless_query_helper_mock = mocker.MagicMock() + return limitless_query_helper_mock + + +@pytest.fixture +def host_info(): + return HostInfo(host="host-info", role=HostRole.READER) + + +@pytest.fixture +def props(): + return Properties() + + +@pytest.fixture +def plugin(mock_plugin_service, props, mock_limitless_router_service): + plugin = LimitlessPlugin(mock_plugin_service, props) + plugin._limitless_router_service = mock_limitless_router_service + return plugin + + +@pytest.fixture(autouse=True) +def run_before_and_after_tests(mock_limitless_router_service): + # Before + + yield + + # After + + LimitlessRouterService._limitless_router_cache.clear() + + +def test_establish_connection_empty_routers_list_then_wait_for_router_info_then_raises_exception(mocker, + mock_conn, + mock_limitless_query_helper, + host_info, + props, + mock_plugin_service): + mock_connect_func = mocker.MagicMock() + mock_connect_func.return_value = mock_conn + + input_context: LimitlessContext = LimitlessContext( + host_info, + props, + None, + mock_connect_func, + [], + mock_plugin_service + ) + + limitless_router_service: LimitlessRouterService = LimitlessRouterService(mock_plugin_service, + mock_limitless_query_helper) + + with pytest.raises(Exception) as e_info: + limitless_router_service.establish_connection(input_context) + + assert e_info.type == AwsWrapperError + assert str(e_info.value) == Messages.get("LimitlessRouterService.NoRoutersAvailable") + + +def test_establish_connection_empty_routers_list_do_not_wait_for_router_info_then_call_connection_function(mocker, + mock_conn, + mock_limitless_query_helper, + host_info, + props, + mock_plugin_service): + WrapperProperties.WAIT_FOR_ROUTER_INFO.set(props, False) + mock_connect_func = mocker.MagicMock() + mock_connect_func.return_value = mock_conn + + input_context: LimitlessContext = LimitlessContext( + host_info, + props, + None, + mock_connect_func, + [], + mock_plugin_service + ) + + limitless_router_service: LimitlessRouterService = LimitlessRouterService(mock_plugin_service, + mock_limitless_query_helper) + limitless_router_service.establish_connection(input_context) + + assert mock_conn == input_context.get_connection() + mock_connect_func.assert_called_once() + + +def test_establish_connection_host_info_in_router_cache_then_call_connection_function(mocker, + mock_conn, + mock_limitless_query_helper, + limitless_router1, + props, + mock_plugin_service, + limitless_routers): + LimitlessRouterService._limitless_router_cache.compute_if_absent(CLUSTER_ID, lambda _: limitless_routers, + EXPIRATION_NANO_SECONDS) + + mock_connect_func = mocker.MagicMock() + mock_connect_func.return_value = mock_conn + + input_context: LimitlessContext = LimitlessContext( + limitless_router1, + props, + None, + mock_connect_func, + [], + mock_plugin_service + ) + + limitless_router_service: LimitlessRouterService = LimitlessRouterService(mock_plugin_service, + mock_limitless_query_helper) + limitless_router_service.establish_connection(input_context) + + assert mock_conn == input_context.get_connection() + mock_connect_func.assert_called_once() + + +def test_establish_connection_fetch_router_list_and_host_info_in_router_list_then_call_connection_function(mocker, + mock_conn, + mock_limitless_query_helper, + host_info, + props, + mock_plugin_service, + limitless_router1, + limitless_routers): + mock_limitless_query_helper.query_for_limitless_routers.return_value = limitless_routers + mock_connect_func = mocker.MagicMock() + mock_connect_func.return_value = mock_conn + + input_context: LimitlessContext = LimitlessContext( + limitless_router1, + props, + None, + mock_connect_func, + [], + mock_plugin_service + ) + + limitless_router_service: LimitlessRouterService = LimitlessRouterService(mock_plugin_service, + mock_limitless_query_helper) + limitless_router_service.establish_connection(input_context) + + assert mock_conn == input_context.get_connection() + assert limitless_routers == LimitlessRouterService._limitless_router_cache.get(CLUSTER_ID) + mock_limitless_query_helper.query_for_limitless_routers.assert_called_once() + mock_connect_func.assert_called_once() + + +def test_establish_connection_router_cache_then_select_host(mocker, + mock_conn, + mock_limitless_query_helper, + host_info, + props, + mock_plugin_service, + plugin, + limitless_router1, + limitless_routers): + LimitlessRouterService._limitless_router_cache.compute_if_absent(CLUSTER_ID, lambda _: limitless_routers, + EXPIRATION_NANO_SECONDS) + mock_plugin_service.get_host_info_by_strategy.return_value = limitless_router1 + mock_plugin_service.connect.return_value = mock_conn + + mock_connect_func = mocker.MagicMock() + mock_connect_func.return_value = None + + input_context: LimitlessContext = LimitlessContext( + host_info, + props, + None, + mock_connect_func, + [], + plugin + ) + + limitless_router_service: LimitlessRouterService = LimitlessRouterService(mock_plugin_service, + mock_limitless_query_helper) + limitless_router_service.establish_connection(input_context) + + assert mock_conn == input_context.get_connection() + assert limitless_routers == LimitlessRouterService._limitless_router_cache.get(CLUSTER_ID) + mock_plugin_service.get_host_info_by_strategy.assert_called_once() + mock_plugin_service.get_host_info_by_strategy.assert_called_with(HostRole.WRITER, "weighted_random", + limitless_routers) + mock_plugin_service.connect.assert_called_once() + mock_plugin_service.connect.assert_called_with(limitless_router1, props, plugin) + mock_connect_func.assert_not_called() + + +def test_establish_connection_fetch_router_list_then_select_host(mocker, + mock_conn, + mock_limitless_query_helper, + host_info, + props, + mock_plugin_service, + plugin, + limitless_router1, + limitless_routers): + mock_limitless_query_helper.query_for_limitless_routers.return_value = limitless_routers + mock_plugin_service.get_host_info_by_strategy.return_value = limitless_router1 + mock_plugin_service.connect.return_value = mock_conn + + mock_connect_func = mocker.MagicMock() + mock_connect_func.return_value = None + + input_context: LimitlessContext = LimitlessContext( + host_info, + props, + None, + mock_connect_func, + [], + plugin + ) + + limitless_router_service: LimitlessRouterService = LimitlessRouterService(mock_plugin_service, + mock_limitless_query_helper) + limitless_router_service.establish_connection(input_context) + + assert mock_conn == input_context.get_connection() + assert limitless_routers == LimitlessRouterService._limitless_router_cache.get(CLUSTER_ID) + mock_limitless_query_helper.query_for_limitless_routers.assert_called_once() + mock_plugin_service.get_host_info_by_strategy.assert_called_once() + mock_plugin_service.get_host_info_by_strategy.assert_called_with(HostRole.WRITER, "weighted_random", + limitless_routers) + mock_plugin_service.connect.assert_called_once() + mock_plugin_service.connect.assert_called_with(limitless_router1, props, plugin) + mock_connect_func.assert_called_once() + + +def test_establish_connection_host_info_in_router_cache_can_call_connection_function_then_raises_exception_and_retries( + mocker, + mock_conn, + mock_limitless_query_helper, + props, + mock_plugin_service, + plugin, + limitless_router1, + limitless_routers): + LimitlessRouterService._limitless_router_cache.compute_if_absent(CLUSTER_ID, lambda _: limitless_routers, + EXPIRATION_NANO_SECONDS) + mock_plugin_service.get_host_info_by_strategy.return_value = limitless_router1 + mock_plugin_service.connect.return_value = mock_conn + + mock_connect_func = mocker.MagicMock() + mock_connect_func.side_effect = Exception() + + input_context: LimitlessContext = LimitlessContext( + limitless_router1, + props, + None, + mock_connect_func, + [], + plugin + ) + + limitless_router_service: LimitlessRouterService = LimitlessRouterService(mock_plugin_service, + mock_limitless_query_helper) + limitless_router_service.establish_connection(input_context) + + assert mock_conn == input_context.get_connection() + assert limitless_routers == LimitlessRouterService._limitless_router_cache.get(CLUSTER_ID) + mock_plugin_service.get_host_info_by_strategy.assert_called_once() + mock_plugin_service.get_host_info_by_strategy.assert_called_with(HostRole.WRITER, "highest_weight", + limitless_routers) + mock_plugin_service.connect.assert_called_once() + mock_plugin_service.connect.assert_called_with(limitless_router1, props, plugin) + mock_connect_func.assert_called_once() + + +def test_establish_connection_selected_host_raises_exception_and_retries(mocker, + mock_conn, + mock_limitless_query_helper, + host_info, + props, + mock_plugin_service, + plugin, + limitless_router1, + limitless_routers): + LimitlessRouterService._limitless_router_cache.compute_if_absent(CLUSTER_ID, lambda _: limitless_routers, + EXPIRATION_NANO_SECONDS) + mock_plugin_service.get_host_info_by_strategy.side_effect = [ + Exception(), + limitless_router1 + ] + mock_plugin_service.connect.return_value = mock_conn + + mock_connect_func = mocker.MagicMock() + mock_connect_func.side_effect = Exception() + + input_context: LimitlessContext = LimitlessContext( + host_info, + props, + None, + mock_connect_func, + [], + plugin + ) + + limitless_router_service: LimitlessRouterService = LimitlessRouterService(mock_plugin_service, + mock_limitless_query_helper) + limitless_router_service.establish_connection(input_context) + + assert mock_conn == input_context.get_connection() + assert limitless_routers == LimitlessRouterService._limitless_router_cache.get(CLUSTER_ID) + assert mock_plugin_service.get_host_info_by_strategy.call_count == 2 + mock_plugin_service.get_host_info_by_strategy.assert_called_with(HostRole.WRITER, "highest_weight", + limitless_routers) + mock_plugin_service.get_host_info_by_strategy.assert_called_with(HostRole.WRITER, "highest_weight", + limitless_routers) + mock_plugin_service.connect.assert_called_once() + mock_plugin_service.connect.assert_called_with(limitless_router1, props, plugin) + + +def test_establish_connection_selected_host_none_then_retry(mocker, + mock_conn, + mock_limitless_query_helper, + host_info, + props, + mock_plugin_service, + plugin, + limitless_router1, + limitless_routers): + LimitlessRouterService._limitless_router_cache.compute_if_absent(CLUSTER_ID, lambda _: limitless_routers, + EXPIRATION_NANO_SECONDS) + mock_plugin_service.get_host_info_by_strategy.side_effect = [ + None, + limitless_router1 + ] + mock_plugin_service.connect.return_value = mock_conn + + mock_connect_func = mocker.MagicMock() + mock_connect_func.side_effect = Exception() + + input_context: LimitlessContext = LimitlessContext( + host_info, + props, + None, + mock_connect_func, + [], + plugin + ) + + limitless_router_service: LimitlessRouterService = LimitlessRouterService(mock_plugin_service, + mock_limitless_query_helper) + limitless_router_service.establish_connection(input_context) + + assert mock_conn == input_context.get_connection() + assert limitless_routers == LimitlessRouterService._limitless_router_cache.get(CLUSTER_ID) + assert mock_plugin_service.get_host_info_by_strategy.call_count == 2 + mock_plugin_service.get_host_info_by_strategy.assert_called_with(HostRole.WRITER, "highest_weight", + limitless_routers) + mock_plugin_service.get_host_info_by_strategy.assert_called_with(HostRole.WRITER, "highest_weight", + limitless_routers) + mock_plugin_service.connect.assert_called_once() + mock_plugin_service.connect.assert_called_with(limitless_router1, props, plugin) + + +def test_establish_connection_plugin_service_connect_raises_exception_then_retry(mocker, + mock_conn, + mock_limitless_query_helper, + host_info, + props, + mock_plugin_service, + plugin, + limitless_router1, + limitless_router2, + limitless_routers): + LimitlessRouterService._limitless_router_cache.compute_if_absent(CLUSTER_ID, lambda _: limitless_routers, + EXPIRATION_NANO_SECONDS) + mock_plugin_service.get_host_info_by_strategy.side_effect = [ + limitless_router1, + limitless_router2 + ] + mock_plugin_service.connect.side_effect = [ + Exception(), + mock_conn + ] + + mock_connect_func = mocker.MagicMock() + mock_connect_func.side_effect = Exception() + + input_context: LimitlessContext = LimitlessContext( + host_info, + props, + None, + mock_connect_func, + [], + plugin + ) + + limitless_router_service: LimitlessRouterService = LimitlessRouterService(mock_plugin_service, + mock_limitless_query_helper) + limitless_router_service.establish_connection(input_context) + + assert mock_conn == input_context.get_connection() + assert limitless_routers == LimitlessRouterService._limitless_router_cache.get(CLUSTER_ID) + assert mock_plugin_service.get_host_info_by_strategy.call_count == 2 + mock_plugin_service.get_host_info_by_strategy.assert_called_with(HostRole.WRITER, "highest_weight", + limitless_routers) + mock_plugin_service.get_host_info_by_strategy.assert_called_with(HostRole.WRITER, "highest_weight", + limitless_routers) + assert mock_plugin_service.connect.call_count == 2 + mock_plugin_service.connect.assert_called_with(limitless_router2, props, plugin) + + +def test_establish_connection_retry_and_max_retries_exceeded_then_raise_exception(mocker, + mock_conn, + mock_limitless_query_helper, + host_info, + props, + mock_plugin_service, + plugin, + limitless_router1, + limitless_routers): + LimitlessRouterService._limitless_router_cache.compute_if_absent(CLUSTER_ID, lambda _: limitless_routers, + EXPIRATION_NANO_SECONDS) + mock_plugin_service.get_host_info_by_strategy.return_value = limitless_router1 + mock_plugin_service.connect.side_effect = Exception() + + mock_connect_func = mocker.MagicMock() + mock_connect_func.side_effect = Exception() + + input_context: LimitlessContext = LimitlessContext( + limitless_router1, + props, + None, + mock_connect_func, + [], + plugin + ) + + limitless_router_service: LimitlessRouterService = LimitlessRouterService(mock_plugin_service, + mock_limitless_query_helper) + with pytest.raises(Exception) as e_info: + limitless_router_service.establish_connection(input_context) + + assert e_info.type == AwsWrapperError + assert str(e_info.value) == Messages.get("LimitlessRouterService.MaxRetriesExceeded") + assert mock_plugin_service.connect.call_count == WrapperProperties.MAX_RETRIES_MS.get(props) + assert mock_plugin_service.get_host_info_by_strategy.call_count == WrapperProperties.MAX_RETRIES_MS.get(props) diff --git a/tests/unit/test_rds_utils.py b/tests/unit/test_rds_utils.py index d3d03bcad..1c9cb69b7 100644 --- a/tests/unit/test_rds_utils.py +++ b/tests/unit/test_rds_utils.py @@ -32,7 +32,7 @@ china_alt_region_instance = "instance-test-name.XYZ.rds.cn-northwest-1.amazonaws.com.cn" china_alt_region_proxy = "proxy-test-name.proxy-XYZ.rds.cn-northwest-1.amazonaws.com.cn" china_alt_region_custom_domain = "custom-test-name.cluster-custom-XYZ.rds.cn-northwest-1.amazonaws.com.cn" -china_alt_region_limitless_db_shard_group = "database-test-name.limitless-XYZ.cn-northwest-1.rds.amazonaws.com.cn" +china_alt_region_limitless_db_shard_group = "database-test-name.shardgrp-XYZ.cn-northwest-1.rds.amazonaws.com.cn" extra_rds_china_path = "database-test-name.cluster-XYZ.rds.cn-northwest-1.rds.amazonaws.com.cn" missing_cn_china_path = "database-test-name.cluster-XYZ.rds.cn-northwest-1.amazonaws.com" missing_region_china_path = "database-test-name.cluster-XYZ.rds.amazonaws.com.cn" @@ -43,7 +43,7 @@ us_isob_east_region_instance = "instance-test-name.XYZ.rds.us-isob-east-1.sc2s.sgov.gov" us_isob_east_region_proxy = "proxy-test-name.proxy-XYZ.rds.us-isob-east-1.sc2s.sgov.gov" us_isob_east_region_custom_domain = "custom-test-name.cluster-custom-XYZ.rds.us-isob-east-1.sc2s.sgov.gov" -us_isob_east_region_limitless_db_shard_group = "database-test-name.limitless-XYZ.rds.us-isob-east-1.sc2s.sgov.gov" +us_isob_east_region_limitless_db_shard_group = "database-test-name.shardgrp-XYZ.rds.us-isob-east-1.sc2s.sgov.gov" us_gov_east_region_cluster = "database-test-name.cluster-XYZ.rds.us-gov-east-1.amazonaws.com" us_iso_east_region_cluster = "database-test-name.cluster-XYZ.rds.us-iso-east-1.c2s.ic.gov" @@ -52,7 +52,7 @@ us_iso_east_region_proxy = "proxy-test-name.proxy-XYZ.rds.us-iso-east-1.c2s.ic.gov" us_iso_east_region_custom_domain = "custom-test-name.cluster-custom-XYZ.rds.us-iso-east-1.c2s.ic.gov" -us_iso_east_region_limitless_db_shard_group = "database-test-name.limitless-XYZ.rds.us-iso-east-1.c2s.ic.gov" +us_iso_east_region_limitless_db_shard_group = "database-test-name.shardgrp-XYZ.rds.us-iso-east-1.c2s.ic.gov" @pytest.mark.parametrize("test_value", [ @@ -266,14 +266,17 @@ def test_is_not_reader_cluster_dns(test_value): def test_get_rds_cluster_host_url(): expected: str = "foo.cluster-xyz.us-west-1.rds.amazonaws.com" expected2: str = "foo-1.cluster-xyz.us-west-1.rds.amazonaws.com.cn" + expected_limitless: str = "foo.shardgrp-xyz.us-west-1.rds.amazonaws.com" ro_endpoint: str = "foo.cluster-ro-xyz.us-west-1.rds.amazonaws.com" china_ro_endpoint: str = "foo-1.cluster-ro-xyz.us-west-1.rds.amazonaws.com.cn" + limitless_endpoint: str = "foo.shardgrp-xyz.us-west-1.rds.amazonaws.com" target = RdsUtils() assert target.get_rds_cluster_host_url(ro_endpoint) == expected assert target.get_rds_cluster_host_url(china_ro_endpoint) == expected2 + assert target.get_rds_cluster_host_url(limitless_endpoint) == expected_limitless @pytest.mark.parametrize(