From 280b72c2aa0440ddd789a9997d041f3ca84b78c4 Mon Sep 17 00:00:00 2001 From: "release-please[bot]" <55107282+release-please[bot]@users.noreply.github.com> Date: Fri, 11 Jul 2025 15:00:07 -0700 Subject: [PATCH 01/14] chore(main): release 1.18.3 (#1309) Co-authored-by: release-please[bot] <55107282+release-please[bot]@users.noreply.github.com> --- CHANGELOG.md | 7 +++ README.md | 3 +- google/cloud/sql/connector/connector.py | 9 ++-- google/cloud/sql/connector/enums.py | 1 + google/cloud/sql/connector/proxy.py | 60 +++++++++++++++++++++ google/cloud/sql/connector/psycopg.py | 72 +++++++++++++++++++++++++ google/cloud/sql/connector/version.py | 2 +- pyproject.toml | 1 + requirements-test.txt | 1 + tests/system/test_psycopg_connection.py | 60 +++++++++++++++++++++ tests/unit/test_psycopg.py | 40 ++++++++++++++ 11 files changed, 251 insertions(+), 5 deletions(-) create mode 100644 google/cloud/sql/connector/proxy.py create mode 100644 google/cloud/sql/connector/psycopg.py create mode 100644 tests/system/test_psycopg_connection.py create mode 100644 tests/unit/test_psycopg.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 26374034b..5883ebe65 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,12 @@ # Changelog +## [1.18.3](https://github.com/GoogleCloudPlatform/cloud-sql-python-connector/compare/v1.18.2...v1.18.3) (2025-07-11) + + +### Bug Fixes + +* suppress lint check for _scopes property ([#1308](https://github.com/GoogleCloudPlatform/cloud-sql-python-connector/issues/1308)) ([821245c](https://github.com/GoogleCloudPlatform/cloud-sql-python-connector/commit/821245c1911fb970e3409b3e249698937a8b7867)) + ## [1.18.2](https://github.com/GoogleCloudPlatform/cloud-sql-python-connector/compare/v1.18.1...v1.18.2) (2025-05-20) diff --git a/README.md b/README.md index 1c5489e04..d11ac0883 100644 --- a/README.md +++ b/README.md @@ -42,6 +42,7 @@ The Cloud SQL Python Connector is a package to be used alongside a database driv Currently supported drivers are: - [`pymysql`](https://github.com/PyMySQL/PyMySQL) (MySQL) - [`pg8000`](https://github.com/tlocke/pg8000) (PostgreSQL) + - [`psycopg`](https://github.com/psycopg/psycopg) (PostgreSQL) - [`asyncpg`](https://github.com/MagicStack/asyncpg) (PostgreSQL) - [`pytds`](https://github.com/denisenkom/pytds) (SQL Server) @@ -587,7 +588,7 @@ async def main(): # acquire connection and query Cloud SQL database async with pool.acquire() as conn: res = await conn.fetch("SELECT NOW()") - + # close Connector await connector.close_async() ``` diff --git a/google/cloud/sql/connector/connector.py b/google/cloud/sql/connector/connector.py index 05eaa51df..2926fa87a 100755 --- a/google/cloud/sql/connector/connector.py +++ b/google/cloud/sql/connector/connector.py @@ -37,6 +37,7 @@ from google.cloud.sql.connector.lazy import LazyRefreshCache from google.cloud.sql.connector.monitored_cache import MonitoredCache import google.cloud.sql.connector.pg8000 as pg8000 +import google.cloud.sql.connector.psycopg as psycopg import google.cloud.sql.connector.pymysql as pymysql import google.cloud.sql.connector.pytds as pytds from google.cloud.sql.connector.resolver import DefaultResolver @@ -230,7 +231,7 @@ def connect( Example: "my-project:us-central1:my-instance" driver (str): A string representing the database driver to connect - with. Supported drivers are pymysql, pg8000, and pytds. + with. Supported drivers are pymysql, pg8000, psycopg, and pytds. **kwargs: Any driver-specific arguments to pass to the underlying driver .connect call. @@ -266,7 +267,8 @@ async def connect_async( Example: "my-project:us-central1:my-instance" driver (str): A string representing the database driver to connect - with. Supported drivers are pymysql, asyncpg, pg8000, and pytds. + with. Supported drivers are pymysql, asyncpg, pg8000, psycopg, and + pytds. **kwargs: Any driver-specific arguments to pass to the underlying driver .connect call. @@ -278,7 +280,7 @@ async def connect_async( ValueError: Connection attempt with built-in database authentication and then subsequent attempt with IAM database authentication. KeyError: Unsupported database driver Must be one of pymysql, asyncpg, - pg8000, and pytds. + pg8000, psycopg, and pytds. """ if self._keys is None: self._keys = asyncio.create_task(generate_keys()) @@ -332,6 +334,7 @@ async def connect_async( connect_func = { "pymysql": pymysql.connect, "pg8000": pg8000.connect, + "psycopg": psycopg.connect, "asyncpg": asyncpg.connect, "pytds": pytds.connect, } diff --git a/google/cloud/sql/connector/enums.py b/google/cloud/sql/connector/enums.py index e6b56af0e..5926b75a2 100644 --- a/google/cloud/sql/connector/enums.py +++ b/google/cloud/sql/connector/enums.py @@ -62,6 +62,7 @@ class DriverMapping(Enum): ASYNCPG = "POSTGRES" PG8000 = "POSTGRES" + PSYCOPG = "POSTGRES" PYMYSQL = "MYSQL" PYTDS = "SQLSERVER" diff --git a/google/cloud/sql/connector/proxy.py b/google/cloud/sql/connector/proxy.py new file mode 100644 index 000000000..1f398461a --- /dev/null +++ b/google/cloud/sql/connector/proxy.py @@ -0,0 +1,60 @@ +""" +Copyright 2025 Google LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import socket +import os +import threading +from pathlib import Path + +SERVER_PROXY_PORT = 3307 + +def start_local_proxy( + ssl_sock, + socket_path, +): + desired_path = Path(socket_path) + desired_path.mkdir(parents=True, exist_ok=True) + + if os.path.exists(socket_path): + os.remove(socket_path) + conn_unix = None + unix_socket = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + + unix_socket.bind(socket_path) + unix_socket.listen(1) + + threading.Thread(target=local_communication, args=(unix_socket, ssl_sock, socket_path)).start() + + +def local_communication( + unix_socket, ssl_sock, socket_path +): + try: + conn_unix, addr_unix = unix_socket.accept() + + while True: + data = conn_unix.recv(10485760) + if not data: + break + ssl_sock.sendall(data) + response = ssl_sock.recv(10485760) + conn_unix.sendall(response) + + finally: + if conn_unix is not None: + conn_unix.close() + unix_socket.close() + os.remove(socket_path) # Clean up the socket file diff --git a/google/cloud/sql/connector/psycopg.py b/google/cloud/sql/connector/psycopg.py new file mode 100644 index 000000000..c34dcb1e1 --- /dev/null +++ b/google/cloud/sql/connector/psycopg.py @@ -0,0 +1,72 @@ +""" +Copyright 2025 Google LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import ssl +from typing import Any, TYPE_CHECKING +import threading + +SERVER_PROXY_PORT = 3307 + +if TYPE_CHECKING: + import psycopg + + +def connect( + ip_address: str, sock: ssl.SSLSocket, **kwargs: Any +) -> "psycopg.Connection": + """Helper function to create a psycopg DB-API connection object. + + Args: + ip_address (str): A string containing an IP address for the Cloud SQL + instance. + sock (ssl.SSLSocket): An SSLSocket object created from the Cloud SQL + server CA cert and ephemeral cert. + kwargs: Additional arguments to pass to the psycopg connect method. + + Returns: + psycopg.Connection: A psycopg connection to the Cloud SQL + instance. + + Raises: + ImportError: The psycopg module cannot be imported. + """ + try: + from psycopg.rows import dict_row + from psycopg import Connection + import threading + from google.cloud.sql.connector.proxy import start_local_proxy + except ImportError: + raise ImportError( + 'Unable to import module "psycopg." Please install and try again.' + ) + + user = kwargs.pop("user") + db = kwargs.pop("db") + passwd = kwargs.pop("password", None) + + kwargs.pop("timeout", None) + + start_local_proxy(sock, f"/tmp/connector-socket/.s.PGSQL.3307") + + conn = Connection.connect( + f"host=/tmp/connector-socket port={SERVER_PROXY_PORT} dbname={db} user={user} password={passwd} sslmode=require", + autocommit=True, + row_factory=dict_row, + **kwargs + ) + + conn.autocommit = True + return conn diff --git a/google/cloud/sql/connector/version.py b/google/cloud/sql/connector/version.py index ad763d279..a60f3fe11 100644 --- a/google/cloud/sql/connector/version.py +++ b/google/cloud/sql/connector/version.py @@ -12,4 +12,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "1.18.2" +__version__ = "1.18.3" diff --git a/pyproject.toml b/pyproject.toml index cbf0dd10f..1c933cf6a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -59,6 +59,7 @@ Changelog = "https://github.com/GoogleCloudPlatform/cloud-sql-python-connector/b [project.optional-dependencies] pymysql = ["PyMySQL>=1.1.0"] pg8000 = ["pg8000>=1.31.1"] +psycopg = ["psycopg>=3.2.9"] pytds = ["python-tds>=1.15.0"] asyncpg = ["asyncpg>=0.30.0"] diff --git a/requirements-test.txt b/requirements-test.txt index a858da78f..624aac6d9 100644 --- a/requirements-test.txt +++ b/requirements-test.txt @@ -7,6 +7,7 @@ sqlalchemy-pytds==1.0.2 sqlalchemy-stubs==0.4 PyMySQL==1.1.1 pg8000==1.31.2 +psycopg[binary]==3.2.9 asyncpg==0.30.0 python-tds==1.16.1 aioresponses==0.7.8 diff --git a/tests/system/test_psycopg_connection.py b/tests/system/test_psycopg_connection.py new file mode 100644 index 000000000..5b3a3c79e --- /dev/null +++ b/tests/system/test_psycopg_connection.py @@ -0,0 +1,60 @@ +""" +Copyright 2025 Google LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +from datetime import datetime +import os + +# [START cloud_sql_connector_postgres_psycopg] + +from google.cloud.sql.connector import Connector +from google.cloud.sql.connector import DefaultResolver + +from sqlalchemy.dialects.postgresql.base import PGDialect +PGDialect._get_server_version_info = lambda *args: (9, 2) + +# [END cloud_sql_connector_postgres_psycopg] + + +def test_psycopg_connection() -> None: + """Basic test to get time from database.""" + inst_conn_name = os.environ["POSTGRES_CONNECTION_NAME"] + user = os.environ["POSTGRES_USER"] + password = os.environ["POSTGRES_PASS"] + db = os.environ["POSTGRES_DB"] + ip_type = os.environ.get("IP_TYPE", "public") + + connector = Connector(refresh_strategy="background", resolver=DefaultResolver) + + pool = connector.connect( + inst_conn_name, + "psycopg", + user=user, + password=password, + db=db, + ip_type=ip_type, # can be "public", "private" or "psc" + ) + + with pool as conn: + + # Open a cursor to perform database operations + with conn.cursor() as cur: + + # Query the database and obtain data as Python objects. + cur.execute("SELECT NOW()") + curr_time = cur.fetchone()["now"] + assert type(curr_time) is datetime + + diff --git a/tests/unit/test_psycopg.py b/tests/unit/test_psycopg.py new file mode 100644 index 000000000..8d9fe1f85 --- /dev/null +++ b/tests/unit/test_psycopg.py @@ -0,0 +1,40 @@ +""" +Copyright 2025 Google LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import socket +import ssl +from typing import Any + +from mock import patch, PropertyMock +import pytest + +from google.cloud.sql.connector.psycopg import connect + + +@pytest.mark.usefixtures("proxy_server") +async def test_psycopg(context: ssl.SSLContext, kwargs: Any) -> None: + """Test to verify that psycopg gets to proper connection call.""" + ip_addr = "127.0.0.1" + sock = context.wrap_socket( + socket.create_connection((ip_addr, 3307)), + server_hostname=ip_addr, + ) + with patch("psycopg.connect") as mock_connect: + type(mock_connect.return_value).autocommit = PropertyMock(return_value=True) + connection = connect(ip_addr, sock, **kwargs) + assert connection.autocommit is True + # verify that driver connection call would be made + assert mock_connect.assert_called_once From 89719a4e586ed2395c3403848377512a6b42de3e Mon Sep 17 00:00:00 2001 From: Uziel Silva Date: Wed, 16 Jul 2025 18:06:56 -0600 Subject: [PATCH 02/14] feat(main): add support for psycopg Changelog: - Add proxy for connections that can only be made through an unix socket, to support the TLS connection - Add support for psycopg, using the proxy server - Add unit and integration tests - Update docs --- google/cloud/sql/connector/proxy.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/google/cloud/sql/connector/proxy.py b/google/cloud/sql/connector/proxy.py index 1f398461a..5f46003bb 100644 --- a/google/cloud/sql/connector/proxy.py +++ b/google/cloud/sql/connector/proxy.py @@ -25,7 +25,10 @@ def start_local_proxy( ssl_sock, socket_path, ): - desired_path = Path(socket_path) + path_parts = socket_path.rsplit('/', 1) + parent_directory = '/'.join(path_parts[:-1]) + + desired_path = Path(parent_directory) desired_path.mkdir(parents=True, exist_ok=True) if os.path.exists(socket_path): From 28c1c40c7b5e29f858f30d849c8b2c733d623f30 Mon Sep 17 00:00:00 2001 From: Uziel Silva Date: Thu, 17 Jul 2025 21:44:09 -0600 Subject: [PATCH 03/14] fix(main): fix feedback PR Changelog: - Make local_socket_path configurable - Set right file permissions - Handle exceptions properly - Use asyncio and the main loop to stop the local proxy and clear the file when the connector is stopped --- google/cloud/sql/connector/connector.py | 18 +++- google/cloud/sql/connector/exceptions.py | 6 ++ google/cloud/sql/connector/proxy.py | 91 +++++++++++++------- google/cloud/sql/connector/psycopg.py | 15 +--- tests/system/test_psycopg_connection.py | 105 ++++++++++++++++++----- 5 files changed, 166 insertions(+), 69 deletions(-) diff --git a/google/cloud/sql/connector/connector.py b/google/cloud/sql/connector/connector.py index 2926fa87a..3c7378471 100755 --- a/google/cloud/sql/connector/connector.py +++ b/google/cloud/sql/connector/connector.py @@ -44,10 +44,12 @@ from google.cloud.sql.connector.resolver import DnsResolver from google.cloud.sql.connector.utils import format_database_user from google.cloud.sql.connector.utils import generate_keys +from google.cloud.sql.connector.proxy import start_local_proxy logger = logging.getLogger(name=__name__) ASYNC_DRIVERS = ["asyncpg"] +LOCAL_PROXY_DRIVERS = ["psycopg"] SERVER_PROXY_PORT = 3307 _DEFAULT_SCHEME = "https://" _DEFAULT_UNIVERSE_DOMAIN = "googleapis.com" @@ -383,7 +385,7 @@ async def connect_async( # async drivers are unblocking and can be awaited directly if driver in ASYNC_DRIVERS: return await connector( - ip_address, + host, await conn_info.create_ssl_context(enable_iam_auth), **kwargs, ) @@ -393,6 +395,18 @@ async def connect_async( socket.create_connection((ip_address, SERVER_PROXY_PORT)), server_hostname=ip_address, ) + + host = ip_address + # start local proxy if driver needs it + if driver in LOCAL_PROXY_DRIVERS: + local_socket_path = kwargs.pop("local_socket_path", "/tmp/connector-socket") + host = local_socket_path + start_local_proxy( + sock, + socket_path=f"{local_socket_path}/.s.PGSQL.{SERVER_PROXY_PORT}", + loop=self._loop + ) + # If this connection was opened using a domain name, then store it # for later in case we need to forcibly close it on failover. if conn_info.conn_name.domain_name: @@ -400,7 +414,7 @@ async def connect_async( # Synchronous drivers are blocking and run using executor connect_partial = partial( connector, - ip_address, + host, sock, **kwargs, ) diff --git a/google/cloud/sql/connector/exceptions.py b/google/cloud/sql/connector/exceptions.py index da39ea25d..3aee96430 100644 --- a/google/cloud/sql/connector/exceptions.py +++ b/google/cloud/sql/connector/exceptions.py @@ -84,3 +84,9 @@ class CacheClosedError(Exception): Exception to be raised when a ConnectionInfoCache can not be accessed after it is closed. """ + + +class LocalProxyStartupError(Exception): + """ + Exception to be raised when a the local UNIX-socket based proxy can not be started. + """ diff --git a/google/cloud/sql/connector/proxy.py b/google/cloud/sql/connector/proxy.py index 5f46003bb..06b623e82 100644 --- a/google/cloud/sql/connector/proxy.py +++ b/google/cloud/sql/connector/proxy.py @@ -16,48 +16,75 @@ import socket import os -import threading +import ssl +import asyncio from pathlib import Path +from typing import Optional + +from google.cloud.sql.connector.exceptions import LocalProxyStartupError SERVER_PROXY_PORT = 3307 +LOCAL_PROXY_MAX_MESSAGE_SIZE = 10485760 def start_local_proxy( - ssl_sock, - socket_path, + ssl_sock: ssl.SSLSocket, + socket_path: Optional[str] = "/tmp/connector-socket", + loop: Optional[asyncio.AbstractEventLoop] = None, ): - path_parts = socket_path.rsplit('/', 1) - parent_directory = '/'.join(path_parts[:-1]) + """Helper function to start a UNIX based local proxy for + transport messages through the SSL Socket. + + Args: + ssl_sock (ssl.SSLSocket): An SSLSocket object created from the Cloud SQL + server CA cert and ephemeral cert. + socket_path: A system path that is going to be used to store the socket. + loop (asyncio.AbstractEventLoop): Event loop to run asyncio tasks. + + Raises: + LocalProxyStartupError: Local UNIX socket based proxy was not able to + get started. + """ + unix_socket = None - desired_path = Path(parent_directory) - desired_path.mkdir(parents=True, exist_ok=True) + try: + path_parts = socket_path.rsplit('/', 1) + parent_directory = '/'.join(path_parts[:-1]) - if os.path.exists(socket_path): - os.remove(socket_path) - conn_unix = None - unix_socket = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + desired_path = Path(parent_directory) + desired_path.mkdir(parents=True, exist_ok=True) - unix_socket.bind(socket_path) - unix_socket.listen(1) + if os.path.exists(socket_path): + os.remove(socket_path) + unix_socket = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) - threading.Thread(target=local_communication, args=(unix_socket, ssl_sock, socket_path)).start() + unix_socket.bind(socket_path) + unix_socket.listen(1) + unix_socket.setblocking(False) + os.chmod(socket_path, 0o600) + except Exception: + raise LocalProxyStartupError( + 'Local UNIX socket based proxy was not able to get started.' + ) + loop.create_task(local_communication(unix_socket, ssl_sock, socket_path, loop)) -def local_communication( - unix_socket, ssl_sock, socket_path + +async def local_communication( + unix_socket, ssl_sock, socket_path, loop ): - try: - conn_unix, addr_unix = unix_socket.accept() - - while True: - data = conn_unix.recv(10485760) - if not data: - break - ssl_sock.sendall(data) - response = ssl_sock.recv(10485760) - conn_unix.sendall(response) - - finally: - if conn_unix is not None: - conn_unix.close() - unix_socket.close() - os.remove(socket_path) # Clean up the socket file + try: + client, _ = await loop.sock_accept(unix_socket) + + while True: + data = await loop.sock_recv(client, LOCAL_PROXY_MAX_MESSAGE_SIZE) + if not data: + client.close() + break + ssl_sock.sendall(data) + response = ssl_sock.recv(LOCAL_PROXY_MAX_MESSAGE_SIZE) + await loop.sock_sendall(client, response) + except Exception: + pass + finally: + client.close() + os.remove(socket_path) # Clean up the socket file diff --git a/google/cloud/sql/connector/psycopg.py b/google/cloud/sql/connector/psycopg.py index c34dcb1e1..fe862bc24 100644 --- a/google/cloud/sql/connector/psycopg.py +++ b/google/cloud/sql/connector/psycopg.py @@ -25,13 +25,12 @@ def connect( - ip_address: str, sock: ssl.SSLSocket, **kwargs: Any + host: str, sock: ssl.SSLSocket, **kwargs: Any ) -> "psycopg.Connection": """Helper function to create a psycopg DB-API connection object. Args: - ip_address (str): A string containing an IP address for the Cloud SQL - instance. + host (str): A string containing the socket path used by the local proxy. sock (ssl.SSLSocket): An SSLSocket object created from the Cloud SQL server CA cert and ephemeral cert. kwargs: Additional arguments to pass to the psycopg connect method. @@ -44,10 +43,7 @@ def connect( ImportError: The psycopg module cannot be imported. """ try: - from psycopg.rows import dict_row from psycopg import Connection - import threading - from google.cloud.sql.connector.proxy import start_local_proxy except ImportError: raise ImportError( 'Unable to import module "psycopg." Please install and try again.' @@ -59,14 +55,9 @@ def connect( kwargs.pop("timeout", None) - start_local_proxy(sock, f"/tmp/connector-socket/.s.PGSQL.3307") - conn = Connection.connect( - f"host=/tmp/connector-socket port={SERVER_PROXY_PORT} dbname={db} user={user} password={passwd} sslmode=require", - autocommit=True, - row_factory=dict_row, + f"host={host} port={SERVER_PROXY_PORT} dbname={db} user={user} password={passwd} sslmode=require", **kwargs ) - conn.autocommit = True return conn diff --git a/tests/system/test_psycopg_connection.py b/tests/system/test_psycopg_connection.py index 5b3a3c79e..9d6363d5c 100644 --- a/tests/system/test_psycopg_connection.py +++ b/tests/system/test_psycopg_connection.py @@ -18,12 +18,84 @@ import os # [START cloud_sql_connector_postgres_psycopg] +from typing import Union + +import sqlalchemy from google.cloud.sql.connector import Connector from google.cloud.sql.connector import DefaultResolver +from google.cloud.sql.connector import DnsResolver + + +def create_sqlalchemy_engine( + instance_connection_name: str, + user: str, + password: str, + db: str, + ip_type: str = "public", + refresh_strategy: str = "background", + resolver: Union[type[DefaultResolver], type[DnsResolver]] = DefaultResolver, +) -> tuple[sqlalchemy.engine.Engine, Connector]: + """Creates a connection pool for a Cloud SQL instance and returns the pool + and the connector. Callers are responsible for closing the pool and the + connector. + + A sample invocation looks like: + + engine, connector = create_sqlalchemy_engine( + inst_conn_name, + user, + password, + db, + ) + with engine.connect() as conn: + time = conn.execute(sqlalchemy.text("SELECT NOW()")).fetchone() + conn.commit() + curr_time = time[0] + # do something with query result + connector.close() + + Args: + instance_connection_name (str): + The instance connection name specifies the instance relative to the + project and region. For example: "my-project:my-region:my-instance" + user (str): + The database user name, e.g., root + password (str): + The database user's password, e.g., secret-password + db (str): + The name of the database, e.g., mydb + ip_type (str): + The IP type of the Cloud SQL instance to connect to. Can be one + of "public", "private", or "psc". + refresh_strategy (Optional[str]): + Refresh strategy for the Cloud SQL Connector. Can be one of "lazy" + or "background". For serverless environments use "lazy" to avoid + errors resulting from CPU being throttled. + resolver (Optional[google.cloud.sql.connector.DefaultResolver]): + Resolver class for resolving instance connection name. Use + google.cloud.sql.connector.DnsResolver when resolving DNS domain + names or google.cloud.sql.connector.DefaultResolver for regular + instance connection names ("my-project:my-region:my-instance"). + """ + connector = Connector(refresh_strategy=refresh_strategy, resolver=resolver) + + # create SQLAlchemy connection pool + engine = sqlalchemy.create_engine( + "postgresql+psycopg://", + creator=lambda: connector.connect( + instance_connection_name, + "psycopg", + user=user, + password=password, + db=db, + local_socket_path="/tmp/conn", + ip_type=ip_type, # can be "public", "private" or "psc" + autocommit=True, + ), + ) + return engine, connector -from sqlalchemy.dialects.postgresql.base import PGDialect -PGDialect._get_server_version_info = lambda *args: (9, 2) # [END cloud_sql_connector_postgres_psycopg] @@ -36,25 +108,12 @@ def test_psycopg_connection() -> None: db = os.environ["POSTGRES_DB"] ip_type = os.environ.get("IP_TYPE", "public") - connector = Connector(refresh_strategy="background", resolver=DefaultResolver) - - pool = connector.connect( - inst_conn_name, - "psycopg", - user=user, - password=password, - db=db, - ip_type=ip_type, # can be "public", "private" or "psc" + engine, connector = create_sqlalchemy_engine( + inst_conn_name, user, password, db, ip_type ) - - with pool as conn: - - # Open a cursor to perform database operations - with conn.cursor() as cur: - - # Query the database and obtain data as Python objects. - cur.execute("SELECT NOW()") - curr_time = cur.fetchone()["now"] - assert type(curr_time) is datetime - - + with engine.connect() as conn: + time = conn.execute(sqlalchemy.text("SELECT NOW()")).fetchone() + conn.commit() + curr_time = time[0] + assert type(curr_time) is datetime + connector.close() From f5be3ac11c374fe6b9dfdf3c8df0375150f7de4c Mon Sep 17 00:00:00 2001 From: Uziel Silva Date: Tue, 22 Jul 2025 21:13:46 -0600 Subject: [PATCH 04/14] fix(main): Prevent asyncio destroyed task warning Changelog: - Return the asyncio task from `start_local_proxy` - Handle it in `close_async` to cancel it gracefully --- google/cloud/sql/connector/connector.py | 9 ++++++++- google/cloud/sql/connector/proxy.py | 7 +++++-- 2 files changed, 13 insertions(+), 3 deletions(-) diff --git a/google/cloud/sql/connector/connector.py b/google/cloud/sql/connector/connector.py index 3c7378471..48dac6a67 100755 --- a/google/cloud/sql/connector/connector.py +++ b/google/cloud/sql/connector/connector.py @@ -401,7 +401,7 @@ async def connect_async( if driver in LOCAL_PROXY_DRIVERS: local_socket_path = kwargs.pop("local_socket_path", "/tmp/connector-socket") host = local_socket_path - start_local_proxy( + self._proxy = start_local_proxy( sock, socket_path=f"{local_socket_path}/.s.PGSQL.{SERVER_PROXY_PORT}", loop=self._loop @@ -486,6 +486,13 @@ async def close_async(self) -> None: await asyncio.gather(*[cache.close() for cache in self._cache.values()]) if self._client: await self._client.close() + if self._proxy: + proxy_task = asyncio.gather(self._proxy) + try: + await asyncio.wait_for(proxy_task, timeout=0.1) + except TimeoutError: + pass # This task runs forever so it is expected to throw this exception + async def create_async_connector( diff --git a/google/cloud/sql/connector/proxy.py b/google/cloud/sql/connector/proxy.py index 06b623e82..85ece82a1 100644 --- a/google/cloud/sql/connector/proxy.py +++ b/google/cloud/sql/connector/proxy.py @@ -30,7 +30,7 @@ def start_local_proxy( ssl_sock: ssl.SSLSocket, socket_path: Optional[str] = "/tmp/connector-socket", loop: Optional[asyncio.AbstractEventLoop] = None, -): +) -> asyncio.Task: """Helper function to start a UNIX based local proxy for transport messages through the SSL Socket. @@ -40,6 +40,9 @@ def start_local_proxy( socket_path: A system path that is going to be used to store the socket. loop (asyncio.AbstractEventLoop): Event loop to run asyncio tasks. + Returns: + asyncio.Task: The asyncio task containing the proxy server process. + Raises: LocalProxyStartupError: Local UNIX socket based proxy was not able to get started. @@ -66,7 +69,7 @@ def start_local_proxy( 'Local UNIX socket based proxy was not able to get started.' ) - loop.create_task(local_communication(unix_socket, ssl_sock, socket_path, loop)) + return loop.create_task(local_communication(unix_socket, ssl_sock, socket_path, loop)) async def local_communication( From 2702590b75ac8557f8a117ab8947480bb1efb7aa Mon Sep 17 00:00:00 2001 From: Uziel Silva Date: Thu, 24 Jul 2025 20:35:21 -0600 Subject: [PATCH 05/14] fix(main) Fix linting and undefined cases Changelog: - Fix linting issues - Define `self.proxy` on the constructor - Prevent issues with undefined variables --- google/cloud/sql/connector/connector.py | 5 +++-- google/cloud/sql/connector/proxy.py | 6 +++--- google/cloud/sql/connector/psycopg.py | 1 - tests/unit/test_psycopg.py | 3 ++- 4 files changed, 8 insertions(+), 7 deletions(-) diff --git a/google/cloud/sql/connector/connector.py b/google/cloud/sql/connector/connector.py index 48dac6a67..a358574d6 100755 --- a/google/cloud/sql/connector/connector.py +++ b/google/cloud/sql/connector/connector.py @@ -37,6 +37,7 @@ from google.cloud.sql.connector.lazy import LazyRefreshCache from google.cloud.sql.connector.monitored_cache import MonitoredCache import google.cloud.sql.connector.pg8000 as pg8000 +from google.cloud.sql.connector.proxy import start_local_proxy import google.cloud.sql.connector.psycopg as psycopg import google.cloud.sql.connector.pymysql as pymysql import google.cloud.sql.connector.pytds as pytds @@ -44,7 +45,6 @@ from google.cloud.sql.connector.resolver import DnsResolver from google.cloud.sql.connector.utils import format_database_user from google.cloud.sql.connector.utils import generate_keys -from google.cloud.sql.connector.proxy import start_local_proxy logger = logging.getLogger(name=__name__) @@ -156,6 +156,7 @@ def __init__( # connection name string and enable_iam_auth boolean flag self._cache: dict[tuple[str, bool], MonitoredCache] = {} self._client: Optional[CloudSQLClient] = None + self._proxy: Optional[asyncio.Task] = None # initialize credentials scopes = ["https://www.googleapis.com/auth/sqlservice.admin"] @@ -385,7 +386,7 @@ async def connect_async( # async drivers are unblocking and can be awaited directly if driver in ASYNC_DRIVERS: return await connector( - host, + ip_address, await conn_info.create_ssl_context(enable_iam_auth), **kwargs, ) diff --git a/google/cloud/sql/connector/proxy.py b/google/cloud/sql/connector/proxy.py index 85ece82a1..cbb795ea5 100644 --- a/google/cloud/sql/connector/proxy.py +++ b/google/cloud/sql/connector/proxy.py @@ -14,11 +14,11 @@ limitations under the License. """ -import socket -import os -import ssl import asyncio +import os from pathlib import Path +import socket +import ssl from typing import Optional from google.cloud.sql.connector.exceptions import LocalProxyStartupError diff --git a/google/cloud/sql/connector/psycopg.py b/google/cloud/sql/connector/psycopg.py index fe862bc24..80e824002 100644 --- a/google/cloud/sql/connector/psycopg.py +++ b/google/cloud/sql/connector/psycopg.py @@ -16,7 +16,6 @@ import ssl from typing import Any, TYPE_CHECKING -import threading SERVER_PROXY_PORT = 3307 diff --git a/tests/unit/test_psycopg.py b/tests/unit/test_psycopg.py index 8d9fe1f85..aa30a9c4b 100644 --- a/tests/unit/test_psycopg.py +++ b/tests/unit/test_psycopg.py @@ -18,7 +18,8 @@ import ssl from typing import Any -from mock import patch, PropertyMock +from mock import patch +from mock import PropertyMock import pytest from google.cloud.sql.connector.psycopg import connect From 42dab35f06d0fc69a2e82340da162a8877b8c220 Mon Sep 17 00:00:00 2001 From: Uziel Silva Date: Thu, 24 Jul 2025 21:11:28 -0600 Subject: [PATCH 06/14] fix(main): Fix psycopg unit test --- tests/unit/test_psycopg.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unit/test_psycopg.py b/tests/unit/test_psycopg.py index aa30a9c4b..8088751c5 100644 --- a/tests/unit/test_psycopg.py +++ b/tests/unit/test_psycopg.py @@ -33,7 +33,7 @@ async def test_psycopg(context: ssl.SSLContext, kwargs: Any) -> None: socket.create_connection((ip_addr, 3307)), server_hostname=ip_addr, ) - with patch("psycopg.connect") as mock_connect: + with patch("psycopg.Connection.connect") as mock_connect: type(mock_connect.return_value).autocommit = PropertyMock(return_value=True) connection = connect(ip_addr, sock, **kwargs) assert connection.autocommit is True From 8c3ce218c60abaf1a0eed1dcc2a9acac09621c73 Mon Sep 17 00:00:00 2001 From: Uziel Silva Date: Mon, 28 Jul 2025 13:06:57 -0600 Subject: [PATCH 07/14] fix(main): Fix linting --- google/cloud/sql/connector/proxy.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/google/cloud/sql/connector/proxy.py b/google/cloud/sql/connector/proxy.py index cbb795ea5..385f2ff51 100644 --- a/google/cloud/sql/connector/proxy.py +++ b/google/cloud/sql/connector/proxy.py @@ -19,7 +19,6 @@ from pathlib import Path import socket import ssl -from typing import Optional from google.cloud.sql.connector.exceptions import LocalProxyStartupError @@ -28,8 +27,8 @@ def start_local_proxy( ssl_sock: ssl.SSLSocket, - socket_path: Optional[str] = "/tmp/connector-socket", - loop: Optional[asyncio.AbstractEventLoop] = None, + socket_path: str, + loop: asyncio.AbstractEventLoop ) -> asyncio.Task: """Helper function to start a UNIX based local proxy for transport messages through the SSL Socket. From d8c23a275f11417272a0b4bde53d63b95d2eb86b Mon Sep 17 00:00:00 2001 From: Uziel Silva Date: Mon, 28 Jul 2025 18:23:15 -0600 Subject: [PATCH 08/14] fix(main) Increase code coverage to 94% Changelog: - Add unit tests for proxy - Add test case to connector for drivers that require the local proxy - Make proper adjustments to code --- google/cloud/sql/connector/connector.py | 4 +- google/cloud/sql/connector/proxy.py | 6 +- tests/unit/test_connector.py | 39 +++++++++++++ tests/unit/test_proxy.py | 75 +++++++++++++++++++++++++ 4 files changed, 118 insertions(+), 6 deletions(-) create mode 100644 tests/unit/test_proxy.py diff --git a/google/cloud/sql/connector/connector.py b/google/cloud/sql/connector/connector.py index a358574d6..a85518ea4 100755 --- a/google/cloud/sql/connector/connector.py +++ b/google/cloud/sql/connector/connector.py @@ -37,7 +37,7 @@ from google.cloud.sql.connector.lazy import LazyRefreshCache from google.cloud.sql.connector.monitored_cache import MonitoredCache import google.cloud.sql.connector.pg8000 as pg8000 -from google.cloud.sql.connector.proxy import start_local_proxy +import google.cloud.sql.connector.proxy as proxy import google.cloud.sql.connector.psycopg as psycopg import google.cloud.sql.connector.pymysql as pymysql import google.cloud.sql.connector.pytds as pytds @@ -402,7 +402,7 @@ async def connect_async( if driver in LOCAL_PROXY_DRIVERS: local_socket_path = kwargs.pop("local_socket_path", "/tmp/connector-socket") host = local_socket_path - self._proxy = start_local_proxy( + self._proxy = proxy.start_local_proxy( sock, socket_path=f"{local_socket_path}/.s.PGSQL.{SERVER_PROXY_PORT}", loop=self._loop diff --git a/google/cloud/sql/connector/proxy.py b/google/cloud/sql/connector/proxy.py index 385f2ff51..e3bd69614 100644 --- a/google/cloud/sql/connector/proxy.py +++ b/google/cloud/sql/connector/proxy.py @@ -74,9 +74,9 @@ def start_local_proxy( async def local_communication( unix_socket, ssl_sock, socket_path, loop ): + client, _ = await loop.sock_accept(unix_socket) + try: - client, _ = await loop.sock_accept(unix_socket) - while True: data = await loop.sock_recv(client, LOCAL_PROXY_MAX_MESSAGE_SIZE) if not data: @@ -85,8 +85,6 @@ async def local_communication( ssl_sock.sendall(data) response = ssl_sock.recv(LOCAL_PROXY_MAX_MESSAGE_SIZE) await loop.sock_sendall(client, response) - except Exception: - pass finally: client.close() os.remove(socket_path) # Clean up the socket file diff --git a/tests/unit/test_connector.py b/tests/unit/test_connector.py index 157697723..76db18ae2 100644 --- a/tests/unit/test_connector.py +++ b/tests/unit/test_connector.py @@ -16,6 +16,8 @@ import asyncio import os +import socket +import ssl from typing import Union from aiohttp import ClientResponseError @@ -31,6 +33,7 @@ from google.cloud.sql.connector.exceptions import CloudSQLIPTypeError from google.cloud.sql.connector.exceptions import IncompatibleDriverError from google.cloud.sql.connector.instance import RefreshAheadCache +from google.cloud.sql.connector.proxy import start_local_proxy @pytest.mark.asyncio @@ -279,6 +282,42 @@ async def test_Connector_connect_async( # verify connector made connection call assert connection is True +@pytest.mark.usefixtures("proxy_server") +@pytest.mark.asyncio +async def test_Connector_connect_local_proxy( + fake_credentials: Credentials, fake_client: CloudSQLClient, context: ssl.SSLContext +) -> None: + """Test that Connector.connect can launch start_local_proxy.""" + async with Connector( + credentials=fake_credentials, loop=asyncio.get_running_loop() + ) as connector: + connector._client = fake_client + socket_path = "/tmp/connector-socket/socket" + ip_addr = "127.0.0.1" + ssl_sock = context.wrap_socket( + socket.create_connection((ip_addr, 3307)), + server_hostname=ip_addr, + ) + loop = asyncio.get_running_loop() + task = start_local_proxy(ssl_sock, socket_path, loop) + # patch db connection creation + with patch("google.cloud.sql.connector.proxy.start_local_proxy") as mock_proxy: + with patch("google.cloud.sql.connector.psycopg.connect") as mock_connect: + mock_connect.return_value = True + mock_proxy.return_value = task + connection = await connector.connect_async( + "test-project:test-region:test-instance", + "psycopg", + user="my-user", + password="my-pass", + db="my-db", + local_socket_path=socket_path, + ) + # verify connector called local proxy + mock_connect.assert_called_once() + mock_proxy.assert_called_once() + assert connection is True + @pytest.mark.asyncio async def test_create_async_connector(fake_credentials: Credentials) -> None: diff --git a/tests/unit/test_proxy.py b/tests/unit/test_proxy.py new file mode 100644 index 000000000..7e1e5c79d --- /dev/null +++ b/tests/unit/test_proxy.py @@ -0,0 +1,75 @@ +""" +Copyright 2025 Google LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import asyncio +import socket +import ssl +from typing import Any + +from mock import Mock +import pytest + +from google.cloud.sql.connector import proxy + +LOCAL_PROXY_MAX_MESSAGE_SIZE = 10485760 + +@pytest.mark.usefixtures("proxy_server") +@pytest.mark.asyncio +async def test_proxy_creates_folder(context: ssl.SSLContext, kwargs: Any) -> None: + """Test to verify that the proxy server is getting back the task.""" + ip_addr = "127.0.0.1" + path = "/tmp/connector-socket/socket" + sock = context.wrap_socket( + socket.create_connection((ip_addr, 3307)), + server_hostname=ip_addr, + ) + loop = asyncio.get_running_loop() + + task = proxy.start_local_proxy(sock, path, loop) + assert (task is not None) + + proxy_task = asyncio.gather(task) + try: + await asyncio.wait_for(proxy_task, timeout=0.1) + except TimeoutError: + pass # This task runs forever so it is expected to throw this exception + +@pytest.mark.usefixtures("proxy_server") +@pytest.mark.asyncio +async def test_local_proxy_communication(context: ssl.SSLContext, kwargs: Any) -> None: + """Test to verify that the communication is getting through.""" + socket_path = "/tmp/connector-socket/socket" + ssl_sock = Mock(spec=ssl.SSLSocket) + loop = asyncio.get_running_loop() + + with socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) as client: + ssl_sock.recv.return_value = b"Received" + + task = proxy.start_local_proxy(ssl_sock, socket_path, loop) + + client.connect(socket_path) + client.sendall(b"Test") + await asyncio.sleep(1) + + ssl_sock.sendall.assert_called_with(b"Test") + response = client.recv(LOCAL_PROXY_MAX_MESSAGE_SIZE) + assert (response == b"Received") + + client.close() + await asyncio.sleep(1) + + proxy_task = asyncio.gather(task) + await asyncio.wait_for(proxy_task, timeout=2) From c8985b65eec05f8b90b0eedc645e9ac2a1af73d6 Mon Sep 17 00:00:00 2001 From: Uziel Silva Date: Mon, 28 Jul 2025 20:26:45 -0600 Subject: [PATCH 09/14] fix(main): Add support for Python 3.9 --- google/cloud/sql/connector/connector.py | 2 +- tests/unit/test_connector.py | 6 ++++++ tests/unit/test_proxy.py | 2 +- 3 files changed, 8 insertions(+), 2 deletions(-) diff --git a/google/cloud/sql/connector/connector.py b/google/cloud/sql/connector/connector.py index a85518ea4..f196f1ce9 100755 --- a/google/cloud/sql/connector/connector.py +++ b/google/cloud/sql/connector/connector.py @@ -491,7 +491,7 @@ async def close_async(self) -> None: proxy_task = asyncio.gather(self._proxy) try: await asyncio.wait_for(proxy_task, timeout=0.1) - except TimeoutError: + except (asyncio.CancelledError, asyncio.TimeoutError, TimeoutError): pass # This task runs forever so it is expected to throw this exception diff --git a/tests/unit/test_connector.py b/tests/unit/test_connector.py index 76db18ae2..704fb52a7 100644 --- a/tests/unit/test_connector.py +++ b/tests/unit/test_connector.py @@ -317,6 +317,12 @@ async def test_Connector_connect_local_proxy( mock_connect.assert_called_once() mock_proxy.assert_called_once() assert connection is True + + proxy_task = asyncio.gather(task) + try: + await asyncio.wait_for(proxy_task, timeout=0.1) + except (asyncio.CancelledError, asyncio.TimeoutError, TimeoutError): + pass # This task runs forever so it is expected to throw this exception @pytest.mark.asyncio diff --git a/tests/unit/test_proxy.py b/tests/unit/test_proxy.py index 7e1e5c79d..c1143f19f 100644 --- a/tests/unit/test_proxy.py +++ b/tests/unit/test_proxy.py @@ -44,7 +44,7 @@ async def test_proxy_creates_folder(context: ssl.SSLContext, kwargs: Any) -> Non proxy_task = asyncio.gather(task) try: await asyncio.wait_for(proxy_task, timeout=0.1) - except TimeoutError: + except (asyncio.CancelledError, asyncio.TimeoutError, TimeoutError): pass # This task runs forever so it is expected to throw this exception @pytest.mark.usefixtures("proxy_server") From 4f6f388f7eea53a51bf82238336acc00f820e48d Mon Sep 17 00:00:00 2001 From: Uziel Silva Date: Mon, 18 Aug 2025 18:38:31 -0600 Subject: [PATCH 10/14] fix(main): Make local proxy to accept multiple connections (WIP) --- google/cloud/sql/connector/connector.py | 55 +++--- google/cloud/sql/connector/enums.py | 4 +- .../{psycopg.py => local_unix_socket.py} | 37 +--- google/cloud/sql/connector/proxy.py | 176 ++++++++++++------ pyproject.toml | 1 - tests/system/test_psycopg_connection.py | 17 +- ...t_psycopg.py => test_local_unix_socket.py} | 14 +- 7 files changed, 169 insertions(+), 135 deletions(-) rename google/cloud/sql/connector/{psycopg.py => local_unix_socket.py} (51%) rename tests/unit/{test_psycopg.py => test_local_unix_socket.py} (61%) diff --git a/google/cloud/sql/connector/connector.py b/google/cloud/sql/connector/connector.py index f196f1ce9..d8c699b09 100755 --- a/google/cloud/sql/connector/connector.py +++ b/google/cloud/sql/connector/connector.py @@ -35,10 +35,10 @@ from google.cloud.sql.connector.enums import RefreshStrategy from google.cloud.sql.connector.instance import RefreshAheadCache from google.cloud.sql.connector.lazy import LazyRefreshCache +import google.cloud.sql.connector.local_unix_socket as local_unix_socket from google.cloud.sql.connector.monitored_cache import MonitoredCache import google.cloud.sql.connector.pg8000 as pg8000 -import google.cloud.sql.connector.proxy as proxy -import google.cloud.sql.connector.psycopg as psycopg +from google.cloud.sql.connector.proxy import Proxy import google.cloud.sql.connector.pymysql as pymysql import google.cloud.sql.connector.pytds as pytds from google.cloud.sql.connector.resolver import DefaultResolver @@ -49,7 +49,6 @@ logger = logging.getLogger(name=__name__) ASYNC_DRIVERS = ["asyncpg"] -LOCAL_PROXY_DRIVERS = ["psycopg"] SERVER_PROXY_PORT = 3307 _DEFAULT_SCHEME = "https://" _DEFAULT_UNIVERSE_DOMAIN = "googleapis.com" @@ -156,7 +155,7 @@ def __init__( # connection name string and enable_iam_auth boolean flag self._cache: dict[tuple[str, bool], MonitoredCache] = {} self._client: Optional[CloudSQLClient] = None - self._proxy: Optional[asyncio.Task] = None + self._proxies: Optional[Proxy] = None # initialize credentials scopes = ["https://www.googleapis.com/auth/sqlservice.admin"] @@ -217,6 +216,29 @@ def __init__( def universe_domain(self) -> str: return self._universe_domain or _DEFAULT_UNIVERSE_DOMAIN + def start_unix_socket_proxy_async( + self, + instance_connection_name: str, + local_socket_path: str, + **kwargs: Any + ) -> None: + """Creates a new Proxy instance and stores it to properly disposal + + Args: + instance_connection_string (str): The instance connection name of the + Cloud SQL instance to connect to. Takes the form of + "project-id:region:instance-name" + + Example: "my-project:us-central1:my-instance" + + local_socket_path (str): A string representing the location of the local socket. + + **kwargs: Any driver-specific arguments to pass to the underlying + driver .connect call. + """ + # TODO: validates the local socket path is not the same as other invocation + self._proxies.append(new Proxy(self, instance_connection_name, local_socket_path, self.loop, **kwargs)) + def connect( self, instance_connection_string: str, driver: str, **kwargs: Any ) -> Any: @@ -234,7 +256,7 @@ def connect( Example: "my-project:us-central1:my-instance" driver (str): A string representing the database driver to connect - with. Supported drivers are pymysql, pg8000, psycopg, and pytds. + with. Supported drivers are pymysql, pg8000, local_unix_socket, and pytds. **kwargs: Any driver-specific arguments to pass to the underlying driver .connect call. @@ -270,7 +292,7 @@ async def connect_async( Example: "my-project:us-central1:my-instance" driver (str): A string representing the database driver to connect - with. Supported drivers are pymysql, asyncpg, pg8000, psycopg, and + with. Supported drivers are pymysql, asyncpg, pg8000, local_unix_socket, and pytds. **kwargs: Any driver-specific arguments to pass to the underlying @@ -283,7 +305,7 @@ async def connect_async( ValueError: Connection attempt with built-in database authentication and then subsequent attempt with IAM database authentication. KeyError: Unsupported database driver Must be one of pymysql, asyncpg, - pg8000, psycopg, and pytds. + pg8000, local_unix_socket, and pytds. """ if self._keys is None: self._keys = asyncio.create_task(generate_keys()) @@ -337,7 +359,7 @@ async def connect_async( connect_func = { "pymysql": pymysql.connect, "pg8000": pg8000.connect, - "psycopg": psycopg.connect, + "local_unix_socket": local_unix_socket.connect, "asyncpg": asyncpg.connect, "pytds": pytds.connect, } @@ -397,17 +419,6 @@ async def connect_async( server_hostname=ip_address, ) - host = ip_address - # start local proxy if driver needs it - if driver in LOCAL_PROXY_DRIVERS: - local_socket_path = kwargs.pop("local_socket_path", "/tmp/connector-socket") - host = local_socket_path - self._proxy = proxy.start_local_proxy( - sock, - socket_path=f"{local_socket_path}/.s.PGSQL.{SERVER_PROXY_PORT}", - loop=self._loop - ) - # If this connection was opened using a domain name, then store it # for later in case we need to forcibly close it on failover. if conn_info.conn_name.domain_name: @@ -488,11 +499,7 @@ async def close_async(self) -> None: if self._client: await self._client.close() if self._proxy: - proxy_task = asyncio.gather(self._proxy) - try: - await asyncio.wait_for(proxy_task, timeout=0.1) - except (asyncio.CancelledError, asyncio.TimeoutError, TimeoutError): - pass # This task runs forever so it is expected to throw this exception + await asyncio.wait_for([ proxy.close_async() for proxy in self._proxies]) diff --git a/google/cloud/sql/connector/enums.py b/google/cloud/sql/connector/enums.py index 5926b75a2..7bde045ed 100644 --- a/google/cloud/sql/connector/enums.py +++ b/google/cloud/sql/connector/enums.py @@ -62,7 +62,7 @@ class DriverMapping(Enum): ASYNCPG = "POSTGRES" PG8000 = "POSTGRES" - PSYCOPG = "POSTGRES" + LOCAL_UNIX_SOCKET = "ANY" PYMYSQL = "MYSQL" PYTDS = "SQLSERVER" @@ -79,7 +79,7 @@ def validate_engine(driver: str, engine_version: str) -> None: the given engine. """ mapping = DriverMapping[driver.upper()] - if not engine_version.startswith(mapping.value): + if not mapping.value == "ANY" and not engine_version.startswith(mapping.value): raise IncompatibleDriverError( f"Database driver '{driver}' is incompatible with database " f"version '{engine_version}'. Given driver can " diff --git a/google/cloud/sql/connector/psycopg.py b/google/cloud/sql/connector/local_unix_socket.py similarity index 51% rename from google/cloud/sql/connector/psycopg.py rename to google/cloud/sql/connector/local_unix_socket.py index 80e824002..497e503e0 100644 --- a/google/cloud/sql/connector/psycopg.py +++ b/google/cloud/sql/connector/local_unix_socket.py @@ -19,44 +19,19 @@ SERVER_PROXY_PORT = 3307 -if TYPE_CHECKING: - import psycopg - - def connect( host: str, sock: ssl.SSLSocket, **kwargs: Any -) -> "psycopg.Connection": - """Helper function to create a psycopg DB-API connection object. +) -> "ssl.SSLSocket": + """Helper function to retrieve the socket for local UNIX sockets. Args: host (str): A string containing the socket path used by the local proxy. sock (ssl.SSLSocket): An SSLSocket object created from the Cloud SQL server CA cert and ephemeral cert. - kwargs: Additional arguments to pass to the psycopg connect method. + kwargs: Additional arguments to pass to the local UNIX socket connect method. Returns: - psycopg.Connection: A psycopg connection to the Cloud SQL - instance. - - Raises: - ImportError: The psycopg module cannot be imported. + ssl.SSLSocket: The same socket """ - try: - from psycopg import Connection - except ImportError: - raise ImportError( - 'Unable to import module "psycopg." Please install and try again.' - ) - - user = kwargs.pop("user") - db = kwargs.pop("db") - passwd = kwargs.pop("password", None) - - kwargs.pop("timeout", None) - - conn = Connection.connect( - f"host={host} port={SERVER_PROXY_PORT} dbname={db} user={user} password={passwd} sslmode=require", - **kwargs - ) - - return conn + + return sock diff --git a/google/cloud/sql/connector/proxy.py b/google/cloud/sql/connector/proxy.py index e3bd69614..568c463ad 100644 --- a/google/cloud/sql/connector/proxy.py +++ b/google/cloud/sql/connector/proxy.py @@ -18,73 +18,125 @@ import os from pathlib import Path import socket +import selectors import ssl +from google.cloud.sql.connector import Connector from google.cloud.sql.connector.exceptions import LocalProxyStartupError SERVER_PROXY_PORT = 3307 LOCAL_PROXY_MAX_MESSAGE_SIZE = 10485760 -def start_local_proxy( - ssl_sock: ssl.SSLSocket, - socket_path: str, - loop: asyncio.AbstractEventLoop -) -> asyncio.Task: - """Helper function to start a UNIX based local proxy for - transport messages through the SSL Socket. - - Args: - ssl_sock (ssl.SSLSocket): An SSLSocket object created from the Cloud SQL - server CA cert and ephemeral cert. - socket_path: A system path that is going to be used to store the socket. - loop (asyncio.AbstractEventLoop): Event loop to run asyncio tasks. - - Returns: - asyncio.Task: The asyncio task containing the proxy server process. - - Raises: - LocalProxyStartupError: Local UNIX socket based proxy was not able to - get started. - """ - unix_socket = None - - try: - path_parts = socket_path.rsplit('/', 1) - parent_directory = '/'.join(path_parts[:-1]) - - desired_path = Path(parent_directory) - desired_path.mkdir(parents=True, exist_ok=True) - - if os.path.exists(socket_path): - os.remove(socket_path) - unix_socket = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) - - unix_socket.bind(socket_path) - unix_socket.listen(1) - unix_socket.setblocking(False) - os.chmod(socket_path, 0o600) - except Exception: - raise LocalProxyStartupError( - 'Local UNIX socket based proxy was not able to get started.' - ) - - return loop.create_task(local_communication(unix_socket, ssl_sock, socket_path, loop)) - - -async def local_communication( - unix_socket, ssl_sock, socket_path, loop -): - client, _ = await loop.sock_accept(unix_socket) - - try: + +class Proxy: + """Creates an "accept loop" async task which will open the unix server socket and listen for new connections.""" + + def __init__( + self, + connector: Connector, + instance_connection_string: str, + socket_path: str, + loop: asyncio.AbstractEventLoop, + **kwargs: Any + ) -> None: + """Keeps track of all the async tasks and starts the accept loop for new connections. + + Args: + connector (Connector): The instance where this Proxy class was created. + + instance_connection_string (str): The instance connection name of the + Cloud SQL instance to connect to. Takes the form of + "project-id:region:instance-name" + + Example: "my-project:us-central1:my-instance" + + socket_path (str): A system path that is going to be used to store the socket. + + loop (asyncio.AbstractEventLoop): Event loop to run asyncio tasks. + + **kwargs: Any driver-specific arguments to pass to the underlying + driver .connect call. + """ + self._connection_tasks = [] + self._addr = instance_connection_string + self._kwargs = kwargs + self._connector = connector + self._task = loop.create_task(accept_loop(socket_path, loop, **kwargs)) + + async def accept_loop( + self + socket_path: str, + loop: asyncio.AbstractEventLoop + ) -> asyncio.Task: + """Starts a UNIX based local proxy for transporting messages through + the SSL Socket, and waits until there is a new connection to accept, to register it + and keep track of it. + + Args: + socket_path: A system path that is going to be used to store the socket. + + loop (asyncio.AbstractEventLoop): Event loop to run asyncio tasks. + + Raises: + LocalProxyStartupError: Local UNIX socket based proxy was not able to + get started. + """ + unix_socket = None + sel = selectors.DefaultSelector() + + try: + path_parts = socket_path.rsplit('/', 1) + parent_directory = '/'.join(path_parts[:-1]) + + desired_path = Path(parent_directory) + desired_path.mkdir(parents=True, exist_ok=True) + + if os.path.exists(socket_path): + os.remove(socket_path) + + unix_socket = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + + unix_socket.bind(socket_path) + unix_socket.listen(1) + unix_socket.setblocking(False) + os.chmod(socket_path, 0o600) + + sel.register(unix_socket, selectors.EVENT_READ, data=None) + + except Exception: + raise LocalProxyStartupError( + 'Local UNIX socket based proxy was not able to get started.' + ) + while True: - data = await loop.sock_recv(client, LOCAL_PROXY_MAX_MESSAGE_SIZE) - if not data: - client.close() - break - ssl_sock.sendall(data) - response = ssl_sock.recv(LOCAL_PROXY_MAX_MESSAGE_SIZE) - await loop.sock_sendall(client, response) - finally: - client.close() - os.remove(socket_path) # Clean up the socket file + client, _ = await loop.sock_accept(unix_socket) + self._connection_tasks.append(loop.create_task(self.client_socket(client, unix_socket, socket_path, loop))) + + async def close_async(self): + proxy_task = asyncio.gather(self._task) + try: + await asyncio.wait_for(proxy_task, timeout=0.1) + except (asyncio.CancelledError, asyncio.TimeoutError, TimeoutError): + pass # This task runs forever so it is expected to throw this exception + + + async def client_socket( + self, client, unix_socket, socket_path, loop + ): + try: + ssl_sock = self.connector.connect( + self._addr, + 'local_unix_socket', + **self._kwargs + ) + while True: + data = await loop.sock_recv(client, LOCAL_PROXY_MAX_MESSAGE_SIZE) + if not data: + client.close() + break + ssl_sock.sendall(data) + response = ssl_sock.recv(LOCAL_PROXY_MAX_MESSAGE_SIZE) + await loop.sock_sendall(client, response) + finally: + client.close() + os.remove(socket_path) # Clean up the socket file diff --git a/pyproject.toml b/pyproject.toml index 1c933cf6a..cbf0dd10f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -59,7 +59,6 @@ Changelog = "https://github.com/GoogleCloudPlatform/cloud-sql-python-connector/b [project.optional-dependencies] pymysql = ["PyMySQL>=1.1.0"] pg8000 = ["pg8000>=1.31.1"] -psycopg = ["psycopg>=3.2.9"] pytds = ["python-tds>=1.15.0"] asyncpg = ["asyncpg>=0.30.0"] diff --git a/tests/system/test_psycopg_connection.py b/tests/system/test_psycopg_connection.py index 9d6363d5c..e54d957c8 100644 --- a/tests/system/test_psycopg_connection.py +++ b/tests/system/test_psycopg_connection.py @@ -20,6 +20,7 @@ # [START cloud_sql_connector_postgres_psycopg] from typing import Union +from psycopg import Connection import sqlalchemy from google.cloud.sql.connector import Connector @@ -79,21 +80,25 @@ def create_sqlalchemy_engine( instance connection names ("my-project:my-region:my-instance"). """ connector = Connector(refresh_strategy=refresh_strategy, resolver=resolver) + unix_socket_path = "/tmp/conn" + await connector.start_unix_socket_proxy_async( + instance_connection_name, + unix_socket_path, + ip_type=ip_type, # can be "public", "private" or "psc" + ) # create SQLAlchemy connection pool engine = sqlalchemy.create_engine( "postgresql+psycopg://", - creator=lambda: connector.connect( - instance_connection_name, - "psycopg", + creator=lambda: Connection.connect( + f"host={unix_socket_path} port={SERVER_PROXY_PORT} dbname={db} user={user} password={passwd} sslmode=require", user=user, password=password, db=db, - local_socket_path="/tmp/conn", - ip_type=ip_type, # can be "public", "private" or "psc" autocommit=True, - ), + ) ) + return engine, connector diff --git a/tests/unit/test_psycopg.py b/tests/unit/test_local_unix_socket.py similarity index 61% rename from tests/unit/test_psycopg.py rename to tests/unit/test_local_unix_socket.py index 8088751c5..8672857ec 100644 --- a/tests/unit/test_psycopg.py +++ b/tests/unit/test_local_unix_socket.py @@ -22,20 +22,16 @@ from mock import PropertyMock import pytest -from google.cloud.sql.connector.psycopg import connect +from google.cloud.sql.connector.local_unix_socket import connect @pytest.mark.usefixtures("proxy_server") -async def test_psycopg(context: ssl.SSLContext, kwargs: Any) -> None: - """Test to verify that psycopg gets to proper connection call.""" +async def test_local_unix_socket(context: ssl.SSLContext, kwargs: Any) -> None: + """Test to verify that local_unix_socket gets to proper connection call.""" ip_addr = "127.0.0.1" sock = context.wrap_socket( socket.create_connection((ip_addr, 3307)), server_hostname=ip_addr, ) - with patch("psycopg.Connection.connect") as mock_connect: - type(mock_connect.return_value).autocommit = PropertyMock(return_value=True) - connection = connect(ip_addr, sock, **kwargs) - assert connection.autocommit is True - # verify that driver connection call would be made - assert mock_connect.assert_called_once + connection = connect(ip_addr, sock, **kwargs) + assert connection == sock From 73622330d7edd54cdc12d4f9709e74c26d3eee8d Mon Sep 17 00:00:00 2001 From: Uziel Silva Date: Wed, 10 Sep 2025 13:18:07 -0600 Subject: [PATCH 11/14] test: Fix compilation errors --- google/cloud/sql/connector/connector.py | 66 +++++++++------ .../cloud/sql/connector/local_unix_socket.py | 2 - google/cloud/sql/connector/proxy.py | 52 ++++++------ tests/system/test_psycopg_connection.py | 10 ++- tests/unit/test_connector.py | 82 +++++++++---------- 5 files changed, 111 insertions(+), 101 deletions(-) diff --git a/google/cloud/sql/connector/connector.py b/google/cloud/sql/connector/connector.py index d8c699b09..06c6fb98f 100755 --- a/google/cloud/sql/connector/connector.py +++ b/google/cloud/sql/connector/connector.py @@ -155,7 +155,7 @@ def __init__( # connection name string and enable_iam_auth boolean flag self._cache: dict[tuple[str, bool], MonitoredCache] = {} self._client: Optional[CloudSQLClient] = None - self._proxies: Optional[Proxy] = None + self._proxies: list[proxy.Proxy] = [] # initialize credentials scopes = ["https://www.googleapis.com/auth/sqlservice.admin"] @@ -216,29 +216,6 @@ def __init__( def universe_domain(self) -> str: return self._universe_domain or _DEFAULT_UNIVERSE_DOMAIN - def start_unix_socket_proxy_async( - self, - instance_connection_name: str, - local_socket_path: str, - **kwargs: Any - ) -> None: - """Creates a new Proxy instance and stores it to properly disposal - - Args: - instance_connection_string (str): The instance connection name of the - Cloud SQL instance to connect to. Takes the form of - "project-id:region:instance-name" - - Example: "my-project:us-central1:my-instance" - - local_socket_path (str): A string representing the location of the local socket. - - **kwargs: Any driver-specific arguments to pass to the underlying - driver .connect call. - """ - # TODO: validates the local socket path is not the same as other invocation - self._proxies.append(new Proxy(self, instance_connection_name, local_socket_path, self.loop, **kwargs)) - def connect( self, instance_connection_string: str, driver: str, **kwargs: Any ) -> Any: @@ -426,7 +403,7 @@ async def connect_async( # Synchronous drivers are blocking and run using executor connect_partial = partial( connector, - host, + ip_address, sock, **kwargs, ) @@ -437,6 +414,42 @@ async def connect_async( await monitored_cache.force_refresh() raise + async def start_unix_socket_proxy_async( + self, instance_connection_string: str, local_socket_path: str, **kwargs: Any + ) -> None: + """Starts a local Unix socket proxy for a Cloud SQL instance. + + Args: + instance_connection_string (str): The instance connection name of the + Cloud SQL instance to connect to. + local_socket_path (str): The path to the local Unix socket. + driver (str): The database driver name. + **kwargs: Keyword arguments to pass to the underlying database + driver. + """ + if "driver" in kwargs: + driver = kwargs["driver"] + else: + driver = "proxy" + + self._init_client(driver) + + # check if a proxy is already running for this socket path + for p in self._proxies: + if p.unix_socket_path == local_socket_path: + raise ValueError( + f"Proxy for socket path {local_socket_path} already exists." + ) + + # Create a new proxy instance + proxy_instance = proxy.Proxy( + local_socket_path, + ConnectorSocketFactory(self, instance_connection_string, **kwargs), + self._loop + ) + await proxy_instance.start() + self._proxies.append(proxy_instance) + async def _remove_cached( self, instance_connection_string: str, enable_iam_auth: bool ) -> None: @@ -496,10 +509,9 @@ async def close_async(self) -> None: """Helper function to cancel the cache's tasks and close aiohttp.ClientSession.""" await asyncio.gather(*[cache.close() for cache in self._cache.values()]) + await asyncio.wait_for(asyncio.gather(*[ proxy.close_async() for proxy in self._proxies]), timeout=2.0) if self._client: await self._client.close() - if self._proxy: - await asyncio.wait_for([ proxy.close_async() for proxy in self._proxies]) diff --git a/google/cloud/sql/connector/local_unix_socket.py b/google/cloud/sql/connector/local_unix_socket.py index 497e503e0..25a9d3be3 100644 --- a/google/cloud/sql/connector/local_unix_socket.py +++ b/google/cloud/sql/connector/local_unix_socket.py @@ -17,8 +17,6 @@ import ssl from typing import Any, TYPE_CHECKING -SERVER_PROXY_PORT = 3307 - def connect( host: str, sock: ssl.SSLSocket, **kwargs: Any ) -> "ssl.SSLSocket": diff --git a/google/cloud/sql/connector/proxy.py b/google/cloud/sql/connector/proxy.py index 568c463ad..e5d9ab536 100644 --- a/google/cloud/sql/connector/proxy.py +++ b/google/cloud/sql/connector/proxy.py @@ -21,10 +21,8 @@ import selectors import ssl -from google.cloud.sql.connector import Connector from google.cloud.sql.connector.exceptions import LocalProxyStartupError -SERVER_PROXY_PORT = 3307 LOCAL_PROXY_MAX_MESSAGE_SIZE = 10485760 @@ -33,11 +31,11 @@ class Proxy: def __init__( self, - connector: Connector, + connector, instance_connection_string: str, socket_path: str, loop: asyncio.AbstractEventLoop, - **kwargs: Any + **kwargs ) -> None: """Keeps track of all the async tasks and starts the accept loop for new connections. @@ -61,28 +59,8 @@ def __init__( self._addr = instance_connection_string self._kwargs = kwargs self._connector = connector - self._task = loop.create_task(accept_loop(socket_path, loop, **kwargs)) - async def accept_loop( - self - socket_path: str, - loop: asyncio.AbstractEventLoop - ) -> asyncio.Task: - """Starts a UNIX based local proxy for transporting messages through - the SSL Socket, and waits until there is a new connection to accept, to register it - and keep track of it. - - Args: - socket_path: A system path that is going to be used to store the socket. - - loop (asyncio.AbstractEventLoop): Event loop to run asyncio tasks. - - Raises: - LocalProxyStartupError: Local UNIX socket based proxy was not able to - get started. - """ unix_socket = None - sel = selectors.DefaultSelector() try: path_parts = socket_path.rsplit('/', 1) @@ -100,14 +78,34 @@ async def accept_loop( unix_socket.listen(1) unix_socket.setblocking(False) os.chmod(socket_path, 0o600) - - sel.register(unix_socket, selectors.EVENT_READ, data=None) + + self._task = loop.create_task(self.accept_loop(unix_socket, socket_path, loop)) except Exception: raise LocalProxyStartupError( 'Local UNIX socket based proxy was not able to get started.' ) + async def accept_loop( + self, + unix_socket, + socket_path: str, + loop: asyncio.AbstractEventLoop + ) -> asyncio.Task: + """Starts a UNIX based local proxy for transporting messages through + the SSL Socket, and waits until there is a new connection to accept, to register it + and keep track of it. + + Args: + socket_path: A system path that is going to be used to store the socket. + + loop (asyncio.AbstractEventLoop): Event loop to run asyncio tasks. + + Raises: + LocalProxyStartupError: Local UNIX socket based proxy was not able to + get started. + """ + print("on accept loop") while True: client, _ = await loop.sock_accept(unix_socket) self._connection_tasks.append(loop.create_task(self.client_socket(client, unix_socket, socket_path, loop))) @@ -124,7 +122,7 @@ async def client_socket( self, client, unix_socket, socket_path, loop ): try: - ssl_sock = self.connector.connect( + ssl_sock = self._connector.connect( self._addr, 'local_unix_socket', **self._kwargs diff --git a/tests/system/test_psycopg_connection.py b/tests/system/test_psycopg_connection.py index e54d957c8..6f0e07c42 100644 --- a/tests/system/test_psycopg_connection.py +++ b/tests/system/test_psycopg_connection.py @@ -27,6 +27,7 @@ from google.cloud.sql.connector import DefaultResolver from google.cloud.sql.connector import DnsResolver +SERVER_PROXY_PORT = 3307 def create_sqlalchemy_engine( instance_connection_name: str, @@ -80,8 +81,9 @@ def create_sqlalchemy_engine( instance connection names ("my-project:my-region:my-instance"). """ connector = Connector(refresh_strategy=refresh_strategy, resolver=resolver) - unix_socket_path = "/tmp/conn" - await connector.start_unix_socket_proxy_async( + unix_socket_folder = "/tmp/conn" + unix_socket_path = f"{unix_socket_folder}/.s.PGSQL.3307" + connector.start_unix_socket_proxy_async( instance_connection_name, unix_socket_path, ip_type=ip_type, # can be "public", "private" or "psc" @@ -91,10 +93,10 @@ def create_sqlalchemy_engine( engine = sqlalchemy.create_engine( "postgresql+psycopg://", creator=lambda: Connection.connect( - f"host={unix_socket_path} port={SERVER_PROXY_PORT} dbname={db} user={user} password={passwd} sslmode=require", + f"host={unix_socket_folder} port={SERVER_PROXY_PORT} dbname={db} user={user} password={password} sslmode=require", user=user, password=password, - db=db, + dbname=db, autocommit=True, ) ) diff --git a/tests/unit/test_connector.py b/tests/unit/test_connector.py index 704fb52a7..47f242adc 100644 --- a/tests/unit/test_connector.py +++ b/tests/unit/test_connector.py @@ -33,7 +33,7 @@ from google.cloud.sql.connector.exceptions import CloudSQLIPTypeError from google.cloud.sql.connector.exceptions import IncompatibleDriverError from google.cloud.sql.connector.instance import RefreshAheadCache -from google.cloud.sql.connector.proxy import start_local_proxy +# from google.cloud.sql.connector.proxy import start_local_proxy @pytest.mark.asyncio @@ -282,47 +282,47 @@ async def test_Connector_connect_async( # verify connector made connection call assert connection is True -@pytest.mark.usefixtures("proxy_server") -@pytest.mark.asyncio -async def test_Connector_connect_local_proxy( - fake_credentials: Credentials, fake_client: CloudSQLClient, context: ssl.SSLContext -) -> None: - """Test that Connector.connect can launch start_local_proxy.""" - async with Connector( - credentials=fake_credentials, loop=asyncio.get_running_loop() - ) as connector: - connector._client = fake_client - socket_path = "/tmp/connector-socket/socket" - ip_addr = "127.0.0.1" - ssl_sock = context.wrap_socket( - socket.create_connection((ip_addr, 3307)), - server_hostname=ip_addr, - ) - loop = asyncio.get_running_loop() - task = start_local_proxy(ssl_sock, socket_path, loop) - # patch db connection creation - with patch("google.cloud.sql.connector.proxy.start_local_proxy") as mock_proxy: - with patch("google.cloud.sql.connector.psycopg.connect") as mock_connect: - mock_connect.return_value = True - mock_proxy.return_value = task - connection = await connector.connect_async( - "test-project:test-region:test-instance", - "psycopg", - user="my-user", - password="my-pass", - db="my-db", - local_socket_path=socket_path, - ) - # verify connector called local proxy - mock_connect.assert_called_once() - mock_proxy.assert_called_once() - assert connection is True +# @pytest.mark.usefixtures("proxy_server") +# @pytest.mark.asyncio +# async def test_Connector_connect_local_proxy( +# fake_credentials: Credentials, fake_client: CloudSQLClient, context: ssl.SSLContext +# ) -> None: +# """Test that Connector.connect can launch start_local_proxy.""" +# async with Connector( +# credentials=fake_credentials, loop=asyncio.get_running_loop() +# ) as connector: +# connector._client = fake_client +# socket_path = "/tmp/connector-socket/socket" +# ip_addr = "127.0.0.1" +# ssl_sock = context.wrap_socket( +# socket.create_connection((ip_addr, 3307)), +# server_hostname=ip_addr, +# ) +# loop = asyncio.get_running_loop() +# task = start_local_proxy(ssl_sock, socket_path, loop) +# # patch db connection creation +# with patch("google.cloud.sql.connector.proxy.start_local_proxy") as mock_proxy: +# with patch("google.cloud.sql.connector.psycopg.connect") as mock_connect: +# mock_connect.return_value = True +# mock_proxy.return_value = task +# connection = await connector.connect_async( +# "test-project:test-region:test-instance", +# "psycopg", +# user="my-user", +# password="my-pass", +# db="my-db", +# local_socket_path=socket_path, +# ) +# # verify connector called local proxy +# mock_connect.assert_called_once() +# mock_proxy.assert_called_once() +# assert connection is True - proxy_task = asyncio.gather(task) - try: - await asyncio.wait_for(proxy_task, timeout=0.1) - except (asyncio.CancelledError, asyncio.TimeoutError, TimeoutError): - pass # This task runs forever so it is expected to throw this exception +# proxy_task = asyncio.gather(task) +# try: +# await asyncio.wait_for(proxy_task, timeout=0.1) +# except (asyncio.CancelledError, asyncio.TimeoutError, TimeoutError): +# pass # This task runs forever so it is expected to throw this exception @pytest.mark.asyncio From 1eb6af21dc2056d7d78021cad697f1add8e2d958 Mon Sep 17 00:00:00 2001 From: Uziel Silva Date: Thu, 2 Oct 2025 14:51:13 -0600 Subject: [PATCH 12/14] feat: Add proxy server and fix all unit tests --- .gitignore | 1 + google/cloud/sql/connector/connector.py | 270 ++++++++++------ google/cloud/sql/connector/enums.py | 4 +- google/cloud/sql/connector/proxy.py | 308 ++++++++++++------ pyproject.toml | 4 + tests/conftest.py | 200 +++++++++--- tests/system/test_psycopg_connection.py | 8 +- tests/unit/test_connector.py | 215 +++++++++---- tests/unit/test_proxy.py | 407 +++++++++++++++++++++--- 9 files changed, 1062 insertions(+), 355 deletions(-) diff --git a/.gitignore b/.gitignore index 9f449ce4a..07f89e077 100644 --- a/.gitignore +++ b/.gitignore @@ -8,3 +8,4 @@ dist/ .idea .coverage sponge_log.xml +*.iml diff --git a/google/cloud/sql/connector/connector.py b/google/cloud/sql/connector/connector.py index 06c6fb98f..61a662080 100755 --- a/google/cloud/sql/connector/connector.py +++ b/google/cloud/sql/connector/connector.py @@ -20,7 +20,6 @@ from functools import partial import logging import os -import socket from threading import Thread from types import TracebackType from typing import Any, Callable, Optional, Union @@ -35,10 +34,9 @@ from google.cloud.sql.connector.enums import RefreshStrategy from google.cloud.sql.connector.instance import RefreshAheadCache from google.cloud.sql.connector.lazy import LazyRefreshCache -import google.cloud.sql.connector.local_unix_socket as local_unix_socket from google.cloud.sql.connector.monitored_cache import MonitoredCache import google.cloud.sql.connector.pg8000 as pg8000 -from google.cloud.sql.connector.proxy import Proxy +import google.cloud.sql.connector.proxy as proxy import google.cloud.sql.connector.pymysql as pymysql import google.cloud.sql.connector.pytds as pytds from google.cloud.sql.connector.resolver import DefaultResolver @@ -216,6 +214,108 @@ def __init__( def universe_domain(self) -> str: return self._universe_domain or _DEFAULT_UNIVERSE_DOMAIN + async def _get_cache( + self, + instance_connection_string: str, + enable_iam_auth: bool, + ip_type: IPTypes, + driver: str | None, + ) -> MonitoredCache: + """Helper function to get instance's cache from Connector cache.""" + + # resolve instance connection name + conn_name = await self._resolver.resolve(instance_connection_string) + cache_key = (str(conn_name), enable_iam_auth) + + # if cache entry doesn't exist or is closed, create it + if cache_key not in self._cache or self._cache[cache_key].closed: + # if lazy refresh, init keys now + if self._refresh_strategy == RefreshStrategy.LAZY and self._keys is None: + self._keys = asyncio.create_task(generate_keys()) + # create cache + if self._refresh_strategy == RefreshStrategy.LAZY: + logger.debug( + f"['{conn_name}']: Refresh strategy is set to lazy refresh" + ) + cache: Union[LazyRefreshCache, RefreshAheadCache] = LazyRefreshCache( + conn_name, + self._init_client(driver), + self._keys, # type: ignore + enable_iam_auth, + ) + else: + logger.debug( + f"['{conn_name}']: Refresh strategy is set to background refresh" + ) + cache = RefreshAheadCache( + conn_name, + self._init_client(driver), + self._keys, # type: ignore + enable_iam_auth, + ) + # wrap cache as a MonitoredCache + monitored_cache = MonitoredCache( + cache, + self._failover_period, + self._resolver, + ) + logger.debug(f"['{conn_name}']: Connection info added to cache") + self._cache[cache_key] = monitored_cache + + monitored_cache = self._cache[(str(conn_name), enable_iam_auth)] + + # Check that the information is valid and matches the driver and db type + try: + conn_info = await monitored_cache.connect_info() + # validate driver matches intended database engine + if driver: + DriverMapping.validate_engine(driver, conn_info.database_version) + if ip_type: + conn_info.get_preferred_ip(ip_type) + except Exception: + await self._remove_cached(str(conn_name), enable_iam_auth) + raise + + return monitored_cache + + async def connect_socket_async( + self, + instance_connection_string: str, + protocol_fn: Callable[[], asyncio.Protocol], + **kwargs: Any, + ) -> tuple[asyncio.Transport, asyncio.Protocol]: + """Helper function to connect to a Cloud SQL instance and return a socket.""" + + enable_iam_auth = kwargs.pop("enable_iam_auth", self._enable_iam_auth) + ip_type = kwargs.pop("ip_type", self._ip_type) + driver = kwargs.pop("driver", None) + # if ip_type is str, convert to IPTypes enum + if isinstance(ip_type, str): + ip_type = IPTypes._from_str(ip_type) + + monitored_cache = await self._get_cache( + instance_connection_string, enable_iam_auth, ip_type, driver + ) + + try: + conn_info = await monitored_cache.connect_info() + ctx = await conn_info.create_ssl_context(enable_iam_auth) + ip_address = conn_info.get_preferred_ip(ip_type) + tx, p = await self._loop.create_connection( + protocol_fn, host=ip_address, port=3307, ssl=ctx + ) + except Exception as ex: + logger.exception("exception starting tls protocol", exc_info=ex) + # with an error from Cloud SQL Admin API call or IP type, invalidate + # the cache and re-raise the error + await self._remove_cached( + instance_connection_string, + enable_iam_auth, + ) + raise + + return tx, p + def connect( self, instance_connection_string: str, driver: str, **kwargs: Any ) -> Any: @@ -233,7 +333,7 @@ def connect( Example: "my-project:us-central1:my-instance" driver (str): A string representing the database driver to connect - with. Supported drivers are pymysql, pg8000, local_unix_socket, and pytds. + with. Supported drivers are pymysql, pg8000, psycopg, and pytds. **kwargs: Any driver-specific arguments to pass to the underlying driver .connect call. @@ -251,6 +351,18 @@ def connect( ) return connect_future.result() + def _init_client(self, driver: Optional[str]) -> CloudSQLClient: + """Lazy initialize the client, setting the driver name in the user agent string.""" + if self._client is None: + self._client = CloudSQLClient( + self._sqladmin_api_endpoint, + self._quota_project, + self._credentials, + user_agent=self._user_agent, + driver=driver, + ) + return self._client + async def connect_async( self, instance_connection_string: str, driver: str, **kwargs: Any ) -> Any: @@ -269,7 +381,7 @@ async def connect_async( Example: "my-project:us-central1:my-instance" driver (str): A string representing the database driver to connect - with. Supported drivers are pymysql, asyncpg, pg8000, local_unix_socket, and + with. Supported drivers are pymysql, asyncpg, pg8000, psycopg, and pytds. **kwargs: Any driver-specific arguments to pass to the underlying @@ -282,66 +394,19 @@ async def connect_async( ValueError: Connection attempt with built-in database authentication and then subsequent attempt with IAM database authentication. KeyError: Unsupported database driver Must be one of pymysql, asyncpg, - pg8000, local_unix_socket, and pytds. + pg8000, psycopg, and pytds. """ - if self._keys is None: - self._keys = asyncio.create_task(generate_keys()) - if self._client is None: - # lazy init client as it has to be initialized in async context - self._client = CloudSQLClient( - self._sqladmin_api_endpoint, - self._quota_project, - self._credentials, - user_agent=self._user_agent, - driver=driver, - ) - enable_iam_auth = kwargs.pop("enable_iam_auth", self._enable_iam_auth) - - conn_name = await self._resolver.resolve(instance_connection_string) - # Cache entry must exist and not be closed - if (str(conn_name), enable_iam_auth) in self._cache and not self._cache[ - (str(conn_name), enable_iam_auth) - ].closed: - monitored_cache = self._cache[(str(conn_name), enable_iam_auth)] - else: - if self._refresh_strategy == RefreshStrategy.LAZY: - logger.debug( - f"['{conn_name}']: Refresh strategy is set to lazy refresh" - ) - cache: Union[LazyRefreshCache, RefreshAheadCache] = LazyRefreshCache( - conn_name, - self._client, - self._keys, - enable_iam_auth, - ) - else: - logger.debug( - f"['{conn_name}']: Refresh strategy is set to backgound refresh" - ) - cache = RefreshAheadCache( - conn_name, - self._client, - self._keys, - enable_iam_auth, - ) - # wrap cache as a MonitoredCache - monitored_cache = MonitoredCache( - cache, - self._failover_period, - self._resolver, - ) - logger.debug(f"['{conn_name}']: Connection info added to cache") - self._cache[(str(conn_name), enable_iam_auth)] = monitored_cache + self._init_client(driver) + # Map drivers to connect functions connect_func = { "pymysql": pymysql.connect, "pg8000": pg8000.connect, - "local_unix_socket": local_unix_socket.connect, "asyncpg": asyncpg.connect, "pytds": pytds.connect, } - # only accept supported database drivers + # Only accept supported database drivers try: connector: Callable = connect_func[driver] # type: ignore except KeyError: @@ -351,6 +416,7 @@ async def connect_async( # if ip_type is str, convert to IPTypes enum if isinstance(ip_type, str): ip_type = IPTypes._from_str(ip_type) + enable_iam_auth = kwargs.pop("enable_iam_auth", self._enable_iam_auth) kwargs["timeout"] = kwargs.get("timeout", self._timeout) # Host and ssl options come from the certificates and metadata, so we don't @@ -359,58 +425,45 @@ async def connect_async( kwargs.pop("ssl", None) kwargs.pop("port", None) - # attempt to get connection info for Cloud SQL instance + monitored_cache = await self._get_cache( + instance_connection_string, enable_iam_auth, ip_type, driver + ) + conn_info = await monitored_cache.connect_info() + ip_address = conn_info.get_preferred_ip(ip_type) + try: - conn_info = await monitored_cache.connect_info() - # validate driver matches intended database engine - DriverMapping.validate_engine(driver, conn_info.database_version) - ip_address = conn_info.get_preferred_ip(ip_type) - except Exception: - # with an error from Cloud SQL Admin API call or IP type, invalidate - # the cache and re-raise the error - await self._remove_cached(str(conn_name), enable_iam_auth) - raise - logger.debug(f"['{conn_info.conn_name}']: Connecting to {ip_address}:3307") - # format `user` param for automatic IAM database authn - if enable_iam_auth: - formatted_user = format_database_user( - conn_info.database_version, kwargs["user"] - ) - if formatted_user != kwargs["user"]: - logger.debug( - f"['{instance_connection_string}']: Truncated IAM database username from {kwargs['user']} to {formatted_user}" + # format `user` param for automatic IAM database authn + if enable_iam_auth: + formatted_user = format_database_user( + conn_info.database_version, kwargs["user"] ) - kwargs["user"] = formatted_user - try: + if formatted_user != kwargs["user"]: + logger.debug( + f"['{instance_connection_string}']: " + "Truncated IAM database username from " + f"{kwargs['user']} to {formatted_user}" + ) + kwargs["user"] = formatted_user + + ctx = await conn_info.create_ssl_context(enable_iam_auth) # async drivers are unblocking and can be awaited directly if driver in ASYNC_DRIVERS: - return await connector( - ip_address, - await conn_info.create_ssl_context(enable_iam_auth), - **kwargs, + return await connector(ip_address, ctx, **kwargs) + else: + # Synchronous drivers are blocking and run using executor + tx, _ = await self.connect_socket_async( + instance_connection_string, asyncio.Protocol, **kwargs ) - # Create socket with SSLContext for sync drivers - ctx = await conn_info.create_ssl_context(enable_iam_auth) - sock = ctx.wrap_socket( - socket.create_connection((ip_address, SERVER_PROXY_PORT)), - server_hostname=ip_address, - ) - - # If this connection was opened using a domain name, then store it - # for later in case we need to forcibly close it on failover. - if conn_info.conn_name.domain_name: - monitored_cache.sockets.append(sock) - # Synchronous drivers are blocking and run using executor - connect_partial = partial( - connector, - ip_address, - sock, - **kwargs, - ) - return await self._loop.run_in_executor(None, connect_partial) + # See https://docs.python.org/3/library/asyncio-protocol.html#asyncio.BaseTransport.get_extra_info + sock = tx.get_extra_info("ssl_object") + connect_partial = partial(connector, ip_address, sock, **kwargs) + return await self._loop.run_in_executor(None, connect_partial) except Exception: # with any exception, we attempt a force refresh, then throw the error + monitored_cache = await self._get_cache( + instance_connection_string, enable_iam_auth, ip_type, driver + ) await monitored_cache.force_refresh() raise @@ -449,7 +502,7 @@ async def start_unix_socket_proxy_async( ) await proxy_instance.start() self._proxies.append(proxy_instance) - + async def _remove_cached( self, instance_connection_string: str, enable_iam_auth: bool ) -> None: @@ -508,13 +561,14 @@ def close(self) -> None: async def close_async(self) -> None: """Helper function to cancel the cache's tasks and close aiohttp.ClientSession.""" + # close all proxies + if self._proxies: + await asyncio.gather(*[proxy.close() for proxy in self._proxies]) await asyncio.gather(*[cache.close() for cache in self._cache.values()]) - await asyncio.wait_for(asyncio.gather(*[ proxy.close_async() for proxy in self._proxies]), timeout=2.0) if self._client: await self._client.close() - async def create_async_connector( ip_type: str | IPTypes = IPTypes.PUBLIC, enable_iam_auth: bool = False, @@ -604,3 +658,13 @@ async def create_async_connector( resolver=resolver, failover_period=failover_period, ) + + +class ConnectorSocketFactory(proxy.ServerConnectionFactory): + def __init__(self, connector:Connector, instance_connection_string:str, **kwargs): + self._connector = connector + self._instance_connection_string = instance_connection_string + self._connect_args=kwargs + + async def connect(self, protocol_fn: Callable[[], asyncio.Protocol]): + await self._connector.connect_socket_async(self._instance_connection_string, protocol_fn, **self._connect_args) \ No newline at end of file diff --git a/google/cloud/sql/connector/enums.py b/google/cloud/sql/connector/enums.py index 7bde045ed..5926b75a2 100644 --- a/google/cloud/sql/connector/enums.py +++ b/google/cloud/sql/connector/enums.py @@ -62,7 +62,7 @@ class DriverMapping(Enum): ASYNCPG = "POSTGRES" PG8000 = "POSTGRES" - LOCAL_UNIX_SOCKET = "ANY" + PSYCOPG = "POSTGRES" PYMYSQL = "MYSQL" PYTDS = "SQLSERVER" @@ -79,7 +79,7 @@ def validate_engine(driver: str, engine_version: str) -> None: the given engine. """ mapping = DriverMapping[driver.upper()] - if not mapping.value == "ANY" and not engine_version.startswith(mapping.value): + if not engine_version.startswith(mapping.value): raise IncompatibleDriverError( f"Database driver '{driver}' is incompatible with database " f"version '{engine_version}'. Given driver can " diff --git a/google/cloud/sql/connector/proxy.py b/google/cloud/sql/connector/proxy.py index e5d9ab536..99a121782 100644 --- a/google/cloud/sql/connector/proxy.py +++ b/google/cloud/sql/connector/proxy.py @@ -14,127 +14,253 @@ limitations under the License. """ +from __future__ import annotations + +from abc import ABC +from abc import abstractmethod import asyncio +from functools import partial +import logging import os from pathlib import Path -import socket -import selectors -import ssl +from typing import Callable, List -from google.cloud.sql.connector.exceptions import LocalProxyStartupError +logger = logging.getLogger(name=__name__) -LOCAL_PROXY_MAX_MESSAGE_SIZE = 10485760 +class BaseProxyProtocol(asyncio.Protocol): + """ + A protocol to proxy data between two transports. + """ -class Proxy: - """Creates an "accept loop" async task which will open the unix server socket and listen for new connections.""" + def __init__(self, proxy: Proxy): + super().__init__() + self.proxy = proxy + self._buffer = bytearray() + self._target: asyncio.Transport | None = None + self.transport: asyncio.Transport | None = None + self._cached: List[bytes] = [] + logger.debug(f"__init__ {self}") + + def connection_made(self, transport): + logger.debug(f"connection_made {self}") + self.transport = transport + + def data_received(self, data): + if self._target is None: + self._cached.append(data) + else: + self._target.write(data) + + def set_target(self, target: asyncio.Transport): + logger.debug(f"set_target {self}") + self._target = target + if self._cached: + self._target.writelines(self._cached) + self._cached = [] + + def eof_received(self): + logger.debug(f"eof_received {self}") + if self._target is not None: + self._target.write_eof() + + def connection_lost(self, exc: Exception | None): + logger.debug(f"connection_lost {exc} {self}") + if self._target is not None: + self._target.close() + + +class ProxyClientConnection: + """ + Holds all of the tasks and details for a client proxy + """ def __init__( self, - connector, - instance_connection_string: str, - socket_path: str, - loop: asyncio.AbstractEventLoop, - **kwargs - ) -> None: - """Keeps track of all the async tasks and starts the accept loop for new connections. - - Args: - connector (Connector): The instance where this Proxy class was created. + client_transport: asyncio.Transport, + client_protocol: ClientToServerProtocol, + ): + self.client_transport = client_transport + self.client_protocol = client_protocol + self.server_transport: asyncio.Transport | None = None + self.server_protocol: ServerToClientProtocol | None = None + self.task: asyncio.Task | None = None + + def close(self): + logger.debug(f"closing {self}") + if self.client_transport is not None: + self._close_transport(self.client_transport) + if self.server_transport is not None: + self._close_transport(self.server_transport) + + def _close_transport(self, transport:asyncio.Transport): + if transport.is_closing(): + return + if transport.can_write_eof(): + transport.write_eof() + else: + transport.close() + +class ClientToServerProtocol(BaseProxyProtocol): + """ + Protocol to copy bytes from the unix socket client to the database server + """ + + def __init__(self, proxy: Proxy): + super().__init__(proxy) + self._buffer = bytearray() + self._target: asyncio.Transport | None = None + logger.debug(f"__init__ {self}") + + def connection_made(self, transport): + # When a connection is made, open the server connection + super().connection_made(transport) + self.proxy._handle_client_connection(transport, self) - instance_connection_string (str): The instance connection name of the - Cloud SQL instance to connect to. Takes the form of - "project-id:region:instance-name" - Example: "my-project:us-central1:my-instance" +class ServerToClientProtocol(BaseProxyProtocol): + """ + Protocol to copy bytes from the database server to the client socket + """ - socket_path (str): A system path that is going to be used to store the socket. + def __init__(self, proxy: Proxy, cconn: ProxyClientConnection): + super().__init__(proxy) + self._buffer = bytearray() + self._target = cconn.client_transport + self._client_protocol = cconn.client_protocol + logger.debug(f"__init__ {self}") - loop (asyncio.AbstractEventLoop): Event loop to run asyncio tasks. + def connection_made(self, transport): + super().connection_made(transport) + self._client_protocol.set_target(transport) - **kwargs: Any driver-specific arguments to pass to the underlying - driver .connect call. + def connection_lost(self, exc: Exception | None): + super().connection_lost(exc) + self.proxy._handle_server_connection_lost() + +class ServerConnectionFactory(ABC): + """ + ServerConnectionFactory is an abstract class that provides connections to the service. + """ + @abstractmethod + async def connect(self, protocol_fn: Callable[[], asyncio.Protocol]): + """ + Establishes a connection to the server and configures it to use the protocol + returned from protocol_fn, with asyncio.EventLoop.create_connection(). + :param protocol_fn: the protocol function + :return: None """ - self._connection_tasks = [] - self._addr = instance_connection_string - self._kwargs = kwargs - self._connector = connector + pass - unix_socket = None +class Proxy: + """ + A class to represent a local Unix socket proxy for a Cloud SQL instance. + This class manages a Unix socket that listens for incoming connections and + proxies them to a Cloud SQL instance. + """ - try: - path_parts = socket_path.rsplit('/', 1) - parent_directory = '/'.join(path_parts[:-1]) + def __init__( + self, + unix_socket_path: str, + server_connection_factory: ServerConnectionFactory, + loop: asyncio.AbstractEventLoop, + ): + """ + Creates a new Proxy + :param unix_socket_path: the path to listen for the proxy connection + :param loop: The event loop + :param instance_connect: A function that will establish the async connection to the server + + The instance_connect function is an asynchronous function that should set up a new connection. + It takes one argument - another function that + """ + self.unix_socket_path = unix_socket_path + self.alive = True + self._loop = loop + self._server: asyncio.AbstractServer | None = None + self._client_connections: set[ProxyClientConnection] = set() + self._server_connection_factory = server_connection_factory - desired_path = Path(parent_directory) - desired_path.mkdir(parents=True, exist_ok=True) + async def start(self) -> None: + """Starts the Unix socket server.""" + if os.path.exists(self.unix_socket_path): + os.remove(self.unix_socket_path) - if os.path.exists(socket_path): - os.remove(socket_path) + parent_dir = Path(self.unix_socket_path).parent + parent_dir.mkdir(parents=True, exist_ok=True) - unix_socket = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + def new_protocol() -> ClientToServerProtocol: + return ClientToServerProtocol(self) - unix_socket.bind(socket_path) - unix_socket.listen(1) - unix_socket.setblocking(False) - os.chmod(socket_path, 0o600) + logger.debug(f"Socket path: {self.unix_socket_path}") + self._server = await self._loop.create_unix_server( + new_protocol, path=self.unix_socket_path + ) + self._loop.create_task(self._server.serve_forever()) - self._task = loop.create_task(self.accept_loop(unix_socket, socket_path, loop)) + def _handle_client_connection( + self, + client_transport: asyncio.Transport, + client_protocol: ClientToServerProtocol, + ) -> None: + """ + Register a new client connection and initiate the task to create a database connection. + This is called by ClientToServerProtocol.connection_made - except Exception: - raise LocalProxyStartupError( - 'Local UNIX socket based proxy was not able to get started.' - ) + :param client_transport: the client transport for the client unix socket + :param client_protocol: the instance for the + :return: None + """ + conn = ProxyClientConnection(client_transport, client_protocol) + self._client_connections.add(conn) + conn.task = self._loop.create_task(self._create_db_instance_connection(conn)) + conn.task.add_done_callback(lambda _: self._client_connections.discard(conn)) - async def accept_loop( + def _handle_server_connection_lost( self, - unix_socket, - socket_path: str, - loop: asyncio.AbstractEventLoop - ) -> asyncio.Task: - """Starts a UNIX based local proxy for transporting messages through - the SSL Socket, and waits until there is a new connection to accept, to register it - and keep track of it. + ) -> None: + """ + Closes the proxy server if the connection to the server is lost - Args: - socket_path: A system path that is going to be used to store the socket. + :return: None + """ + logger.debug(f"Closing proxy server due to lost connection") + self._loop.create_task(self.close()) + + async def _create_db_instance_connection(self, conn: ProxyClientConnection) -> None: + """ + Manages a single proxy connection from a client to the Cloud SQL instance. + """ + try: + logger.debug("_proxy_connection() started") + new_protocol = partial(ServerToClientProtocol, self, conn) + + # Establish connection to the database + await self._server_connection_factory.connect(new_protocol) + logger.debug("_proxy_connection() succeeded") - loop (asyncio.AbstractEventLoop): Event loop to run asyncio tasks. + except Exception as e: + logger.error(f"Error handling proxy connection: {e}") + await self.close() + raise e - Raises: - LocalProxyStartupError: Local UNIX socket based proxy was not able to - get started. + async def close(self) -> None: """ - print("on accept loop") - while True: - client, _ = await loop.sock_accept(unix_socket) - self._connection_tasks.append(loop.create_task(self.client_socket(client, unix_socket, socket_path, loop))) + Shuts down the proxy server and cleans up resources. + """ + logger.info(f"Closing Unix socket proxy at {self.unix_socket_path}") - async def close_async(self): - proxy_task = asyncio.gather(self._task) - try: - await asyncio.wait_for(proxy_task, timeout=0.1) - except (asyncio.CancelledError, asyncio.TimeoutError, TimeoutError): - pass # This task runs forever so it is expected to throw this exception + if self._server: + self._server.close() + await self._server.wait_closed() + if self._client_connections: + for conn in list(self._client_connections): + conn.close() + await asyncio.wait([c.task for c in self._client_connections if c.task is not None], timeout=0.1) - async def client_socket( - self, client, unix_socket, socket_path, loop - ): - try: - ssl_sock = self._connector.connect( - self._addr, - 'local_unix_socket', - **self._kwargs - ) - while True: - data = await loop.sock_recv(client, LOCAL_PROXY_MAX_MESSAGE_SIZE) - if not data: - client.close() - break - ssl_sock.sendall(data) - response = ssl_sock.recv(LOCAL_PROXY_MAX_MESSAGE_SIZE) - await loop.sock_sendall(client, response) - finally: - client.close() - os.remove(socket_path) # Clean up the socket file + if os.path.exists(self.unix_socket_path): + os.remove(self.unix_socket_path) + + logger.info(f"Unix socket proxy for {self.unix_socket_path} closed.") + self.alive = False \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index cbf0dd10f..8ffce4d63 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -82,6 +82,10 @@ exclude = ['docs/*', 'samples/*'] [tool.pytest.ini_options] asyncio_mode = "auto" +log_cli = true +log_cli_level = "DEBUG" +log_cli_format = "%(asctime)s [%(levelname)8s] %(message)s (%(filename)s:%(lineno)s)" +log_cli_date_format = "%Y-%m-%d %H:%M:%S.%f" [tool.ruff.lint] extend-select = ["I"] diff --git a/tests/conftest.py b/tests/conftest.py index 83d7a78f3..54836ee84 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -15,6 +15,8 @@ """ import asyncio +from asyncio import Server +import logging import os import socket import ssl @@ -36,6 +38,7 @@ from google.cloud.sql.connector.utils import write_to_file SCOPES = ["https://www.googleapis.com/auth/sqlservice.admin"] +logger = logging.getLogger(name=__name__) def pytest_addoption(parser: Any) -> None: @@ -84,55 +87,138 @@ def fake_credentials() -> FakeCredentials: return FakeCredentials() -async def start_proxy_server(instance: FakeCSQLInstance) -> None: +async def start_proxy_server_async( + instance: FakeCSQLInstance, with_read_write: bool +) -> Server: """Run local proxy server capable of performing mTLS""" ip_address = "127.0.0.1" port = 3307 - # create socket - with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: - # create SSL/TLS context - context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) - context.minimum_version = ssl.TLSVersion.TLSv1_3 - # tmpdir and its contents are automatically deleted after the CA cert - # and cert chain are loaded into the SSLcontext. The values - # need to be written to files in order to be loaded by the SSLContext - server_key_bytes = instance.server_key.private_bytes( - encoding=serialization.Encoding.PEM, - format=serialization.PrivateFormat.TraditionalOpenSSL, - encryption_algorithm=serialization.NoEncryption(), + logger.debug("start_proxy_server_async started") + + context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) + context.minimum_version = ssl.TLSVersion.TLSv1_3 + # tmpdir and its contents are automatically deleted after the CA cert + # and cert chain are loaded into the SSLcontext. The values + # need to be written to files in order to be loaded by the SSLContext + server_key_bytes = instance.server_key.private_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.TraditionalOpenSSL, + encryption_algorithm=serialization.NoEncryption(), + ) + + async with TemporaryDirectory() as tmpdir: + server_filename, _, key_filename = await write_to_file( + tmpdir, instance.server_cert_pem, "", server_key_bytes ) - async with TemporaryDirectory() as tmpdir: - server_filename, _, key_filename = await write_to_file( - tmpdir, instance.server_cert_pem, "", server_key_bytes - ) - context.load_cert_chain(server_filename, key_filename) - # allow socket to be re-used - sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - # bind socket to Cloud SQL proxy server port on localhost - sock.bind((ip_address, port)) - # listen for incoming connections - sock.listen(5) - - with context.wrap_socket(sock, server_side=True) as ssock: - while True: - conn, _ = ssock.accept() - conn.close() + context.load_cert_chain(server_filename, key_filename) + + async def handler(reader: asyncio.StreamReader, writer: asyncio.StreamWriter): + logger.debug("Received fake connection") + if with_read_write: + line = await reader.readline() + logger.debug(f"Received request {line}") + writer.write("world\n".encode("utf-8")) + await writer.drain() + logger.debug("Wrote response") + if writer.can_write_eof(): + writer.write_eof() + logger.debug("Closing connection") + writer.close() + await writer.wait_closed() + logger.debug("Closed connection") + + server = await asyncio.start_server( + handler, host=ip_address, port=port, ssl=context + ) + logger.debug("Listening on 127.0.0.1:3307") + asyncio.create_task(server.serve_forever()) + return server -@pytest.fixture(scope="session") -def proxy_server(fake_instance: FakeCSQLInstance) -> None: - """Run local proxy server capable of performing mTLS""" - thread = Thread( - target=asyncio.run, - args=( - start_proxy_server( - fake_instance, - ), - ), - daemon=True, +@pytest.fixture(scope="function") +def proxy_server_async(fake_instance: FakeCSQLInstance): + # Create an event loop in a different thread for the server + loop = asyncio.new_event_loop() + + def f(loop): + asyncio.set_event_loop(loop) + loop.run_forever() + logger.debug("exiting thread") + + t = Thread(target=f, args=(loop,)) + t.start() + t.join(1) + + # Submit the server task to the thread + server_fut = asyncio.run_coroutine_threadsafe( + start_proxy_server_async(fake_instance, True), loop ) - thread.start() - thread.join(1.0) # add a delay to allow the proxy server to start + while not server_fut.done(): + t.join(0.1) + logger.debug("proxy_server_async server started") + yield + logger.debug("proxy_server_async fixture done") + + logger.debug("proxy_server_async fixture cleanup") + + # Stop the server after the test is complete + async def stop_server(): + logger.debug("inside_cleanup closing server") + server_fut.result().close() + loop.shutdown_asyncgens() + loop.stop() + logger.debug("inside_cleanup end") + + logger.debug("cleanup starting") + asyncio.run_coroutine_threadsafe(stop_server(), loop) + logger.debug("cleanup done") + while loop.is_running(): + t.join(0.1) + logger.debug("loop is not running") + loop.close() + t.join(1) + + +@pytest.fixture(scope="function") +def proxy_server(fake_instance: FakeCSQLInstance): + # Create an event loop in a different thread for the server + loop = asyncio.new_event_loop() + + def f(loop): + asyncio.set_event_loop(loop) + loop.run_forever() + logger.debug("exiting thread") + + t = Thread(target=f, args=(loop,)) + t.start() + + # Submit the server task to the thread + server_fut = asyncio.run_coroutine_threadsafe( + start_proxy_server_async(fake_instance, False), loop + ) + while not server_fut.done(): + t.join(0.1) + logger.debug("proxy_server_async server started") + yield + logger.debug("proxy_server_async fixture done") + + logger.debug("proxy_server_async fixture cleanup") + + # Stop the server after the test is complete + async def stop_server(): + logger.debug("inside_cleanup closing server") + server_fut.result().close() + loop.shutdown_asyncgens() + loop.stop() + logger.debug("inside_cleanup end") + + logger.debug("cleanup starting") + asyncio.run_coroutine_threadsafe(stop_server(), loop) + logger.debug("cleanup done") + while loop.is_running(): + t.join(0.1) + logger.debug("loop is not running") + loop.close() @pytest.fixture @@ -191,3 +277,33 @@ async def cache(fake_client: CloudSQLClient) -> AsyncGenerator[RefreshAheadCache ) yield cache await cache.close() + + +@pytest.fixture +def connected_socket_pair() -> tuple[socket.socket, socket.socket]: + """A fixture that provides a pair of connected sockets.""" + server, client = socket.socketpair() + yield server, client + server.close() + client.close() + + +@pytest.fixture +async def echo_server() -> AsyncGenerator[tuple[str, int], None]: + """A fixture that starts an asyncio echo server.""" + + async def handle_echo(reader, writer): + while True: + data = await reader.read(100) + if not data: + break + writer.write(data) + await writer.drain() + writer.close() + await writer.wait_closed() + + server = await asyncio.start_server(handle_echo, "127.0.0.1", 0) + addr = server.sockets[0].getsockname() + yield addr + server.close() + await server.wait_closed() \ No newline at end of file diff --git a/tests/system/test_psycopg_connection.py b/tests/system/test_psycopg_connection.py index 6f0e07c42..1fed06356 100644 --- a/tests/system/test_psycopg_connection.py +++ b/tests/system/test_psycopg_connection.py @@ -29,7 +29,7 @@ SERVER_PROXY_PORT = 3307 -def create_sqlalchemy_engine( +async def create_sqlalchemy_engine( instance_connection_name: str, user: str, password: str, @@ -83,7 +83,7 @@ def create_sqlalchemy_engine( connector = Connector(refresh_strategy=refresh_strategy, resolver=resolver) unix_socket_folder = "/tmp/conn" unix_socket_path = f"{unix_socket_folder}/.s.PGSQL.3307" - connector.start_unix_socket_proxy_async( + await connector.start_unix_socket_proxy_async( instance_connection_name, unix_socket_path, ip_type=ip_type, # can be "public", "private" or "psc" @@ -107,7 +107,7 @@ def create_sqlalchemy_engine( # [END cloud_sql_connector_postgres_psycopg] -def test_psycopg_connection() -> None: +async def test_psycopg_connection() -> None: """Basic test to get time from database.""" inst_conn_name = os.environ["POSTGRES_CONNECTION_NAME"] user = os.environ["POSTGRES_USER"] @@ -115,7 +115,7 @@ def test_psycopg_connection() -> None: db = os.environ["POSTGRES_DB"] ip_type = os.environ.get("IP_TYPE", "public") - engine, connector = create_sqlalchemy_engine( + engine, connector = await create_sqlalchemy_engine( inst_conn_name, user, password, db, ip_type ) with engine.connect() as conn: diff --git a/tests/unit/test_connector.py b/tests/unit/test_connector.py index 47f242adc..2a7b64c82 100644 --- a/tests/unit/test_connector.py +++ b/tests/unit/test_connector.py @@ -1,12 +1,9 @@ """ Copyright 2021 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -15,9 +12,9 @@ """ import asyncio +import logging import os import socket -import ssl from typing import Union from aiohttp import ClientResponseError @@ -33,22 +30,29 @@ from google.cloud.sql.connector.exceptions import CloudSQLIPTypeError from google.cloud.sql.connector.exceptions import IncompatibleDriverError from google.cloud.sql.connector.instance import RefreshAheadCache -# from google.cloud.sql.connector.proxy import start_local_proxy + +logger = logging.getLogger(name=__name__) @pytest.mark.asyncio async def test_connect_enable_iam_auth_error( - fake_credentials: Credentials, fake_client: CloudSQLClient + fake_credentials: Credentials, + fake_client: CloudSQLClient, + connected_socket_pair: tuple[socket.socket, socket.socket], ) -> None: """Test that calling connect() with different enable_iam_auth argument values creates two cache entries.""" connect_string = "test-project:test-region:test-instance" + server, client = connected_socket_pair async with Connector( credentials=fake_credentials, loop=asyncio.get_running_loop() ) as connector: connector._client = fake_client # patch db connection creation - with patch("google.cloud.sql.connector.asyncpg.connect") as mock_connect: + with ( + patch("socket.create_connection", return_value=client), + patch("google.cloud.sql.connector.asyncpg.connect") as mock_connect, + ): mock_connect.return_value = True # connect with enable_iam_auth False connection = await connector.connect_async( @@ -81,6 +85,7 @@ async def test_connect_enable_iam_auth_error( async def test_connect_incompatible_driver_error( fake_credentials: Credentials, fake_client: CloudSQLClient, + proxy_server, ) -> None: """Test that calling connect() with driver that is incompatible with database version throws error.""" @@ -90,14 +95,8 @@ async def test_connect_incompatible_driver_error( ) as connector: connector._client = fake_client # try to connect using pymysql driver to a Postgres database - with pytest.raises(IncompatibleDriverError) as exc_info: + with pytest.raises(IncompatibleDriverError): await connector.connect_async(connect_string, "pymysql") - assert ( - exc_info.value.args[0] - == "Database driver 'pymysql' is incompatible with database version" - " 'POSTGRES_15'. Given driver can only be used with Cloud SQL MYSQL" - " databases." - ) def test_connect_with_unsupported_driver(fake_credentials: Credentials) -> None: @@ -238,13 +237,19 @@ def test_Connector_Init_bad_ip_type(fake_credentials: Credentials) -> None: def test_Connector_connect_bad_ip_type( - fake_credentials: Credentials, fake_client: CloudSQLClient + fake_credentials: Credentials, + fake_client: CloudSQLClient, + connected_socket_pair: tuple[socket.socket, socket.socket], ) -> None: """Test that Connector.connect errors due to bad ip_type str.""" + server, client = connected_socket_pair with Connector(credentials=fake_credentials) as connector: connector._client = fake_client bad_ip_type = "bad-ip-type" - with pytest.raises(ValueError) as exc_info: + with ( + patch("socket.create_connection", return_value=client), + pytest.raises(ValueError) as exc_info, + ): connector.connect( "test-project:test-region:test-instance", "pg8000", @@ -262,15 +267,21 @@ def test_Connector_connect_bad_ip_type( @pytest.mark.asyncio async def test_Connector_connect_async( - fake_credentials: Credentials, fake_client: CloudSQLClient + fake_credentials: Credentials, + fake_client: CloudSQLClient, + connected_socket_pair: tuple[socket.socket, socket.socket], ) -> None: """Test that Connector.connect_async can properly return a DB API connection.""" + server, client = connected_socket_pair async with Connector( credentials=fake_credentials, loop=asyncio.get_running_loop() ) as connector: connector._client = fake_client # patch db connection creation - with patch("google.cloud.sql.connector.asyncpg.connect") as mock_connect: + with ( + patch("socket.create_connection", return_value=client), + patch("google.cloud.sql.connector.asyncpg.connect") as mock_connect, + ): mock_connect.return_value = True connection = await connector.connect_async( "test-project:test-region:test-instance", @@ -282,48 +293,6 @@ async def test_Connector_connect_async( # verify connector made connection call assert connection is True -# @pytest.mark.usefixtures("proxy_server") -# @pytest.mark.asyncio -# async def test_Connector_connect_local_proxy( -# fake_credentials: Credentials, fake_client: CloudSQLClient, context: ssl.SSLContext -# ) -> None: -# """Test that Connector.connect can launch start_local_proxy.""" -# async with Connector( -# credentials=fake_credentials, loop=asyncio.get_running_loop() -# ) as connector: -# connector._client = fake_client -# socket_path = "/tmp/connector-socket/socket" -# ip_addr = "127.0.0.1" -# ssl_sock = context.wrap_socket( -# socket.create_connection((ip_addr, 3307)), -# server_hostname=ip_addr, -# ) -# loop = asyncio.get_running_loop() -# task = start_local_proxy(ssl_sock, socket_path, loop) -# # patch db connection creation -# with patch("google.cloud.sql.connector.proxy.start_local_proxy") as mock_proxy: -# with patch("google.cloud.sql.connector.psycopg.connect") as mock_connect: -# mock_connect.return_value = True -# mock_proxy.return_value = task -# connection = await connector.connect_async( -# "test-project:test-region:test-instance", -# "psycopg", -# user="my-user", -# password="my-pass", -# db="my-db", -# local_socket_path=socket_path, -# ) -# # verify connector called local proxy -# mock_connect.assert_called_once() -# mock_proxy.assert_called_once() -# assert connection is True - -# proxy_task = asyncio.gather(task) -# try: -# await asyncio.wait_for(proxy_task, timeout=0.1) -# except (asyncio.CancelledError, asyncio.TimeoutError, TimeoutError): -# pass # This task runs forever so it is expected to throw this exception - @pytest.mark.asyncio async def test_create_async_connector(fake_credentials: Credentials) -> None: @@ -359,7 +328,9 @@ def test_Connector_close_called_multiple_times(fake_credentials: Credentials) -> async def test_Connector_remove_cached_bad_instance( - fake_credentials: Credentials, fake_client: CloudSQLClient + fake_credentials: Credentials, + fake_client: CloudSQLClient, + proxy_server, ) -> None: """When a Connector attempts to retrieve connection info for a non-existent instance, it should delete the instance from @@ -384,7 +355,9 @@ async def test_Connector_remove_cached_bad_instance( async def test_Connector_remove_cached_no_ip_type( - fake_credentials: Credentials, fake_client: CloudSQLClient + fake_credentials: Credentials, + fake_client: CloudSQLClient, + proxy_server, ) -> None: """When a Connector attempts to connect and preferred IP type is not present, it should delete the instance from the cache and ensure no background refresh @@ -513,3 +486,121 @@ def test_configured_quota_project_env_var( assert connector._quota_project == quota_project # unset env var del os.environ["GOOGLE_CLOUD_QUOTA_PROJECT"] + + +@pytest.mark.asyncio +async def test_Connector_start_unix_socket_proxy_async( + fake_credentials: Credentials, + fake_client: CloudSQLClient, + proxy_server_async: None, +) -> None: + """Test that Connector.connect_async can properly return a DB API connection.""" + async with Connector( + credentials=fake_credentials, loop=asyncio.get_running_loop() + ) as connector: + connector._client = fake_client + + # Open proxy connection + # start the proxy server + await connector.start_unix_socket_proxy_async( + "test-project:test-region:test-instance", + "/tmp/csql-python/proxytest/.s.PGSQL.5432", + driver="asyncpg", + user="my-user", + password="my-pass", + db="my-db", + ) + # Wait for server to start + await asyncio.sleep(0.5) + + reader, writer = await asyncio.open_unix_connection( + "/tmp/csql-python/proxytest/.s.PGSQL.5432" + ) + writer.write("hello\n".encode()) + await writer.drain() + await asyncio.sleep(0.5) + msg = await reader.readline() + assert msg.decode("utf-8") == "world\n" + + +class TestProtocol(asyncio.Protocol): + """ + A protocol to proxy data between two transports. + """ + + def __init__(self): + self._buffer = bytearray() + logger.debug(f"__init__ {self}") + self.received = bytearray() + self.connected = asyncio.Future() + self.future = asyncio.Future() + + def data_received(self, data): + logger.debug("received {!r}".format(data)) + self.received = data + + def connection_made(self, transport): + logger.debug(f"connection_made called {self}") + self.transport = transport + if not self.connected.done(): + self.connected.set_result(True) + # Write the request and EOF + transport.write("hello\n".encode()) + # if transport.can_write_eof(): + # transport.write_eof() + logger.debug(f"connection_made done, wrote hello{self}") + + def eof_received(self) -> bool | None: + logger.debug(f"eof_received {self}") + # If this has received data, then close. + if len(self.received) > 0: + self.transport.close() + if not self.connected.done(): + self.connected.set_result(True) + if not self.future.done(): + self.future.set_result(True) + return True + + def connection_lost(self, exc: Exception | None) -> None: + logger.debug(f"connection_lost {exc} {self}") + self.transport.abort() + if not self.connected.done(): + self.connected.set_result(True) + if not self.future.done(): + self.future.set_result(True) + super().connection_lost(exc) + + +@pytest.mark.asyncio +async def test_Connector_connect_socket_async( + fake_credentials: Credentials, + fake_client: CloudSQLClient, + proxy_server_async: None, +) -> None: + """Test that Connector.connect_async can properly return a DB API connection.""" + async with Connector( + credentials=fake_credentials, loop=asyncio.get_running_loop() + ) as connector: + logger.info("client socket opening") + connector._client = fake_client + p = TestProtocol() + + # Open proxy connection + # start the proxy server + future = connector.connect_socket_async( + "test-project:test-region:test-instance", + lambda: p, + driver="asyncpg", + user="my-user", + password="my-pass", + db="my-db", + ) + logger.info("client socket opening") + await future + logger.info("client socket opened") + await p.connected + logger.info("client socket connected") + await p.future + logger.info("client socket done") + + assert p.received.decode() == "world\n" \ No newline at end of file diff --git a/tests/unit/test_proxy.py b/tests/unit/test_proxy.py index c1143f19f..0c179186d 100644 --- a/tests/unit/test_proxy.py +++ b/tests/unit/test_proxy.py @@ -15,61 +15,366 @@ """ import asyncio -import socket -import ssl -from typing import Any +import os +import shutil +import tempfile +from unittest.mock import MagicMock -from mock import Mock import pytest -from google.cloud.sql.connector import proxy +from google.cloud.sql.connector.proxy import Proxy, ServerConnectionFactory + + +@pytest.fixture +def short_tmpdir(): + """Create a temporary directory with a short path.""" + dir_path = tempfile.mkdtemp(dir="/tmp") + yield dir_path + shutil.rmtree(dir_path) -LOCAL_PROXY_MAX_MESSAGE_SIZE = 10485760 -@pytest.mark.usefixtures("proxy_server") @pytest.mark.asyncio -async def test_proxy_creates_folder(context: ssl.SSLContext, kwargs: Any) -> None: - """Test to verify that the proxy server is getting back the task.""" - ip_addr = "127.0.0.1" - path = "/tmp/connector-socket/socket" - sock = context.wrap_socket( - socket.create_connection((ip_addr, 3307)), - server_hostname=ip_addr, - ) - loop = asyncio.get_running_loop() - - task = proxy.start_local_proxy(sock, path, loop) - assert (task is not None) - - proxy_task = asyncio.gather(task) - try: - await asyncio.wait_for(proxy_task, timeout=0.1) - except (asyncio.CancelledError, asyncio.TimeoutError, TimeoutError): - pass # This task runs forever so it is expected to throw this exception - -@pytest.mark.usefixtures("proxy_server") +async def test_proxy_creates_folder_and_socket(short_tmpdir): + """ + Test to verify that the Proxy server creates the folder and socket file. + """ + socket_path = os.path.join(short_tmpdir, ".s.PGSQL.5432") + connector = MagicMock(spec=ServerConnectionFactory) + proxy = Proxy(socket_path, connector, asyncio.get_event_loop()) + await proxy.start() + + assert os.path.exists(short_tmpdir) + assert os.path.exists(socket_path) + + await proxy.close() + + +# A mock ServerConnectionFactory for testing purposes. +class MockServerConnectionFactory(ServerConnectionFactory): + def __init__(self, loop): + self.server_protocol = None + self.server_transport = None + self.connect_called = asyncio.Event() + self.connect_ran = asyncio.Event() + self.force_connect_error = False + self.loop = loop + self.server_data = bytearray() + + async def connect(self, protocol_fn): + self.connect_called.set() + if self.force_connect_error: + raise Exception("Forced connection error") + + self.server_protocol = protocol_fn() + # Create a mock transport for server-side communication + self.server_transport = MagicMock(spec=asyncio.Transport) + self.server_transport.write.side_effect = self.server_data.extend + self.server_transport.is_closing.return_value = False + + # Simulate connection made for the server protocol + self.server_protocol.connection_made(self.server_transport) + self.connect_ran.set() + return self.server_transport, self.server_protocol + + +# Test fixture for the proxy +@pytest.fixture +async def proxy_server(short_tmpdir): + socket_path = os.path.join(short_tmpdir, ".s.PGSQL.5432") + loop = asyncio.get_event_loop() + connector = MockServerConnectionFactory(loop) + proxy = Proxy(socket_path, connector, loop) + await proxy.start() + yield proxy, socket_path, connector + await proxy.close() + + @pytest.mark.asyncio -async def test_local_proxy_communication(context: ssl.SSLContext, kwargs: Any) -> None: - """Test to verify that the communication is getting through.""" - socket_path = "/tmp/connector-socket/socket" - ssl_sock = Mock(spec=ssl.SSLSocket) - loop = asyncio.get_running_loop() - - with socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) as client: - ssl_sock.recv.return_value = b"Received" - - task = proxy.start_local_proxy(ssl_sock, socket_path, loop) - - client.connect(socket_path) - client.sendall(b"Test") - await asyncio.sleep(1) - - ssl_sock.sendall.assert_called_with(b"Test") - response = client.recv(LOCAL_PROXY_MAX_MESSAGE_SIZE) - assert (response == b"Received") - - client.close() - await asyncio.sleep(1) - - proxy_task = asyncio.gather(task) - await asyncio.wait_for(proxy_task, timeout=2) +async def test_proxy_client_to_server(proxy_server): + """ + 1. Create a new proxy. Open a client socket to the proxy. Write data to + the client socket. Read data from the server. Check that the data was + received by the server. + """ + proxy, socket_path, connector = proxy_server + reader, writer = await asyncio.open_unix_connection(socket_path) + + # wait for server connection to be established + await connector.connect_called.wait() + + # Write data to the client socket + test_data = b"test data from client" + writer.write(test_data) + await writer.drain() + + # Check that the data was received by the server + await asyncio.sleep(0.01) # give event loop a chance to run + assert connector.server_data == test_data + + writer.close() + await writer.wait_closed() + + +@pytest.mark.asyncio +async def test_proxy_server_to_client(proxy_server): + """ + 2. Create a new proxy, Open a client socket. Write data to the server + socket. Read data from the client socket. Check that the data was + received by the client. + """ + proxy, socket_path, connector = proxy_server + reader, writer = await asyncio.open_unix_connection(socket_path) + + # wait for server connection to be established + await connector.connect_called.wait() + + # Write data from the server to the client + test_data = b"test data from server" + connector.server_protocol.data_received(test_data) + + # Read data from the client socket + received_data = await reader.read(len(test_data)) + + # Check that the data was received by the client + assert received_data == test_data + + writer.close() + await writer.wait_closed() + + +@pytest.mark.asyncio +async def test_proxy_server_connect_fails(proxy_server): + """ + 3. Create a new proxy. Open a client socket. The server socket fails to + connect. Check that the client socket is closed. + """ + proxy, socket_path, connector = proxy_server + connector.force_connect_error = True + + reader, writer = await asyncio.open_unix_connection(socket_path) + + # wait for server connection to be attempted + await connector.connect_called.wait() + + assert os.path.exists(socket_path) == True + + # The client connection should be closed by the proxy + # Reading should return EOF + data = await reader.read(100) + assert data == b"" + + await asyncio.sleep(1) # give proxy a chance to shut down + + assert os.path.exists(socket_path) == False + + +@pytest.mark.asyncio +async def test_proxy_client_closes_connection(proxy_server): + """ + 4. Create a new proxy. Open a client socket. Check that the server + socket connected. Close the client socket. Check that the server socket + is closed gracefully. + """ + proxy, socket_path, connector = proxy_server + reader, writer = await asyncio.open_unix_connection(socket_path) + + # wait for server connection to be established + await connector.connect_called.wait() + assert connector.server_transport is not None + + # Close the client socket + writer.close() + await writer.wait_closed() + + # Check that the server socket is closed + await asyncio.sleep(0.01) # give event loop a chance to run + connector.server_transport.close.assert_called_once() + + +# +# TCP Server Fixtures and Tests +# + + +@pytest.fixture +async def tcp_echo_server(): + """Fixture to create a TCP echo server.""" + + async def echo(reader, writer): + try: + while not reader.at_eof(): + data = await reader.read(1024) + if not data: + break + writer.write(data) + await writer.drain() + finally: + writer.close() + await writer.wait_closed() + + server = await asyncio.start_server(echo, "127.0.0.1", 0) + addr = server.sockets[0].getsockname() + host, port = addr[0], addr[1] + + yield host, port + + server.close() + await server.wait_closed() + + +@pytest.fixture +async def tcp_server_accept_and_close(): + """Fixture to create a TCP server that accepts and immediately closes.""" + + async def accept_and_close(reader, writer): + writer.close() + await writer.wait_closed() + + server = await asyncio.start_server(accept_and_close, "127.0.0.1", 0) + addr = server.sockets[0].getsockname() + host, port = addr[0], addr[1] + + yield host, port + + server.close() + await server.wait_closed() + + +class TCPServerConnectionFactory(ServerConnectionFactory): + """A ServerConnectionFactory that connects to a TCP server.""" + + def __init__(self, host, port, loop): + self.host = host + self.port = port + self.loop = loop + self.connect_called = asyncio.Event() + self.connect_ran = asyncio.Event() + self.server_transport: asyncio.Transport | None = None + self.server_protocol: asyncio.Protocol | None = None + + async def connect(self, protocol_fn): + self.connect_called.set() + transport, protocol = await asyncio.wait_for(self.loop.create_connection( + protocol_fn, self.host, self.port, + ), timeout=0.5) + self.server_transport = transport + self.server_protocol = protocol + self.connect_ran.set() + return transport, protocol + + +@pytest.fixture +async def tcp_proxy_server(short_tmpdir, tcp_echo_server): + """Fixture to set up a proxy with a TCP backend.""" + socket_path = os.path.join(short_tmpdir, ".s.PGSQL.5432") + loop = asyncio.get_event_loop() + host, port = tcp_echo_server + connector = TCPServerConnectionFactory(host, port, loop) + proxy = Proxy(socket_path, connector, loop) + await proxy.start() + yield proxy, socket_path, connector + await proxy.close() + + +@pytest.fixture +async def tcp_proxy_server_with_closing_backend(short_tmpdir, tcp_server_accept_and_close): + """Fixture to set up a proxy with a TCP backend that closes immediately.""" + socket_path = os.path.join(short_tmpdir, ".s.PGSQL.5432") + loop = asyncio.get_event_loop() + host, port = tcp_server_accept_and_close + connector = TCPServerConnectionFactory(host, port, loop) + proxy = Proxy(socket_path, connector, loop) + await proxy.start() + yield proxy, socket_path, connector + await proxy.close() + + +@pytest.fixture +async def tcp_proxy_server_with_no_tcp_server(short_tmpdir): + """Fixture to set up a proxy with a TCP backend that closes immediately.""" + socket_path = os.path.join(short_tmpdir, ".s.PGSQL.5432") + loop = asyncio.get_event_loop() + connector = TCPServerConnectionFactory("localhost", "34532", loop) + proxy = Proxy(socket_path, connector, loop) + await proxy.start() + yield proxy, socket_path, connector + await proxy.close() + + +@pytest.mark.asyncio +async def test_tcp_proxy_echo(tcp_proxy_server): + """ + Tests data flow from client to a TCP server and back. + """ + proxy, socket_path, connector = tcp_proxy_server + reader, writer = await asyncio.open_unix_connection(socket_path) + await connector.connect_called.wait() + + test_data = b"test data from client" + writer.write(test_data) + await writer.drain() + + # Read echoed data back from the server + received_data = await reader.read(len(test_data)) + assert received_data == test_data + + writer.close() + await writer.wait_closed() + + +@pytest.mark.asyncio +async def test_tcp_proxy_server_connection_refused(tcp_proxy_server_with_no_tcp_server): + """ + Tests that the client socket is closed when TCP connection fails. + """ + proxy, socket_path, connector = tcp_proxy_server_with_no_tcp_server + + reader, writer = await asyncio.open_unix_connection(socket_path) + await connector.connect_called.wait() + test_data = b"test data from client" + writer.write(test_data) + await writer.drain() + + await asyncio.sleep(1.5) + assert os.path.exists(socket_path) == False + + + +@pytest.mark.asyncio +async def test_tcp_proxy_server_unexpected_closed(tcp_proxy_server_with_closing_backend): + """ + Tests that the client socket is closed when TCP connection fails. + """ + proxy, socket_path, connector = tcp_proxy_server_with_closing_backend + + reader, writer = await asyncio.open_unix_connection(socket_path) + await connector.connect_called.wait() + + # The client connection should be closed by the proxy + data = await reader.read(100) + assert data == b"" + + await asyncio.sleep(0.5) # give event loop a chance to run + assert os.path.exists(socket_path) == False + + + +@pytest.mark.asyncio +async def test_tcp_proxy_client_closes_connection(tcp_proxy_server): + """ + Tests that closing the client socket closes the TCP server socket. + """ + proxy, socket_path, connector = tcp_proxy_server + reader, writer = await asyncio.open_unix_connection(socket_path) + await connector.connect_ran.wait() + + assert connector.server_transport is not None + assert not connector.server_transport.is_closing() + + # Close the client socket + writer.close() + await writer.wait_closed() + + # Check that the server socket is closing + await asyncio.sleep(0.01) + assert connector.server_transport.is_closing() \ No newline at end of file From f8391aa4496db0e2b056c7bfd4326497d2eaa2bb Mon Sep 17 00:00:00 2001 From: Uziel Silva Date: Fri, 17 Oct 2025 16:09:19 -0600 Subject: [PATCH 13/14] feat: replace ssl_object with sslcontext and build a socket manually --- google/cloud/sql/connector/connector.py | 7 +- tests/system/test_psycopg_connection.py | 129 +++++++++--------------- 2 files changed, 51 insertions(+), 85 deletions(-) diff --git a/google/cloud/sql/connector/connector.py b/google/cloud/sql/connector/connector.py index 16195cc10..128dc7341 100755 --- a/google/cloud/sql/connector/connector.py +++ b/google/cloud/sql/connector/connector.py @@ -20,6 +20,7 @@ from functools import partial import logging import os +import socket from threading import Thread from types import TracebackType from typing import Any, Callable, Optional, Union @@ -514,7 +515,11 @@ async def connect_async( instance_connection_string, asyncio.Protocol, **kwargs ) # See https://docs.python.org/3/library/asyncio-protocol.html#asyncio.BaseTransport.get_extra_info - sock = tx.get_extra_info("ssl_object") + ctx = tx.get_extra_info("sslcontext") + sock = ctx.wrap_socket( + socket.create_connection((ip_address, SERVER_PROXY_PORT)), + server_hostname=ip_address, + ) connect_partial = partial(connector, ip_address, sock, **kwargs) return await self._loop.run_in_executor(None, connect_partial) diff --git a/tests/system/test_psycopg_connection.py b/tests/system/test_psycopg_connection.py index 1fed06356..ae744dacf 100644 --- a/tests/system/test_psycopg_connection.py +++ b/tests/system/test_psycopg_connection.py @@ -14,6 +14,7 @@ limitations under the License. """ +import asyncio from datetime import datetime import os @@ -21,6 +22,8 @@ from typing import Union from psycopg import Connection +import pytest +import logging import sqlalchemy from google.cloud.sql.connector import Connector @@ -29,98 +32,56 @@ SERVER_PROXY_PORT = 3307 -async def create_sqlalchemy_engine( - instance_connection_name: str, - user: str, - password: str, - db: str, - ip_type: str = "public", - refresh_strategy: str = "background", - resolver: Union[type[DefaultResolver], type[DnsResolver]] = DefaultResolver, -) -> tuple[sqlalchemy.engine.Engine, Connector]: - """Creates a connection pool for a Cloud SQL instance and returns the pool - and the connector. Callers are responsible for closing the pool and the - connector. - - A sample invocation looks like: - - engine, connector = create_sqlalchemy_engine( - inst_conn_name, - user, - password, - db, - ) - with engine.connect() as conn: - time = conn.execute(sqlalchemy.text("SELECT NOW()")).fetchone() - conn.commit() - curr_time = time[0] - # do something with query result - connector.close() - - Args: - instance_connection_name (str): - The instance connection name specifies the instance relative to the - project and region. For example: "my-project:my-region:my-instance" - user (str): - The database user name, e.g., root - password (str): - The database user's password, e.g., secret-password - db (str): - The name of the database, e.g., mydb - ip_type (str): - The IP type of the Cloud SQL instance to connect to. Can be one - of "public", "private", or "psc". - refresh_strategy (Optional[str]): - Refresh strategy for the Cloud SQL Connector. Can be one of "lazy" - or "background". For serverless environments use "lazy" to avoid - errors resulting from CPU being throttled. - resolver (Optional[google.cloud.sql.connector.DefaultResolver]): - Resolver class for resolving instance connection name. Use - google.cloud.sql.connector.DnsResolver when resolving DNS domain - names or google.cloud.sql.connector.DefaultResolver for regular - instance connection names ("my-project:my-region:my-instance"). - """ - connector = Connector(refresh_strategy=refresh_strategy, resolver=resolver) - unix_socket_folder = "/tmp/conn" - unix_socket_path = f"{unix_socket_folder}/.s.PGSQL.3307" - await connector.start_unix_socket_proxy_async( - instance_connection_name, - unix_socket_path, - ip_type=ip_type, # can be "public", "private" or "psc" - ) - - # create SQLAlchemy connection pool - engine = sqlalchemy.create_engine( - "postgresql+psycopg://", - creator=lambda: Connection.connect( - f"host={unix_socket_folder} port={SERVER_PROXY_PORT} dbname={db} user={user} password={password} sslmode=require", - user=user, - password=password, - dbname=db, - autocommit=True, - ) - ) - - return engine, connector - +logger = logging.getLogger(name=__name__) # [END cloud_sql_connector_postgres_psycopg] +@pytest.mark.asyncio async def test_psycopg_connection() -> None: """Basic test to get time from database.""" - inst_conn_name = os.environ["POSTGRES_CONNECTION_NAME"] + instance_connection_name = os.environ["POSTGRES_CONNECTION_NAME"] user = os.environ["POSTGRES_USER"] password = os.environ["POSTGRES_PASS"] db = os.environ["POSTGRES_DB"] ip_type = os.environ.get("IP_TYPE", "public") - engine, connector = await create_sqlalchemy_engine( - inst_conn_name, user, password, db, ip_type - ) - with engine.connect() as conn: - time = conn.execute(sqlalchemy.text("SELECT NOW()")).fetchone() - conn.commit() - curr_time = time[0] - assert type(curr_time) is datetime - connector.close() + unix_socket_folder = "/tmp/conn" + unix_socket_path = f"{unix_socket_folder}/.s.PGSQL.3307" + + async with Connector( + refresh_strategy='lazy', resolver=DefaultResolver + ) as connector: + # Open proxy connection + # start the proxy server + + await connector.start_unix_socket_proxy_async( + instance_connection_name, + unix_socket_path, + driver="psycopg", + user=user, + password=password, + db=db, + ip_type=ip_type, # can be "public", "private" or "psc" + ) + + # Wait for server to start + await asyncio.sleep(0.5) + + engine = sqlalchemy.create_engine( + "postgresql+psycopg://", + creator=lambda: Connection.connect( + f"host={unix_socket_folder} port={SERVER_PROXY_PORT} dbname={db} user={user} password={password} sslmode=require", + user=user, + password=password, + dbname=db, + autocommit=True, + ) + ) + + with engine.connect() as conn: + time = conn.execute(sqlalchemy.text("SELECT NOW()")).fetchone() + conn.commit() + curr_time = time[0] + assert type(curr_time) is datetime + connector.close() From 3791368e084f52a9ec740d842d977397ab281d43 Mon Sep 17 00:00:00 2001 From: Uziel Silva Date: Fri, 17 Oct 2025 16:51:22 -0600 Subject: [PATCH 14/14] feat: Pass loop to aiohttp client --- google/cloud/sql/connector/client.py | 4 ++-- google/cloud/sql/connector/connector.py | 2 ++ tests/system/test_psycopg_connection.py | 4 ---- 3 files changed, 4 insertions(+), 6 deletions(-) diff --git a/google/cloud/sql/connector/client.py b/google/cloud/sql/connector/client.py index 11508ce17..ef748eb19 100644 --- a/google/cloud/sql/connector/client.py +++ b/google/cloud/sql/connector/client.py @@ -58,6 +58,7 @@ def __init__( client: Optional[aiohttp.ClientSession] = None, driver: Optional[str] = None, user_agent: Optional[str] = None, + loop: Optional[asyncio.AbstractEventLoop] = None, ) -> None: """Establishes the client to be used for Cloud SQL Admin API requests. @@ -84,8 +85,7 @@ def __init__( } if quota_project: headers["x-goog-user-project"] = quota_project - - self._client = client if client else aiohttp.ClientSession(headers=headers) + self._client = client if client else aiohttp.ClientSession(headers=headers, loop=loop) self._credentials = credentials if sqladmin_api_endpoint is None: self._sqladmin_api_endpoint = DEFAULT_SERVICE_ENDPOINT diff --git a/google/cloud/sql/connector/connector.py b/google/cloud/sql/connector/connector.py index 128dc7341..b5ecff204 100755 --- a/google/cloud/sql/connector/connector.py +++ b/google/cloud/sql/connector/connector.py @@ -363,6 +363,7 @@ def _init_client(self, driver: Optional[str]) -> CloudSQLClient: self._credentials, user_agent=self._user_agent, driver=driver, + loop=self._loop ) return self._client @@ -419,6 +420,7 @@ async def connect_async( self._credentials, user_agent=self._user_agent, driver=driver, + loop=self._loop ) enable_iam_auth = kwargs.pop("enable_iam_auth", self._enable_iam_auth) diff --git a/tests/system/test_psycopg_connection.py b/tests/system/test_psycopg_connection.py index ae744dacf..25a46f3b8 100644 --- a/tests/system/test_psycopg_connection.py +++ b/tests/system/test_psycopg_connection.py @@ -58,10 +58,6 @@ async def test_psycopg_connection() -> None: await connector.start_unix_socket_proxy_async( instance_connection_name, unix_socket_path, - driver="psycopg", - user=user, - password=password, - db=db, ip_type=ip_type, # can be "public", "private" or "psc" )