diff --git a/aws_advanced_python_wrapper/aurora_connection_tracker_plugin.py b/aws_advanced_python_wrapper/aurora_connection_tracker_plugin.py index a2ce609b..dfb97744 100644 --- a/aws_advanced_python_wrapper/aurora_connection_tracker_plugin.py +++ b/aws_advanced_python_wrapper/aurora_connection_tracker_plugin.py @@ -50,13 +50,12 @@ def populate_opened_connection_set(self, host_info: HostInfo, conn: Connection): """ aliases: FrozenSet[str] = host_info.as_aliases() - host: str = host_info.as_alias() - if self._rds_utils.is_rds_instance(host): - self._track_connection(host, conn) + if self._rds_utils.is_rds_instance(host_info.host): + self._track_connection(host_info.as_alias(), conn) return - instance_endpoint: Optional[str] = next((alias for alias in aliases if self._rds_utils.is_rds_instance(alias)), + instance_endpoint: Optional[str] = next((alias for alias in aliases if self._rds_utils.is_rds_instance(self._rds_utils.remove_port(alias))), None) if not instance_endpoint: logger.debug("OpenedConnectionTracker.UnableToPopulateOpenedConnectionSet") @@ -82,7 +81,7 @@ def invalidate_all_connections(self, host_info: Optional[HostInfo] = None, host: return for instance in host: - if instance is not None and self._rds_utils.is_rds_instance(instance): + if instance is not None and self._rds_utils.is_rds_instance(self._rds_utils.remove_port(instance)): instance_endpoint = instance break diff --git a/aws_advanced_python_wrapper/host_list_provider.py b/aws_advanced_python_wrapper/host_list_provider.py index c5d54a08..4ead4ebe 100644 --- a/aws_advanced_python_wrapper/host_list_provider.py +++ b/aws_advanced_python_wrapper/host_list_provider.py @@ -199,6 +199,8 @@ def _initialize(self): else: self._cluster_instance_template = HostInfo( host=self._rds_utils.get_rds_instance_host_pattern(self._initial_host_info.host), + host_id=self._initial_host_info.host_id, + port=self._initial_host_info.port, host_availability_strategy=host_availability_strategy) self._validate_host_pattern(self._cluster_instance_template.host) @@ -216,14 +218,15 @@ def _initialize(self): self._cluster_id = cluster_id_suggestion.cluster_id self._is_primary_cluster_id = cluster_id_suggestion.is_primary_cluster_id else: - cluster_url = self._rds_utils.get_rds_cluster_host_url(self._initial_host_info.url) + cluster_url = self._rds_utils.get_rds_cluster_host_url(self._initial_host_info.host) if cluster_url is not None: - self._cluster_id = cluster_url + self._cluster_id = f"{cluster_url}:{self._cluster_instance_template.port}" \ + if self._cluster_instance_template.is_port_specified() else cluster_url self._is_primary_cluster_id = True self._is_primary_cluster_id_cache.put(self._cluster_id, True, self._suggested_cluster_id_refresh_ns) - self._is_initialized = True + self._is_initialized = True def _validate_host_pattern(self, host: str): if not self._rds_utils.is_dns_pattern_valid(host): diff --git a/aws_advanced_python_wrapper/sql_alchemy_connection_provider.py b/aws_advanced_python_wrapper/sql_alchemy_connection_provider.py index 07900a1d..7e385712 100644 --- a/aws_advanced_python_wrapper/sql_alchemy_connection_provider.py +++ b/aws_advanced_python_wrapper/sql_alchemy_connection_provider.py @@ -57,10 +57,12 @@ def __init__( self, pool_configurator: Optional[Callable] = None, pool_mapping: Optional[Callable] = None, + accept_url_func: Optional[Callable] = None, pool_expiration_check_ns: int = -1, pool_cleanup_interval_ns: int = -1): self._pool_configurator = pool_configurator self._pool_mapping = pool_mapping + self._accept_url_func = accept_url_func if pool_expiration_check_ns > -1: SqlAlchemyPooledConnectionProvider._POOL_EXPIRATION_CHECK_NS = pool_expiration_check_ns @@ -80,6 +82,8 @@ def keys(self): return self._database_pools.keys() def accepts_host_info(self, host_info: HostInfo, props: Properties) -> bool: + if self._accept_url_func: + return self._accept_url_func(host_info, props) url_type = SqlAlchemyPooledConnectionProvider._rds_utils.identify_rds_type(host_info.host) return RdsUrlType.RDS_INSTANCE == url_type diff --git a/aws_advanced_python_wrapper/utils/rdsutils.py b/aws_advanced_python_wrapper/utils/rdsutils.py index 677ace03..60d340ec 100644 --- a/aws_advanced_python_wrapper/utils/rdsutils.py +++ b/aws_advanced_python_wrapper/utils/rdsutils.py @@ -12,8 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -from re import search, sub -from typing import Optional +from __future__ import annotations + +from re import Match, search, sub +from typing import Dict, Optional from aws_advanced_python_wrapper.utils.rds_url_type import RdsUrlType @@ -58,78 +60,101 @@ class RdsUtils: Example: test-postgres-instance-1.123456789012.rds.cn-northwest-1.amazonaws.com.cn """ - AURORA_DNS_PATTERN = r"(?P.+)\." \ - r"(?Pproxy-|cluster-|cluster-ro-|cluster-custom-)?" \ + AURORA_DNS_PATTERN = r"^(?P.+)\." \ + r"(?Pproxy-|cluster-|cluster-ro-|cluster-custom-|limitless-)?" \ r"(?P[a-zA-Z0-9]+\." \ - r"(?P[a-zA-Z0-9\-]+)\.rds\.amazonaws\.com)(?!\.cn$)" - AURORA_INSTANCE_PATTERN = r"(?P.+)\." \ + r"(?P[a-zA-Z0-9\-]+)\.rds\.amazonaws\.com)(?!\.cn)$" + AURORA_INSTANCE_PATTERN = r"^(?P.+)\." \ r"(?P[a-zA-Z0-9]+\." \ - r"(?P[a-zA-Z0-9\-]+)\.rds\.amazonaws\.com)(?!\.cn$)" - AURORA_CLUSTER_PATTERN = r"(?P.+)\." \ + r"(?P[a-zA-Z0-9\-]+)\.rds\.amazonaws\.com)(?!\.cn)$" + AURORA_CLUSTER_PATTERN = r"^(?P.+)\." \ r"(?Pcluster-|cluster-ro-)+" \ r"(?P[a-zA-Z0-9]+\." \ - r"(?P[a-zA-Z0-9\-]+)\.rds\.amazonaws\.com)(?!\.cn$)" - AURORA_CUSTOM_CLUSTER_PATTERN = r"(?P.+)\." \ + r"(?P[a-zA-Z0-9\-]+)\.rds\.amazonaws\.com)(?!\.cn)$" + AURORA_CUSTOM_CLUSTER_PATTERN = r"^(?P.+)\." \ r"(?Pcluster-custom-)+" \ r"(?P[a-zA-Z0-9]+\." \ - r"(?P[a-zA-Z0-9\-]+)\.rds\.amazonaws\.com)(?!\.cn$)" - AURORA_PROXY_DNS_PATTERN = r"(?P.+)\." \ + r"(?P[a-zA-Z0-9\-]+)\.rds\.amazonaws\.com)(?!\.cn)$" + AURORA_PROXY_DNS_PATTERN = r"^(?P.+)\." \ r"(?Pproxy-)+" \ r"(?P[a-zA-Z0-9]+\." \ - r"(?P[a-zA-Z0-9\\-]+)\.rds\.amazonaws\.com)(?!\.cn$)" - AURORA_CHINA_DNS_PATTERN = r"(?P.+)\." \ - r"(?Pproxy-|cluster-|cluster-ro-|cluster-custom-)?" \ + r"(?P[a-zA-Z0-9\\-]+)\.rds\.amazonaws\.com)(?!\.cn)$" + AURORA_OLD_CHINA_DNS_PATTERN = r"^(?P.+)\." \ + r"(?Pproxy-|cluster-|cluster-ro-|cluster-custom-|limitless-)?" \ + r"(?P[a-zA-Z0-9]+\." \ + r"(?P[a-zA-Z0-9\-]+)\.rds\.amazonaws\.com\.cn)$" + AURORA_CHINA_DNS_PATTERN = r"^(?P.+)\." \ + r"(?Pproxy-|cluster-|cluster-ro-|cluster-custom-|limitless-)?" \ r"(?P[a-zA-Z0-9]+\." \ - r"(?P[a-zA-Z0-9\-]+)\.rds\.amazonaws\.com\.cn)" - AURORA_CHINA_INSTANCE_PATTERN = r"(?P.+)\." \ - r"(?P[a-zA-Z0-9]+\." \ - r"(?P[a-zA-Z0-9\-]+)\.rds\.amazonaws\.com\.cn)" - AURORA_CHINA_CLUSTER_PATTERN = r"(?P.+)\." \ + r"rds\.(?P[a-zA-Z0-9\-]+)\.amazonaws\.com\.cn)$" + AURORA_OLD_CHINA_CLUSTER_PATTERN = r"^(?P.+)\." \ + r"(?Pcluster-|cluster-ro-)+" \ + r"(?P[a-zA-Z0-9]+\." \ + r"(?P[a-zA-Z0-9\-]+)\.rds\.amazonaws\.com\.cn)$" + AURORA_CHINA_CLUSTER_PATTERN = r"^(?P.+)\." \ r"(?Pcluster-|cluster-ro-)+" \ r"(?P[a-zA-Z0-9]+\." \ - r"(?P[a-zA-Z0-9\-]+)\.rds\.amazonaws\.com\.cn)" - AURORA_CHINA_CUSTOM_CLUSTER_PATTERN = r"(?P.+)\." \ - r"(?Pcluster-custom-)+" \ - r"(?P[a-zA-Z0-9]+\." \ - r"(?P[a-zA-Z0-9\-]+)\.rds\.amazonaws\.com\.cn)" - AURORA_CHINA_PROXY_DNS_PATTERN = r"(?P.+)\." \ - r"(?Pproxy-)+" \ - r"(?P[a-zA-Z0-9]+\." \ - r"(?P[a-zA-Z0-9\-])+\.rds\.amazonaws\.com\.cn)" + r"rds\.(?P[a-zA-Z0-9\-]+)\.amazonaws\.com\.cn)$" + AURORA_GOV_DNS_PATTERN = r"^(?P.+)\." \ + r"(?Pproxy-|cluster-|cluster-ro-|cluster-custom-|limitless-)?" \ + r"(?P[a-zA-Z0-9]+\.rds\.(?P[a-zA-Z0-9\-]+)" \ + r"\.(amazonaws\.com|c2s\.ic\.gov|sc2s\.sgov\.gov))$" + AURORA_GOV_CLUSTER_PATTERN = r"^(?P.+)\." \ + r"(?Pcluster-|cluster-ro-)+" \ + r"(?P[a-zA-Z0-9]+\.rds\.(?P[a-zA-Z0-9\-]+)" \ + r"\.(amazonaws\.com|c2s\.ic\.gov|sc2s\.sgov\.gov))$" + ELB_PATTERN = r"^(?.+)\.elb\.((?[a-zA-Z0-9\-]+)\.amazonaws\.com)$" IP_V4 = r"^(([1-9]|[1-9][0-9]|1[0-9]{2}|2[0-4][0-9]|25[0-5])\.){1}" \ - r"(([0-9]|[1-9][0-9]|1[0-9]{2}|2[0-4][0-9]|25[0-5])\.){2}([0-9]|[1-9][0-9]|1[0-9]{2}|2[0-4][0-9]|25[0-5])$" - IP_V6 = r"^[0-9a-fA-F]{1,4}(:[0-9a-fA-F]{1,4}){7}$" - IP_V6_COMPRESSED = r"^(([0-9A-Fa-f]{1,4}(:[0-9A-Fa-f]{1,4}){0,5})?)::(([0-9A-Fa-f]{1,4}(:[0-9A-Fa-f]{1,4}){0,5})?)$" + r"(([0-9]|[1-9][0-9]|1[0-9]{2}|2[0-4][0-9]|25[0-5])\.){2}([0-9]|[1-9][0-9]|1[0-9]{2}|2[0-4][0-9]|25[0-5])" + IP_V6 = r"^[0-9a-fA-F]{1,4}(:[0-9a-fA-F]{1,4}){7}" + IP_V6_COMPRESSED = r"^(([0-9A-Fa-f]{1,4}(:[0-9A-Fa-f]{1,4}){0,5})?)::(([0-9A-Fa-f]{1,4}(:[0-9A-Fa-f]{1,4}){0,5})?)" DNS_GROUP = "dns" DOMAIN_GROUP = "domain" INSTANCE_GROUP = "instance" REGION_GROUP = "region" + CACHE_DNS_PATTERNS: Dict[str, Match[str]] = {} + CACHE_PATTERNS: Dict[str, str] = {} + def is_rds_cluster_dns(self, host: str) -> bool: - return self._contains(host, [self.AURORA_CLUSTER_PATTERN, self.AURORA_CHINA_CLUSTER_PATTERN]) + dns_group = self._get_dns_group(host) + return dns_group is not None and dns_group.casefold() in ["cluster-", "cluster-ro-"] def is_rds_custom_cluster_dns(self, host: str) -> bool: - return self._contains(host, [self.AURORA_CUSTOM_CLUSTER_PATTERN, self.AURORA_CHINA_CUSTOM_CLUSTER_PATTERN]) + dns_group = self._get_dns_group(host) + return dns_group is not None and dns_group.casefold() == "cluster-custom-" def is_rds_dns(self, host: str) -> bool: - return self._contains(host, [self.AURORA_DNS_PATTERN, self.AURORA_CHINA_DNS_PATTERN]) + if not host or not host.strip(): + return False + + pattern = self._find(host, [RdsUtils.AURORA_DNS_PATTERN, + RdsUtils.AURORA_CHINA_DNS_PATTERN, + RdsUtils.AURORA_OLD_CHINA_DNS_PATTERN, + RdsUtils.AURORA_GOV_DNS_PATTERN]) + group = self._get_regex_group(pattern, RdsUtils.DNS_GROUP) + + if group: + RdsUtils.CACHE_PATTERNS[host] = group + + return pattern is not None def is_rds_instance(self, host: str) -> bool: - return (self._contains(host, [self.AURORA_INSTANCE_PATTERN, self.AURORA_CHINA_INSTANCE_PATTERN]) - and self.is_rds_dns(host)) + return self._get_dns_group(host) is None and self.is_rds_dns(host) def is_rds_proxy_dns(self, host: str) -> bool: - return self._contains(host, [self.AURORA_PROXY_DNS_PATTERN, self.AURORA_CHINA_PROXY_DNS_PATTERN]) + dns_group = self._get_dns_group(host) + return dns_group is not None and dns_group.casefold() == "proxy-" def get_rds_instance_host_pattern(self, host: str) -> str: if not host or not host.strip(): return "?" - match = self._find(host, [self.AURORA_DNS_PATTERN, self.AURORA_CHINA_DNS_PATTERN]) + match = self._get_group(host, RdsUtils.DOMAIN_GROUP) if match: - return f"?.{match.group(self.DOMAIN_GROUP)}" + return f"?.{match}" return "?" @@ -137,56 +162,54 @@ def get_rds_region(self, host: Optional[str]): if not host or not host.strip(): return None - match = self._find(host, [self.AURORA_DNS_PATTERN, self.AURORA_CHINA_DNS_PATTERN]) - if match: - return match.group(self.REGION_GROUP) + group = self._get_group(host, RdsUtils.REGION_GROUP) + if group: + return group + elb_matcher = search(RdsUtils.ELB_PATTERN, host) + if elb_matcher: + return elb_matcher.group(RdsUtils.REGION_GROUP) return None def is_writer_cluster_dns(self, host: str) -> bool: - if not host or not host.strip(): - return False - - match = self._find(host, [self.AURORA_CLUSTER_PATTERN, self.AURORA_CHINA_CLUSTER_PATTERN]) - if match: - return "cluster-".casefold() == match.group(self.DNS_GROUP).casefold() - - return False + dns_group = self._get_dns_group(host) + return dns_group is not None and dns_group.casefold() == "cluster-" def is_reader_cluster_dns(self, host: str) -> bool: - match = self._find(host, [self.AURORA_CLUSTER_PATTERN, self.AURORA_CHINA_CLUSTER_PATTERN]) - if match: - return "cluster-ro-".casefold() == match.group(self.DNS_GROUP).casefold() - - return False + dns_group = self._get_dns_group(host) + return dns_group is not None and dns_group.casefold() == "cluster-ro-" def get_rds_cluster_host_url(self, host: str): if not host or not host.strip(): return None - if search(self.AURORA_CLUSTER_PATTERN, host): - return sub(self.AURORA_CLUSTER_PATTERN, r"\g.cluster-\g", host) - - if search(self.AURORA_CHINA_CLUSTER_PATTERN, host): - return sub(self.AURORA_CHINA_CLUSTER_PATTERN, r"\g.cluster-\g", host) + for pattern in [RdsUtils.AURORA_DNS_PATTERN, + RdsUtils.AURORA_CHINA_DNS_PATTERN, + RdsUtils.AURORA_OLD_CHINA_DNS_PATTERN, + RdsUtils.AURORA_GOV_DNS_PATTERN]: + if m := search(pattern, host): + group = self._get_regex_group(m, RdsUtils.DNS_GROUP) + if group is not None: + return sub(pattern, r"\g.cluster-\g", host) + return None return None def get_instance_id(self, host: str) -> Optional[str]: - if not host or not host.strip(): - return None - - match = self._find(host, [self.AURORA_INSTANCE_PATTERN, self.AURORA_CHINA_INSTANCE_PATTERN]) - if match: - return match.group(self.INSTANCE_GROUP) + if self._get_dns_group(host) is None: + return self._get_group(host, self.INSTANCE_GROUP) return None def is_ipv4(self, host: str) -> bool: - return self._contains(host, [self.IP_V4]) + if host is None or not host.strip(): + return False + return search(RdsUtils.IP_V4, host) is not None def is_ipv6(self, host: str) -> bool: - return self._contains(host, [self.IP_V6, self.IP_V6_COMPRESSED]) + if host is None or not host.strip(): + return False + return search(RdsUtils.IP_V6_COMPRESSED, host) is not None or search(RdsUtils.IP_V6, host) is not None def is_dns_pattern_valid(self, host: str) -> bool: return "?" in host @@ -210,17 +233,48 @@ def identify_rds_type(self, host: Optional[str]) -> RdsUrlType: return RdsUrlType.OTHER - def _contains(self, host: str, patterns: list) -> bool: - if not host or not host.strip(): - return False - - return len([pattern for pattern in patterns if search(pattern, host)]) > 0 - def _find(self, host: str, patterns: list): if not host or not host.strip(): return None for pattern in patterns: + match = RdsUtils.CACHE_DNS_PATTERNS.get(host) + if match: + return match + match = search(pattern, host) if match: + RdsUtils.CACHE_DNS_PATTERNS[host] = match return match + + return None + + def _get_regex_group(self, pattern: Match[str], group_name: str): + if pattern is None: + return None + return pattern.group(group_name) + + def _get_group(self, host: str, group: str): + if not host or not host.strip(): + return None + + pattern = self._find(host, [RdsUtils.AURORA_DNS_PATTERN, + RdsUtils.AURORA_CHINA_DNS_PATTERN, + RdsUtils.AURORA_OLD_CHINA_DNS_PATTERN, + RdsUtils.AURORA_GOV_DNS_PATTERN]) + return self._get_regex_group(pattern, group) + + def _get_dns_group(self, host: str): + return self._get_group(host, RdsUtils.DNS_GROUP) + + def remove_port(self, url: str): + if not url or not url.strip(): + return None + if ":" in url: + return url.split(":")[0] + return url + + @staticmethod + def clear_cache(): + RdsUtils.CACHE_PATTERNS.clear() + RdsUtils.CACHE_DNS_PATTERNS.clear() diff --git a/tests/integration/container/conftest.py b/tests/integration/container/conftest.py index d0c4acf4..3a8358ae 100644 --- a/tests/integration/container/conftest.py +++ b/tests/integration/container/conftest.py @@ -28,6 +28,7 @@ from aws_advanced_python_wrapper.host_list_provider import RdsHostListProvider from aws_advanced_python_wrapper.plugin_service import PluginServiceImpl from aws_advanced_python_wrapper.utils.log import Logger +from aws_advanced_python_wrapper.utils.rdsutils import RdsUtils if TYPE_CHECKING: from .utils.test_driver import TestDriver @@ -124,6 +125,7 @@ def pytest_runtest_setup(item): assert cluster_ip == writer_ip + RdsUtils.clear_cache() RdsHostListProvider._topology_cache.clear() RdsHostListProvider._is_primary_cluster_id_cache.clear() RdsHostListProvider._cluster_ids_to_update.clear() diff --git a/tests/integration/container/test_autoscaling.py b/tests/integration/container/test_autoscaling.py index bd0f5d98..61c9e7d2 100644 --- a/tests/integration/container/test_autoscaling.py +++ b/tests/integration/container/test_autoscaling.py @@ -101,6 +101,7 @@ def test_pooled_connection_auto_scaling__set_read_only_on_old_connection( provider = SqlAlchemyPooledConnectionProvider( lambda _, __: {"pool_size": original_cluster_size}, None, + None, 120000000000, # 2 minutes 180000000000) # 3 minutes ConnectionProviderManager.set_connection_provider(provider) @@ -167,6 +168,7 @@ def test_pooled_connection_auto_scaling__failover_from_deleted_reader( provider = SqlAlchemyPooledConnectionProvider( lambda _, __: {"pool_size": len(instances) * 5}, None, + None, 120000000000, # 2 minutes 180000000000) # 3 minutes ConnectionProviderManager.set_connection_provider(provider) diff --git a/tests/integration/container/test_read_write_splitting.py b/tests/integration/container/test_read_write_splitting.py index 4fdcbf37..86ad049a 100644 --- a/tests/integration/container/test_read_write_splitting.py +++ b/tests/integration/container/test_read_write_splitting.py @@ -515,7 +515,8 @@ def test_pooled_connection__cluster_url_failover( def test_pooled_connection__failover_failed( self, test_environment: TestEnvironment, test_driver: TestDriver, rds_utils, conn_utils, proxied_failover_props): - provider = SqlAlchemyPooledConnectionProvider(lambda _, __: {"pool_size": 1}) + writer_host = test_environment.get_writer().get_host() + provider = SqlAlchemyPooledConnectionProvider(lambda _, __: {"pool_size": 1}, None, lambda host_info, props: writer_host in host_info.host) ConnectionProviderManager.set_connection_provider(provider) WrapperProperties.PLUGINS.set(proxied_failover_props, "read_write_splitting,failover,host_monitoring") diff --git a/tests/unit/test_iam_plugin.py b/tests/unit/test_iam_plugin.py index 6438be94..04273698 100644 --- a/tests/unit/test_iam_plugin.py +++ b/tests/unit/test_iam_plugin.py @@ -357,7 +357,7 @@ def test_connect_with_specified_region(mocker, mock_plugin_service, mock_session pytest.param("test-.cluster-ro-123456789012.us-east-2.rds.amazonaws.com"), pytest.param("test.cluster-custom-123456789012.us-east-2.rds.amazonaws.com"), pytest.param("test-.proxy-123456789012.us-east-2.rds.amazonaws.com.cn"), - pytest.param("test-.proxy-123456789012.us-east-2.rds.amazonaws.com.proxy"), + pytest.param("test-.proxy-123456789012.us-east-2.rds.amazonaws.com"), ]) @patch("aws_advanced_python_wrapper.iam_plugin.IamAuthPlugin._token_cache", _token_cache) def test_connect_with_specified_host(iam_host: str, mocker, mock_plugin_service, mock_session, mock_func, mock_client, mock_dialect): diff --git a/tests/unit/test_rds_utils.py b/tests/unit/test_rds_utils.py index 6b0a37c0..dacd100f 100644 --- a/tests/unit/test_rds_utils.py +++ b/tests/unit/test_rds_utils.py @@ -27,12 +27,46 @@ china_region_proxy = "proxy-test-name.proxy-XYZ.cn-northwest-1.rds.amazonaws.com.cn" china_region_custom_domain = "custom-test-name.cluster-custom-XYZ.cn-northwest-1.rds.amazonaws.com.cn" +china_alt_region_cluster = "database-test-name.cluster-XYZ.rds.cn-northwest-1.amazonaws.com.cn" +china_alt_region_cluster_read_only = "database-test-name.cluster-ro-XYZ.rds.cn-northwest-1.amazonaws.com.cn" +china_alt_region_instance = "instance-test-name.XYZ.rds.cn-northwest-1.amazonaws.com.cn" +china_alt_region_proxy = "proxy-test-name.proxy-XYZ.rds.cn-northwest-1.amazonaws.com.cn" +china_alt_region_custom_domain = "custom-test-name.cluster-custom-XYZ.rds.cn-northwest-1.amazonaws.com.cn" +china_alt_region_limitless_db_shard_group = "database-test-name.limitless-XYZ.cn-northwest-1.rds.amazonaws.com.cn" +extra_rds_china_path = "database-test-name.cluster-XYZ.rds.cn-northwest-1.rds.amazonaws.com.cn" +missing_cn_china_path = "database-test-name.cluster-XYZ.rds.cn-northwest-1.amazonaws.com" +missing_region_china_path = "database-test-name.cluster-XYZ.rds.amazonaws.com.cn" + +us_east_region_elb_url = "elb-name.elb.us-east-2.amazonaws.com" +us_isob_east_region_cluster = "database-test-name.cluster-XYZ.rds.us-isob-east-1.sc2s.sgov.gov" +us_isob_east_region_cluster_read_only = "database-test-name.cluster-ro-XYZ.rds.us-isob-east-1.sc2s.sgov.gov" +us_isob_east_region_instance = "instance-test-name.XYZ.rds.us-isob-east-1.sc2s.sgov.gov" +us_isob_east_region_proxy = "proxy-test-name.proxy-XYZ.rds.us-isob-east-1.sc2s.sgov.gov" +us_isob_east_region_custom_domain = "custom-test-name.cluster-custom-XYZ.rds.us-isob-east-1.sc2s.sgov.gov" +us_isob_east_region_limitless_db_shard_group = "database-test-name.limitless-XYZ.rds.us-isob-east-1.sc2s.sgov.gov" +us_gov_east_region_cluster = "database-test-name.cluster-XYZ.rds.us-gov-east-1.amazonaws.com" + +us_iso_east_region_cluster = "database-test-name.cluster-XYZ.rds.us-iso-east-1.c2s.ic.gov" +us_iso_east_region_cluster_read_only = "database-test-name.cluster-ro-XYZ.rds.us-iso-east-1.c2s.ic.gov" +us_iso_east_region_instance = "instance-test-name.XYZ.rds.us-iso-east-1.c2s.ic.gov" +us_iso_east_region_proxy = "proxy-test-name.proxy-XYZ.rds.us-iso-east-1.c2s.ic.gov" +us_iso_east_region_custom_domain = "custom-test-name.cluster-custom-XYZ.rds.us-iso-east-1.c2s.ic.gov" + +us_iso_east_region_limitless_db_shard_group = "database-test-name.limitless-XYZ.rds.us-iso-east-1.c2s.ic.gov" + @pytest.mark.parametrize("test_value", [ us_east_region_cluster, us_east_region_cluster_read_only, china_region_cluster, + china_alt_region_cluster, china_region_cluster_read_only, + china_alt_region_cluster_read_only, + us_isob_east_region_cluster, + us_isob_east_region_cluster_read_only, + us_gov_east_region_cluster, + us_iso_east_region_cluster, + us_iso_east_region_cluster_read_only ]) def test_is_rds_cluster_dns(test_value): target = RdsUtils() @@ -46,7 +80,18 @@ def test_is_rds_cluster_dns(test_value): us_east_region_custom_domain, china_region_instance, china_region_proxy, - china_region_custom_domain + china_region_custom_domain, + china_alt_region_instance, + china_alt_region_proxy, + china_alt_region_custom_domain, + china_alt_region_limitless_db_shard_group, + us_east_region_elb_url, + us_isob_east_region_instance, + us_isob_east_region_proxy, + us_isob_east_region_limitless_db_shard_group, + us_iso_east_region_instance, + us_iso_east_region_proxy, + us_iso_east_region_limitless_db_shard_group, ]) def test_is_not_rds_cluster_dns(test_value): target = RdsUtils() @@ -64,7 +109,20 @@ def test_is_not_rds_cluster_dns(test_value): china_region_cluster_read_only, china_region_instance, china_region_proxy, - china_region_custom_domain + china_region_custom_domain, + china_alt_region_cluster, + china_alt_region_cluster_read_only, + china_alt_region_instance, + china_alt_region_proxy, + china_alt_region_custom_domain, + china_alt_region_limitless_db_shard_group, + us_isob_east_region_cluster, + us_isob_east_region_cluster_read_only, + us_isob_east_region_instance, + us_isob_east_region_proxy, + us_isob_east_region_custom_domain, + us_isob_east_region_limitless_db_shard_group, + us_gov_east_region_cluster, ]) def test_is_rds_dns(test_value): target = RdsUtils() @@ -81,11 +139,24 @@ def test_is_rds_dns(test_value): ("?.XYZ.cn-northwest-1.rds.amazonaws.com.cn", china_region_cluster_read_only), ("?.XYZ.cn-northwest-1.rds.amazonaws.com.cn", china_region_instance), ("?.XYZ.cn-northwest-1.rds.amazonaws.com.cn", china_region_proxy), - ("?.XYZ.cn-northwest-1.rds.amazonaws.com.cn", china_region_custom_domain) + ("?.XYZ.cn-northwest-1.rds.amazonaws.com.cn", china_region_custom_domain), + ("?.XYZ.rds.cn-northwest-1.amazonaws.com.cn", china_alt_region_cluster), + ("?.XYZ.rds.cn-northwest-1.amazonaws.com.cn", china_alt_region_cluster_read_only), + ("?.XYZ.rds.cn-northwest-1.amazonaws.com.cn", china_alt_region_instance), + ("?.XYZ.rds.cn-northwest-1.amazonaws.com.cn", china_alt_region_proxy), + ("?.XYZ.rds.cn-northwest-1.amazonaws.com.cn", china_alt_region_custom_domain), + ("?.XYZ.cn-northwest-1.rds.amazonaws.com.cn", china_alt_region_limitless_db_shard_group), + ("?.XYZ.rds.us-isob-east-1.sc2s.sgov.gov", us_isob_east_region_cluster), + ("?.XYZ.rds.us-isob-east-1.sc2s.sgov.gov", us_isob_east_region_cluster_read_only), + ("?.XYZ.rds.us-isob-east-1.sc2s.sgov.gov", us_isob_east_region_instance), + ("?.XYZ.rds.us-isob-east-1.sc2s.sgov.gov", us_isob_east_region_proxy), + ("?.XYZ.rds.us-isob-east-1.sc2s.sgov.gov", us_isob_east_region_custom_domain), + ("?.XYZ.rds.us-isob-east-1.sc2s.sgov.gov", us_isob_east_region_limitless_db_shard_group), + ("?.XYZ.rds.us-gov-east-1.amazonaws.com", us_gov_east_region_cluster), ]) def test_get_rds_instance_host_pattern(expected, test_value): target = RdsUtils() - assert expected == target.get_rds_instance_host_pattern(test_value) + assert target.get_rds_instance_host_pattern(test_value) == expected @pytest.mark.parametrize("expected, test_value", [ @@ -98,11 +169,24 @@ def test_get_rds_instance_host_pattern(expected, test_value): ("cn-northwest-1", china_region_cluster_read_only), ("cn-northwest-1", china_region_instance), ("cn-northwest-1", china_region_proxy), - ("cn-northwest-1", china_region_custom_domain) + ("cn-northwest-1", china_region_custom_domain), + ("cn-northwest-1", china_alt_region_cluster), + ("cn-northwest-1", china_alt_region_cluster_read_only), + ("cn-northwest-1", china_alt_region_instance), + ("cn-northwest-1", china_alt_region_proxy), + ("cn-northwest-1", china_alt_region_custom_domain), + ("cn-northwest-1", china_alt_region_limitless_db_shard_group), + ("us-isob-east-1", us_isob_east_region_cluster), + ("us-isob-east-1", us_isob_east_region_cluster_read_only), + ("us-isob-east-1", us_isob_east_region_instance), + ("us-isob-east-1", us_isob_east_region_proxy), + ("us-isob-east-1", us_isob_east_region_custom_domain), + ("us-isob-east-1", us_isob_east_region_limitless_db_shard_group), + ("us-gov-east-1", us_gov_east_region_cluster), ]) def test_get_rds_region(expected, test_value): target = RdsUtils() - assert expected == target.get_rds_region(test_value) + assert target.get_rds_region(test_value) == expected @pytest.mark.parametrize("test_value", [ @@ -123,7 +207,17 @@ def test_is_writer_cluster_dns(test_value): china_region_cluster_read_only, china_region_instance, china_region_proxy, - china_region_custom_domain + china_region_custom_domain, + china_alt_region_cluster_read_only, + china_alt_region_instance, + china_alt_region_proxy, + china_alt_region_custom_domain, + china_alt_region_limitless_db_shard_group, + us_isob_east_region_cluster_read_only, + us_isob_east_region_instance, + us_isob_east_region_proxy, + us_isob_east_region_custom_domain, + us_isob_east_region_limitless_db_shard_group, ]) def test_is_not_writer_cluster_dns(test_value): target = RdsUtils() @@ -134,6 +228,7 @@ def test_is_not_writer_cluster_dns(test_value): @pytest.mark.parametrize("test_value", [ us_east_region_cluster_read_only, china_region_cluster_read_only, + us_isob_east_region_cluster_read_only, ]) def test_is_reader_cluster_dns(test_value): target = RdsUtils() @@ -149,7 +244,18 @@ def test_is_reader_cluster_dns(test_value): china_region_cluster, china_region_instance, china_region_proxy, - china_region_custom_domain + china_region_custom_domain, + china_region_cluster, + china_region_instance, + china_region_proxy, + china_region_custom_domain, + china_alt_region_limitless_db_shard_group, + us_isob_east_region_cluster, + us_isob_east_region_instance, + us_isob_east_region_proxy, + us_isob_east_region_custom_domain, + us_isob_east_region_limitless_db_shard_group, + us_gov_east_region_cluster, ]) def test_is_not_reader_cluster_dns(test_value): target = RdsUtils() @@ -166,8 +272,8 @@ def test_get_rds_cluster_host_url(): target = RdsUtils() - assert expected == target.get_rds_cluster_host_url(ro_endpoint) - assert expected2 == target.get_rds_cluster_host_url(china_ro_endpoint) + assert target.get_rds_cluster_host_url(ro_endpoint) == expected + assert target.get_rds_cluster_host_url(china_ro_endpoint) == expected2 @pytest.mark.parametrize( @@ -177,4 +283,4 @@ def test_get_rds_cluster_host_url(): ) def test_get_instance_id(host: str, expected_id: str): target = RdsUtils() - assert expected_id == target.get_instance_id(host) + assert target.get_instance_id(host) == expected_id