From c8f5ff867d2fee553c4c0c6401aae170be3ece19 Mon Sep 17 00:00:00 2001 From: Mario Pastorelli Date: Wed, 5 Mar 2025 00:03:04 +0100 Subject: [PATCH] Add `std::io::Seek` instance for `std::io::Take` --- library/std/src/io/mod.rs | 53 +++++++++++++++- library/std/src/io/tests.rs | 120 ++++++++++++++++++++++++++++++++++++ library/std/src/lib.rs | 1 + 3 files changed, 173 insertions(+), 1 deletion(-) diff --git a/library/std/src/io/mod.rs b/library/std/src/io/mod.rs index 96fac4f6bde68..03f5f838311a9 100644 --- a/library/std/src/io/mod.rs +++ b/library/std/src/io/mod.rs @@ -1214,7 +1214,7 @@ pub trait Read { where Self: Sized, { - Take { inner: self, limit } + Take { inner: self, len: limit, limit } } } @@ -2830,6 +2830,7 @@ impl SizeHint for Chain { #[derive(Debug)] pub struct Take { inner: T, + len: u64, limit: u64, } @@ -2864,6 +2865,12 @@ impl Take { self.limit } + /// Returns the number of bytes read so far. + #[unstable(feature = "seek_io_take_position", issue = "97227")] + pub fn position(&self) -> u64 { + self.len - self.limit + } + /// Sets the number of bytes that can be read before this instance will /// return EOF. This is the same as constructing a new `Take` instance, so /// the amount of bytes read and the previous limit value don't matter when @@ -2889,6 +2896,7 @@ impl Take { /// ``` #[stable(feature = "take_set_limit", since = "1.27.0")] pub fn set_limit(&mut self, limit: u64) { + self.len = limit; self.limit = limit; } @@ -3076,6 +3084,49 @@ impl SizeHint for Take { } } +#[stable(feature = "seek_io_take", since = "CURRENT_RUSTC_VERSION")] +impl Seek for Take { + fn seek(&mut self, pos: SeekFrom) -> Result { + let new_position = match pos { + SeekFrom::Start(v) => Some(v), + SeekFrom::Current(v) => self.position().checked_add_signed(v), + SeekFrom::End(v) => self.len.checked_add_signed(v), + }; + let new_position = match new_position { + Some(v) if v <= self.len => v, + _ => return Err(ErrorKind::InvalidInput.into()), + }; + while new_position != self.position() { + if let Some(offset) = new_position.checked_signed_diff(self.position()) { + self.inner.seek_relative(offset)?; + self.limit = self.limit.wrapping_sub(offset as u64); + break; + } + let offset = if new_position > self.position() { i64::MAX } else { i64::MIN }; + self.inner.seek_relative(offset)?; + self.limit = self.limit.wrapping_sub(offset as u64); + } + Ok(new_position) + } + + fn stream_len(&mut self) -> Result { + Ok(self.len) + } + + fn stream_position(&mut self) -> Result { + Ok(self.position()) + } + + fn seek_relative(&mut self, offset: i64) -> Result<()> { + if !self.position().checked_add_signed(offset).is_some_and(|p| p <= self.len) { + return Err(ErrorKind::InvalidInput.into()); + } + self.inner.seek_relative(offset)?; + self.limit = self.limit.wrapping_sub(offset as u64); + Ok(()) + } +} + /// An iterator over `u8` values of a reader. /// /// This struct is generally created by calling [`bytes`] on a reader. diff --git a/library/std/src/io/tests.rs b/library/std/src/io/tests.rs index fd962b0415c7d..b22988d4a8a9d 100644 --- a/library/std/src/io/tests.rs +++ b/library/std/src/io/tests.rs @@ -416,6 +416,126 @@ fn seek_position() -> io::Result<()> { Ok(()) } +#[test] +fn take_seek() -> io::Result<()> { + let mut buf = Cursor::new(b"0123456789"); + buf.set_position(2); + let mut take = buf.by_ref().take(4); + let mut buf1 = [0u8; 1]; + let mut buf2 = [0u8; 2]; + assert_eq!(take.position(), 0); + + assert_eq!(take.seek(SeekFrom::Start(0))?, 0); + take.read_exact(&mut buf2)?; + assert_eq!(buf2, [b'2', b'3']); + assert_eq!(take.seek(SeekFrom::Start(1))?, 1); + take.read_exact(&mut buf2)?; + assert_eq!(buf2, [b'3', b'4']); + assert_eq!(take.seek(SeekFrom::Start(2))?, 2); + take.read_exact(&mut buf2)?; + assert_eq!(buf2, [b'4', b'5']); + assert_eq!(take.seek(SeekFrom::Start(3))?, 3); + take.read_exact(&mut buf1)?; + assert_eq!(buf1, [b'5']); + assert_eq!(take.seek(SeekFrom::Start(4))?, 4); + assert_eq!(take.read(&mut buf1)?, 0); + + assert_eq!(take.seek(SeekFrom::End(0))?, 4); + assert_eq!(take.seek(SeekFrom::End(-1))?, 3); + take.read_exact(&mut buf1)?; + assert_eq!(buf1, [b'5']); + assert_eq!(take.seek(SeekFrom::End(-2))?, 2); + take.read_exact(&mut buf2)?; + assert_eq!(buf2, [b'4', b'5']); + assert_eq!(take.seek(SeekFrom::End(-3))?, 1); + take.read_exact(&mut buf2)?; + assert_eq!(buf2, [b'3', b'4']); + assert_eq!(take.seek(SeekFrom::End(-4))?, 0); + take.read_exact(&mut buf2)?; + assert_eq!(buf2, [b'2', b'3']); + + assert_eq!(take.seek(SeekFrom::Current(0))?, 2); + take.read_exact(&mut buf2)?; + assert_eq!(buf2, [b'4', b'5']); + + assert_eq!(take.seek(SeekFrom::Current(-3))?, 1); + take.read_exact(&mut buf2)?; + assert_eq!(buf2, [b'3', b'4']); + + assert_eq!(take.seek(SeekFrom::Current(-1))?, 2); + take.read_exact(&mut buf2)?; + assert_eq!(buf2, [b'4', b'5']); + + assert_eq!(take.seek(SeekFrom::Current(-4))?, 0); + take.read_exact(&mut buf2)?; + assert_eq!(buf2, [b'2', b'3']); + + assert_eq!(take.seek(SeekFrom::Current(2))?, 4); + assert_eq!(take.read(&mut buf1)?, 0); + + Ok(()) +} + +#[test] +fn take_seek_error() { + let buf = Cursor::new(b"0123456789"); + let mut take = buf.take(2); + assert!(take.seek(SeekFrom::Start(3)).is_err()); + assert!(take.seek(SeekFrom::End(1)).is_err()); + assert!(take.seek(SeekFrom::End(-3)).is_err()); + assert!(take.seek(SeekFrom::Current(-1)).is_err()); + assert!(take.seek(SeekFrom::Current(3)).is_err()); +} + +struct ExampleHugeRangeOfZeroes { + position: u64, +} + +impl Read for ExampleHugeRangeOfZeroes { + fn read(&mut self, buf: &mut [u8]) -> io::Result { + let max = buf.len().min(usize::MAX); + for i in 0..max { + if self.position == u64::MAX { + return Ok(i); + } + self.position += 1; + buf[i] = 0; + } + Ok(max) + } +} + +impl Seek for ExampleHugeRangeOfZeroes { + fn seek(&mut self, pos: io::SeekFrom) -> io::Result { + match pos { + io::SeekFrom::Start(i) => self.position = i, + io::SeekFrom::End(i) if i >= 0 => self.position = u64::MAX, + io::SeekFrom::End(i) => self.position = self.position - i.unsigned_abs(), + io::SeekFrom::Current(i) => { + self.position = if i >= 0 { + self.position.saturating_add(i.unsigned_abs()) + } else { + self.position.saturating_sub(i.unsigned_abs()) + }; + } + } + Ok(self.position) + } +} + +#[test] +fn take_seek_big_offsets() -> io::Result<()> { + let inner = ExampleHugeRangeOfZeroes { position: 1 }; + let mut take = inner.take(u64::MAX - 2); + assert_eq!(take.seek(io::SeekFrom::Start(u64::MAX - 2))?, u64::MAX - 2); + assert_eq!(take.inner.position, u64::MAX - 1); + assert_eq!(take.seek(io::SeekFrom::Start(0))?, 0); + assert_eq!(take.inner.position, 1); + assert_eq!(take.seek(io::SeekFrom::End(-1))?, u64::MAX - 3); + assert_eq!(take.inner.position, u64::MAX - 2); + Ok(()) +} + // A simple example reader which uses the default implementation of // read_to_end. struct ExampleSliceReader<'a> { diff --git a/library/std/src/lib.rs b/library/std/src/lib.rs index ca04a381271b2..ef41b47384d61 100644 --- a/library/std/src/lib.rs +++ b/library/std/src/lib.rs @@ -325,6 +325,7 @@ #![feature(try_blocks)] #![feature(try_trait_v2)] #![feature(type_alias_impl_trait)] +#![feature(unsigned_signed_diff)] // tidy-alphabetical-end // // Library features (core):