Skip to content

Commit 0035f21

Browse files
committed
fix
2 parents 1d9377b + 30066a7 commit 0035f21

File tree

4 files changed

+310
-144
lines changed

4 files changed

+310
-144
lines changed

src/dispatcher.rs

Lines changed: 136 additions & 115 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,10 @@ pub(crate) fn register_callout(token_id: u32) {
3838
DISPATCHER.with(|dispatcher| dispatcher.register_callout(token_id));
3939
}
4040

41+
pub(crate) fn register_grpc_callout(token_id: u32) {
42+
DISPATCHER.with(|dispatcher| dispatcher.register_grpc_callout(token_id));
43+
}
44+
4145
pub(crate) fn register_grpc_stream(token_id: u32) {
4246
DISPATCHER.with(|dispatcher| dispatcher.register_grpc_stream(token_id));
4347
}
@@ -56,7 +60,8 @@ struct Dispatcher {
5660
http_streams: RefCell<HashMap<u32, Box<dyn HttpContext>>>,
5761
active_id: Cell<u32>,
5862
callouts: RefCell<HashMap<u32, u32>>,
59-
grpc_stream_callouts: RefCell<HashMap<u32, u32>>,
63+
grpc_callouts: RefCell<HashMap<u32, u32>>,
64+
grpc_stream_tokens: RefCell<HashMap<u32, u32>>,
6065
}
6166

6267
impl Dispatcher {
@@ -70,7 +75,8 @@ impl Dispatcher {
7075
http_streams: RefCell::new(HashMap::new()),
7176
active_id: Cell::new(0),
7277
callouts: RefCell::new(HashMap::new()),
73-
grpc_stream_callouts: RefCell::new(HashMap::new()),
78+
grpc_callouts: RefCell::new(HashMap::new()),
79+
grpc_stream_tokens: RefCell::new(HashMap::new()),
7480
}
7581
}
7682

@@ -97,9 +103,20 @@ impl Dispatcher {
97103
}
98104
}
99105

100-
fn register_grpc_stream(&self, token_id: u32) {
106+
fn register_grpc_stream_tokens(&self, token_id: u32) {
101107
if self
102-
.grpc_stream_callouts
108+
.grpc_stream_tokens
109+
.borrow_mut()
110+
.insert(token_id, self.active_id.get())
111+
.is_some()
112+
{
113+
panic!("duplicate token_id")
114+
}
115+
}
116+
117+
fn register_grpc_callout(&self, token_id: u32) {
118+
if self
119+
.grpc_callouts
103120
.borrow_mut()
104121
.insert(token_id, self.active_id.get())
105122
.is_some()
@@ -187,99 +204,6 @@ impl Dispatcher {
187204
}
188205
}
189206

190-
fn on_grpc_receive_initial_metadata(&self, token_id: u32, headers: u32) {
191-
let context_id = self
192-
.grpc_stream_callouts
193-
.borrow_mut()
194-
.get(&token_id)
195-
.expect("invalid token_id")
196-
.clone();
197-
198-
if let Some(http_stream) = self.http_streams.borrow_mut().get_mut(&context_id) {
199-
self.active_id.set(context_id);
200-
hostcalls::set_effective_context(context_id).unwrap();
201-
http_stream.on_grpc_receive_initial_metadata(token_id, headers);
202-
} else if let Some(stream) = self.streams.borrow_mut().get_mut(&context_id) {
203-
self.active_id.set(context_id);
204-
hostcalls::set_effective_context(context_id).unwrap();
205-
stream.on_grpc_receive_initial_metadata(token_id, headers);
206-
} else if let Some(root) = self.roots.borrow_mut().get_mut(&context_id) {
207-
self.active_id.set(context_id);
208-
hostcalls::set_effective_context(context_id).unwrap();
209-
root.on_grpc_receive_initial_metadata(token_id, headers);
210-
}
211-
}
212-
213-
fn on_grpc_receive_trailing_metadata(&self, token_id: u32, trailers: u32) {
214-
let context_id = self
215-
.grpc_stream_callouts
216-
.borrow_mut()
217-
.get(&token_id)
218-
.expect("invalid token_id")
219-
.clone();
220-
221-
if let Some(http_stream) = self.http_streams.borrow_mut().get_mut(&context_id) {
222-
self.active_id.set(context_id);
223-
hostcalls::set_effective_context(context_id).unwrap();
224-
http_stream.on_grpc_receive_trailing_metadata(token_id, trailers);
225-
} else if let Some(stream) = self.streams.borrow_mut().get_mut(&context_id) {
226-
self.active_id.set(context_id);
227-
hostcalls::set_effective_context(context_id).unwrap();
228-
stream.on_grpc_receive_trailing_metadata(token_id, trailers);
229-
} else if let Some(root) = self.roots.borrow_mut().get_mut(&context_id) {
230-
self.active_id.set(context_id);
231-
hostcalls::set_effective_context(context_id).unwrap();
232-
root.on_grpc_receive_trailing_metadata(token_id, trailers);
233-
}
234-
}
235-
236-
fn on_grpc_receive(&self, token_id: u32, response_size: usize) {
237-
// TODO(shikugawa): migrate with gRPC callout tokens
238-
let context_id = self
239-
.grpc_stream_callouts
240-
.borrow_mut()
241-
.get(&token_id)
242-
.expect("invalid token_id")
243-
.clone();
244-
245-
if let Some(http_stream) = self.http_streams.borrow_mut().get_mut(&context_id) {
246-
self.active_id.set(context_id);
247-
hostcalls::set_effective_context(context_id).unwrap();
248-
http_stream.on_grpc_receive(token_id, response_size);
249-
} else if let Some(stream) = self.streams.borrow_mut().get_mut(&context_id) {
250-
self.active_id.set(context_id);
251-
hostcalls::set_effective_context(context_id).unwrap();
252-
stream.on_grpc_receive(token_id, response_size);
253-
} else if let Some(root) = self.roots.borrow_mut().get_mut(&context_id) {
254-
self.active_id.set(context_id);
255-
hostcalls::set_effective_context(context_id).unwrap();
256-
root.on_grpc_receive(token_id, response_size);
257-
}
258-
}
259-
260-
fn on_grpc_close(&self, token_id: u32, status_code: u32) {
261-
// TODO(shikugawa): migrate with gRPC callout tokens
262-
let context_id = self
263-
.grpc_stream_callouts
264-
.borrow_mut()
265-
.remove(&token_id)
266-
.expect("invalid token_id");
267-
268-
if let Some(http_stream) = self.http_streams.borrow_mut().get_mut(&context_id) {
269-
self.active_id.set(context_id);
270-
hostcalls::set_effective_context(context_id).unwrap();
271-
http_stream.on_grpc_close(token_id, status_code);
272-
} else if let Some(stream) = self.streams.borrow_mut().get_mut(&context_id) {
273-
self.active_id.set(context_id);
274-
hostcalls::set_effective_context(context_id).unwrap();
275-
stream.on_grpc_close(token_id, status_code);
276-
} else if let Some(root) = self.roots.borrow_mut().get_mut(&context_id) {
277-
self.active_id.set(context_id);
278-
hostcalls::set_effective_context(context_id).unwrap();
279-
root.on_grpc_close(token_id, status_code);
280-
}
281-
}
282-
283207
fn on_done(&self, context_id: u32) -> bool {
284208
if let Some(http_stream) = self.http_streams.borrow_mut().get_mut(&context_id) {
285209
self.active_id.set(context_id);
@@ -491,6 +415,121 @@ impl Dispatcher {
491415
root.on_http_call_response(token_id, num_headers, body_size, num_trailers)
492416
}
493417
}
418+
419+
fn on_grpc_receive(&self, token_id: u32, response_size: usize) {
420+
if let Some(context_id) = self.grpc_callouts.borrow_mut().remove(&token_id) {
421+
if let Some(http_stream) = self.http_streams.borrow_mut().get_mut(&context_id) {
422+
self.active_id.set(context_id);
423+
hostcalls::set_effective_context(context_id).unwrap();
424+
http_stream.on_grpc_call_response(token_id, 0, response_size);
425+
} else if let Some(stream) = self.streams.borrow_mut().get_mut(&context_id) {
426+
self.active_id.set(context_id);
427+
hostcalls::set_effective_context(context_id).unwrap();
428+
stream.on_grpc_call_response(token_id, 0, response_size);
429+
} else if let Some(root) = self.roots.borrow_mut().get_mut(&context_id) {
430+
self.active_id.set(context_id);
431+
hostcalls::set_effective_context(context_id).unwrap();
432+
root.on_grpc_call_response(token_id, 0, response_size);
433+
}
434+
} else if let Some(context_id) = self.grpc_callouts.borrow_mut().get(&token_id) {
435+
let context_id = context_id.clone();
436+
if let Some(http_stream) = self.http_streams.borrow_mut().get_mut(&context_id) {
437+
self.active_id.set(context_id);
438+
hostcalls::set_effective_context(context_id).unwrap();
439+
http_stream.on_grpc_stream_receive_body(token_id, response_size);
440+
} else if let Some(stream) = self.streams.borrow_mut().get_mut(&context_id) {
441+
self.active_id.set(context_id);
442+
hostcalls::set_effective_context(context_id).unwrap();
443+
stream.on_grpc_stream_receive_body(token_id, response_size);
444+
} else if let Some(root) = self.roots.borrow_mut().get_mut(&context_id) {
445+
self.active_id.set(context_id);
446+
hostcalls::set_effective_context(context_id).unwrap();
447+
root.on_grpc_stream_receive_body(token_id, response_size);
448+
}
449+
} else {
450+
panic!("invalid token_id")
451+
}
452+
}
453+
454+
fn on_grpc_close(&self, token_id: u32, status_code: u32) {
455+
if let Some(context_id) = self.grpc_callouts.borrow_mut().remove(&token_id) {
456+
if let Some(http_stream) = self.http_streams.borrow_mut().get_mut(&context_id) {
457+
self.active_id.set(context_id);
458+
hostcalls::set_effective_context(context_id).unwrap();
459+
http_stream.on_grpc_call_response(token_id, status_code, 0);
460+
} else if let Some(stream) = self.streams.borrow_mut().get_mut(&context_id) {
461+
self.active_id.set(context_id);
462+
hostcalls::set_effective_context(context_id).unwrap();
463+
stream.on_grpc_call_response(token_id, status_code, 0);
464+
} else if let Some(root) = self.roots.borrow_mut().get_mut(&context_id) {
465+
self.active_id.set(context_id);
466+
hostcalls::set_effective_context(context_id).unwrap();
467+
root.on_grpc_call_response(token_id, status_code, 0);
468+
}
469+
} else if let Some(context_id) = self.grpc_stream_tokens.borrow_mut().remove(&token_id) {
470+
if let Some(http_stream) = self.http_streams.borrow_mut().get_mut(&context_id) {
471+
self.active_id.set(context_id);
472+
hostcalls::set_effective_context(context_id).unwrap();
473+
http_stream.on_grpc_stream_close(token_id, status_code)
474+
} else if let Some(stream) = self.streams.borrow_mut().get_mut(&context_id) {
475+
self.active_id.set(context_id);
476+
hostcalls::set_effective_context(context_id).unwrap();
477+
stream.on_grpc_stream_close(token_id, status_code)
478+
} else if let Some(root) = self.roots.borrow_mut().get_mut(&context_id) {
479+
self.active_id.set(context_id);
480+
hostcalls::set_effective_context(context_id).unwrap();
481+
root.on_grpc_stream_close(token_id, status_code)
482+
}
483+
} else {
484+
panic!("invalid token_id")
485+
}
486+
}
487+
488+
fn on_grpc_receive_initial_metadata(&self, token_id: u32, headers: u32) {
489+
let context_id = self
490+
.grpc_stream_tokens
491+
.borrow_mut()
492+
.get(&token_id)
493+
.expect("invalid token_id")
494+
.clone();
495+
496+
if let Some(http_stream) = self.http_streams.borrow_mut().get_mut(&context_id) {
497+
self.active_id.set(context_id);
498+
hostcalls::set_effective_context(context_id).unwrap();
499+
http_stream.on_grpc_stream_receive_initial_metadata(token_id, headers);
500+
} else if let Some(stream) = self.streams.borrow_mut().get_mut(&context_id) {
501+
self.active_id.set(context_id);
502+
hostcalls::set_effective_context(context_id).unwrap();
503+
stream.on_grpc_stream_receive_initial_metadata(token_id, headers);
504+
} else if let Some(root) = self.roots.borrow_mut().get_mut(&context_id) {
505+
self.active_id.set(context_id);
506+
hostcalls::set_effective_context(context_id).unwrap();
507+
root.on_grpc_stream_receive_initial_metadata(token_id, headers);
508+
}
509+
}
510+
511+
fn on_grpc_receive_trailing_metadata(&self, token_id: u32, trailers: u32) {
512+
let context_id = self
513+
.grpc_stream_tokens
514+
.borrow_mut()
515+
.get(&token_id)
516+
.expect("invalid token_id")
517+
.clone();
518+
519+
if let Some(http_stream) = self.http_streams.borrow_mut().get_mut(&context_id) {
520+
self.active_id.set(context_id);
521+
hostcalls::set_effective_context(context_id).unwrap();
522+
http_stream.on_grpc_stream_receive_trailing_metadata(token_id, trailers);
523+
} else if let Some(stream) = self.streams.borrow_mut().get_mut(&context_id) {
524+
self.active_id.set(context_id);
525+
hostcalls::set_effective_context(context_id).unwrap();
526+
stream.on_grpc_stream_receive_trailing_metadata(token_id, trailers);
527+
} else if let Some(root) = self.roots.borrow_mut().get_mut(&context_id) {
528+
self.active_id.set(context_id);
529+
hostcalls::set_effective_context(context_id).unwrap();
530+
root.on_grpc_stream_receive_trailing_metadata(token_id, trailers);
531+
}
532+
}
494533
}
495534

496535
#[no_mangle]
@@ -620,24 +659,6 @@ pub extern "C" fn proxy_on_http_call_response(
620659
})
621660
}
622661

623-
#[no_mangle]
624-
pub extern "C" fn proxy_on_grpc_receive_initial_metadata(
625-
_context_id: u32,
626-
token_id: u32,
627-
headers: u32,
628-
) {
629-
DISPATCHER.with(|dispatcher| dispatcher.on_grpc_receive_initial_metadata(token_id, headers))
630-
}
631-
632-
#[no_mangle]
633-
pub extern "C" fn proxy_on_grpc_receive_trailing_metadata(
634-
_context_id: u32,
635-
token_id: u32,
636-
trailers: u32,
637-
) {
638-
DISPATCHER.with(|dispatcher| dispatcher.on_grpc_receive_trailing_metadata(token_id, trailers))
639-
}
640-
641662
#[no_mangle]
642663
pub extern "C" fn proxy_on_grpc_receive(_context_id: u32, token_id: u32, response_size: usize) {
643664
DISPATCHER.with(|dispatcher| dispatcher.on_grpc_receive(token_id, response_size))

0 commit comments

Comments
 (0)