From 06a2e721096becc8dedf412a9b39713f0d158eaa Mon Sep 17 00:00:00 2001 From: tison Date: Thu, 16 Oct 2025 23:05:12 +0800 Subject: [PATCH] Implement VecDeque::extract_if Signed-off-by: tison --- .../src/collections/vec_deque/extract_if.rs | 149 ++++++++++ .../alloc/src/collections/vec_deque/mod.rs | 94 ++++++ .../alloc/src/collections/vec_deque/tests.rs | 272 +++++++++++++++++- 3 files changed, 514 insertions(+), 1 deletion(-) create mode 100644 library/alloc/src/collections/vec_deque/extract_if.rs diff --git a/library/alloc/src/collections/vec_deque/extract_if.rs b/library/alloc/src/collections/vec_deque/extract_if.rs new file mode 100644 index 0000000000000..bed7d46482cf4 --- /dev/null +++ b/library/alloc/src/collections/vec_deque/extract_if.rs @@ -0,0 +1,149 @@ +use core::ops::{Range, RangeBounds}; +use core::{fmt, ptr, slice}; + +use super::VecDeque; +use crate::alloc::{Allocator, Global}; + +/// An iterator which uses a closure to determine if an element should be removed. +/// +/// This struct is created by [`VecDeque::extract_if`]. +/// See its documentation for more. +/// +/// # Example +/// +/// ``` +/// #![feature(vec_deque_extract_if)] +/// +/// use std::collections::vec_deque::ExtractIf; +/// use std::collections::vec_deque::VecDeque; +/// +/// let mut v = VecDeque::from([0, 1, 2]); +/// let iter: ExtractIf<'_, _, _> = v.extract_if(.., |x| *x % 2 == 0); +/// ``` +#[unstable(feature = "vec_deque_extract_if", issue = "147750")] +#[must_use = "iterators are lazy and do nothing unless consumed"] +pub struct ExtractIf< + 'a, + T, + F, + #[unstable(feature = "allocator_api", issue = "32838")] A: Allocator = Global, +> { + vec: &'a mut VecDeque, + /// The index of the item that will be inspected by the next call to `next`. + idx: usize, + /// Elements at and beyond this point will be retained. Must be equal or smaller than `old_len`. + end: usize, + /// The number of items that have been drained (removed) thus far. + del: usize, + /// The original length of `vec` prior to draining. + old_len: usize, + /// The filter test predicate. + pred: F, +} + +impl<'a, T, F, A: Allocator> ExtractIf<'a, T, F, A> { + pub(super) fn new>( + vec: &'a mut VecDeque, + pred: F, + range: R, + ) -> Self { + let old_len = vec.len(); + let Range { start, end } = slice::range(range, ..old_len); + + // Guard against the deque getting leaked (leak amplification) + vec.len = 0; + ExtractIf { vec, idx: start, del: 0, end, old_len, pred } + } + + /// Returns a reference to the underlying allocator. + #[unstable(feature = "allocator_api", issue = "32838")] + #[inline] + pub fn allocator(&self) -> &A { + self.vec.allocator() + } +} + +#[unstable(feature = "vec_deque_extract_if", issue = "147750")] +impl Iterator for ExtractIf<'_, T, F, A> +where + F: FnMut(&mut T) -> bool, +{ + type Item = T; + + fn next(&mut self) -> Option { + while self.idx < self.end { + let i = self.idx; + // SAFETY: + // We know that `i < self.end` from the if guard and that `self.end <= self.old_len` from + // the validity of `Self`. Therefore `i` points to an element within `vec`. + // + // Additionally, the i-th element is valid because each element is visited at most once + // and it is the first time we access vec[i]. + // + // Note: we can't use `vec.get_mut(i).unwrap()` here since the precondition for that + // function is that i < vec.len, but we've set vec's length to zero. + let idx = self.vec.to_physical_idx(i); + let cur = unsafe { &mut *self.vec.ptr().add(idx) }; + let drained = (self.pred)(cur); + // Update the index *after* the predicate is called. If the index + // is updated prior and the predicate panics, the element at this + // index would be leaked. + self.idx += 1; + if drained { + self.del += 1; + // SAFETY: We never touch this element again after returning it. + return Some(unsafe { ptr::read(cur) }); + } else if self.del > 0 { + let hole_slot = self.vec.to_physical_idx(i - self.del); + // SAFETY: `self.del` > 0, so the hole slot must not overlap with current element. + // We use copy for move, and never touch this element again. + unsafe { self.vec.wrap_copy(idx, hole_slot, 1) }; + } + } + None + } + + fn size_hint(&self) -> (usize, Option) { + (0, Some(self.end - self.idx)) + } +} + +#[unstable(feature = "vec_deque_extract_if", issue = "147750")] +impl Drop for ExtractIf<'_, T, F, A> { + fn drop(&mut self) { + if self.del > 0 { + let src = self.vec.to_physical_idx(self.idx); + let dst = self.vec.to_physical_idx(self.idx - self.del); + let len = self.old_len - self.idx; + // SAFETY: Trailing unchecked items must be valid since we never touch them. + unsafe { self.vec.wrap_copy(src, dst, len) }; + } + self.vec.len = self.old_len - self.del; + } +} + +#[unstable(feature = "vec_deque_extract_if", issue = "147750")] +impl fmt::Debug for ExtractIf<'_, T, F, A> +where + T: fmt::Debug, + A: Allocator, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let peek = if self.idx < self.end { + let idx = self.vec.to_physical_idx(self.idx); + // This has to use pointer arithmetic as `self.vec[self.idx]` or + // `self.vec.get_unchecked(self.idx)` wouldn't work since we + // temporarily set the length of `self.vec` to zero. + // + // SAFETY: + // Since `self.idx` is smaller than `self.end` and `self.end` is + // smaller than `self.old_len`, `idx` is valid for indexing the + // buffer. Also, per the invariant of `self.idx`, this element + // has not been inspected/moved out yet. + Some(unsafe { &*self.vec.ptr().add(idx) }) + } else { + None + }; + f.debug_struct("ExtractIf").field("peek", &peek).finish_non_exhaustive() + } +} diff --git a/library/alloc/src/collections/vec_deque/mod.rs b/library/alloc/src/collections/vec_deque/mod.rs index ac619a42d356d..3bdb7415f0c2a 100644 --- a/library/alloc/src/collections/vec_deque/mod.rs +++ b/library/alloc/src/collections/vec_deque/mod.rs @@ -32,6 +32,11 @@ pub use self::drain::Drain; mod drain; +#[unstable(feature = "vec_deque_extract_if", issue = "147750")] +pub use self::extract_if::ExtractIf; + +mod extract_if; + #[stable(feature = "rust1", since = "1.0.0")] pub use self::iter_mut::IterMut; @@ -542,6 +547,95 @@ impl VecDeque { } debug_assert!(self.head < self.capacity() || self.capacity() == 0); } + + /// Creates an iterator which uses a closure to determine if an element in the range should be removed. + /// + /// If the closure returns `true`, the element is removed from the deque and yielded. If the closure + /// returns `false`, or panics, the element remains in the deque and will not be yielded. + /// + /// Only elements that fall in the provided range are considered for extraction, but any elements + /// after the range will still have to be moved if any element has been extracted. + /// + /// If the returned `ExtractIf` is not exhausted, e.g. because it is dropped without iterating + /// or the iteration short-circuits, then the remaining elements will be retained. + /// Use [`retain_mut`] with a negated predicate if you do not need the returned iterator. + /// + /// [`retain_mut`]: VecDeque::retain_mut + /// + /// Using this method is equivalent to the following code: + /// + /// ``` + /// #![feature(vec_deque_extract_if)] + /// # use std::collections::VecDeque; + /// # let some_predicate = |x: &mut i32| { *x % 2 == 1 }; + /// # let mut deq: VecDeque<_> = (0..10).collect(); + /// # let mut deq2 = deq.clone(); + /// # let range = 1..5; + /// let mut i = range.start; + /// let end_items = deq.len() - range.end; + /// # let mut extracted = vec![]; + /// + /// while i < deq.len() - end_items { + /// if some_predicate(&mut deq[i]) { + /// let val = deq.remove(i).unwrap(); + /// // your code here + /// # extracted.push(val); + /// } else { + /// i += 1; + /// } + /// } + /// + /// # let extracted2: Vec<_> = deq2.extract_if(range, some_predicate).collect(); + /// # assert_eq!(deq, deq2); + /// # assert_eq!(extracted, extracted2); + /// ``` + /// + /// But `extract_if` is easier to use. `extract_if` is also more efficient, + /// because it can backshift the elements of the array in bulk. + /// + /// The iterator also lets you mutate the value of each element in the + /// closure, regardless of whether you choose to keep or remove it. + /// + /// # Panics + /// + /// If `range` is out of bounds. + /// + /// # Examples + /// + /// Splitting a deque into even and odd values, reusing the original deque: + /// + /// ``` + /// #![feature(vec_deque_extract_if)] + /// use std::collections::VecDeque; + /// + /// let mut numbers = VecDeque::from([1, 2, 3, 4, 5, 6, 8, 9, 11, 13, 14, 15]); + /// + /// let evens = numbers.extract_if(.., |x| *x % 2 == 0).collect::>(); + /// let odds = numbers; + /// + /// assert_eq!(evens, VecDeque::from([2, 4, 6, 8, 14])); + /// assert_eq!(odds, VecDeque::from([1, 3, 5, 9, 11, 13, 15])); + /// ``` + /// + /// Using the range argument to only process a part of the deque: + /// + /// ``` + /// #![feature(vec_deque_extract_if)] + /// use std::collections::VecDeque; + /// + /// let mut items = VecDeque::from([0, 0, 0, 0, 0, 0, 0, 1, 2, 1, 2, 1, 2]); + /// let ones = items.extract_if(7.., |x| *x == 1).collect::>(); + /// assert_eq!(items, VecDeque::from([0, 0, 0, 0, 0, 0, 0, 2, 2, 2])); + /// assert_eq!(ones.len(), 3); + /// ``` + #[unstable(feature = "vec_deque_extract_if", issue = "147750")] + pub fn extract_if(&mut self, range: R, filter: F) -> ExtractIf<'_, T, F, A> + where + F: FnMut(&mut T) -> bool, + R: RangeBounds, + { + ExtractIf::new(self, filter, range) + } } impl VecDeque { diff --git a/library/alloc/src/collections/vec_deque/tests.rs b/library/alloc/src/collections/vec_deque/tests.rs index 2501534e95080..dc50cc34d9dac 100644 --- a/library/alloc/src/collections/vec_deque/tests.rs +++ b/library/alloc/src/collections/vec_deque/tests.rs @@ -1,6 +1,8 @@ -use core::iter::TrustedLen; +use std::iter::TrustedLen; +use std::panic::{AssertUnwindSafe, catch_unwind}; use super::*; +use crate::testing::crash_test::{CrashTestDummy, Panic}; use crate::testing::macros::struct_with_counted_drop; #[bench] @@ -1161,3 +1163,271 @@ fn issue_80303() { assert_eq!(vda, vdb); assert_eq!(hash_code(vda), hash_code(vdb)); } + +#[test] +fn extract_if_test() { + let mut m: VecDeque = VecDeque::from([1, 2, 3, 4, 5, 6]); + let deleted = m.extract_if(.., |v| *v < 4).collect::>(); + + assert_eq!(deleted, &[1, 2, 3]); + assert_eq!(m, &[4, 5, 6]); +} + +#[test] +fn drain_to_empty_test() { + let mut m: VecDeque = VecDeque::from([1, 2, 3, 4, 5, 6]); + let deleted = m.extract_if(.., |_| true).collect::>(); + + assert_eq!(deleted, &[1, 2, 3, 4, 5, 6]); + assert_eq!(m, &[]); +} + +#[test] +fn extract_if_empty() { + let mut list: VecDeque = VecDeque::new(); + + { + let mut iter = list.extract_if(.., |_| true); + assert_eq!(iter.size_hint(), (0, Some(0))); + assert_eq!(iter.next(), None); + assert_eq!(iter.size_hint(), (0, Some(0))); + assert_eq!(iter.next(), None); + assert_eq!(iter.size_hint(), (0, Some(0))); + } + + assert_eq!(list.len(), 0); + assert_eq!(list, vec![]); +} + +#[test] +fn extract_if_zst() { + let mut list: VecDeque<_> = [(), (), (), (), ()].into_iter().collect(); + let initial_len = list.len(); + let mut count = 0; + + { + let mut iter = list.extract_if(.., |_| true); + assert_eq!(iter.size_hint(), (0, Some(initial_len))); + while let Some(_) = iter.next() { + count += 1; + assert_eq!(iter.size_hint(), (0, Some(initial_len - count))); + } + assert_eq!(iter.size_hint(), (0, Some(0))); + assert_eq!(iter.next(), None); + assert_eq!(iter.size_hint(), (0, Some(0))); + } + + assert_eq!(count, initial_len); + assert_eq!(list.len(), 0); + assert_eq!(list, vec![]); +} + +#[test] +fn extract_if_false() { + let mut list: VecDeque<_> = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10].into_iter().collect(); + + let initial_len = list.len(); + let mut count = 0; + + { + let mut iter = list.extract_if(.., |_| false); + assert_eq!(iter.size_hint(), (0, Some(initial_len))); + for _ in iter.by_ref() { + count += 1; + } + assert_eq!(iter.size_hint(), (0, Some(0))); + assert_eq!(iter.next(), None); + assert_eq!(iter.size_hint(), (0, Some(0))); + } + + assert_eq!(count, 0); + assert_eq!(list.len(), initial_len); + assert_eq!(list, vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10]); +} + +#[test] +fn extract_if_true() { + let mut list: VecDeque<_> = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10].into_iter().collect(); + + let initial_len = list.len(); + let mut count = 0; + + { + let mut iter = list.extract_if(.., |_| true); + assert_eq!(iter.size_hint(), (0, Some(initial_len))); + while let Some(_) = iter.next() { + count += 1; + assert_eq!(iter.size_hint(), (0, Some(initial_len - count))); + } + assert_eq!(iter.size_hint(), (0, Some(0))); + assert_eq!(iter.next(), None); + assert_eq!(iter.size_hint(), (0, Some(0))); + } + + assert_eq!(count, initial_len); + assert_eq!(list.len(), 0); + assert_eq!(list, vec![]); +} + +#[test] +fn extract_if_non_contiguous() { + let mut list = + [1, 2, 4, 6, 7, 9, 11, 13, 15, 17, 18, 20, 22, 24, 26, 27, 29, 31, 33, 34, 35, 36, 37, 39] + .into_iter() + .collect::>(); + list.rotate_left(3); + + assert!(!list.is_contiguous()); + assert_eq!( + list, + [6, 7, 9, 11, 13, 15, 17, 18, 20, 22, 24, 26, 27, 29, 31, 33, 34, 35, 36, 37, 39, 1, 2, 4] + ); + + let removed = list.extract_if(.., |x| *x % 2 == 0).collect::>(); + assert_eq!(removed.len(), 10); + assert_eq!(removed, vec![6, 18, 20, 22, 24, 26, 34, 36, 2, 4]); + + assert_eq!(list.len(), 14); + assert_eq!(list, vec![7, 9, 11, 13, 15, 17, 27, 29, 31, 33, 35, 37, 39, 1]); +} + +#[test] +fn extract_if_complex() { + { + // [+xxx++++++xxxxx++++x+x++] + let mut list = [ + 1, 2, 4, 6, 7, 9, 11, 13, 15, 17, 18, 20, 22, 24, 26, 27, 29, 31, 33, 34, 35, 36, 37, + 39, + ] + .into_iter() + .collect::>(); + + let removed = list.extract_if(.., |x| *x % 2 == 0).collect::>(); + assert_eq!(removed.len(), 10); + assert_eq!(removed, vec![2, 4, 6, 18, 20, 22, 24, 26, 34, 36]); + + assert_eq!(list.len(), 14); + assert_eq!(list, vec![1, 7, 9, 11, 13, 15, 17, 27, 29, 31, 33, 35, 37, 39]); + } + + { + // [xxx++++++xxxxx++++x+x++] + let mut list = + [2, 4, 6, 7, 9, 11, 13, 15, 17, 18, 20, 22, 24, 26, 27, 29, 31, 33, 34, 35, 36, 37, 39] + .into_iter() + .collect::>(); + + let removed = list.extract_if(.., |x| *x % 2 == 0).collect::>(); + assert_eq!(removed.len(), 10); + assert_eq!(removed, vec![2, 4, 6, 18, 20, 22, 24, 26, 34, 36]); + + assert_eq!(list.len(), 13); + assert_eq!(list, vec![7, 9, 11, 13, 15, 17, 27, 29, 31, 33, 35, 37, 39]); + } + + { + // [xxx++++++xxxxx++++x+x] + let mut list = + [2, 4, 6, 7, 9, 11, 13, 15, 17, 18, 20, 22, 24, 26, 27, 29, 31, 33, 34, 35, 36] + .into_iter() + .collect::>(); + + let removed = list.extract_if(.., |x| *x % 2 == 0).collect::>(); + assert_eq!(removed.len(), 10); + assert_eq!(removed, vec![2, 4, 6, 18, 20, 22, 24, 26, 34, 36]); + + assert_eq!(list.len(), 11); + assert_eq!(list, vec![7, 9, 11, 13, 15, 17, 27, 29, 31, 33, 35]); + } + + { + // [xxxxxxxxxx+++++++++++] + let mut list = [2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 1, 3, 5, 7, 9, 11, 13, 15, 17, 19] + .into_iter() + .collect::>(); + + let removed = list.extract_if(.., |x| *x % 2 == 0).collect::>(); + assert_eq!(removed.len(), 10); + assert_eq!(removed, vec![2, 4, 6, 8, 10, 12, 14, 16, 18, 20]); + + assert_eq!(list.len(), 10); + assert_eq!(list, vec![1, 3, 5, 7, 9, 11, 13, 15, 17, 19]); + } + + { + // [+++++++++++xxxxxxxxxx] + let mut list = [1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20] + .into_iter() + .collect::>(); + + let removed = list.extract_if(.., |x| *x % 2 == 0).collect::>(); + assert_eq!(removed.len(), 10); + assert_eq!(removed, vec![2, 4, 6, 8, 10, 12, 14, 16, 18, 20]); + + assert_eq!(list.len(), 10); + assert_eq!(list, vec![1, 3, 5, 7, 9, 11, 13, 15, 17, 19]); + } +} + +#[test] +#[cfg_attr(not(panic = "unwind"), ignore = "test requires unwinding support")] +fn extract_if_drop_panic_leak() { + let d0 = CrashTestDummy::new(0); + let d1 = CrashTestDummy::new(1); + let d2 = CrashTestDummy::new(2); + let d3 = CrashTestDummy::new(3); + let d4 = CrashTestDummy::new(4); + let d5 = CrashTestDummy::new(5); + let d6 = CrashTestDummy::new(6); + let d7 = CrashTestDummy::new(7); + let mut q = VecDeque::new(); + q.push_back(d3.spawn(Panic::Never)); + q.push_back(d4.spawn(Panic::Never)); + q.push_back(d5.spawn(Panic::Never)); + q.push_back(d6.spawn(Panic::Never)); + q.push_back(d7.spawn(Panic::Never)); + q.push_front(d2.spawn(Panic::Never)); + q.push_front(d1.spawn(Panic::InDrop)); + q.push_front(d0.spawn(Panic::Never)); + + catch_unwind(AssertUnwindSafe(|| q.extract_if(.., |_| true).for_each(drop))).unwrap_err(); + + assert_eq!(d0.dropped(), 1); + assert_eq!(d1.dropped(), 1); + assert_eq!(d2.dropped(), 0); + assert_eq!(d3.dropped(), 0); + assert_eq!(d4.dropped(), 0); + assert_eq!(d5.dropped(), 0); + assert_eq!(d6.dropped(), 0); + assert_eq!(d7.dropped(), 0); + drop(q); + assert_eq!(d2.dropped(), 1); + assert_eq!(d3.dropped(), 1); + assert_eq!(d4.dropped(), 1); + assert_eq!(d5.dropped(), 1); + assert_eq!(d6.dropped(), 1); + assert_eq!(d7.dropped(), 1); +} + +#[test] +#[cfg_attr(not(panic = "unwind"), ignore = "test requires unwinding support")] +fn extract_if_pred_panic_leak() { + struct_with_counted_drop!(D(u32), DROPS); + + let mut q = VecDeque::new(); + q.push_back(D(3)); + q.push_back(D(4)); + q.push_back(D(5)); + q.push_back(D(6)); + q.push_back(D(7)); + q.push_front(D(2)); + q.push_front(D(1)); + q.push_front(D(0)); + + _ = catch_unwind(AssertUnwindSafe(|| { + q.extract_if(.., |item| if item.0 >= 2 { panic!() } else { true }).for_each(drop) + })); + + assert_eq!(DROPS.get(), 2); // 0 and 1 + assert_eq!(q.len(), 6); +}