From cd6df8581a8fbbf0287246ad737f3966970814ad Mon Sep 17 00:00:00 2001 From: rhatgadkar-goog Date: Tue, 24 Dec 2024 03:37:35 +0000 Subject: [PATCH 01/14] feat: support static connection info --- .../alloydb/connector/async_connector.py | 10 +++ google/cloud/alloydb/connector/connector.py | 10 +++ google/cloud/alloydb/connector/static.py | 90 +++++++++++++++++++ tests/unit/conftest.py | 9 +- tests/unit/mocks.py | 73 +++++++++++++++ tests/unit/test_async_connector.py | 45 ++++++++++ tests/unit/test_connector.py | 48 ++++++++++ 7 files changed, 280 insertions(+), 5 deletions(-) create mode 100644 google/cloud/alloydb/connector/static.py diff --git a/google/cloud/alloydb/connector/async_connector.py b/google/cloud/alloydb/connector/async_connector.py index 5e8e6bd6..097da9e2 100644 --- a/google/cloud/alloydb/connector/async_connector.py +++ b/google/cloud/alloydb/connector/async_connector.py @@ -15,6 +15,7 @@ from __future__ import annotations import asyncio +import io import logging from types import TracebackType from typing import Any, Optional, TYPE_CHECKING, Union @@ -29,6 +30,7 @@ from google.cloud.alloydb.connector.enums import RefreshStrategy from google.cloud.alloydb.connector.instance import RefreshAheadCache from google.cloud.alloydb.connector.lazy import LazyRefreshCache +from google.cloud.alloydb.connector.static import StaticConnectionInfoCache from google.cloud.alloydb.connector.utils import generate_keys if TYPE_CHECKING: @@ -59,6 +61,9 @@ class AsyncConnector: of the following: RefreshStrategy.LAZY ("LAZY") or RefreshStrategy.BACKGROUND ("BACKGROUND"). Default: RefreshStrategy.BACKGROUND + static_conn_info (io.TextIOBase): A file-like JSON object that contains + static connection info for the StaticConnectionInfoCache. + Defaults to None, which will not use the StaticConnectionInfoCache. """ def __init__( @@ -70,6 +75,7 @@ def __init__( ip_type: str | IPTypes = IPTypes.PRIVATE, user_agent: Optional[str] = None, refresh_strategy: str | RefreshStrategy = RefreshStrategy.BACKGROUND, + static_conn_info: io.TextIOBase = None, ) -> None: self._cache: dict[str, Union[RefreshAheadCache, LazyRefreshCache]] = {} # initialize default params @@ -100,6 +106,7 @@ def __init__( except RuntimeError: self._keys = None self._client: Optional[AlloyDBClient] = None + self._static_conn_info = static_conn_info async def connect( self, @@ -138,10 +145,13 @@ async def connect( ) enable_iam_auth = kwargs.pop("enable_iam_auth", self._enable_iam_auth) + static_conn_info = kwargs.pop("static_conn_info", self._static_conn_info) # use existing connection info if possible if instance_uri in self._cache: cache = self._cache[instance_uri] + elif static_conn_info: + cache = StaticConnectionInfoCache(instance_uri, static_conn_info) else: if self._refresh_strategy == RefreshStrategy.LAZY: logger.debug( diff --git a/google/cloud/alloydb/connector/connector.py b/google/cloud/alloydb/connector/connector.py index 8112aa9b..d5c930df 100644 --- a/google/cloud/alloydb/connector/connector.py +++ b/google/cloud/alloydb/connector/connector.py @@ -16,6 +16,7 @@ import asyncio from functools import partial +import io import logging import socket import struct @@ -34,6 +35,7 @@ from google.cloud.alloydb.connector.instance import RefreshAheadCache from google.cloud.alloydb.connector.lazy import LazyRefreshCache import google.cloud.alloydb.connector.pg8000 as pg8000 +from google.cloud.alloydb.connector.static import StaticConnectionInfoCache from google.cloud.alloydb.connector.utils import generate_keys import google.cloud.alloydb_connectors_v1.proto.resources_pb2 as connectorspb @@ -71,6 +73,9 @@ class Connector: of the following: RefreshStrategy.LAZY ("LAZY") or RefreshStrategy.BACKGROUND ("BACKGROUND"). Default: RefreshStrategy.BACKGROUND + static_conn_info (io.TextIOBase): A file-like JSON object that contains + static connection info for the StaticConnectionInfoCache. + Defaults to None, which will not use the StaticConnectionInfoCache. """ def __init__( @@ -82,6 +87,7 @@ def __init__( ip_type: str | IPTypes = IPTypes.PRIVATE, user_agent: Optional[str] = None, refresh_strategy: str | RefreshStrategy = RefreshStrategy.BACKGROUND, + static_conn_info: io.TextIOBase = None, ) -> None: # create event loop and start it in background thread self._loop: asyncio.AbstractEventLoop = asyncio.new_event_loop() @@ -113,6 +119,7 @@ def __init__( loop=self._loop, ) self._client: Optional[AlloyDBClient] = None + self._static_conn_info = static_conn_info def connect(self, instance_uri: str, driver: str, **kwargs: Any) -> Any: """ @@ -168,9 +175,12 @@ async def connect_async(self, instance_uri: str, driver: str, **kwargs: Any) -> driver=driver, ) enable_iam_auth = kwargs.pop("enable_iam_auth", self._enable_iam_auth) + static_conn_info = kwargs.pop("static_conn_info", self._static_conn_info) # use existing connection info if possible if instance_uri in self._cache: cache = self._cache[instance_uri] + elif static_conn_info: + cache = StaticConnectionInfoCache(instance_uri, static_conn_info) else: if self._refresh_strategy == RefreshStrategy.LAZY: logger.debug( diff --git a/google/cloud/alloydb/connector/static.py b/google/cloud/alloydb/connector/static.py new file mode 100644 index 00000000..9731e594 --- /dev/null +++ b/google/cloud/alloydb/connector/static.py @@ -0,0 +1,90 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from cryptography.hazmat.primitives import serialization +from cryptography.hazmat.primitives.asymmetric import rsa +from datetime import datetime +from datetime import timedelta +from datetime import timezone +import io +import json + +from google.cloud.alloydb.connector.connection_info import ConnectionInfo + + +class StaticConnectionInfoCache: + """ + StaticConnectionInfoCache creates a connection info cache that will always + return a pre-defined connection info. + + This static connection info should hold JSON with the following format: + { + "publicKey": "", + "privateKey": "", + "projects//locations//clusters//instances/": { + "ipAddress": "", + "publicIpAddress": "", + "pscInstanceConfig": { + "pscDnsName": "" + }, + "pemCertificateChain": [ + "", "", "" + ], + "caCert": "" + } + } + """ + + def __init__(self, instance_uri: str, static_conn_info: io.TextIOBase) -> None: + """ + Initializes a StaticConnectionInfoCache instance. + + Args: + instance_uri (str): The AlloyDB instance's connection URI. + static_conn_info (io.TextIOBase): The static connection info JSON. + """ + static_info = json.load(static_conn_info) + ca_cert = static_info[instance_uri]["caCert"] + cert_chain = static_info[instance_uri]["pemCertificateChain"] + ip_addrs = { + "PRIVATE": static_info[instance_uri]["ipAddress"], + "PUBLIC": static_info[instance_uri]["publicIpAddress"], + "PSC": static_info[instance_uri]["pscInstanceConfig"]["pscDnsName"], + } + expiration = datetime.now(timezone.utc) + timedelta(hours=1) + priv_key = static_info["privateKey"] + priv_key_bytes: rsa.RSAPrivateKey = serialization.load_pem_private_key( + priv_key.encode("UTF-8"), password=None, + ) + self._info = ConnectionInfo(cert_chain, ca_cert, priv_key_bytes, ip_addrs, expiration) + + async def force_refresh(self) -> None: + """ + This is a no-op as the cache holds only static connection information + and does no refresh. + """ + pass + + async def connect_info(self) -> ConnectionInfo: + """ + Retrieves ConnectionInfo instance for establishing a secure + connection to the AlloyDB instance. + """ + return self._info + + async def close(self) -> None: + """ + This is a no-op. + """ + pass \ No newline at end of file diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index 45648fa7..c7e27f4a 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -66,8 +66,8 @@ async def start_proxy_server(instance: FakeInstance) -> None: # listen for incoming connections sock.listen(5) - while True: - with context.wrap_socket(sock, server_side=True) as ssock: + with context.wrap_socket(sock, server_side=True) as ssock: + while True: conn, _ = ssock.accept() metadata_exchange(conn) conn.sendall(instance.name.encode("utf-8")) @@ -75,7 +75,7 @@ async def start_proxy_server(instance: FakeInstance) -> None: @pytest.fixture(scope="session") -def proxy_server(fake_instance: FakeInstance) -> Generator: +def proxy_server(fake_instance: FakeInstance) -> None: """Run local proxy server capable of performing metadata exchange""" thread = Thread( target=asyncio.run, @@ -87,5 +87,4 @@ def proxy_server(fake_instance: FakeInstance) -> Generator: daemon=True, ) thread.start() - yield thread - thread.join() + thread.join(0.1) # wait 100ms to allow the proxy server to start diff --git a/tests/unit/mocks.py b/tests/unit/mocks.py index be8e4750..b46314d4 100644 --- a/tests/unit/mocks.py +++ b/tests/unit/mocks.py @@ -16,7 +16,9 @@ from datetime import datetime from datetime import timedelta from datetime import timezone +import io import ipaddress +import json import ssl import struct from typing import Any, Callable, Literal, Optional @@ -193,6 +195,34 @@ def get_pem_certs(self) -> tuple[str, str, str]: encoding=serialization.Encoding.PEM ).decode("UTF-8") return (pem_root, pem_intermediate, pem_server) + + def generate_pem_certificate_chain(self, pub_key: str) -> Tuple[str, List[str]]: + """Generate the CA certificate and certificate chain for the AlloyDB instance.""" + root_cert, intermediate_cert, server_cert = self.get_pem_certs() + # encode public key to bytes + pub_key_bytes: rsa.RSAPublicKey = serialization.load_pem_public_key( + pub_key.encode("UTF-8"), + ) + # build client cert + client_cert = ( + x509.CertificateBuilder() + .subject_name(self.intermediate_cert.subject) + .issuer_name(self.intermediate_cert.issuer) + .public_key(pub_key_bytes) + .serial_number(x509.random_serial_number()) + .not_valid_before(self.cert_before) + .not_valid_after(self.cert_expiry) + ) + # sign client cert with intermediate cert + client_cert = client_cert.sign(self.intermediate_key, hashes.SHA256()) + client_cert = client_cert.public_bytes( + encoding=serialization.Encoding.PEM + ).decode("UTF-8") + return (server_cert, [client_cert, intermediate_cert, root_cert]) + + def uri(self) -> str: + """The URI of the AlloyDB instance.""" + return f"projects/{self.project}/locations/{self.region}/clusters/{self.cluster}/instances/{self.name}" class FakeAlloyDBClient: @@ -378,3 +408,46 @@ async def force_refresh(self) -> None: async def close(self) -> None: self._close_called = True + + +def write_static_info(i: FakeInstance) -> io.StringIO: + """ + Creates a static connection info JSON for the StaticConnectionInfoCache. + + Args: + i (FakeInstance): The FakeInstance to use to create the CA cert and + chain. + + Returns: + io.StringIO + """ + priv_key = rsa.generate_private_key(public_exponent=65537, key_size=2048) + pub_pem = ( + priv_key.public_key() + .public_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PublicFormat.SubjectPublicKeyInfo, + ) + .decode("UTF-8") + ) + priv_pem = ( + priv_key.private_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.TraditionalOpenSSL, + encryption_algorithm=serialization.NoEncryption(), + ) + .decode("UTF-8") + ) + ca_cert, chain = i.generate_pem_certificate_chain(pub_pem) + static = { + "publicKey": pub_pem, + "privateKey": priv_pem, + } + static[i.uri()] = { + "pemCertificateChain": chain, + "caCert": ca_cert, + "ipAddress": "127.0.0.1", # "private" IP is localhost in testing + "publicIpAddress": "", + "pscInstanceConfig": {"pscDnsName": ""}, + } + return io.StringIO(json.dumps(static)) diff --git a/tests/unit/test_async_connector.py b/tests/unit/test_async_connector.py index 0f150875..8e518276 100644 --- a/tests/unit/test_async_connector.py +++ b/tests/unit/test_async_connector.py @@ -20,6 +20,7 @@ from mocks import FakeAlloyDBClient from mocks import FakeConnectionInfo from mocks import FakeCredentials +from mocks import write_static_info import pytest from google.cloud.alloydb.connector import AsyncConnector @@ -333,3 +334,47 @@ async def test_Connector_remove_cached_no_ip_type(credentials: FakeCredentials) await connector.connect(instance_uri, "asyncpg", ip_type="private") # check that cache has been removed from dict assert instance_uri not in connector._cache + +async def test_Connector_static_connection_info(credentials: FakeCredentials, fake_client: FakeAlloyDBClient) -> None: + """ + Test that AsyncConnector.__init__() can specify a static connection info to + connect to an instance. + """ + static_info = write_static_info(fake_client.instance) + async with AsyncConnector(credentials=credentials, static_conn_info=static_info) as connector: + connector._client = fake_client + # patch db connection creation + with patch("google.cloud.alloydb.connector.asyncpg.connect") as mock_connect: + mock_connect.return_value = True + connection = await connector.connect( + fake_client.instance.uri(), + "asyncpg", + user="test-user", + password="test-password", + db="test-db", + ) + # check connection is returned + assert connection is True + + +async def test_connect_static_connection_info(credentials: FakeCredentials, fake_client: FakeAlloyDBClient) -> None: + """ + Test that AsyncConnector.connect() can specify a static connection info to + connect to an instance. + """ + static_info = write_static_info(fake_client.instance) + async with AsyncConnector(credentials=credentials) as connector: + connector._client = fake_client + # patch db connection creation + with patch("google.cloud.alloydb.connector.asyncpg.connect") as mock_connect: + mock_connect.return_value = True + connection = await connector.connect( + fake_client.instance.uri(), + "asyncpg", + user="test-user", + password="test-password", + db="test-db", + static_conn_info=static_info, + ) + # check connection is returned + assert connection is True \ No newline at end of file diff --git a/tests/unit/test_connector.py b/tests/unit/test_connector.py index a02ad30e..902a032f 100644 --- a/tests/unit/test_connector.py +++ b/tests/unit/test_connector.py @@ -20,6 +20,7 @@ from mock import patch from mocks import FakeAlloyDBClient from mocks import FakeCredentials +from mocks import write_static_info import pytest from google.cloud.alloydb.connector import Connector @@ -248,3 +249,50 @@ async def test_Connector_remove_cached_no_ip_type(credentials: FakeCredentials) await connector.connect_async(instance_uri, "pg8000", ip_type="private") # check that cache has been removed from dict assert instance_uri not in connector._cache + + +@pytest.mark.usefixtures("proxy_server") +def test_Connector_static_connection_info(credentials: FakeCredentials, fake_client: FakeAlloyDBClient) -> None: + """ + Test that Connector.__init__() can specify a static connection info to + connect to an instance. + """ + static_info = write_static_info(fake_client.instance) + with Connector(credentials=credentials, static_conn_info=static_info) as connector: + connector._client = fake_client + # patch db connection creation + with patch("google.cloud.alloydb.connector.pg8000.connect") as mock_connect: + mock_connect.return_value = True + connection = connector.connect( + fake_client.instance.uri(), + "pg8000", + user="test-user", + password="test-password", + db="test-db", + ) + # check connection is returned + assert connection is True + + +@pytest.mark.usefixtures("proxy_server") +def test_connect_static_connection_info(credentials: FakeCredentials, fake_client: FakeAlloyDBClient) -> None: + """ + Test that Connector.connect() can specify a static connection info to + connect to an instance. + """ + static_info = write_static_info(fake_client.instance) + with Connector(credentials=credentials) as connector: + connector._client = fake_client + # patch db connection creation + with patch("google.cloud.alloydb.connector.pg8000.connect") as mock_connect: + mock_connect.return_value = True + connection = connector.connect( + fake_client.instance.uri(), + "pg8000", + user="test-user", + password="test-password", + db="test-db", + static_conn_info=static_info, + ) + # check connection is returned + assert connection is True \ No newline at end of file From 7d8e3f322196c24f5a6655b7957b34d139eb53af Mon Sep 17 00:00:00 2001 From: rhatgadkar-goog Date: Fri, 3 Jan 2025 01:04:55 +0000 Subject: [PATCH 02/14] Add support for static.json with no PSC config --- google/cloud/alloydb/connector/static.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/google/cloud/alloydb/connector/static.py b/google/cloud/alloydb/connector/static.py index 9731e594..fa94b8bd 100644 --- a/google/cloud/alloydb/connector/static.py +++ b/google/cloud/alloydb/connector/static.py @@ -57,10 +57,13 @@ def __init__(self, instance_uri: str, static_conn_info: io.TextIOBase) -> None: static_info = json.load(static_conn_info) ca_cert = static_info[instance_uri]["caCert"] cert_chain = static_info[instance_uri]["pemCertificateChain"] + dns = "" + if static_info[instance_uri]["pscInstanceConfig"]: + dns = static_info[instance_uri]["pscInstanceConfig"]["pscDnsName"].rstrip(".") ip_addrs = { "PRIVATE": static_info[instance_uri]["ipAddress"], "PUBLIC": static_info[instance_uri]["publicIpAddress"], - "PSC": static_info[instance_uri]["pscInstanceConfig"]["pscDnsName"], + "PSC": dns, } expiration = datetime.now(timezone.utc) + timedelta(hours=1) priv_key = static_info["privateKey"] From a8e5780f7f3b2229fa42a68ca529387cc0f639b2 Mon Sep 17 00:00:00 2001 From: rhatgadkar-goog Date: Fri, 31 Jan 2025 02:44:19 +0000 Subject: [PATCH 03/14] Remove static_conn_info from connect() and fix lint issues --- .../alloydb/connector/async_connector.py | 10 +++--- google/cloud/alloydb/connector/connector.py | 17 ++++------ google/cloud/alloydb/connector/static.py | 11 +++--- tests/unit/conftest.py | 8 ++--- tests/unit/mocks.py | 12 +++---- tests/unit/test_async_connector.py | 34 +++---------------- tests/unit/test_connector.py | 33 ++---------------- 7 files changed, 30 insertions(+), 95 deletions(-) diff --git a/google/cloud/alloydb/connector/async_connector.py b/google/cloud/alloydb/connector/async_connector.py index 097da9e2..e01468e6 100644 --- a/google/cloud/alloydb/connector/async_connector.py +++ b/google/cloud/alloydb/connector/async_connector.py @@ -21,13 +21,12 @@ from typing import Any, Optional, TYPE_CHECKING, Union import google.auth -from google.auth.credentials import with_scopes_if_required import google.auth.transport.requests +from google.auth.credentials import with_scopes_if_required import google.cloud.alloydb.connector.asyncpg as asyncpg from google.cloud.alloydb.connector.client import AlloyDBClient -from google.cloud.alloydb.connector.enums import IPTypes -from google.cloud.alloydb.connector.enums import RefreshStrategy +from google.cloud.alloydb.connector.enums import IPTypes, RefreshStrategy from google.cloud.alloydb.connector.instance import RefreshAheadCache from google.cloud.alloydb.connector.lazy import LazyRefreshCache from google.cloud.alloydb.connector.static import StaticConnectionInfoCache @@ -145,13 +144,12 @@ async def connect( ) enable_iam_auth = kwargs.pop("enable_iam_auth", self._enable_iam_auth) - static_conn_info = kwargs.pop("static_conn_info", self._static_conn_info) # use existing connection info if possible if instance_uri in self._cache: cache = self._cache[instance_uri] - elif static_conn_info: - cache = StaticConnectionInfoCache(instance_uri, static_conn_info) + elif self._static_conn_info: + cache = StaticConnectionInfoCache(instance_uri, self._static_conn_info) else: if self._refresh_strategy == RefreshStrategy.LAZY: logger.debug( diff --git a/google/cloud/alloydb/connector/connector.py b/google/cloud/alloydb/connector/connector.py index d5c930df..7aff380a 100644 --- a/google/cloud/alloydb/connector/connector.py +++ b/google/cloud/alloydb/connector/connector.py @@ -15,29 +15,27 @@ from __future__ import annotations import asyncio -from functools import partial import io import logging import socket import struct +from functools import partial from threading import Thread from types import TracebackType from typing import Any, Optional, TYPE_CHECKING, Union from google.auth import default -from google.auth.credentials import TokenState -from google.auth.credentials import with_scopes_if_required +from google.auth.credentials import TokenState, with_scopes_if_required from google.auth.transport import requests +import google.cloud.alloydb.connector.pg8000 as pg8000 +import google.cloud.alloydb_connectors_v1.proto.resources_pb2 as connectorspb from google.cloud.alloydb.connector.client import AlloyDBClient -from google.cloud.alloydb.connector.enums import IPTypes -from google.cloud.alloydb.connector.enums import RefreshStrategy +from google.cloud.alloydb.connector.enums import IPTypes, RefreshStrategy from google.cloud.alloydb.connector.instance import RefreshAheadCache from google.cloud.alloydb.connector.lazy import LazyRefreshCache -import google.cloud.alloydb.connector.pg8000 as pg8000 from google.cloud.alloydb.connector.static import StaticConnectionInfoCache from google.cloud.alloydb.connector.utils import generate_keys -import google.cloud.alloydb_connectors_v1.proto.resources_pb2 as connectorspb if TYPE_CHECKING: import ssl @@ -175,12 +173,11 @@ async def connect_async(self, instance_uri: str, driver: str, **kwargs: Any) -> driver=driver, ) enable_iam_auth = kwargs.pop("enable_iam_auth", self._enable_iam_auth) - static_conn_info = kwargs.pop("static_conn_info", self._static_conn_info) # use existing connection info if possible if instance_uri in self._cache: cache = self._cache[instance_uri] - elif static_conn_info: - cache = StaticConnectionInfoCache(instance_uri, static_conn_info) + elif self._static_conn_info: + cache = StaticConnectionInfoCache(instance_uri, self._static_conn_info) else: if self._refresh_strategy == RefreshStrategy.LAZY: logger.debug( diff --git a/google/cloud/alloydb/connector/static.py b/google/cloud/alloydb/connector/static.py index fa94b8bd..f8a63fca 100644 --- a/google/cloud/alloydb/connector/static.py +++ b/google/cloud/alloydb/connector/static.py @@ -1,4 +1,4 @@ -# Copyright 2024 Google LLC +# Copyright 2025 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,13 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -from cryptography.hazmat.primitives import serialization -from cryptography.hazmat.primitives.asymmetric import rsa -from datetime import datetime -from datetime import timedelta -from datetime import timezone import io import json +from datetime import datetime, timedelta, timezone + +from cryptography.hazmat.primitives import serialization +from cryptography.hazmat.primitives.asymmetric import rsa from google.cloud.alloydb.connector.connection_info import ConnectionInfo diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index c7e27f4a..a04faa4a 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -18,12 +18,10 @@ from threading import Thread from typing import Generator -from aiofiles.tempfile import TemporaryDirectory -from mocks import FakeAlloyDBClient -from mocks import FakeCredentials -from mocks import FakeInstance -from mocks import metadata_exchange import pytest +from aiofiles.tempfile import TemporaryDirectory +from mocks import (FakeAlloyDBClient, FakeCredentials, FakeInstance, + metadata_exchange) from google.cloud.alloydb.connector.utils import _write_to_file diff --git a/tests/unit/mocks.py b/tests/unit/mocks.py index b46314d4..81c741b8 100644 --- a/tests/unit/mocks.py +++ b/tests/unit/mocks.py @@ -13,27 +13,23 @@ # limitations under the License. import asyncio -from datetime import datetime -from datetime import timedelta -from datetime import timezone import io import ipaddress import json import ssl import struct +from datetime import datetime, timedelta, timezone from typing import Any, Callable, Literal, Optional from cryptography import x509 -from cryptography.hazmat.primitives import hashes -from cryptography.hazmat.primitives import serialization +from cryptography.hazmat.primitives import hashes, serialization from cryptography.hazmat.primitives.asymmetric import rsa from cryptography.x509.oid import NameOID -from google.auth.credentials import _helpers -from google.auth.credentials import TokenState +from google.auth.credentials import TokenState, _helpers from google.auth.transport import requests -from google.cloud.alloydb.connector.connection_info import ConnectionInfo import google.cloud.alloydb_connectors_v1.proto.resources_pb2 as connectorspb +from google.cloud.alloydb.connector.connection_info import ConnectionInfo class FakeCredentials: diff --git a/tests/unit/test_async_connector.py b/tests/unit/test_async_connector.py index 8e518276..bd522195 100644 --- a/tests/unit/test_async_connector.py +++ b/tests/unit/test_async_connector.py @@ -15,16 +15,13 @@ import asyncio from typing import Union +import pytest from aiohttp import ClientResponseError from mock import patch -from mocks import FakeAlloyDBClient -from mocks import FakeConnectionInfo -from mocks import FakeCredentials -from mocks import write_static_info -import pytest +from mocks import (FakeAlloyDBClient, FakeConnectionInfo, FakeCredentials, + write_static_info) -from google.cloud.alloydb.connector import AsyncConnector -from google.cloud.alloydb.connector import IPTypes +from google.cloud.alloydb.connector import AsyncConnector, IPTypes from google.cloud.alloydb.connector.exceptions import IPTypeNotFoundError from google.cloud.alloydb.connector.instance import RefreshAheadCache @@ -354,27 +351,4 @@ async def test_Connector_static_connection_info(credentials: FakeCredentials, fa db="test-db", ) # check connection is returned - assert connection is True - - -async def test_connect_static_connection_info(credentials: FakeCredentials, fake_client: FakeAlloyDBClient) -> None: - """ - Test that AsyncConnector.connect() can specify a static connection info to - connect to an instance. - """ - static_info = write_static_info(fake_client.instance) - async with AsyncConnector(credentials=credentials) as connector: - connector._client = fake_client - # patch db connection creation - with patch("google.cloud.alloydb.connector.asyncpg.connect") as mock_connect: - mock_connect.return_value = True - connection = await connector.connect( - fake_client.instance.uri(), - "asyncpg", - user="test-user", - password="test-password", - db="test-db", - static_conn_info=static_info, - ) - # check connection is returned assert connection is True \ No newline at end of file diff --git a/tests/unit/test_connector.py b/tests/unit/test_connector.py index 902a032f..b1d0f99e 100644 --- a/tests/unit/test_connector.py +++ b/tests/unit/test_connector.py @@ -16,15 +16,12 @@ from threading import Thread from typing import Union +import pytest from aiohttp import ClientResponseError from mock import patch -from mocks import FakeAlloyDBClient -from mocks import FakeCredentials -from mocks import write_static_info -import pytest +from mocks import FakeAlloyDBClient, FakeCredentials, write_static_info -from google.cloud.alloydb.connector import Connector -from google.cloud.alloydb.connector import IPTypes +from google.cloud.alloydb.connector import Connector, IPTypes from google.cloud.alloydb.connector.exceptions import IPTypeNotFoundError from google.cloud.alloydb.connector.instance import RefreshAheadCache from google.cloud.alloydb.connector.utils import generate_keys @@ -271,28 +268,4 @@ def test_Connector_static_connection_info(credentials: FakeCredentials, fake_cli db="test-db", ) # check connection is returned - assert connection is True - - -@pytest.mark.usefixtures("proxy_server") -def test_connect_static_connection_info(credentials: FakeCredentials, fake_client: FakeAlloyDBClient) -> None: - """ - Test that Connector.connect() can specify a static connection info to - connect to an instance. - """ - static_info = write_static_info(fake_client.instance) - with Connector(credentials=credentials) as connector: - connector._client = fake_client - # patch db connection creation - with patch("google.cloud.alloydb.connector.pg8000.connect") as mock_connect: - mock_connect.return_value = True - connection = connector.connect( - fake_client.instance.uri(), - "pg8000", - user="test-user", - password="test-password", - db="test-db", - static_conn_info=static_info, - ) - # check connection is returned assert connection is True \ No newline at end of file From a59cd2fe340c4ce8a760ad68ec00814f78800dbf Mon Sep 17 00:00:00 2001 From: rhatgadkar-goog Date: Fri, 31 Jan 2025 02:54:43 +0000 Subject: [PATCH 04/14] Change Tuple to tuple --- tests/unit/mocks.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unit/mocks.py b/tests/unit/mocks.py index 81c741b8..5db30854 100644 --- a/tests/unit/mocks.py +++ b/tests/unit/mocks.py @@ -192,7 +192,7 @@ def get_pem_certs(self) -> tuple[str, str, str]: ).decode("UTF-8") return (pem_root, pem_intermediate, pem_server) - def generate_pem_certificate_chain(self, pub_key: str) -> Tuple[str, List[str]]: + def generate_pem_certificate_chain(self, pub_key: str) -> tuple[str, List[str]]: """Generate the CA certificate and certificate chain for the AlloyDB instance.""" root_cert, intermediate_cert, server_cert = self.get_pem_certs() # encode public key to bytes From 7f0e89cdf9454743193d1795b75fbe5676a716b3 Mon Sep 17 00:00:00 2001 From: rhatgadkar-goog Date: Fri, 31 Jan 2025 02:56:05 +0000 Subject: [PATCH 05/14] Change List to list --- tests/unit/mocks.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unit/mocks.py b/tests/unit/mocks.py index 5db30854..c119f4fe 100644 --- a/tests/unit/mocks.py +++ b/tests/unit/mocks.py @@ -192,7 +192,7 @@ def get_pem_certs(self) -> tuple[str, str, str]: ).decode("UTF-8") return (pem_root, pem_intermediate, pem_server) - def generate_pem_certificate_chain(self, pub_key: str) -> tuple[str, List[str]]: + def generate_pem_certificate_chain(self, pub_key: str) -> tuple[str, list[str]]: """Generate the CA certificate and certificate chain for the AlloyDB instance.""" root_cert, intermediate_cert, server_cert = self.get_pem_certs() # encode public key to bytes From 61660c332890c455299eb4587678f7eb6a48f98f Mon Sep 17 00:00:00 2001 From: rhatgadkar-goog Date: Fri, 31 Jan 2025 02:59:56 +0000 Subject: [PATCH 06/14] Fix lint errors using black --- google/cloud/alloydb/connector/static.py | 15 ++++++++++----- tests/unit/conftest.py | 5 ++--- tests/unit/mocks.py | 19 ++++++++----------- tests/unit/test_async_connector.py | 19 ++++++++++++++----- tests/unit/test_connector.py | 6 ++++-- 5 files changed, 38 insertions(+), 26 deletions(-) diff --git a/google/cloud/alloydb/connector/static.py b/google/cloud/alloydb/connector/static.py index f8a63fca..b6a3b52a 100644 --- a/google/cloud/alloydb/connector/static.py +++ b/google/cloud/alloydb/connector/static.py @@ -58,7 +58,9 @@ def __init__(self, instance_uri: str, static_conn_info: io.TextIOBase) -> None: cert_chain = static_info[instance_uri]["pemCertificateChain"] dns = "" if static_info[instance_uri]["pscInstanceConfig"]: - dns = static_info[instance_uri]["pscInstanceConfig"]["pscDnsName"].rstrip(".") + dns = static_info[instance_uri]["pscInstanceConfig"]["pscDnsName"].rstrip( + "." + ) ip_addrs = { "PRIVATE": static_info[instance_uri]["ipAddress"], "PUBLIC": static_info[instance_uri]["publicIpAddress"], @@ -67,9 +69,12 @@ def __init__(self, instance_uri: str, static_conn_info: io.TextIOBase) -> None: expiration = datetime.now(timezone.utc) + timedelta(hours=1) priv_key = static_info["privateKey"] priv_key_bytes: rsa.RSAPrivateKey = serialization.load_pem_private_key( - priv_key.encode("UTF-8"), password=None, + priv_key.encode("UTF-8"), + password=None, + ) + self._info = ConnectionInfo( + cert_chain, ca_cert, priv_key_bytes, ip_addrs, expiration ) - self._info = ConnectionInfo(cert_chain, ca_cert, priv_key_bytes, ip_addrs, expiration) async def force_refresh(self) -> None: """ @@ -84,9 +89,9 @@ async def connect_info(self) -> ConnectionInfo: connection to the AlloyDB instance. """ return self._info - + async def close(self) -> None: """ This is a no-op. """ - pass \ No newline at end of file + pass diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index a04faa4a..7b3788d9 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -20,8 +20,7 @@ import pytest from aiofiles.tempfile import TemporaryDirectory -from mocks import (FakeAlloyDBClient, FakeCredentials, FakeInstance, - metadata_exchange) +from mocks import FakeAlloyDBClient, FakeCredentials, FakeInstance, metadata_exchange from google.cloud.alloydb.connector.utils import _write_to_file @@ -85,4 +84,4 @@ def proxy_server(fake_instance: FakeInstance) -> None: daemon=True, ) thread.start() - thread.join(0.1) # wait 100ms to allow the proxy server to start + thread.join(0.1) # wait 100ms to allow the proxy server to start diff --git a/tests/unit/mocks.py b/tests/unit/mocks.py index c119f4fe..f12ba458 100644 --- a/tests/unit/mocks.py +++ b/tests/unit/mocks.py @@ -191,7 +191,7 @@ def get_pem_certs(self) -> tuple[str, str, str]: encoding=serialization.Encoding.PEM ).decode("UTF-8") return (pem_root, pem_intermediate, pem_server) - + def generate_pem_certificate_chain(self, pub_key: str) -> tuple[str, list[str]]: """Generate the CA certificate and certificate chain for the AlloyDB instance.""" root_cert, intermediate_cert, server_cert = self.get_pem_certs() @@ -215,7 +215,7 @@ def generate_pem_certificate_chain(self, pub_key: str) -> tuple[str, list[str]]: encoding=serialization.Encoding.PEM ).decode("UTF-8") return (server_cert, [client_cert, intermediate_cert, root_cert]) - + def uri(self) -> str: """The URI of the AlloyDB instance.""" return f"projects/{self.project}/locations/{self.region}/clusters/{self.cluster}/instances/{self.name}" @@ -426,14 +426,11 @@ def write_static_info(i: FakeInstance) -> io.StringIO: ) .decode("UTF-8") ) - priv_pem = ( - priv_key.private_bytes( - encoding=serialization.Encoding.PEM, - format=serialization.PrivateFormat.TraditionalOpenSSL, - encryption_algorithm=serialization.NoEncryption(), - ) - .decode("UTF-8") - ) + priv_pem = priv_key.private_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.TraditionalOpenSSL, + encryption_algorithm=serialization.NoEncryption(), + ).decode("UTF-8") ca_cert, chain = i.generate_pem_certificate_chain(pub_pem) static = { "publicKey": pub_pem, @@ -442,7 +439,7 @@ def write_static_info(i: FakeInstance) -> io.StringIO: static[i.uri()] = { "pemCertificateChain": chain, "caCert": ca_cert, - "ipAddress": "127.0.0.1", # "private" IP is localhost in testing + "ipAddress": "127.0.0.1", # "private" IP is localhost in testing "publicIpAddress": "", "pscInstanceConfig": {"pscDnsName": ""}, } diff --git a/tests/unit/test_async_connector.py b/tests/unit/test_async_connector.py index bd522195..3d5b1718 100644 --- a/tests/unit/test_async_connector.py +++ b/tests/unit/test_async_connector.py @@ -18,8 +18,12 @@ import pytest from aiohttp import ClientResponseError from mock import patch -from mocks import (FakeAlloyDBClient, FakeConnectionInfo, FakeCredentials, - write_static_info) +from mocks import ( + FakeAlloyDBClient, + FakeConnectionInfo, + FakeCredentials, + write_static_info, +) from google.cloud.alloydb.connector import AsyncConnector, IPTypes from google.cloud.alloydb.connector.exceptions import IPTypeNotFoundError @@ -332,13 +336,18 @@ async def test_Connector_remove_cached_no_ip_type(credentials: FakeCredentials) # check that cache has been removed from dict assert instance_uri not in connector._cache -async def test_Connector_static_connection_info(credentials: FakeCredentials, fake_client: FakeAlloyDBClient) -> None: + +async def test_Connector_static_connection_info( + credentials: FakeCredentials, fake_client: FakeAlloyDBClient +) -> None: """ Test that AsyncConnector.__init__() can specify a static connection info to connect to an instance. """ static_info = write_static_info(fake_client.instance) - async with AsyncConnector(credentials=credentials, static_conn_info=static_info) as connector: + async with AsyncConnector( + credentials=credentials, static_conn_info=static_info + ) as connector: connector._client = fake_client # patch db connection creation with patch("google.cloud.alloydb.connector.asyncpg.connect") as mock_connect: @@ -351,4 +360,4 @@ async def test_Connector_static_connection_info(credentials: FakeCredentials, fa db="test-db", ) # check connection is returned - assert connection is True \ No newline at end of file + assert connection is True diff --git a/tests/unit/test_connector.py b/tests/unit/test_connector.py index b1d0f99e..359036e7 100644 --- a/tests/unit/test_connector.py +++ b/tests/unit/test_connector.py @@ -249,7 +249,9 @@ async def test_Connector_remove_cached_no_ip_type(credentials: FakeCredentials) @pytest.mark.usefixtures("proxy_server") -def test_Connector_static_connection_info(credentials: FakeCredentials, fake_client: FakeAlloyDBClient) -> None: +def test_Connector_static_connection_info( + credentials: FakeCredentials, fake_client: FakeAlloyDBClient +) -> None: """ Test that Connector.__init__() can specify a static connection info to connect to an instance. @@ -268,4 +270,4 @@ def test_Connector_static_connection_info(credentials: FakeCredentials, fake_cli db="test-db", ) # check connection is returned - assert connection is True \ No newline at end of file + assert connection is True From aac315c4a74cb05ec29f5190c632f45e8d30e843 Mon Sep 17 00:00:00 2001 From: rhatgadkar-goog Date: Fri, 31 Jan 2025 03:11:21 +0000 Subject: [PATCH 07/14] Address lint issues with imports --- google/cloud/alloydb/connector/static.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/google/cloud/alloydb/connector/static.py b/google/cloud/alloydb/connector/static.py index b6a3b52a..287e29ba 100644 --- a/google/cloud/alloydb/connector/static.py +++ b/google/cloud/alloydb/connector/static.py @@ -12,9 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. +from datetime import datetime +from datetime import timedelta +from datetime import timezone import io import json -from datetime import datetime, timedelta, timezone from cryptography.hazmat.primitives import serialization from cryptography.hazmat.primitives.asymmetric import rsa From 5abf2124dc2dfbd702e5d972d29dd0625bd65d3a Mon Sep 17 00:00:00 2001 From: rhatgadkar-goog Date: Fri, 31 Jan 2025 03:27:19 +0000 Subject: [PATCH 08/14] Fix lint errors for imports --- tests/unit/conftest.py | 7 +++++-- tests/unit/mocks.py | 12 ++++++++---- tests/unit/test_async_connector.py | 15 +++++++-------- tests/unit/test_connector.py | 9 ++++++--- 4 files changed, 26 insertions(+), 17 deletions(-) diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index 7b3788d9..82932a25 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -18,9 +18,12 @@ from threading import Thread from typing import Generator -import pytest from aiofiles.tempfile import TemporaryDirectory -from mocks import FakeAlloyDBClient, FakeCredentials, FakeInstance, metadata_exchange +import pytest +from mocks import FakeAlloyDBClient +from mocks import FakeCredentials +from mocks import FakeInstance +from mocks import metadata_exchange from google.cloud.alloydb.connector.utils import _write_to_file diff --git a/tests/unit/mocks.py b/tests/unit/mocks.py index f12ba458..4dc0b257 100644 --- a/tests/unit/mocks.py +++ b/tests/unit/mocks.py @@ -13,23 +13,27 @@ # limitations under the License. import asyncio +from datetime import datetime +from datetime import timedelta +from datetime import timezone import io import ipaddress import json import ssl import struct -from datetime import datetime, timedelta, timezone from typing import Any, Callable, Literal, Optional from cryptography import x509 -from cryptography.hazmat.primitives import hashes, serialization +from cryptography.hazmat.primitives import hashes +from cryptography.hazmat.primitives import serialization from cryptography.hazmat.primitives.asymmetric import rsa from cryptography.x509.oid import NameOID -from google.auth.credentials import TokenState, _helpers +from google.auth.credentials import _helpers +from google.auth.credentials import TokenState from google.auth.transport import requests -import google.cloud.alloydb_connectors_v1.proto.resources_pb2 as connectorspb from google.cloud.alloydb.connector.connection_info import ConnectionInfo +import google.cloud.alloydb_connectors_v1.proto.resources_pb2 as connectorspb class FakeCredentials: diff --git a/tests/unit/test_async_connector.py b/tests/unit/test_async_connector.py index 3d5b1718..e2da0ee4 100644 --- a/tests/unit/test_async_connector.py +++ b/tests/unit/test_async_connector.py @@ -15,17 +15,16 @@ import asyncio from typing import Union -import pytest from aiohttp import ClientResponseError from mock import patch -from mocks import ( - FakeAlloyDBClient, - FakeConnectionInfo, - FakeCredentials, - write_static_info, -) +from mocks import FakeAlloyDBClient +from mocks import FakeConnectionInfo +from mocks import FakeCredentials +from mocks import write_static_info +import pytest -from google.cloud.alloydb.connector import AsyncConnector, IPTypes +from google.cloud.alloydb.connector import AsyncConnector +from google.cloud.alloydb.connector import IPTypes from google.cloud.alloydb.connector.exceptions import IPTypeNotFoundError from google.cloud.alloydb.connector.instance import RefreshAheadCache diff --git a/tests/unit/test_connector.py b/tests/unit/test_connector.py index 359036e7..c7660d1c 100644 --- a/tests/unit/test_connector.py +++ b/tests/unit/test_connector.py @@ -16,12 +16,15 @@ from threading import Thread from typing import Union -import pytest from aiohttp import ClientResponseError from mock import patch -from mocks import FakeAlloyDBClient, FakeCredentials, write_static_info +from mocks import FakeAlloyDBClient +from mocks import FakeCredentials +from mocks import write_static_info +import pytest -from google.cloud.alloydb.connector import Connector, IPTypes +from google.cloud.alloydb.connector import Connector +from google.cloud.alloydb.connector import IPTypes from google.cloud.alloydb.connector.exceptions import IPTypeNotFoundError from google.cloud.alloydb.connector.instance import RefreshAheadCache from google.cloud.alloydb.connector.utils import generate_keys From 6b3290c5e376eaf0d979e0b22d9f3d57a10bb672 Mon Sep 17 00:00:00 2001 From: rhatgadkar-goog Date: Fri, 31 Jan 2025 03:39:31 +0000 Subject: [PATCH 09/14] Fix lint issues related to ordering of imports --- google/cloud/alloydb/connector/async_connector.py | 5 +++-- google/cloud/alloydb/connector/connector.py | 12 +++++++----- tests/unit/conftest.py | 2 +- 3 files changed, 11 insertions(+), 8 deletions(-) diff --git a/google/cloud/alloydb/connector/async_connector.py b/google/cloud/alloydb/connector/async_connector.py index e01468e6..2eb5283b 100644 --- a/google/cloud/alloydb/connector/async_connector.py +++ b/google/cloud/alloydb/connector/async_connector.py @@ -21,12 +21,13 @@ from typing import Any, Optional, TYPE_CHECKING, Union import google.auth -import google.auth.transport.requests from google.auth.credentials import with_scopes_if_required +import google.auth.transport.requests import google.cloud.alloydb.connector.asyncpg as asyncpg from google.cloud.alloydb.connector.client import AlloyDBClient -from google.cloud.alloydb.connector.enums import IPTypes, RefreshStrategy +from google.cloud.alloydb.connector.enums import IPTypes +from google.cloud.alloydb.connector.enums import RefreshStrategy from google.cloud.alloydb.connector.instance import RefreshAheadCache from google.cloud.alloydb.connector.lazy import LazyRefreshCache from google.cloud.alloydb.connector.static import StaticConnectionInfoCache diff --git a/google/cloud/alloydb/connector/connector.py b/google/cloud/alloydb/connector/connector.py index 7aff380a..bcfe2e0e 100644 --- a/google/cloud/alloydb/connector/connector.py +++ b/google/cloud/alloydb/connector/connector.py @@ -15,27 +15,29 @@ from __future__ import annotations import asyncio +from functools import partial import io import logging import socket import struct -from functools import partial from threading import Thread from types import TracebackType from typing import Any, Optional, TYPE_CHECKING, Union from google.auth import default -from google.auth.credentials import TokenState, with_scopes_if_required +from google.auth.credentials import TokenState +from google.auth.credentials import with_scopes_if_required from google.auth.transport import requests -import google.cloud.alloydb.connector.pg8000 as pg8000 -import google.cloud.alloydb_connectors_v1.proto.resources_pb2 as connectorspb from google.cloud.alloydb.connector.client import AlloyDBClient -from google.cloud.alloydb.connector.enums import IPTypes, RefreshStrategy +from google.cloud.alloydb.connector.enums import IPTypes +from google.cloud.alloydb.connector.enums import RefreshStrategy from google.cloud.alloydb.connector.instance import RefreshAheadCache from google.cloud.alloydb.connector.lazy import LazyRefreshCache +import google.cloud.alloydb.connector.pg8000 as pg8000 from google.cloud.alloydb.connector.static import StaticConnectionInfoCache from google.cloud.alloydb.connector.utils import generate_keys +import google.cloud.alloydb_connectors_v1.proto.resources_pb2 as connectorspb if TYPE_CHECKING: import ssl diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index 82932a25..84a6c85e 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -19,11 +19,11 @@ from typing import Generator from aiofiles.tempfile import TemporaryDirectory -import pytest from mocks import FakeAlloyDBClient from mocks import FakeCredentials from mocks import FakeInstance from mocks import metadata_exchange +import pytest from google.cloud.alloydb.connector.utils import _write_to_file From 30505c25ef83ae85c189457432c0d7f97d4c0ffe Mon Sep 17 00:00:00 2001 From: rhatgadkar-goog Date: Fri, 31 Jan 2025 03:47:27 +0000 Subject: [PATCH 10/14] Remove Generator from conftest.py --- tests/unit/conftest.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index 84a6c85e..b81c7ccb 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -16,7 +16,6 @@ import socket import ssl from threading import Thread -from typing import Generator from aiofiles.tempfile import TemporaryDirectory from mocks import FakeAlloyDBClient From 1ff96c7e3302ee85c6c56e098d8cf9310c12e59f Mon Sep 17 00:00:00 2001 From: rhatgadkar-goog Date: Fri, 31 Jan 2025 19:01:51 +0000 Subject: [PATCH 11/14] fix lint errors --- google/cloud/alloydb/connector/async_connector.py | 3 ++- google/cloud/alloydb/connector/connection_info.py | 4 ++-- google/cloud/alloydb/connector/connector.py | 3 ++- google/cloud/alloydb/connector/static.py | 6 ++---- google/cloud/alloydb/connector/utils.py | 3 ++- 5 files changed, 10 insertions(+), 9 deletions(-) diff --git a/google/cloud/alloydb/connector/async_connector.py b/google/cloud/alloydb/connector/async_connector.py index 2eb5283b..0db6f3db 100644 --- a/google/cloud/alloydb/connector/async_connector.py +++ b/google/cloud/alloydb/connector/async_connector.py @@ -75,7 +75,7 @@ def __init__( ip_type: str | IPTypes = IPTypes.PRIVATE, user_agent: Optional[str] = None, refresh_strategy: str | RefreshStrategy = RefreshStrategy.BACKGROUND, - static_conn_info: io.TextIOBase = None, + static_conn_info: Optional[io.TextIOBase] = None, ) -> None: self._cache: dict[str, Union[RefreshAheadCache, LazyRefreshCache]] = {} # initialize default params @@ -147,6 +147,7 @@ async def connect( enable_iam_auth = kwargs.pop("enable_iam_auth", self._enable_iam_auth) # use existing connection info if possible + cache: Union[RefreshAheadCache, LazyRefreshCache, StaticConnectionInfoCache] if instance_uri in self._cache: cache = self._cache[instance_uri] elif self._static_conn_info: diff --git a/google/cloud/alloydb/connector/connection_info.py b/google/cloud/alloydb/connector/connection_info.py index 65ca231d..aed28adb 100644 --- a/google/cloud/alloydb/connector/connection_info.py +++ b/google/cloud/alloydb/connector/connection_info.py @@ -27,7 +27,7 @@ if TYPE_CHECKING: import datetime - from cryptography.hazmat.primitives.asymmetric import rsa + from cryptography.hazmat.primitives.asymmetric.types import PrivateKeyTypes from google.cloud.alloydb.connector.enums import IPTypes @@ -41,7 +41,7 @@ class ConnectionInfo: cert_chain: list[str] ca_cert: str - key: rsa.RSAPrivateKey + key: PrivateKeyTypes ip_addrs: dict[str, Optional[str]] expiration: datetime.datetime context: Optional[ssl.SSLContext] = None diff --git a/google/cloud/alloydb/connector/connector.py b/google/cloud/alloydb/connector/connector.py index bcfe2e0e..755e2b68 100644 --- a/google/cloud/alloydb/connector/connector.py +++ b/google/cloud/alloydb/connector/connector.py @@ -87,7 +87,7 @@ def __init__( ip_type: str | IPTypes = IPTypes.PRIVATE, user_agent: Optional[str] = None, refresh_strategy: str | RefreshStrategy = RefreshStrategy.BACKGROUND, - static_conn_info: io.TextIOBase = None, + static_conn_info: Optional[io.TextIOBase] = None, ) -> None: # create event loop and start it in background thread self._loop: asyncio.AbstractEventLoop = asyncio.new_event_loop() @@ -176,6 +176,7 @@ async def connect_async(self, instance_uri: str, driver: str, **kwargs: Any) -> ) enable_iam_auth = kwargs.pop("enable_iam_auth", self._enable_iam_auth) # use existing connection info if possible + cache: Union[RefreshAheadCache, LazyRefreshCache, StaticConnectionInfoCache] if instance_uri in self._cache: cache = self._cache[instance_uri] elif self._static_conn_info: diff --git a/google/cloud/alloydb/connector/static.py b/google/cloud/alloydb/connector/static.py index 287e29ba..da90245d 100644 --- a/google/cloud/alloydb/connector/static.py +++ b/google/cloud/alloydb/connector/static.py @@ -19,7 +19,6 @@ import json from cryptography.hazmat.primitives import serialization -from cryptography.hazmat.primitives.asymmetric import rsa from google.cloud.alloydb.connector.connection_info import ConnectionInfo @@ -70,9 +69,8 @@ def __init__(self, instance_uri: str, static_conn_info: io.TextIOBase) -> None: } expiration = datetime.now(timezone.utc) + timedelta(hours=1) priv_key = static_info["privateKey"] - priv_key_bytes: rsa.RSAPrivateKey = serialization.load_pem_private_key( - priv_key.encode("UTF-8"), - password=None, + priv_key_bytes = serialization.load_pem_private_key( + priv_key.encode("UTF-8"), password=None ) self._info = ConnectionInfo( cert_chain, ca_cert, priv_key_bytes, ip_addrs, expiration diff --git a/google/cloud/alloydb/connector/utils.py b/google/cloud/alloydb/connector/utils.py index 908ce406..e4c99393 100644 --- a/google/cloud/alloydb/connector/utils.py +++ b/google/cloud/alloydb/connector/utils.py @@ -17,10 +17,11 @@ import aiofiles from cryptography.hazmat.primitives import serialization from cryptography.hazmat.primitives.asymmetric import rsa +from cryptography.hazmat.primitives.asymmetric.types import PrivateKeyTypes async def _write_to_file( - dir_path: str, ca_cert: str, cert_chain: list[str], key: rsa.RSAPrivateKey + dir_path: str, ca_cert: str, cert_chain: list[str], key: PrivateKeyTypes ) -> tuple[str, str, str]: """ Helper function to write the server_ca, client certificate and From 5d9c473adedeb402493cd2d4b547735b51af2124 Mon Sep 17 00:00:00 2001 From: rhatgadkar-goog Date: Tue, 4 Feb 2025 02:57:27 +0000 Subject: [PATCH 12/14] Address PR comments --- .../alloydb/connector/async_connector.py | 15 +++-------- google/cloud/alloydb/connector/connector.py | 9 +++---- google/cloud/alloydb/connector/types.py | 23 ++++++++++++++++ tests/unit/conftest.py | 5 +++- tests/unit/test_async_connector.py | 27 ------------------- 5 files changed, 33 insertions(+), 46 deletions(-) create mode 100644 google/cloud/alloydb/connector/types.py diff --git a/google/cloud/alloydb/connector/async_connector.py b/google/cloud/alloydb/connector/async_connector.py index 0db6f3db..bd8e88d3 100644 --- a/google/cloud/alloydb/connector/async_connector.py +++ b/google/cloud/alloydb/connector/async_connector.py @@ -15,10 +15,9 @@ from __future__ import annotations import asyncio -import io import logging from types import TracebackType -from typing import Any, Optional, TYPE_CHECKING, Union +from typing import Any, Optional, TYPE_CHECKING import google.auth from google.auth.credentials import with_scopes_if_required @@ -30,7 +29,7 @@ from google.cloud.alloydb.connector.enums import RefreshStrategy from google.cloud.alloydb.connector.instance import RefreshAheadCache from google.cloud.alloydb.connector.lazy import LazyRefreshCache -from google.cloud.alloydb.connector.static import StaticConnectionInfoCache +from google.cloud.alloydb.connector.types import CacheTypes from google.cloud.alloydb.connector.utils import generate_keys if TYPE_CHECKING: @@ -61,9 +60,6 @@ class AsyncConnector: of the following: RefreshStrategy.LAZY ("LAZY") or RefreshStrategy.BACKGROUND ("BACKGROUND"). Default: RefreshStrategy.BACKGROUND - static_conn_info (io.TextIOBase): A file-like JSON object that contains - static connection info for the StaticConnectionInfoCache. - Defaults to None, which will not use the StaticConnectionInfoCache. """ def __init__( @@ -75,9 +71,8 @@ def __init__( ip_type: str | IPTypes = IPTypes.PRIVATE, user_agent: Optional[str] = None, refresh_strategy: str | RefreshStrategy = RefreshStrategy.BACKGROUND, - static_conn_info: Optional[io.TextIOBase] = None, ) -> None: - self._cache: dict[str, Union[RefreshAheadCache, LazyRefreshCache]] = {} + self._cache: dict[str, CacheTypes] = {} # initialize default params self._quota_project = quota_project self._alloydb_api_endpoint = alloydb_api_endpoint @@ -106,7 +101,6 @@ def __init__( except RuntimeError: self._keys = None self._client: Optional[AlloyDBClient] = None - self._static_conn_info = static_conn_info async def connect( self, @@ -147,11 +141,8 @@ async def connect( enable_iam_auth = kwargs.pop("enable_iam_auth", self._enable_iam_auth) # use existing connection info if possible - cache: Union[RefreshAheadCache, LazyRefreshCache, StaticConnectionInfoCache] if instance_uri in self._cache: cache = self._cache[instance_uri] - elif self._static_conn_info: - cache = StaticConnectionInfoCache(instance_uri, self._static_conn_info) else: if self._refresh_strategy == RefreshStrategy.LAZY: logger.debug( diff --git a/google/cloud/alloydb/connector/connector.py b/google/cloud/alloydb/connector/connector.py index 755e2b68..23e567b3 100644 --- a/google/cloud/alloydb/connector/connector.py +++ b/google/cloud/alloydb/connector/connector.py @@ -22,7 +22,7 @@ import struct from threading import Thread from types import TracebackType -from typing import Any, Optional, TYPE_CHECKING, Union +from typing import Any, Optional, TYPE_CHECKING from google.auth import default from google.auth.credentials import TokenState @@ -35,7 +35,7 @@ from google.cloud.alloydb.connector.instance import RefreshAheadCache from google.cloud.alloydb.connector.lazy import LazyRefreshCache import google.cloud.alloydb.connector.pg8000 as pg8000 -from google.cloud.alloydb.connector.static import StaticConnectionInfoCache +from google.cloud.alloydb.connector.types import CacheTypes from google.cloud.alloydb.connector.utils import generate_keys import google.cloud.alloydb_connectors_v1.proto.resources_pb2 as connectorspb @@ -93,7 +93,7 @@ def __init__( self._loop: asyncio.AbstractEventLoop = asyncio.new_event_loop() self._thread = Thread(target=self._loop.run_forever, daemon=True) self._thread.start() - self._cache: dict[str, Union[RefreshAheadCache, LazyRefreshCache]] = {} + self._cache: dict[str, CacheTypes] = {} # initialize default params self._quota_project = quota_project self._alloydb_api_endpoint = alloydb_api_endpoint @@ -176,11 +176,8 @@ async def connect_async(self, instance_uri: str, driver: str, **kwargs: Any) -> ) enable_iam_auth = kwargs.pop("enable_iam_auth", self._enable_iam_auth) # use existing connection info if possible - cache: Union[RefreshAheadCache, LazyRefreshCache, StaticConnectionInfoCache] if instance_uri in self._cache: cache = self._cache[instance_uri] - elif self._static_conn_info: - cache = StaticConnectionInfoCache(instance_uri, self._static_conn_info) else: if self._refresh_strategy == RefreshStrategy.LAZY: logger.debug( diff --git a/google/cloud/alloydb/connector/types.py b/google/cloud/alloydb/connector/types.py new file mode 100644 index 00000000..486370c7 --- /dev/null +++ b/google/cloud/alloydb/connector/types.py @@ -0,0 +1,23 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import typing + +from google.cloud.alloydb.connector.instance import RefreshAheadCache +from google.cloud.alloydb.connector.lazy import LazyRefreshCache +from google.cloud.alloydb.connector.static import StaticConnectionInfoCache + +CacheTypes = typing.Union[ + RefreshAheadCache, LazyRefreshCache, StaticConnectionInfoCache +] diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index b81c7ccb..890378d0 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -16,6 +16,7 @@ import socket import ssl from threading import Thread +from typing import Final from aiofiles.tempfile import TemporaryDirectory from mocks import FakeAlloyDBClient @@ -26,6 +27,8 @@ from google.cloud.alloydb.connector.utils import _write_to_file +DELAY: Final[float] = 1.0 + @pytest.fixture def credentials() -> FakeCredentials: @@ -86,4 +89,4 @@ def proxy_server(fake_instance: FakeInstance) -> None: daemon=True, ) thread.start() - thread.join(0.1) # wait 100ms to allow the proxy server to start + thread.join(DELAY) # add a delay to allow the proxy server to start diff --git a/tests/unit/test_async_connector.py b/tests/unit/test_async_connector.py index e2da0ee4..0f150875 100644 --- a/tests/unit/test_async_connector.py +++ b/tests/unit/test_async_connector.py @@ -20,7 +20,6 @@ from mocks import FakeAlloyDBClient from mocks import FakeConnectionInfo from mocks import FakeCredentials -from mocks import write_static_info import pytest from google.cloud.alloydb.connector import AsyncConnector @@ -334,29 +333,3 @@ async def test_Connector_remove_cached_no_ip_type(credentials: FakeCredentials) await connector.connect(instance_uri, "asyncpg", ip_type="private") # check that cache has been removed from dict assert instance_uri not in connector._cache - - -async def test_Connector_static_connection_info( - credentials: FakeCredentials, fake_client: FakeAlloyDBClient -) -> None: - """ - Test that AsyncConnector.__init__() can specify a static connection info to - connect to an instance. - """ - static_info = write_static_info(fake_client.instance) - async with AsyncConnector( - credentials=credentials, static_conn_info=static_info - ) as connector: - connector._client = fake_client - # patch db connection creation - with patch("google.cloud.alloydb.connector.asyncpg.connect") as mock_connect: - mock_connect.return_value = True - connection = await connector.connect( - fake_client.instance.uri(), - "asyncpg", - user="test-user", - password="test-password", - db="test-db", - ) - # check connection is returned - assert connection is True From 171e6e14c21baf0fe7bc5b0fecda162dcb8a110b Mon Sep 17 00:00:00 2001 From: rhatgadkar-goog Date: Fri, 7 Feb 2025 03:05:03 +0000 Subject: [PATCH 13/14] Add unit tests for StaticConnectionInfoCache --- tests/unit/mocks.py | 8 ++-- tests/unit/test_static.py | 95 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 99 insertions(+), 4 deletions(-) create mode 100644 tests/unit/test_static.py diff --git a/tests/unit/mocks.py b/tests/unit/mocks.py index 4dc0b257..cb64c819 100644 --- a/tests/unit/mocks.py +++ b/tests/unit/mocks.py @@ -149,7 +149,7 @@ def __init__( cluster: str = "test-cluster", name: str = "test-instance", ip_addrs: dict = { - "PRIVATE": "127.0.0.1", + "PRIVATE": "127.0.0.1", # "private" IP is localhost in testing "PUBLIC": "0.0.0.0", "PSC": "x.y.alloydb.goog", }, @@ -443,8 +443,8 @@ def write_static_info(i: FakeInstance) -> io.StringIO: static[i.uri()] = { "pemCertificateChain": chain, "caCert": ca_cert, - "ipAddress": "127.0.0.1", # "private" IP is localhost in testing - "publicIpAddress": "", - "pscInstanceConfig": {"pscDnsName": ""}, + "ipAddress": i.ip_addrs["PRIVATE"], + "publicIpAddress": i.ip_addrs["PUBLIC"], + "pscInstanceConfig": {"pscDnsName": i.ip_addrs["PSC"]}, } return io.StringIO(json.dumps(static)) diff --git a/tests/unit/test_static.py b/tests/unit/test_static.py new file mode 100644 index 00000000..30e82cff --- /dev/null +++ b/tests/unit/test_static.py @@ -0,0 +1,95 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from mocks import FakeInstance, write_static_info + +from google.cloud.alloydb.connector.connection_info import ConnectionInfo +from google.cloud.alloydb.connector.static import StaticConnectionInfoCache + + +def test_StaticConnectionInfoCache_init() -> None: + """ + Test that StaticConnectionInfoCache.__init__ populates its ConnectionInfo + object. + """ + i = FakeInstance() + static_info = write_static_info(i) + cache = StaticConnectionInfoCache(i.uri(), static_info) + assert len(cache._info.cert_chain) == 3 + assert cache._info.ca_cert + assert cache._info.key + assert cache._info.ip_addrs == { + "PRIVATE": i.ip_addrs["PRIVATE"], + "PUBLIC": i.ip_addrs["PUBLIC"], + "PSC": i.ip_addrs["PSC"], + } + assert cache._info.expiration + +def test_StaticConnectionInfoCache_init_trailing_dot_dns() -> None: + """ + Test that StaticConnectionInfoCache.__init__ populates its ConnectionInfo + object correctly when its PSC DNS name contains a trailing dot. + """ + i = FakeInstance() + no_trailing_dot_dns = i.ip_addrs["PSC"] + i.ip_addrs["PSC"] += "." + static_info = write_static_info(i) + cache = StaticConnectionInfoCache(i.uri(), static_info) + assert len(cache._info.cert_chain) == 3 + assert cache._info.ca_cert + assert cache._info.key + assert cache._info.ip_addrs == { + "PRIVATE": i.ip_addrs["PRIVATE"], + "PUBLIC": i.ip_addrs["PUBLIC"], + "PSC": no_trailing_dot_dns, + } + assert cache._info.expiration + +async def test_StaticConnectionInfoCache_force_refresh() -> None: + """ + Test that StaticConnectionInfoCache.force_refresh is a no-op. + """ + i = FakeInstance() + static_info = write_static_info(i) + cache = StaticConnectionInfoCache(i.uri(), static_info) + conn_info = cache._info + await cache.force_refresh() + conn_info2 = cache._info + assert conn_info2 == conn_info + +async def test_StaticConnectionInfoCache_connect_info() -> None: + """ + Test that StaticConnectionInfoCache.connect_info works as expected. + """ + i = FakeInstance() + static_info = write_static_info(i) + cache = StaticConnectionInfoCache(i.uri(), static_info) + # check that cached connection info is now set + assert isinstance(cache._info, ConnectionInfo) + conn_info = await cache.connect_info() + # check that calling connect_info uses cached info + conn_info2 = await cache.connect_info() + assert conn_info2 == conn_info + +async def test_StaticConnectionInfoCache_close() -> None: + """ + Test that StaticConnectionInfoCache.close is a no-op. + """ + i = FakeInstance() + static_info = write_static_info(i) + cache = StaticConnectionInfoCache(i.uri(), static_info) + conn_info = cache._info + await cache.close() + conn_info2 = cache._info + assert conn_info2 == conn_info From 0d75bd9ffa8d58901f56da723129fc9fc13e64e3 Mon Sep 17 00:00:00 2001 From: rhatgadkar-goog Date: Fri, 7 Feb 2025 18:09:31 +0000 Subject: [PATCH 14/14] Address PR comments --- google/cloud/alloydb/connector/connector.py | 3 +++ google/cloud/alloydb/connector/static.py | 7 ++++++- tests/unit/conftest.py | 3 +-- tests/unit/test_static.py | 12 +++++++++--- 4 files changed, 19 insertions(+), 6 deletions(-) diff --git a/google/cloud/alloydb/connector/connector.py b/google/cloud/alloydb/connector/connector.py index 23e567b3..41a13fa3 100644 --- a/google/cloud/alloydb/connector/connector.py +++ b/google/cloud/alloydb/connector/connector.py @@ -76,6 +76,9 @@ class Connector: static_conn_info (io.TextIOBase): A file-like JSON object that contains static connection info for the StaticConnectionInfoCache. Defaults to None, which will not use the StaticConnectionInfoCache. + This is a *dev-only* option and should not be used in production as + it will result in failed connections after the client certificate + expires. """ def __init__( diff --git a/google/cloud/alloydb/connector/static.py b/google/cloud/alloydb/connector/static.py index da90245d..8ee6732e 100644 --- a/google/cloud/alloydb/connector/static.py +++ b/google/cloud/alloydb/connector/static.py @@ -26,7 +26,12 @@ class StaticConnectionInfoCache: """ StaticConnectionInfoCache creates a connection info cache that will always - return a pre-defined connection info. + return a pre-defined connection info. This is a *dev-only* option and + should not be used in production as it will result in failed connections + after the client certificate expires. It is also subject to breaking changes + in the format. NOTE: The static connection info is not refreshed by the + connector. The JSON format supports multiple instances, regardless of + cluster. This static connection info should hold JSON with the following format: { diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index 890378d0..ca5d0910 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -16,7 +16,6 @@ import socket import ssl from threading import Thread -from typing import Final from aiofiles.tempfile import TemporaryDirectory from mocks import FakeAlloyDBClient @@ -27,7 +26,7 @@ from google.cloud.alloydb.connector.utils import _write_to_file -DELAY: Final[float] = 1.0 +DELAY = 1.0 @pytest.fixture diff --git a/tests/unit/test_static.py b/tests/unit/test_static.py index 30e82cff..77c33eba 100644 --- a/tests/unit/test_static.py +++ b/tests/unit/test_static.py @@ -12,7 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -from mocks import FakeInstance, write_static_info +from mocks import FakeInstance +from mocks import write_static_info from google.cloud.alloydb.connector.connection_info import ConnectionInfo from google.cloud.alloydb.connector.static import StaticConnectionInfoCache @@ -36,6 +37,7 @@ def test_StaticConnectionInfoCache_init() -> None: } assert cache._info.expiration + def test_StaticConnectionInfoCache_init_trailing_dot_dns() -> None: """ Test that StaticConnectionInfoCache.__init__ populates its ConnectionInfo @@ -56,6 +58,7 @@ def test_StaticConnectionInfoCache_init_trailing_dot_dns() -> None: } assert cache._info.expiration + async def test_StaticConnectionInfoCache_force_refresh() -> None: """ Test that StaticConnectionInfoCache.force_refresh is a no-op. @@ -68,20 +71,23 @@ async def test_StaticConnectionInfoCache_force_refresh() -> None: conn_info2 = cache._info assert conn_info2 == conn_info + async def test_StaticConnectionInfoCache_connect_info() -> None: """ - Test that StaticConnectionInfoCache.connect_info works as expected. + Test that StaticConnectionInfoCache.connect_info returns the ConnectionInfo + object. """ i = FakeInstance() static_info = write_static_info(i) cache = StaticConnectionInfoCache(i.uri(), static_info) # check that cached connection info is now set assert isinstance(cache._info, ConnectionInfo) - conn_info = await cache.connect_info() + conn_info = cache._info # check that calling connect_info uses cached info conn_info2 = await cache.connect_info() assert conn_info2 == conn_info + async def test_StaticConnectionInfoCache_close() -> None: """ Test that StaticConnectionInfoCache.close is a no-op.