Skip to content

Extensions Empty #274

@j-mendez

Description

@j-mendez

When using hyper the extensions are not passed to the Server.

use rmcp::{model::*, tool};
use serde::{Deserialize, Serialize};
use rmcp::Error as McpError;
use rmcp::handler::server::tool::{FromToolCallContextPart, ToolCallContext};
use hyper::HeaderMap;

#[derive(Debug,)]
pub struct ReqHeaders(pub HeaderMap);

impl<'a, S> FromToolCallContextPart<'a, S> for ReqHeaders {
    fn from_tool_call_context_part(
        context: ToolCallContext<'a, S>,
    ) -> Result<(Self, ToolCallContext<'a, S>), McpError> {
        println!("{:?}", context.request_context().extensions.len());

        match context.request_context().extensions.get::<HeaderMap>() {
            Some(headers) => {
                
                Ok((ReqHeaders(headers.clone()), context))
            },
            None =>Ok((ReqHeaders(HeaderMap::new()), context)),
        }
    }
}

#[derive(Clone)]
pub struct Counter;

#[tool(tool_box)]
impl Counter {
    pub fn new() -> Self {
        Self
    }

#[tool(description = "Return pong")]
pub async fn ping(
    &self,
     ReqHeaders(headers): ReqHeaders

) -> Result<CallToolResult, rmcp::Error> {
    println!("{:?}", headers);
    Ok(CallToolResult::success(vec![Content::text("pong")]))
}
    #[tool(description = "Repeat what you say")]
    fn echo(
        &self,
        #[tool(param)]
        #[schemars(description = "Repeat what you say")]
        saying: String,
    ) -> Result<CallToolResult, rmcp::Error> {
        Ok(CallToolResult::success(vec![Content::text(saying)]))
    }
}

#[tool(tool_box)]
impl rmcp::ServerHandler for Counter {
    fn get_info(&self) -> ServerInfo {
        ServerInfo {
            instructions: Some("Ping tool.".into()),
            capabilities: ServerCapabilities::builder()
                .enable_tools()
                .build(),
            ..Default::default()
        }
    }

    async fn initialize(
        &self,
        _request: InitializeRequestParam,
        context: rmcp::service::RequestContext<rmcp::RoleServer>,
    ) -> Result<InitializeResult, rmcp::Error> {
        if let Some(http_request_part) = context.extensions.get::<axum::http::request::Parts>() {
            let initialize_headers = &http_request_part.headers;
            let initialize_uri = &http_request_part.uri;
            tracing::info!(?initialize_headers, %initialize_uri, "initialize from http server");
        }

        Ok(self.get_info())
    }
}

pub type RmcpType =
    hyper_util::service::TowerToHyperService<rmcp::transport::StreamableHttpService<Counter>>;

Setup the server

    let mut conf: rmcp::transport::StreamableHttpServerConfig = Default::default();
    conf.stateful_mode = false;

    let rmcp_service: StreamableHttpService<rmpc_server::Counter> = StreamableHttpService::new(
        || Ok(crate::rmpc_server::Counter::new()),
        LocalSessionManager::default().into(),
        conf,
    );

Inside the hyper serve_connection_with_upgrades:

let mut response = rmcp_service.handle(request).await;

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions