Skip to content

Commit c65f700

Browse files
feat: support static connection info
1 parent 58ef2d2 commit c65f700

File tree

7 files changed

+281
-26
lines changed

7 files changed

+281
-26
lines changed

google/cloud/alloydb/connector/async_connector.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from __future__ import annotations
1616

1717
import asyncio
18+
import io
1819
import logging
1920
from types import TracebackType
2021
from typing import Any, Dict, Optional, Type, TYPE_CHECKING, Union
@@ -29,6 +30,7 @@
2930
from google.cloud.alloydb.connector.enums import RefreshStrategy
3031
from google.cloud.alloydb.connector.instance import RefreshAheadCache
3132
from google.cloud.alloydb.connector.lazy import LazyRefreshCache
33+
from google.cloud.alloydb.connector.static import StaticConnectionInfoCache
3234
from google.cloud.alloydb.connector.utils import generate_keys
3335

3436
if TYPE_CHECKING:
@@ -59,6 +61,9 @@ class AsyncConnector:
5961
of the following: RefreshStrategy.LAZY ("LAZY") or
6062
RefreshStrategy.BACKGROUND ("BACKGROUND").
6163
Default: RefreshStrategy.BACKGROUND
64+
static_conn_info (io.TextIOBase): A file-like JSON object that contains
65+
static connection info for the StaticConnectionInfoCache.
66+
Defaults to None, which will not use the StaticConnectionInfoCache.
6267
"""
6368

6469
def __init__(
@@ -70,6 +75,7 @@ def __init__(
7075
ip_type: str | IPTypes = IPTypes.PRIVATE,
7176
user_agent: Optional[str] = None,
7277
refresh_strategy: str | RefreshStrategy = RefreshStrategy.BACKGROUND,
78+
static_conn_info: io.TextIOBase = None,
7379
) -> None:
7480
self._cache: Dict[str, Union[RefreshAheadCache, LazyRefreshCache]] = {}
7581
# initialize default params
@@ -100,6 +106,7 @@ def __init__(
100106
except RuntimeError:
101107
self._keys = None
102108
self._client: Optional[AlloyDBClient] = None
109+
self._static_conn_info = static_conn_info
103110

104111
async def connect(
105112
self,
@@ -138,10 +145,13 @@ async def connect(
138145
)
139146

140147
enable_iam_auth = kwargs.pop("enable_iam_auth", self._enable_iam_auth)
148+
static_conn_info = kwargs.pop("static_conn_info", self._static_conn_info)
141149

142150
# use existing connection info if possible
143151
if instance_uri in self._cache:
144152
cache = self._cache[instance_uri]
153+
elif static_conn_info:
154+
cache = StaticConnectionInfoCache(instance_uri, static_conn_info)
145155
else:
146156
if self._refresh_strategy == RefreshStrategy.LAZY:
147157
logger.debug(

google/cloud/alloydb/connector/connector.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
import asyncio
1818
from functools import partial
19+
import io
1920
import logging
2021
import socket
2122
import struct
@@ -34,6 +35,7 @@
3435
from google.cloud.alloydb.connector.instance import RefreshAheadCache
3536
from google.cloud.alloydb.connector.lazy import LazyRefreshCache
3637
import google.cloud.alloydb.connector.pg8000 as pg8000
38+
from google.cloud.alloydb.connector.static import StaticConnectionInfoCache
3739
from google.cloud.alloydb.connector.utils import generate_keys
3840
import google.cloud.alloydb_connectors_v1.proto.resources_pb2 as connectorspb
3941

@@ -71,6 +73,9 @@ class Connector:
7173
of the following: RefreshStrategy.LAZY ("LAZY") or
7274
RefreshStrategy.BACKGROUND ("BACKGROUND").
7375
Default: RefreshStrategy.BACKGROUND
76+
static_conn_info (io.TextIOBase): A file-like JSON object that contains
77+
static connection info for the StaticConnectionInfoCache.
78+
Defaults to None, which will not use the StaticConnectionInfoCache.
7479
"""
7580

7681
def __init__(
@@ -82,6 +87,7 @@ def __init__(
8287
ip_type: str | IPTypes = IPTypes.PRIVATE,
8388
user_agent: Optional[str] = None,
8489
refresh_strategy: str | RefreshStrategy = RefreshStrategy.BACKGROUND,
90+
static_conn_info: io.TextIOBase = None,
8591
) -> None:
8692
# create event loop and start it in background thread
8793
self._loop: asyncio.AbstractEventLoop = asyncio.new_event_loop()
@@ -113,6 +119,7 @@ def __init__(
113119
loop=self._loop,
114120
)
115121
self._client: Optional[AlloyDBClient] = None
122+
self._static_conn_info = static_conn_info
116123

117124
def connect(self, instance_uri: str, driver: str, **kwargs: Any) -> Any:
118125
"""
@@ -168,9 +175,12 @@ async def connect_async(self, instance_uri: str, driver: str, **kwargs: Any) ->
168175
driver=driver,
169176
)
170177
enable_iam_auth = kwargs.pop("enable_iam_auth", self._enable_iam_auth)
178+
static_conn_info = kwargs.pop("static_conn_info", self._static_conn_info)
171179
# use existing connection info if possible
172180
if instance_uri in self._cache:
173181
cache = self._cache[instance_uri]
182+
elif static_conn_info:
183+
cache = StaticConnectionInfoCache(instance_uri, static_conn_info)
174184
else:
175185
if self._refresh_strategy == RefreshStrategy.LAZY:
176186
logger.debug(
Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
# Copyright 2024 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from cryptography.hazmat.primitives import serialization
16+
from cryptography.hazmat.primitives.asymmetric import rsa
17+
from datetime import datetime
18+
from datetime import timedelta
19+
from datetime import timezone
20+
import io
21+
import json
22+
23+
from google.cloud.alloydb.connector.connection_info import ConnectionInfo
24+
25+
26+
class StaticConnectionInfoCache:
27+
"""
28+
StaticConnectionInfoCache creates a connection info cache that will always
29+
return a pre-defined connection info.
30+
31+
This static connection info should hold JSON with the following format:
32+
{
33+
"publicKey": "<PEM Encoded public RSA key>",
34+
"privateKey": "<PEM Encoded private RSA key>",
35+
"projects/<PROJECT>/locations/<REGION>/clusters/<CLUSTER>/instances/<INSTANCE>": {
36+
"ipAddress": "<PSA-based private IP address>",
37+
"publicIpAddress": "<public IP address>",
38+
"pscInstanceConfig": {
39+
"pscDnsName": "<PSC DNS name>"
40+
},
41+
"pemCertificateChain": [
42+
"<client cert>", "<intermediate cert>", "<CA cert>"
43+
],
44+
"caCert": "<CA cert>"
45+
}
46+
}
47+
"""
48+
49+
def __init__(self, instance_uri: str, static_conn_info: io.TextIOBase) -> None:
50+
"""
51+
Initializes a StaticConnectionInfoCache instance.
52+
53+
Args:
54+
instance_uri (str): The AlloyDB instance's connection URI.
55+
static_conn_info (io.TextIOBase): The static connection info JSON.
56+
"""
57+
static_info = json.load(static_conn_info)
58+
ca_cert = static_info[instance_uri]["caCert"]
59+
cert_chain = static_info[instance_uri]["pemCertificateChain"]
60+
ip_addrs = {
61+
"PRIVATE": static_info[instance_uri]["ipAddress"],
62+
"PUBLIC": static_info[instance_uri]["publicIpAddress"],
63+
"PSC": static_info[instance_uri]["pscInstanceConfig"]["pscDnsName"],
64+
}
65+
expiration = datetime.now(timezone.utc) + timedelta(hours=1)
66+
priv_key = static_info["privateKey"]
67+
priv_key_bytes: rsa.RSAPrivateKey = serialization.load_pem_private_key(
68+
priv_key.encode("UTF-8"), password=None,
69+
)
70+
self._info = ConnectionInfo(cert_chain, ca_cert, priv_key_bytes, ip_addrs, expiration)
71+
72+
async def force_refresh(self) -> None:
73+
"""
74+
This is a no-op as the cache holds only static connection information
75+
and does no refresh.
76+
"""
77+
pass
78+
79+
async def connect_info(self) -> ConnectionInfo:
80+
"""
81+
Retrieves ConnectionInfo instance for establishing a secure
82+
connection to the AlloyDB instance.
83+
"""
84+
return self._info
85+
86+
async def close(self) -> None:
87+
"""
88+
This is a no-op.
89+
"""
90+
pass

tests/unit/conftest.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -66,16 +66,16 @@ async def start_proxy_server(instance: FakeInstance) -> None:
6666
# listen for incoming connections
6767
sock.listen(5)
6868

69-
while True:
70-
with context.wrap_socket(sock, server_side=True) as ssock:
69+
with context.wrap_socket(sock, server_side=True) as ssock:
70+
while True:
7171
conn, _ = ssock.accept()
7272
metadata_exchange(conn)
7373
conn.sendall(instance.name.encode("utf-8"))
7474
conn.close()
7575

7676

7777
@pytest.fixture(scope="session")
78-
def proxy_server(fake_instance: FakeInstance) -> Generator:
78+
def proxy_server(fake_instance: FakeInstance) -> None:
7979
"""Run local proxy server capable of performing metadata exchange"""
8080
thread = Thread(
8181
target=asyncio.run,
@@ -87,5 +87,4 @@ def proxy_server(fake_instance: FakeInstance) -> Generator:
8787
daemon=True,
8888
)
8989
thread.start()
90-
yield thread
91-
thread.join()
90+
thread.join(0.1) # wait 100ms to allow the proxy server to start

tests/unit/mocks.py

Lines changed: 74 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,9 @@
1616
from datetime import datetime
1717
from datetime import timedelta
1818
from datetime import timezone
19+
import io
1920
import ipaddress
21+
import json
2022
import ssl
2123
import struct
2224
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple
@@ -193,6 +195,34 @@ def get_pem_certs(self) -> Tuple[str, str, str]:
193195
encoding=serialization.Encoding.PEM
194196
).decode("UTF-8")
195197
return (pem_root, pem_intermediate, pem_server)
198+
199+
def generate_pem_certificate_chain(self, pub_key: str) -> Tuple[str, List[str]]:
200+
"""Generate the CA certificate and certificate chain for the AlloyDB instance."""
201+
root_cert, intermediate_cert, server_cert = self.get_pem_certs()
202+
# encode public key to bytes
203+
pub_key_bytes: rsa.RSAPublicKey = serialization.load_pem_public_key(
204+
pub_key.encode("UTF-8"),
205+
)
206+
# build client cert
207+
client_cert = (
208+
x509.CertificateBuilder()
209+
.subject_name(self.intermediate_cert.subject)
210+
.issuer_name(self.intermediate_cert.issuer)
211+
.public_key(pub_key_bytes)
212+
.serial_number(x509.random_serial_number())
213+
.not_valid_before(self.cert_before)
214+
.not_valid_after(self.cert_expiry)
215+
)
216+
# sign client cert with intermediate cert
217+
client_cert = client_cert.sign(self.intermediate_key, hashes.SHA256())
218+
client_cert = client_cert.public_bytes(
219+
encoding=serialization.Encoding.PEM
220+
).decode("UTF-8")
221+
return (server_cert, [client_cert, intermediate_cert, root_cert])
222+
223+
def uri(self) -> str:
224+
"""The URI of the AlloyDB instance."""
225+
return f"projects/{self.project}/locations/{self.region}/clusters/{self.cluster}/instances/{self.name}"
196226

197227

198228
class FakeAlloyDBClient:
@@ -216,27 +246,7 @@ async def _get_client_certificate(
216246
cluster: str,
217247
pub_key: str,
218248
) -> Tuple[str, List[str]]:
219-
root_cert, intermediate_cert, server_cert = self.instance.get_pem_certs()
220-
# encode public key to bytes
221-
pub_key_bytes: rsa.RSAPublicKey = serialization.load_pem_public_key(
222-
pub_key.encode("UTF-8"),
223-
)
224-
# build client cert
225-
client_cert = (
226-
x509.CertificateBuilder()
227-
.subject_name(self.instance.intermediate_cert.subject)
228-
.issuer_name(self.instance.intermediate_cert.issuer)
229-
.public_key(pub_key_bytes)
230-
.serial_number(x509.random_serial_number())
231-
.not_valid_before(self.instance.cert_before)
232-
.not_valid_after(self.instance.cert_expiry)
233-
)
234-
# sign client cert with intermediate cert
235-
client_cert = client_cert.sign(self.instance.intermediate_key, hashes.SHA256())
236-
client_cert = client_cert.public_bytes(
237-
encoding=serialization.Encoding.PEM
238-
).decode("UTF-8")
239-
return (server_cert, [client_cert, intermediate_cert, root_cert])
249+
return self.instance.generate_pem_certificate_chain(pub_key)
240250

241251
async def get_connection_info(
242252
self,
@@ -378,3 +388,46 @@ async def force_refresh(self) -> None:
378388

379389
async def close(self) -> None:
380390
self._close_called = True
391+
392+
393+
def write_static_info(i: FakeInstance) -> io.StringIO:
394+
"""
395+
Creates a static connection info JSON for the StaticConnectionInfoCache.
396+
397+
Args:
398+
i (FakeInstance): The FakeInstance to use to create the CA cert and
399+
chain.
400+
401+
Returns:
402+
io.StringIO
403+
"""
404+
priv_key = rsa.generate_private_key(public_exponent=65537, key_size=2048)
405+
pub_pem = (
406+
priv_key.public_key()
407+
.public_bytes(
408+
encoding=serialization.Encoding.PEM,
409+
format=serialization.PublicFormat.SubjectPublicKeyInfo,
410+
)
411+
.decode("UTF-8")
412+
)
413+
priv_pem = (
414+
priv_key.private_bytes(
415+
encoding=serialization.Encoding.PEM,
416+
format=serialization.PrivateFormat.TraditionalOpenSSL,
417+
encryption_algorithm=serialization.NoEncryption(),
418+
)
419+
.decode("UTF-8")
420+
)
421+
ca_cert, chain = i.generate_pem_certificate_chain(pub_pem)
422+
static = {
423+
"publicKey": pub_pem,
424+
"privateKey": priv_pem,
425+
}
426+
static[i.uri()] = {
427+
"pemCertificateChain": chain,
428+
"caCert": ca_cert,
429+
"ipAddress": "127.0.0.1", # "private" IP is localhost in testing
430+
"publicIpAddress": "",
431+
"pscInstanceConfig": {"pscDnsName": ""},
432+
}
433+
return io.StringIO(json.dumps(static))

0 commit comments

Comments
 (0)