1515# specific language governing permissions and limitations
1616# under the License.
1717
18- from typing import Any , List , Optional , Tuple , Type , Union
18+ import asyncio
19+ from typing import Any , Awaitable , Callable , List , Optional , Tuple , Type , Union
1920
21+ from ._compat import await_if_coro , get_running_loop
2022from ._exceptions import (
2123 HTTP_STATUS_TO_ERROR ,
2224 ApiError ,
2325 ConnectionError ,
2426 ConnectionTimeout ,
2527 TransportError ,
2628)
27- from ._models import ApiResponseMeta , NodeConfig
29+ from ._models import ApiResponseMeta , NodeConfig , SniffOptions
2830from ._node import AiohttpHttpNode , BaseNode
2931from ._node_pool import NodePool , NodeSelector
30- from ._transport import NOT_DEAD_NODE_HTTP_STATUSES , Transport
32+ from ._transport import (
33+ NOT_DEAD_NODE_HTTP_STATUSES ,
34+ Transport ,
35+ validate_sniffing_options ,
36+ )
3137from .client_utils import DEFAULT , normalize_headers
3238
3339
@@ -52,6 +58,17 @@ def __init__(
5258 max_retries : int = 3 ,
5359 retry_on_status = (429 , 502 , 503 , 504 ),
5460 retry_on_timeout : bool = False ,
61+ sniff_on_start : bool = False ,
62+ sniff_before_requests : bool = False ,
63+ sniff_on_node_failure : bool = False ,
64+ sniff_timeout : Optional [float ] = 1.0 ,
65+ min_delay_between_sniffing : float = 10.0 ,
66+ sniff_callback : Optional [
67+ Callable [
68+ ["Transport" , "SniffOptions" ],
69+ Union [List [NodeConfig ], Awaitable [List [NodeConfig ]]],
70+ ]
71+ ] = None ,
5572 ):
5673 """
5774 :arg node_configs: List of 'NodeConfig' instances to create initial set of nodes.
@@ -74,11 +91,33 @@ def __init__(
7491 on a different node. defaults to ``(429, 502, 503, 504)``
7592 :arg retry_on_timeout: should timeout trigger a retry on different
7693 node? (default ``False``)
77-
78- Any extra keyword arguments will be passed to the `node_class`
79- when creating and instance unless overridden by that node's
80- options provided as part of the hosts parameter.
94+ :arg sniff_on_start: If ``True`` will sniff for additional nodes as soon
95+ as possible, guaranteed before the first request.
96+ :arg sniff_on_node_failure: If ``True`` will sniff for additional nodees
97+ after a node is marked as dead in the pool.
98+ :arg sniff_before_requests: If ``True`` will occasionally sniff for additional
99+ nodes as requests are sent.
100+ :arg sniff_timeout: Timeout value in seconds to use for sniffing requests.
101+ Defaults to 1 second.
102+ :arg min_delay_between_sniffing: Number of seconds to wait between calls to
103+ :meth:`elastic_transport.Transport.sniff` to avoid sniffing too frequently.
104+ Defaults to 10 seconds.
105+ :arg sniff_callback: Function that is passed a :class:`elastic_transport.Transport` and
106+ :class:`elastic_transport.SniffOptions` and should do node discovery and
107+ return a list of :class:`elastic_transport.NodeConfig` instances or a coroutine
108+ that returns the list.
81109 """
110+
111+ # Since we don't pass all the sniffing options to super().__init__()
112+ # we want to validate the sniffing options here too.
113+ validate_sniffing_options (
114+ node_configs = node_configs ,
115+ sniff_on_start = sniff_on_start ,
116+ sniff_before_requests = sniff_before_requests ,
117+ sniff_on_node_failure = sniff_on_node_failure ,
118+ sniff_callback = sniff_callback ,
119+ )
120+
82121 super ().__init__ (
83122 node_configs = node_configs ,
84123 node_class = node_class ,
@@ -91,8 +130,20 @@ def __init__(
91130 max_retries = max_retries ,
92131 retry_on_status = retry_on_status ,
93132 retry_on_timeout = retry_on_timeout ,
133+ sniff_timeout = sniff_timeout ,
134+ min_delay_between_sniffing = min_delay_between_sniffing ,
94135 )
95136
137+ self ._sniff_on_start = sniff_on_start
138+ self ._sniff_before_requests = sniff_before_requests
139+ self ._sniff_on_node_failure = sniff_on_node_failure
140+ self ._sniff_timeout = sniff_timeout
141+ self ._sniff_callback = sniff_callback
142+ self ._sniffing_lock = None
143+ self ._sniffing_task : Optional [asyncio .Task ] = None
144+ self ._last_sniffed_at = 0.0
145+ self ._loop : Optional [asyncio .BaseEventLoop ] = None
146+
96147 async def perform_request (
97148 self ,
98149 method : str ,
@@ -123,6 +174,8 @@ async def perform_request(
123174 :arg ignore_status: Collection of HTTP status codes to not raise an error for.
124175 :returns: Tuple of the HttpResponse with the deserialized response.
125176 """
177+ await self ._async_init ()
178+
126179 if isinstance (ignore_status , int ):
127180 ignore_status = {ignore_status }
128181
@@ -143,8 +196,13 @@ async def perform_request(
143196 errors = []
144197
145198 for attempt in range (self .max_retries + 1 ):
146- node = self .node_pool .get ()
147199
200+ # If we sniff before requests are made we want to do so before
201+ # 'node_pool.get()' is called so our sniffed nodes show up in the pool.
202+ if self ._sniff_before_requests :
203+ await self .sniff (False )
204+
205+ node = self .node_pool .get ()
148206 try :
149207 response , raw_data = await node .perform_request (
150208 method ,
@@ -184,6 +242,14 @@ async def perform_request(
184242 if node_failure :
185243 self .node_pool .mark_dead (node )
186244
245+ if self ._sniff_on_node_failure :
246+ try :
247+ await self .sniff (False )
248+ except TransportError :
249+ # If sniffing on failure, it could fail too. Catch the
250+ # exception not to interrupt the retries.
251+ pass
252+
187253 if retry :
188254 # raise exception on last retry
189255 if attempt == self .max_retries :
@@ -200,9 +266,75 @@ async def perform_request(
200266 self .node_pool .mark_live (node )
201267 return response , data
202268
269+ async def sniff (self , is_initial_sniff : bool ) -> None :
270+ await self ._async_init ()
271+ task = self ._create_sniffing_task (is_initial_sniff )
272+
273+ # Only block on the task if this is the initial sniff.
274+ # Otherwise we do the sniffing in the background.
275+ if is_initial_sniff and task :
276+ await task
277+
203278 async def close (self ) -> None :
204279 """
205280 Explicitly closes all nodes in the transport's pool
206281 """
207282 for node in self .node_pool .all ():
208283 await node .close ()
284+
285+ def _should_sniff (self , is_initial_sniff : bool ) -> bool :
286+ """Decide if we should sniff or not. _async_init() must be called
287+ before using this function.The async implementation doesn't have a lock.
288+ """
289+ if is_initial_sniff :
290+ return True
291+
292+ # Only start a new sniff if the previous run is completed.
293+ if self ._sniffing_task :
294+ if not self ._sniffing_task .done ():
295+ return False
296+ # If there was a previous run we collect the sniffing task's
297+ # result as it could have failed with an exception.
298+ self ._sniffing_task .result ()
299+
300+ return (
301+ self ._loop .time () - self ._last_sniffed_at
302+ >= self ._min_delay_between_sniffing
303+ )
304+
305+ def _create_sniffing_task (self , is_initial_sniff : bool ) -> Optional [asyncio .Task ]:
306+ """Creates a sniffing task if one should be created and returns the task if created."""
307+ task = None
308+ if self ._should_sniff (is_initial_sniff ):
309+ # 'self._sniffing_task' is unset within the task implementation.
310+ task = self ._loop .create_task (self ._sniffing_task_impl (is_initial_sniff ))
311+ self ._sniffing_task = task
312+ return task
313+
314+ async def _sniffing_task_impl (self , is_initial_sniff : bool ) -> None :
315+ """Implementation of the sniffing task"""
316+ previously_sniffed_at = self ._last_sniffed_at
317+ try :
318+ self ._last_sniffed_at = self ._loop .time ()
319+ options = SniffOptions (
320+ is_initial_sniff = is_initial_sniff , sniff_timeout = self ._sniff_timeout
321+ )
322+ for node_config in await await_if_coro (self ._sniff_callback (self , options )):
323+ self .node_pool .add (node_config )
324+
325+ # If sniffing failed for any reason we
326+ # want to allow retrying immediately.
327+ except BaseException :
328+ self ._last_sniffed_at = previously_sniffed_at
329+ raise
330+
331+ async def _async_init (self ) -> None :
332+ """Async constructor which is called on the first call to perform_request()
333+ because we're not guaranteed to be within an active asyncio event loop
334+ when __init__() is called.
335+ """
336+ if self ._loop is not None :
337+ return # Call at most once!
338+ self ._loop = get_running_loop ()
339+ if self ._sniff_on_start :
340+ await self .sniff (True )
0 commit comments