diff --git a/src/dispatcher.rs b/src/dispatcher.rs index d9c9d491..b98e32fe 100644 --- a/src/dispatcher.rs +++ b/src/dispatcher.rs @@ -38,6 +38,10 @@ pub(crate) fn register_callout(token_id: u32) { DISPATCHER.with(|dispatcher| dispatcher.register_callout(token_id)); } +pub(crate) fn register_grpc_callout(token_id: u32) { + DISPATCHER.with(|dispatcher| dispatcher.register_grpc_callout(token_id)); +} + struct NoopRoot; impl Context for NoopRoot {} @@ -52,6 +56,7 @@ struct Dispatcher { http_streams: RefCell>>, active_id: Cell, callouts: RefCell>, + grpc_callouts: RefCell>, } impl Dispatcher { @@ -65,6 +70,7 @@ impl Dispatcher { http_streams: RefCell::new(HashMap::new()), active_id: Cell::new(0), callouts: RefCell::new(HashMap::new()), + grpc_callouts: RefCell::new(HashMap::new()), } } @@ -91,6 +97,17 @@ impl Dispatcher { } } + fn register_grpc_callout(&self, token_id: u32) { + if self + .grpc_callouts + .borrow_mut() + .insert(token_id, self.active_id.get()) + .is_some() + { + panic!("duplicate token_id") + } + } + fn create_root_context(&self, context_id: u32) { let new_context = match self.new_root.get() { Some(f) => f(context_id), @@ -381,6 +398,50 @@ impl Dispatcher { root.on_http_call_response(token_id, num_headers, body_size, num_trailers) } } + + fn on_grpc_receive(&self, token_id: u32, response_size: usize) { + let context_id = self + .grpc_callouts + .borrow_mut() + .remove(&token_id) + .expect("invalid token_id"); + + if let Some(http_stream) = self.http_streams.borrow_mut().get_mut(&context_id) { + self.active_id.set(context_id); + hostcalls::set_effective_context(context_id).unwrap(); + http_stream.on_grpc_call_response(token_id, 0, response_size); + } else if let Some(stream) = self.streams.borrow_mut().get_mut(&context_id) { + self.active_id.set(context_id); + hostcalls::set_effective_context(context_id).unwrap(); + stream.on_grpc_call_response(token_id, 0, response_size); + } else if let Some(root) = self.roots.borrow_mut().get_mut(&context_id) { + self.active_id.set(context_id); + hostcalls::set_effective_context(context_id).unwrap(); + root.on_grpc_call_response(token_id, 0, response_size); + } + } + + fn on_grpc_close(&self, token_id: u32, status_code: u32) { + let context_id = self + .grpc_callouts + .borrow_mut() + .remove(&token_id) + .expect("invalid token_id"); + + if let Some(http_stream) = self.http_streams.borrow_mut().get_mut(&context_id) { + self.active_id.set(context_id); + hostcalls::set_effective_context(context_id).unwrap(); + http_stream.on_grpc_call_response(token_id, status_code, 0); + } else if let Some(stream) = self.streams.borrow_mut().get_mut(&context_id) { + self.active_id.set(context_id); + hostcalls::set_effective_context(context_id).unwrap(); + stream.on_grpc_call_response(token_id, status_code, 0); + } else if let Some(root) = self.roots.borrow_mut().get_mut(&context_id) { + self.active_id.set(context_id); + hostcalls::set_effective_context(context_id).unwrap(); + root.on_grpc_call_response(token_id, status_code, 0); + } + } } #[no_mangle] @@ -509,3 +570,13 @@ pub extern "C" fn proxy_on_http_call_response( dispatcher.on_http_call_response(token_id, num_headers, body_size, num_trailers) }) } + +#[no_mangle] +pub extern "C" fn proxy_on_grpc_receive(_context_id: u32, token_id: u32, response_size: usize) { + DISPATCHER.with(|dispatcher| dispatcher.on_grpc_receive(token_id, response_size)) +} + +#[no_mangle] +pub extern "C" fn proxy_on_grpc_close(_context_id: u32, token_id: u32, status_code: u32) { + DISPATCHER.with(|dispatcher| dispatcher.on_grpc_close(token_id, status_code)) +} diff --git a/src/hostcalls.rs b/src/hostcalls.rs index 8f37f2ae..6b4e960d 100644 --- a/src/hostcalls.rs +++ b/src/hostcalls.rs @@ -651,6 +651,73 @@ pub fn dispatch_http_call( } } +extern "C" { + fn proxy_grpc_call( + upstream_data: *const u8, + upstream_size: usize, + service_name_data: *const u8, + service_name_size: usize, + method_name_data: *const u8, + method_name_size: usize, + initial_metadata_data: *const u8, + initial_metadata_size: usize, + message_data_data: *const u8, + message_data_size: usize, + timeout: u32, + return_callout_id: *mut u32, + ) -> Status; +} + +pub fn dispatch_grpc_call( + upstream_name: &str, + service_name: &str, + method_name: &str, + initial_metadata: Vec<(&str, &[u8])>, + message: Option<&[u8]>, + timeout: Duration, +) -> Result { + let mut return_callout_id = 0; + let serialized_initial_metadata = utils::serialize_bytes_value_map(initial_metadata); + unsafe { + match proxy_grpc_call( + upstream_name.as_ptr(), + upstream_name.len(), + service_name.as_ptr(), + service_name.len(), + method_name.as_ptr(), + method_name.len(), + serialized_initial_metadata.as_ptr(), + serialized_initial_metadata.len(), + message.map_or(null(), |message| message.as_ptr()), + message.map_or(0, |message| message.len()), + timeout.as_millis() as u32, + &mut return_callout_id, + ) { + Status::Ok => { + dispatcher::register_grpc_callout(return_callout_id); + Ok(return_callout_id) + } + Status::ParseFailure => Err(Status::ParseFailure), + Status::InternalFailure => Err(Status::InternalFailure), + status => panic!("unexpected status: {}", status as u32), + } + } +} + +extern "C" { + fn proxy_grpc_cancel(token_id: u32) -> Status; +} + +pub fn cancel_grpc_call(token_id: u32) -> Result<(), Status> { + unsafe { + match proxy_grpc_cancel(token_id) { + Status::Ok => Ok(()), + Status::NotFound => Err(Status::NotFound), + status => panic!("unexpected status: {}", status as u32), + } + } +} + extern "C" { fn proxy_set_effective_context(context_id: u32) -> Status; } @@ -783,6 +850,26 @@ mod utils { bytes } + pub(super) fn serialize_bytes_value_map(map: Vec<(&str, &[u8])>) -> Bytes { + let mut size: usize = 4; + for (name, value) in &map { + size += name.len() + value.len() + 10; + } + let mut bytes: Bytes = Vec::with_capacity(size); + bytes.extend_from_slice(&map.len().to_le_bytes()); + for (name, value) in &map { + bytes.extend_from_slice(&name.len().to_le_bytes()); + bytes.extend_from_slice(&value.len().to_le_bytes()); + } + for (name, value) in &map { + bytes.extend_from_slice(&name.as_bytes()); + bytes.push(0); + bytes.extend_from_slice(&value); + bytes.push(0); + } + bytes + } + pub(super) fn deserialize_map(bytes: &[u8]) -> Vec<(String, String)> { let mut map = Vec::new(); if bytes.is_empty() { diff --git a/src/traits.rs b/src/traits.rs index 5b7fc4be..6baed9f1 100644 --- a/src/traits.rs +++ b/src/traits.rs @@ -90,6 +90,35 @@ pub trait Context { hostcalls::get_map(MapType::HttpCallResponseTrailers).unwrap() } + fn dispatch_grpc_call( + &self, + upstream_name: &str, + service_name: &str, + method_name: &str, + initial_metadata: Vec<(&str, &[u8])>, + message: Option<&[u8]>, + timeout: Duration, + ) -> Result { + hostcalls::dispatch_grpc_call( + upstream_name, + service_name, + method_name, + initial_metadata, + message, + timeout, + ) + } + + fn on_grpc_call_response(&mut self, _token_id: u32, _status_code: u32, _response_size: usize) {} + + fn get_grpc_call_response_body(&self, start: usize, max_size: usize) -> Option { + hostcalls::get_buffer(BufferType::GrpcReceiveBuffer, start, max_size).unwrap() + } + + fn cancel_grpc_call(&self, token_id: u32) -> Result<(), Status> { + hostcalls::cancel_grpc_call(token_id) + } + fn on_done(&mut self) -> bool { true } diff --git a/src/types.rs b/src/types.rs index 855a414b..efdce127 100644 --- a/src/types.rs +++ b/src/types.rs @@ -42,6 +42,7 @@ pub enum Status { Ok = 0, NotFound = 1, BadArgument = 2, + ParseFailure = 4, Empty = 7, CasMismatch = 8, InternalFailure = 10, @@ -62,6 +63,7 @@ pub enum BufferType { DownstreamData = 2, UpstreamData = 3, HttpCallResponseBody = 4, + GrpcReceiveBuffer = 5, } #[repr(u32)]