diff --git a/google/cloud/alloydb/connector/async_connector.py b/google/cloud/alloydb/connector/async_connector.py index 5e8e6bd6..bd8e88d3 100644 --- a/google/cloud/alloydb/connector/async_connector.py +++ b/google/cloud/alloydb/connector/async_connector.py @@ -17,7 +17,7 @@ import asyncio import logging from types import TracebackType -from typing import Any, Optional, TYPE_CHECKING, Union +from typing import Any, Optional, TYPE_CHECKING import google.auth from google.auth.credentials import with_scopes_if_required @@ -29,6 +29,7 @@ from google.cloud.alloydb.connector.enums import RefreshStrategy from google.cloud.alloydb.connector.instance import RefreshAheadCache from google.cloud.alloydb.connector.lazy import LazyRefreshCache +from google.cloud.alloydb.connector.types import CacheTypes from google.cloud.alloydb.connector.utils import generate_keys if TYPE_CHECKING: @@ -71,7 +72,7 @@ def __init__( user_agent: Optional[str] = None, refresh_strategy: str | RefreshStrategy = RefreshStrategy.BACKGROUND, ) -> None: - self._cache: dict[str, Union[RefreshAheadCache, LazyRefreshCache]] = {} + self._cache: dict[str, CacheTypes] = {} # initialize default params self._quota_project = quota_project self._alloydb_api_endpoint = alloydb_api_endpoint diff --git a/google/cloud/alloydb/connector/connection_info.py b/google/cloud/alloydb/connector/connection_info.py index 65ca231d..aed28adb 100644 --- a/google/cloud/alloydb/connector/connection_info.py +++ b/google/cloud/alloydb/connector/connection_info.py @@ -27,7 +27,7 @@ if TYPE_CHECKING: import datetime - from cryptography.hazmat.primitives.asymmetric import rsa + from cryptography.hazmat.primitives.asymmetric.types import PrivateKeyTypes from google.cloud.alloydb.connector.enums import IPTypes @@ -41,7 +41,7 @@ class ConnectionInfo: cert_chain: list[str] ca_cert: str - key: rsa.RSAPrivateKey + key: PrivateKeyTypes ip_addrs: dict[str, Optional[str]] expiration: datetime.datetime context: Optional[ssl.SSLContext] = None diff --git a/google/cloud/alloydb/connector/connector.py b/google/cloud/alloydb/connector/connector.py index 8112aa9b..41a13fa3 100644 --- a/google/cloud/alloydb/connector/connector.py +++ b/google/cloud/alloydb/connector/connector.py @@ -16,12 +16,13 @@ import asyncio from functools import partial +import io import logging import socket import struct from threading import Thread from types import TracebackType -from typing import Any, Optional, TYPE_CHECKING, Union +from typing import Any, Optional, TYPE_CHECKING from google.auth import default from google.auth.credentials import TokenState @@ -34,6 +35,7 @@ from google.cloud.alloydb.connector.instance import RefreshAheadCache from google.cloud.alloydb.connector.lazy import LazyRefreshCache import google.cloud.alloydb.connector.pg8000 as pg8000 +from google.cloud.alloydb.connector.types import CacheTypes from google.cloud.alloydb.connector.utils import generate_keys import google.cloud.alloydb_connectors_v1.proto.resources_pb2 as connectorspb @@ -71,6 +73,12 @@ class Connector: of the following: RefreshStrategy.LAZY ("LAZY") or RefreshStrategy.BACKGROUND ("BACKGROUND"). Default: RefreshStrategy.BACKGROUND + static_conn_info (io.TextIOBase): A file-like JSON object that contains + static connection info for the StaticConnectionInfoCache. + Defaults to None, which will not use the StaticConnectionInfoCache. + This is a *dev-only* option and should not be used in production as + it will result in failed connections after the client certificate + expires. """ def __init__( @@ -82,12 +90,13 @@ def __init__( ip_type: str | IPTypes = IPTypes.PRIVATE, user_agent: Optional[str] = None, refresh_strategy: str | RefreshStrategy = RefreshStrategy.BACKGROUND, + static_conn_info: Optional[io.TextIOBase] = None, ) -> None: # create event loop and start it in background thread self._loop: asyncio.AbstractEventLoop = asyncio.new_event_loop() self._thread = Thread(target=self._loop.run_forever, daemon=True) self._thread.start() - self._cache: dict[str, Union[RefreshAheadCache, LazyRefreshCache]] = {} + self._cache: dict[str, CacheTypes] = {} # initialize default params self._quota_project = quota_project self._alloydb_api_endpoint = alloydb_api_endpoint @@ -113,6 +122,7 @@ def __init__( loop=self._loop, ) self._client: Optional[AlloyDBClient] = None + self._static_conn_info = static_conn_info def connect(self, instance_uri: str, driver: str, **kwargs: Any) -> Any: """ diff --git a/google/cloud/alloydb/connector/static.py b/google/cloud/alloydb/connector/static.py new file mode 100644 index 00000000..8ee6732e --- /dev/null +++ b/google/cloud/alloydb/connector/static.py @@ -0,0 +1,102 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from datetime import datetime +from datetime import timedelta +from datetime import timezone +import io +import json + +from cryptography.hazmat.primitives import serialization + +from google.cloud.alloydb.connector.connection_info import ConnectionInfo + + +class StaticConnectionInfoCache: + """ + StaticConnectionInfoCache creates a connection info cache that will always + return a pre-defined connection info. This is a *dev-only* option and + should not be used in production as it will result in failed connections + after the client certificate expires. It is also subject to breaking changes + in the format. NOTE: The static connection info is not refreshed by the + connector. The JSON format supports multiple instances, regardless of + cluster. + + This static connection info should hold JSON with the following format: + { + "publicKey": "", + "privateKey": "", + "projects//locations//clusters//instances/": { + "ipAddress": "", + "publicIpAddress": "", + "pscInstanceConfig": { + "pscDnsName": "" + }, + "pemCertificateChain": [ + "", "", "" + ], + "caCert": "" + } + } + """ + + def __init__(self, instance_uri: str, static_conn_info: io.TextIOBase) -> None: + """ + Initializes a StaticConnectionInfoCache instance. + + Args: + instance_uri (str): The AlloyDB instance's connection URI. + static_conn_info (io.TextIOBase): The static connection info JSON. + """ + static_info = json.load(static_conn_info) + ca_cert = static_info[instance_uri]["caCert"] + cert_chain = static_info[instance_uri]["pemCertificateChain"] + dns = "" + if static_info[instance_uri]["pscInstanceConfig"]: + dns = static_info[instance_uri]["pscInstanceConfig"]["pscDnsName"].rstrip( + "." + ) + ip_addrs = { + "PRIVATE": static_info[instance_uri]["ipAddress"], + "PUBLIC": static_info[instance_uri]["publicIpAddress"], + "PSC": dns, + } + expiration = datetime.now(timezone.utc) + timedelta(hours=1) + priv_key = static_info["privateKey"] + priv_key_bytes = serialization.load_pem_private_key( + priv_key.encode("UTF-8"), password=None + ) + self._info = ConnectionInfo( + cert_chain, ca_cert, priv_key_bytes, ip_addrs, expiration + ) + + async def force_refresh(self) -> None: + """ + This is a no-op as the cache holds only static connection information + and does no refresh. + """ + pass + + async def connect_info(self) -> ConnectionInfo: + """ + Retrieves ConnectionInfo instance for establishing a secure + connection to the AlloyDB instance. + """ + return self._info + + async def close(self) -> None: + """ + This is a no-op. + """ + pass diff --git a/google/cloud/alloydb/connector/types.py b/google/cloud/alloydb/connector/types.py new file mode 100644 index 00000000..486370c7 --- /dev/null +++ b/google/cloud/alloydb/connector/types.py @@ -0,0 +1,23 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import typing + +from google.cloud.alloydb.connector.instance import RefreshAheadCache +from google.cloud.alloydb.connector.lazy import LazyRefreshCache +from google.cloud.alloydb.connector.static import StaticConnectionInfoCache + +CacheTypes = typing.Union[ + RefreshAheadCache, LazyRefreshCache, StaticConnectionInfoCache +] diff --git a/google/cloud/alloydb/connector/utils.py b/google/cloud/alloydb/connector/utils.py index 908ce406..e4c99393 100644 --- a/google/cloud/alloydb/connector/utils.py +++ b/google/cloud/alloydb/connector/utils.py @@ -17,10 +17,11 @@ import aiofiles from cryptography.hazmat.primitives import serialization from cryptography.hazmat.primitives.asymmetric import rsa +from cryptography.hazmat.primitives.asymmetric.types import PrivateKeyTypes async def _write_to_file( - dir_path: str, ca_cert: str, cert_chain: list[str], key: rsa.RSAPrivateKey + dir_path: str, ca_cert: str, cert_chain: list[str], key: PrivateKeyTypes ) -> tuple[str, str, str]: """ Helper function to write the server_ca, client certificate and diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index 45648fa7..ca5d0910 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -16,7 +16,6 @@ import socket import ssl from threading import Thread -from typing import Generator from aiofiles.tempfile import TemporaryDirectory from mocks import FakeAlloyDBClient @@ -27,6 +26,8 @@ from google.cloud.alloydb.connector.utils import _write_to_file +DELAY = 1.0 + @pytest.fixture def credentials() -> FakeCredentials: @@ -66,8 +67,8 @@ async def start_proxy_server(instance: FakeInstance) -> None: # listen for incoming connections sock.listen(5) - while True: - with context.wrap_socket(sock, server_side=True) as ssock: + with context.wrap_socket(sock, server_side=True) as ssock: + while True: conn, _ = ssock.accept() metadata_exchange(conn) conn.sendall(instance.name.encode("utf-8")) @@ -75,7 +76,7 @@ async def start_proxy_server(instance: FakeInstance) -> None: @pytest.fixture(scope="session") -def proxy_server(fake_instance: FakeInstance) -> Generator: +def proxy_server(fake_instance: FakeInstance) -> None: """Run local proxy server capable of performing metadata exchange""" thread = Thread( target=asyncio.run, @@ -87,5 +88,4 @@ def proxy_server(fake_instance: FakeInstance) -> Generator: daemon=True, ) thread.start() - yield thread - thread.join() + thread.join(DELAY) # add a delay to allow the proxy server to start diff --git a/tests/unit/mocks.py b/tests/unit/mocks.py index be8e4750..cb64c819 100644 --- a/tests/unit/mocks.py +++ b/tests/unit/mocks.py @@ -16,7 +16,9 @@ from datetime import datetime from datetime import timedelta from datetime import timezone +import io import ipaddress +import json import ssl import struct from typing import Any, Callable, Literal, Optional @@ -147,7 +149,7 @@ def __init__( cluster: str = "test-cluster", name: str = "test-instance", ip_addrs: dict = { - "PRIVATE": "127.0.0.1", + "PRIVATE": "127.0.0.1", # "private" IP is localhost in testing "PUBLIC": "0.0.0.0", "PSC": "x.y.alloydb.goog", }, @@ -194,6 +196,34 @@ def get_pem_certs(self) -> tuple[str, str, str]: ).decode("UTF-8") return (pem_root, pem_intermediate, pem_server) + def generate_pem_certificate_chain(self, pub_key: str) -> tuple[str, list[str]]: + """Generate the CA certificate and certificate chain for the AlloyDB instance.""" + root_cert, intermediate_cert, server_cert = self.get_pem_certs() + # encode public key to bytes + pub_key_bytes: rsa.RSAPublicKey = serialization.load_pem_public_key( + pub_key.encode("UTF-8"), + ) + # build client cert + client_cert = ( + x509.CertificateBuilder() + .subject_name(self.intermediate_cert.subject) + .issuer_name(self.intermediate_cert.issuer) + .public_key(pub_key_bytes) + .serial_number(x509.random_serial_number()) + .not_valid_before(self.cert_before) + .not_valid_after(self.cert_expiry) + ) + # sign client cert with intermediate cert + client_cert = client_cert.sign(self.intermediate_key, hashes.SHA256()) + client_cert = client_cert.public_bytes( + encoding=serialization.Encoding.PEM + ).decode("UTF-8") + return (server_cert, [client_cert, intermediate_cert, root_cert]) + + def uri(self) -> str: + """The URI of the AlloyDB instance.""" + return f"projects/{self.project}/locations/{self.region}/clusters/{self.cluster}/instances/{self.name}" + class FakeAlloyDBClient: """Fake class for testing AlloyDBClient""" @@ -378,3 +408,43 @@ async def force_refresh(self) -> None: async def close(self) -> None: self._close_called = True + + +def write_static_info(i: FakeInstance) -> io.StringIO: + """ + Creates a static connection info JSON for the StaticConnectionInfoCache. + + Args: + i (FakeInstance): The FakeInstance to use to create the CA cert and + chain. + + Returns: + io.StringIO + """ + priv_key = rsa.generate_private_key(public_exponent=65537, key_size=2048) + pub_pem = ( + priv_key.public_key() + .public_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PublicFormat.SubjectPublicKeyInfo, + ) + .decode("UTF-8") + ) + priv_pem = priv_key.private_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.TraditionalOpenSSL, + encryption_algorithm=serialization.NoEncryption(), + ).decode("UTF-8") + ca_cert, chain = i.generate_pem_certificate_chain(pub_pem) + static = { + "publicKey": pub_pem, + "privateKey": priv_pem, + } + static[i.uri()] = { + "pemCertificateChain": chain, + "caCert": ca_cert, + "ipAddress": i.ip_addrs["PRIVATE"], + "publicIpAddress": i.ip_addrs["PUBLIC"], + "pscInstanceConfig": {"pscDnsName": i.ip_addrs["PSC"]}, + } + return io.StringIO(json.dumps(static)) diff --git a/tests/unit/test_connector.py b/tests/unit/test_connector.py index a02ad30e..c7660d1c 100644 --- a/tests/unit/test_connector.py +++ b/tests/unit/test_connector.py @@ -20,6 +20,7 @@ from mock import patch from mocks import FakeAlloyDBClient from mocks import FakeCredentials +from mocks import write_static_info import pytest from google.cloud.alloydb.connector import Connector @@ -248,3 +249,28 @@ async def test_Connector_remove_cached_no_ip_type(credentials: FakeCredentials) await connector.connect_async(instance_uri, "pg8000", ip_type="private") # check that cache has been removed from dict assert instance_uri not in connector._cache + + +@pytest.mark.usefixtures("proxy_server") +def test_Connector_static_connection_info( + credentials: FakeCredentials, fake_client: FakeAlloyDBClient +) -> None: + """ + Test that Connector.__init__() can specify a static connection info to + connect to an instance. + """ + static_info = write_static_info(fake_client.instance) + with Connector(credentials=credentials, static_conn_info=static_info) as connector: + connector._client = fake_client + # patch db connection creation + with patch("google.cloud.alloydb.connector.pg8000.connect") as mock_connect: + mock_connect.return_value = True + connection = connector.connect( + fake_client.instance.uri(), + "pg8000", + user="test-user", + password="test-password", + db="test-db", + ) + # check connection is returned + assert connection is True diff --git a/tests/unit/test_static.py b/tests/unit/test_static.py new file mode 100644 index 00000000..77c33eba --- /dev/null +++ b/tests/unit/test_static.py @@ -0,0 +1,101 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from mocks import FakeInstance +from mocks import write_static_info + +from google.cloud.alloydb.connector.connection_info import ConnectionInfo +from google.cloud.alloydb.connector.static import StaticConnectionInfoCache + + +def test_StaticConnectionInfoCache_init() -> None: + """ + Test that StaticConnectionInfoCache.__init__ populates its ConnectionInfo + object. + """ + i = FakeInstance() + static_info = write_static_info(i) + cache = StaticConnectionInfoCache(i.uri(), static_info) + assert len(cache._info.cert_chain) == 3 + assert cache._info.ca_cert + assert cache._info.key + assert cache._info.ip_addrs == { + "PRIVATE": i.ip_addrs["PRIVATE"], + "PUBLIC": i.ip_addrs["PUBLIC"], + "PSC": i.ip_addrs["PSC"], + } + assert cache._info.expiration + + +def test_StaticConnectionInfoCache_init_trailing_dot_dns() -> None: + """ + Test that StaticConnectionInfoCache.__init__ populates its ConnectionInfo + object correctly when its PSC DNS name contains a trailing dot. + """ + i = FakeInstance() + no_trailing_dot_dns = i.ip_addrs["PSC"] + i.ip_addrs["PSC"] += "." + static_info = write_static_info(i) + cache = StaticConnectionInfoCache(i.uri(), static_info) + assert len(cache._info.cert_chain) == 3 + assert cache._info.ca_cert + assert cache._info.key + assert cache._info.ip_addrs == { + "PRIVATE": i.ip_addrs["PRIVATE"], + "PUBLIC": i.ip_addrs["PUBLIC"], + "PSC": no_trailing_dot_dns, + } + assert cache._info.expiration + + +async def test_StaticConnectionInfoCache_force_refresh() -> None: + """ + Test that StaticConnectionInfoCache.force_refresh is a no-op. + """ + i = FakeInstance() + static_info = write_static_info(i) + cache = StaticConnectionInfoCache(i.uri(), static_info) + conn_info = cache._info + await cache.force_refresh() + conn_info2 = cache._info + assert conn_info2 == conn_info + + +async def test_StaticConnectionInfoCache_connect_info() -> None: + """ + Test that StaticConnectionInfoCache.connect_info returns the ConnectionInfo + object. + """ + i = FakeInstance() + static_info = write_static_info(i) + cache = StaticConnectionInfoCache(i.uri(), static_info) + # check that cached connection info is now set + assert isinstance(cache._info, ConnectionInfo) + conn_info = cache._info + # check that calling connect_info uses cached info + conn_info2 = await cache.connect_info() + assert conn_info2 == conn_info + + +async def test_StaticConnectionInfoCache_close() -> None: + """ + Test that StaticConnectionInfoCache.close is a no-op. + """ + i = FakeInstance() + static_info = write_static_info(i) + cache = StaticConnectionInfoCache(i.uri(), static_info) + conn_info = cache._info + await cache.close() + conn_info2 = cache._info + assert conn_info2 == conn_info