Skip to content

Commit a983c20

Browse files
committed
Adding oss maint notifications handler configurations to parsers. Placeholder for smigrated handler in OSSMaintNotificationsHandler class
1 parent e0847c9 commit a983c20

File tree

6 files changed

+300
-64
lines changed

6 files changed

+300
-64
lines changed

redis/_parsers/base.py

Lines changed: 88 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
NodeMigratedNotification,
1212
NodeMigratingNotification,
1313
NodeMovingNotification,
14+
OSSNodeMigratedNotification,
15+
OSSNodeMigratingNotification,
1416
)
1517

1618
if sys.version_info.major >= 3 and sys.version_info.minor >= 11:
@@ -179,6 +181,34 @@ async def read_response(
179181
class MaintenanceNotificationsParser:
180182
"""Protocol defining maintenance push notification parsing functionality"""
181183

184+
@staticmethod
185+
def parse_oss_maintenance_start_msg(response):
186+
# TODO This format is not the final - will be changed later
187+
# Expected message format is: SMIGRATING <seq_number> <src_node> <dest_node> <slots>
188+
id = response[1]
189+
190+
address_value = response[3]
191+
if isinstance(address_value, bytes):
192+
address_value = address_value.decode()
193+
if response[2] in ("TO", b"TO"):
194+
dest_node = address_value
195+
src_node = None
196+
else:
197+
dest_node = None
198+
src_node = address_value
199+
200+
slots = response[4]
201+
return OSSNodeMigratingNotification(id, src_node, dest_node, slots)
202+
203+
@staticmethod
204+
def parse_oss_maintenance_completed_msg(response):
205+
# TODO This format is not the final - will be changed later
206+
# Expected message format is: SMIGRATED <seq_number> <node_address> <slots>
207+
id = response[1]
208+
node_address = response[2]
209+
slots = response[3]
210+
return OSSNodeMigratedNotification(id, node_address, slots)
211+
182212
@staticmethod
183213
def parse_maintenance_start_msg(response, notification_type):
184214
# Expected message format is: <notification_type> <seq_number> <time>
@@ -215,12 +245,15 @@ def parse_moving_msg(response):
215245
_MIGRATED_MESSAGE = "MIGRATED"
216246
_FAILING_OVER_MESSAGE = "FAILING_OVER"
217247
_FAILED_OVER_MESSAGE = "FAILED_OVER"
248+
_SMIGRATING_MESSAGE = "SMIGRATING"
249+
_SMIGRATED_MESSAGE = "SMIGRATED"
218250

219251
_MAINTENANCE_MESSAGES = (
220252
_MIGRATING_MESSAGE,
221253
_MIGRATED_MESSAGE,
222254
_FAILING_OVER_MESSAGE,
223255
_FAILED_OVER_MESSAGE,
256+
_SMIGRATING_MESSAGE,
224257
)
225258

226259
MSG_TYPE_TO_MAINT_NOTIFICATION_PARSER_MAPPING: dict[
@@ -246,6 +279,14 @@ def parse_moving_msg(response):
246279
NodeMovingNotification,
247280
MaintenanceNotificationsParser.parse_moving_msg,
248281
),
282+
_SMIGRATING_MESSAGE: (
283+
OSSNodeMigratingNotification,
284+
MaintenanceNotificationsParser.parse_oss_maintenance_start_msg,
285+
),
286+
_SMIGRATED_MESSAGE: (
287+
OSSNodeMigratedNotification,
288+
MaintenanceNotificationsParser.parse_oss_maintenance_completed_msg,
289+
),
249290
}
250291

251292

@@ -256,6 +297,7 @@ class PushNotificationsParser(Protocol):
256297
invalidation_push_handler_func: Optional[Callable] = None
257298
node_moving_push_handler_func: Optional[Callable] = None
258299
maintenance_push_handler_func: Optional[Callable] = None
300+
oss_maintenance_push_handler_func: Optional[Callable] = None
259301

260302
def handle_pubsub_push_response(self, response):
261303
"""Handle pubsub push responses"""
@@ -270,6 +312,7 @@ def handle_push_response(self, response, **kwargs):
270312
_INVALIDATION_MESSAGE,
271313
*_MAINTENANCE_MESSAGES,
272314
_MOVING_MESSAGE,
315+
_SMIGRATED_MESSAGE,
273316
):
274317
return self.pubsub_push_handler_func(response)
275318

@@ -292,13 +335,27 @@ def handle_push_response(self, response, **kwargs):
292335
parser_function = MSG_TYPE_TO_MAINT_NOTIFICATION_PARSER_MAPPING[
293336
msg_type
294337
][1]
295-
notification_type = MSG_TYPE_TO_MAINT_NOTIFICATION_PARSER_MAPPING[
296-
msg_type
297-
][0]
298-
notification = parser_function(response, notification_type)
338+
if msg_type == _SMIGRATING_MESSAGE:
339+
notification = parser_function(response)
340+
else:
341+
notification_type = MSG_TYPE_TO_MAINT_NOTIFICATION_PARSER_MAPPING[
342+
msg_type
343+
][0]
344+
notification = parser_function(response, notification_type)
299345

300346
if notification is not None:
301347
return self.maintenance_push_handler_func(notification)
348+
if (
349+
msg_type == _SMIGRATED_MESSAGE
350+
and self.oss_maintenance_push_handler_func
351+
):
352+
parser_function = MSG_TYPE_TO_MAINT_NOTIFICATION_PARSER_MAPPING[
353+
msg_type
354+
][1]
355+
notification = parser_function(response)
356+
357+
if notification is not None:
358+
return self.oss_maintenance_push_handler_func(notification)
302359
except Exception as e:
303360
logger.error(
304361
"Error handling {} message ({}): {}".format(msg_type, response, e)
@@ -318,6 +375,9 @@ def set_node_moving_push_handler(self, node_moving_push_handler_func):
318375
def set_maintenance_push_handler(self, maintenance_push_handler_func):
319376
self.maintenance_push_handler_func = maintenance_push_handler_func
320377

378+
def set_oss_maintenance_push_handler(self, oss_maintenance_push_handler_func):
379+
self.oss_maintenance_push_handler_func = oss_maintenance_push_handler_func
380+
321381

322382
class AsyncPushNotificationsParser(Protocol):
323383
"""Protocol defining async RESP3-specific parsing functionality"""
@@ -326,6 +386,7 @@ class AsyncPushNotificationsParser(Protocol):
326386
invalidation_push_handler_func: Optional[Callable] = None
327387
node_moving_push_handler_func: Optional[Callable[..., Awaitable[None]]] = None
328388
maintenance_push_handler_func: Optional[Callable[..., Awaitable[None]]] = None
389+
oss_maintenance_push_handler_func: Optional[Callable[..., Awaitable[None]]] = None
329390

330391
async def handle_pubsub_push_response(self, response):
331392
"""Handle pubsub push responses asynchronously"""
@@ -342,6 +403,7 @@ async def handle_push_response(self, response, **kwargs):
342403
_INVALIDATION_MESSAGE,
343404
*_MAINTENANCE_MESSAGES,
344405
_MOVING_MESSAGE,
406+
_SMIGRATED_MESSAGE,
345407
):
346408
return await self.pubsub_push_handler_func(response)
347409

@@ -366,13 +428,28 @@ async def handle_push_response(self, response, **kwargs):
366428
parser_function = MSG_TYPE_TO_MAINT_NOTIFICATION_PARSER_MAPPING[
367429
msg_type
368430
][1]
369-
notification_type = MSG_TYPE_TO_MAINT_NOTIFICATION_PARSER_MAPPING[
370-
msg_type
371-
][0]
372-
notification = parser_function(response, notification_type)
431+
if msg_type == _SMIGRATING_MESSAGE:
432+
notification = parser_function(response)
433+
else:
434+
notification_type = MSG_TYPE_TO_MAINT_NOTIFICATION_PARSER_MAPPING[
435+
msg_type
436+
][0]
437+
notification = parser_function(response, notification_type)
373438

374439
if notification is not None:
375440
return await self.maintenance_push_handler_func(notification)
441+
if (
442+
msg_type == _SMIGRATED_MESSAGE
443+
and self.oss_maintenance_push_handler_func
444+
):
445+
parser_function = MSG_TYPE_TO_MAINT_NOTIFICATION_PARSER_MAPPING[
446+
msg_type
447+
][1]
448+
notification = parser_function(response)
449+
if notification is not None:
450+
return await self.oss_maintenance_push_handler_func(
451+
notification
452+
)
376453
except Exception as e:
377454
logger.error(
378455
"Error handling {} message ({}): {}".format(msg_type, response, e)
@@ -394,6 +471,9 @@ def set_node_moving_push_handler(self, node_moving_push_handler_func):
394471
def set_maintenance_push_handler(self, maintenance_push_handler_func):
395472
self.maintenance_push_handler_func = maintenance_push_handler_func
396473

474+
def set_oss_maintenance_push_handler(self, oss_maintenance_push_handler_func):
475+
self.oss_maintenance_push_handler_func = oss_maintenance_push_handler_func
476+
397477

398478
class _AsyncRESPBase(AsyncBaseParser):
399479
"""Base class for async resp parsing"""

redis/_parsers/hiredis.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ def __init__(self, socket_read_size):
4949
self.pubsub_push_handler_func = self.handle_pubsub_push_response
5050
self.node_moving_push_handler_func = None
5151
self.maintenance_push_handler_func = None
52+
self.oss_maintenance_push_handler_func = None
5253
self.invalidation_push_handler_func = None
5354
self._hiredis_PushNotificationType = None
5455

redis/_parsers/resp3.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ def __init__(self, socket_read_size):
2020
self.pubsub_push_handler_func = self.handle_pubsub_push_response
2121
self.node_moving_push_handler_func = None
2222
self.maintenance_push_handler_func = None
23+
self.oss_maintenance_push_handler_func = None
2324
self.invalidation_push_handler_func = None
2425

2526
def handle_pubsub_push_response(self, response):

redis/maint_notifications.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,9 @@
99

1010
from redis.typing import Number
1111

12+
if TYPE_CHECKING:
13+
from redis.cluster import NodesManager
14+
1215

1316
class MaintenanceState(enum.Enum):
1417
NONE = "none"
@@ -886,6 +889,7 @@ class MaintNotificationsConnectionHandler:
886889
_NOTIFICATION_TYPES: dict[type["MaintenanceNotification"], int] = {
887890
NodeMigratingNotification: 1,
888891
NodeFailingOverNotification: 1,
892+
OSSNodeMigratingNotification: 1,
889893
NodeMigratedNotification: 0,
890894
NodeFailedOverNotification: 0,
891895
}
@@ -939,3 +943,31 @@ def handle_maintenance_completed_notification(self):
939943
# timeouts by providing -1 as the relaxed timeout
940944
self.connection.update_current_socket_timeout(-1)
941945
self.connection.maintenance_state = MaintenanceState.NONE
946+
947+
948+
class OSSMaintNotificationsHandler:
949+
def __init__(
950+
self, nodes_manager: "NodesManager", config: MaintNotificationsConfig
951+
) -> None:
952+
self.nodes_manager = nodes_manager
953+
self.config = config
954+
self._processed_notifications = set()
955+
self._lock = threading.RLock()
956+
957+
def remove_expired_notifications(self):
958+
with self._lock:
959+
for notification in tuple(self._processed_notifications):
960+
if notification.is_expired():
961+
self._processed_notifications.remove(notification)
962+
963+
def handle_notification(self, notification: MaintenanceNotification):
964+
if isinstance(notification, OSSNodeMigratedNotification):
965+
self.handle_oss_maintenance_completed_notification(notification)
966+
else:
967+
logging.error(f"Unhandled notification type: {notification}")
968+
969+
def handle_oss_maintenance_completed_notification(
970+
self, notification: OSSNodeMigratedNotification
971+
):
972+
self.remove_expired_notifications()
973+
logging.info(f"Received OSS maintenance completed notification: {notification}")

tests/maint_notifications/proxy_server_helpers.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import base64
22
import logging
3+
import re
34
from typing import Union
45

56
from redis.http.http_client import HttpClient, HttpError
@@ -10,6 +11,19 @@
1011
class RespTranslator:
1112
"""Helper class to translate between RESP and other encodings."""
1213

14+
@staticmethod
15+
def str_or_list_to_resp(txt: str) -> str:
16+
"""
17+
Convert specific string or list to RESP format.
18+
"""
19+
if re.match(r"^<.*>$", txt):
20+
items = txt[1:-1].split(",")
21+
return f"*{len(items)}\r\n" + "\r\n".join(
22+
f"${len(x)}\r\n{x}" for x in items
23+
)
24+
else:
25+
return f"${len(txt)}\r\n{txt}"
26+
1327
@staticmethod
1428
def cluster_slots_to_resp(resp: str) -> str:
1529
"""Convert query to RESP format."""
@@ -24,7 +38,9 @@ def smigrating_to_resp(resp: str) -> str:
2438
"""Convert query to RESP format."""
2539
return (
2640
f">{len(resp.split())}\r\n"
27-
+ "\r\n".join(f"${len(x)}\r\n{x}" for x in resp.split())
41+
+ "\r\n".join(
42+
f"{RespTranslator.str_or_list_to_resp(x)}" for x in resp.split()
43+
)
2844
+ "\r\n"
2945
)
3046

@@ -118,6 +134,8 @@ def get_stats(self) -> dict:
118134

119135
try:
120136
response = self.http_client.get(url)
137+
if isinstance(response, dict):
138+
return response
121139
return response.json()
122140

123141
except HttpError as e:
@@ -134,6 +152,8 @@ def get_connections(self) -> dict:
134152

135153
try:
136154
response = self.http_client.get(url)
155+
if isinstance(response, dict):
156+
return response
137157
return response.json()
138158
except HttpError as e:
139159
raise RuntimeError(f"Failed to get connections: {e}")
@@ -192,7 +212,9 @@ def send_notification(
192212
data = base64.b64encode(notification.encode("utf-8"))
193213

194214
try:
195-
response = self.http_client.post(url, json_body=data)
215+
response = self.http_client.post(url, data=data)
216+
if isinstance(response, dict):
217+
return response
196218
results = response.json()
197219
except HttpError as e:
198220
results = {"error": str(e)}

0 commit comments

Comments
 (0)