diff --git a/src/core/interfaces/command_service.py b/src/core/interfaces/command_service.py index b108654a1..403bdae56 100644 --- a/src/core/interfaces/command_service.py +++ b/src/core/interfaces/command_service.py @@ -1,5 +1,6 @@ from __future__ import annotations +import inspect from collections.abc import Awaitable, Callable from typing import Any @@ -24,7 +25,18 @@ def __init__(self, handler: CommandServiceHandler): async def process_commands( self, messages: list[Any], session_id: str ) -> ProcessedResult: - return await self._handler(messages, session_id) + result = self._handler(messages, session_id) + + if inspect.isawaitable(result): + return await result + + if isinstance(result, ProcessedResult): + return result + + raise TypeError( + "The command service handler must return a ProcessedResult or an awaitable" + " resolving to ProcessedResult." + ) def ensure_command_service( diff --git a/tests/unit/core/test_command_service_module.py b/tests/unit/core/test_command_service_module.py index bc9291043..f63facd5f 100644 --- a/tests/unit/core/test_command_service_module.py +++ b/tests/unit/core/test_command_service_module.py @@ -1,4 +1,5 @@ import pytest + from src.core.domain.processed_result import ProcessedResult from src.core.interfaces.command_service import ensure_command_service from src.core.interfaces.command_service_interface import ICommandService @@ -48,6 +49,23 @@ async def handler(messages: list[str], session_id: str) -> ProcessedResult: assert result.command_results == ["session"] +@pytest.mark.asyncio +async def test_ensure_command_service_wraps_sync_callable() -> None: + def handler(messages: list[str], session_id: str) -> ProcessedResult: + return ProcessedResult( + modified_messages=[value.upper() for value in messages], + command_executed=True, + command_results=[session_id], + ) + + validated_service = ensure_command_service(handler) + + result = await validated_service.process_commands(["hello"], "session") + assert result.modified_messages == ["HELLO"] + assert result.command_executed is True + assert result.command_results == ["session"] + + def test_ensure_command_service_rejects_none() -> None: with pytest.raises(ValueError) as exc: ensure_command_service(None)