1
1
import asyncio
2
- from typing import Callable , Optional , Coroutine , Any
2
+ from typing import Callable , Optional , Coroutine , Any , List , Union , Awaitable
3
3
4
4
from redis .asyncio .multidb .command_executor import DefaultCommandExecutor
5
5
from redis .asyncio .multidb .database import AsyncDatabase , Databases
10
10
from redis .background import BackgroundScheduler
11
11
from redis .commands import AsyncRedisModuleCommands , AsyncCoreCommands
12
12
from redis .multidb .exception import NoValidDatabaseException
13
+ from redis .typing import KeyT
13
14
14
15
15
16
class MultiDBClient (AsyncRedisModuleCommands , AsyncCoreCommands ):
@@ -49,6 +50,19 @@ def __init__(self, config: MultiDbConfig):
49
50
self ._hc_lock = asyncio .Lock ()
50
51
self ._bg_scheduler = BackgroundScheduler ()
51
52
self ._config = config
53
+ self ._hc_task = None
54
+ self ._half_open_state_task = None
55
+
56
+ async def __aenter__ (self : "MultiDBClient" ) -> "MultiDBClient" :
57
+ if not self .initialized :
58
+ await self .initialize ()
59
+ return self
60
+
61
+ async def __aexit__ (self , exc_type , exc_value , traceback ):
62
+ if self ._hc_task :
63
+ self ._hc_task .cancel ()
64
+ if self ._half_open_state_task :
65
+ self ._half_open_state_task .cancel ()
52
66
53
67
async def initialize (self ):
54
68
"""
@@ -61,7 +75,7 @@ async def raise_exception_on_failed_hc(error):
61
75
await self ._check_databases_health (on_error = raise_exception_on_failed_hc )
62
76
63
77
# Starts recurring health checks on the background.
64
- asyncio .create_task (self ._bg_scheduler .run_recurring_async (
78
+ self . _hc_task = asyncio .create_task (self ._bg_scheduler .run_recurring_async (
65
79
self ._health_check_interval ,
66
80
self ._check_databases_health ,
67
81
))
@@ -180,6 +194,34 @@ async def execute_command(self, *args, **options):
180
194
181
195
return await self .command_executor .execute_command (* args , ** options )
182
196
197
+ def pipeline (self ):
198
+ """
199
+ Enters into pipeline mode of the client.
200
+ """
201
+ return Pipeline (self )
202
+
203
+ async def transaction (
204
+ self ,
205
+ func : Callable [["Pipeline" ], Union [Any , Awaitable [Any ]]],
206
+ * watches : KeyT ,
207
+ shard_hint : Optional [str ] = None ,
208
+ value_from_callable : bool = False ,
209
+ watch_delay : Optional [float ] = None ,
210
+ ):
211
+ """
212
+ Executes callable as transaction.
213
+ """
214
+ if not self .initialized :
215
+ await self .initialize ()
216
+
217
+ return await self .command_executor .execute_transaction (
218
+ func ,
219
+ * watches ,
220
+ shard_hint = shard_hint ,
221
+ value_from_callable = value_from_callable ,
222
+ watch_delay = watch_delay ,
223
+ )
224
+
183
225
async def _check_databases_health (
184
226
self ,
185
227
on_error : Optional [Callable [[Exception ], Coroutine [Any , Any , None ]]] = None ,
@@ -227,11 +269,75 @@ def _on_circuit_state_change_callback(self, circuit: CircuitBreaker, old_state:
227
269
loop = asyncio .get_running_loop ()
228
270
229
271
if new_state == CBState .HALF_OPEN :
230
- asyncio .create_task (self ._check_db_health (circuit .database ))
272
+ self . _half_open_state_task = asyncio .create_task (self ._check_db_health (circuit .database ))
231
273
return
232
274
233
275
if old_state == CBState .CLOSED and new_state == CBState .OPEN :
234
276
loop .call_later (DEFAULT_GRACE_PERIOD , _half_open_circuit , circuit )
235
277
236
278
def _half_open_circuit (circuit : CircuitBreaker ):
237
- circuit .state = CBState .HALF_OPEN
279
+ circuit .state = CBState .HALF_OPEN
280
+
281
+ class Pipeline (AsyncRedisModuleCommands , AsyncCoreCommands ):
282
+ """
283
+ Pipeline implementation for multiple logical Redis databases.
284
+ """
285
+ def __init__ (self , client : MultiDBClient ):
286
+ self ._command_stack = []
287
+ self ._client = client
288
+
289
+ async def __aenter__ (self : "Pipeline" ) -> "Pipeline" :
290
+ return self
291
+
292
+ async def __aexit__ (self , exc_type , exc_value , traceback ):
293
+ await self .reset ()
294
+ await self ._client .__aexit__ (exc_type , exc_value , traceback )
295
+
296
+ def __await__ (self ):
297
+ return self ._async_self ().__await__ ()
298
+
299
+ async def _async_self (self ):
300
+ return self
301
+
302
+ def __len__ (self ) -> int :
303
+ return len (self ._command_stack )
304
+
305
+ def __bool__ (self ) -> bool :
306
+ """Pipeline instances should always evaluate to True"""
307
+ return True
308
+
309
+ async def reset (self ) -> None :
310
+ self ._command_stack = []
311
+
312
+ async def aclose (self ) -> None :
313
+ """Close the pipeline"""
314
+ await self .reset ()
315
+
316
+ def pipeline_execute_command (self , * args , ** options ) -> "Pipeline" :
317
+ """
318
+ Stage a command to be executed when execute() is next called
319
+
320
+ Returns the current Pipeline object back so commands can be
321
+ chained together, such as:
322
+
323
+ pipe = pipe.set('foo', 'bar').incr('baz').decr('bang')
324
+
325
+ At some other point, you can then run: pipe.execute(),
326
+ which will execute all commands queued in the pipe.
327
+ """
328
+ self ._command_stack .append ((args , options ))
329
+ return self
330
+
331
+ def execute_command (self , * args , ** kwargs ):
332
+ """Adds a command to the stack"""
333
+ return self .pipeline_execute_command (* args , ** kwargs )
334
+
335
+ async def execute (self ) -> List [Any ]:
336
+ """Execute all the commands in the current pipeline"""
337
+ if not self ._client .initialized :
338
+ await self ._client .initialize ()
339
+
340
+ try :
341
+ return await self ._client .command_executor .execute_pipeline (tuple (self ._command_stack ))
342
+ finally :
343
+ await self .reset ()
0 commit comments