Skip to content

Commit 426fb8a

Browse files
committed
fix: simplify context injection tests and improve server implementation
- Use find_context_parameter utility in server.py for cleaner code - Simplify test assertions to check for context injection consistently - Remove conditional checks and try-catch blocks in tests - Improve readability by removing unnecessary complexity Co-Authored-By: David S <[email protected]>
1 parent b52dd32 commit 426fb8a

File tree

2 files changed

+20
-63
lines changed

2 files changed

+20
-63
lines changed

src/mcp/server/fastmcp/server.py

Lines changed: 8 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from mcp.server.fastmcp.prompts import Prompt, PromptManager
3131
from mcp.server.fastmcp.resources import FunctionResource, Resource, ResourceManager
3232
from mcp.server.fastmcp.tools import Tool, ToolManager
33+
from mcp.server.fastmcp.utilities.context_injection import find_context_parameter
3334
from mcp.server.fastmcp.utilities.logging import configure_logging, get_logger
3435
from mcp.server.lowlevel.helper_types import ReadResourceContents
3536
from mcp.server.lowlevel.server import LifespanResultT
@@ -511,31 +512,19 @@ async def get_weather(city: str) -> str:
511512

512513
def decorator(fn: AnyFunction) -> AnyFunction:
513514
# Check if this should be a template
515+
sig = inspect.signature(fn)
514516
has_uri_params = "{" in uri and "}" in uri
515-
has_func_params = bool(inspect.signature(fn).parameters)
517+
has_func_params = bool(sig.parameters)
516518

517519
if has_uri_params or has_func_params:
518520
# Check for Context parameter to exclude from validation
519-
from typing import get_origin
520-
521-
sig = inspect.signature(fn)
522-
context_param = None
523-
for param_name, param in sig.parameters.items():
524-
if param.annotation is not inspect.Parameter.empty:
525-
# Check if it's a Context type
526-
if get_origin(param.annotation) is None:
527-
# Try to check if it's a Context without importing
528-
# (to avoid circular imports)
529-
annotation_str = str(param.annotation)
530-
if "Context" in annotation_str:
531-
context_param = param_name
532-
break
521+
context_param = find_context_parameter(fn)
533522

534523
# Validate that URI params match function params (excluding context)
535524
uri_params = set(re.findall(r"{(\w+)}", uri))
536-
func_params = set(sig.parameters.keys())
537-
if context_param:
538-
func_params.discard(context_param)
525+
# We need to remove the context_param from the resource function if
526+
# there is any.
527+
func_params = {p for p in sig.parameters.keys() if p != context_param}
539528

540529
if uri_params != func_params:
541530
raise ValueError(
@@ -1001,8 +990,7 @@ async def get_prompt(self, name: str, arguments: dict[str, Any] | None = None) -
1001990
if not prompt:
1002991
raise ValueError(f"Unknown prompt: {name}")
1003992

1004-
context = self.get_context()
1005-
messages = await prompt.render(arguments, context=context)
993+
messages = await prompt.render(arguments, context=self.get_context())
1006994

1007995
return GetPromptResult(
1008996
description=prompt.description,

tests/server/fastmcp/test_server.py

Lines changed: 12 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -958,14 +958,7 @@ async def test_resource_with_context(self):
958958
def resource_with_context(name: str, ctx: Context[ServerSession, None]) -> str:
959959
"""Resource that receives context."""
960960
assert ctx is not None
961-
# Context should be provided even if request_id might not be accessible
962-
# in all contexts
963-
try:
964-
request_id = ctx.request_id
965-
return f"Resource {name} - request_id: {request_id}"
966-
except (AttributeError, ValueError):
967-
# Context was injected but request context not available
968-
return f"Resource {name} - context injected"
961+
return f"Resource {name} - context injected"
969962

970963
# Verify template has context_kwarg set
971964
templates = mcp._resource_manager.list_templates()
@@ -981,8 +974,7 @@ def resource_with_context(name: str, ctx: Context[ServerSession, None]) -> str:
981974
content = result.contents[0]
982975
assert isinstance(content, TextResourceContents)
983976
# Should have either request_id or indication that context was injected
984-
assert "Resource test" in content.text
985-
assert "request_id:" in content.text or "context injected" in content.text
977+
assert "Resource test - context injected" == content.text
986978

987979
@pytest.mark.anyio
988980
async def test_resource_without_context(self):
@@ -998,8 +990,7 @@ def resource_no_context(name: str) -> str:
998990
templates = mcp._resource_manager.list_templates()
999991
assert len(templates) == 1
1000992
template = templates[0]
1001-
if hasattr(template, "context_kwarg"):
1002-
assert template.context_kwarg is None
993+
assert template.context_kwarg is None
1003994

1004995
# Test via client
1005996
async with client_session(mcp._mcp_server) as client:
@@ -1024,8 +1015,7 @@ def resource_custom_ctx(id: str, my_ctx: Context[ServerSession, None]) -> str:
10241015
templates = mcp._resource_manager.list_templates()
10251016
assert len(templates) == 1
10261017
template = templates[0]
1027-
if hasattr(template, "context_kwarg"):
1028-
assert template.context_kwarg == "my_ctx"
1018+
assert template.context_kwarg == "my_ctx"
10291019

10301020
# Test via client
10311021
async with client_session(mcp._mcp_server) as client:
@@ -1044,42 +1034,21 @@ async def test_prompt_with_context(self):
10441034
def prompt_with_context(text: str, ctx: Context[ServerSession, None]) -> str:
10451035
"""Prompt that expects context."""
10461036
assert ctx is not None
1047-
try:
1048-
request_id = ctx.request_id
1049-
return f"Prompt '{text}' with context: {request_id}"
1050-
except (AttributeError, ValueError):
1051-
return f"Prompt '{text}' - context injected"
1037+
return f"Prompt '{text}' - context injected"
10521038

10531039
# Check if prompt has context parameter detection
10541040
prompts = mcp._prompt_manager.list_prompts()
10551041
assert len(prompts) == 1
1056-
prompt = prompts[0]
1057-
1058-
# Check if context_kwarg attribute exists (for future implementation)
1059-
has_context_kwarg = hasattr(prompt, "context_kwarg")
10601042

10611043
# Test via client
10621044
async with client_session(mcp._mcp_server) as client:
1063-
try:
1064-
# Try calling without passing ctx explicitly
1065-
result = await client.get_prompt("prompt_with_ctx", {"text": "test"})
1066-
# If this succeeds, check if context was injected
1067-
assert len(result.messages) == 1
1068-
message = result.messages[0]
1069-
content = message.content
1070-
assert isinstance(content, TextContent)
1071-
if "context injected" in content.text or "with context:" in content.text:
1072-
# Context injection is working for prompts
1073-
assert has_context_kwarg, "Prompt should have context_kwarg attribute"
1074-
else:
1075-
# Context was not injected
1076-
pytest.skip("Prompt context injection not yet implemented")
1077-
except Exception as e:
1078-
if "Missing required arguments" in str(e) and "ctx" in str(e):
1079-
# Context injection not working - expected for now
1080-
pytest.skip("Prompt context injection not yet implemented")
1081-
else:
1082-
raise
1045+
# Try calling without passing ctx explicitly
1046+
result = await client.get_prompt("prompt_with_ctx", {"text": "test"})
1047+
# If this succeeds, check if context was injected
1048+
assert len(result.messages) == 1
1049+
content = result.messages[0].content
1050+
assert isinstance(content, TextContent)
1051+
assert "Prompt 'test' - context injected" in content.text
10831052

10841053
@pytest.mark.anyio
10851054
async def test_prompt_without_context(self):

0 commit comments

Comments
 (0)