Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion src/core/interfaces/command_service.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import inspect
from collections.abc import Awaitable, Callable
from typing import Any

Expand All @@ -24,7 +25,18 @@ def __init__(self, handler: CommandServiceHandler):
async def process_commands(
self, messages: list[Any], session_id: str
) -> ProcessedResult:
return await self._handler(messages, session_id)
result = self._handler(messages, session_id)

if inspect.isawaitable(result):
return await result

if isinstance(result, ProcessedResult):
return result

raise TypeError(
"The command service handler must return a ProcessedResult or an awaitable"
" resolving to ProcessedResult."
)


def ensure_command_service(
Expand Down
18 changes: 18 additions & 0 deletions tests/unit/core/test_command_service_module.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import pytest

from src.core.domain.processed_result import ProcessedResult
from src.core.interfaces.command_service import ensure_command_service
from src.core.interfaces.command_service_interface import ICommandService
Expand Down Expand Up @@ -48,6 +49,23 @@ async def handler(messages: list[str], session_id: str) -> ProcessedResult:
assert result.command_results == ["session"]


@pytest.mark.asyncio
async def test_ensure_command_service_wraps_sync_callable() -> None:
def handler(messages: list[str], session_id: str) -> ProcessedResult:
return ProcessedResult(
modified_messages=[value.upper() for value in messages],
command_executed=True,
command_results=[session_id],
)

validated_service = ensure_command_service(handler)

result = await validated_service.process_commands(["hello"], "session")
assert result.modified_messages == ["HELLO"]
assert result.command_executed is True
assert result.command_results == ["session"]


def test_ensure_command_service_rejects_none() -> None:
with pytest.raises(ValueError) as exc:
ensure_command_service(None)
Expand Down
Loading