From 4862a3e6a093636598fe4375683b95168b3e21e2 Mon Sep 17 00:00:00 2001 From: matdev83 <211248003+matdev83@users.noreply.github.com> Date: Mon, 13 Oct 2025 00:57:12 +0200 Subject: [PATCH] Optimize tool call handler ordering cache --- .../services/tool_call_reactor_service.py | 28 ++++- .../test_tool_call_reactor_service.py | 100 ++++++++++++++++++ 2 files changed, 123 insertions(+), 5 deletions(-) diff --git a/src/core/services/tool_call_reactor_service.py b/src/core/services/tool_call_reactor_service.py index d41b72645..f385b7b56 100644 --- a/src/core/services/tool_call_reactor_service.py +++ b/src/core/services/tool_call_reactor_service.py @@ -41,6 +41,25 @@ def __init__(self, history_tracker: IToolCallHistoryTracker | None = None) -> No self._handlers: dict[str, IToolCallHandler] = {} self._history_tracker = history_tracker self._lock = asyncio.Lock() + self._sorted_handlers: tuple[IToolCallHandler, ...] | None = None + + def _invalidate_sorted_handlers(self) -> None: + """Invalidate cached handler ordering.""" + + self._sorted_handlers = None + + def _get_sorted_handlers(self) -> tuple[IToolCallHandler, ...]: + """Return handlers sorted by priority, caching the result.""" + + if self._sorted_handlers is None: + self._sorted_handlers = tuple( + sorted( + self._handlers.values(), + key=lambda h: h.priority, + reverse=True, + ) + ) + return self._sorted_handlers def register_handler_sync(self, handler: IToolCallHandler) -> None: """Register a tool call handler synchronously. @@ -61,6 +80,7 @@ def register_handler_sync(self, handler: IToolCallHandler) -> None: ) self._handlers[handler.name] = handler + self._invalidate_sorted_handlers() logger.info(f"Registered tool call handler synchronously: {handler.name}") async def register_handler(self, handler: IToolCallHandler) -> None: @@ -79,6 +99,7 @@ async def register_handler(self, handler: IToolCallHandler) -> None: ) self._handlers[handler.name] = handler + self._invalidate_sorted_handlers() logger.info(f"Registered tool call handler: {handler.name}") async def unregister_handler(self, handler_name: str) -> None: @@ -97,6 +118,7 @@ async def unregister_handler(self, handler_name: str) -> None: ) del self._handlers[handler_name] + self._invalidate_sorted_handlers() logger.info(f"Unregistered tool call handler: {handler_name}") async def process_tool_call( @@ -139,11 +161,7 @@ async def process_tool_call( ) # Get handlers sorted by priority (highest first) - handlers = sorted( - self._handlers.values(), - key=lambda h: h.priority, - reverse=True, - ) + handlers = self._get_sorted_handlers() # Process through handlers for handler in handlers: diff --git a/tests/unit/core/services/test_tool_call_reactor_service.py b/tests/unit/core/services/test_tool_call_reactor_service.py index 9d47c5194..f375535c0 100644 --- a/tests/unit/core/services/test_tool_call_reactor_service.py +++ b/tests/unit/core/services/test_tool_call_reactor_service.py @@ -314,6 +314,106 @@ async def test_get_registered_handlers(self, reactor): assert "handler1" in handlers assert "handler2" in handlers + @pytest.mark.asyncio + async def test_handler_cache_invalidation_on_register(self, reactor): + """Registering a new handler should rebuild cached ordering.""" + + swallow_result = ToolCallReactionResult(should_swallow=True) + low_priority_handler = MockToolCallHandler( + "low_priority", priority=10, handle_result=swallow_result + ) + await reactor.register_handler(low_priority_handler) + + context = ToolCallContext( + session_id="test_session", + backend_name="test_backend", + model_name="test_model", + full_response='{"content": "test"}', + tool_name="test_tool", + tool_arguments={"arg": "value"}, + ) + + # Prime cached ordering with the existing handler + result = await reactor.process_tool_call(context) + assert result is not None + assert low_priority_handler.handle_call_count == 1 + + high_priority_handler = MockToolCallHandler( + "high_priority", + priority=100, + handle_result=ToolCallReactionResult(should_swallow=True), + ) + + await reactor.register_handler(high_priority_handler) + + context2 = ToolCallContext( + session_id="test_session", + backend_name="test_backend", + model_name="test_model", + full_response='{"content": "test"}', + tool_name="test_tool", + tool_arguments={"arg": "value"}, + ) + + result2 = await reactor.process_tool_call(context2) + + assert result2 is not None and result2.should_swallow is True + assert high_priority_handler.handle_call_count == 1 + assert high_priority_handler.can_handle_call_count == 1 + # High priority handler should swallow before low priority handler is invoked again + assert low_priority_handler.handle_call_count == 1 + + @pytest.mark.asyncio + async def test_handler_cache_invalidation_on_unregister(self, reactor): + """Removing a handler should evict it from the cached ordering.""" + + high_priority_handler = MockToolCallHandler( + "high_priority", + priority=100, + handle_result=ToolCallReactionResult(should_swallow=True), + ) + low_priority_handler = MockToolCallHandler( + "low_priority", + priority=10, + handle_result=ToolCallReactionResult(should_swallow=True), + ) + + await reactor.register_handler(low_priority_handler) + await reactor.register_handler(high_priority_handler) + + context = ToolCallContext( + session_id="test_session", + backend_name="test_backend", + model_name="test_model", + full_response='{"content": "test"}', + tool_name="test_tool", + tool_arguments={"arg": "value"}, + ) + + # First call should be swallowed by the high priority handler + result = await reactor.process_tool_call(context) + assert result is not None and result.should_swallow is True + assert high_priority_handler.handle_call_count == 1 + assert low_priority_handler.handle_call_count == 0 + + await reactor.unregister_handler("high_priority") + + context2 = ToolCallContext( + session_id="test_session", + backend_name="test_backend", + model_name="test_model", + full_response='{"content": "test"}', + tool_name="test_tool", + tool_arguments={"arg": "value"}, + ) + + result2 = await reactor.process_tool_call(context2) + + assert result2 is not None and result2.should_swallow is True + # Low priority handler should now handle the call and high priority handler should not be invoked again + assert low_priority_handler.handle_call_count == 1 + assert high_priority_handler.handle_call_count == 1 + class TestInMemoryToolCallHistoryTracker: """Test cases for InMemoryToolCallHistoryTracker."""