Skip to content

Commit 4817a26

Browse files
authored
Added pub/sub support for MultiDBClient (#3764)
* Extract additional interfaces and abstract classes * Added base async components * Added command executor * Added recurring background tasks with event loop only * Added MultiDBClient * Added scenario and config tests * Added pipeline and transaction support for MultiDBClient * Added pub/sub support for MultiDBClient * Added check for couroutines methods for pub/sub
1 parent 1dfffd2 commit 4817a26

File tree

5 files changed

+197
-19
lines changed

5 files changed

+197
-19
lines changed

redis/asyncio/client.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1191,6 +1191,7 @@ async def run(
11911191
*,
11921192
exception_handler: Optional["PSWorkerThreadExcHandlerT"] = None,
11931193
poll_timeout: float = 1.0,
1194+
pubsub = None
11941195
) -> None:
11951196
"""Process pub/sub messages using registered callbacks.
11961197
@@ -1215,9 +1216,14 @@ async def run(
12151216
await self.connect()
12161217
while True:
12171218
try:
1218-
await self.get_message(
1219-
ignore_subscribe_messages=True, timeout=poll_timeout
1220-
)
1219+
if pubsub is None:
1220+
await self.get_message(
1221+
ignore_subscribe_messages=True, timeout=poll_timeout
1222+
)
1223+
else:
1224+
await pubsub.get_message(
1225+
ignore_subscribe_messages=True, timeout=poll_timeout
1226+
)
12211227
except asyncio.CancelledError:
12221228
raise
12231229
except BaseException as e:

redis/asyncio/multidb/client.py

Lines changed: 133 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import asyncio
22
from typing import Callable, Optional, Coroutine, Any, List, Union, Awaitable
33

4+
from redis.asyncio.client import PubSubHandler
45
from redis.asyncio.multidb.command_executor import DefaultCommandExecutor
56
from redis.asyncio.multidb.database import AsyncDatabase, Databases
67
from redis.asyncio.multidb.failure_detector import AsyncFailureDetector
@@ -10,7 +11,7 @@
1011
from redis.background import BackgroundScheduler
1112
from redis.commands import AsyncRedisModuleCommands, AsyncCoreCommands
1213
from redis.multidb.exception import NoValidDatabaseException
13-
from redis.typing import KeyT
14+
from redis.typing import KeyT, EncodableT, ChannelT
1415

1516

1617
class MultiDBClient(AsyncRedisModuleCommands, AsyncCoreCommands):
@@ -222,6 +223,17 @@ async def transaction(
222223
watch_delay=watch_delay,
223224
)
224225

226+
async def pubsub(self, **kwargs):
227+
"""
228+
Return a Publish/Subscribe object. With this object, you can
229+
subscribe to channels and listen for messages that get published to
230+
them.
231+
"""
232+
if not self.initialized:
233+
await self.initialize()
234+
235+
return PubSub(self, **kwargs)
236+
225237
async def _check_databases_health(
226238
self,
227239
on_error: Optional[Callable[[Exception], Coroutine[Any, Any, None]]] = None,
@@ -340,4 +352,123 @@ async def execute(self) -> List[Any]:
340352
try:
341353
return await self._client.command_executor.execute_pipeline(tuple(self._command_stack))
342354
finally:
343-
await self.reset()
355+
await self.reset()
356+
357+
class PubSub:
358+
"""
359+
PubSub object for multi database client.
360+
"""
361+
def __init__(self, client: MultiDBClient, **kwargs):
362+
"""Initialize the PubSub object for a multi-database client.
363+
364+
Args:
365+
client: MultiDBClient instance to use for pub/sub operations
366+
**kwargs: Additional keyword arguments to pass to the underlying pubsub implementation
367+
"""
368+
369+
self._client = client
370+
self._client.command_executor.pubsub(**kwargs)
371+
372+
async def __aenter__(self) -> "PubSub":
373+
return self
374+
375+
async def __aexit__(self, exc_type, exc_value, traceback) -> None:
376+
await self.aclose()
377+
378+
async def aclose(self):
379+
return await self._client.command_executor.execute_pubsub_method('aclose')
380+
381+
@property
382+
def subscribed(self) -> bool:
383+
return self._client.command_executor.active_pubsub.subscribed
384+
385+
async def execute_command(self, *args: EncodableT):
386+
return await self._client.command_executor.execute_pubsub_method('execute_command', *args)
387+
388+
async def psubscribe(self, *args: ChannelT, **kwargs: PubSubHandler):
389+
"""
390+
Subscribe to channel patterns. Patterns supplied as keyword arguments
391+
expect a pattern name as the key and a callable as the value. A
392+
pattern's callable will be invoked automatically when a message is
393+
received on that pattern rather than producing a message via
394+
``listen()``.
395+
"""
396+
return await self._client.command_executor.execute_pubsub_method(
397+
'psubscribe',
398+
*args,
399+
**kwargs
400+
)
401+
402+
async def punsubscribe(self, *args: ChannelT):
403+
"""
404+
Unsubscribe from the supplied patterns. If empty, unsubscribe from
405+
all patterns.
406+
"""
407+
return await self._client.command_executor.execute_pubsub_method(
408+
'punsubscribe',
409+
*args
410+
)
411+
412+
async def subscribe(self, *args: ChannelT, **kwargs: Callable):
413+
"""
414+
Subscribe to channels. Channels supplied as keyword arguments expect
415+
a channel name as the key and a callable as the value. A channel's
416+
callable will be invoked automatically when a message is received on
417+
that channel rather than producing a message via ``listen()`` or
418+
``get_message()``.
419+
"""
420+
return await self._client.command_executor.execute_pubsub_method(
421+
'subscribe',
422+
*args,
423+
**kwargs
424+
)
425+
426+
async def unsubscribe(self, *args):
427+
"""
428+
Unsubscribe from the supplied channels. If empty, unsubscribe from
429+
all channels
430+
"""
431+
return await self._client.command_executor.execute_pubsub_method(
432+
'unsubscribe',
433+
*args
434+
)
435+
436+
async def get_message(
437+
self, ignore_subscribe_messages: bool = False, timeout: Optional[float] = 0.0
438+
):
439+
"""
440+
Get the next message if one is available, otherwise None.
441+
442+
If timeout is specified, the system will wait for `timeout` seconds
443+
before returning. Timeout should be specified as a floating point
444+
number or None to wait indefinitely.
445+
"""
446+
return await self._client.command_executor.execute_pubsub_method(
447+
'get_message',
448+
ignore_subscribe_messages=ignore_subscribe_messages, timeout=timeout
449+
)
450+
451+
async def run(
452+
self,
453+
*,
454+
exception_handler: Optional["PSWorkerThreadExcHandlerT"] = None,
455+
poll_timeout: float = 1.0,
456+
) -> None:
457+
"""Process pub/sub messages using registered callbacks.
458+
459+
This is the equivalent of :py:meth:`redis.PubSub.run_in_thread` in
460+
redis-py, but it is a coroutine. To launch it as a separate task, use
461+
``asyncio.create_task``:
462+
463+
>>> task = asyncio.create_task(pubsub.run())
464+
465+
To shut it down, use asyncio cancellation:
466+
467+
>>> task.cancel()
468+
>>> await task
469+
"""
470+
return await self._client.command_executor.execute_pubsub_run(
471+
exception_handler=exception_handler,
472+
sleep_time=poll_timeout,
473+
pubsub=self
474+
)

redis/asyncio/multidb/command_executor.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from abc import abstractmethod
2+
from asyncio import iscoroutinefunction
23
from datetime import datetime
34
from typing import List, Optional, Callable, Any, Union, Awaitable
45

@@ -178,14 +179,10 @@ def failover_strategy(self) -> AsyncFailoverStrategy:
178179
def command_retry(self) -> Retry:
179180
return self._command_retry
180181

181-
async def pubsub(self, **kwargs):
182-
async def callback():
183-
if self._active_pubsub is None:
184-
self._active_pubsub = self._active_database.client.pubsub(**kwargs)
185-
self._active_pubsub_kwargs = kwargs
186-
return None
187-
188-
return await self._execute_with_failure_detection(callback)
182+
def pubsub(self, **kwargs):
183+
if self._active_pubsub is None:
184+
self._active_pubsub = self._active_database.client.pubsub(**kwargs)
185+
self._active_pubsub_kwargs = kwargs
189186

190187
async def execute_command(self, *args, **options):
191188
async def callback():
@@ -225,7 +222,10 @@ async def callback():
225222
async def execute_pubsub_method(self, method_name: str, *args, **kwargs):
226223
async def callback():
227224
method = getattr(self.active_pubsub, method_name)
228-
return await method(*args, **kwargs)
225+
if iscoroutinefunction(method):
226+
return await method(*args, **kwargs)
227+
else:
228+
return method(*args, **kwargs)
229229

230230
return await self._execute_with_failure_detection(callback, *args)
231231

redis/multidb/client.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -337,9 +337,6 @@ def __init__(self, client: MultiDBClient, **kwargs):
337337
def __enter__(self) -> "PubSub":
338338
return self
339339

340-
def __exit__(self, exc_type, exc_value, traceback) -> None:
341-
self.reset()
342-
343340
def __del__(self) -> None:
344341
try:
345342
# if this object went out of scope prior to shutting down
@@ -350,7 +347,7 @@ def __del__(self) -> None:
350347
pass
351348

352349
def reset(self) -> None:
353-
pass
350+
return self._client.command_executor.execute_pubsub_method('reset')
354351

355352
def close(self) -> None:
356353
self.reset()
@@ -359,6 +356,9 @@ def close(self) -> None:
359356
def subscribed(self) -> bool:
360357
return self._client.command_executor.active_pubsub.subscribed
361358

359+
def execute_command(self, *args):
360+
return self._client.command_executor.execute_pubsub_method('execute_command', *args)
361+
362362
def psubscribe(self, *args, **kwargs):
363363
"""
364364
Subscribe to channel patterns. Patterns supplied as keyword arguments

tests/test_asyncio/test_scenario/test_active_active.py

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import asyncio
2+
import json
23
import logging
34
from time import sleep
45

@@ -186,4 +187,44 @@ async def callback(pipe: Pipeline):
186187
# Execute transaction until database failover
187188
while not listener.is_changed_flag:
188189
await r_multi_db.transaction(callback) == [True, True, True, 'value1', 'value2', 'value3']
189-
await asyncio.sleep(0.5)
190+
await asyncio.sleep(0.5)
191+
192+
@pytest.mark.asyncio
193+
@pytest.mark.parametrize(
194+
"r_multi_db",
195+
[{"failure_threshold": 2}],
196+
indirect=True
197+
)
198+
@pytest.mark.timeout(50)
199+
async def test_pubsub_failover_to_another_db(self, r_multi_db, fault_injector_client):
200+
r_multi_db, listener, config = r_multi_db
201+
202+
event = asyncio.Event()
203+
asyncio.create_task(trigger_network_failure_action(fault_injector_client,config,event))
204+
205+
data = json.dumps({'message': 'test'})
206+
messages_count = 0
207+
208+
async def handler(message):
209+
nonlocal messages_count
210+
messages_count += 1
211+
212+
pubsub = await r_multi_db.pubsub()
213+
214+
# Assign a handler and run in a separate thread.
215+
await pubsub.subscribe(**{'test-channel': handler})
216+
task = asyncio.create_task(pubsub.run(poll_timeout=0.1))
217+
218+
# Execute publish before network failure
219+
while not event.is_set():
220+
await r_multi_db.publish('test-channel', data)
221+
await asyncio.sleep(0.5)
222+
223+
# Execute publish until database failover
224+
while not listener.is_changed_flag:
225+
await r_multi_db.publish('test-channel', data)
226+
await asyncio.sleep(0.5)
227+
228+
task.cancel()
229+
await pubsub.unsubscribe('test-channel') is True
230+
assert messages_count > 1

0 commit comments

Comments
 (0)