diff --git a/src/core/services/streaming/tool_call_repair_processor.py b/src/core/services/streaming/tool_call_repair_processor.py index 530314629..b502c5c0c 100644 --- a/src/core/services/streaming/tool_call_repair_processor.py +++ b/src/core/services/streaming/tool_call_repair_processor.py @@ -24,6 +24,7 @@ class ToolCallRepairProcessor(IStreamProcessor): def __init__(self, tool_call_repair_service: IToolCallRepairService) -> None: self.tool_call_repair_service = tool_call_repair_service self._buffers: dict[str, str] = {} + self._max_buffer_bytes = self._resolve_buffer_cap(tool_call_repair_service) async def process(self, content: StreamingContent) -> StreamingContent: """ @@ -36,6 +37,7 @@ async def process(self, content: StreamingContent) -> StreamingContent: buffer = self._buffers.get(stream_id, "") buffer += content.content or "" + buffer = self._enforce_buffer_cap(stream_id, buffer) repaired_content_parts: list[str] = [] remaining_buffer = buffer @@ -101,3 +103,58 @@ async def process(self, content: StreamingContent) -> StreamingContent: content="", is_cancellation=content.is_cancellation, ) # Return empty if nothing to yield + + def _resolve_buffer_cap(self, service: IToolCallRepairService) -> int: + """Determine the maximum buffer size supported by the repair service.""" + + default_cap = 64 * 1024 + candidate = getattr(service, "max_buffer_bytes", default_cap) + try: + cap_value = int(candidate) + except (TypeError, ValueError): + logger.warning( + "Invalid tool call repair buffer cap %r; using default %d bytes", + candidate, + default_cap, + ) + return default_cap + if cap_value < 0: + logger.warning( + "Negative tool call repair buffer cap %d received; treating as zero", + cap_value, + ) + return 0 + return cap_value + + def _enforce_buffer_cap(self, stream_id: str, buffer: str) -> str: + """Ensure per-stream buffer usage stays within configured limits.""" + + cap = self._max_buffer_bytes + if cap == 0: + if buffer: + logger.warning( + "Dropping streaming tool call buffer for stream %s because cap is 0", + stream_id, + ) + return "" + + if cap < 0: + return buffer + + if not buffer: + return buffer + + buffer_bytes = buffer.encode("utf-8") + current_size = len(buffer_bytes) + if current_size <= cap: + return buffer + + truncated_bytes = buffer_bytes[-cap:] + dropped = current_size - cap + logger.warning( + "Tool call repair buffer for stream %s exceeded %d bytes; dropping %d bytes", + stream_id, + cap, + dropped, + ) + return truncated_bytes.decode("utf-8", errors="ignore") diff --git a/src/core/services/tool_call_repair_service.py b/src/core/services/tool_call_repair_service.py index 3dbc6f747..28a67e334 100644 --- a/src/core/services/tool_call_repair_service.py +++ b/src/core/services/tool_call_repair_service.py @@ -35,7 +35,17 @@ def __init__(self, max_buffer_bytes: int | None = None) -> None: ) # Cap per-session buffer to guard against pathological streams - self._max_buffer_bytes: int = max_buffer_bytes or (64 * 1024) # default 64 KB + self._max_buffer_bytes: int = ( + int(max_buffer_bytes) if max_buffer_bytes is not None else 64 * 1024 + ) + if self._max_buffer_bytes < 0: + self._max_buffer_bytes = 0 + + @property + def max_buffer_bytes(self) -> int: + """Return the configured maximum buffer size for streaming repair.""" + + return self._max_buffer_bytes def repair_tool_calls(self, response_content: str) -> dict[str, Any] | None: """ diff --git a/tests/unit/core/services/test_tool_call_repair.py b/tests/unit/core/services/test_tool_call_repair.py index 80b6e3780..ec4c85370 100644 --- a/tests/unit/core/services/test_tool_call_repair.py +++ b/tests/unit/core/services/test_tool_call_repair.py @@ -3,6 +3,7 @@ import pytest from pytest_mock import MockerFixture +from src.core.domain.streaming_response_processor import StreamingContent from src.core.interfaces.response_processor_interface import ProcessedResponse from src.core.services.streaming.tool_call_repair_processor import ( ToolCallRepairProcessor, @@ -72,8 +73,6 @@ async def test_process_chunks_with_tool_call( streaming_processor: StreamingToolCallRepairProcessor, mocker: MockerFixture, ) -> None: - from src.core.domain.streaming_response_processor import StreamingContent - # Mock the underlying ToolCallRepairProcessor's process method # This is where the actual repair logic is now encapsulated mock_tool_call_repair_processor_process = mocker.AsyncMock( @@ -150,3 +149,47 @@ async def mock_async_chunks_generator() -> ( ) assert actual_calls[2].content == "World." assert actual_calls[3].is_done is True and actual_calls[3].content == "" + + +class TestToolCallRepairProcessorBehavior: + @pytest.mark.asyncio + async def test_buffer_truncated_when_cap_exceeded( + self, mocker: MockerFixture + ) -> None: + service = ToolCallRepairService(max_buffer_bytes=32) + processor = ToolCallRepairProcessor(service) + + stream_id = "stream-cap" + large_payload = "x" * 100 + + repair_mock = mocker.patch.object( + service, "repair_tool_calls", return_value=None + ) + + await processor.process( + StreamingContent(content=large_payload, metadata={"stream_id": stream_id}) + ) + + stored_buffer = processor._buffers.get(stream_id, "") + assert len(stored_buffer.encode("utf-8")) <= service.max_buffer_bytes + + repair_mock.assert_called_once() + processed_buffer = repair_mock.call_args[0][0] + assert len(processed_buffer.encode("utf-8")) <= service.max_buffer_bytes + + @pytest.mark.asyncio + async def test_buffer_dropped_when_cap_zero(self, mocker: MockerFixture) -> None: + service = ToolCallRepairService(max_buffer_bytes=0) + processor = ToolCallRepairProcessor(service) + + stream_id = "stream-zero" + repair_mock = mocker.patch.object( + service, "repair_tool_calls", return_value=None + ) + + await processor.process( + StreamingContent(content="payload", metadata={"stream_id": stream_id}) + ) + + assert stream_id not in processor._buffers + repair_mock.assert_called_once_with("")