diff --git a/crates/rmcp/src/handler/client.rs b/crates/rmcp/src/handler/client.rs index 13bd9677..3e91a7e2 100644 --- a/crates/rmcp/src/handler/client.rs +++ b/crates/rmcp/src/handler/client.rs @@ -53,14 +53,6 @@ impl Service for H { Ok(()) } - fn get_peer(&self) -> Option> { - self.get_peer() - } - - fn set_peer(&mut self, peer: Peer) { - self.set_peer(peer); - } - fn get_info(&self) -> ::Info { self.get_info() } diff --git a/crates/rmcp/src/handler/server.rs b/crates/rmcp/src/handler/server.rs index 808c9604..98199282 100644 --- a/crates/rmcp/src/handler/server.rs +++ b/crates/rmcp/src/handler/server.rs @@ -89,14 +89,6 @@ impl Service for H { Ok(()) } - fn get_peer(&self) -> Option> { - self.get_peer() - } - - fn set_peer(&mut self, peer: Peer) { - self.set_peer(peer); - } - fn get_info(&self) -> ::Info { self.get_info() } diff --git a/crates/rmcp/src/service.rs b/crates/rmcp/src/service.rs index cb902fc3..84d62beb 100644 --- a/crates/rmcp/src/service.rs +++ b/crates/rmcp/src/service.rs @@ -36,8 +36,10 @@ use tracing::instrument; pub enum ServiceError { #[error("Mcp error: {0}")] McpError(McpError), - #[error("Transport error: {0}")] - Transport(std::io::Error), + #[error("Transport send error: {0}")] + TransportSend(Box), + #[error("Transport closed")] + TransportClosed, #[error("Unexpected response type")] UnexpectedResponse, #[error("task cancelled for reason {}", reason.as_deref().unwrap_or(""))] @@ -98,8 +100,6 @@ pub trait Service: Send + Sync + 'static { &self, notification: R::PeerNot, ) -> impl Future> + Send + '_; - fn get_peer(&self) -> Option>; - fn set_peer(&mut self, peer: Peer); fn get_info(&self) -> R::Info; } @@ -148,14 +148,6 @@ impl Service for Box> { DynService::handle_notification(self.as_ref(), notification) } - fn get_peer(&self) -> Option> { - DynService::get_peer(self.as_ref()) - } - - fn set_peer(&mut self, peer: Peer) { - DynService::set_peer(self.as_mut(), peer) - } - fn get_info(&self) -> R::Info { DynService::get_info(self.as_ref()) } @@ -168,8 +160,6 @@ pub trait DynService: Send + Sync { context: RequestContext, ) -> BoxFuture>; fn handle_notification(&self, notification: R::PeerNot) -> BoxFuture>; - fn get_peer(&self) -> Option>; - fn set_peer(&mut self, peer: Peer); fn get_info(&self) -> R::Info; } @@ -184,12 +174,6 @@ impl> DynService for S { fn handle_notification(&self, notification: R::PeerNot) -> BoxFuture> { Box::pin(self.handle_notification(notification)) } - fn get_peer(&self) -> Option> { - self.get_peer() - } - fn set_peer(&mut self, peer: Peer) { - self.set_peer(peer) - } fn get_info(&self) -> R::Info { self.get_info() } @@ -255,9 +239,7 @@ impl RequestHandle { pub async fn await_response(self) -> Result { if let Some(timeout) = self.options.timeout { let timeout_result = tokio::time::timeout(timeout, async move { - self.rx - .await - .map_err(|_e| ServiceError::Transport(std::io::Error::other("disconnected")))? + self.rx.await.map_err(|_e| ServiceError::TransportClosed)? }) .await; match timeout_result { @@ -278,9 +260,7 @@ impl RequestHandle { } } } else { - self.rx - .await - .map_err(|_e| ServiceError::Transport(std::io::Error::other("disconnected")))? + self.rx.await.map_err(|_e| ServiceError::TransportClosed)? } } @@ -373,12 +353,8 @@ impl Peer { responder, }) .await - .map_err(|_m| { - ServiceError::Transport(std::io::Error::other("disconnected: receiver dropped")) - })?; - receiver.await.map_err(|_e| { - ServiceError::Transport(std::io::Error::other("disconnected: responder dropped")) - })? + .map_err(|_m| ServiceError::TransportClosed)?; + receiver.await.map_err(|_e| ServiceError::TransportClosed)? } pub async fn send_request(&self, request: R::Req) -> Result { self.send_request_with_option(request, PeerRequestOptions::no_options()) @@ -416,7 +392,7 @@ impl Peer { responder, }) .await - .map_err(|_m| ServiceError::Transport(std::io::Error::other("disconnected")))?; + .map_err(|_m| ServiceError::TransportClosed)?; Ok(RequestHandle { id, rx: receiver, @@ -428,6 +404,10 @@ impl Peer { pub fn peer_info(&self) -> &R::PeerInfo { &self.info } + + pub fn is_transport_closed(&self) -> bool { + self.tx.is_closed() + } } #[derive(Debug)] @@ -518,7 +498,7 @@ where #[instrument(skip_all)] async fn serve_inner( - mut service: S, + service: S, transport: T, peer: Peer, mut peer_rx: tokio::sync::mpsc::Receiver>, @@ -540,7 +520,6 @@ where tracing::info!(?peer_info, "Service initialized as server"); } - service.set_peer(peer.clone()); let mut local_responder_pool = HashMap::>>::new(); let mut local_ct_pool = HashMap::::new(); @@ -631,8 +610,7 @@ where Event::SendTaskResult(SendTaskResult::Request { id, result }) => { if let Err(e) = result { if let Some(responder) = local_responder_pool.remove(&id) { - let _ = responder - .send(Err(ServiceError::Transport(std::io::Error::other(e)))); + let _ = responder.send(Err(ServiceError::TransportSend(Box::new(e)))); } } } @@ -642,7 +620,7 @@ where cancellation_param, }) => { let response = if let Err(e) = result { - Err(ServiceError::Transport(std::io::Error::other(e))) + Err(ServiceError::TransportSend(Box::new(e))) } else { Ok(()) }; diff --git a/crates/rmcp/src/service/tower.rs b/crates/rmcp/src/service/tower.rs index 84984d49..da30f15d 100644 --- a/crates/rmcp/src/service/tower.rs +++ b/crates/rmcp/src/service/tower.rs @@ -2,12 +2,11 @@ use std::{future::poll_fn, marker::PhantomData}; use tower_service::Service as TowerService; -use crate::service::{Peer, RequestContext, Service, ServiceRole}; +use crate::service::{RequestContext, Service, ServiceRole}; pub struct TowerHandler { pub service: S, pub info: R::Info, - pub peer: Option>, role: PhantomData, } @@ -17,7 +16,6 @@ impl TowerHandler { service, role: PhantomData, info, - peer: None, } } } @@ -48,14 +46,6 @@ where std::future::ready(Ok(())) } - fn get_peer(&self) -> Option> { - self.peer.clone() - } - - fn set_peer(&mut self, peer: Peer) { - self.peer = Some(peer); - } - fn get_info(&self) -> R::Info { self.info.clone() }