diff --git a/src/proto/connection.rs b/src/proto/connection.rs index 727643a65..637fac358 100644 --- a/src/proto/connection.rs +++ b/src/proto/connection.rs @@ -145,7 +145,9 @@ where /// connection flow control pub(crate) fn set_target_window_size(&mut self, size: WindowSize) { - self.inner.streams.set_target_connection_window_size(size); + let _res = self.inner.streams.set_target_connection_window_size(size); + // TODO: proper error handling + debug_assert!(_res.is_ok()); } /// Send a new SETTINGS frame with an updated initial window size. diff --git a/src/proto/mod.rs b/src/proto/mod.rs index d71ee9c42..567d03060 100644 --- a/src/proto/mod.rs +++ b/src/proto/mod.rs @@ -30,7 +30,7 @@ pub type PingPayload = [u8; 8]; pub type WindowSize = u32; // Constants -pub const MAX_WINDOW_SIZE: WindowSize = (1 << 31) - 1; +pub const MAX_WINDOW_SIZE: WindowSize = (1 << 31) - 1; // i32::MAX as u32 pub const DEFAULT_REMOTE_RESET_STREAM_MAX: usize = 20; pub const DEFAULT_RESET_STREAM_MAX: usize = 10; pub const DEFAULT_RESET_STREAM_SECS: u64 = 30; diff --git a/src/proto/streams/flow_control.rs b/src/proto/streams/flow_control.rs index 73a7754db..57a935825 100644 --- a/src/proto/streams/flow_control.rs +++ b/src/proto/streams/flow_control.rs @@ -75,12 +75,12 @@ impl FlowControl { self.window_size > self.available } - pub fn claim_capacity(&mut self, capacity: WindowSize) { - self.available -= capacity; + pub fn claim_capacity(&mut self, capacity: WindowSize) -> Result<(), Reason> { + self.available.decrease_by(capacity) } - pub fn assign_capacity(&mut self, capacity: WindowSize) { - self.available += capacity; + pub fn assign_capacity(&mut self, capacity: WindowSize) -> Result<(), Reason> { + self.available.increase_by(capacity) } /// If a WINDOW_UPDATE frame should be sent, returns a positive number @@ -136,22 +136,23 @@ impl FlowControl { /// /// This is called after receiving a SETTINGS frame with a lower /// INITIAL_WINDOW_SIZE value. - pub fn dec_send_window(&mut self, sz: WindowSize) { + pub fn dec_send_window(&mut self, sz: WindowSize) -> Result<(), Reason> { tracing::trace!( "dec_window; sz={}; window={}, available={}", sz, self.window_size, self.available ); - // This should not be able to overflow `window_size` from the bottom. - self.window_size -= sz; + // ~~This should not be able to overflow `window_size` from the bottom.~~ wrong. it can. + self.window_size.decrease_by(sz)?; + Ok(()) } /// Decrement the recv-side window size. /// /// This is called after receiving a SETTINGS ACK frame with a lower /// INITIAL_WINDOW_SIZE value. - pub fn dec_recv_window(&mut self, sz: WindowSize) { + pub fn dec_recv_window(&mut self, sz: WindowSize) -> Result<(), Reason> { tracing::trace!( "dec_recv_window; sz={}; window={}, available={}", sz, @@ -159,13 +160,14 @@ impl FlowControl { self.available ); // This should not be able to overflow `window_size` from the bottom. - self.window_size -= sz; - self.available -= sz; + self.window_size.decrease_by(sz)?; + self.available.decrease_by(sz)?; + Ok(()) } /// Decrements the window reflecting data has actually been sent. The caller /// must ensure that the window has capacity. - pub fn send_data(&mut self, sz: WindowSize) { + pub fn send_data(&mut self, sz: WindowSize) -> Result<(), Reason> { tracing::trace!( "send_data; sz={}; window={}; available={}", sz, @@ -176,12 +178,13 @@ impl FlowControl { // If send size is zero it's meaningless to update flow control window if sz > 0 { // Ensure that the argument is correct - assert!(self.window_size >= sz as usize); + assert!(self.window_size.0 >= sz as i32); // Update values - self.window_size -= sz; - self.available -= sz; + self.window_size.decrease_by(sz)?; + self.available.decrease_by(sz)?; } + Ok(()) } } @@ -208,6 +211,29 @@ impl Window { assert!(self.0 >= 0, "negative Window"); self.0 as WindowSize } + + pub fn decrease_by(&mut self, other: WindowSize) -> Result<(), Reason> { + if let Some(v) = self.0.checked_sub(other as i32) { + self.0 = v; + Ok(()) + } else { + Err(Reason::FLOW_CONTROL_ERROR) + } + } + + pub fn increase_by(&mut self, other: WindowSize) -> Result<(), Reason> { + let other = self.add(other)?; + self.0 = other.0; + Ok(()) + } + + pub fn add(&self, other: WindowSize) -> Result { + if let Some(v) = self.0.checked_add(other as i32) { + Ok(Self(v)) + } else { + Err(Reason::FLOW_CONTROL_ERROR) + } + } } impl PartialEq for Window { @@ -230,25 +256,6 @@ impl PartialOrd for Window { } } -impl ::std::ops::SubAssign for Window { - fn sub_assign(&mut self, other: WindowSize) { - self.0 -= other as i32; - } -} - -impl ::std::ops::Add for Window { - type Output = Self; - fn add(self, other: WindowSize) -> Self::Output { - Window(self.0 + other as i32) - } -} - -impl ::std::ops::AddAssign for Window { - fn add_assign(&mut self, other: WindowSize) { - self.0 += other as i32; - } -} - impl fmt::Display for Window { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { fmt::Display::fmt(&self.0, f) diff --git a/src/proto/streams/prioritize.rs b/src/proto/streams/prioritize.rs index 88204ddcc..35795fae4 100644 --- a/src/proto/streams/prioritize.rs +++ b/src/proto/streams/prioritize.rs @@ -87,7 +87,9 @@ impl Prioritize { flow.inc_window(config.remote_init_window_sz) .expect("invalid initial window size"); - flow.assign_capacity(config.remote_init_window_sz); + // TODO: proper error handling + let _res = flow.assign_capacity(config.remote_init_window_sz); + debug_assert!(_res.is_ok()); tracing::trace!("Prioritize::new; flow={:?}", flow); @@ -253,7 +255,9 @@ impl Prioritize { if available as usize > capacity { let diff = available - capacity as WindowSize; - stream.send_flow.claim_capacity(diff); + // TODO: proper error handling + let _res = stream.send_flow.claim_capacity(diff); + debug_assert!(_res.is_ok()); self.assign_connection_capacity(diff, stream, counts); } @@ -324,7 +328,9 @@ impl Prioritize { pub fn reclaim_all_capacity(&mut self, stream: &mut store::Ptr, counts: &mut Counts) { let available = stream.send_flow.available().as_size(); if available > 0 { - stream.send_flow.claim_capacity(available); + // TODO: proper error handling + let _res = stream.send_flow.claim_capacity(available); + debug_assert!(_res.is_ok()); // Re-assign all capacity to the connection self.assign_connection_capacity(available, stream, counts); } @@ -337,7 +343,9 @@ impl Prioritize { if stream.requested_send_capacity as usize > stream.buffered_send_data { let reserved = stream.requested_send_capacity - stream.buffered_send_data as WindowSize; - stream.send_flow.claim_capacity(reserved); + // TODO: proper error handling + let _res = stream.send_flow.claim_capacity(reserved); + debug_assert!(_res.is_ok()); self.assign_connection_capacity(reserved, stream, counts); } } @@ -363,7 +371,9 @@ impl Prioritize { let span = tracing::trace_span!("assign_connection_capacity", inc); let _e = span.enter(); - self.flow.assign_capacity(inc); + // TODO: proper error handling + let _res = self.flow.assign_capacity(inc); + debug_assert!(_res.is_ok()); // Assign newly acquired capacity to streams pending capacity. while self.flow.available() > 0 { @@ -443,7 +453,9 @@ impl Prioritize { stream.assign_capacity(assign, self.max_buffer_size); // Claim the capacity from the connection - self.flow.claim_capacity(assign); + // TODO: proper error handling + let _res = self.flow.claim_capacity(assign); + debug_assert!(_res.is_ok()); } tracing::trace!( @@ -763,12 +775,16 @@ impl Prioritize { // Assign the capacity back to the connection that // was just consumed from the stream in the previous // line. - self.flow.assign_capacity(len); + // TODO: proper error handling + let _res = self.flow.assign_capacity(len); + debug_assert!(_res.is_ok()); }); let (eos, len) = tracing::trace_span!("updating connection flow") .in_scope(|| { - self.flow.send_data(len); + // TODO: proper error handling + let _res = self.flow.send_data(len); + debug_assert!(_res.is_ok()); // Wrap the frame's data payload to ensure that the // correct amount of data gets written. diff --git a/src/proto/streams/recv.rs b/src/proto/streams/recv.rs index cd96dce2c..1d91b5e58 100644 --- a/src/proto/streams/recv.rs +++ b/src/proto/streams/recv.rs @@ -90,7 +90,7 @@ impl Recv { // settings flow.inc_window(DEFAULT_INITIAL_WINDOW_SIZE) .expect("invalid initial remote window size"); - flow.assign_capacity(DEFAULT_INITIAL_WINDOW_SIZE); + flow.assign_capacity(DEFAULT_INITIAL_WINDOW_SIZE).unwrap(); Recv { init_window_sz: config.local_init_window_sz, @@ -354,7 +354,9 @@ impl Recv { self.in_flight_data -= capacity; // Assign capacity to connection - self.flow.assign_capacity(capacity); + // TODO: proper error handling + let _res = self.flow.assign_capacity(capacity); + debug_assert!(_res.is_ok()); if self.flow.unclaimed_capacity().is_some() { if let Some(task) = task.take() { @@ -382,7 +384,9 @@ impl Recv { stream.in_flight_recv_data -= capacity; // Assign capacity to stream - stream.recv_flow.assign_capacity(capacity); + // TODO: proper error handling + let _res = stream.recv_flow.assign_capacity(capacity); + debug_assert!(_res.is_ok()); if stream.recv_flow.unclaimed_capacity().is_some() { // Queue the stream for sending the WINDOW_UPDATE frame. @@ -428,7 +432,11 @@ impl Recv { /// /// The `task` is an optional parked task for the `Connection` that might /// be blocked on needing more window capacity. - pub fn set_target_connection_window(&mut self, target: WindowSize, task: &mut Option) { + pub fn set_target_connection_window( + &mut self, + target: WindowSize, + task: &mut Option, + ) -> Result<(), Reason> { tracing::trace!( "set_target_connection_window; target={}; available={}, reserved={}", target, @@ -441,11 +449,15 @@ impl Recv { // // Update the flow controller with the difference between the new // target and the current target. - let current = (self.flow.available() + self.in_flight_data).checked_size(); + let current = self + .flow + .available() + .add(self.in_flight_data)? + .checked_size(); if target > current { - self.flow.assign_capacity(target - current); + self.flow.assign_capacity(target - current)?; } else { - self.flow.claim_capacity(current - target); + self.flow.claim_capacity(current - target)?; } // If changing the target capacity means we gained a bunch of capacity, @@ -456,6 +468,7 @@ impl Recv { task.wake(); } } + Ok(()) } pub(crate) fn apply_local_settings( @@ -495,9 +508,13 @@ impl Recv { let dec = old_sz - target; tracing::trace!("decrementing all windows; dec={}", dec); - store.for_each(|mut stream| { - stream.recv_flow.dec_recv_window(dec); - }) + store.try_for_each(|mut stream| { + stream + .recv_flow + .dec_recv_window(dec) + .map_err(proto::Error::library_go_away)?; + Ok::<_, proto::Error>(()) + })?; } Ordering::Greater => { // We must increase the (local) window on every open stream. @@ -510,7 +527,10 @@ impl Recv { .recv_flow .inc_window(inc) .map_err(proto::Error::library_go_away)?; - stream.recv_flow.assign_capacity(inc); + stream + .recv_flow + .assign_capacity(inc) + .map_err(proto::Error::library_go_away)?; Ok::<_, proto::Error>(()) })?; } @@ -617,7 +637,10 @@ impl Recv { } // Update stream level flow control - stream.recv_flow.send_data(sz); + stream + .recv_flow + .send_data(sz) + .map_err(proto::Error::library_go_away)?; // Track the data as in-flight stream.in_flight_recv_data += sz; @@ -658,7 +681,7 @@ impl Recv { } // Update connection level flow control - self.flow.send_data(sz); + self.flow.send_data(sz).map_err(Error::library_go_away)?; // Track the data as in-flight self.in_flight_data += sz; diff --git a/src/proto/streams/send.rs b/src/proto/streams/send.rs index 20aba38d4..dcb5225c7 100644 --- a/src/proto/streams/send.rs +++ b/src/proto/streams/send.rs @@ -4,7 +4,7 @@ use super::{ }; use crate::codec::UserError; use crate::frame::{self, Reason}; -use crate::proto::{Error, Initiator}; +use crate::proto::{self, Error, Initiator}; use bytes::Buf; use tokio::io::AsyncWrite; @@ -458,10 +458,21 @@ impl Send { tracing::trace!("decrementing all windows; dec={}", dec); let mut total_reclaimed = 0; - store.for_each(|mut stream| { + store.try_for_each(|mut stream| { let stream = &mut *stream; - stream.send_flow.dec_send_window(dec); + tracing::trace!( + "decrementing stream window; id={:?}; decr={}; flow={:?}", + stream.id, + dec, + stream.send_flow + ); + + // TODO: this decrement can underflow based on received frames! + stream + .send_flow + .dec_send_window(dec) + .map_err(proto::Error::library_go_away)?; // It's possible that decreasing the window causes // `window_size` (the stream-specific window) to fall below @@ -474,7 +485,10 @@ impl Send { let reclaimed = if available > window_size { // Drop down to `window_size`. let reclaim = available - window_size; - stream.send_flow.claim_capacity(reclaim); + stream + .send_flow + .claim_capacity(reclaim) + .map_err(proto::Error::library_go_away)?; total_reclaimed += reclaim; reclaim } else { @@ -492,7 +506,9 @@ impl Send { // TODO: Should this notify the producer when the capacity // of a stream is reduced? Maybe it should if the capacity // is reduced to zero, allowing the producer to stop work. - }); + + Ok::<_, proto::Error>(()) + })?; self.prioritize .assign_connection_capacity(total_reclaimed, store, counts); diff --git a/src/proto/streams/stream.rs b/src/proto/streams/stream.rs index 2888d744b..43e313647 100644 --- a/src/proto/streams/stream.rs +++ b/src/proto/streams/stream.rs @@ -146,7 +146,9 @@ impl Stream { recv_flow .inc_window(init_recv_window) .expect("invalid initial receive window"); - recv_flow.assign_capacity(init_recv_window); + // TODO: proper error handling? + let _res = recv_flow.assign_capacity(init_recv_window); + debug_assert!(_res.is_ok()); send_flow .inc_window(init_send_window) @@ -275,7 +277,9 @@ impl Stream { pub fn assign_capacity(&mut self, capacity: WindowSize, max_buffer_size: usize) { let prev_capacity = self.capacity(max_buffer_size); debug_assert!(capacity > 0); - self.send_flow.assign_capacity(capacity); + // TODO: proper error handling + let _res = self.send_flow.assign_capacity(capacity); + debug_assert!(_res.is_ok()); tracing::trace!( " assigned capacity to stream; available={}; buffered={}; id={:?}; max_buffer_size={} prev={}", @@ -294,7 +298,9 @@ impl Stream { pub fn send_data(&mut self, len: WindowSize, max_buffer_size: usize) { let prev_capacity = self.capacity(max_buffer_size); - self.send_flow.send_data(len); + // TODO: proper error handling + let _res = self.send_flow.send_data(len); + debug_assert!(_res.is_ok()); // Decrement the stream's buffered data counter debug_assert!(self.buffered_send_data >= len as usize); diff --git a/src/proto/streams/streams.rs b/src/proto/streams/streams.rs index d64e00970..c36a40d13 100644 --- a/src/proto/streams/streams.rs +++ b/src/proto/streams/streams.rs @@ -118,7 +118,7 @@ where } } - pub fn set_target_connection_window_size(&mut self, size: WindowSize) { + pub fn set_target_connection_window_size(&mut self, size: WindowSize) -> Result<(), Reason> { let mut me = self.inner.lock().unwrap(); let me = &mut *me; diff --git a/tests/h2-tests/tests/flow_control.rs b/tests/h2-tests/tests/flow_control.rs index 5caa2ec3a..dbb933286 100644 --- a/tests/h2-tests/tests/flow_control.rs +++ b/tests/h2-tests/tests/flow_control.rs @@ -1858,3 +1858,139 @@ async fn poll_capacity_wakeup_after_window_update() { join(srv, h2).await; } + +#[tokio::test] +async fn window_size_decremented_past_zero() { + h2_support::trace_init!(); + let (io, mut client) = mock::new(); + + let client = async move { + // let _ = client.assert_server_handshake().await; + + // preface + client.write_preface().await; + + // the following http 2 bytes are fuzzer-generated + client.send_bytes(&[0, 0, 0, 4, 0, 0, 0, 0, 0]).await; + client + .send_bytes(&[ + 0, 0, 23, 1, 1, 0, 249, 255, 191, 131, 1, 1, 1, 70, 1, 1, 1, 1, 65, 1, 1, 65, 1, 1, + 65, 1, 1, 1, 1, 1, 1, 190, + ]) + .await; + client.send_bytes(&[0, 0, 0, 0, 0, 0, 0, 0, 1]).await; + client + .send_bytes(&[ + 0, 0, 9, 247, 0, 121, 255, 255, 184, 1, 65, 1, 1, 1, 1, 1, 1, 190, + ]) + .await; + client.send_bytes(&[0, 0, 0, 0, 0, 0, 0, 0, 1]).await; + client.send_bytes(&[0, 0, 0, 0, 0, 0, 0, 0, 1]).await; + client.send_bytes(&[0, 0, 0, 0, 0, 0, 0, 0, 1]).await; + client.send_bytes(&[0, 0, 0, 0, 0, 0, 0, 0, 1]).await; + client.send_bytes(&[0, 0, 0, 0, 0, 0, 0, 0, 1]).await; + client.send_bytes(&[0, 0, 0, 0, 0, 0, 0, 0, 1]).await; + client.send_bytes(&[0, 0, 0, 0, 0, 0, 0, 0, 1]).await; + client.send_bytes(&[0, 0, 0, 0, 0, 0, 0, 0, 1]).await; + client + .send_bytes(&[0, 0, 3, 0, 1, 0, 249, 255, 191, 1, 1, 190]) + .await; + client + .send_bytes(&[0, 0, 2, 50, 107, 0, 0, 0, 1, 0, 0]) + .await; + client + .send_bytes(&[0, 0, 5, 2, 0, 0, 0, 0, 1, 128, 0, 55, 0, 0]) + .await; + client + .send_bytes(&[ + 0, 0, 12, 4, 0, 0, 0, 0, 0, 126, 4, 39, 184, 171, 125, 33, 0, 3, 107, 50, 98, + ]) + .await; + client + .send_bytes(&[0, 0, 6, 4, 0, 0, 0, 0, 0, 3, 4, 76, 255, 71, 131]) + .await; + client + .send_bytes(&[ + 0, 0, 12, 4, 0, 0, 0, 0, 0, 0, 4, 39, 184, 171, 74, 33, 0, 3, 107, 50, 98, + ]) + .await; + client + .send_bytes(&[ + 0, 0, 30, 4, 0, 0, 0, 0, 0, 0, 4, 56, 184, 171, 125, 65, 0, 35, 65, 65, 65, 61, + 232, 87, 115, 89, 116, 0, 4, 0, 58, 33, 125, 33, 79, 3, 107, 49, 98, + ]) + .await; + client + .send_bytes(&[ + 0, 0, 12, 4, 0, 0, 0, 0, 0, 0, 4, 39, 184, 171, 125, 33, 0, 3, 107, 50, 98, + ]) + .await; + client.send_bytes(&[0, 0, 0, 4, 0, 0, 0, 0, 0]).await; + client + .send_bytes(&[ + 0, 0, 12, 4, 0, 0, 0, 0, 0, 126, 4, 39, 184, 171, 125, 33, 0, 3, 107, 50, 98, + ]) + .await; + client + .send_bytes(&[ + 0, 0, 177, 1, 44, 0, 0, 0, 1, 67, 67, 67, 67, 67, 67, 131, 134, 5, 61, 67, 67, 67, + 67, 67, 67, 67, 67, 67, 67, 67, 67, 67, 67, 67, 67, 67, 67, 67, 67, 67, 67, 67, 67, + 67, 67, 67, 67, 67, 67, 67, 67, 67, 67, 67, 67, 67, 67, 67, 67, 67, 67, 67, 67, 67, + 67, 67, 67, 67, 67, 67, 67, 67, 67, 67, 67, 67, 67, 67, 115, 102, 1, 3, 48, 43, + 101, 64, 31, 37, 99, 99, 97, 97, 97, 97, 49, 97, 54, 97, 97, 97, 97, 49, 97, 54, + 97, 99, 54, 53, 53, 51, 53, 99, 99, 97, 97, 99, 97, 97, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, + ]) + .await; + client.send_bytes(&[0, 0, 0, 0, 0, 0, 0, 0, 1]).await; + client.send_bytes(&[0, 0, 0, 0, 0, 0, 0, 0, 1]).await; + client.send_bytes(&[0, 0, 0, 0, 0, 0, 0, 0, 1]).await; + client + .send_bytes(&[ + 0, 0, 12, 4, 0, 0, 0, 0, 0, 0, 4, 0, 58, 171, 125, 33, 79, 3, 107, 49, 98, + ]) + .await; + client + .send_bytes(&[0, 0, 6, 4, 0, 0, 0, 0, 0, 0, 4, 87, 115, 89, 116]) + .await; + client + .send_bytes(&[ + 0, 0, 12, 4, 0, 0, 0, 0, 0, 126, 4, 39, 184, 171, 125, 33, 0, 3, 107, 50, 98, + ]) + .await; + client + .send_bytes(&[ + 0, 0, 129, 1, 44, 0, 0, 0, 1, 67, 67, 67, 67, 67, 67, 131, 134, 5, 18, 67, 67, 61, + 67, 67, 67, 67, 67, 67, 67, 67, 67, 67, 48, 54, 53, 55, 114, 1, 4, 97, 49, 51, 116, + 64, 2, 117, 115, 4, 103, 101, 110, 116, 64, 8, 57, 111, 110, 116, 101, 110, 115, + 102, 7, 43, 43, 49, 48, 48, 43, 101, 192, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + ]) + .await; + client.send_bytes(&[0, 0, 0, 0, 0, 0, 0, 0, 1]).await; + client.send_bytes(&[0, 0, 0, 0, 0, 0, 0, 0, 1]).await; + client.send_bytes(&[0, 0, 0, 0, 0, 0, 0, 0, 1]).await; + client + .send_bytes(&[ + 0, 0, 12, 4, 0, 0, 0, 0, 0, 0, 4, 0, 58, 171, 125, 33, 79, 3, 107, 49, 98, + ]) + .await; + + // TODO: is CANCEL the right error code to expect here? + // client.recv_frame(frames::reset(1).protocol_error()).await; + }; + + let srv = async move { + let builder = server::Builder::new(); + let mut srv = builder.handshake::<_, Bytes>(io).await.expect("handshake"); + + // just keep it open + let res = poll_fn(move |cx| srv.poll_closed(cx)).await; + tracing::debug!("{:?}", res); + }; + + join(client, srv).await; +}