Skip to content

Commit 1dfffd2

Browse files
authored
Added pipeline and transaction support for MultiDBClient (#3763)
* Extract additional interfaces and abstract classes * Added base async components * Added command executor * Added recurring background tasks with event loop only * Added MultiDBClient * Added scenario and config tests * Added pipeline and transaction support for MultiDBClient * Updated scenario tests to check failover
1 parent ec8113b commit 1dfffd2

File tree

5 files changed

+585
-12
lines changed

5 files changed

+585
-12
lines changed

redis/asyncio/multidb/client.py

Lines changed: 110 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import asyncio
2-
from typing import Callable, Optional, Coroutine, Any
2+
from typing import Callable, Optional, Coroutine, Any, List, Union, Awaitable
33

44
from redis.asyncio.multidb.command_executor import DefaultCommandExecutor
55
from redis.asyncio.multidb.database import AsyncDatabase, Databases
@@ -10,6 +10,7 @@
1010
from redis.background import BackgroundScheduler
1111
from redis.commands import AsyncRedisModuleCommands, AsyncCoreCommands
1212
from redis.multidb.exception import NoValidDatabaseException
13+
from redis.typing import KeyT
1314

1415

1516
class MultiDBClient(AsyncRedisModuleCommands, AsyncCoreCommands):
@@ -49,6 +50,19 @@ def __init__(self, config: MultiDbConfig):
4950
self._hc_lock = asyncio.Lock()
5051
self._bg_scheduler = BackgroundScheduler()
5152
self._config = config
53+
self._hc_task = None
54+
self._half_open_state_task = None
55+
56+
async def __aenter__(self: "MultiDBClient") -> "MultiDBClient":
57+
if not self.initialized:
58+
await self.initialize()
59+
return self
60+
61+
async def __aexit__(self, exc_type, exc_value, traceback):
62+
if self._hc_task:
63+
self._hc_task.cancel()
64+
if self._half_open_state_task:
65+
self._half_open_state_task.cancel()
5266

5367
async def initialize(self):
5468
"""
@@ -61,7 +75,7 @@ async def raise_exception_on_failed_hc(error):
6175
await self._check_databases_health(on_error=raise_exception_on_failed_hc)
6276

6377
# Starts recurring health checks on the background.
64-
asyncio.create_task(self._bg_scheduler.run_recurring_async(
78+
self._hc_task = asyncio.create_task(self._bg_scheduler.run_recurring_async(
6579
self._health_check_interval,
6680
self._check_databases_health,
6781
))
@@ -180,6 +194,34 @@ async def execute_command(self, *args, **options):
180194

181195
return await self.command_executor.execute_command(*args, **options)
182196

197+
def pipeline(self):
198+
"""
199+
Enters into pipeline mode of the client.
200+
"""
201+
return Pipeline(self)
202+
203+
async def transaction(
204+
self,
205+
func: Callable[["Pipeline"], Union[Any, Awaitable[Any]]],
206+
*watches: KeyT,
207+
shard_hint: Optional[str] = None,
208+
value_from_callable: bool = False,
209+
watch_delay: Optional[float] = None,
210+
):
211+
"""
212+
Executes callable as transaction.
213+
"""
214+
if not self.initialized:
215+
await self.initialize()
216+
217+
return await self.command_executor.execute_transaction(
218+
func,
219+
*watches,
220+
shard_hint=shard_hint,
221+
value_from_callable=value_from_callable,
222+
watch_delay=watch_delay,
223+
)
224+
183225
async def _check_databases_health(
184226
self,
185227
on_error: Optional[Callable[[Exception], Coroutine[Any, Any, None]]] = None,
@@ -227,11 +269,75 @@ def _on_circuit_state_change_callback(self, circuit: CircuitBreaker, old_state:
227269
loop = asyncio.get_running_loop()
228270

229271
if new_state == CBState.HALF_OPEN:
230-
asyncio.create_task(self._check_db_health(circuit.database))
272+
self._half_open_state_task = asyncio.create_task(self._check_db_health(circuit.database))
231273
return
232274

233275
if old_state == CBState.CLOSED and new_state == CBState.OPEN:
234276
loop.call_later(DEFAULT_GRACE_PERIOD, _half_open_circuit, circuit)
235277

236278
def _half_open_circuit(circuit: CircuitBreaker):
237-
circuit.state = CBState.HALF_OPEN
279+
circuit.state = CBState.HALF_OPEN
280+
281+
class Pipeline(AsyncRedisModuleCommands, AsyncCoreCommands):
282+
"""
283+
Pipeline implementation for multiple logical Redis databases.
284+
"""
285+
def __init__(self, client: MultiDBClient):
286+
self._command_stack = []
287+
self._client = client
288+
289+
async def __aenter__(self: "Pipeline") -> "Pipeline":
290+
return self
291+
292+
async def __aexit__(self, exc_type, exc_value, traceback):
293+
await self.reset()
294+
await self._client.__aexit__(exc_type, exc_value, traceback)
295+
296+
def __await__(self):
297+
return self._async_self().__await__()
298+
299+
async def _async_self(self):
300+
return self
301+
302+
def __len__(self) -> int:
303+
return len(self._command_stack)
304+
305+
def __bool__(self) -> bool:
306+
"""Pipeline instances should always evaluate to True"""
307+
return True
308+
309+
async def reset(self) -> None:
310+
self._command_stack = []
311+
312+
async def aclose(self) -> None:
313+
"""Close the pipeline"""
314+
await self.reset()
315+
316+
def pipeline_execute_command(self, *args, **options) -> "Pipeline":
317+
"""
318+
Stage a command to be executed when execute() is next called
319+
320+
Returns the current Pipeline object back so commands can be
321+
chained together, such as:
322+
323+
pipe = pipe.set('foo', 'bar').incr('baz').decr('bang')
324+
325+
At some other point, you can then run: pipe.execute(),
326+
which will execute all commands queued in the pipe.
327+
"""
328+
self._command_stack.append((args, options))
329+
return self
330+
331+
def execute_command(self, *args, **kwargs):
332+
"""Adds a command to the stack"""
333+
return self.pipeline_execute_command(*args, **kwargs)
334+
335+
async def execute(self) -> List[Any]:
336+
"""Execute all the commands in the current pipeline"""
337+
if not self._client.initialized:
338+
await self._client.initialize()
339+
340+
try:
341+
return await self._client.command_executor.execute_pipeline(tuple(self._command_stack))
342+
finally:
343+
await self.reset()

redis/asyncio/multidb/command_executor.py

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from abc import abstractmethod
22
from datetime import datetime
3-
from typing import List, Optional, Callable, Any
3+
from typing import List, Optional, Callable, Any, Union, Awaitable
44

55
from redis.asyncio.client import PubSub, Pipeline
66
from redis.asyncio.multidb.database import Databases, AsyncDatabase, Database
@@ -13,6 +13,7 @@
1313
from redis.event import EventDispatcherInterface, AsyncOnCommandsFailEvent
1414
from redis.multidb.command_executor import CommandExecutor, BaseCommandExecutor
1515
from redis.multidb.config import DEFAULT_AUTO_FALLBACK_INTERVAL
16+
from redis.typing import KeyT
1617

1718

1819
class AsyncCommandExecutor(CommandExecutor):
@@ -194,17 +195,30 @@ async def callback():
194195

195196
async def execute_pipeline(self, command_stack: tuple):
196197
async def callback():
197-
with self._active_database.client.pipeline() as pipe:
198+
async with self._active_database.client.pipeline() as pipe:
198199
for command, options in command_stack:
199-
await pipe.execute_command(*command, **options)
200+
pipe.execute_command(*command, **options)
200201

201202
return await pipe.execute()
202203

203204
return await self._execute_with_failure_detection(callback, command_stack)
204205

205-
async def execute_transaction(self, transaction: Callable[[Pipeline], None], *watches, **options):
206+
async def execute_transaction(
207+
self,
208+
func: Callable[["Pipeline"], Union[Any, Awaitable[Any]]],
209+
*watches: KeyT,
210+
shard_hint: Optional[str] = None,
211+
value_from_callable: bool = False,
212+
watch_delay: Optional[float] = None,
213+
):
206214
async def callback():
207-
return await self._active_database.client.transaction(transaction, *watches, **options)
215+
return await self._active_database.client.transaction(
216+
func,
217+
*watches,
218+
shard_hint=shard_hint,
219+
value_from_callable=value_from_callable,
220+
watch_delay=watch_delay
221+
)
208222

209223
return await self._execute_with_failure_detection(callback)
210224

0 commit comments

Comments
 (0)