Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file added examples/shared/__init__.py
Empty file.
179 changes: 179 additions & 0 deletions examples/shared/in_memory_task_store.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
"""
In-memory implementation of TaskStore for demonstration purposes.

This implementation stores all tasks in memory and provides automatic cleanup
based on the keepAlive duration specified in the task metadata.

Note: This is not suitable for production use as all data is lost on restart.
For production, consider implementing TaskStore with a database or distributed cache.
"""

import asyncio
from dataclasses import dataclass
from typing import Any

from mcp.shared.task import TaskStatus, TaskStore, is_terminal
from mcp.types import Request, RequestId, Result, Task, TaskMetadata


@dataclass
class StoredTask:
"""Internal storage representation of a task."""

task: Task
request: Request[Any, Any]
request_id: RequestId
result: Result | None = None


class InMemoryTaskStore(TaskStore):
"""
A simple in-memory implementation of TaskStore for demonstration purposes.

This implementation stores all tasks in memory and provides automatic cleanup
based on the keepAlive duration specified in the task metadata.

Note: This is not suitable for production use as all data is lost on restart.
For production, consider implementing TaskStore with a database or distributed cache.
"""

def __init__(self) -> None:
self._tasks: dict[str, StoredTask] = {}
self._cleanup_tasks: dict[str, asyncio.Task[None]] = {}

async def create_task(
self, task: TaskMetadata, request_id: RequestId, request: Request[Any, Any], session_id: str | None = None
) -> None:
"""Create a new task with the given metadata and original request."""
task_id = task.taskId

if task_id in self._tasks:
raise ValueError(f"Task with ID {task_id} already exists")

task_obj = Task(
taskId=task_id,
status="submitted",
keepAlive=task.keepAlive,
pollInterval=500, # Default 500ms poll frequency
)

self._tasks[task_id] = StoredTask(task=task_obj, request=request, request_id=request_id)

# Schedule cleanup if keepAlive is specified
if task.keepAlive is not None:
self._schedule_cleanup(task_id, task.keepAlive / 1000.0)

async def get_task(self, task_id: str, session_id: str | None = None) -> Task | None:
"""Get the current status of a task."""
stored = self._tasks.get(task_id)
if stored is None:
return None

# Return a copy to prevent external modification
return Task(**stored.task.model_dump())

async def store_task_result(self, task_id: str, result: Result, session_id: str | None = None) -> None:
"""Store the result of a completed task."""
stored = self._tasks.get(task_id)
if stored is None:
raise ValueError(f"Task with ID {task_id} not found")

stored.result = result
stored.task.status = "completed"

# Reset cleanup timer to start from now (if keepAlive is set)
if stored.task.keepAlive is not None:
self._cancel_cleanup(task_id)
self._schedule_cleanup(task_id, stored.task.keepAlive / 1000.0)

async def get_task_result(self, task_id: str, session_id: str | None = None) -> Result:
"""Retrieve the stored result of a task."""
stored = self._tasks.get(task_id)
if stored is None:
raise ValueError(f"Task with ID {task_id} not found")

if stored.result is None:
raise ValueError(f"Task {task_id} has no result stored")

return stored.result

async def update_task_status(
self, task_id: str, status: TaskStatus, error: str | None = None, session_id: str | None = None
) -> None:
"""Update a task's status."""
stored = self._tasks.get(task_id)
if stored is None:
raise ValueError(f"Task with ID {task_id} not found")

stored.task.status = status
if error is not None:
stored.task.error = error

# If task is in a terminal state and has keepAlive, start cleanup timer
if is_terminal(status) and stored.task.keepAlive is not None:
self._cancel_cleanup(task_id)
self._schedule_cleanup(task_id, stored.task.keepAlive / 1000.0)

async def list_tasks(self, cursor: str | None = None, session_id: str | None = None) -> dict[str, Any]:
"""
List tasks, optionally starting from a pagination cursor.

Returns a dict with 'tasks' list and optional 'nextCursor' string.
"""
PAGE_SIZE = 10
all_task_ids = list(self._tasks.keys())

start_index = 0
if cursor is not None:
try:
cursor_index = all_task_ids.index(cursor)
start_index = cursor_index + 1
except ValueError:
raise ValueError(f"Invalid cursor: {cursor}")

page_task_ids = all_task_ids[start_index : start_index + PAGE_SIZE]
tasks = [Task(**self._tasks[tid].task.model_dump()) for tid in page_task_ids]

next_cursor = page_task_ids[-1] if start_index + PAGE_SIZE < len(all_task_ids) and page_task_ids else None

return {"tasks": tasks, "nextCursor": next_cursor}

async def delete_task(self, task_id: str, session_id: str | None = None) -> None:
"""Delete a task from storage."""
if task_id not in self._tasks:
raise ValueError(f"Task with ID {task_id} not found")

# Cancel any scheduled cleanup
self._cancel_cleanup(task_id)

# Remove the task
self._tasks.pop(task_id)

def _schedule_cleanup(self, task_id: str, delay_seconds: float) -> None:
"""Schedule automatic cleanup of a task after the specified delay."""

async def cleanup() -> None:
await asyncio.sleep(delay_seconds)
self._tasks.pop(task_id, None)
self._cleanup_tasks.pop(task_id, None)

task = asyncio.create_task(cleanup())
self._cleanup_tasks[task_id] = task

def _cancel_cleanup(self, task_id: str) -> None:
"""Cancel any scheduled cleanup for a task."""
cleanup_task = self._cleanup_tasks.pop(task_id, None)
if cleanup_task is not None and not cleanup_task.done():
cleanup_task.cancel()

def cleanup(self) -> None:
"""Cleanup all timers and tasks (useful for testing or graceful shutdown)."""
for task in self._cleanup_tasks.values():
if not task.done():
task.cancel()
self._cleanup_tasks.clear()
self._tasks.clear()

def get_all_tasks(self) -> list[Task]:
"""Get all tasks (useful for debugging). Returns copies to prevent modification."""
return [Task(**stored.task.model_dump()) for stored in self._tasks.values()]
34 changes: 34 additions & 0 deletions examples/snippets/clients/streamable_task_get_result.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
"""
Run from the repository root:
uv run examples/snippets/clients/task_based_tool_client.py

Prerequisites:
The task_based_tool server must be running on http://localhost:8000
Start it with:
cd examples/snippets && uv run server task_based_tool streamable-http
"""

import asyncio

from mcp import ClientSession
from mcp.client.streamable_http import MCP_SESSION_ID, streamablehttp_client
from mcp.types import CallToolResult


async def main():
async with streamablehttp_client(
"http://localhost:3000/mcp",
headers={MCP_SESSION_ID: "5771f709-66f5-4176-9f32-ce91e3117df2"},
terminate_on_close=False,
) as (
read_stream,
write_stream,
_,
):
async with ClientSession(read_stream, write_stream) as session:
result = await session.get_task_result("736054ac-5f10-409e-a06a-526761ea827a", CallToolResult)
print(result)


if __name__ == "__main__":
asyncio.run(main())
122 changes: 122 additions & 0 deletions examples/snippets/clients/streamable_tasks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
"""
Run from the repository root:
uv run examples/snippets/clients/task_based_tool_client.py

Prerequisites:
The task_based_tool server must be running on http://localhost:8000
Start it with:
cd examples/snippets && uv run server task_based_tool streamable-http
"""

import asyncio

from mcp import ClientSession, types
from mcp.client.streamable_http import streamablehttp_client
from mcp.shared.context import RequestContext
from mcp.shared.request import TaskHandlerOptions


async def elicitation_handler(
context: RequestContext[ClientSession, None], params: types.ElicitRequestParams
) -> types.ElicitResult | types.ErrorData:
"""
Handle elicitation requests from the server.

This handler collects user feedback with a predefined schema including:
- rating (1-5, required)
- comments (optional text up to 500 chars)
- recommend (boolean, required)
"""
print(f"\n🎯 Elicitation request received: {params.message}")
print(f"Schema: {params.requestedSchema}")
await asyncio.sleep(5)

# In a real application, you would collect this data from the user
# For this example, we'll return mock data
feedback_data: dict[str, str | int | float | bool | None] = {
"rating": 5,
"comments": "The task execution was excellent and fast!",
"recommend": True,
}

print(f"📝 Returning feedback: {feedback_data}")

return types.ElicitResult(action="accept", content=feedback_data)


async def main():
"""
Demonstrate task-based execution with begin_call_tool.

This example shows how to:
1. Start a long-running tool call with begin_call_tool()
2. Get task status updates through callbacks
3. Wait for the final result with polling
4. Handle elicitation requests from the server
"""
# Connect to the task-based tool example server via streamable HTTP
async with streamablehttp_client("http://localhost:3000/mcp", terminate_on_close=False) as (
read_stream,
write_stream,
_,
):
async with ClientSession(read_stream, write_stream, elicitation_callback=elicitation_handler) as session:
# Initialize the connection
await session.initialize()

print("Starting task-based tool execution...")

# Track callback invocations
task_created = False
status_updates: list[str] = []

async def on_task_created() -> None:
"""Called when the task is first created."""
nonlocal task_created
task_created = True
print("✓ Task created on server")

async def on_task_status(task_result: types.GetTaskResult) -> None:
"""Called whenever the task status is polled."""
status_updates.append(task_result.status)
print(f" Status ({task_result.taskId}): {task_result.status}")

# Begin the tool call (returns immediately with a PendingRequest)
print("\nCalling begin_call_tool...")
# pending_request = session.begin_call_tool(
# "collect-user-info",
# arguments={"infoType": "feedback"},
# )
pending_request = session.begin_call_tool(
"delay",
arguments={},
)

print("Tool call initiated! Now waiting for result with task polling...\n")

# Wait for the result with task callbacks
result = await pending_request.result(
TaskHandlerOptions(on_task_created=on_task_created, on_task_status=on_task_status)
)

# Display the result
print("\n✓ Tool execution completed!")
if result.content:
content_block = result.content[0]
if isinstance(content_block, types.TextContent):
print(f"Result: {content_block.text}")
else:
print(f"Result: {content_block}")
else:
print("Result: No content")

# Show callback statistics
print("\nTask callbacks:")
print(f" - Task created callback: {'Yes' if task_created else 'No'}")
print(f" - Status updates received: {len(status_updates)}")
if status_updates:
print(f" - Final status: {status_updates[-1]}")


if __name__ == "__main__":
asyncio.run(main())
2 changes: 1 addition & 1 deletion examples/snippets/servers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def run_server():
print("Usage: server <server-name> [transport]")
print("Available servers: basic_tool, basic_resource, basic_prompt, tool_progress,")
print(" sampling, elicitation, completion, notifications,")
print(" fastmcp_quickstart, structured_output, images")
print(" fastmcp_quickstart, structured_output, images, task_based_tool")
print("Available transports: stdio (default), sse, streamable-http")
sys.exit(1)

Expand Down
32 changes: 32 additions & 0 deletions examples/snippets/servers/task_based_tool.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
"""Example server demonstrating task-based execution with long-running tools."""

import asyncio

from examples.shared.in_memory_task_store import InMemoryTaskStore
from mcp.server.fastmcp import FastMCP

# Create a task store to enable task-based execution
task_store = InMemoryTaskStore()
mcp = FastMCP(name="Task-Based Tool Example", task_store=task_store)


@mcp.tool()
async def long_running_computation(data: str, delay_seconds: float = 2.0) -> str:
"""
Simulate a long-running computation that benefits from task-based execution.

This tool demonstrates the 'call-now, fetch-later' pattern where clients can:
1. Initiate the task without waiting
2. Disconnect and reconnect later
3. Poll for status and retrieve results when ready

Args:
data: Input data to process
delay_seconds: Simulated processing time
"""
# Simulate long-running work
await asyncio.sleep(delay_seconds)

# Return processed result
result = f"Processed: {data.upper()} (took {delay_seconds}s)"
return result
Loading
Loading