Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions google/cloud/alloydb/connector/async_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import asyncio
import logging
from types import TracebackType
from typing import Any, Optional, TYPE_CHECKING, Union
from typing import Any, Optional, TYPE_CHECKING

import google.auth
from google.auth.credentials import with_scopes_if_required
Expand All @@ -29,6 +29,7 @@
from google.cloud.alloydb.connector.enums import RefreshStrategy
from google.cloud.alloydb.connector.instance import RefreshAheadCache
from google.cloud.alloydb.connector.lazy import LazyRefreshCache
from google.cloud.alloydb.connector.types import CacheTypes
from google.cloud.alloydb.connector.utils import generate_keys

if TYPE_CHECKING:
Expand Down Expand Up @@ -71,7 +72,7 @@ def __init__(
user_agent: Optional[str] = None,
refresh_strategy: str | RefreshStrategy = RefreshStrategy.BACKGROUND,
) -> None:
self._cache: dict[str, Union[RefreshAheadCache, LazyRefreshCache]] = {}
self._cache: dict[str, CacheTypes] = {}
# initialize default params
self._quota_project = quota_project
self._alloydb_api_endpoint = alloydb_api_endpoint
Expand Down
4 changes: 2 additions & 2 deletions google/cloud/alloydb/connector/connection_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
if TYPE_CHECKING:
import datetime

from cryptography.hazmat.primitives.asymmetric import rsa
from cryptography.hazmat.primitives.asymmetric.types import PrivateKeyTypes

from google.cloud.alloydb.connector.enums import IPTypes

Expand All @@ -41,7 +41,7 @@ class ConnectionInfo:

cert_chain: list[str]
ca_cert: str
key: rsa.RSAPrivateKey
key: PrivateKeyTypes
ip_addrs: dict[str, Optional[str]]
expiration: datetime.datetime
context: Optional[ssl.SSLContext] = None
Expand Down
14 changes: 12 additions & 2 deletions google/cloud/alloydb/connector/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,13 @@

import asyncio
from functools import partial
import io
import logging
import socket
import struct
from threading import Thread
from types import TracebackType
from typing import Any, Optional, TYPE_CHECKING, Union
from typing import Any, Optional, TYPE_CHECKING

from google.auth import default
from google.auth.credentials import TokenState
Expand All @@ -34,6 +35,7 @@
from google.cloud.alloydb.connector.instance import RefreshAheadCache
from google.cloud.alloydb.connector.lazy import LazyRefreshCache
import google.cloud.alloydb.connector.pg8000 as pg8000
from google.cloud.alloydb.connector.types import CacheTypes
from google.cloud.alloydb.connector.utils import generate_keys
import google.cloud.alloydb_connectors_v1.proto.resources_pb2 as connectorspb

Expand Down Expand Up @@ -71,6 +73,12 @@ class Connector:
of the following: RefreshStrategy.LAZY ("LAZY") or
RefreshStrategy.BACKGROUND ("BACKGROUND").
Default: RefreshStrategy.BACKGROUND
static_conn_info (io.TextIOBase): A file-like JSON object that contains
static connection info for the StaticConnectionInfoCache.
Defaults to None, which will not use the StaticConnectionInfoCache.
This is a *dev-only* option and should not be used in production as
it will result in failed connections after the client certificate
expires.
"""

def __init__(
Expand All @@ -82,12 +90,13 @@ def __init__(
ip_type: str | IPTypes = IPTypes.PRIVATE,
user_agent: Optional[str] = None,
refresh_strategy: str | RefreshStrategy = RefreshStrategy.BACKGROUND,
static_conn_info: Optional[io.TextIOBase] = None,
) -> None:
# create event loop and start it in background thread
self._loop: asyncio.AbstractEventLoop = asyncio.new_event_loop()
self._thread = Thread(target=self._loop.run_forever, daemon=True)
self._thread.start()
self._cache: dict[str, Union[RefreshAheadCache, LazyRefreshCache]] = {}
self._cache: dict[str, CacheTypes] = {}
# initialize default params
self._quota_project = quota_project
self._alloydb_api_endpoint = alloydb_api_endpoint
Expand All @@ -113,6 +122,7 @@ def __init__(
loop=self._loop,
)
self._client: Optional[AlloyDBClient] = None
self._static_conn_info = static_conn_info

def connect(self, instance_uri: str, driver: str, **kwargs: Any) -> Any:
"""
Expand Down
102 changes: 102 additions & 0 deletions google/cloud/alloydb/connector/static.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from datetime import datetime
from datetime import timedelta
from datetime import timezone
import io
import json

from cryptography.hazmat.primitives import serialization

from google.cloud.alloydb.connector.connection_info import ConnectionInfo


class StaticConnectionInfoCache:
"""
StaticConnectionInfoCache creates a connection info cache that will always
return a pre-defined connection info. This is a *dev-only* option and
should not be used in production as it will result in failed connections
after the client certificate expires. It is also subject to breaking changes
in the format. NOTE: The static connection info is not refreshed by the
connector. The JSON format supports multiple instances, regardless of
cluster.

This static connection info should hold JSON with the following format:
{
"publicKey": "<PEM Encoded public RSA key>",
"privateKey": "<PEM Encoded private RSA key>",
"projects/<PROJECT>/locations/<REGION>/clusters/<CLUSTER>/instances/<INSTANCE>": {
"ipAddress": "<PSA-based private IP address>",
"publicIpAddress": "<public IP address>",
"pscInstanceConfig": {
"pscDnsName": "<PSC DNS name>"
},
"pemCertificateChain": [
"<client cert>", "<intermediate cert>", "<CA cert>"
],
"caCert": "<CA cert>"
}
}
"""

def __init__(self, instance_uri: str, static_conn_info: io.TextIOBase) -> None:
"""
Initializes a StaticConnectionInfoCache instance.

Args:
instance_uri (str): The AlloyDB instance's connection URI.
static_conn_info (io.TextIOBase): The static connection info JSON.
"""
static_info = json.load(static_conn_info)
ca_cert = static_info[instance_uri]["caCert"]
cert_chain = static_info[instance_uri]["pemCertificateChain"]
dns = ""
if static_info[instance_uri]["pscInstanceConfig"]:
dns = static_info[instance_uri]["pscInstanceConfig"]["pscDnsName"].rstrip(
"."
)
ip_addrs = {
"PRIVATE": static_info[instance_uri]["ipAddress"],
"PUBLIC": static_info[instance_uri]["publicIpAddress"],
"PSC": dns,
}
expiration = datetime.now(timezone.utc) + timedelta(hours=1)
priv_key = static_info["privateKey"]
priv_key_bytes = serialization.load_pem_private_key(
priv_key.encode("UTF-8"), password=None
)
self._info = ConnectionInfo(
cert_chain, ca_cert, priv_key_bytes, ip_addrs, expiration
)

async def force_refresh(self) -> None:
"""
This is a no-op as the cache holds only static connection information
and does no refresh.
"""
pass

async def connect_info(self) -> ConnectionInfo:
"""
Retrieves ConnectionInfo instance for establishing a secure
connection to the AlloyDB instance.
"""
return self._info

async def close(self) -> None:
"""
This is a no-op.
"""
pass
23 changes: 23 additions & 0 deletions google/cloud/alloydb/connector/types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import typing

from google.cloud.alloydb.connector.instance import RefreshAheadCache
from google.cloud.alloydb.connector.lazy import LazyRefreshCache
from google.cloud.alloydb.connector.static import StaticConnectionInfoCache

CacheTypes = typing.Union[
RefreshAheadCache, LazyRefreshCache, StaticConnectionInfoCache
]
3 changes: 2 additions & 1 deletion google/cloud/alloydb/connector/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,11 @@
import aiofiles
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.asymmetric import rsa
from cryptography.hazmat.primitives.asymmetric.types import PrivateKeyTypes


async def _write_to_file(
dir_path: str, ca_cert: str, cert_chain: list[str], key: rsa.RSAPrivateKey
dir_path: str, ca_cert: str, cert_chain: list[str], key: PrivateKeyTypes
) -> tuple[str, str, str]:
"""
Helper function to write the server_ca, client certificate and
Expand Down
12 changes: 6 additions & 6 deletions tests/unit/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
import socket
import ssl
from threading import Thread
from typing import Generator

from aiofiles.tempfile import TemporaryDirectory
from mocks import FakeAlloyDBClient
Expand All @@ -27,6 +26,8 @@

from google.cloud.alloydb.connector.utils import _write_to_file

DELAY = 1.0


@pytest.fixture
def credentials() -> FakeCredentials:
Expand Down Expand Up @@ -66,16 +67,16 @@ async def start_proxy_server(instance: FakeInstance) -> None:
# listen for incoming connections
sock.listen(5)

while True:
with context.wrap_socket(sock, server_side=True) as ssock:
with context.wrap_socket(sock, server_side=True) as ssock:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change is needed to run multiple tests that use the same proxy_server fixture. Without it, only 1 test can successfully run. The other tests fail with "Connection refused" errors.

while True:
conn, _ = ssock.accept()
metadata_exchange(conn)
conn.sendall(instance.name.encode("utf-8"))
conn.close()


@pytest.fixture(scope="session")
def proxy_server(fake_instance: FakeInstance) -> Generator:
def proxy_server(fake_instance: FakeInstance) -> None:
"""Run local proxy server capable of performing metadata exchange"""
thread = Thread(
target=asyncio.run,
Expand All @@ -87,5 +88,4 @@ def proxy_server(fake_instance: FakeInstance) -> Generator:
daemon=True,
)
thread.start()
yield thread
thread.join()
thread.join(DELAY) # add a delay to allow the proxy server to start
72 changes: 71 additions & 1 deletion tests/unit/mocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@
from datetime import datetime
from datetime import timedelta
from datetime import timezone
import io
import ipaddress
import json
import ssl
import struct
from typing import Any, Callable, Literal, Optional
Expand Down Expand Up @@ -147,7 +149,7 @@ def __init__(
cluster: str = "test-cluster",
name: str = "test-instance",
ip_addrs: dict = {
"PRIVATE": "127.0.0.1",
"PRIVATE": "127.0.0.1", # "private" IP is localhost in testing
"PUBLIC": "0.0.0.0",
"PSC": "x.y.alloydb.goog",
},
Expand Down Expand Up @@ -194,6 +196,34 @@ def get_pem_certs(self) -> tuple[str, str, str]:
).decode("UTF-8")
return (pem_root, pem_intermediate, pem_server)

def generate_pem_certificate_chain(self, pub_key: str) -> tuple[str, list[str]]:
"""Generate the CA certificate and certificate chain for the AlloyDB instance."""
root_cert, intermediate_cert, server_cert = self.get_pem_certs()
# encode public key to bytes
pub_key_bytes: rsa.RSAPublicKey = serialization.load_pem_public_key(
pub_key.encode("UTF-8"),
)
# build client cert
client_cert = (
x509.CertificateBuilder()
.subject_name(self.intermediate_cert.subject)
.issuer_name(self.intermediate_cert.issuer)
.public_key(pub_key_bytes)
.serial_number(x509.random_serial_number())
.not_valid_before(self.cert_before)
.not_valid_after(self.cert_expiry)
)
# sign client cert with intermediate cert
client_cert = client_cert.sign(self.intermediate_key, hashes.SHA256())
client_cert = client_cert.public_bytes(
encoding=serialization.Encoding.PEM
).decode("UTF-8")
return (server_cert, [client_cert, intermediate_cert, root_cert])

def uri(self) -> str:
"""The URI of the AlloyDB instance."""
return f"projects/{self.project}/locations/{self.region}/clusters/{self.cluster}/instances/{self.name}"


class FakeAlloyDBClient:
"""Fake class for testing AlloyDBClient"""
Expand Down Expand Up @@ -378,3 +408,43 @@ async def force_refresh(self) -> None:

async def close(self) -> None:
self._close_called = True


def write_static_info(i: FakeInstance) -> io.StringIO:
"""
Creates a static connection info JSON for the StaticConnectionInfoCache.

Args:
i (FakeInstance): The FakeInstance to use to create the CA cert and
chain.

Returns:
io.StringIO
"""
priv_key = rsa.generate_private_key(public_exponent=65537, key_size=2048)
pub_pem = (
priv_key.public_key()
.public_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PublicFormat.SubjectPublicKeyInfo,
)
.decode("UTF-8")
)
priv_pem = priv_key.private_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PrivateFormat.TraditionalOpenSSL,
encryption_algorithm=serialization.NoEncryption(),
).decode("UTF-8")
ca_cert, chain = i.generate_pem_certificate_chain(pub_pem)
static = {
"publicKey": pub_pem,
"privateKey": priv_pem,
}
static[i.uri()] = {
"pemCertificateChain": chain,
"caCert": ca_cert,
"ipAddress": i.ip_addrs["PRIVATE"],
"publicIpAddress": i.ip_addrs["PUBLIC"],
"pscInstanceConfig": {"pscDnsName": i.ip_addrs["PSC"]},
}
return io.StringIO(json.dumps(static))
Loading
Loading