Skip to content

Commit 52e6298

Browse files
committed
feat: implement context injection for resources and prompts in FastMCP
This adds automatic context injection support for both resources and prompts in the FastMCP framework, allowing these handlers to receive the request context without explicitly passing it as an argument. Changes: - Add context parameter detection in ResourceTemplate and Prompt classes - Use func_metadata utility to generate schemas while excluding Context params - Pass context through resource_manager.get_resource() and prompt.render() - Update server validation to exclude Context params from URI parameter checks - Add comprehensive test coverage for context injection scenarios The implementation follows the same pattern already used for tools, ensuring consistency across all FastMCP handler types.
1 parent 91686f7 commit 52e6298

File tree

7 files changed

+348
-220
lines changed

7 files changed

+348
-220
lines changed

src/mcp/server/fastmcp/prompts/base.py

Lines changed: 47 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,22 @@
11
"""Base classes for FastMCP prompts."""
22

3+
from __future__ import annotations
4+
35
import inspect
46
from collections.abc import Awaitable, Callable, Sequence
5-
from typing import Any, Literal
7+
from typing import TYPE_CHECKING, Any, Literal, get_origin
68

79
import pydantic_core
810
from pydantic import BaseModel, Field, TypeAdapter, validate_call
911

12+
from mcp.server.fastmcp.utilities.func_metadata import func_metadata
1013
from mcp.types import ContentBlock, TextContent
1114

15+
if TYPE_CHECKING:
16+
from mcp.server.fastmcp.server import Context
17+
from mcp.server.session import ServerSessionT
18+
from mcp.shared.context import LifespanContextT, RequestT
19+
1220

1321
class Message(BaseModel):
1422
"""Base class for all prompt messages."""
@@ -62,6 +70,7 @@ class Prompt(BaseModel):
6270
description: str | None = Field(None, description="Description of what the prompt does")
6371
arguments: list[PromptArgument] | None = Field(None, description="Arguments that can be passed to the prompt")
6472
fn: Callable[..., PromptResult | Awaitable[PromptResult]] = Field(exclude=True)
73+
context_kwarg: str | None = Field(None, description="Name of the kwarg that should receive context", exclude=True)
6574

6675
@classmethod
6776
def from_function(
@@ -70,7 +79,8 @@ def from_function(
7079
name: str | None = None,
7180
title: str | None = None,
7281
description: str | None = None,
73-
) -> "Prompt":
82+
context_kwarg: str | None = None,
83+
) -> Prompt:
7484
"""Create a Prompt from a function.
7585
7686
The function can return:
@@ -84,8 +94,29 @@ def from_function(
8494
if func_name == "<lambda>":
8595
raise ValueError("You must provide a name for lambda functions")
8696

87-
# Get schema from TypeAdapter - will fail if function isn't properly typed
88-
parameters = TypeAdapter(fn).json_schema()
97+
# Find context parameter if it exists
98+
if context_kwarg is None:
99+
from mcp.server.fastmcp.server import Context
100+
101+
sig = inspect.signature(fn)
102+
for param_name, param in sig.parameters.items():
103+
if get_origin(param.annotation) is not None:
104+
continue
105+
if param.annotation is not inspect.Parameter.empty:
106+
try:
107+
if issubclass(param.annotation, Context):
108+
context_kwarg = param_name
109+
break
110+
except TypeError:
111+
# issubclass raises TypeError for non-class types
112+
pass
113+
114+
# Get schema from func_metadata, excluding context parameter
115+
func_arg_metadata = func_metadata(
116+
fn,
117+
skip_names=[context_kwarg] if context_kwarg is not None else [],
118+
)
119+
parameters = func_arg_metadata.arg_model.model_json_schema()
89120

90121
# Convert parameters to PromptArguments
91122
arguments: list[PromptArgument] = []
@@ -109,9 +140,14 @@ def from_function(
109140
description=description or fn.__doc__ or "",
110141
arguments=arguments,
111142
fn=fn,
143+
context_kwarg=context_kwarg,
112144
)
113145

114-
async def render(self, arguments: dict[str, Any] | None = None) -> list[Message]:
146+
async def render(
147+
self,
148+
arguments: dict[str, Any] | None = None,
149+
context: Context[ServerSessionT, LifespanContextT, RequestT] | None = None,
150+
) -> list[Message]:
115151
"""Render the prompt with arguments."""
116152
# Validate required arguments
117153
if self.arguments:
@@ -122,8 +158,13 @@ async def render(self, arguments: dict[str, Any] | None = None) -> list[Message]
122158
raise ValueError(f"Missing required arguments: {missing}")
123159

124160
try:
161+
# Add context to arguments if needed
162+
call_args = arguments or {}
163+
if self.context_kwarg is not None and context is not None:
164+
call_args = {**call_args, self.context_kwarg: context}
165+
125166
# Call function and check if result is a coroutine
126-
result = self.fn(**(arguments or {}))
167+
result = self.fn(**call_args)
127168
if inspect.iscoroutine(result):
128169
result = await result
129170

src/mcp/server/fastmcp/prompts/manager.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,17 @@
11
"""Prompt management functionality."""
22

3-
from typing import Any
3+
from __future__ import annotations
4+
5+
from typing import TYPE_CHECKING, Any
46

57
from mcp.server.fastmcp.prompts.base import Message, Prompt
68
from mcp.server.fastmcp.utilities.logging import get_logger
79

10+
if TYPE_CHECKING:
11+
from mcp.server.fastmcp.server import Context
12+
from mcp.server.session import ServerSessionT
13+
from mcp.shared.context import LifespanContextT, RequestT
14+
815
logger = get_logger(__name__)
916

1017

@@ -39,10 +46,15 @@ def add_prompt(
3946
self._prompts[prompt.name] = prompt
4047
return prompt
4148

42-
async def render_prompt(self, name: str, arguments: dict[str, Any] | None = None) -> list[Message]:
49+
async def render_prompt(
50+
self,
51+
name: str,
52+
arguments: dict[str, Any] | None = None,
53+
context: Context[ServerSessionT, LifespanContextT, RequestT] | None = None,
54+
) -> list[Message]:
4355
"""Render a prompt by name with arguments."""
4456
prompt = self.get_prompt(name)
4557
if not prompt:
4658
raise ValueError(f"Unknown prompt: {name}")
4759

48-
return await prompt.render(arguments)
60+
return await prompt.render(arguments, context=context)

src/mcp/server/fastmcp/resources/resource_manager.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,21 @@
11
"""Resource manager functionality."""
22

3+
from __future__ import annotations
4+
35
from collections.abc import Callable
4-
from typing import Any
6+
from typing import TYPE_CHECKING, Any
57

68
from pydantic import AnyUrl
79

810
from mcp.server.fastmcp.resources.base import Resource
911
from mcp.server.fastmcp.resources.templates import ResourceTemplate
1012
from mcp.server.fastmcp.utilities.logging import get_logger
1113

14+
if TYPE_CHECKING:
15+
from mcp.server.fastmcp.server import Context
16+
from mcp.server.session import ServerSessionT
17+
from mcp.shared.context import LifespanContextT, RequestT
18+
1219
logger = get_logger(__name__)
1320

1421

@@ -67,7 +74,11 @@ def add_template(
6774
self._templates[template.uri_template] = template
6875
return template
6976

70-
async def get_resource(self, uri: AnyUrl | str) -> Resource | None:
77+
async def get_resource(
78+
self,
79+
uri: AnyUrl | str,
80+
context: Context[ServerSessionT, LifespanContextT, RequestT] | None = None,
81+
) -> Resource | None:
7182
"""Get resource by URI, checking concrete resources first, then templates."""
7283
uri_str = str(uri)
7384
logger.debug("Getting resource", extra={"uri": uri_str})
@@ -80,7 +91,7 @@ async def get_resource(self, uri: AnyUrl | str) -> Resource | None:
8091
for template in self._templates.values():
8192
if params := template.matches(uri_str):
8293
try:
83-
return await template.create_resource(uri_str, params)
94+
return await template.create_resource(uri_str, params, context=context)
8495
except Exception as e:
8596
raise ValueError(f"Error creating resource from template: {e}")
8697

src/mcp/server/fastmcp/resources/templates.py

Lines changed: 44 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,17 @@
55
import inspect
66
import re
77
from collections.abc import Callable
8-
from typing import Any
8+
from typing import TYPE_CHECKING, Any, get_origin
99

10-
from pydantic import BaseModel, Field, TypeAdapter, validate_call
10+
from pydantic import BaseModel, Field, validate_call
1111

1212
from mcp.server.fastmcp.resources.types import FunctionResource, Resource
13+
from mcp.server.fastmcp.utilities.func_metadata import func_metadata
14+
15+
if TYPE_CHECKING:
16+
from mcp.server.fastmcp.server import Context
17+
from mcp.server.session import ServerSessionT
18+
from mcp.shared.context import LifespanContextT, RequestT
1319

1420

1521
class ResourceTemplate(BaseModel):
@@ -22,6 +28,7 @@ class ResourceTemplate(BaseModel):
2228
mime_type: str = Field(default="text/plain", description="MIME type of the resource content")
2329
fn: Callable[..., Any] = Field(exclude=True)
2430
parameters: dict[str, Any] = Field(description="JSON schema for function parameters")
31+
context_kwarg: str | None = Field(None, description="Name of the kwarg that should receive context")
2532

2633
@classmethod
2734
def from_function(
@@ -32,14 +39,36 @@ def from_function(
3239
title: str | None = None,
3340
description: str | None = None,
3441
mime_type: str | None = None,
42+
context_kwarg: str | None = None,
3543
) -> ResourceTemplate:
3644
"""Create a template from a function."""
3745
func_name = name or fn.__name__
3846
if func_name == "<lambda>":
3947
raise ValueError("You must provide a name for lambda functions")
4048

41-
# Get schema from TypeAdapter - will fail if function isn't properly typed
42-
parameters = TypeAdapter(fn).json_schema()
49+
# Find context parameter if it exists
50+
if context_kwarg is None:
51+
from mcp.server.fastmcp.server import Context
52+
53+
sig = inspect.signature(fn)
54+
for param_name, param in sig.parameters.items():
55+
if get_origin(param.annotation) is not None:
56+
continue
57+
if param.annotation is not inspect.Parameter.empty:
58+
try:
59+
if issubclass(param.annotation, Context):
60+
context_kwarg = param_name
61+
break
62+
except TypeError:
63+
# issubclass raises TypeError for non-class types
64+
pass
65+
66+
# Get schema from func_metadata, excluding context parameter
67+
func_arg_metadata = func_metadata(
68+
fn,
69+
skip_names=[context_kwarg] if context_kwarg is not None else [],
70+
)
71+
parameters = func_arg_metadata.arg_model.model_json_schema()
4372

4473
# ensure the arguments are properly cast
4574
fn = validate_call(fn)
@@ -52,6 +81,7 @@ def from_function(
5281
mime_type=mime_type or "text/plain",
5382
fn=fn,
5483
parameters=parameters,
84+
context_kwarg=context_kwarg,
5585
)
5686

5787
def matches(self, uri: str) -> dict[str, Any] | None:
@@ -63,9 +93,18 @@ def matches(self, uri: str) -> dict[str, Any] | None:
6393
return match.groupdict()
6494
return None
6595

66-
async def create_resource(self, uri: str, params: dict[str, Any]) -> Resource:
96+
async def create_resource(
97+
self,
98+
uri: str,
99+
params: dict[str, Any],
100+
context: Context[ServerSessionT, LifespanContextT, RequestT] | None = None,
101+
) -> Resource:
67102
"""Create a resource from the template with the given parameters."""
68103
try:
104+
# Add context to params if needed
105+
if self.context_kwarg is not None and context is not None:
106+
params = {**params, self.context_kwarg: context}
107+
69108
# Call function and check if result is a coroutine
70109
result = self.fn(**params)
71110
if inspect.iscoroutine(result):

src/mcp/server/fastmcp/server.py

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -326,7 +326,8 @@ async def list_resource_templates(self) -> list[MCPResourceTemplate]:
326326
async def read_resource(self, uri: AnyUrl | str) -> Iterable[ReadResourceContents]:
327327
"""Read a resource by URI."""
328328

329-
resource = await self._resource_manager.get_resource(uri)
329+
context = self.get_context()
330+
resource = await self._resource_manager.get_resource(uri, context=context)
330331
if not resource:
331332
raise ResourceError(f"Unknown resource: {uri}")
332333

@@ -514,9 +515,27 @@ def decorator(fn: AnyFunction) -> AnyFunction:
514515
has_func_params = bool(inspect.signature(fn).parameters)
515516

516517
if has_uri_params or has_func_params:
517-
# Validate that URI params match function params
518+
# Check for Context parameter to exclude from validation
519+
from typing import get_origin
520+
521+
sig = inspect.signature(fn)
522+
context_param = None
523+
for param_name, param in sig.parameters.items():
524+
if param.annotation is not inspect.Parameter.empty:
525+
# Check if it's a Context type
526+
if get_origin(param.annotation) is None:
527+
# Try to check if it's a Context without importing
528+
# (to avoid circular imports)
529+
annotation_str = str(param.annotation)
530+
if "Context" in annotation_str:
531+
context_param = param_name
532+
break
533+
534+
# Validate that URI params match function params (excluding context)
518535
uri_params = set(re.findall(r"{(\w+)}", uri))
519-
func_params = set(inspect.signature(fn).parameters.keys())
536+
func_params = set(sig.parameters.keys())
537+
if context_param:
538+
func_params.discard(context_param)
520539

521540
if uri_params != func_params:
522541
raise ValueError(
@@ -982,7 +1001,8 @@ async def get_prompt(self, name: str, arguments: dict[str, Any] | None = None) -
9821001
if not prompt:
9831002
raise ValueError(f"Unknown prompt: {name}")
9841003

985-
messages = await prompt.render(arguments)
1004+
context = self.get_context()
1005+
messages = await prompt.render(arguments, context=context)
9861006

9871007
return GetPromptResult(
9881008
description=prompt.description,

tests/server/fastmcp/test_server.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -963,7 +963,7 @@ def resource_with_context(name: str, ctx: Context[ServerSession, None]) -> str:
963963
try:
964964
request_id = ctx.request_id
965965
return f"Resource {name} - request_id: {request_id}"
966-
except:
966+
except (AttributeError, ValueError):
967967
# Context was injected but request context not available
968968
return f"Resource {name} - context injected"
969969

@@ -976,7 +976,7 @@ def resource_with_context(name: str, ctx: Context[ServerSession, None]) -> str:
976976

977977
# Test via client
978978
async with client_session(mcp._mcp_server) as client:
979-
result = await client.read_resource("resource://context/test")
979+
result = await client.read_resource(AnyUrl("resource://context/test"))
980980
assert len(result.contents) == 1
981981
content = result.contents[0]
982982
assert isinstance(content, TextResourceContents)
@@ -1003,7 +1003,7 @@ def resource_no_context(name: str) -> str:
10031003

10041004
# Test via client
10051005
async with client_session(mcp._mcp_server) as client:
1006-
result = await client.read_resource("resource://nocontext/test")
1006+
result = await client.read_resource(AnyUrl("resource://nocontext/test"))
10071007
assert len(result.contents) == 1
10081008
content = result.contents[0]
10091009
assert isinstance(content, TextResourceContents)
@@ -1029,7 +1029,7 @@ def resource_custom_ctx(id: str, my_ctx: Context[ServerSession, None]) -> str:
10291029

10301030
# Test via client
10311031
async with client_session(mcp._mcp_server) as client:
1032-
result = await client.read_resource("resource://custom/123")
1032+
result = await client.read_resource(AnyUrl("resource://custom/123"))
10331033
assert len(result.contents) == 1
10341034
content = result.contents[0]
10351035
assert isinstance(content, TextResourceContents)
@@ -1043,9 +1043,12 @@ async def test_prompt_with_context(self):
10431043
@mcp.prompt("prompt_with_ctx")
10441044
def prompt_with_context(text: str, ctx: Context[ServerSession, None]) -> str:
10451045
"""Prompt that expects context."""
1046-
if ctx and hasattr(ctx, "request_id"):
1047-
return f"Prompt '{text}' with context: {ctx.request_id}"
1048-
return f"Prompt '{text}' - no context"
1046+
assert ctx is not None
1047+
try:
1048+
request_id = ctx.request_id
1049+
return f"Prompt '{text}' with context: {request_id}"
1050+
except (AttributeError, ValueError):
1051+
return f"Prompt '{text}' - context injected"
10491052

10501053
# Check if prompt has context parameter detection
10511054
prompts = mcp._prompt_manager.list_prompts()
@@ -1065,7 +1068,7 @@ def prompt_with_context(text: str, ctx: Context[ServerSession, None]) -> str:
10651068
message = result.messages[0]
10661069
content = message.content
10671070
assert isinstance(content, TextContent)
1068-
if "with context:" in content.text:
1071+
if "context injected" in content.text or "with context:" in content.text:
10691072
# Context injection is working for prompts
10701073
assert has_context_kwarg, "Prompt should have context_kwarg attribute"
10711074
else:

0 commit comments

Comments
 (0)