From b26a03ef543f022b6dd8b0db8293006a549bff2c Mon Sep 17 00:00:00 2001 From: vladvildanov Date: Wed, 27 Aug 2025 16:42:32 +0300 Subject: [PATCH] Extract additional interfaces and abstract classes --- redis/multidb/circuit.py | 82 ++++++----- redis/multidb/client.py | 25 ++-- redis/multidb/command_executor.py | 152 ++++++++++---------- redis/multidb/config.py | 8 +- redis/multidb/database.py | 100 +++++++------ redis/multidb/event.py | 13 +- redis/multidb/failover.py | 11 +- redis/multidb/failure_detector.py | 1 - tests/test_multidb/conftest.py | 12 +- tests/test_multidb/test_circuit.py | 4 +- tests/test_multidb/test_client.py | 4 +- tests/test_multidb/test_config.py | 10 +- tests/test_multidb/test_failure_detector.py | 12 +- 13 files changed, 225 insertions(+), 209 deletions(-) diff --git a/redis/multidb/circuit.py b/redis/multidb/circuit.py index 79c8a5f379..221dc556a3 100644 --- a/redis/multidb/circuit.py +++ b/redis/multidb/circuit.py @@ -45,8 +45,49 @@ def database(self, database): """Set database associated with this circuit.""" pass +class BaseCircuitBreaker(CircuitBreaker): + """ + Base implementation of Circuit Breaker interface. + """ + def __init__(self, cb: pybreaker.CircuitBreaker): + self._cb = cb + self._state_pb_mapper = { + State.CLOSED: self._cb.close, + State.OPEN: self._cb.open, + State.HALF_OPEN: self._cb.half_open, + } + self._database = None + + @property + def grace_period(self) -> float: + return self._cb.reset_timeout + + @grace_period.setter + def grace_period(self, grace_period: float): + self._cb.reset_timeout = grace_period + + @property + def state(self) -> State: + return State(value=self._cb.state.name) + + @state.setter + def state(self, state: State): + self._state_pb_mapper[state]() + + @property + def database(self): + return self._database + + @database.setter + def database(self, database): + self._database = database + +class SyncCircuitBreaker(CircuitBreaker): + """ + Synchronous implementation of Circuit Breaker interface. + """ @abstractmethod - def on_state_changed(self, cb: Callable[["CircuitBreaker", State, State], None]): + def on_state_changed(self, cb: Callable[["SyncCircuitBreaker", State, State], None]): """Callback called when the state of the circuit changes.""" pass @@ -54,7 +95,7 @@ class PBListener(pybreaker.CircuitBreakerListener): """Wrapper for callback to be compatible with pybreaker implementation.""" def __init__( self, - cb: Callable[[CircuitBreaker, State, State], None], + cb: Callable[[SyncCircuitBreaker, State, State], None], database, ): """ @@ -75,8 +116,7 @@ def state_change(self, cb, old_state, new_state): new_state = State(value=new_state.name) self._cb(cb, old_state, new_state) - -class PBCircuitBreakerAdapter(CircuitBreaker): +class PBCircuitBreakerAdapter(SyncCircuitBreaker, BaseCircuitBreaker): def __init__(self, cb: pybreaker.CircuitBreaker): """ Initialize a PBCircuitBreakerAdapter instance. @@ -87,38 +127,8 @@ def __init__(self, cb: pybreaker.CircuitBreaker): Args: cb: A pybreaker CircuitBreaker instance to be adapted. """ - self._cb = cb - self._state_pb_mapper = { - State.CLOSED: self._cb.close, - State.OPEN: self._cb.open, - State.HALF_OPEN: self._cb.half_open, - } - self._database = None - - @property - def grace_period(self) -> float: - return self._cb.reset_timeout - - @grace_period.setter - def grace_period(self, grace_period: float): - self._cb.reset_timeout = grace_period - - @property - def state(self) -> State: - return State(value=self._cb.state.name) - - @state.setter - def state(self, state: State): - self._state_pb_mapper[state]() - - @property - def database(self): - return self._database - - @database.setter - def database(self, database): - self._database = database + super().__init__(cb) - def on_state_changed(self, cb: Callable[["CircuitBreaker", State, State], None]): + def on_state_changed(self, cb: Callable[["SyncCircuitBreaker", State, State], None]): listener = PBListener(cb, self.database) self._cb.add_listener(listener) \ No newline at end of file diff --git a/redis/multidb/client.py b/redis/multidb/client.py index 56342a7a53..8a0e006977 100644 --- a/redis/multidb/client.py +++ b/redis/multidb/client.py @@ -1,15 +1,12 @@ import threading -import socket from typing import List, Any, Callable, Optional from redis.background import BackgroundScheduler -from redis.client import PubSubWorkerThread -from redis.exceptions import ConnectionError, TimeoutError from redis.commands import RedisModuleCommands, CoreCommands from redis.multidb.command_executor import DefaultCommandExecutor from redis.multidb.config import MultiDbConfig, DEFAULT_GRACE_PERIOD -from redis.multidb.circuit import State as CBState, CircuitBreaker -from redis.multidb.database import Database, AbstractDatabase, Databases +from redis.multidb.circuit import State as CBState, SyncCircuitBreaker +from redis.multidb.database import Database, Databases, SyncDatabase from redis.multidb.exception import NoValidDatabaseException from redis.multidb.failure_detector import FailureDetector from redis.multidb.healthcheck import HealthCheck @@ -92,7 +89,7 @@ def get_databases(self) -> Databases: """ return self._databases - def set_active_database(self, database: AbstractDatabase) -> None: + def set_active_database(self, database: SyncDatabase) -> None: """ Promote one of the existing databases to become an active. """ @@ -115,7 +112,7 @@ def set_active_database(self, database: AbstractDatabase) -> None: raise NoValidDatabaseException('Cannot set active database, database is unhealthy') - def add_database(self, database: AbstractDatabase): + def add_database(self, database: SyncDatabase): """ Adds a new database to the database list. """ @@ -129,7 +126,7 @@ def add_database(self, database: AbstractDatabase): self._databases.add(database, database.weight) self._change_active_database(database, highest_weighted_db) - def _change_active_database(self, new_database: AbstractDatabase, highest_weight_database: AbstractDatabase): + def _change_active_database(self, new_database: SyncDatabase, highest_weight_database: SyncDatabase): if new_database.weight > highest_weight_database.weight and new_database.circuit.state == CBState.CLOSED: self.command_executor.active_database = new_database @@ -143,7 +140,7 @@ def remove_database(self, database: Database): if highest_weight <= weight and highest_weighted_db.circuit.state == CBState.CLOSED: self.command_executor.active_database = highest_weighted_db - def update_database_weight(self, database: AbstractDatabase, weight: float): + def update_database_weight(self, database: SyncDatabase, weight: float): """ Updates a database from the database list. """ @@ -210,7 +207,7 @@ def pubsub(self, **kwargs): return PubSub(self, **kwargs) - def _check_db_health(self, database: AbstractDatabase, on_error: Callable[[Exception], None] = None) -> None: + def _check_db_health(self, database: SyncDatabase, on_error: Callable[[Exception], None] = None) -> None: """ Runs health checks on the given database until first failure. """ @@ -247,7 +244,7 @@ def _check_databases_health(self, on_error: Callable[[Exception], None] = None): for database, _ in self._databases: self._check_db_health(database, on_error) - def _on_circuit_state_change_callback(self, circuit: CircuitBreaker, old_state: CBState, new_state: CBState): + def _on_circuit_state_change_callback(self, circuit: SyncCircuitBreaker, old_state: CBState, new_state: CBState): if new_state == CBState.HALF_OPEN: self._check_db_health(circuit.database) return @@ -255,7 +252,7 @@ def _on_circuit_state_change_callback(self, circuit: CircuitBreaker, old_state: if old_state == CBState.CLOSED and new_state == CBState.OPEN: self._bg_scheduler.run_once(DEFAULT_GRACE_PERIOD, _half_open_circuit, circuit) -def _half_open_circuit(circuit: CircuitBreaker): +def _half_open_circuit(circuit: SyncCircuitBreaker): circuit.state = CBState.HALF_OPEN @@ -450,8 +447,8 @@ def run_in_thread( exception_handler: Optional[Callable] = None, sharded_pubsub: bool = False, ) -> "PubSubWorkerThread": - return self._client.command_executor.execute_pubsub_run_in_thread( - sleep_time=sleep_time, + return self._client.command_executor.execute_pubsub_run( + sleep_time, daemon=daemon, exception_handler=exception_handler, pubsub=self, diff --git a/redis/multidb/command_executor.py b/redis/multidb/command_executor.py index 094230a31d..364c0a07ea 100644 --- a/redis/multidb/command_executor.py +++ b/redis/multidb/command_executor.py @@ -1,11 +1,11 @@ from abc import ABC, abstractmethod from datetime import datetime, timedelta -from typing import List, Optional, Callable +from typing import List, Optional, Callable, Any from redis.client import Pipeline, PubSub, PubSubWorkerThread from redis.event import EventDispatcherInterface, OnCommandsFailEvent from redis.multidb.config import DEFAULT_AUTO_FALLBACK_INTERVAL -from redis.multidb.database import Database, AbstractDatabase, Databases +from redis.multidb.database import Database, Databases, SyncDatabase from redis.multidb.circuit import State as CBState from redis.multidb.event import RegisterCommandFailure, ActiveDatabaseChanged, ResubscribeOnActiveDatabaseChanged from redis.multidb.failover import FailoverStrategy @@ -17,15 +17,40 @@ class CommandExecutor(ABC): @property @abstractmethod - def failure_detectors(self) -> List[FailureDetector]: - """Returns a list of failure detectors.""" + def auto_fallback_interval(self) -> float: + """Returns auto-fallback interval.""" pass + @auto_fallback_interval.setter @abstractmethod - def add_failure_detector(self, failure_detector: FailureDetector) -> None: - """Adds new failure detector to the list of failure detectors.""" + def auto_fallback_interval(self, auto_fallback_interval: float) -> None: + """Sets auto-fallback interval.""" pass +class BaseCommandExecutor(CommandExecutor): + def __init__( + self, + auto_fallback_interval: float = DEFAULT_AUTO_FALLBACK_INTERVAL, + ): + self._auto_fallback_interval = auto_fallback_interval + self._next_fallback_attempt: datetime + + @property + def auto_fallback_interval(self) -> float: + return self._auto_fallback_interval + + @auto_fallback_interval.setter + def auto_fallback_interval(self, auto_fallback_interval: int) -> None: + self._auto_fallback_interval = auto_fallback_interval + + def _schedule_next_fallback(self) -> None: + if self._auto_fallback_interval == DEFAULT_AUTO_FALLBACK_INTERVAL: + return + + self._next_fallback_attempt = datetime.now() + timedelta(seconds=self._auto_fallback_interval) + +class SyncCommandExecutor(CommandExecutor): + @property @abstractmethod def databases(self) -> Databases: @@ -34,19 +59,25 @@ def databases(self) -> Databases: @property @abstractmethod - def active_database(self) -> Optional[Database]: - """Returns currently active database.""" + def failure_detectors(self) -> List[FailureDetector]: + """Returns a list of failure detectors.""" pass - @active_database.setter @abstractmethod - def active_database(self, database: AbstractDatabase) -> None: - """Sets currently active database.""" + def add_failure_detector(self, failure_detector: FailureDetector) -> None: + """Adds a new failure detector to the list of failure detectors.""" pass + @property @abstractmethod - def pubsub(self, **kwargs): - """Initializes a PubSub object on a currently active database""" + def active_database(self) -> Optional[Database]: + """Returns currently active database.""" + pass + + @active_database.setter + @abstractmethod + def active_database(self, database: SyncDatabase) -> None: + """Sets the currently active database.""" pass @property @@ -69,30 +100,41 @@ def failover_strategy(self) -> FailoverStrategy: @property @abstractmethod - def auto_fallback_interval(self) -> float: - """Returns auto-fallback interval.""" + def command_retry(self) -> Retry: + """Returns command retry object.""" pass - @auto_fallback_interval.setter @abstractmethod - def auto_fallback_interval(self, auto_fallback_interval: float) -> None: - """Sets auto-fallback interval.""" + def pubsub(self, **kwargs): + """Initializes a PubSub object on a currently active database""" pass - @property @abstractmethod - def command_retry(self) -> Retry: - """Returns command retry object.""" + def execute_command(self, *args, **options): + """Executes a command and returns the result.""" pass @abstractmethod - def execute_command(self, *args, **options): - """Executes a command and returns the result.""" + def execute_pipeline(self, command_stack: tuple): + """Executes a stack of commands in pipeline.""" pass + @abstractmethod + def execute_transaction(self, transaction: Callable[[Pipeline], None], *watches, **options): + """Executes a transaction block wrapped in callback.""" + pass -class DefaultCommandExecutor(CommandExecutor): + @abstractmethod + def execute_pubsub_method(self, method_name: str, *args, **kwargs): + """Executes a given method on active pub/sub.""" + pass + @abstractmethod + def execute_pubsub_run(self, sleep_time: float, **kwargs) -> Any: + """Executes pub/sub run in a thread.""" + pass + +class DefaultCommandExecutor(SyncCommandExecutor, BaseCommandExecutor): def __init__( self, failure_detectors: List[FailureDetector], @@ -113,22 +155,26 @@ def __init__( event_dispatcher: Interface for dispatching events auto_fallback_interval: Time interval in seconds between attempts to fall back to a primary database """ + super().__init__(auto_fallback_interval) + for fd in failure_detectors: fd.set_command_executor(command_executor=self) - self._failure_detectors = failure_detectors self._databases = databases + self._failure_detectors = failure_detectors self._command_retry = command_retry self._failover_strategy = failover_strategy self._event_dispatcher = event_dispatcher - self._auto_fallback_interval = auto_fallback_interval - self._next_fallback_attempt: datetime self._active_database: Optional[Database] = None self._active_pubsub: Optional[PubSub] = None self._active_pubsub_kwargs = {} self._setup_event_dispatcher() self._schedule_next_fallback() + @property + def databases(self) -> Databases: + return self._databases + @property def failure_detectors(self) -> List[FailureDetector]: return self._failure_detectors @@ -136,20 +182,16 @@ def failure_detectors(self) -> List[FailureDetector]: def add_failure_detector(self, failure_detector: FailureDetector) -> None: self._failure_detectors.append(failure_detector) - @property - def databases(self) -> Databases: - return self._databases - @property def command_retry(self) -> Retry: return self._command_retry @property - def active_database(self) -> Optional[AbstractDatabase]: + def active_database(self) -> Optional[SyncDatabase]: return self._active_database @active_database.setter - def active_database(self, database: AbstractDatabase) -> None: + def active_database(self, database: SyncDatabase) -> None: old_active = self._active_database self._active_database = database @@ -170,25 +212,13 @@ def active_pubsub(self, pubsub: PubSub) -> None: def failover_strategy(self) -> FailoverStrategy: return self._failover_strategy - @property - def auto_fallback_interval(self) -> float: - return self._auto_fallback_interval - - @auto_fallback_interval.setter - def auto_fallback_interval(self, auto_fallback_interval: int) -> None: - self._auto_fallback_interval = auto_fallback_interval - def execute_command(self, *args, **options): - """Executes a command and returns the result.""" def callback(): return self._active_database.client.execute_command(*args, **options) return self._execute_with_failure_detection(callback, args) def execute_pipeline(self, command_stack: tuple): - """ - Executes a stack of commands in pipeline. - """ def callback(): with self._active_database.client.pipeline() as pipe: for command, options in command_stack: @@ -199,18 +229,12 @@ def callback(): return self._execute_with_failure_detection(callback, command_stack) def execute_transaction(self, transaction: Callable[[Pipeline], None], *watches, **options): - """ - Executes a transaction block wrapped in callback. - """ def callback(): return self._active_database.client.transaction(transaction, *watches, **options) return self._execute_with_failure_detection(callback) def pubsub(self, **kwargs): - """ - Initializes a PubSub object on a currently active database. - """ def callback(): if self._active_pubsub is None: self._active_pubsub = self._active_database.client.pubsub(**kwargs) @@ -220,31 +244,15 @@ def callback(): return self._execute_with_failure_detection(callback) def execute_pubsub_method(self, method_name: str, *args, **kwargs): - """ - Executes given method on active pub/sub. - """ def callback(): method = getattr(self.active_pubsub, method_name) return method(*args, **kwargs) return self._execute_with_failure_detection(callback, *args) - def execute_pubsub_run_in_thread( - self, - pubsub, - sleep_time: float = 0.0, - daemon: bool = False, - exception_handler: Optional[Callable] = None, - sharded_pubsub: bool = False, - ) -> "PubSubWorkerThread": + def execute_pubsub_run(self, sleep_time, **kwargs) -> "PubSubWorkerThread": def callback(): - return self._active_pubsub.run_in_thread( - sleep_time, - daemon=daemon, - exception_handler=exception_handler, - pubsub=pubsub, - sharded_pubsub=sharded_pubsub - ) + return self._active_pubsub.run_in_thread(sleep_time, **kwargs) return self._execute_with_failure_detection(callback) @@ -280,12 +288,6 @@ def _check_active_database(self): self.active_database = self._failover_strategy.database self._schedule_next_fallback() - def _schedule_next_fallback(self) -> None: - if self._auto_fallback_interval == DEFAULT_AUTO_FALLBACK_INTERVAL: - return - - self._next_fallback_attempt = datetime.now() + timedelta(seconds=self._auto_fallback_interval) - def _setup_event_dispatcher(self): """ Registers necessary listeners. diff --git a/redis/multidb/config.py b/redis/multidb/config.py index 5555baec44..a966ec329a 100644 --- a/redis/multidb/config.py +++ b/redis/multidb/config.py @@ -9,7 +9,7 @@ from redis.backoff import ExponentialWithJitterBackoff, AbstractBackoff, NoBackoff from redis.data_structure import WeightedList from redis.event import EventDispatcher, EventDispatcherInterface -from redis.multidb.circuit import CircuitBreaker, PBCircuitBreakerAdapter +from redis.multidb.circuit import PBCircuitBreakerAdapter, SyncCircuitBreaker from redis.multidb.database import Database, Databases from redis.multidb.failure_detector import FailureDetector, CommandFailureDetector from redis.multidb.healthcheck import HealthCheck, EchoHealthCheck, DEFAULT_HEALTH_CHECK_RETRIES, \ @@ -44,7 +44,7 @@ class DatabaseConfig: client_kwargs (dict): Additional parameters for the database client connection. from_url (Optional[str]): Redis URL way of connecting to the database. from_pool (Optional[ConnectionPool]): A pre-configured connection pool to use. - circuit (Optional[CircuitBreaker]): Custom circuit breaker implementation. + circuit (Optional[SyncCircuitBreaker]): Custom circuit breaker implementation. grace_period (float): Grace period after which we need to check if the circuit could be closed again. health_check_url (Optional[str]): URL for health checks. Cluster FQDN is typically used on public Redis Enterprise endpoints. @@ -57,11 +57,11 @@ class DatabaseConfig: client_kwargs: dict = field(default_factory=dict) from_url: Optional[str] = None from_pool: Optional[ConnectionPool] = None - circuit: Optional[CircuitBreaker] = None + circuit: Optional[SyncCircuitBreaker] = None grace_period: float = DEFAULT_GRACE_PERIOD health_check_url: Optional[str] = None - def default_circuit_breaker(self) -> CircuitBreaker: + def default_circuit_breaker(self) -> SyncCircuitBreaker: circuit_breaker = pybreaker.CircuitBreaker(reset_timeout=self.grace_period) return PBCircuitBreakerAdapter(circuit_breaker) diff --git a/redis/multidb/database.py b/redis/multidb/database.py index b03e77bd70..75a662d904 100644 --- a/redis/multidb/database.py +++ b/redis/multidb/database.py @@ -5,65 +5,92 @@ from redis import RedisCluster from redis.data_structure import WeightedList -from redis.multidb.circuit import CircuitBreaker +from redis.multidb.circuit import SyncCircuitBreaker from redis.typing import Number class AbstractDatabase(ABC): @property @abstractmethod - def client(self) -> Union[redis.Redis, RedisCluster]: - """The underlying redis client.""" + def weight(self) -> float: + """The weight of this database in compare to others. Used to determine the database failover to.""" pass - @client.setter + @weight.setter @abstractmethod - def client(self, client: Union[redis.Redis, RedisCluster]): - """Set the underlying redis client.""" + def weight(self, weight: float): + """Set the weight of this database in compare to others.""" pass @property @abstractmethod - def weight(self) -> float: - """The weight of this database in compare to others. Used to determine the database failover to.""" + def health_check_url(self) -> Optional[str]: + """Health check URL associated with the current database.""" pass - @weight.setter + @health_check_url.setter @abstractmethod - def weight(self, weight: float): - """Set the weight of this database in compare to others.""" + def health_check_url(self, health_check_url: Optional[str]): + """Set the health check URL associated with the current database.""" pass +class BaseDatabase(AbstractDatabase): + def __init__( + self, + weight: float, + health_check_url: Optional[str] = None, + ): + self._weight = weight + self._health_check_url = health_check_url + + @property + def weight(self) -> float: + return self._weight + + @weight.setter + def weight(self, weight: float): + self._weight = weight + + @property + def health_check_url(self) -> Optional[str]: + return self._health_check_url + + @health_check_url.setter + def health_check_url(self, health_check_url: Optional[str]): + self._health_check_url = health_check_url + +class SyncDatabase(AbstractDatabase): + """Database with an underlying synchronous redis client.""" @property @abstractmethod - def circuit(self) -> CircuitBreaker: - """Circuit breaker for the current database.""" + def client(self) -> Union[redis.Redis, RedisCluster]: + """The underlying redis client.""" pass - @circuit.setter + @client.setter @abstractmethod - def circuit(self, circuit: CircuitBreaker): - """Set the circuit breaker for the current database.""" + def client(self, client: Union[redis.Redis, RedisCluster]): + """Set the underlying redis client.""" pass @property @abstractmethod - def health_check_url(self) -> Optional[str]: - """Health check URL associated with the current database.""" + def circuit(self) -> SyncCircuitBreaker: + """Circuit breaker for the current database.""" pass - @health_check_url.setter + @circuit.setter @abstractmethod - def health_check_url(self, health_check_url: Optional[str]): - """Set the health check URL associated with the current database.""" + def circuit(self, circuit: SyncCircuitBreaker): + """Set the circuit breaker for the current database.""" pass -Databases = WeightedList[tuple[AbstractDatabase, Number]] +Databases = WeightedList[tuple[SyncDatabase, Number]] -class Database(AbstractDatabase): +class Database(BaseDatabase, SyncDatabase): def __init__( self, client: Union[redis.Redis, RedisCluster], - circuit: CircuitBreaker, + circuit: SyncCircuitBreaker, weight: float, health_check_url: Optional[str] = None, ): @@ -79,8 +106,7 @@ def __init__( self._client = client self._cb = circuit self._cb.database = self - self._weight = weight - self._health_check_url = health_check_url + super().__init__(weight, health_check_url) @property def client(self) -> Union[redis.Redis, RedisCluster]: @@ -91,25 +117,9 @@ def client(self, client: Union[redis.Redis, RedisCluster]): self._client = client @property - def weight(self) -> float: - return self._weight - - @weight.setter - def weight(self, weight: float): - self._weight = weight - - @property - def circuit(self) -> CircuitBreaker: + def circuit(self) -> SyncCircuitBreaker: return self._cb @circuit.setter - def circuit(self, circuit: CircuitBreaker): - self._cb = circuit - - @property - def health_check_url(self) -> Optional[str]: - return self._health_check_url - - @health_check_url.setter - def health_check_url(self, health_check_url: Optional[str]): - self._health_check_url = health_check_url + def circuit(self, circuit: SyncCircuitBreaker): + self._cb = circuit \ No newline at end of file diff --git a/redis/multidb/event.py b/redis/multidb/event.py index 2598bc4d06..bca9482347 100644 --- a/redis/multidb/event.py +++ b/redis/multidb/event.py @@ -1,8 +1,7 @@ from typing import List from redis.event import EventListenerInterface, OnCommandsFailEvent -from redis.multidb.config import Databases -from redis.multidb.database import AbstractDatabase +from redis.multidb.database import SyncDatabase from redis.multidb.failure_detector import FailureDetector class ActiveDatabaseChanged: @@ -11,8 +10,8 @@ class ActiveDatabaseChanged: """ def __init__( self, - old_database: AbstractDatabase, - new_database: AbstractDatabase, + old_database: SyncDatabase, + new_database: SyncDatabase, command_executor, **kwargs ): @@ -22,11 +21,11 @@ def __init__( self._kwargs = kwargs @property - def old_database(self) -> AbstractDatabase: + def old_database(self) -> SyncDatabase: return self._old_database @property - def new_database(self) -> AbstractDatabase: + def new_database(self) -> SyncDatabase: return self._new_database @property @@ -39,7 +38,7 @@ def kwargs(self): class ResubscribeOnActiveDatabaseChanged(EventListenerInterface): """ - Re-subscribe currently active pub/sub to a new active database. + Re-subscribe the currently active pub / sub to a new active database. """ def listen(self, event: ActiveDatabaseChanged): old_pubsub = event.command_executor.active_pubsub diff --git a/redis/multidb/failover.py b/redis/multidb/failover.py index d6cf198678..fd08b77ecd 100644 --- a/redis/multidb/failover.py +++ b/redis/multidb/failover.py @@ -1,8 +1,7 @@ from abc import ABC, abstractmethod from redis.data_structure import WeightedList -from redis.multidb.database import Databases -from redis.multidb.database import AbstractDatabase +from redis.multidb.database import Databases, SyncDatabase from redis.multidb.circuit import State as CBState from redis.multidb.exception import NoValidDatabaseException from redis.retry import Retry @@ -13,13 +12,13 @@ class FailoverStrategy(ABC): @property @abstractmethod - def database(self) -> AbstractDatabase: + def database(self) -> SyncDatabase: """Select the database according to the strategy.""" pass @abstractmethod def set_databases(self, databases: Databases) -> None: - """Set the databases strategy operates on.""" + """Set the database strategy operates on.""" pass class WeightBasedFailoverStrategy(FailoverStrategy): @@ -35,7 +34,7 @@ def __init__( self._databases = WeightedList() @property - def database(self) -> AbstractDatabase: + def database(self) -> SyncDatabase: return self._retry.call_with_retry( lambda: self._get_active_database(), lambda _: dummy_fail() @@ -44,7 +43,7 @@ def database(self) -> AbstractDatabase: def set_databases(self, databases: Databases) -> None: self._databases = databases - def _get_active_database(self) -> AbstractDatabase: + def _get_active_database(self) -> SyncDatabase: for database, _ in self._databases: if database.circuit.state == CBState.CLOSED: return database diff --git a/redis/multidb/failure_detector.py b/redis/multidb/failure_detector.py index 3280fa6c32..ef4bd35f69 100644 --- a/redis/multidb/failure_detector.py +++ b/redis/multidb/failure_detector.py @@ -24,7 +24,6 @@ class CommandFailureDetector(FailureDetector): """ Detects a failure based on a threshold of failed commands during a specific period of time. """ - def __init__( self, threshold: int, diff --git a/tests/test_multidb/conftest.py b/tests/test_multidb/conftest.py index a34ef01476..9503d79d9b 100644 --- a/tests/test_multidb/conftest.py +++ b/tests/test_multidb/conftest.py @@ -4,7 +4,7 @@ from redis import Redis from redis.data_structure import WeightedList -from redis.multidb.circuit import CircuitBreaker, State as CBState +from redis.multidb.circuit import State as CBState, SyncCircuitBreaker from redis.multidb.config import MultiDbConfig, DatabaseConfig, DEFAULT_HEALTH_CHECK_INTERVAL, \ DEFAULT_AUTO_FALLBACK_INTERVAL from redis.multidb.database import Database, Databases @@ -19,8 +19,8 @@ def mock_client() -> Redis: return Mock(spec=Redis) @pytest.fixture() -def mock_cb() -> CircuitBreaker: - return Mock(spec=CircuitBreaker) +def mock_cb() -> SyncCircuitBreaker: + return Mock(spec=SyncCircuitBreaker) @pytest.fixture() def mock_fd() -> FailureDetector: @@ -41,7 +41,7 @@ def mock_db(request) -> Database: db.client = Mock(spec=Redis) cb = request.param.get("circuit", {}) - mock_cb = Mock(spec=CircuitBreaker) + mock_cb = Mock(spec=SyncCircuitBreaker) mock_cb.grace_period = cb.get("grace_period", 1.0) mock_cb.state = cb.get("state", CBState.CLOSED) @@ -55,7 +55,7 @@ def mock_db1(request) -> Database: db.client = Mock(spec=Redis) cb = request.param.get("circuit", {}) - mock_cb = Mock(spec=CircuitBreaker) + mock_cb = Mock(spec=SyncCircuitBreaker) mock_cb.grace_period = cb.get("grace_period", 1.0) mock_cb.state = cb.get("state", CBState.CLOSED) @@ -69,7 +69,7 @@ def mock_db2(request) -> Database: db.client = Mock(spec=Redis) cb = request.param.get("circuit", {}) - mock_cb = Mock(spec=CircuitBreaker) + mock_cb = Mock(spec=SyncCircuitBreaker) mock_cb.grace_period = cb.get("grace_period", 1.0) mock_cb.state = cb.get("state", CBState.CLOSED) diff --git a/tests/test_multidb/test_circuit.py b/tests/test_multidb/test_circuit.py index 7dc642373b..f5f39c3f6b 100644 --- a/tests/test_multidb/test_circuit.py +++ b/tests/test_multidb/test_circuit.py @@ -1,7 +1,7 @@ import pybreaker import pytest -from redis.multidb.circuit import PBCircuitBreakerAdapter, State as CbState, CircuitBreaker +from redis.multidb.circuit import PBCircuitBreakerAdapter, State as CbState, CircuitBreaker, SyncCircuitBreaker class TestPBCircuitBreaker: @@ -39,7 +39,7 @@ def test_cb_executes_callback_on_state_changed(self): adapter = PBCircuitBreakerAdapter(cb=pb_circuit) called_count = 0 - def callback(cb: CircuitBreaker, old_state: CbState, new_state: CbState): + def callback(cb: SyncCircuitBreaker, old_state: CbState, new_state: CbState): nonlocal called_count assert old_state == CbState.CLOSED assert new_state == CbState.HALF_OPEN diff --git a/tests/test_multidb/test_client.py b/tests/test_multidb/test_client.py index 193980d37c..c7c15fe684 100644 --- a/tests/test_multidb/test_client.py +++ b/tests/test_multidb/test_client.py @@ -8,7 +8,7 @@ from redis.multidb.circuit import State as CBState, PBCircuitBreakerAdapter from redis.multidb.config import DEFAULT_FAILOVER_RETRIES, \ DEFAULT_FAILOVER_BACKOFF -from redis.multidb.database import AbstractDatabase +from redis.multidb.database import SyncDatabase from redis.multidb.client import MultiDBClient from redis.multidb.exception import NoValidDatabaseException from redis.multidb.failover import WeightBasedFailoverStrategy @@ -458,7 +458,7 @@ def test_set_active_database( assert client.set('key', 'value') == 'OK' with pytest.raises(ValueError, match='Given database is not a member of database list'): - client.set_active_database(Mock(spec=AbstractDatabase)) + client.set_active_database(Mock(spec=SyncDatabase)) mock_hc.check_health.return_value = False diff --git a/tests/test_multidb/test_config.py b/tests/test_multidb/test_config.py index 87aae701a9..e428b3ce7a 100644 --- a/tests/test_multidb/test_config.py +++ b/tests/test_multidb/test_config.py @@ -1,6 +1,6 @@ from unittest.mock import Mock from redis.connection import ConnectionPool -from redis.multidb.circuit import CircuitBreaker, PBCircuitBreakerAdapter +from redis.multidb.circuit import PBCircuitBreakerAdapter, SyncCircuitBreaker from redis.multidb.config import MultiDbConfig, DEFAULT_HEALTH_CHECK_INTERVAL, \ DEFAULT_AUTO_FALLBACK_INTERVAL, DatabaseConfig, DEFAULT_GRACE_PERIOD from redis.multidb.database import Database @@ -49,11 +49,11 @@ def test_overridden_config(self): mock_connection_pools[0].connection_kwargs = {} mock_connection_pools[1].connection_kwargs = {} mock_connection_pools[2].connection_kwargs = {} - mock_cb1 = Mock(spec=CircuitBreaker) + mock_cb1 = Mock(spec=SyncCircuitBreaker) mock_cb1.grace_period = grace_period - mock_cb2 = Mock(spec=CircuitBreaker) + mock_cb2 = Mock(spec=SyncCircuitBreaker) mock_cb2.grace_period = grace_period - mock_cb3 = Mock(spec=CircuitBreaker) + mock_cb3 = Mock(spec=SyncCircuitBreaker) mock_cb3.grace_period = grace_period mock_failure_detectors = [Mock(spec=FailureDetector), Mock(spec=FailureDetector)] mock_health_checks = [Mock(spec=HealthCheck), Mock(spec=HealthCheck)] @@ -113,7 +113,7 @@ def test_default_config(self): def test_overridden_config(self): mock_connection_pool = Mock(spec=ConnectionPool) - mock_circuit = Mock(spec=CircuitBreaker) + mock_circuit = Mock(spec=SyncCircuitBreaker) config = DatabaseConfig( client_kwargs={'connection_pool': mock_connection_pool}, weight=1.0, circuit=mock_circuit diff --git a/tests/test_multidb/test_failure_detector.py b/tests/test_multidb/test_failure_detector.py index 86d6e1cd82..28687f2a11 100644 --- a/tests/test_multidb/test_failure_detector.py +++ b/tests/test_multidb/test_failure_detector.py @@ -3,7 +3,7 @@ import pytest -from redis.multidb.command_executor import CommandExecutor +from redis.multidb.command_executor import SyncCommandExecutor from redis.multidb.failure_detector import CommandFailureDetector from redis.multidb.circuit import State as CBState from redis.exceptions import ConnectionError @@ -19,7 +19,7 @@ class TestCommandFailureDetector: ) def test_failure_detector_open_circuit_on_threshold_exceed_and_interval_not_exceed(self, mock_db): fd = CommandFailureDetector(5, 1) - mock_ce = Mock(spec=CommandExecutor) + mock_ce = Mock(spec=SyncCommandExecutor) mock_ce.active_database = mock_db fd.set_command_executor(mock_ce) assert mock_db.circuit.state == CBState.CLOSED @@ -41,7 +41,7 @@ def test_failure_detector_open_circuit_on_threshold_exceed_and_interval_not_exce ) def test_failure_detector_do_not_open_circuit_if_threshold_not_exceed_and_interval_not_exceed(self, mock_db): fd = CommandFailureDetector(5, 1) - mock_ce = Mock(spec=CommandExecutor) + mock_ce = Mock(spec=SyncCommandExecutor) mock_ce.active_database = mock_db fd.set_command_executor(mock_ce) assert mock_db.circuit.state == CBState.CLOSED @@ -62,7 +62,7 @@ def test_failure_detector_do_not_open_circuit_if_threshold_not_exceed_and_interv ) def test_failure_detector_do_not_open_circuit_on_threshold_exceed_and_interval_exceed(self, mock_db): fd = CommandFailureDetector(5, 0.3) - mock_ce = Mock(spec=CommandExecutor) + mock_ce = Mock(spec=SyncCommandExecutor) mock_ce.active_database = mock_db fd.set_command_executor(mock_ce) assert mock_db.circuit.state == CBState.CLOSED @@ -96,7 +96,7 @@ def test_failure_detector_do_not_open_circuit_on_threshold_exceed_and_interval_e ) def test_failure_detector_refresh_timer_on_expired_duration(self, mock_db): fd = CommandFailureDetector(5, 0.3) - mock_ce = Mock(spec=CommandExecutor) + mock_ce = Mock(spec=SyncCommandExecutor) mock_ce.active_database = mock_db fd.set_command_executor(mock_ce) assert mock_db.circuit.state == CBState.CLOSED @@ -128,7 +128,7 @@ def test_failure_detector_refresh_timer_on_expired_duration(self, mock_db): ) def test_failure_detector_open_circuit_on_specific_exception_threshold_exceed(self, mock_db): fd = CommandFailureDetector(5, 1, error_types=[ConnectionError]) - mock_ce = Mock(spec=CommandExecutor) + mock_ce = Mock(spec=SyncCommandExecutor) mock_ce.active_database = mock_db fd.set_command_executor(mock_ce) assert mock_db.circuit.state == CBState.CLOSED