@@ -24,6 +24,13 @@ async def __call__(
2424 ) -> types .ListRootsResult | types .ErrorData : ...
2525
2626
27+ class LoggingFnT (Protocol ):
28+ async def __call__ (
29+ self ,
30+ params : types .LoggingMessageNotificationParams ,
31+ ) -> None : ...
32+
33+
2734async def _default_sampling_callback (
2835 context : RequestContext ["ClientSession" , Any ],
2936 params : types .CreateMessageRequestParams ,
@@ -43,7 +50,15 @@ async def _default_list_roots_callback(
4350 )
4451
4552
46- ClientResponse = TypeAdapter (types .ClientResult | types .ErrorData )
53+ async def _default_logging_callback (
54+ params : types .LoggingMessageNotificationParams ,
55+ ) -> None :
56+ pass
57+
58+
59+ ClientResponse : TypeAdapter [types .ClientResult | types .ErrorData ] = TypeAdapter (
60+ types .ClientResult | types .ErrorData
61+ )
4762
4863
4964class ClientSession (
@@ -62,6 +77,7 @@ def __init__(
6277 read_timeout_seconds : timedelta | None = None ,
6378 sampling_callback : SamplingFnT | None = None ,
6479 list_roots_callback : ListRootsFnT | None = None ,
80+ logging_callback : LoggingFnT | None = None ,
6581 ) -> None :
6682 super ().__init__ (
6783 read_stream ,
@@ -72,20 +88,15 @@ def __init__(
7288 )
7389 self ._sampling_callback = sampling_callback or _default_sampling_callback
7490 self ._list_roots_callback = list_roots_callback or _default_list_roots_callback
91+ self ._logging_callback = logging_callback or _default_logging_callback
7592
7693 async def initialize (self ) -> types .InitializeResult :
77- sampling = (
78- types .SamplingCapability () if self ._sampling_callback is not None else None
79- )
80- roots = (
81- types .RootsCapability (
82- # TODO: Should this be based on whether we
83- # _will_ send notifications, or only whether
84- # they're supported?
85- listChanged = True ,
86- )
87- if self ._list_roots_callback is not None
88- else None
94+ sampling = types .SamplingCapability ()
95+ roots = types .RootsCapability (
96+ # TODO: Should this be based on whether we
97+ # _will_ send notifications, or only whether
98+ # they're supported?
99+ listChanged = True ,
89100 )
90101
91102 result = await self .send_request (
@@ -219,7 +230,7 @@ async def unsubscribe_resource(self, uri: AnyUrl) -> types.EmptyResult:
219230 )
220231
221232 async def call_tool (
222- self , name : str , arguments : dict | None = None
233+ self , name : str , arguments : dict [ str , Any ] | None = None
223234 ) -> types .CallToolResult :
224235 """Send a tools/call request."""
225236 return await self .send_request (
@@ -258,7 +269,9 @@ async def get_prompt(
258269 )
259270
260271 async def complete (
261- self , ref : types .ResourceReference | types .PromptReference , argument : dict
272+ self ,
273+ ref : types .ResourceReference | types .PromptReference ,
274+ argument : dict [str , str ],
262275 ) -> types .CompleteResult :
263276 """Send a completion/complete request."""
264277 return await self .send_request (
@@ -323,3 +336,13 @@ async def _received_request(
323336 return await responder .respond (
324337 types .ClientResult (root = types .EmptyResult ())
325338 )
339+
340+ async def _received_notification (
341+ self , notification : types .ServerNotification
342+ ) -> None :
343+ """Handle notifications from the server."""
344+ match notification .root :
345+ case types .LoggingMessageNotification (params = params ):
346+ await self ._logging_callback (params )
347+ case _:
348+ pass
0 commit comments