From 298b0a8cf453fd1919ff62c52b920163cf3953f2 Mon Sep 17 00:00:00 2001 From: matdev83 <211248003+matdev83@users.noreply.github.com> Date: Mon, 13 Oct 2025 00:43:37 +0200 Subject: [PATCH] Optimize tool call handler ordering cache --- .../services/tool_call_reactor_service.py | 28 +++++++++--- .../test_tool_call_reactor_service.py | 44 +++++++++++++++++++ 2 files changed, 67 insertions(+), 5 deletions(-) diff --git a/src/core/services/tool_call_reactor_service.py b/src/core/services/tool_call_reactor_service.py index d41b72645..7192d7c51 100644 --- a/src/core/services/tool_call_reactor_service.py +++ b/src/core/services/tool_call_reactor_service.py @@ -41,6 +41,7 @@ def __init__(self, history_tracker: IToolCallHistoryTracker | None = None) -> No self._handlers: dict[str, IToolCallHandler] = {} self._history_tracker = history_tracker self._lock = asyncio.Lock() + self._handlers_cache: tuple[IToolCallHandler, ...] | None = None def register_handler_sync(self, handler: IToolCallHandler) -> None: """Register a tool call handler synchronously. @@ -61,6 +62,7 @@ def register_handler_sync(self, handler: IToolCallHandler) -> None: ) self._handlers[handler.name] = handler + self._invalidate_handler_cache() logger.info(f"Registered tool call handler synchronously: {handler.name}") async def register_handler(self, handler: IToolCallHandler) -> None: @@ -79,6 +81,7 @@ async def register_handler(self, handler: IToolCallHandler) -> None: ) self._handlers[handler.name] = handler + self._invalidate_handler_cache() logger.info(f"Registered tool call handler: {handler.name}") async def unregister_handler(self, handler_name: str) -> None: @@ -97,6 +100,7 @@ async def unregister_handler(self, handler_name: str) -> None: ) del self._handlers[handler_name] + self._invalidate_handler_cache() logger.info(f"Unregistered tool call handler: {handler_name}") async def process_tool_call( @@ -139,11 +143,7 @@ async def process_tool_call( ) # Get handlers sorted by priority (highest first) - handlers = sorted( - self._handlers.values(), - key=lambda h: h.priority, - reverse=True, - ) + handlers = self._get_handlers_by_priority() # Process through handlers for handler in handlers: @@ -183,6 +183,24 @@ def get_registered_handlers(self) -> list[str]: """ return list(self._handlers.keys()) + def _invalidate_handler_cache(self) -> None: + """Invalidate the cached handler ordering.""" + self._handlers_cache = None + + def _get_handlers_by_priority(self) -> tuple[IToolCallHandler, ...]: + """Return handlers sorted by priority, caching the order.""" + cached_handlers = self._handlers_cache + if cached_handlers is None: + cached_handlers = tuple( + sorted( + self._handlers.values(), + key=lambda handler: handler.priority, + reverse=True, + ) + ) + self._handlers_cache = cached_handlers + return cached_handlers + class InMemoryToolCallHistoryTracker(IToolCallHistoryTracker): """In-memory implementation of tool call history tracking.""" diff --git a/tests/unit/core/services/test_tool_call_reactor_service.py b/tests/unit/core/services/test_tool_call_reactor_service.py index 9d47c5194..2e2f0f142 100644 --- a/tests/unit/core/services/test_tool_call_reactor_service.py +++ b/tests/unit/core/services/test_tool_call_reactor_service.py @@ -274,6 +274,50 @@ async def test_process_tool_call_multiple_handlers_priority(self, reactor): assert low_priority_handler.can_handle_call_count == 1 assert low_priority_handler.handle_call_count == 1 + @pytest.mark.asyncio + async def test_handler_cache_refresh_after_unregister(self, reactor): + """Handler ordering cache should refresh after unregister events.""" + + swallowing_result = ToolCallReactionResult( + should_swallow=True, + replacement_response="handled", + ) + primary_handler = MockToolCallHandler( + "primary", + priority=100, + can_handle_return=True, + handle_result=swallowing_result, + ) + secondary_handler = MockToolCallHandler( + "secondary", + priority=10, + can_handle_return=True, + handle_result=swallowing_result, + ) + + await reactor.register_handler(secondary_handler) + await reactor.register_handler(primary_handler) + + context = ToolCallContext( + session_id="test_session", + backend_name="test_backend", + model_name="test_model", + full_response='{"content": "test"}', + tool_name="test_tool", + tool_arguments={"arg": "value"}, + ) + + result_first = await reactor.process_tool_call(context) + assert result_first is not None and result_first.should_swallow is True + assert primary_handler.handle_call_count == 1 + assert secondary_handler.handle_call_count == 0 + + await reactor.unregister_handler(primary_handler.name) + + result_second = await reactor.process_tool_call(context) + assert result_second is not None and result_second.should_swallow is True + assert secondary_handler.handle_call_count == 1 + @pytest.mark.asyncio async def test_process_tool_call_handler_error_handling(self, reactor): """Test that handler errors don't crash the reactor."""