Skip to content

Commit f3fe656

Browse files
committed
fix with reviews
Signed-off-by: Shikugawa <[email protected]>
1 parent 0035f21 commit f3fe656

File tree

3 files changed

+115
-79
lines changed

3 files changed

+115
-79
lines changed

src/dispatcher.rs

Lines changed: 69 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ struct Dispatcher {
6161
active_id: Cell<u32>,
6262
callouts: RefCell<HashMap<u32, u32>>,
6363
grpc_callouts: RefCell<HashMap<u32, u32>>,
64-
grpc_stream_tokens: RefCell<HashMap<u32, u32>>,
64+
grpc_streams: RefCell<HashMap<u32, u32>>,
6565
}
6666

6767
impl Dispatcher {
@@ -76,7 +76,7 @@ impl Dispatcher {
7676
active_id: Cell::new(0),
7777
callouts: RefCell::new(HashMap::new()),
7878
grpc_callouts: RefCell::new(HashMap::new()),
79-
grpc_stream_tokens: RefCell::new(HashMap::new()),
79+
grpc_streams: RefCell::new(HashMap::new()),
8080
}
8181
}
8282

@@ -103,9 +103,9 @@ impl Dispatcher {
103103
}
104104
}
105105

106-
fn register_grpc_stream_tokens(&self, token_id: u32) {
106+
fn register_grpc_stream(&self, token_id: u32) {
107107
if self
108-
.grpc_stream_tokens
108+
.grpc_streams
109109
.borrow_mut()
110110
.insert(token_id, self.active_id.get())
111111
.is_some()
@@ -416,6 +416,29 @@ impl Dispatcher {
416416
}
417417
}
418418

419+
fn on_grpc_receive_initial_metadata(&self, token_id: u32, headers: u32) {
420+
let context_id = self
421+
.grpc_streams
422+
.borrow_mut()
423+
.get(&token_id)
424+
.expect("invalid token_id")
425+
.clone();
426+
427+
if let Some(http_stream) = self.http_streams.borrow_mut().get_mut(&context_id) {
428+
self.active_id.set(context_id);
429+
hostcalls::set_effective_context(context_id).unwrap();
430+
http_stream.on_grpc_stream_initial_metadata(token_id, headers);
431+
} else if let Some(stream) = self.streams.borrow_mut().get_mut(&context_id) {
432+
self.active_id.set(context_id);
433+
hostcalls::set_effective_context(context_id).unwrap();
434+
stream.on_grpc_stream_initial_metadata(token_id, headers);
435+
} else if let Some(root) = self.roots.borrow_mut().get_mut(&context_id) {
436+
self.active_id.set(context_id);
437+
hostcalls::set_effective_context(context_id).unwrap();
438+
root.on_grpc_stream_initial_metadata(token_id, headers);
439+
}
440+
}
441+
419442
fn on_grpc_receive(&self, token_id: u32, response_size: usize) {
420443
if let Some(context_id) = self.grpc_callouts.borrow_mut().remove(&token_id) {
421444
if let Some(http_stream) = self.http_streams.borrow_mut().get_mut(&context_id) {
@@ -431,26 +454,49 @@ impl Dispatcher {
431454
hostcalls::set_effective_context(context_id).unwrap();
432455
root.on_grpc_call_response(token_id, 0, response_size);
433456
}
434-
} else if let Some(context_id) = self.grpc_callouts.borrow_mut().get(&token_id) {
457+
} else if let Some(context_id) = self.grpc_streams.borrow_mut().get(&token_id) {
435458
let context_id = context_id.clone();
436459
if let Some(http_stream) = self.http_streams.borrow_mut().get_mut(&context_id) {
437460
self.active_id.set(context_id);
438461
hostcalls::set_effective_context(context_id).unwrap();
439-
http_stream.on_grpc_stream_receive_body(token_id, response_size);
462+
http_stream.on_grpc_stream_message(token_id, response_size);
440463
} else if let Some(stream) = self.streams.borrow_mut().get_mut(&context_id) {
441464
self.active_id.set(context_id);
442465
hostcalls::set_effective_context(context_id).unwrap();
443-
stream.on_grpc_stream_receive_body(token_id, response_size);
466+
stream.on_grpc_stream_message(token_id, response_size);
444467
} else if let Some(root) = self.roots.borrow_mut().get_mut(&context_id) {
445468
self.active_id.set(context_id);
446469
hostcalls::set_effective_context(context_id).unwrap();
447-
root.on_grpc_stream_receive_body(token_id, response_size);
470+
root.on_grpc_stream_message(token_id, response_size);
448471
}
449472
} else {
450473
panic!("invalid token_id")
451474
}
452475
}
453476

477+
fn on_grpc_receive_trailing_metadata(&self, token_id: u32, trailers: u32) {
478+
let context_id = self
479+
.grpc_streams
480+
.borrow_mut()
481+
.get(&token_id)
482+
.expect("invalid token_id")
483+
.clone();
484+
485+
if let Some(http_stream) = self.http_streams.borrow_mut().get_mut(&context_id) {
486+
self.active_id.set(context_id);
487+
hostcalls::set_effective_context(context_id).unwrap();
488+
http_stream.on_grpc_stream_initial_metadata(token_id, trailers);
489+
} else if let Some(stream) = self.streams.borrow_mut().get_mut(&context_id) {
490+
self.active_id.set(context_id);
491+
hostcalls::set_effective_context(context_id).unwrap();
492+
stream.on_grpc_stream_initial_metadata(token_id, trailers);
493+
} else if let Some(root) = self.roots.borrow_mut().get_mut(&context_id) {
494+
self.active_id.set(context_id);
495+
hostcalls::set_effective_context(context_id).unwrap();
496+
root.on_grpc_stream_initial_metadata(token_id, trailers);
497+
}
498+
}
499+
454500
fn on_grpc_close(&self, token_id: u32, status_code: u32) {
455501
if let Some(context_id) = self.grpc_callouts.borrow_mut().remove(&token_id) {
456502
if let Some(http_stream) = self.http_streams.borrow_mut().get_mut(&context_id) {
@@ -466,7 +512,7 @@ impl Dispatcher {
466512
hostcalls::set_effective_context(context_id).unwrap();
467513
root.on_grpc_call_response(token_id, status_code, 0);
468514
}
469-
} else if let Some(context_id) = self.grpc_stream_tokens.borrow_mut().remove(&token_id) {
515+
} else if let Some(context_id) = self.grpc_streams.borrow_mut().remove(&token_id) {
470516
if let Some(http_stream) = self.http_streams.borrow_mut().get_mut(&context_id) {
471517
self.active_id.set(context_id);
472518
hostcalls::set_effective_context(context_id).unwrap();
@@ -484,52 +530,6 @@ impl Dispatcher {
484530
panic!("invalid token_id")
485531
}
486532
}
487-
488-
fn on_grpc_receive_initial_metadata(&self, token_id: u32, headers: u32) {
489-
let context_id = self
490-
.grpc_stream_tokens
491-
.borrow_mut()
492-
.get(&token_id)
493-
.expect("invalid token_id")
494-
.clone();
495-
496-
if let Some(http_stream) = self.http_streams.borrow_mut().get_mut(&context_id) {
497-
self.active_id.set(context_id);
498-
hostcalls::set_effective_context(context_id).unwrap();
499-
http_stream.on_grpc_stream_receive_initial_metadata(token_id, headers);
500-
} else if let Some(stream) = self.streams.borrow_mut().get_mut(&context_id) {
501-
self.active_id.set(context_id);
502-
hostcalls::set_effective_context(context_id).unwrap();
503-
stream.on_grpc_stream_receive_initial_metadata(token_id, headers);
504-
} else if let Some(root) = self.roots.borrow_mut().get_mut(&context_id) {
505-
self.active_id.set(context_id);
506-
hostcalls::set_effective_context(context_id).unwrap();
507-
root.on_grpc_stream_receive_initial_metadata(token_id, headers);
508-
}
509-
}
510-
511-
fn on_grpc_receive_trailing_metadata(&self, token_id: u32, trailers: u32) {
512-
let context_id = self
513-
.grpc_stream_tokens
514-
.borrow_mut()
515-
.get(&token_id)
516-
.expect("invalid token_id")
517-
.clone();
518-
519-
if let Some(http_stream) = self.http_streams.borrow_mut().get_mut(&context_id) {
520-
self.active_id.set(context_id);
521-
hostcalls::set_effective_context(context_id).unwrap();
522-
http_stream.on_grpc_stream_receive_trailing_metadata(token_id, trailers);
523-
} else if let Some(stream) = self.streams.borrow_mut().get_mut(&context_id) {
524-
self.active_id.set(context_id);
525-
hostcalls::set_effective_context(context_id).unwrap();
526-
stream.on_grpc_stream_receive_trailing_metadata(token_id, trailers);
527-
} else if let Some(root) = self.roots.borrow_mut().get_mut(&context_id) {
528-
self.active_id.set(context_id);
529-
hostcalls::set_effective_context(context_id).unwrap();
530-
root.on_grpc_stream_receive_trailing_metadata(token_id, trailers);
531-
}
532-
}
533533
}
534534

535535
#[no_mangle]
@@ -659,11 +659,25 @@ pub extern "C" fn proxy_on_http_call_response(
659659
})
660660
}
661661

662+
#[no_mangle]
663+
pub extern "C" fn proxy_on_grpc_receive_initial_metadata(
664+
_context_id: u32,
665+
token_id: u32,
666+
headers: u32,
667+
) {
668+
DISPATCHER.with(|dispatcher| dispatcher.on_grpc_receive_initial_metadata(token_id, headers))
669+
}
670+
662671
#[no_mangle]
663672
pub extern "C" fn proxy_on_grpc_receive(_context_id: u32, token_id: u32, response_size: usize) {
664673
DISPATCHER.with(|dispatcher| dispatcher.on_grpc_receive(token_id, response_size))
665674
}
666675

676+
#[no_mangle]
677+
pub extern "C" fn proxy_on_grpc_trailing_metadata(_context_id: u32, token_id: u32, trailers: u32) {
678+
DISPATCHER.with(|dispatcher| dispatcher.on_grpc_receive_trailing_metadata(token_id, trailers))
679+
}
680+
667681
#[no_mangle]
668682
pub extern "C" fn proxy_on_grpc_close(_context_id: u32, token_id: u32, status_code: u32) {
669683
DISPATCHER.with(|dispatcher| dispatcher.on_grpc_close(token_id, status_code))

src/hostcalls.rs

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,24 @@ pub fn get_map(map_type: MapType) -> Result<Vec<(String, String)>, Status> {
177177
}
178178
}
179179

180+
pub fn get_map_bytes(map_type: MapType) -> Result<Vec<(String, Vec<u8>)>, Status> {
181+
unsafe {
182+
let mut return_data: *mut u8 = null_mut();
183+
let mut return_size: usize = 0;
184+
match proxy_get_header_map_pairs(map_type, &mut return_data, &mut return_size) {
185+
Status::Ok => {
186+
if !return_data.is_null() {
187+
let serialized_map = Vec::from_raw_parts(return_data, return_size, return_size);
188+
Ok(utils::deserialize_bytes_map(&serialized_map))
189+
} else {
190+
Ok(Vec::new())
191+
}
192+
}
193+
status => panic!("unexpected status: {}", status as u32),
194+
}
195+
}
196+
}
197+
180198
extern "C" {
181199
fn proxy_set_header_map_pairs(
182200
map_type: MapType,
@@ -677,7 +695,7 @@ pub fn dispatch_grpc_call(
677695
timeout: Duration,
678696
) -> Result<u32, Status> {
679697
let mut return_callout_id = 0;
680-
let serialized_initial_metadata = utils::serialize_bytes_value_map(initial_metadata);
698+
let serialized_initial_metadata = utils::serialize_bytes_map(initial_metadata);
681699
unsafe {
682700
match proxy_grpc_call(
683701
upstream_name.as_ptr(),
@@ -718,14 +736,14 @@ extern "C" {
718736
) -> Status;
719737
}
720738

721-
pub fn create_grpc_stream(
739+
pub fn open_grpc_stream(
722740
upstream_name: &str,
723741
service_name: &str,
724742
method_name: &str,
725743
initial_metadata: Vec<(&str, &[u8])>,
726744
) -> Result<u32, Status> {
727745
let mut return_stream_id = 0;
728-
let serialized_initial_metadata = utils::serialize_bytes_value_map(initial_metadata);
746+
let serialized_initial_metadata = utils::serialize_bytes_map(initial_metadata);
729747
unsafe {
730748
match proxy_grpc_stream(
731749
upstream_name.as_ptr(),
@@ -758,7 +776,7 @@ extern "C" {
758776
) -> Status;
759777
}
760778

761-
pub fn grpc_stream_send(
779+
pub fn send_grpc_stream_message(
762780
token: u32,
763781
message: Option<&[u8]>,
764782
end_stream: bool,
@@ -782,7 +800,7 @@ extern "C" {
782800
fn proxy_grpc_cancel(token_id: u32) -> Status;
783801
}
784802

785-
pub fn grpc_call_cancel(token_id: u32) -> Result<(), Status> {
803+
pub fn cancel_grpc_call(token_id: u32) -> Result<(), Status> {
786804
unsafe {
787805
match proxy_grpc_cancel(token_id) {
788806
Status::Ok => Ok(()),
@@ -796,7 +814,7 @@ extern "C" {
796814
fn proxy_grpc_close(token_id: u32) -> Status;
797815
}
798816

799-
pub fn grpc_stream_close(token_id: u32) -> Result<(), Status> {
817+
pub fn close_grpc_stream(token_id: u32) -> Result<(), Status> {
800818
unsafe {
801819
match proxy_grpc_close(token_id) {
802820
Status::Ok => Ok(()),
@@ -938,7 +956,7 @@ mod utils {
938956
bytes
939957
}
940958

941-
pub(super) fn serialize_bytes_value_map(map: Vec<(&str, &[u8])>) -> Bytes {
959+
pub(super) fn serialize_bytes_map(map: Vec<(&str, &[u8])>) -> Bytes {
942960
let mut size: usize = 4;
943961
for (name, value) in &map {
944962
size += name.len() + value.len() + 10;

src/traits.rs

Lines changed: 21 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -119,45 +119,49 @@ pub trait Context {
119119
hostcalls::cancel_grpc_call(token_id)
120120
}
121121

122-
fn create_grpc_stream(
122+
fn open_grpc_stream(
123123
&self,
124124
cluster_name: &str,
125125
service_name: &str,
126126
method_name: &str,
127127
initial_metadata: Vec<(&str, &[u8])>,
128128
) -> Result<u32, Status> {
129-
hostcalls::create_grpc_stream(cluster_name, service_name, method_name, initial_metadata)
129+
hostcalls::open_grpc_stream(cluster_name, service_name, method_name, initial_metadata)
130130
}
131131

132-
fn grpc_stream_send(
132+
fn on_grpc_stream_initial_metadata(&mut self, _token_id: u32, _num_elements: u32) {}
133+
134+
fn get_grpc_stream_initial_metadata(&self) -> Vec<(String, Vec<u8>)> {
135+
hostcalls::get_map_bytes(MapType::GrpcReceiveInitialMetadata).unwrap()
136+
}
137+
138+
fn send_grpc_stream_message(
133139
&self,
134140
token_id: u32,
135141
message: Option<&[u8]>,
136142
end_stream: bool,
137143
) -> Result<(), Status> {
138-
hostcalls::grpc_stream_send(token_id, message, end_stream)
144+
hostcalls::send_grpc_stream_message(token_id, message, end_stream)
139145
}
140146

141-
fn grpc_stream_close(&self, token_id: u32) -> Result<(), Status> {
142-
hostcalls::grpc_stream_close(token_id)
143-
}
144-
145-
fn on_grpc_stream_receive_initial_metadata(&mut self, _token_id: u32, _headers: u32) {}
147+
fn on_grpc_stream_message(&mut self, _token_id: u32, _message_size: usize) {}
146148

147-
fn on_grpc_stream_receive_trailing_metadata(&mut self, _token_id: u32, _trailers: u32) {}
148-
149-
fn on_grpc_stream_receive_body(&mut self, _token_id: u32, _response_size: usize) {}
149+
fn get_grpc_stream_message(&mut self, start: usize, max_size: usize) -> Option<Bytes> {
150+
hostcalls::get_buffer(BufferType::GrpcReceiveBuffer, start, max_size).unwrap()
151+
}
150152

151-
fn on_grpc_stream_close(&mut self, _token_id: u32, _status_code: u32) {}
153+
fn on_grpc_stream_trailing_metadata(&mut self, _token_id: u32, _num_elements: u32) {}
152154

153-
fn get_grpc_call_initial_metadata(&self) -> Vec<(String, Vec<u8>)> {
154-
hostcalls::get_map_bytes(MapType::GrpcReceiveInitialMetadata).unwrap()
155+
fn get_grpc_stream_trailing_metadata(&self) -> Vec<(String, Vec<u8>)> {
156+
hostcalls::get_map_bytes(MapType::GrpcReceiveTrailingMetadata).unwrap()
155157
}
156158

157-
fn get_grpc_call_trailing_metadata(&self) -> Vec<(String, Vec<u8>)> {
158-
hostcalls::get_map_bytes(MapType::GrpcReceiveTrailingMetadata).unwrap()
159+
fn close_grpc_stream(&self, token_id: u32) -> Result<(), Status> {
160+
hostcalls::close_grpc_stream(token_id)
159161
}
160162

163+
fn on_grpc_stream_close(&mut self, _token_id: u32, _status_code: u32) {}
164+
161165
fn on_done(&mut self) -> bool {
162166
true
163167
}

0 commit comments

Comments
 (0)