Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions src/frame/go_away.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ use crate::frame::{self, Error, Head, Kind, Reason, StreamId};
pub struct GoAway {
last_stream_id: StreamId,
error_code: Reason,
#[allow(unused)]
debug_data: Bytes,
}

Expand All @@ -21,6 +20,15 @@ impl GoAway {
}
}

#[doc(hidden)]
#[cfg(feature = "unstable")]
pub fn with_debug_data(self, debug_data: impl Into<Bytes>) -> Self {
Self {
debug_data: debug_data.into(),
..self
}
}

pub fn last_stream_id(&self) -> StreamId {
self.last_stream_id
}
Expand Down Expand Up @@ -52,9 +60,10 @@ impl GoAway {
pub fn encode<B: BufMut>(&self, dst: &mut B) {
tracing::trace!("encoding GO_AWAY; code={:?}", self.error_code);
let head = Head::new(Kind::GoAway, 0, StreamId::zero());
head.encode(8, dst);
head.encode(8 + self.debug_data.len(), dst);
dst.put_u32(self.last_stream_id.into());
dst.put_u32(self.error_code.into());
dst.put(self.debug_data.slice(..));
}
}

Expand Down
23 changes: 23 additions & 0 deletions src/proto/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -398,6 +398,18 @@ where
self.go_away.go_away_now(frame);
}

#[doc(hidden)]
#[cfg(feature = "unstable")]
fn go_away_now_debug_data(&mut self) {
let last_processed_id = self.streams.last_processed_id();

let frame = frame::GoAway::new(last_processed_id, Reason::NO_ERROR)
.with_debug_data("something went wrong");

self.streams.send_go_away(last_processed_id);
self.go_away.go_away(frame);
}

fn go_away_from_user(&mut self, e: Reason) {
let last_processed_id = self.streams.last_processed_id();
let frame = frame::GoAway::new(last_processed_id, e);
Expand Down Expand Up @@ -576,6 +588,17 @@ where
// for a pong before proceeding.
self.inner.ping_pong.ping_shutdown();
}

#[doc(hidden)]
#[cfg(feature = "unstable")]
pub fn go_away_debug_data(&mut self) {
if self.inner.go_away.is_going_away() {
return;
}

self.inner.as_dyn().go_away_now_debug_data();
self.inner.ping_pong.ping_shutdown();
}
}

impl<T, P, B> Drop for Connection<T, P, B>
Expand Down
4 changes: 0 additions & 4 deletions src/proto/go_away.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,6 @@ pub(super) struct GoAway {
/// were a `frame::GoAway`, it might appear like we eventually wanted to
/// serialize it. We **only** want to be able to look up these fields at a
/// later time.
///
/// (Technically, `frame::GoAway` should gain an opaque_debug_data field as
/// well, and we wouldn't want to save that here to accidentally dump in logs,
/// or waste struct space.)
#[derive(Debug)]
pub(crate) struct GoingAway {
/// Stores the highest stream ID of a GOAWAY that has been sent.
Expand Down
6 changes: 6 additions & 0 deletions src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -544,6 +544,12 @@ where
self.connection.go_away_gracefully();
}

#[doc(hidden)]
#[cfg(feature = "unstable")]
pub fn debug_data_shutdown(&mut self) {
self.connection.go_away_debug_data();
}

/// Takes a `PingPong` instance from the connection.
///
/// # Note
Expand Down
7 changes: 7 additions & 0 deletions tests/h2-support/src/frames.rs
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,13 @@ impl Mock<frame::GoAway> {
self.reason(frame::Reason::NO_ERROR)
}

pub fn data<I>(self, debug_data: I) -> Self
where
I: Into<Bytes>,
{
Mock(self.0.with_debug_data(debug_data.into()))
}

pub fn reason(self, reason: frame::Reason) -> Self {
Mock(frame::GoAway::new(self.0.last_stream_id(), reason))
}
Expand Down
35 changes: 35 additions & 0 deletions tests/h2-tests/tests/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -705,6 +705,41 @@ async fn graceful_shutdown() {
join(client, srv).await;
}

#[tokio::test]
async fn go_away_sends_debug_data() {
h2_support::trace_init!();

let (io, mut client) = mock::new();

let client = async move {
let settings = client.assert_server_handshake().await;
assert_default_settings!(settings);
client
.send_frame(frames::headers(1).request("POST", "https://example.com/"))
.await;
client
.recv_frame(frames::go_away(1).no_error().data("something went wrong"))
.await;
};

let src = async move {
let mut srv = server::handshake(io).await.expect("handshake");
let (_req, _tx) = srv.next().await.unwrap().expect("server receives request");

srv.debug_data_shutdown();

let srv_fut = async move {
poll_fn(move |cx| srv.poll_closed(cx))
.await
.expect("server");
};

srv_fut.await
};

join(client, src).await;
}

#[tokio::test]
async fn goaway_even_if_client_sent_goaway() {
h2_support::trace_init!();
Expand Down