Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 46 additions & 36 deletions redis/multidb/circuit.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,16 +45,57 @@ def database(self, database):
"""Set database associated with this circuit."""
pass

class BaseCircuitBreaker(CircuitBreaker):
"""
Base implementation of Circuit Breaker interface.
"""
def __init__(self, cb: pybreaker.CircuitBreaker):
self._cb = cb
self._state_pb_mapper = {
State.CLOSED: self._cb.close,
State.OPEN: self._cb.open,
State.HALF_OPEN: self._cb.half_open,
}
self._database = None

@property
def grace_period(self) -> float:
return self._cb.reset_timeout

@grace_period.setter
def grace_period(self, grace_period: float):
self._cb.reset_timeout = grace_period

@property
def state(self) -> State:
return State(value=self._cb.state.name)

@state.setter
def state(self, state: State):
self._state_pb_mapper[state]()

@property
def database(self):
return self._database

@database.setter
def database(self, database):
self._database = database

class SyncCircuitBreaker(CircuitBreaker):
"""
Synchronous implementation of Circuit Breaker interface.
"""
@abstractmethod
def on_state_changed(self, cb: Callable[["CircuitBreaker", State, State], None]):
def on_state_changed(self, cb: Callable[["SyncCircuitBreaker", State, State], None]):
"""Callback called when the state of the circuit changes."""
pass

class PBListener(pybreaker.CircuitBreakerListener):
"""Wrapper for callback to be compatible with pybreaker implementation."""
def __init__(
self,
cb: Callable[[CircuitBreaker, State, State], None],
cb: Callable[[SyncCircuitBreaker, State, State], None],
database,
):
"""
Expand All @@ -75,8 +116,7 @@ def state_change(self, cb, old_state, new_state):
new_state = State(value=new_state.name)
self._cb(cb, old_state, new_state)


class PBCircuitBreakerAdapter(CircuitBreaker):
class PBCircuitBreakerAdapter(SyncCircuitBreaker, BaseCircuitBreaker):
def __init__(self, cb: pybreaker.CircuitBreaker):
"""
Initialize a PBCircuitBreakerAdapter instance.
Expand All @@ -87,38 +127,8 @@ def __init__(self, cb: pybreaker.CircuitBreaker):
Args:
cb: A pybreaker CircuitBreaker instance to be adapted.
"""
self._cb = cb
self._state_pb_mapper = {
State.CLOSED: self._cb.close,
State.OPEN: self._cb.open,
State.HALF_OPEN: self._cb.half_open,
}
self._database = None

@property
def grace_period(self) -> float:
return self._cb.reset_timeout

@grace_period.setter
def grace_period(self, grace_period: float):
self._cb.reset_timeout = grace_period

@property
def state(self) -> State:
return State(value=self._cb.state.name)

@state.setter
def state(self, state: State):
self._state_pb_mapper[state]()

@property
def database(self):
return self._database

@database.setter
def database(self, database):
self._database = database
super().__init__(cb)

def on_state_changed(self, cb: Callable[["CircuitBreaker", State, State], None]):
def on_state_changed(self, cb: Callable[["SyncCircuitBreaker", State, State], None]):
listener = PBListener(cb, self.database)
self._cb.add_listener(listener)
25 changes: 11 additions & 14 deletions redis/multidb/client.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,12 @@
import threading
import socket
from typing import List, Any, Callable, Optional

from redis.background import BackgroundScheduler
from redis.client import PubSubWorkerThread
from redis.exceptions import ConnectionError, TimeoutError
from redis.commands import RedisModuleCommands, CoreCommands
from redis.multidb.command_executor import DefaultCommandExecutor
from redis.multidb.config import MultiDbConfig, DEFAULT_GRACE_PERIOD
from redis.multidb.circuit import State as CBState, CircuitBreaker
from redis.multidb.database import Database, AbstractDatabase, Databases
from redis.multidb.circuit import State as CBState, SyncCircuitBreaker
from redis.multidb.database import Database, Databases, SyncDatabase
from redis.multidb.exception import NoValidDatabaseException
from redis.multidb.failure_detector import FailureDetector
from redis.multidb.healthcheck import HealthCheck
Expand Down Expand Up @@ -92,7 +89,7 @@ def get_databases(self) -> Databases:
"""
return self._databases

def set_active_database(self, database: AbstractDatabase) -> None:
def set_active_database(self, database: SyncDatabase) -> None:
"""
Promote one of the existing databases to become an active.
"""
Expand All @@ -115,7 +112,7 @@ def set_active_database(self, database: AbstractDatabase) -> None:

raise NoValidDatabaseException('Cannot set active database, database is unhealthy')

def add_database(self, database: AbstractDatabase):
def add_database(self, database: SyncDatabase):
"""
Adds a new database to the database list.
"""
Expand All @@ -129,7 +126,7 @@ def add_database(self, database: AbstractDatabase):
self._databases.add(database, database.weight)
self._change_active_database(database, highest_weighted_db)

def _change_active_database(self, new_database: AbstractDatabase, highest_weight_database: AbstractDatabase):
def _change_active_database(self, new_database: SyncDatabase, highest_weight_database: SyncDatabase):
if new_database.weight > highest_weight_database.weight and new_database.circuit.state == CBState.CLOSED:
self.command_executor.active_database = new_database

Expand All @@ -143,7 +140,7 @@ def remove_database(self, database: Database):
if highest_weight <= weight and highest_weighted_db.circuit.state == CBState.CLOSED:
self.command_executor.active_database = highest_weighted_db

def update_database_weight(self, database: AbstractDatabase, weight: float):
def update_database_weight(self, database: SyncDatabase, weight: float):
"""
Updates a database from the database list.
"""
Expand Down Expand Up @@ -210,7 +207,7 @@ def pubsub(self, **kwargs):

return PubSub(self, **kwargs)

def _check_db_health(self, database: AbstractDatabase, on_error: Callable[[Exception], None] = None) -> None:
def _check_db_health(self, database: SyncDatabase, on_error: Callable[[Exception], None] = None) -> None:
"""
Runs health checks on the given database until first failure.
"""
Expand Down Expand Up @@ -247,15 +244,15 @@ def _check_databases_health(self, on_error: Callable[[Exception], None] = None):
for database, _ in self._databases:
self._check_db_health(database, on_error)

def _on_circuit_state_change_callback(self, circuit: CircuitBreaker, old_state: CBState, new_state: CBState):
def _on_circuit_state_change_callback(self, circuit: SyncCircuitBreaker, old_state: CBState, new_state: CBState):
if new_state == CBState.HALF_OPEN:
self._check_db_health(circuit.database)
return

if old_state == CBState.CLOSED and new_state == CBState.OPEN:
self._bg_scheduler.run_once(DEFAULT_GRACE_PERIOD, _half_open_circuit, circuit)

def _half_open_circuit(circuit: CircuitBreaker):
def _half_open_circuit(circuit: SyncCircuitBreaker):
circuit.state = CBState.HALF_OPEN


Expand Down Expand Up @@ -450,8 +447,8 @@ def run_in_thread(
exception_handler: Optional[Callable] = None,
sharded_pubsub: bool = False,
) -> "PubSubWorkerThread":
return self._client.command_executor.execute_pubsub_run_in_thread(
sleep_time=sleep_time,
return self._client.command_executor.execute_pubsub_run(
sleep_time,
daemon=daemon,
exception_handler=exception_handler,
pubsub=self,
Expand Down
Loading