@@ -42,7 +42,7 @@ class Connection(metaclass=ConnectionMeta):
42
42
"""
43
43
44
44
__slots__ = ('_protocol' , '_transport' , '_loop' ,
45
- '_top_xact' , '_aborted' ,
45
+ '_top_xact' , '_aborted' , '_middlewares'
46
46
'_pool_release_ctr' , '_stmt_cache' , '_stmts_to_close' ,
47
47
'_listeners' , '_server_version' , '_server_caps' ,
48
48
'_intro_query' , '_reset_query' , '_proxy' ,
@@ -53,7 +53,8 @@ class Connection(metaclass=ConnectionMeta):
53
53
def __init__ (self , protocol , transport , loop ,
54
54
addr : (str , int ) or str ,
55
55
config : connect_utils ._ClientConfiguration ,
56
- params : connect_utils ._ConnectionParameters ):
56
+ params : connect_utils ._ConnectionParameters ,
57
+ middlewares = None ):
57
58
self ._protocol = protocol
58
59
self ._transport = transport
59
60
self ._loop = loop
@@ -92,7 +93,7 @@ def __init__(self, protocol, transport, loop,
92
93
93
94
self ._reset_query = None
94
95
self ._proxy = None
95
-
96
+ self . _middlewares = _middlewares
96
97
# Used to serialize operations that might involve anonymous
97
98
# statements. Specifically, we want to make the following
98
99
# operation atomic:
@@ -1410,8 +1411,12 @@ async def reload_schema_state(self):
1410
1411
1411
1412
async def _execute (self , query , args , limit , timeout , return_status = False ):
1412
1413
with self ._stmt_exclusive_section :
1413
- result , _ = await self .__execute (
1414
- query , args , limit , timeout , return_status = return_status )
1414
+ wrapped = self .__execute
1415
+ if self ._middlewares :
1416
+ for m in reversed (self ._middlewares ):
1417
+ wrapped = await m (self , wrapped )
1418
+
1419
+ result , _ = await wrapped (query , args , limit , timeout , return_status = return_status )
1415
1420
return result
1416
1421
1417
1422
async def __execute (self , query , args , limit , timeout ,
@@ -1502,6 +1507,7 @@ async def connect(dsn=None, *,
1502
1507
max_cacheable_statement_size = 1024 * 15 ,
1503
1508
command_timeout = None ,
1504
1509
ssl = None ,
1510
+ middlewares = None ,
1505
1511
connection_class = Connection ,
1506
1512
server_settings = None ):
1507
1513
r"""A coroutine to establish a connection to a PostgreSQL server.
@@ -1618,6 +1624,10 @@ async def connect(dsn=None, *,
1618
1624
PostgreSQL documentation for
1619
1625
a `list of supported options <server settings>`_.
1620
1626
1627
+ :param middlewares:
1628
+ An optional list of middleware functions. Refer to documentation
1629
+ on create_pool.
1630
+
1621
1631
:param Connection connection_class:
1622
1632
Class of the returned connection object. Must be a subclass of
1623
1633
:class:`~asyncpg.connection.Connection`.
@@ -1683,6 +1693,7 @@ async def connect(dsn=None, *,
1683
1693
ssl = ssl , database = database ,
1684
1694
server_settings = server_settings ,
1685
1695
command_timeout = command_timeout ,
1696
+ middlewares = middlewares ,
1686
1697
statement_cache_size = statement_cache_size ,
1687
1698
max_cached_statement_lifetime = max_cached_statement_lifetime ,
1688
1699
max_cacheable_statement_size = max_cacheable_statement_size )
0 commit comments