diff --git a/README.md b/README.md index 9bea3d6d6..b69528dfa 100644 --- a/README.md +++ b/README.md @@ -42,6 +42,7 @@ The Cloud SQL Python Connector is a package to be used alongside a database driv Currently supported drivers are: - [`pymysql`](https://github.com/PyMySQL/PyMySQL) (MySQL) - [`pg8000`](https://github.com/tlocke/pg8000) (PostgreSQL) + - [`psycopg`](https://github.com/psycopg/psycopg) (PostgreSQL) - [`asyncpg`](https://github.com/MagicStack/asyncpg) (PostgreSQL) - [`pytds`](https://github.com/denisenkom/pytds) (SQL Server) @@ -587,7 +588,7 @@ async def main(): # acquire connection and query Cloud SQL database async with pool.acquire() as conn: res = await conn.fetch("SELECT NOW()") - + # close Connector await connector.close_async() ``` diff --git a/google/cloud/sql/connector/client.py b/google/cloud/sql/connector/client.py index 11508ce17..ef748eb19 100644 --- a/google/cloud/sql/connector/client.py +++ b/google/cloud/sql/connector/client.py @@ -58,6 +58,7 @@ def __init__( client: Optional[aiohttp.ClientSession] = None, driver: Optional[str] = None, user_agent: Optional[str] = None, + loop: Optional[asyncio.AbstractEventLoop] = None, ) -> None: """Establishes the client to be used for Cloud SQL Admin API requests. @@ -84,8 +85,7 @@ def __init__( } if quota_project: headers["x-goog-user-project"] = quota_project - - self._client = client if client else aiohttp.ClientSession(headers=headers) + self._client = client if client else aiohttp.ClientSession(headers=headers, loop=loop) self._credentials = credentials if sqladmin_api_endpoint is None: self._sqladmin_api_endpoint = DEFAULT_SERVICE_ENDPOINT diff --git a/google/cloud/sql/connector/connector.py b/google/cloud/sql/connector/connector.py index 0229b7283..b5ecff204 100755 --- a/google/cloud/sql/connector/connector.py +++ b/google/cloud/sql/connector/connector.py @@ -39,6 +39,7 @@ from google.cloud.sql.connector.lazy import LazyRefreshCache from google.cloud.sql.connector.monitored_cache import MonitoredCache import google.cloud.sql.connector.pg8000 as pg8000 +import google.cloud.sql.connector.proxy as proxy import google.cloud.sql.connector.pymysql as pymysql import google.cloud.sql.connector.pytds as pytds from google.cloud.sql.connector.resolver import DefaultResolver @@ -155,6 +156,7 @@ def __init__( # connection name string and enable_iam_auth boolean flag self._cache: dict[tuple[str, bool], MonitoredCache] = {} self._client: Optional[CloudSQLClient] = None + self._proxies: list[proxy.Proxy] = [] # initialize credentials scopes = ["https://www.googleapis.com/auth/sqlservice.admin"] @@ -215,6 +217,108 @@ def __init__( def universe_domain(self) -> str: return self._universe_domain or _DEFAULT_UNIVERSE_DOMAIN + async def _get_cache( + self, + instance_connection_string: str, + enable_iam_auth: bool, + ip_type: IPTypes, + driver: str | None, + ) -> MonitoredCache: + """Helper function to get instance's cache from Connector cache.""" + + # resolve instance connection name + conn_name = await self._resolver.resolve(instance_connection_string) + cache_key = (str(conn_name), enable_iam_auth) + + # if cache entry doesn't exist or is closed, create it + if cache_key not in self._cache or self._cache[cache_key].closed: + # if lazy refresh, init keys now + if self._refresh_strategy == RefreshStrategy.LAZY and self._keys is None: + self._keys = asyncio.create_task(generate_keys()) + # create cache + if self._refresh_strategy == RefreshStrategy.LAZY: + logger.debug( + f"['{conn_name}']: Refresh strategy is set to lazy refresh" + ) + cache: Union[LazyRefreshCache, RefreshAheadCache] = LazyRefreshCache( + conn_name, + self._init_client(driver), + self._keys, # type: ignore + enable_iam_auth, + ) + else: + logger.debug( + f"['{conn_name}']: Refresh strategy is set to background refresh" + ) + cache = RefreshAheadCache( + conn_name, + self._init_client(driver), + self._keys, # type: ignore + enable_iam_auth, + ) + # wrap cache as a MonitoredCache + monitored_cache = MonitoredCache( + cache, + self._failover_period, + self._resolver, + ) + logger.debug(f"['{conn_name}']: Connection info added to cache") + self._cache[cache_key] = monitored_cache + + monitored_cache = self._cache[(str(conn_name), enable_iam_auth)] + + # Check that the information is valid and matches the driver and db type + try: + conn_info = await monitored_cache.connect_info() + # validate driver matches intended database engine + if driver: + DriverMapping.validate_engine(driver, conn_info.database_version) + if ip_type: + conn_info.get_preferred_ip(ip_type) + except Exception: + await self._remove_cached(str(conn_name), enable_iam_auth) + raise + + return monitored_cache + + async def connect_socket_async( + self, + instance_connection_string: str, + protocol_fn: Callable[[], asyncio.Protocol], + **kwargs: Any, + ) -> tuple[asyncio.Transport, asyncio.Protocol]: + """Helper function to connect to a Cloud SQL instance and return a socket.""" + + enable_iam_auth = kwargs.pop("enable_iam_auth", self._enable_iam_auth) + ip_type = kwargs.pop("ip_type", self._ip_type) + driver = kwargs.pop("driver", None) + # if ip_type is str, convert to IPTypes enum + if isinstance(ip_type, str): + ip_type = IPTypes._from_str(ip_type) + + monitored_cache = await self._get_cache( + instance_connection_string, enable_iam_auth, ip_type, driver + ) + + try: + conn_info = await monitored_cache.connect_info() + ctx = await conn_info.create_ssl_context(enable_iam_auth) + ip_address = conn_info.get_preferred_ip(ip_type) + tx, p = await self._loop.create_connection( + protocol_fn, host=ip_address, port=3307, ssl=ctx + ) + except Exception as ex: + logger.exception("exception starting tls protocol", exc_info=ex) + # with an error from Cloud SQL Admin API call or IP type, invalidate + # the cache and re-raise the error + await self._remove_cached( + instance_connection_string, + enable_iam_auth, + ) + raise + + return tx, p + def connect( self, instance_connection_string: str, driver: str, **kwargs: Any ) -> Any: @@ -232,7 +336,7 @@ def connect( Example: "my-project:us-central1:my-instance" driver (str): A string representing the database driver to connect - with. Supported drivers are pymysql, pg8000, and pytds. + with. Supported drivers are pymysql, pg8000, psycopg, and pytds. **kwargs: Any driver-specific arguments to pass to the underlying driver .connect call. @@ -250,6 +354,19 @@ def connect( ) return connect_future.result() + def _init_client(self, driver: Optional[str]) -> CloudSQLClient: + """Lazy initialize the client, setting the driver name in the user agent string.""" + if self._client is None: + self._client = CloudSQLClient( + self._sqladmin_api_endpoint, + self._quota_project, + self._credentials, + user_agent=self._user_agent, + driver=driver, + loop=self._loop + ) + return self._client + async def connect_async( self, instance_connection_string: str, driver: str, **kwargs: Any ) -> Any: @@ -268,7 +385,8 @@ async def connect_async( Example: "my-project:us-central1:my-instance" driver (str): A string representing the database driver to connect - with. Supported drivers are pymysql, asyncpg, pg8000, and pytds. + with. Supported drivers are pymysql, asyncpg, pg8000, psycopg, and + pytds. **kwargs: Any driver-specific arguments to pass to the underlying driver .connect call. @@ -280,8 +398,9 @@ async def connect_async( ValueError: Connection attempt with built-in database authentication and then subsequent attempt with IAM database authentication. KeyError: Unsupported database driver Must be one of pymysql, asyncpg, - pg8000, and pytds. + pg8000, psycopg, and pytds. """ + self._init_client(driver) # check if event loop is running in current thread if self._loop != asyncio.get_running_loop(): raise ConnectorLoopError( @@ -301,6 +420,7 @@ async def connect_async( self._credentials, user_agent=self._user_agent, driver=driver, + loop=self._loop ) enable_iam_auth = kwargs.pop("enable_iam_auth", self._enable_iam_auth) @@ -340,6 +460,7 @@ async def connect_async( logger.debug(f"['{conn_name}']: Connection info added to cache") self._cache[(str(conn_name), enable_iam_auth)] = monitored_cache + # Map drivers to connect functions connect_func = { "pymysql": pymysql.connect, "pg8000": pg8000.connect, @@ -347,7 +468,7 @@ async def connect_async( "pytds": pytds.connect, } - # only accept supported database drivers + # Only accept supported database drivers try: connector: Callable = connect_func[driver] # type: ignore except KeyError: @@ -357,6 +478,7 @@ async def connect_async( # if ip_type is str, convert to IPTypes enum if isinstance(ip_type, str): ip_type = IPTypes._from_str(ip_type) + enable_iam_auth = kwargs.pop("enable_iam_auth", self._enable_iam_auth) kwargs["timeout"] = kwargs.get("timeout", self._timeout) # Host and ssl options come from the certificates and metadata, so we don't @@ -365,60 +487,88 @@ async def connect_async( kwargs.pop("ssl", None) kwargs.pop("port", None) - # attempt to get connection info for Cloud SQL instance + monitored_cache = await self._get_cache( + instance_connection_string, enable_iam_auth, ip_type, driver + ) + conn_info = await monitored_cache.connect_info() + ip_address = conn_info.get_preferred_ip(ip_type) + try: - conn_info = await monitored_cache.connect_info() - # validate driver matches intended database engine - DriverMapping.validate_engine(driver, conn_info.database_version) - ip_address = conn_info.get_preferred_ip(ip_type) - except Exception: - # with an error from Cloud SQL Admin API call or IP type, invalidate - # the cache and re-raise the error - await self._remove_cached(str(conn_name), enable_iam_auth) - raise - logger.debug(f"['{conn_info.conn_name}']: Connecting to {ip_address}:3307") - # format `user` param for automatic IAM database authn - if enable_iam_auth: - formatted_user = format_database_user( - conn_info.database_version, kwargs["user"] - ) - if formatted_user != kwargs["user"]: - logger.debug( - f"['{instance_connection_string}']: Truncated IAM database username from {kwargs['user']} to {formatted_user}" + # format `user` param for automatic IAM database authn + if enable_iam_auth: + formatted_user = format_database_user( + conn_info.database_version, kwargs["user"] ) - kwargs["user"] = formatted_user - try: + if formatted_user != kwargs["user"]: + logger.debug( + f"['{instance_connection_string}']: " + "Truncated IAM database username from " + f"{kwargs['user']} to {formatted_user}" + ) + kwargs["user"] = formatted_user + + ctx = await conn_info.create_ssl_context(enable_iam_auth) # async drivers are unblocking and can be awaited directly if driver in ASYNC_DRIVERS: - return await connector( - ip_address, - await conn_info.create_ssl_context(enable_iam_auth), - **kwargs, + return await connector(ip_address, ctx, **kwargs) + else: + # Synchronous drivers are blocking and run using executor + tx, _ = await self.connect_socket_async( + instance_connection_string, asyncio.Protocol, **kwargs ) - # Create socket with SSLContext for sync drivers - ctx = await conn_info.create_ssl_context(enable_iam_auth) - sock = ctx.wrap_socket( - socket.create_connection((ip_address, SERVER_PROXY_PORT)), - server_hostname=ip_address, - ) - # If this connection was opened using a domain name, then store it - # for later in case we need to forcibly close it on failover. - if conn_info.conn_name.domain_name: - monitored_cache.sockets.append(sock) - # Synchronous drivers are blocking and run using executor - connect_partial = partial( - connector, - ip_address, - sock, - **kwargs, - ) - return await self._loop.run_in_executor(None, connect_partial) + # See https://docs.python.org/3/library/asyncio-protocol.html#asyncio.BaseTransport.get_extra_info + ctx = tx.get_extra_info("sslcontext") + sock = ctx.wrap_socket( + socket.create_connection((ip_address, SERVER_PROXY_PORT)), + server_hostname=ip_address, + ) + connect_partial = partial(connector, ip_address, sock, **kwargs) + return await self._loop.run_in_executor(None, connect_partial) except Exception: # with any exception, we attempt a force refresh, then throw the error + monitored_cache = await self._get_cache( + instance_connection_string, enable_iam_auth, ip_type, driver + ) await monitored_cache.force_refresh() raise + async def start_unix_socket_proxy_async( + self, instance_connection_string: str, local_socket_path: str, **kwargs: Any + ) -> None: + """Starts a local Unix socket proxy for a Cloud SQL instance. + + Args: + instance_connection_string (str): The instance connection name of the + Cloud SQL instance to connect to. + local_socket_path (str): The path to the local Unix socket. + driver (str): The database driver name. + **kwargs: Keyword arguments to pass to the underlying database + driver. + """ + if "driver" in kwargs: + driver = kwargs["driver"] + else: + driver = "proxy" + + self._init_client(driver) + + # check if a proxy is already running for this socket path + for p in self._proxies: + if p.unix_socket_path == local_socket_path: + raise ValueError( + f"Proxy for socket path {local_socket_path} already exists." + ) + + # Create a new proxy instance + proxy_instance = proxy.Proxy( + local_socket_path, + ConnectorSocketFactory(self, instance_connection_string, **kwargs), + self._loop + ) + await proxy_instance.start() + self._proxies.append(proxy_instance) + async def _remove_cached( self, instance_connection_string: str, enable_iam_auth: bool ) -> None: @@ -477,6 +627,9 @@ def close(self) -> None: async def close_async(self) -> None: """Helper function to cancel the cache's tasks and close aiohttp.ClientSession.""" + # close all proxies + if self._proxies: + await asyncio.gather(*[proxy.close() for proxy in self._proxies]) await asyncio.gather(*[cache.close() for cache in self._cache.values()]) if self._client: await self._client.close() @@ -571,3 +724,13 @@ async def create_async_connector( resolver=resolver, failover_period=failover_period, ) + + +class ConnectorSocketFactory(proxy.ServerConnectionFactory): + def __init__(self, connector:Connector, instance_connection_string:str, **kwargs): + self._connector = connector + self._instance_connection_string = instance_connection_string + self._connect_args=kwargs + + async def connect(self, protocol_fn: Callable[[], asyncio.Protocol]): + await self._connector.connect_socket_async(self._instance_connection_string, protocol_fn, **self._connect_args) \ No newline at end of file diff --git a/google/cloud/sql/connector/enums.py b/google/cloud/sql/connector/enums.py index e6b56af0e..5926b75a2 100644 --- a/google/cloud/sql/connector/enums.py +++ b/google/cloud/sql/connector/enums.py @@ -62,6 +62,7 @@ class DriverMapping(Enum): ASYNCPG = "POSTGRES" PG8000 = "POSTGRES" + PSYCOPG = "POSTGRES" PYMYSQL = "MYSQL" PYTDS = "SQLSERVER" diff --git a/google/cloud/sql/connector/exceptions.py b/google/cloud/sql/connector/exceptions.py index da39ea25d..3aee96430 100644 --- a/google/cloud/sql/connector/exceptions.py +++ b/google/cloud/sql/connector/exceptions.py @@ -84,3 +84,9 @@ class CacheClosedError(Exception): Exception to be raised when a ConnectionInfoCache can not be accessed after it is closed. """ + + +class LocalProxyStartupError(Exception): + """ + Exception to be raised when a the local UNIX-socket based proxy can not be started. + """ diff --git a/google/cloud/sql/connector/local_unix_socket.py b/google/cloud/sql/connector/local_unix_socket.py new file mode 100644 index 000000000..25a9d3be3 --- /dev/null +++ b/google/cloud/sql/connector/local_unix_socket.py @@ -0,0 +1,35 @@ +""" +Copyright 2025 Google LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import ssl +from typing import Any, TYPE_CHECKING + +def connect( + host: str, sock: ssl.SSLSocket, **kwargs: Any +) -> "ssl.SSLSocket": + """Helper function to retrieve the socket for local UNIX sockets. + + Args: + host (str): A string containing the socket path used by the local proxy. + sock (ssl.SSLSocket): An SSLSocket object created from the Cloud SQL + server CA cert and ephemeral cert. + kwargs: Additional arguments to pass to the local UNIX socket connect method. + + Returns: + ssl.SSLSocket: The same socket + """ + + return sock diff --git a/google/cloud/sql/connector/proxy.py b/google/cloud/sql/connector/proxy.py new file mode 100644 index 000000000..99a121782 --- /dev/null +++ b/google/cloud/sql/connector/proxy.py @@ -0,0 +1,266 @@ +""" +Copyright 2025 Google LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +from __future__ import annotations + +from abc import ABC +from abc import abstractmethod +import asyncio +from functools import partial +import logging +import os +from pathlib import Path +from typing import Callable, List + +logger = logging.getLogger(name=__name__) + + +class BaseProxyProtocol(asyncio.Protocol): + """ + A protocol to proxy data between two transports. + """ + + def __init__(self, proxy: Proxy): + super().__init__() + self.proxy = proxy + self._buffer = bytearray() + self._target: asyncio.Transport | None = None + self.transport: asyncio.Transport | None = None + self._cached: List[bytes] = [] + logger.debug(f"__init__ {self}") + + def connection_made(self, transport): + logger.debug(f"connection_made {self}") + self.transport = transport + + def data_received(self, data): + if self._target is None: + self._cached.append(data) + else: + self._target.write(data) + + def set_target(self, target: asyncio.Transport): + logger.debug(f"set_target {self}") + self._target = target + if self._cached: + self._target.writelines(self._cached) + self._cached = [] + + def eof_received(self): + logger.debug(f"eof_received {self}") + if self._target is not None: + self._target.write_eof() + + def connection_lost(self, exc: Exception | None): + logger.debug(f"connection_lost {exc} {self}") + if self._target is not None: + self._target.close() + + +class ProxyClientConnection: + """ + Holds all of the tasks and details for a client proxy + """ + + def __init__( + self, + client_transport: asyncio.Transport, + client_protocol: ClientToServerProtocol, + ): + self.client_transport = client_transport + self.client_protocol = client_protocol + self.server_transport: asyncio.Transport | None = None + self.server_protocol: ServerToClientProtocol | None = None + self.task: asyncio.Task | None = None + + def close(self): + logger.debug(f"closing {self}") + if self.client_transport is not None: + self._close_transport(self.client_transport) + if self.server_transport is not None: + self._close_transport(self.server_transport) + + def _close_transport(self, transport:asyncio.Transport): + if transport.is_closing(): + return + if transport.can_write_eof(): + transport.write_eof() + else: + transport.close() + +class ClientToServerProtocol(BaseProxyProtocol): + """ + Protocol to copy bytes from the unix socket client to the database server + """ + + def __init__(self, proxy: Proxy): + super().__init__(proxy) + self._buffer = bytearray() + self._target: asyncio.Transport | None = None + logger.debug(f"__init__ {self}") + + def connection_made(self, transport): + # When a connection is made, open the server connection + super().connection_made(transport) + self.proxy._handle_client_connection(transport, self) + + +class ServerToClientProtocol(BaseProxyProtocol): + """ + Protocol to copy bytes from the database server to the client socket + """ + + def __init__(self, proxy: Proxy, cconn: ProxyClientConnection): + super().__init__(proxy) + self._buffer = bytearray() + self._target = cconn.client_transport + self._client_protocol = cconn.client_protocol + logger.debug(f"__init__ {self}") + + def connection_made(self, transport): + super().connection_made(transport) + self._client_protocol.set_target(transport) + + def connection_lost(self, exc: Exception | None): + super().connection_lost(exc) + self.proxy._handle_server_connection_lost() + +class ServerConnectionFactory(ABC): + """ + ServerConnectionFactory is an abstract class that provides connections to the service. + """ + @abstractmethod + async def connect(self, protocol_fn: Callable[[], asyncio.Protocol]): + """ + Establishes a connection to the server and configures it to use the protocol + returned from protocol_fn, with asyncio.EventLoop.create_connection(). + :param protocol_fn: the protocol function + :return: None + """ + pass + +class Proxy: + """ + A class to represent a local Unix socket proxy for a Cloud SQL instance. + This class manages a Unix socket that listens for incoming connections and + proxies them to a Cloud SQL instance. + """ + + def __init__( + self, + unix_socket_path: str, + server_connection_factory: ServerConnectionFactory, + loop: asyncio.AbstractEventLoop, + ): + """ + Creates a new Proxy + :param unix_socket_path: the path to listen for the proxy connection + :param loop: The event loop + :param instance_connect: A function that will establish the async connection to the server + + The instance_connect function is an asynchronous function that should set up a new connection. + It takes one argument - another function that + """ + self.unix_socket_path = unix_socket_path + self.alive = True + self._loop = loop + self._server: asyncio.AbstractServer | None = None + self._client_connections: set[ProxyClientConnection] = set() + self._server_connection_factory = server_connection_factory + + async def start(self) -> None: + """Starts the Unix socket server.""" + if os.path.exists(self.unix_socket_path): + os.remove(self.unix_socket_path) + + parent_dir = Path(self.unix_socket_path).parent + parent_dir.mkdir(parents=True, exist_ok=True) + + def new_protocol() -> ClientToServerProtocol: + return ClientToServerProtocol(self) + + logger.debug(f"Socket path: {self.unix_socket_path}") + self._server = await self._loop.create_unix_server( + new_protocol, path=self.unix_socket_path + ) + self._loop.create_task(self._server.serve_forever()) + + def _handle_client_connection( + self, + client_transport: asyncio.Transport, + client_protocol: ClientToServerProtocol, + ) -> None: + """ + Register a new client connection and initiate the task to create a database connection. + This is called by ClientToServerProtocol.connection_made + + :param client_transport: the client transport for the client unix socket + :param client_protocol: the instance for the + :return: None + """ + conn = ProxyClientConnection(client_transport, client_protocol) + self._client_connections.add(conn) + conn.task = self._loop.create_task(self._create_db_instance_connection(conn)) + conn.task.add_done_callback(lambda _: self._client_connections.discard(conn)) + + def _handle_server_connection_lost( + self, + ) -> None: + """ + Closes the proxy server if the connection to the server is lost + + :return: None + """ + logger.debug(f"Closing proxy server due to lost connection") + self._loop.create_task(self.close()) + + async def _create_db_instance_connection(self, conn: ProxyClientConnection) -> None: + """ + Manages a single proxy connection from a client to the Cloud SQL instance. + """ + try: + logger.debug("_proxy_connection() started") + new_protocol = partial(ServerToClientProtocol, self, conn) + + # Establish connection to the database + await self._server_connection_factory.connect(new_protocol) + logger.debug("_proxy_connection() succeeded") + + except Exception as e: + logger.error(f"Error handling proxy connection: {e}") + await self.close() + raise e + + async def close(self) -> None: + """ + Shuts down the proxy server and cleans up resources. + """ + logger.info(f"Closing Unix socket proxy at {self.unix_socket_path}") + + if self._server: + self._server.close() + await self._server.wait_closed() + + if self._client_connections: + for conn in list(self._client_connections): + conn.close() + await asyncio.wait([c.task for c in self._client_connections if c.task is not None], timeout=0.1) + + if os.path.exists(self.unix_socket_path): + os.remove(self.unix_socket_path) + + logger.info(f"Unix socket proxy for {self.unix_socket_path} closed.") + self.alive = False \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index cbf0dd10f..8ffce4d63 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -82,6 +82,10 @@ exclude = ['docs/*', 'samples/*'] [tool.pytest.ini_options] asyncio_mode = "auto" +log_cli = true +log_cli_level = "DEBUG" +log_cli_format = "%(asctime)s [%(levelname)8s] %(message)s (%(filename)s:%(lineno)s)" +log_cli_date_format = "%Y-%m-%d %H:%M:%S.%f" [tool.ruff.lint] extend-select = ["I"] diff --git a/requirements-test.txt b/requirements-test.txt index 8f690bfc5..85e96fddf 100644 --- a/requirements-test.txt +++ b/requirements-test.txt @@ -7,6 +7,7 @@ sqlalchemy-pytds==1.0.2 sqlalchemy-stubs==0.4 PyMySQL==1.1.1 pg8000==1.31.4 +psycopg[binary]==3.2.9 asyncpg==0.30.0 python-tds==1.16.1 aioresponses==0.7.8 diff --git a/tests/conftest.py b/tests/conftest.py index 83d7a78f3..54836ee84 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -15,6 +15,8 @@ """ import asyncio +from asyncio import Server +import logging import os import socket import ssl @@ -36,6 +38,7 @@ from google.cloud.sql.connector.utils import write_to_file SCOPES = ["https://www.googleapis.com/auth/sqlservice.admin"] +logger = logging.getLogger(name=__name__) def pytest_addoption(parser: Any) -> None: @@ -84,55 +87,138 @@ def fake_credentials() -> FakeCredentials: return FakeCredentials() -async def start_proxy_server(instance: FakeCSQLInstance) -> None: +async def start_proxy_server_async( + instance: FakeCSQLInstance, with_read_write: bool +) -> Server: """Run local proxy server capable of performing mTLS""" ip_address = "127.0.0.1" port = 3307 - # create socket - with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: - # create SSL/TLS context - context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) - context.minimum_version = ssl.TLSVersion.TLSv1_3 - # tmpdir and its contents are automatically deleted after the CA cert - # and cert chain are loaded into the SSLcontext. The values - # need to be written to files in order to be loaded by the SSLContext - server_key_bytes = instance.server_key.private_bytes( - encoding=serialization.Encoding.PEM, - format=serialization.PrivateFormat.TraditionalOpenSSL, - encryption_algorithm=serialization.NoEncryption(), + logger.debug("start_proxy_server_async started") + + context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) + context.minimum_version = ssl.TLSVersion.TLSv1_3 + # tmpdir and its contents are automatically deleted after the CA cert + # and cert chain are loaded into the SSLcontext. The values + # need to be written to files in order to be loaded by the SSLContext + server_key_bytes = instance.server_key.private_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.TraditionalOpenSSL, + encryption_algorithm=serialization.NoEncryption(), + ) + + async with TemporaryDirectory() as tmpdir: + server_filename, _, key_filename = await write_to_file( + tmpdir, instance.server_cert_pem, "", server_key_bytes ) - async with TemporaryDirectory() as tmpdir: - server_filename, _, key_filename = await write_to_file( - tmpdir, instance.server_cert_pem, "", server_key_bytes - ) - context.load_cert_chain(server_filename, key_filename) - # allow socket to be re-used - sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - # bind socket to Cloud SQL proxy server port on localhost - sock.bind((ip_address, port)) - # listen for incoming connections - sock.listen(5) - - with context.wrap_socket(sock, server_side=True) as ssock: - while True: - conn, _ = ssock.accept() - conn.close() + context.load_cert_chain(server_filename, key_filename) + + async def handler(reader: asyncio.StreamReader, writer: asyncio.StreamWriter): + logger.debug("Received fake connection") + if with_read_write: + line = await reader.readline() + logger.debug(f"Received request {line}") + writer.write("world\n".encode("utf-8")) + await writer.drain() + logger.debug("Wrote response") + if writer.can_write_eof(): + writer.write_eof() + logger.debug("Closing connection") + writer.close() + await writer.wait_closed() + logger.debug("Closed connection") + + server = await asyncio.start_server( + handler, host=ip_address, port=port, ssl=context + ) + logger.debug("Listening on 127.0.0.1:3307") + asyncio.create_task(server.serve_forever()) + return server -@pytest.fixture(scope="session") -def proxy_server(fake_instance: FakeCSQLInstance) -> None: - """Run local proxy server capable of performing mTLS""" - thread = Thread( - target=asyncio.run, - args=( - start_proxy_server( - fake_instance, - ), - ), - daemon=True, +@pytest.fixture(scope="function") +def proxy_server_async(fake_instance: FakeCSQLInstance): + # Create an event loop in a different thread for the server + loop = asyncio.new_event_loop() + + def f(loop): + asyncio.set_event_loop(loop) + loop.run_forever() + logger.debug("exiting thread") + + t = Thread(target=f, args=(loop,)) + t.start() + t.join(1) + + # Submit the server task to the thread + server_fut = asyncio.run_coroutine_threadsafe( + start_proxy_server_async(fake_instance, True), loop ) - thread.start() - thread.join(1.0) # add a delay to allow the proxy server to start + while not server_fut.done(): + t.join(0.1) + logger.debug("proxy_server_async server started") + yield + logger.debug("proxy_server_async fixture done") + + logger.debug("proxy_server_async fixture cleanup") + + # Stop the server after the test is complete + async def stop_server(): + logger.debug("inside_cleanup closing server") + server_fut.result().close() + loop.shutdown_asyncgens() + loop.stop() + logger.debug("inside_cleanup end") + + logger.debug("cleanup starting") + asyncio.run_coroutine_threadsafe(stop_server(), loop) + logger.debug("cleanup done") + while loop.is_running(): + t.join(0.1) + logger.debug("loop is not running") + loop.close() + t.join(1) + + +@pytest.fixture(scope="function") +def proxy_server(fake_instance: FakeCSQLInstance): + # Create an event loop in a different thread for the server + loop = asyncio.new_event_loop() + + def f(loop): + asyncio.set_event_loop(loop) + loop.run_forever() + logger.debug("exiting thread") + + t = Thread(target=f, args=(loop,)) + t.start() + + # Submit the server task to the thread + server_fut = asyncio.run_coroutine_threadsafe( + start_proxy_server_async(fake_instance, False), loop + ) + while not server_fut.done(): + t.join(0.1) + logger.debug("proxy_server_async server started") + yield + logger.debug("proxy_server_async fixture done") + + logger.debug("proxy_server_async fixture cleanup") + + # Stop the server after the test is complete + async def stop_server(): + logger.debug("inside_cleanup closing server") + server_fut.result().close() + loop.shutdown_asyncgens() + loop.stop() + logger.debug("inside_cleanup end") + + logger.debug("cleanup starting") + asyncio.run_coroutine_threadsafe(stop_server(), loop) + logger.debug("cleanup done") + while loop.is_running(): + t.join(0.1) + logger.debug("loop is not running") + loop.close() @pytest.fixture @@ -191,3 +277,33 @@ async def cache(fake_client: CloudSQLClient) -> AsyncGenerator[RefreshAheadCache ) yield cache await cache.close() + + +@pytest.fixture +def connected_socket_pair() -> tuple[socket.socket, socket.socket]: + """A fixture that provides a pair of connected sockets.""" + server, client = socket.socketpair() + yield server, client + server.close() + client.close() + + +@pytest.fixture +async def echo_server() -> AsyncGenerator[tuple[str, int], None]: + """A fixture that starts an asyncio echo server.""" + + async def handle_echo(reader, writer): + while True: + data = await reader.read(100) + if not data: + break + writer.write(data) + await writer.drain() + writer.close() + await writer.wait_closed() + + server = await asyncio.start_server(handle_echo, "127.0.0.1", 0) + addr = server.sockets[0].getsockname() + yield addr + server.close() + await server.wait_closed() \ No newline at end of file diff --git a/tests/system/test_psycopg_connection.py b/tests/system/test_psycopg_connection.py new file mode 100644 index 000000000..25a46f3b8 --- /dev/null +++ b/tests/system/test_psycopg_connection.py @@ -0,0 +1,83 @@ +""" +Copyright 2025 Google LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import asyncio +from datetime import datetime +import os + +# [START cloud_sql_connector_postgres_psycopg] +from typing import Union + +from psycopg import Connection +import pytest +import logging +import sqlalchemy + +from google.cloud.sql.connector import Connector +from google.cloud.sql.connector import DefaultResolver +from google.cloud.sql.connector import DnsResolver + +SERVER_PROXY_PORT = 3307 + +logger = logging.getLogger(name=__name__) + +# [END cloud_sql_connector_postgres_psycopg] + + +@pytest.mark.asyncio +async def test_psycopg_connection() -> None: + """Basic test to get time from database.""" + instance_connection_name = os.environ["POSTGRES_CONNECTION_NAME"] + user = os.environ["POSTGRES_USER"] + password = os.environ["POSTGRES_PASS"] + db = os.environ["POSTGRES_DB"] + ip_type = os.environ.get("IP_TYPE", "public") + + unix_socket_folder = "/tmp/conn" + unix_socket_path = f"{unix_socket_folder}/.s.PGSQL.3307" + + async with Connector( + refresh_strategy='lazy', resolver=DefaultResolver + ) as connector: + # Open proxy connection + # start the proxy server + + await connector.start_unix_socket_proxy_async( + instance_connection_name, + unix_socket_path, + ip_type=ip_type, # can be "public", "private" or "psc" + ) + + # Wait for server to start + await asyncio.sleep(0.5) + + engine = sqlalchemy.create_engine( + "postgresql+psycopg://", + creator=lambda: Connection.connect( + f"host={unix_socket_folder} port={SERVER_PROXY_PORT} dbname={db} user={user} password={password} sslmode=require", + user=user, + password=password, + dbname=db, + autocommit=True, + ) + ) + + with engine.connect() as conn: + time = conn.execute(sqlalchemy.text("SELECT NOW()")).fetchone() + conn.commit() + curr_time = time[0] + assert type(curr_time) is datetime + connector.close() diff --git a/tests/unit/test_connector.py b/tests/unit/test_connector.py index 1bcb42616..612b1d3b4 100644 --- a/tests/unit/test_connector.py +++ b/tests/unit/test_connector.py @@ -1,12 +1,9 @@ """ Copyright 2021 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -15,7 +12,9 @@ """ import asyncio +import logging import os +import socket from threading import Thread from typing import Union @@ -34,20 +33,28 @@ from google.cloud.sql.connector.exceptions import IncompatibleDriverError from google.cloud.sql.connector.instance import RefreshAheadCache +logger = logging.getLogger(name=__name__) + @pytest.mark.asyncio async def test_connect_enable_iam_auth_error( - fake_credentials: Credentials, fake_client: CloudSQLClient + fake_credentials: Credentials, + fake_client: CloudSQLClient, + connected_socket_pair: tuple[socket.socket, socket.socket], ) -> None: """Test that calling connect() with different enable_iam_auth argument values creates two cache entries.""" connect_string = "test-project:test-region:test-instance" + server, client = connected_socket_pair async with Connector( credentials=fake_credentials, loop=asyncio.get_running_loop() ) as connector: connector._client = fake_client # patch db connection creation - with patch("google.cloud.sql.connector.asyncpg.connect") as mock_connect: + with ( + patch("socket.create_connection", return_value=client), + patch("google.cloud.sql.connector.asyncpg.connect") as mock_connect, + ): mock_connect.return_value = True # connect with enable_iam_auth False connection = await connector.connect_async( @@ -80,6 +87,7 @@ async def test_connect_enable_iam_auth_error( async def test_connect_incompatible_driver_error( fake_credentials: Credentials, fake_client: CloudSQLClient, + proxy_server, ) -> None: """Test that calling connect() with driver that is incompatible with database version throws error.""" @@ -89,14 +97,8 @@ async def test_connect_incompatible_driver_error( ) as connector: connector._client = fake_client # try to connect using pymysql driver to a Postgres database - with pytest.raises(IncompatibleDriverError) as exc_info: + with pytest.raises(IncompatibleDriverError): await connector.connect_async(connect_string, "pymysql") - assert ( - exc_info.value.args[0] - == "Database driver 'pymysql' is incompatible with database version" - " 'POSTGRES_15'. Given driver can only be used with Cloud SQL MYSQL" - " databases." - ) def test_connect_with_unsupported_driver(fake_credentials: Credentials) -> None: @@ -237,13 +239,19 @@ def test_Connector_Init_bad_ip_type(fake_credentials: Credentials) -> None: def test_Connector_connect_bad_ip_type( - fake_credentials: Credentials, fake_client: CloudSQLClient + fake_credentials: Credentials, + fake_client: CloudSQLClient, + connected_socket_pair: tuple[socket.socket, socket.socket], ) -> None: """Test that Connector.connect errors due to bad ip_type str.""" + server, client = connected_socket_pair with Connector(credentials=fake_credentials) as connector: connector._client = fake_client bad_ip_type = "bad-ip-type" - with pytest.raises(ValueError) as exc_info: + with ( + patch("socket.create_connection", return_value=client), + pytest.raises(ValueError) as exc_info, + ): connector.connect( "test-project:test-region:test-instance", "pg8000", @@ -261,15 +269,21 @@ def test_Connector_connect_bad_ip_type( @pytest.mark.asyncio async def test_Connector_connect_async( - fake_credentials: Credentials, fake_client: CloudSQLClient + fake_credentials: Credentials, + fake_client: CloudSQLClient, + connected_socket_pair: tuple[socket.socket, socket.socket], ) -> None: """Test that Connector.connect_async can properly return a DB API connection.""" + server, client = connected_socket_pair async with Connector( credentials=fake_credentials, loop=asyncio.get_running_loop() ) as connector: connector._client = fake_client # patch db connection creation - with patch("google.cloud.sql.connector.asyncpg.connect") as mock_connect: + with ( + patch("socket.create_connection", return_value=client), + patch("google.cloud.sql.connector.asyncpg.connect") as mock_connect, + ): mock_connect.return_value = True connection = await connector.connect_async( "test-project:test-region:test-instance", @@ -348,7 +362,9 @@ def test_Connector_close_called_multiple_times(fake_credentials: Credentials) -> async def test_Connector_remove_cached_bad_instance( - fake_credentials: Credentials, fake_client: CloudSQLClient + fake_credentials: Credentials, + fake_client: CloudSQLClient, + proxy_server, ) -> None: """When a Connector attempts to retrieve connection info for a non-existent instance, it should delete the instance from @@ -373,7 +389,9 @@ async def test_Connector_remove_cached_bad_instance( async def test_Connector_remove_cached_no_ip_type( - fake_credentials: Credentials, fake_client: CloudSQLClient + fake_credentials: Credentials, + fake_client: CloudSQLClient, + proxy_server, ) -> None: """When a Connector attempts to connect and preferred IP type is not present, it should delete the instance from the cache and ensure no background refresh @@ -502,3 +520,121 @@ def test_configured_quota_project_env_var( assert connector._quota_project == quota_project # unset env var del os.environ["GOOGLE_CLOUD_QUOTA_PROJECT"] + + +@pytest.mark.asyncio +async def test_Connector_start_unix_socket_proxy_async( + fake_credentials: Credentials, + fake_client: CloudSQLClient, + proxy_server_async: None, +) -> None: + """Test that Connector.connect_async can properly return a DB API connection.""" + async with Connector( + credentials=fake_credentials, loop=asyncio.get_running_loop() + ) as connector: + connector._client = fake_client + + # Open proxy connection + # start the proxy server + await connector.start_unix_socket_proxy_async( + "test-project:test-region:test-instance", + "/tmp/csql-python/proxytest/.s.PGSQL.5432", + driver="asyncpg", + user="my-user", + password="my-pass", + db="my-db", + ) + # Wait for server to start + await asyncio.sleep(0.5) + + reader, writer = await asyncio.open_unix_connection( + "/tmp/csql-python/proxytest/.s.PGSQL.5432" + ) + writer.write("hello\n".encode()) + await writer.drain() + await asyncio.sleep(0.5) + msg = await reader.readline() + assert msg.decode("utf-8") == "world\n" + + +class TestProtocol(asyncio.Protocol): + """ + A protocol to proxy data between two transports. + """ + + def __init__(self): + self._buffer = bytearray() + logger.debug(f"__init__ {self}") + self.received = bytearray() + self.connected = asyncio.Future() + self.future = asyncio.Future() + + def data_received(self, data): + logger.debug("received {!r}".format(data)) + self.received = data + + def connection_made(self, transport): + logger.debug(f"connection_made called {self}") + self.transport = transport + if not self.connected.done(): + self.connected.set_result(True) + # Write the request and EOF + transport.write("hello\n".encode()) + # if transport.can_write_eof(): + # transport.write_eof() + logger.debug(f"connection_made done, wrote hello{self}") + + def eof_received(self) -> bool | None: + logger.debug(f"eof_received {self}") + # If this has received data, then close. + if len(self.received) > 0: + self.transport.close() + if not self.connected.done(): + self.connected.set_result(True) + if not self.future.done(): + self.future.set_result(True) + return True + + def connection_lost(self, exc: Exception | None) -> None: + logger.debug(f"connection_lost {exc} {self}") + self.transport.abort() + if not self.connected.done(): + self.connected.set_result(True) + if not self.future.done(): + self.future.set_result(True) + super().connection_lost(exc) + + +@pytest.mark.asyncio +async def test_Connector_connect_socket_async( + fake_credentials: Credentials, + fake_client: CloudSQLClient, + proxy_server_async: None, +) -> None: + """Test that Connector.connect_async can properly return a DB API connection.""" + async with Connector( + credentials=fake_credentials, loop=asyncio.get_running_loop() + ) as connector: + logger.info("client socket opening") + connector._client = fake_client + p = TestProtocol() + + # Open proxy connection + # start the proxy server + future = connector.connect_socket_async( + "test-project:test-region:test-instance", + lambda: p, + driver="asyncpg", + user="my-user", + password="my-pass", + db="my-db", + ) + logger.info("client socket opening") + await future + logger.info("client socket opened") + await p.connected + logger.info("client socket connected") + await p.future + logger.info("client socket done") + + assert p.received.decode() == "world\n" \ No newline at end of file diff --git a/tests/unit/test_local_unix_socket.py b/tests/unit/test_local_unix_socket.py new file mode 100644 index 000000000..8672857ec --- /dev/null +++ b/tests/unit/test_local_unix_socket.py @@ -0,0 +1,37 @@ +""" +Copyright 2025 Google LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import socket +import ssl +from typing import Any + +from mock import patch +from mock import PropertyMock +import pytest + +from google.cloud.sql.connector.local_unix_socket import connect + + +@pytest.mark.usefixtures("proxy_server") +async def test_local_unix_socket(context: ssl.SSLContext, kwargs: Any) -> None: + """Test to verify that local_unix_socket gets to proper connection call.""" + ip_addr = "127.0.0.1" + sock = context.wrap_socket( + socket.create_connection((ip_addr, 3307)), + server_hostname=ip_addr, + ) + connection = connect(ip_addr, sock, **kwargs) + assert connection == sock diff --git a/tests/unit/test_proxy.py b/tests/unit/test_proxy.py new file mode 100644 index 000000000..0c179186d --- /dev/null +++ b/tests/unit/test_proxy.py @@ -0,0 +1,380 @@ +""" +Copyright 2025 Google LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import asyncio +import os +import shutil +import tempfile +from unittest.mock import MagicMock + +import pytest + +from google.cloud.sql.connector.proxy import Proxy, ServerConnectionFactory + + +@pytest.fixture +def short_tmpdir(): + """Create a temporary directory with a short path.""" + dir_path = tempfile.mkdtemp(dir="/tmp") + yield dir_path + shutil.rmtree(dir_path) + + +@pytest.mark.asyncio +async def test_proxy_creates_folder_and_socket(short_tmpdir): + """ + Test to verify that the Proxy server creates the folder and socket file. + """ + socket_path = os.path.join(short_tmpdir, ".s.PGSQL.5432") + connector = MagicMock(spec=ServerConnectionFactory) + proxy = Proxy(socket_path, connector, asyncio.get_event_loop()) + await proxy.start() + + assert os.path.exists(short_tmpdir) + assert os.path.exists(socket_path) + + await proxy.close() + + +# A mock ServerConnectionFactory for testing purposes. +class MockServerConnectionFactory(ServerConnectionFactory): + def __init__(self, loop): + self.server_protocol = None + self.server_transport = None + self.connect_called = asyncio.Event() + self.connect_ran = asyncio.Event() + self.force_connect_error = False + self.loop = loop + self.server_data = bytearray() + + async def connect(self, protocol_fn): + self.connect_called.set() + if self.force_connect_error: + raise Exception("Forced connection error") + + self.server_protocol = protocol_fn() + # Create a mock transport for server-side communication + self.server_transport = MagicMock(spec=asyncio.Transport) + self.server_transport.write.side_effect = self.server_data.extend + self.server_transport.is_closing.return_value = False + + # Simulate connection made for the server protocol + self.server_protocol.connection_made(self.server_transport) + self.connect_ran.set() + return self.server_transport, self.server_protocol + + +# Test fixture for the proxy +@pytest.fixture +async def proxy_server(short_tmpdir): + socket_path = os.path.join(short_tmpdir, ".s.PGSQL.5432") + loop = asyncio.get_event_loop() + connector = MockServerConnectionFactory(loop) + proxy = Proxy(socket_path, connector, loop) + await proxy.start() + yield proxy, socket_path, connector + await proxy.close() + + +@pytest.mark.asyncio +async def test_proxy_client_to_server(proxy_server): + """ + 1. Create a new proxy. Open a client socket to the proxy. Write data to + the client socket. Read data from the server. Check that the data was + received by the server. + """ + proxy, socket_path, connector = proxy_server + reader, writer = await asyncio.open_unix_connection(socket_path) + + # wait for server connection to be established + await connector.connect_called.wait() + + # Write data to the client socket + test_data = b"test data from client" + writer.write(test_data) + await writer.drain() + + # Check that the data was received by the server + await asyncio.sleep(0.01) # give event loop a chance to run + assert connector.server_data == test_data + + writer.close() + await writer.wait_closed() + + +@pytest.mark.asyncio +async def test_proxy_server_to_client(proxy_server): + """ + 2. Create a new proxy, Open a client socket. Write data to the server + socket. Read data from the client socket. Check that the data was + received by the client. + """ + proxy, socket_path, connector = proxy_server + reader, writer = await asyncio.open_unix_connection(socket_path) + + # wait for server connection to be established + await connector.connect_called.wait() + + # Write data from the server to the client + test_data = b"test data from server" + connector.server_protocol.data_received(test_data) + + # Read data from the client socket + received_data = await reader.read(len(test_data)) + + # Check that the data was received by the client + assert received_data == test_data + + writer.close() + await writer.wait_closed() + + +@pytest.mark.asyncio +async def test_proxy_server_connect_fails(proxy_server): + """ + 3. Create a new proxy. Open a client socket. The server socket fails to + connect. Check that the client socket is closed. + """ + proxy, socket_path, connector = proxy_server + connector.force_connect_error = True + + reader, writer = await asyncio.open_unix_connection(socket_path) + + # wait for server connection to be attempted + await connector.connect_called.wait() + + assert os.path.exists(socket_path) == True + + # The client connection should be closed by the proxy + # Reading should return EOF + data = await reader.read(100) + assert data == b"" + + await asyncio.sleep(1) # give proxy a chance to shut down + + assert os.path.exists(socket_path) == False + + +@pytest.mark.asyncio +async def test_proxy_client_closes_connection(proxy_server): + """ + 4. Create a new proxy. Open a client socket. Check that the server + socket connected. Close the client socket. Check that the server socket + is closed gracefully. + """ + proxy, socket_path, connector = proxy_server + reader, writer = await asyncio.open_unix_connection(socket_path) + + # wait for server connection to be established + await connector.connect_called.wait() + assert connector.server_transport is not None + + # Close the client socket + writer.close() + await writer.wait_closed() + + # Check that the server socket is closed + await asyncio.sleep(0.01) # give event loop a chance to run + connector.server_transport.close.assert_called_once() + + +# +# TCP Server Fixtures and Tests +# + + +@pytest.fixture +async def tcp_echo_server(): + """Fixture to create a TCP echo server.""" + + async def echo(reader, writer): + try: + while not reader.at_eof(): + data = await reader.read(1024) + if not data: + break + writer.write(data) + await writer.drain() + finally: + writer.close() + await writer.wait_closed() + + server = await asyncio.start_server(echo, "127.0.0.1", 0) + addr = server.sockets[0].getsockname() + host, port = addr[0], addr[1] + + yield host, port + + server.close() + await server.wait_closed() + + +@pytest.fixture +async def tcp_server_accept_and_close(): + """Fixture to create a TCP server that accepts and immediately closes.""" + + async def accept_and_close(reader, writer): + writer.close() + await writer.wait_closed() + + server = await asyncio.start_server(accept_and_close, "127.0.0.1", 0) + addr = server.sockets[0].getsockname() + host, port = addr[0], addr[1] + + yield host, port + + server.close() + await server.wait_closed() + + +class TCPServerConnectionFactory(ServerConnectionFactory): + """A ServerConnectionFactory that connects to a TCP server.""" + + def __init__(self, host, port, loop): + self.host = host + self.port = port + self.loop = loop + self.connect_called = asyncio.Event() + self.connect_ran = asyncio.Event() + self.server_transport: asyncio.Transport | None = None + self.server_protocol: asyncio.Protocol | None = None + + async def connect(self, protocol_fn): + self.connect_called.set() + transport, protocol = await asyncio.wait_for(self.loop.create_connection( + protocol_fn, self.host, self.port, + ), timeout=0.5) + self.server_transport = transport + self.server_protocol = protocol + self.connect_ran.set() + return transport, protocol + + +@pytest.fixture +async def tcp_proxy_server(short_tmpdir, tcp_echo_server): + """Fixture to set up a proxy with a TCP backend.""" + socket_path = os.path.join(short_tmpdir, ".s.PGSQL.5432") + loop = asyncio.get_event_loop() + host, port = tcp_echo_server + connector = TCPServerConnectionFactory(host, port, loop) + proxy = Proxy(socket_path, connector, loop) + await proxy.start() + yield proxy, socket_path, connector + await proxy.close() + + +@pytest.fixture +async def tcp_proxy_server_with_closing_backend(short_tmpdir, tcp_server_accept_and_close): + """Fixture to set up a proxy with a TCP backend that closes immediately.""" + socket_path = os.path.join(short_tmpdir, ".s.PGSQL.5432") + loop = asyncio.get_event_loop() + host, port = tcp_server_accept_and_close + connector = TCPServerConnectionFactory(host, port, loop) + proxy = Proxy(socket_path, connector, loop) + await proxy.start() + yield proxy, socket_path, connector + await proxy.close() + + +@pytest.fixture +async def tcp_proxy_server_with_no_tcp_server(short_tmpdir): + """Fixture to set up a proxy with a TCP backend that closes immediately.""" + socket_path = os.path.join(short_tmpdir, ".s.PGSQL.5432") + loop = asyncio.get_event_loop() + connector = TCPServerConnectionFactory("localhost", "34532", loop) + proxy = Proxy(socket_path, connector, loop) + await proxy.start() + yield proxy, socket_path, connector + await proxy.close() + + +@pytest.mark.asyncio +async def test_tcp_proxy_echo(tcp_proxy_server): + """ + Tests data flow from client to a TCP server and back. + """ + proxy, socket_path, connector = tcp_proxy_server + reader, writer = await asyncio.open_unix_connection(socket_path) + await connector.connect_called.wait() + + test_data = b"test data from client" + writer.write(test_data) + await writer.drain() + + # Read echoed data back from the server + received_data = await reader.read(len(test_data)) + assert received_data == test_data + + writer.close() + await writer.wait_closed() + + +@pytest.mark.asyncio +async def test_tcp_proxy_server_connection_refused(tcp_proxy_server_with_no_tcp_server): + """ + Tests that the client socket is closed when TCP connection fails. + """ + proxy, socket_path, connector = tcp_proxy_server_with_no_tcp_server + + reader, writer = await asyncio.open_unix_connection(socket_path) + await connector.connect_called.wait() + test_data = b"test data from client" + writer.write(test_data) + await writer.drain() + + await asyncio.sleep(1.5) + assert os.path.exists(socket_path) == False + + + +@pytest.mark.asyncio +async def test_tcp_proxy_server_unexpected_closed(tcp_proxy_server_with_closing_backend): + """ + Tests that the client socket is closed when TCP connection fails. + """ + proxy, socket_path, connector = tcp_proxy_server_with_closing_backend + + reader, writer = await asyncio.open_unix_connection(socket_path) + await connector.connect_called.wait() + + # The client connection should be closed by the proxy + data = await reader.read(100) + assert data == b"" + + await asyncio.sleep(0.5) # give event loop a chance to run + assert os.path.exists(socket_path) == False + + + +@pytest.mark.asyncio +async def test_tcp_proxy_client_closes_connection(tcp_proxy_server): + """ + Tests that closing the client socket closes the TCP server socket. + """ + proxy, socket_path, connector = tcp_proxy_server + reader, writer = await asyncio.open_unix_connection(socket_path) + await connector.connect_ran.wait() + + assert connector.server_transport is not None + assert not connector.server_transport.is_closing() + + # Close the client socket + writer.close() + await writer.wait_closed() + + # Check that the server socket is closing + await asyncio.sleep(0.01) + assert connector.server_transport.is_closing() \ No newline at end of file