diff --git a/redis/connection.py b/redis/connection.py index 310f06e204..e2dff43e03 100755 --- a/redis/connection.py +++ b/redis/connection.py @@ -1043,6 +1043,7 @@ def reset(self): def make_connection(self): "Make a fresh connection." connection = self.connection_class(**self.connection_kwargs) + connection.pool_generation = self.generation self._connections.append(connection) return connection @@ -1071,6 +1072,12 @@ def get_connection(self, command_name, *keys, **options): # raised unless handled by application code. If you want never to raise ConnectionError("No connection available.") + # If the pool generation differs, close the connection and open a new one. + if connection is not None and connection.pool_generation != self.generation: + self._remove_connection(connection) + connection.disconnect() + connection = None + # If the ``connection`` is actually ``None`` then that's a cue to make # a new connection to add to the pool. if connection is None: @@ -1085,7 +1092,14 @@ def release(self, connection): if connection.pid != self.pid: return - # Put the connection back into the pool. + # If we are releasing a connection that is no longer the same as the pool's generation, + # we will disconnect it. + if connection.pool_generation != self.generation: + self._remove_connection(connection) + connection.disconnect() + connection = None + + # Put the connection, or None back into the pool. try: self.pool.put_nowait(connection) except Full: @@ -1093,7 +1107,17 @@ def release(self, connection): # we don't want this connection pass - def disconnect(self): + def disconnect(self, immediate=False): "Disconnects all connections in the pool." - for connection in self._connections: - connection.disconnect() + self.generation += 1 + + if immediate: + for connection in self._connections: + connection.disconnect() + + def _remove_connection(self, connection): + "Remove a connection from the list of connections." + try: + self._connections.remove(connection) + except IndexError: + pass \ No newline at end of file diff --git a/tests/test_connection_pool.py b/tests/test_connection_pool.py index 6b2478aae8..fe2f6e70f3 100644 --- a/tests/test_connection_pool.py +++ b/tests/test_connection_pool.py @@ -16,6 +16,10 @@ class DummyConnection(object): def __init__(self, **kwargs): self.kwargs = kwargs self.pid = os.getpid() + self.disconnected = False + + def disconnect(self): + self.disconnected = True class TestConnectionPool(object): @@ -127,6 +131,35 @@ def test_reuse_previously_released_connection(self): c2 = pool.get_connection('_') assert c1 == c2 + def test_disconnect_changes_generation_and_returns_new_connection(self): + pool = self.get_pool() + c1 = pool.get_connection('_') + pool.release(c1) + assert pool.generation == 0 + pool.disconnect() + assert pool.generation == 1 + c2 = pool.get_connection('_') + assert c1 != c2 + + def test_disconnect_happens_when_releasing_connection_when_pool_generation_changes(self): + pool = self.get_pool() + c1 = pool.get_connection('_') + pool.disconnect() + assert not c1.disconnected + pool.release(c1) + assert c1.disconnected + c2 = pool.get_connection('_') + assert c1 != c2 + + def test_disconnect_disconnects_immediately(self): + pool = self.get_pool() + c1 = pool.get_connection('_') + pool.disconnect(immediate=True) + assert c1.disconnected + pool.release(c1) + c2 = pool.get_connection('_') + assert c1 != c2 + def test_repr_contains_db_info_tcp(self): pool = redis.ConnectionPool(host='localhost', port=6379, db=0) expected = 'ConnectionPool>'