From 1c966e7f15a46cc015e7acfde4650b45fee55168 Mon Sep 17 00:00:00 2001 From: Scott McMurray Date: Thu, 24 Nov 2022 03:12:05 -0800 Subject: [PATCH 1/3] Extract the logic for `TrustedLen` to a named method that can be called directly --- library/alloc/src/vec/mod.rs | 35 ++++++++++++++++++++++++++++ library/alloc/src/vec/spec_extend.rs | 34 ++------------------------- 2 files changed, 37 insertions(+), 32 deletions(-) diff --git a/library/alloc/src/vec/mod.rs b/library/alloc/src/vec/mod.rs index 766006939fa48..9da728b34c3f0 100644 --- a/library/alloc/src/vec/mod.rs +++ b/library/alloc/src/vec/mod.rs @@ -2870,6 +2870,41 @@ impl Vec { } } + // specific extend for `TrustedLen` iterators, called both by the specializations + // and internal places where resolving specialization makes compilation slower + #[cfg(not(no_global_oom_handling))] + fn extend_trusted(&mut self, iterator: impl iter::TrustedLen) { + let (low, high) = iterator.size_hint(); + if let Some(additional) = high { + debug_assert_eq!( + low, + additional, + "TrustedLen iterator's size hint is not exact: {:?}", + (low, high) + ); + self.reserve(additional); + unsafe { + let mut ptr = self.as_mut_ptr().add(self.len()); + let mut local_len = SetLenOnDrop::new(&mut self.len); + iterator.for_each(move |element| { + ptr::write(ptr, element); + ptr = ptr.add(1); + // Since the loop executes user code which can panic we have to bump the pointer + // after each step. + // NB can't overflow since we would have had to alloc the address space + local_len.increment_len(1); + }); + } + } else { + // Per TrustedLen contract a `None` upper bound means that the iterator length + // truly exceeds usize::MAX, which would eventually lead to a capacity overflow anyway. + // Since the other branch already panics eagerly (via `reserve()`) we do the same here. + // This avoids additional codegen for a fallback code path which would eventually + // panic anyway. + panic!("capacity overflow"); + } + } + /// Creates a splicing iterator that replaces the specified range in the vector /// with the given `replace_with` iterator and yields the removed items. /// `replace_with` does not need to be the same length as `range`. diff --git a/library/alloc/src/vec/spec_extend.rs b/library/alloc/src/vec/spec_extend.rs index 1ea9c827afd70..56065ce565bfc 100644 --- a/library/alloc/src/vec/spec_extend.rs +++ b/library/alloc/src/vec/spec_extend.rs @@ -1,9 +1,8 @@ use crate::alloc::Allocator; use core::iter::TrustedLen; -use core::ptr::{self}; use core::slice::{self}; -use super::{IntoIter, SetLenOnDrop, Vec}; +use super::{IntoIter, Vec}; // Specialization trait used for Vec::extend pub(super) trait SpecExtend { @@ -24,36 +23,7 @@ where I: TrustedLen, { default fn spec_extend(&mut self, iterator: I) { - // This is the case for a TrustedLen iterator. - let (low, high) = iterator.size_hint(); - if let Some(additional) = high { - debug_assert_eq!( - low, - additional, - "TrustedLen iterator's size hint is not exact: {:?}", - (low, high) - ); - self.reserve(additional); - unsafe { - let mut ptr = self.as_mut_ptr().add(self.len()); - let mut local_len = SetLenOnDrop::new(&mut self.len); - iterator.for_each(move |element| { - ptr::write(ptr, element); - ptr = ptr.add(1); - // Since the loop executes user code which can panic we have to bump the pointer - // after each step. - // NB can't overflow since we would have had to alloc the address space - local_len.increment_len(1); - }); - } - } else { - // Per TrustedLen contract a `None` upper bound means that the iterator length - // truly exceeds usize::MAX, which would eventually lead to a capacity overflow anyway. - // Since the other branch already panics eagerly (via `reserve()`) we do the same here. - // This avoids additional codegen for a fallback code path which would eventually - // panic anyway. - panic!("capacity overflow"); - } + self.extend_trusted(iterator) } } From a8954f1f6a2245794ff41db9b1cba33a76e7d5f9 Mon Sep 17 00:00:00 2001 From: Scott McMurray Date: Thu, 24 Nov 2022 03:12:54 -0800 Subject: [PATCH 2/3] Stop peeling the last iteration of the loop in `Vec::repeat_with` --- library/alloc/src/vec/mod.rs | 12 +----------- 1 file changed, 1 insertion(+), 11 deletions(-) diff --git a/library/alloc/src/vec/mod.rs b/library/alloc/src/vec/mod.rs index 9da728b34c3f0..94ebcb863a421 100644 --- a/library/alloc/src/vec/mod.rs +++ b/library/alloc/src/vec/mod.rs @@ -2163,7 +2163,7 @@ impl Vec { { let len = self.len(); if new_len > len { - self.extend_with(new_len - len, ExtendFunc(f)); + self.extend_trusted(iter::repeat_with(f).take(new_len - len)); } else { self.truncate(new_len); } @@ -2491,16 +2491,6 @@ impl ExtendWith for ExtendElement { } } -struct ExtendFunc(F); -impl T> ExtendWith for ExtendFunc { - fn next(&mut self) -> T { - (self.0)() - } - fn last(mut self) -> T { - (self.0)() - } -} - impl Vec { #[cfg(not(no_global_oom_handling))] /// Extend the vector by `n` values, using the given generator. From 9d68a1a74c65245c9ae7b5db2c552c995697e8ef Mon Sep 17 00:00:00 2001 From: Scott McMurray Date: Thu, 24 Nov 2022 19:14:19 -0800 Subject: [PATCH 3/3] Tune RepeatWith::try_fold and Take::for_each and Vec::extend_trusted --- library/alloc/src/vec/mod.rs | 9 ++++----- library/alloc/src/vec/set_len_on_drop.rs | 5 +++++ library/core/src/iter/adapters/take.rs | 21 +++++++++++++++++++- library/core/src/iter/sources/repeat_with.rs | 17 ++++++++++++++++ library/core/tests/iter/adapters/take.rs | 20 +++++++++++++++++++ src/test/codegen/repeat-trusted-len.rs | 7 +++++++ 6 files changed, 73 insertions(+), 6 deletions(-) diff --git a/library/alloc/src/vec/mod.rs b/library/alloc/src/vec/mod.rs index 94ebcb863a421..e147af2ce39c6 100644 --- a/library/alloc/src/vec/mod.rs +++ b/library/alloc/src/vec/mod.rs @@ -2874,13 +2874,12 @@ impl Vec { ); self.reserve(additional); unsafe { - let mut ptr = self.as_mut_ptr().add(self.len()); + let ptr = self.as_mut_ptr(); let mut local_len = SetLenOnDrop::new(&mut self.len); iterator.for_each(move |element| { - ptr::write(ptr, element); - ptr = ptr.add(1); - // Since the loop executes user code which can panic we have to bump the pointer - // after each step. + ptr::write(ptr.add(local_len.current_len()), element); + // Since the loop executes user code which can panic we have to update + // the length every step to correctly drop what we've written. // NB can't overflow since we would have had to alloc the address space local_len.increment_len(1); }); diff --git a/library/alloc/src/vec/set_len_on_drop.rs b/library/alloc/src/vec/set_len_on_drop.rs index 8b66bc8121296..6ce5a3a9f54eb 100644 --- a/library/alloc/src/vec/set_len_on_drop.rs +++ b/library/alloc/src/vec/set_len_on_drop.rs @@ -18,6 +18,11 @@ impl<'a> SetLenOnDrop<'a> { pub(super) fn increment_len(&mut self, increment: usize) { self.local_len += increment; } + + #[inline] + pub(super) fn current_len(&self) -> usize { + self.local_len + } } impl Drop for SetLenOnDrop<'_> { diff --git a/library/core/src/iter/adapters/take.rs b/library/core/src/iter/adapters/take.rs index 58a0b9d7bbe99..d947c7b0e3013 100644 --- a/library/core/src/iter/adapters/take.rs +++ b/library/core/src/iter/adapters/take.rs @@ -75,7 +75,6 @@ where #[inline] fn try_fold(&mut self, init: Acc, fold: Fold) -> R where - Self: Sized, Fold: FnMut(Acc, Self::Item) -> R, R: Try, { @@ -100,6 +99,26 @@ where impl_fold_via_try_fold! { fold -> try_fold } + #[inline] + fn for_each(mut self, f: F) { + // The default implementation would use a unit accumulator, so we can + // avoid a stateful closure by folding over the remaining number + // of items we wish to return instead. + fn check<'a, Item>( + mut action: impl FnMut(Item) + 'a, + ) -> impl FnMut(usize, Item) -> Option + 'a { + move |more, x| { + action(x); + more.checked_sub(1) + } + } + + let remaining = self.n; + if remaining > 0 { + self.iter.try_fold(remaining - 1, check(f)); + } + } + #[inline] #[rustc_inherit_overflow_checks] fn advance_by(&mut self, n: usize) -> Result<(), usize> { diff --git a/library/core/src/iter/sources/repeat_with.rs b/library/core/src/iter/sources/repeat_with.rs index 6f62662d88066..ab2d0472b4701 100644 --- a/library/core/src/iter/sources/repeat_with.rs +++ b/library/core/src/iter/sources/repeat_with.rs @@ -1,4 +1,5 @@ use crate::iter::{FusedIterator, TrustedLen}; +use crate::ops::Try; /// Creates a new iterator that repeats elements of type `A` endlessly by /// applying the provided closure, the repeater, `F: FnMut() -> A`. @@ -89,6 +90,22 @@ impl A> Iterator for RepeatWith { fn size_hint(&self) -> (usize, Option) { (usize::MAX, None) } + + #[inline] + fn try_fold(&mut self, mut init: Acc, mut fold: Fold) -> R + where + Fold: FnMut(Acc, Self::Item) -> R, + R: Try, + { + // This override isn't strictly needed, but avoids the need to optimize + // away the `next`-always-returns-`Some` and emphasizes that the `?` + // is the only way to exit the loop. + + loop { + let item = (self.repeater)(); + init = fold(init, item)?; + } + } } #[stable(feature = "iterator_repeat_with", since = "1.28.0")] diff --git a/library/core/tests/iter/adapters/take.rs b/library/core/tests/iter/adapters/take.rs index bfb659f0a8378..3e26b43a2ede5 100644 --- a/library/core/tests/iter/adapters/take.rs +++ b/library/core/tests/iter/adapters/take.rs @@ -146,3 +146,23 @@ fn test_take_try_folds() { assert_eq!(iter.try_for_each(Err), Err(2)); assert_eq!(iter.try_for_each(Err), Ok(())); } + +#[test] +fn test_byref_take_consumed_items() { + let mut inner = 10..90; + + let mut count = 0; + inner.by_ref().take(0).for_each(|_| count += 1); + assert_eq!(count, 0); + assert_eq!(inner, 10..90); + + let mut count = 0; + inner.by_ref().take(10).for_each(|_| count += 1); + assert_eq!(count, 10); + assert_eq!(inner, 20..90); + + let mut count = 0; + inner.by_ref().take(100).for_each(|_| count += 1); + assert_eq!(count, 70); + assert_eq!(inner, 90..90); +} diff --git a/src/test/codegen/repeat-trusted-len.rs b/src/test/codegen/repeat-trusted-len.rs index 7aebd3ec7df0a..87c8fe1354d76 100644 --- a/src/test/codegen/repeat-trusted-len.rs +++ b/src/test/codegen/repeat-trusted-len.rs @@ -11,3 +11,10 @@ pub fn repeat_take_collect() -> Vec { // CHECK: call void @llvm.memset.{{.+}}({{i8\*|ptr}} {{.*}}align 1{{.*}} %{{[0-9]+}}, i8 42, i{{[0-9]+}} 100000, i1 false) iter::repeat(42).take(100000).collect() } + +// CHECK-LABEL: @repeat_with_take_collect +#[no_mangle] +pub fn repeat_with_take_collect() -> Vec { +// CHECK: call void @llvm.memset.{{.+}}({{i8\*|ptr}} {{.*}}align 1{{.*}} %{{[0-9]+}}, i8 13, i{{[0-9]+}} 12345, i1 false) + iter::repeat_with(|| 13).take(12345).collect() +}