|
42 | 42 | MaintenanceEventConnectionHandler,
|
43 | 43 | MaintenanceEventPoolHandler,
|
44 | 44 | MaintenanceEventsConfig,
|
| 45 | + MaintenanceState, |
45 | 46 | )
|
46 | 47 | from .retry import Retry
|
47 | 48 | from .utils import (
|
@@ -285,6 +286,7 @@ def __init__(
|
285 | 286 | maintenance_events_config: Optional[MaintenanceEventsConfig] = None,
|
286 | 287 | tmp_host_address: Optional[str] = None,
|
287 | 288 | tmp_relax_timeout: Optional[float] = -1,
|
| 289 | + maintenance_state: "MaintenanceState" = MaintenanceState.NONE, |
288 | 290 | ):
|
289 | 291 | """
|
290 | 292 | Initialize a new Connection.
|
@@ -374,6 +376,7 @@ def __init__(
|
374 | 376 | self._should_reconnect = False
|
375 | 377 | self.tmp_host_address = tmp_host_address
|
376 | 378 | self.tmp_relax_timeout = tmp_relax_timeout
|
| 379 | + self.maintenance_state = maintenance_state |
377 | 380 |
|
378 | 381 | def __repr__(self):
|
379 | 382 | repr_args = ",".join([f"{k}={v}" for k, v in self.repr_pieces()])
|
@@ -835,6 +838,9 @@ def update_tmp_settings(
|
835 | 838 | if tmp_relax_timeout is not SENTINEL:
|
836 | 839 | self.tmp_relax_timeout = tmp_relax_timeout
|
837 | 840 |
|
| 841 | + def set_maintenance_state(self, state: "MaintenanceState"): |
| 842 | + self.maintenance_state = state |
| 843 | + |
838 | 844 |
|
839 | 845 | class Connection(AbstractConnection):
|
840 | 846 | "Manages TCP communication to and from a Redis server"
|
@@ -1724,11 +1730,18 @@ def make_connection(self) -> "ConnectionInterface":
|
1724 | 1730 | raise MaxConnectionsError("Too many connections")
|
1725 | 1731 | self._created_connections += 1
|
1726 | 1732 |
|
| 1733 | + # Pass current maintenance_state to new connections |
| 1734 | + maintenance_state = self.connection_kwargs.get( |
| 1735 | + "maintenance_state", MaintenanceState.NONE |
| 1736 | + ) |
| 1737 | + kwargs = dict(self.connection_kwargs) |
| 1738 | + kwargs["maintenance_state"] = maintenance_state |
| 1739 | + |
1727 | 1740 | if self.cache is not None:
|
1728 | 1741 | return CacheProxyConnection(
|
1729 |
| - self.connection_class(**self.connection_kwargs), self.cache, self._lock |
| 1742 | + self.connection_class(**kwargs), self.cache, self._lock |
1730 | 1743 | )
|
1731 |
| - return self.connection_class(**self.connection_kwargs) |
| 1744 | + return self.connection_class(**kwargs) |
1732 | 1745 |
|
1733 | 1746 | def release(self, connection: "Connection") -> None:
|
1734 | 1747 | "Releases the connection back to the pool"
|
@@ -1953,6 +1966,16 @@ async def _mock(self, error: RedisError):
|
1953 | 1966 | """
|
1954 | 1967 | pass
|
1955 | 1968 |
|
| 1969 | + def set_maintenance_state_for_all(self, state: "MaintenanceState"): |
| 1970 | + with self._lock: |
| 1971 | + for conn in self._available_connections: |
| 1972 | + conn.set_maintenance_state(state) |
| 1973 | + for conn in self._in_use_connections: |
| 1974 | + conn.set_maintenance_state(state) |
| 1975 | + |
| 1976 | + def set_maintenance_state_in_kwargs(self, state: "MaintenanceState"): |
| 1977 | + self.connection_kwargs["maintenance_state"] = state |
| 1978 | + |
1956 | 1979 |
|
1957 | 1980 | class BlockingConnectionPool(ConnectionPool):
|
1958 | 1981 | """
|
@@ -2047,15 +2070,20 @@ def make_connection(self):
|
2047 | 2070 | if self._in_maintenance:
|
2048 | 2071 | self._lock.acquire()
|
2049 | 2072 | self._locked = True
|
| 2073 | + # Pass current maintenance_state to new connections |
| 2074 | + maintenance_state = self.connection_kwargs.get( |
| 2075 | + "maintenance_state", MaintenanceState.NONE |
| 2076 | + ) |
| 2077 | + kwargs = dict(self.connection_kwargs) |
| 2078 | + kwargs["maintenance_state"] = maintenance_state |
2050 | 2079 | if self.cache is not None:
|
2051 | 2080 | connection = CacheProxyConnection(
|
2052 |
| - self.connection_class(**self.connection_kwargs), |
| 2081 | + self.connection_class(**kwargs), |
2053 | 2082 | self.cache,
|
2054 | 2083 | self._lock,
|
2055 | 2084 | )
|
2056 | 2085 | else:
|
2057 |
| - connection = self.connection_class(**self.connection_kwargs) |
2058 |
| - |
| 2086 | + connection = self.connection_class(**kwargs) |
2059 | 2087 | self._connections.append(connection)
|
2060 | 2088 | return connection
|
2061 | 2089 | finally:
|
@@ -2266,3 +2294,12 @@ def _update_maintenance_events_configs_for_connections(
|
2266 | 2294 | def set_in_maintenance(self, in_maintenance: bool):
|
2267 | 2295 | """Set the maintenance mode for the connection pool."""
|
2268 | 2296 | self._in_maintenance = in_maintenance
|
| 2297 | + |
| 2298 | + def set_maintenance_state_for_all(self, state: "MaintenanceState"): |
| 2299 | + with self._lock: |
| 2300 | + for conn in getattr(self, "_connections", []): |
| 2301 | + if conn: |
| 2302 | + conn.set_maintenance_state(state) |
| 2303 | + |
| 2304 | + def set_maintenance_state_in_kwargs(self, state: "MaintenanceState"): |
| 2305 | + self.connection_kwargs["maintenance_state"] = state |
0 commit comments