Skip to content

Commit ebbc467

Browse files
authored
Add support for async Transport sniffing
1 parent fb251d0 commit ebbc467

File tree

6 files changed

+437
-27
lines changed

6 files changed

+437
-27
lines changed

elastic_transport/_async_transport.py

Lines changed: 140 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,19 +15,25 @@
1515
# specific language governing permissions and limitations
1616
# under the License.
1717

18-
from typing import Any, List, Optional, Tuple, Type, Union
18+
import asyncio
19+
from typing import Any, Awaitable, Callable, List, Optional, Tuple, Type, Union
1920

21+
from ._compat import await_if_coro, get_running_loop
2022
from ._exceptions import (
2123
HTTP_STATUS_TO_ERROR,
2224
ApiError,
2325
ConnectionError,
2426
ConnectionTimeout,
2527
TransportError,
2628
)
27-
from ._models import ApiResponseMeta, NodeConfig
29+
from ._models import ApiResponseMeta, NodeConfig, SniffOptions
2830
from ._node import AiohttpHttpNode, BaseNode
2931
from ._node_pool import NodePool, NodeSelector
30-
from ._transport import NOT_DEAD_NODE_HTTP_STATUSES, Transport
32+
from ._transport import (
33+
NOT_DEAD_NODE_HTTP_STATUSES,
34+
Transport,
35+
validate_sniffing_options,
36+
)
3137
from .client_utils import DEFAULT, normalize_headers
3238

3339

@@ -52,6 +58,17 @@ def __init__(
5258
max_retries: int = 3,
5359
retry_on_status=(429, 502, 503, 504),
5460
retry_on_timeout: bool = False,
61+
sniff_on_start: bool = False,
62+
sniff_before_requests: bool = False,
63+
sniff_on_node_failure: bool = False,
64+
sniff_timeout: Optional[float] = 1.0,
65+
min_delay_between_sniffing: float = 10.0,
66+
sniff_callback: Optional[
67+
Callable[
68+
["Transport", "SniffOptions"],
69+
Union[List[NodeConfig], Awaitable[List[NodeConfig]]],
70+
]
71+
] = None,
5572
):
5673
"""
5774
:arg node_configs: List of 'NodeConfig' instances to create initial set of nodes.
@@ -74,11 +91,33 @@ def __init__(
7491
on a different node. defaults to ``(429, 502, 503, 504)``
7592
:arg retry_on_timeout: should timeout trigger a retry on different
7693
node? (default ``False``)
77-
78-
Any extra keyword arguments will be passed to the `node_class`
79-
when creating and instance unless overridden by that node's
80-
options provided as part of the hosts parameter.
94+
:arg sniff_on_start: If ``True`` will sniff for additional nodes as soon
95+
as possible, guaranteed before the first request.
96+
:arg sniff_on_node_failure: If ``True`` will sniff for additional nodees
97+
after a node is marked as dead in the pool.
98+
:arg sniff_before_requests: If ``True`` will occasionally sniff for additional
99+
nodes as requests are sent.
100+
:arg sniff_timeout: Timeout value in seconds to use for sniffing requests.
101+
Defaults to 1 second.
102+
:arg min_delay_between_sniffing: Number of seconds to wait between calls to
103+
:meth:`elastic_transport.Transport.sniff` to avoid sniffing too frequently.
104+
Defaults to 10 seconds.
105+
:arg sniff_callback: Function that is passed a :class:`elastic_transport.Transport` and
106+
:class:`elastic_transport.SniffOptions` and should do node discovery and
107+
return a list of :class:`elastic_transport.NodeConfig` instances or a coroutine
108+
that returns the list.
81109
"""
110+
111+
# Since we don't pass all the sniffing options to super().__init__()
112+
# we want to validate the sniffing options here too.
113+
validate_sniffing_options(
114+
node_configs=node_configs,
115+
sniff_on_start=sniff_on_start,
116+
sniff_before_requests=sniff_before_requests,
117+
sniff_on_node_failure=sniff_on_node_failure,
118+
sniff_callback=sniff_callback,
119+
)
120+
82121
super().__init__(
83122
node_configs=node_configs,
84123
node_class=node_class,
@@ -91,8 +130,20 @@ def __init__(
91130
max_retries=max_retries,
92131
retry_on_status=retry_on_status,
93132
retry_on_timeout=retry_on_timeout,
133+
sniff_timeout=sniff_timeout,
134+
min_delay_between_sniffing=min_delay_between_sniffing,
94135
)
95136

137+
self._sniff_on_start = sniff_on_start
138+
self._sniff_before_requests = sniff_before_requests
139+
self._sniff_on_node_failure = sniff_on_node_failure
140+
self._sniff_timeout = sniff_timeout
141+
self._sniff_callback = sniff_callback
142+
self._sniffing_lock = None
143+
self._sniffing_task: Optional[asyncio.Task] = None
144+
self._last_sniffed_at = 0.0
145+
self._loop: Optional[asyncio.BaseEventLoop] = None
146+
96147
async def perform_request(
97148
self,
98149
method: str,
@@ -123,6 +174,8 @@ async def perform_request(
123174
:arg ignore_status: Collection of HTTP status codes to not raise an error for.
124175
:returns: Tuple of the HttpResponse with the deserialized response.
125176
"""
177+
await self._async_init()
178+
126179
if isinstance(ignore_status, int):
127180
ignore_status = {ignore_status}
128181

@@ -143,8 +196,13 @@ async def perform_request(
143196
errors = []
144197

145198
for attempt in range(self.max_retries + 1):
146-
node = self.node_pool.get()
147199

200+
# If we sniff before requests are made we want to do so before
201+
# 'node_pool.get()' is called so our sniffed nodes show up in the pool.
202+
if self._sniff_before_requests:
203+
await self.sniff(False)
204+
205+
node = self.node_pool.get()
148206
try:
149207
response, raw_data = await node.perform_request(
150208
method,
@@ -184,6 +242,14 @@ async def perform_request(
184242
if node_failure:
185243
self.node_pool.mark_dead(node)
186244

245+
if self._sniff_on_node_failure:
246+
try:
247+
await self.sniff(False)
248+
except TransportError:
249+
# If sniffing on failure, it could fail too. Catch the
250+
# exception not to interrupt the retries.
251+
pass
252+
187253
if retry:
188254
# raise exception on last retry
189255
if attempt == self.max_retries:
@@ -200,9 +266,75 @@ async def perform_request(
200266
self.node_pool.mark_live(node)
201267
return response, data
202268

269+
async def sniff(self, is_initial_sniff: bool) -> None:
270+
await self._async_init()
271+
task = self._create_sniffing_task(is_initial_sniff)
272+
273+
# Only block on the task if this is the initial sniff.
274+
# Otherwise we do the sniffing in the background.
275+
if is_initial_sniff and task:
276+
await task
277+
203278
async def close(self) -> None:
204279
"""
205280
Explicitly closes all nodes in the transport's pool
206281
"""
207282
for node in self.node_pool.all():
208283
await node.close()
284+
285+
def _should_sniff(self, is_initial_sniff: bool) -> bool:
286+
"""Decide if we should sniff or not. _async_init() must be called
287+
before using this function.The async implementation doesn't have a lock.
288+
"""
289+
if is_initial_sniff:
290+
return True
291+
292+
# Only start a new sniff if the previous run is completed.
293+
if self._sniffing_task:
294+
if not self._sniffing_task.done():
295+
return False
296+
# If there was a previous run we collect the sniffing task's
297+
# result as it could have failed with an exception.
298+
self._sniffing_task.result()
299+
300+
return (
301+
self._loop.time() - self._last_sniffed_at
302+
>= self._min_delay_between_sniffing
303+
)
304+
305+
def _create_sniffing_task(self, is_initial_sniff: bool) -> Optional[asyncio.Task]:
306+
"""Creates a sniffing task if one should be created and returns the task if created."""
307+
task = None
308+
if self._should_sniff(is_initial_sniff):
309+
# 'self._sniffing_task' is unset within the task implementation.
310+
task = self._loop.create_task(self._sniffing_task_impl(is_initial_sniff))
311+
self._sniffing_task = task
312+
return task
313+
314+
async def _sniffing_task_impl(self, is_initial_sniff: bool) -> None:
315+
"""Implementation of the sniffing task"""
316+
previously_sniffed_at = self._last_sniffed_at
317+
try:
318+
self._last_sniffed_at = self._loop.time()
319+
options = SniffOptions(
320+
is_initial_sniff=is_initial_sniff, sniff_timeout=self._sniff_timeout
321+
)
322+
for node_config in await await_if_coro(self._sniff_callback(self, options)):
323+
self.node_pool.add(node_config)
324+
325+
# If sniffing failed for any reason we
326+
# want to allow retrying immediately.
327+
except BaseException:
328+
self._last_sniffed_at = previously_sniffed_at
329+
raise
330+
331+
async def _async_init(self) -> None:
332+
"""Async constructor which is called on the first call to perform_request()
333+
because we're not guaranteed to be within an active asyncio event loop
334+
when __init__() is called.
335+
"""
336+
if self._loop is not None:
337+
return # Call at most once!
338+
self._loop = get_running_loop()
339+
if self._sniff_on_start:
340+
await self.sniff(True)

elastic_transport/_compat.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import inspect
2020
import sys
2121
from pathlib import Path
22+
from typing import Awaitable, TypeVar, Union
2223
from urllib.parse import quote as _quote
2324
from urllib.parse import urlencode, urlparse
2425

@@ -45,6 +46,15 @@ def get_running_loop():
4546
return loop
4647

4748

49+
T = TypeVar("T")
50+
51+
52+
async def await_if_coro(coro: Union[T, Awaitable[T]]) -> T:
53+
if inspect.iscoroutine(coro):
54+
return await coro
55+
return coro
56+
57+
4858
_QUOTE_ALWAYS_SAFE = frozenset(
4959
"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789_.-~"
5060
)
@@ -69,7 +79,7 @@ def __enter__(self) -> None:
6979
def __exit__(self, *_) -> None:
7080
pass
7181

72-
def acquire(self, blocking: bool = True) -> bool:
82+
def acquire(self, _: bool = True) -> bool:
7383
return True
7484

7585
def release(self) -> None:
@@ -109,12 +119,11 @@ def warn_stacklevel() -> int:
109119
return level
110120
except KeyError:
111121
pass
112-
except Exception:
113-
return 2
114122
return 0
115123

116124

117125
__all__ = [
126+
"await_if_coro",
118127
"get_running_loop",
119128
"ordered_dict",
120129
"quote",

elastic_transport/_transport.py

Lines changed: 30 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -126,21 +126,13 @@ def __init__(
126126
)
127127
node_class = NODE_CLASS_NAMES[node_class]
128128

129-
# Additional requirements for when sniffing is enabled
130-
sniffing_enabled = (
131-
sniff_before_requests or sniff_on_start or sniff_on_node_failure
129+
validate_sniffing_options(
130+
node_configs=node_configs,
131+
sniff_on_start=sniff_on_start,
132+
sniff_before_requests=sniff_before_requests,
133+
sniff_on_node_failure=sniff_on_node_failure,
134+
sniff_callback=sniff_callback,
132135
)
133-
if sniffing_enabled and not sniff_callback:
134-
raise ValueError("Enabling sniffing requires specifying a 'sniff_callback'")
135-
if not sniffing_enabled and sniff_callback:
136-
raise ValueError(
137-
"Using 'sniff_callback' requires enabling sniffing via 'sniff_on_start', "
138-
"'sniff_before_requests' or 'sniff_on_node_failure'"
139-
)
140-
141-
# If we're sniffing we want to warn the user for non-homogenous NodeConfigs.
142-
if sniffing_enabled and len(node_configs) > 1:
143-
warn_if_varying_node_config_options(node_configs)
144136

145137
# Create the default metadata for the x-elastic-client-meta
146138
# HTTP header. Only requires adding the (service, service_version)
@@ -358,6 +350,30 @@ def _should_sniff(self, is_initial_sniff: bool) -> bool:
358350
return self._sniffing_lock.acquire(False)
359351

360352

353+
def validate_sniffing_options(
354+
*,
355+
node_configs: List[NodeConfig],
356+
sniff_before_requests: bool,
357+
sniff_on_start: bool,
358+
sniff_on_node_failure: bool,
359+
sniff_callback: Optional[Any],
360+
) -> None:
361+
"""Validates the Transport configurations for sniffing"""
362+
363+
sniffing_enabled = sniff_before_requests or sniff_on_start or sniff_on_node_failure
364+
if sniffing_enabled and not sniff_callback:
365+
raise ValueError("Enabling sniffing requires specifying a 'sniff_callback'")
366+
if not sniffing_enabled and sniff_callback:
367+
raise ValueError(
368+
"Using 'sniff_callback' requires enabling sniffing via 'sniff_on_start', "
369+
"'sniff_before_requests' or 'sniff_on_node_failure'"
370+
)
371+
372+
# If we're sniffing we want to warn the user for non-homogenous NodeConfigs.
373+
if sniffing_enabled and len(node_configs) > 1:
374+
warn_if_varying_node_config_options(node_configs)
375+
376+
361377
def warn_if_varying_node_config_options(node_configs: List[NodeConfig]) -> None:
362378
"""Function which detects situations when sniffing may product incorrect configs"""
363379
exempt_attrs = {"host", "port", "connections_per_node", "_extras"}

0 commit comments

Comments
 (0)