Skip to content

Commit 0859d82

Browse files
authored
feat(transport): support customizing Channel's async executor (#935)
1 parent 01e5be5 commit 0859d82

File tree

6 files changed

+96
-10
lines changed

6 files changed

+96
-10
lines changed

tonic/src/transport/channel/endpoint.rs

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,19 @@ use super::Channel;
44
use super::ClientTlsConfig;
55
#[cfg(feature = "tls")]
66
use crate::transport::service::TlsConnector;
7-
use crate::transport::Error;
7+
use crate::transport::{service::SharedExec, Error, Executor};
88
use bytes::Bytes;
99
use http::{uri::Uri, HeaderValue};
1010
use std::{
1111
convert::{TryFrom, TryInto},
1212
fmt,
13+
future::Future,
14+
pin::Pin,
1315
str::FromStr,
1416
time::Duration,
1517
};
1618
use tower::make::MakeConnection;
19+
// use crate::transport::E
1720

1821
/// Channel builder.
1922
///
@@ -37,6 +40,7 @@ pub struct Endpoint {
3740
pub(crate) http2_keep_alive_while_idle: Option<bool>,
3841
pub(crate) connect_timeout: Option<Duration>,
3942
pub(crate) http2_adaptive_window: Option<bool>,
43+
pub(crate) executor: SharedExec,
4044
}
4145

4246
impl Endpoint {
@@ -263,6 +267,17 @@ impl Endpoint {
263267
}
264268
}
265269

270+
/// Sets the executor used to spawn async tasks.
271+
///
272+
/// Uses `tokio::spawn` by default.
273+
pub fn executor<E>(mut self, executor: E) -> Self
274+
where
275+
E: Executor<Pin<Box<dyn Future<Output = ()> + Send>>> + Send + Sync + 'static,
276+
{
277+
self.executor = SharedExec::new(executor);
278+
self
279+
}
280+
266281
/// Create a channel from this config.
267282
pub async fn connect(&self) -> Result<Channel, Error> {
268283
let mut http = hyper::client::connect::HttpConnector::new();
@@ -396,6 +411,7 @@ impl From<Uri> for Endpoint {
396411
http2_keep_alive_while_idle: None,
397412
connect_timeout: None,
398413
http2_adaptive_window: None,
414+
executor: SharedExec::tokio(),
399415
}
400416
}
401417
}

tonic/src/transport/channel/mod.rs

Lines changed: 29 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,9 @@ pub use endpoint::Endpoint;
99
#[cfg(feature = "tls")]
1010
pub use tls::ClientTlsConfig;
1111

12-
use super::service::{Connection, DynamicServiceStream};
12+
use super::service::{Connection, DynamicServiceStream, SharedExec};
1313
use crate::body::BoxBody;
14+
use crate::transport::Executor;
1415
use bytes::Bytes;
1516
use http::{
1617
uri::{InvalidUri, Uri},
@@ -124,10 +125,26 @@ impl Channel {
124125
pub fn balance_channel<K>(capacity: usize) -> (Self, Sender<Change<K, Endpoint>>)
125126
where
126127
K: Hash + Eq + Send + Clone + 'static,
128+
{
129+
Self::balance_channel_with_executor(capacity, SharedExec::tokio())
130+
}
131+
132+
/// Balance a list of [`Endpoint`]'s.
133+
///
134+
/// This creates a [`Channel`] that will listen to a stream of change events and will add or remove provided endpoints.
135+
///
136+
/// The [`Channel`] will use the given executor to spawn async tasks.
137+
pub fn balance_channel_with_executor<K, E>(
138+
capacity: usize,
139+
executor: E,
140+
) -> (Self, Sender<Change<K, Endpoint>>)
141+
where
142+
K: Hash + Eq + Send + Clone + 'static,
143+
E: Executor<Pin<Box<dyn Future<Output = ()> + Send>>> + Send + Sync + 'static,
127144
{
128145
let (tx, rx) = channel(capacity);
129146
let list = DynamicServiceStream::new(rx);
130-
(Self::balance(list, DEFAULT_BUFFER_SIZE), tx)
147+
(Self::balance(list, DEFAULT_BUFFER_SIZE, executor), tx)
131148
}
132149

133150
pub(crate) fn new<C>(connector: C, endpoint: Endpoint) -> Self
@@ -138,9 +155,11 @@ impl Channel {
138155
C::Response: AsyncRead + AsyncWrite + HyperConnection + Unpin + Send + 'static,
139156
{
140157
let buffer_size = endpoint.buffer_size.unwrap_or(DEFAULT_BUFFER_SIZE);
158+
let executor = endpoint.executor.clone();
141159

142160
let svc = Connection::lazy(connector, endpoint);
143-
let svc = Buffer::new(Either::A(svc), buffer_size);
161+
let (svc, worker) = Buffer::pair(Either::A(svc), buffer_size);
162+
executor.execute(Box::pin(worker));
144163

145164
Channel { svc }
146165
}
@@ -153,25 +172,29 @@ impl Channel {
153172
C::Response: AsyncRead + AsyncWrite + HyperConnection + Unpin + Send + 'static,
154173
{
155174
let buffer_size = endpoint.buffer_size.unwrap_or(DEFAULT_BUFFER_SIZE);
175+
let executor = endpoint.executor.clone();
156176

157177
let svc = Connection::connect(connector, endpoint)
158178
.await
159179
.map_err(super::Error::from_source)?;
160-
let svc = Buffer::new(Either::A(svc), buffer_size);
180+
let (svc, worker) = Buffer::pair(Either::A(svc), buffer_size);
181+
executor.execute(Box::pin(worker));
161182

162183
Ok(Channel { svc })
163184
}
164185

165-
pub(crate) fn balance<D>(discover: D, buffer_size: usize) -> Self
186+
pub(crate) fn balance<D, E>(discover: D, buffer_size: usize, executor: E) -> Self
166187
where
167188
D: Discover<Service = Connection> + Unpin + Send + 'static,
168189
D::Error: Into<crate::Error>,
169190
D::Key: Hash + Send + Clone,
191+
E: Executor<futures_core::future::BoxFuture<'static, ()>> + Send + Sync + 'static,
170192
{
171193
let svc = Balance::new(discover);
172194

173195
let svc = BoxService::new(svc);
174-
let svc = Buffer::new(Either::B(svc), buffer_size);
196+
let (svc, worker) = Buffer::pair(Either::B(svc), buffer_size);
197+
executor.execute(Box::pin(worker));
175198

176199
Channel { svc }
177200
}

tonic/src/transport/mod.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,10 +99,12 @@ pub use self::error::Error;
9999
#[doc(inline)]
100100
pub use self::server::{NamedService, Server};
101101
#[doc(inline)]
102-
pub use self::service::TimeoutExpired;
102+
pub use self::service::grpc_timeout::TimeoutExpired;
103103
pub use self::tls::Certificate;
104104
pub use hyper::{Body, Uri};
105105

106+
pub(crate) use self::service::executor::Executor;
107+
106108
#[cfg(feature = "tls")]
107109
#[cfg_attr(docsrs, doc(cfg(feature = "tls")))]
108110
pub use self::channel::ClientTlsConfig;

tonic/src/transport/service/connection.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ impl Connection {
3939
.http2_initial_connection_window_size(endpoint.init_connection_window_size)
4040
.http2_only(true)
4141
.http2_keep_alive_interval(endpoint.http2_keep_alive_interval)
42+
.executor(endpoint.executor.clone())
4243
.clone();
4344

4445
if let Some(val) = endpoint.http2_keep_alive_timeout {
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
use futures_core::future::BoxFuture;
2+
use std::{future::Future, sync::Arc};
3+
4+
pub(crate) use hyper::rt::Executor;
5+
6+
#[derive(Copy, Clone)]
7+
struct TokioExec;
8+
9+
impl<F> Executor<F> for TokioExec
10+
where
11+
F: Future + Send + 'static,
12+
F::Output: Send + 'static,
13+
{
14+
fn execute(&self, fut: F) {
15+
tokio::spawn(fut);
16+
}
17+
}
18+
19+
#[derive(Clone)]
20+
pub(crate) struct SharedExec {
21+
inner: Arc<dyn Executor<BoxFuture<'static, ()>> + Send + Sync + 'static>,
22+
}
23+
24+
impl SharedExec {
25+
pub(crate) fn new<E>(exec: E) -> Self
26+
where
27+
E: Executor<BoxFuture<'static, ()>> + Send + Sync + 'static,
28+
{
29+
Self {
30+
inner: Arc::new(exec),
31+
}
32+
}
33+
34+
pub(crate) fn tokio() -> Self {
35+
Self::new(TokioExec)
36+
}
37+
}
38+
39+
impl Executor<BoxFuture<'static, ()>> for SharedExec {
40+
fn execute(&self, fut: BoxFuture<'static, ()>) {
41+
self.inner.execute(fut)
42+
}
43+
}

tonic/src/transport/service/mod.rs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@ mod add_origin;
22
mod connection;
33
mod connector;
44
mod discover;
5-
mod grpc_timeout;
5+
pub(crate) mod executor;
6+
pub(crate) mod grpc_timeout;
67
mod io;
78
mod reconnect;
89
mod router;
@@ -14,11 +15,11 @@ pub(crate) use self::add_origin::AddOrigin;
1415
pub(crate) use self::connection::Connection;
1516
pub(crate) use self::connector::connector;
1617
pub(crate) use self::discover::DynamicServiceStream;
18+
pub(crate) use self::executor::SharedExec;
1719
pub(crate) use self::grpc_timeout::GrpcTimeout;
1820
pub(crate) use self::io::ServerIo;
1921
#[cfg(feature = "tls")]
2022
pub(crate) use self::tls::{TlsAcceptor, TlsConnector};
2123
pub(crate) use self::user_agent::UserAgent;
2224

23-
pub use self::grpc_timeout::TimeoutExpired;
2425
pub use self::router::Routes;

0 commit comments

Comments
 (0)