Skip to content

Commit ff37321

Browse files
authored
Add experimental support for pinning chain certificates
`SSLObject.get_verified_chain()` and `Certificate.public_bytes()` are private APIs in CPython 3.10. They're not documented anywhere yet but seem to work and we need them for Security on by Default. See: python/cpython#25467
1 parent ebbc467 commit ff37321

File tree

7 files changed

+277
-52
lines changed

7 files changed

+277
-52
lines changed

elastic_transport/_models.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -297,6 +297,11 @@ class NodeConfig:
297297
#: SHA-256 fingerprint of the node's certificate. If this value is
298298
#: given then root-of-trust verification isn't done and only the
299299
#: node's certificate fingerprint is verified.
300+
#:
301+
#: On CPython 3.10+ this also verifies if any certificate in the
302+
#: chain including the Root CA matches this fingerprint. However
303+
#: because this requires using private APIs support for this is
304+
#: **experimental**.
300305
ssl_assert_fingerprint: Optional[str] = None
301306
#: Minimum TLS version to use to connect to the node.
302307
ssl_version: Optional[int] = None

elastic_transport/_node/_http_aiohttp.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,8 @@
4343

4444

4545
class AiohttpHttpNode(BaseAsyncNode):
46+
"""Default asynchronous node class using the ``aiohttp`` library via HTTP"""
47+
4648
_ELASTIC_CLIENT_META = ("ai", _AIOHTTP_META_VERSION)
4749

4850
def __init__(self, config: NodeConfig):
@@ -178,10 +180,10 @@ async def perform_request(
178180
):
179181
raise ConnectionTimeout(
180182
"Connection timed out during request", errors=(e,)
181-
)
183+
) from None
182184
elif isinstance(e, (ssl.SSLError, aiohttp_exceptions.ClientSSLError)):
183-
raise TlsError(str(e), errors=(e,))
184-
raise ConnectionError(str(e), errors=(e,))
185+
raise TlsError(str(e), errors=(e,)) from None
186+
raise ConnectionError(str(e), errors=(e,)) from None
185187

186188
return (
187189
ApiResponseMeta(
@@ -221,7 +223,7 @@ def _create_aiohttp_session(self) -> None:
221223

222224

223225
@functools.lru_cache(maxsize=64, typed=True)
224-
def aiohttp_fingerprint(ssl_assert_fingerprint: str) -> aiohttp.Fingerprint:
226+
def aiohttp_fingerprint(ssl_assert_fingerprint: str) -> "aiohttp.Fingerprint":
225227
"""Changes 'ssl_assert_fingerprint' into a configured 'aiohttp.Fingerprint' instance.
226228
Uses a cache to prevent creating tons of objects needlessly.
227229
"""

elastic_transport/_node/_http_requests.py

Lines changed: 34 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -35,29 +35,39 @@
3535

3636
_REQUESTS_AVAILABLE = True
3737
_REQUESTS_META_VERSION = client_meta_version(requests.__version__)
38+
39+
# Use our custom HTTPSConnectionPool for chain cert fingerprint support.
40+
try:
41+
from ._urllib3_chain_certs import HTTPSConnectionPool
42+
except (ImportError, AttributeError):
43+
HTTPSConnectionPool = urllib3.HTTPSConnectionPool
44+
45+
class _ElasticHTTPAdapter(HTTPAdapter):
46+
def __init__(self, node_config: NodeConfig, **kwargs):
47+
self._node_config = node_config
48+
super().__init__(**kwargs)
49+
50+
def init_poolmanager(
51+
self, connections, maxsize, block=False, **pool_kwargs
52+
) -> None:
53+
if self._node_config.ssl_context:
54+
pool_kwargs.setdefault("ssl_context", self._node_config.ssl_context)
55+
if self._node_config.ssl_assert_fingerprint:
56+
pool_kwargs.setdefault(
57+
"assert_fingerprint", self._node_config.ssl_assert_fingerprint
58+
)
59+
60+
super().init_poolmanager(connections, maxsize, block=block, **pool_kwargs)
61+
self.poolmanager.pool_classes_by_scheme["https"] = HTTPSConnectionPool
62+
63+
3864
except ImportError: # pragma: nocover
3965
_REQUESTS_AVAILABLE = False
4066
_REQUESTS_META_VERSION = ""
4167

4268

4369
class RequestsHttpNode(BaseNode):
44-
"""
45-
Connection using the `requests` library communicating via HTTP.
46-
47-
:arg use_ssl: use ssl for the connection if `True`
48-
:arg verify_certs: whether to verify SSL certificates
49-
:arg ssl_show_warn: show warning when verify certs is disabled
50-
:arg ca_certs: optional path to CA bundle. By default standard requests'
51-
bundle will be used.
52-
:arg client_cert: path to the file containing the private key and the
53-
certificate, or cert only if using client_key
54-
:arg client_key: path to the file containing the private key if using
55-
separate cert and key files (client_cert will contain only the cert)
56-
:arg headers: any custom http headers to be add to requests
57-
:arg http_compress: Use gzip compression
58-
:arg opaque_id: Send this value in the 'X-Opaque-Id' HTTP header
59-
For tracing all requests made by this transport.
60-
"""
70+
"""Synchronous node using the ``requests`` library communicating via HTTP"""
6171

6272
_ELASTIC_CLIENT_META = ("rq", _REQUESTS_META_VERSION)
6373

@@ -109,8 +119,10 @@ def __init__(self, config: NodeConfig):
109119
pool_maxsize=config.connections_per_node,
110120
pool_block=True,
111121
)
112-
for prefix in ("http://", "https://"):
113-
self.session.mount(prefix=prefix, adapter=adapter)
122+
# Preload the HTTPConnectionPool so initialization issues
123+
# are raised here instead of in perform_request()
124+
adapter.get_connection(self.base_url)
125+
self.session.mount(prefix=f"{self.scheme}://", adapter=adapter)
114126

115127
def perform_request(
116128
self,
@@ -161,10 +173,10 @@ def perform_request(
161173
if isinstance(e, requests.Timeout):
162174
raise ConnectionTimeout(
163175
"Connection timed out during request", errors=(e,)
164-
)
176+
) from None
165177
elif isinstance(e, (ssl.SSLError, requests.exceptions.SSLError)):
166-
raise TlsError(str(e), errors=(e,))
167-
raise ConnectionError(str(e), errors=(e,))
178+
raise TlsError(str(e), errors=(e,)) from None
179+
raise ConnectionError(str(e), errors=(e,)) from None
168180

169181
response = ApiResponseMeta(
170182
node=self.config,
@@ -180,22 +192,3 @@ def close(self) -> None:
180192
Explicitly closes connections
181193
"""
182194
self.session.close()
183-
184-
185-
class _ElasticHTTPAdapter(HTTPAdapter):
186-
def __init__(self, node_config: NodeConfig, **kwargs):
187-
self._node_config = node_config
188-
super().__init__(**kwargs)
189-
190-
def init_poolmanager(
191-
self, connections, maxsize, block=False, **pool_kwargs
192-
) -> urllib3.PoolManager:
193-
if self._node_config.ssl_context:
194-
pool_kwargs.setdefault("ssl_context", self._node_config.ssl_context)
195-
if self._node_config.ssl_assert_fingerprint:
196-
pool_kwargs.setdefault(
197-
"ssl_assert_fingerprint", self._node_config.ssl_assert_fingerprint
198-
)
199-
return super().init_poolmanager(
200-
connections, maxsize, block=block, **pool_kwargs
201-
)

elastic_transport/_node/_http_urllib3.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,14 @@
3131
from ..client_utils import DEFAULT, client_meta_version
3232
from ._base import DEFAULT_CA_CERTS, RERAISE_EXCEPTIONS, BaseNode
3333

34+
try:
35+
from ._urllib3_chain_certs import HTTPSConnectionPool
36+
except (ImportError, AttributeError):
37+
HTTPSConnectionPool = urllib3.HTTPSConnectionPool
38+
3439

3540
class Urllib3HttpNode(BaseNode):
36-
"""Default synchronous node class using the `urllib3` library via HTTP."""
41+
"""Default synchronous node class using the ``urllib3`` library via HTTP"""
3742

3843
_ELASTIC_CLIENT_META = ("ur", client_meta_version(urllib3.__version__))
3944

@@ -45,7 +50,7 @@ def __init__(self, config: NodeConfig):
4550

4651
# if ssl_context provided use SSL by default
4752
if config.scheme == "https" and config.ssl_context:
48-
pool_class = urllib3.HTTPSConnectionPool
53+
pool_class = HTTPSConnectionPool
4954
kw.update(
5055
{
5156
"assert_fingerprint": config.ssl_assert_fingerprint,
@@ -54,7 +59,7 @@ def __init__(self, config: NodeConfig):
5459
)
5560

5661
elif config.scheme == "https":
57-
pool_class = urllib3.HTTPSConnectionPool
62+
pool_class = HTTPSConnectionPool
5863
kw.update(
5964
{
6065
"ssl_version": config.ssl_version,
@@ -147,10 +152,10 @@ def perform_request(
147152
if isinstance(e, (ConnectTimeoutError, ReadTimeoutError)):
148153
raise ConnectionTimeout(
149154
"Connection timed out during request", errors=(e,)
150-
)
155+
) from None
151156
elif isinstance(e, (ssl.SSLError, urllib3.exceptions.SSLError)):
152-
raise TlsError(str(e), errors=(e,))
153-
raise ConnectionError(str(e), errors=(e,))
157+
raise TlsError(str(e), errors=(e,)) from None
158+
raise ConnectionError(str(e), errors=(e,)) from None
154159

155160
response = ApiResponseMeta(
156161
node=self.config,
Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
# Licensed to Elasticsearch B.V. under one or more contributor
2+
# license agreements. See the NOTICE file distributed with
3+
# this work for additional information regarding copyright
4+
# ownership. Elasticsearch B.V. licenses this file to you under
5+
# the Apache License, Version 2.0 (the "License"); you may
6+
# not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
18+
import hashlib
19+
import sys
20+
from binascii import hexlify, unhexlify
21+
from hmac import compare_digest
22+
from typing import Any, List, Optional
23+
24+
import _ssl
25+
import urllib3
26+
27+
from ._base import RERAISE_EXCEPTIONS
28+
29+
if sys.version_info < (3, 10) or sys.implementation.name != "cpython":
30+
raise ImportError("Only supported on CPython 3.10+")
31+
32+
_ENCODING_DER = _ssl.ENCODING_DER
33+
_HASHES_BY_LENGTH = {32: hashlib.md5, 40: hashlib.sha1, 64: hashlib.sha256}
34+
35+
__all__ = ["HTTPSConnectionPool"]
36+
37+
38+
class HTTPSConnectionPool(urllib3.HTTPSConnectionPool):
39+
"""HTTPSConnectionPool implementation which supports ``assert_fingerprint``
40+
on certificates within the chain instead of only the leaf cert using private
41+
APIs in CPython 3.10+
42+
"""
43+
44+
def __init__(
45+
self, *args: Any, assert_fingerprint: Optional[str] = None, **kwargs: Any
46+
) -> None:
47+
self._elastic_assert_fingerprint = (
48+
assert_fingerprint.replace(":", "").lower() if assert_fingerprint else None
49+
)
50+
51+
# Complain about fingerprint length earlier than urllib3 does.
52+
if (
53+
self._elastic_assert_fingerprint
54+
and len(self._elastic_assert_fingerprint) not in _HASHES_BY_LENGTH
55+
):
56+
valid_lengths = "', '".join(map(str, sorted(_HASHES_BY_LENGTH.keys())))
57+
raise ValueError(
58+
f"Fingerprint of invalid length '{len(self._elastic_assert_fingerprint)}'"
59+
f", should be one of '{valid_lengths}'"
60+
)
61+
62+
if assert_fingerprint:
63+
# Falsey but not None. This is a hack to skip fingerprinting by urllib3
64+
# but still set 'is_verified=True' within HTTPSConnectionPool._validate_conn()
65+
kwargs["assert_fingerprint"] = ""
66+
67+
super().__init__(*args, **kwargs)
68+
69+
def _validate_conn(self, conn):
70+
"""
71+
Called right before a request is made, after the socket is created.
72+
"""
73+
super(HTTPSConnectionPool, self)._validate_conn(conn)
74+
75+
if self._elastic_assert_fingerprint:
76+
hash_func = _HASHES_BY_LENGTH[len(self._elastic_assert_fingerprint)]
77+
assert_fingerprint = unhexlify(
78+
self._elastic_assert_fingerprint.lower()
79+
.replace(":", "")
80+
.encode("ascii")
81+
)
82+
83+
fingerprints: List[bytes]
84+
try:
85+
# 'get_verified_chain()' and 'Certificate.public_bytes()' are private APIs
86+
# in CPython 3.10. They're not documented anywhere yet but seem to work
87+
# and we need them for Security on by Default so... onwards we go!
88+
# See: https://github.com/python/cpython/pull/25467
89+
fingerprints = [
90+
hash_func(cert.public_bytes(_ENCODING_DER)).digest()
91+
for cert in conn.sock._sslobj.get_verified_chain()
92+
]
93+
except RERAISE_EXCEPTIONS: # pragma: nocover
94+
raise
95+
# Because these are private APIs we are super careful here
96+
# so that if anything "goes wrong" we fallback on the old behavior.
97+
except Exception: # pragma: nocover
98+
fingerprints = []
99+
100+
# Only add the peercert in front of the chain if it's not there for some reason.
101+
# This is to make sure old behavior of 'ssl_assert_fingerprint' still works.
102+
peercert_fingerprint = hash_func(conn.sock.getpeercert(True)).digest()
103+
if peercert_fingerprint not in fingerprints: # pragma: nocover
104+
fingerprints.insert(0, peercert_fingerprint)
105+
106+
# If any match then that's a success! We always run them
107+
# all through though because of constant time concerns.
108+
success = False
109+
for fingerprint in fingerprints:
110+
success |= compare_digest(fingerprint, assert_fingerprint)
111+
112+
# Give users all the fingerprints we checked against in
113+
# order of peer -> root CA.
114+
if not success:
115+
raise urllib3.exceptions.SSLError(
116+
'Fingerprints did not match. Expected "{0}", got "{1}".'.format(
117+
self._elastic_assert_fingerprint,
118+
'", "'.join([x.decode() for x in map(hexlify, fingerprints)]),
119+
)
120+
)
121+
conn.is_verified = success

tests/node/test_http_aiohttp.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -264,7 +264,7 @@ async def test_ssl_assert_fingerprint(httpbin_cert_fingerprint):
264264
resp, _ = await node.perform_request("GET", "/")
265265

266266
assert resp.status == 200
267-
assert w == []
267+
assert [str(x.message) for x in w if x.category != DeprecationWarning] == []
268268

269269

270270
async def test_default_headers():

0 commit comments

Comments
 (0)