Skip to content
114 changes: 110 additions & 4 deletions redis/asyncio/multidb/client.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import asyncio
from typing import Callable, Optional, Coroutine, Any
from typing import Callable, Optional, Coroutine, Any, List, Union, Awaitable

from redis.asyncio.multidb.command_executor import DefaultCommandExecutor
from redis.asyncio.multidb.database import AsyncDatabase, Databases
Expand All @@ -10,6 +10,7 @@
from redis.background import BackgroundScheduler
from redis.commands import AsyncRedisModuleCommands, AsyncCoreCommands
from redis.multidb.exception import NoValidDatabaseException
from redis.typing import KeyT


class MultiDBClient(AsyncRedisModuleCommands, AsyncCoreCommands):
Expand Down Expand Up @@ -49,6 +50,19 @@ def __init__(self, config: MultiDbConfig):
self._hc_lock = asyncio.Lock()
self._bg_scheduler = BackgroundScheduler()
self._config = config
self._hc_task = None
self._half_open_state_task = None

async def __aenter__(self: "MultiDBClient") -> "MultiDBClient":
if not self.initialized:
await self.initialize()
return self

async def __aexit__(self, exc_type, exc_value, traceback):
if self._hc_task:
self._hc_task.cancel()
if self._half_open_state_task:
self._half_open_state_task.cancel()

async def initialize(self):
"""
Expand All @@ -61,7 +75,7 @@ async def raise_exception_on_failed_hc(error):
await self._check_databases_health(on_error=raise_exception_on_failed_hc)

# Starts recurring health checks on the background.
asyncio.create_task(self._bg_scheduler.run_recurring_async(
self._hc_task = asyncio.create_task(self._bg_scheduler.run_recurring_async(
self._health_check_interval,
self._check_databases_health,
))
Expand Down Expand Up @@ -180,6 +194,34 @@ async def execute_command(self, *args, **options):

return await self.command_executor.execute_command(*args, **options)

def pipeline(self):
"""
Enters into pipeline mode of the client.
"""
return Pipeline(self)

async def transaction(
self,
func: Callable[["Pipeline"], Union[Any, Awaitable[Any]]],
*watches: KeyT,
shard_hint: Optional[str] = None,
value_from_callable: bool = False,
watch_delay: Optional[float] = None,
):
"""
Executes callable as transaction.
"""
if not self.initialized:
await self.initialize()

return await self.command_executor.execute_transaction(
func,
*watches,
shard_hint=shard_hint,
value_from_callable=value_from_callable,
watch_delay=watch_delay,
)

async def _check_databases_health(
self,
on_error: Optional[Callable[[Exception], Coroutine[Any, Any, None]]] = None,
Expand Down Expand Up @@ -227,11 +269,75 @@ def _on_circuit_state_change_callback(self, circuit: CircuitBreaker, old_state:
loop = asyncio.get_running_loop()

if new_state == CBState.HALF_OPEN:
asyncio.create_task(self._check_db_health(circuit.database))
self._half_open_state_task = asyncio.create_task(self._check_db_health(circuit.database))
return

if old_state == CBState.CLOSED and new_state == CBState.OPEN:
loop.call_later(DEFAULT_GRACE_PERIOD, _half_open_circuit, circuit)

def _half_open_circuit(circuit: CircuitBreaker):
circuit.state = CBState.HALF_OPEN
circuit.state = CBState.HALF_OPEN

class Pipeline(AsyncRedisModuleCommands, AsyncCoreCommands):
"""
Pipeline implementation for multiple logical Redis databases.
"""
def __init__(self, client: MultiDBClient):
self._command_stack = []
self._client = client

async def __aenter__(self: "Pipeline") -> "Pipeline":
return self

async def __aexit__(self, exc_type, exc_value, traceback):
await self.reset()
await self._client.__aexit__(exc_type, exc_value, traceback)

def __await__(self):
return self._async_self().__await__()

async def _async_self(self):
return self

def __len__(self) -> int:
return len(self._command_stack)

def __bool__(self) -> bool:
"""Pipeline instances should always evaluate to True"""
return True

async def reset(self) -> None:
self._command_stack = []

async def aclose(self) -> None:
"""Close the pipeline"""
await self.reset()

def pipeline_execute_command(self, *args, **options) -> "Pipeline":
"""
Stage a command to be executed when execute() is next called

Returns the current Pipeline object back so commands can be
chained together, such as:

pipe = pipe.set('foo', 'bar').incr('baz').decr('bang')

At some other point, you can then run: pipe.execute(),
which will execute all commands queued in the pipe.
"""
self._command_stack.append((args, options))
return self

def execute_command(self, *args, **kwargs):
"""Adds a command to the stack"""
return self.pipeline_execute_command(*args, **kwargs)

async def execute(self) -> List[Any]:
"""Execute all the commands in the current pipeline"""
if not self._client.initialized:
await self._client.initialize()

try:
return await self._client.command_executor.execute_pipeline(tuple(self._command_stack))
finally:
await self.reset()
24 changes: 19 additions & 5 deletions redis/asyncio/multidb/command_executor.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from abc import abstractmethod
from datetime import datetime
from typing import List, Optional, Callable, Any
from typing import List, Optional, Callable, Any, Union, Awaitable

from redis.asyncio.client import PubSub, Pipeline
from redis.asyncio.multidb.database import Databases, AsyncDatabase, Database
Expand All @@ -13,6 +13,7 @@
from redis.event import EventDispatcherInterface, AsyncOnCommandsFailEvent
from redis.multidb.command_executor import CommandExecutor, BaseCommandExecutor
from redis.multidb.config import DEFAULT_AUTO_FALLBACK_INTERVAL
from redis.typing import KeyT


class AsyncCommandExecutor(CommandExecutor):
Expand Down Expand Up @@ -194,17 +195,30 @@ async def callback():

async def execute_pipeline(self, command_stack: tuple):
async def callback():
with self._active_database.client.pipeline() as pipe:
async with self._active_database.client.pipeline() as pipe:
for command, options in command_stack:
await pipe.execute_command(*command, **options)
pipe.execute_command(*command, **options)

return await pipe.execute()

return await self._execute_with_failure_detection(callback, command_stack)

async def execute_transaction(self, transaction: Callable[[Pipeline], None], *watches, **options):
async def execute_transaction(
self,
func: Callable[["Pipeline"], Union[Any, Awaitable[Any]]],
*watches: KeyT,
shard_hint: Optional[str] = None,
value_from_callable: bool = False,
watch_delay: Optional[float] = None,
):
async def callback():
return await self._active_database.client.transaction(transaction, *watches, **options)
return await self._active_database.client.transaction(
func,
*watches,
shard_hint=shard_hint,
value_from_callable=value_from_callable,
watch_delay=watch_delay
)

return await self._execute_with_failure_detection(callback)

Expand Down
Loading