Skip to content

Commit 9232805

Browse files
committed
add an error type for the sub-classes
1 parent 2948106 commit 9232805

File tree

1 file changed

+44
-26
lines changed
  • crates/rmcp/src/transport

1 file changed

+44
-26
lines changed

crates/rmcp/src/transport/sse.rs

Lines changed: 44 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,13 @@ const MIME_TYPE: &str = "text/event-stream";
1414
const HEADER_LAST_EVENT_ID: &str = "Last-Event-ID";
1515

1616
#[derive(Error, Debug)]
17-
pub enum SseTransportError {
17+
pub enum SseTransportError<E: std::error::Error + Send + Sync + 'static> {
1818
#[error("SSE error: {0}")]
1919
Sse(#[from] SseError),
2020
#[error("IO error: {0}")]
2121
Io(#[from] std::io::Error),
22-
#[error("Reqwest error: {0}")]
23-
Reqwest(#[from] reqwest::Error),
22+
#[error("Transport error: {0}")]
23+
Transport(E),
2424
#[error("unexpected end of stream")]
2525
UnexpectedEndOfStream,
2626
#[error("Url error: {0}")]
@@ -29,13 +29,13 @@ pub enum SseTransportError {
2929
UnexpectedContentType(Option<HeaderValue>),
3030
}
3131

32-
enum SseTransportState {
32+
enum SseTransportState<E: std::error::Error + Send + Sync + 'static> {
3333
Connected(BoxStream<'static, Result<Sse, SseError>>),
3434
Retrying {
3535
times: usize,
3636
fut: BoxFuture<
3737
'static,
38-
Result<BoxStream<'static, Result<Sse, SseError>>, SseTransportError>,
38+
Result<BoxStream<'static, Result<Sse, SseError>>, SseTransportError<E>>,
3939
>,
4040
},
4141
Fatal {
@@ -60,17 +60,23 @@ impl Default for SseTransportRetryConfig {
6060
}
6161
}
6262

63-
pub trait SseClient: Clone + Send + Sync + 'static {
63+
impl From<reqwest::Error> for SseTransportError<reqwest::Error> {
64+
fn from(e: reqwest::Error) -> Self {
65+
SseTransportError::Transport(e)
66+
}
67+
}
68+
69+
pub trait SseClient<E: std::error::Error + Send + Sync>: Clone + Send + Sync + 'static {
6470
fn connect(
6571
&self,
6672
last_event_id: Option<String>,
67-
) -> BoxFuture<'static, Result<BoxStream<'static, Result<Sse, SseError>>, SseTransportError>>;
73+
) -> BoxFuture<'static, Result<BoxStream<'static, Result<Sse, SseError>>, SseTransportError<E>>>;
6874

6975
fn post(
7076
&self,
7177
endpoint: &str,
7278
message: ClientJsonRpcMessage,
73-
) -> BoxFuture<'static, Result<(), SseTransportError>>;
79+
) -> BoxFuture<'static, Result<(), SseTransportError<E>>>;
7480
}
7581

7682
pub struct RetryConfig {
@@ -84,7 +90,7 @@ pub struct ReqwestSseClient {
8490
sse_url: Url,
8591
}
8692
impl ReqwestSseClient {
87-
pub fn new<U>(url: U) -> Result<Self, SseTransportError>
93+
pub fn new<U>(url: U) -> Result<Self, SseTransportError<reqwest::Error>>
8894
where
8995
U: IntoUrl,
9096
{
@@ -95,7 +101,10 @@ impl ReqwestSseClient {
95101
})
96102
}
97103

98-
pub async fn new_with_timeout<U>(url: U, timeout: Duration) -> Result<Self, SseTransportError>
104+
pub async fn new_with_timeout<U>(
105+
url: U,
106+
timeout: Duration,
107+
) -> Result<Self, SseTransportError<reqwest::Error>>
99108
where
100109
U: IntoUrl,
101110
{
@@ -109,7 +118,10 @@ impl ReqwestSseClient {
109118
})
110119
}
111120

112-
pub async fn new_with_client<U>(url: U, client: HttpClient) -> Result<Self, SseTransportError>
121+
pub async fn new_with_client<U>(
122+
url: U,
123+
client: HttpClient,
124+
) -> Result<Self, SseTransportError<reqwest::Error>>
113125
where
114126
U: IntoUrl,
115127
{
@@ -121,12 +133,14 @@ impl ReqwestSseClient {
121133
}
122134
}
123135

124-
impl SseClient for ReqwestSseClient {
136+
impl SseClient<reqwest::Error> for ReqwestSseClient {
125137
fn connect(
126138
&self,
127139
last_event_id: Option<String>,
128-
) -> BoxFuture<'static, Result<BoxStream<'static, Result<Sse, SseError>>, SseTransportError>>
129-
{
140+
) -> BoxFuture<
141+
'static,
142+
Result<BoxStream<'static, Result<Sse, SseError>>, SseTransportError<reqwest::Error>>,
143+
> {
130144
let client = self.http_client.clone();
131145
let sse_url = self.sse_url.as_ref().to_string();
132146
let last_event_id = last_event_id.clone();
@@ -157,7 +171,7 @@ impl SseClient for ReqwestSseClient {
157171
&self,
158172
session_id: &str,
159173
message: ClientJsonRpcMessage,
160-
) -> BoxFuture<'static, Result<(), SseTransportError>> {
174+
) -> BoxFuture<'static, Result<(), SseTransportError<reqwest::Error>>> {
161175
let client = self.http_client.clone();
162176
let sse_url = self.sse_url.clone();
163177
let session_id = session_id.to_string();
@@ -179,19 +193,21 @@ impl SseClient for ReqwestSseClient {
179193
/// Call [`SseTransport::start`] to create a new transport from url.
180194
///
181195
/// Call [`SseTransport::start_with_client`] to create a new transport with a customized reqwest client.
182-
pub struct SseTransport<C: SseClient> {
196+
pub struct SseTransport<C: SseClient<E>, E: std::error::Error + Send + Sync + 'static> {
183197
client: Arc<C>,
184-
state: SseTransportState,
198+
state: SseTransportState<E>,
185199
last_event_id: Option<String>,
186200
recommended_retry_duration_ms: Option<u64>,
187201
session_id: String,
188202
#[allow(clippy::type_complexity)]
189-
request_queue: VecDeque<tokio::sync::oneshot::Receiver<Result<(), SseTransportError>>>,
203+
request_queue: VecDeque<tokio::sync::oneshot::Receiver<Result<(), SseTransportError<E>>>>,
190204
pub retry_config: SseTransportRetryConfig,
191205
}
192206

193-
impl SseTransport<ReqwestSseClient> {
194-
pub async fn start<U>(url: U) -> Result<SseTransport<ReqwestSseClient>, SseTransportError>
207+
impl SseTransport<ReqwestSseClient, reqwest::Error> {
208+
pub async fn start<U>(
209+
url: U,
210+
) -> Result<SseTransport<ReqwestSseClient, reqwest::Error>, SseTransportError<reqwest::Error>>
195211
where
196212
U: IntoUrl,
197213
{
@@ -200,8 +216,8 @@ impl SseTransport<ReqwestSseClient> {
200216
}
201217
}
202218

203-
impl<C: SseClient> SseTransport<C> {
204-
pub async fn start_with_client(client: C) -> Result<Self, SseTransportError> {
219+
impl<C: SseClient<E>, E: std::error::Error + Send + Sync + 'static> SseTransport<C, E> {
220+
pub async fn start_with_client(client: C) -> Result<Self, SseTransportError<E>> {
205221
let mut event_stream = client.connect(None).await?;
206222
let mut last_event_id = None;
207223
let mut retry = None;
@@ -233,7 +249,7 @@ impl<C: SseClient> SseTransport<C> {
233249

234250
fn retry_connection(
235251
&self,
236-
) -> BoxFuture<'static, Result<BoxStream<'static, Result<Sse, SseError>>, SseTransportError>>
252+
) -> BoxFuture<'static, Result<BoxStream<'static, Result<Sse, SseError>>, SseTransportError<E>>>
237253
{
238254
let retry_duration = {
239255
let recommended_retry_duration = self
@@ -250,7 +266,7 @@ impl<C: SseClient> SseTransport<C> {
250266
}
251267
}
252268

253-
impl<C: SseClient> Stream for SseTransport<C> {
269+
impl<C: SseClient<E>, E: std::error::Error + Send + Sync + 'static> Stream for SseTransport<C, E> {
254270
type Item = ServerJsonRpcMessage;
255271

256272
fn poll_next(
@@ -323,8 +339,10 @@ impl<C: SseClient> Stream for SseTransport<C> {
323339
}
324340
}
325341

326-
impl<C: SseClient> Sink<ClientJsonRpcMessage> for SseTransport<C> {
327-
type Error = SseTransportError;
342+
impl<C: SseClient<E>, E: std::error::Error + Send + Sync + 'static> Sink<ClientJsonRpcMessage>
343+
for SseTransport<C, E>
344+
{
345+
type Error = SseTransportError<E>;
328346

329347
fn poll_ready(
330348
mut self: std::pin::Pin<&mut Self>,

0 commit comments

Comments
 (0)