Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 80 additions & 28 deletions crates/rmcp/src/service/client.rs
Original file line number Diff line number Diff line change
@@ -1,20 +1,73 @@
use futures::{SinkExt, StreamExt};
use futures::{SinkExt, Stream, StreamExt};
use thiserror::Error;

use super::*;
use crate::model::{
CallToolRequest, CallToolRequestParam, CallToolResult, CancelledNotification,
CancelledNotificationParam, ClientInfo, ClientMessage, ClientNotification, ClientRequest,
ClientResult, CompleteRequest, CompleteRequestParam, CompleteResult, GetPromptRequest,
GetPromptRequestParam, GetPromptResult, InitializeRequest, InitializedNotification,
ListPromptsRequest, ListPromptsResult, ListResourceTemplatesRequest,
JsonRpcResponse, ListPromptsRequest, ListPromptsResult, ListResourceTemplatesRequest,
ListResourceTemplatesResult, ListResourcesRequest, ListResourcesResult, ListToolsRequest,
ListToolsResult, PaginatedRequestParam, PaginatedRequestParamInner, ProgressNotification,
ProgressNotificationParam, ReadResourceRequest, ReadResourceRequestParam, ReadResourceResult,
RootsListChangedNotification, ServerInfo, ServerNotification, ServerRequest, ServerResult,
SetLevelRequest, SetLevelRequestParam, SubscribeRequest, SubscribeRequestParam,
UnsubscribeRequest, UnsubscribeRequestParam,
RequestId, RootsListChangedNotification, ServerInfo, ServerJsonRpcMessage, ServerNotification,
ServerRequest, ServerResult, SetLevelRequest, SetLevelRequestParam, SubscribeRequest,
SubscribeRequestParam, UnsubscribeRequest, UnsubscribeRequestParam,
};

/// It represents the error that may occur when serving the client.
///
/// if you want to handle the error, you can use `serve_client_with_ct` or `serve_client` with `Result<RunningService<RoleClient, S>, ClientError>`
#[derive(Error, Debug)]
pub enum ClientError {
#[error("expect initialized response, but received: {0:?}")]
ExpectedInitResponse(Option<ServerJsonRpcMessage>),

#[error("expect initialized result, but received: {0:?}")]
ExpectedInitResult(Option<ServerResult>),

#[error("conflict initialized response id: expected {0}, got {1}")]
ConflictInitResponseId(RequestId, RequestId),

#[error("connection closed: {0}")]
ConnectionClosed(String),

#[error("IO error: {0}")]
Io(#[from] std::io::Error),
}

/// Helper function to get the next message from the stream
async fn expect_next_message<S>(
stream: &mut S,
context: &str,
) -> Result<ServerJsonRpcMessage, ClientError>
where
S: Stream<Item = ServerJsonRpcMessage> + Unpin,
{
stream
.next()
.await
.ok_or_else(|| ClientError::ConnectionClosed(context.to_string()))
.map_err(|e| ClientError::Io(std::io::Error::new(std::io::ErrorKind::Other, e)))
}

/// Helper function to expect a response from the stream
async fn expect_response<S>(
stream: &mut S,
context: &str,
) -> Result<(ServerResult, RequestId), ClientError>
where
S: Stream<Item = ServerJsonRpcMessage> + Unpin,
{
let msg = expect_next_message(stream, context).await?;

match msg {
ServerJsonRpcMessage::Response(JsonRpcResponse { id, result, .. }) => Ok((result, id)),
_ => Err(ClientError::ExpectedInitResponse(Some(msg))),
}
}

#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
pub struct RoleClient;

Expand Down Expand Up @@ -74,6 +127,15 @@ where
let mut sink = Box::pin(sink);
let mut stream = Box::pin(stream);
let id_provider = <Arc<AtomicU32RequestIdProvider>>::default();

// Convert ClientError to std::io::Error, then to E
let handle_client_error = |e: ClientError| -> E {
match e {
ClientError::Io(io_err) => io_err.into(),
other => std::io::Error::new(std::io::ErrorKind::Other, format!("{}", other)).into(),
}
};

// service
let id = id_provider.next_request_id();
let init_request = InitializeRequest {
Expand All @@ -85,34 +147,24 @@ where
.into_json_rpc_message(),
)
.await?;
let (response, response_id) = stream
.next()

let (response, response_id) = expect_response(&mut stream, "initialize response")
.await
.ok_or(std::io::Error::new(
std::io::ErrorKind::UnexpectedEof,
"expect initialize response",
))?
.into_message()
.into_result()
.ok_or(std::io::Error::new(
std::io::ErrorKind::InvalidData,
"expect initialize result",
))?;
.map_err(handle_client_error)?;

if id != response_id {
return Err(std::io::Error::new(
std::io::ErrorKind::InvalidData,
"conflict initialize response id",
)
.into());
return Err(handle_client_error(ClientError::ConflictInitResponseId(
id,
response_id,
)));
}
let response = response.map_err(std::io::Error::other)?;

let ServerResult::InitializeResult(initialize_result) = response else {
return Err(std::io::Error::new(
std::io::ErrorKind::InvalidData,
"expect initialize result",
)
.into());
return Err(handle_client_error(ClientError::ExpectedInitResult(Some(
response,
))));
};

// send notification
let notification = ClientMessage::Notification(ClientNotification::InitializedNotification(
InitializedNotification {
Expand Down
Loading