diff --git a/crates/rmcp/src/model/meta.rs b/crates/rmcp/src/model/meta.rs index 010dd8ca..13a0b4f6 100644 --- a/crates/rmcp/src/model/meta.rs +++ b/crates/rmcp/src/model/meta.rs @@ -4,8 +4,8 @@ use serde::{Deserialize, Serialize}; use serde_json::Value; use super::{ - ClientNotification, ClientRequest, Extensions, JsonObject, NumberOrString, ProgressToken, - ServerNotification, ServerRequest, + ClientNotification, ClientRequest, Extensions, JsonObject, JsonRpcMessage, NumberOrString, + ProgressToken, ServerNotification, ServerRequest, }; pub trait GetMeta { @@ -153,3 +153,42 @@ impl DerefMut for Meta { &mut self.0 } } + +impl JsonRpcMessage +where + Req: GetExtensions, + Noti: GetExtensions, +{ + pub fn insert_extension(&mut self, value: T) { + match self { + JsonRpcMessage::Request(json_rpc_request) => { + json_rpc_request.request.extensions_mut().insert(value); + } + JsonRpcMessage::Notification(json_rpc_notification) => { + json_rpc_notification + .notification + .extensions_mut() + .insert(value); + } + JsonRpcMessage::BatchRequest(json_rpc_batch_request_items) => { + for item in json_rpc_batch_request_items { + match item { + super::JsonRpcBatchRequestItem::Request(json_rpc_request) => { + json_rpc_request + .request + .extensions_mut() + .insert(value.clone()); + } + super::JsonRpcBatchRequestItem::Notification(json_rpc_notification) => { + json_rpc_notification + .notification + .extensions_mut() + .insert(value.clone()); + } + } + } + } + _ => {} + } + } +} diff --git a/crates/rmcp/src/transport/sse_server.rs b/crates/rmcp/src/transport/sse_server.rs index 5389ea9d..bbeecf19 100644 --- a/crates/rmcp/src/transport/sse_server.rs +++ b/crates/rmcp/src/transport/sse_server.rs @@ -3,7 +3,7 @@ use std::{collections::HashMap, io, net::SocketAddr, sync::Arc, time::Duration}; use axum::{ Json, Router, extract::{Query, State}, - http::StatusCode, + http::{StatusCode, request::Parts}, response::{ Response, sse::{Event, KeepAlive, Sse}, @@ -64,7 +64,8 @@ pub struct PostEventQuery { async fn post_event_handler( State(app): State, Query(PostEventQuery { session_id }): Query, - Json(message): Json, + parts: Parts, + Json(mut message): Json, ) -> Result { tracing::debug!(session_id, ?message, "new client message"); let tx = { @@ -73,6 +74,7 @@ async fn post_event_handler( .ok_or(StatusCode::NOT_FOUND)? .clone() }; + message.insert_extension(parts); if tx.send(message).await.is_err() { tracing::error!("send message error"); return Err(StatusCode::GONE); diff --git a/crates/rmcp/src/transport/streamable_http_server/axum.rs b/crates/rmcp/src/transport/streamable_http_server/axum.rs index f01ecec2..385df578 100644 --- a/crates/rmcp/src/transport/streamable_http_server/axum.rs +++ b/crates/rmcp/src/transport/streamable_http_server/axum.rs @@ -3,7 +3,7 @@ use std::{collections::HashMap, io, net::SocketAddr, sync::Arc, time::Duration}; use axum::{ Json, Router, extract::State, - http::{HeaderMap, HeaderValue, StatusCode}, + http::{HeaderMap, HeaderValue, StatusCode, request::Parts}, response::{ IntoResponse, Response, sse::{Event, KeepAlive, Sse}, @@ -68,11 +68,11 @@ fn receiver_as_stream( async fn post_handler( State(app): State, - header_map: HeaderMap, - Json(message): Json, + parts: Parts, + Json(mut message): Json, ) -> Result { use futures::StreamExt; - if let Some(session_id) = header_map.get(HEADER_SESSION_ID) { + if let Some(session_id) = parts.headers.get(HEADER_SESSION_ID).cloned() { let session_id = session_id .to_str() .map_err(|e| (StatusCode::BAD_REQUEST, e.to_string()).into_response())?; @@ -84,6 +84,8 @@ async fn post_handler( .ok_or((StatusCode::NOT_FOUND, "session not found").into_response())?; session.handle().clone() }; + // inject request part + message.insert_extension(parts); match &message { ClientJsonRpcMessage::Request(_) | ClientJsonRpcMessage::BatchRequest(_) => { let receiver = handle.establish_request_wise_channel().await.map_err(|e| { @@ -128,6 +130,8 @@ async fn post_handler( } else { // expect initialize message let session_id = session_id(); + // inject request part + message.insert_extension(parts); let (session, transport) = super::session::create_session(session_id.clone(), Default::default()); let Ok(_) = app.transport_tx.send(transport) else { diff --git a/examples/servers/src/common/counter.rs b/examples/servers/src/common/counter.rs index 7bed523a..12aa8a4a 100644 --- a/examples/servers/src/common/counter.rs +++ b/examples/servers/src/common/counter.rs @@ -17,6 +17,7 @@ pub struct StructRequest { pub struct Counter { counter: Arc>, } + #[tool(tool_box)] impl Counter { #[allow(dead_code)] @@ -194,4 +195,17 @@ impl ServerHandler for Counter { resource_templates: Vec::new(), }) } + + async fn initialize( + &self, + _request: InitializeRequestParam, + context: RequestContext, + ) -> Result { + if let Some(http_request_part) = context.extensions.get::() { + let initialize_headers = &http_request_part.headers; + let initialize_uri = &http_request_part.uri; + tracing::info!(?initialize_headers, %initialize_uri, "initialize from http server"); + } + Ok(self.get_info()) + } }