diff --git a/frame/bags-list/src/lib.rs b/frame/bags-list/src/lib.rs index aa9f1c80dfdbc..94553433e230d 100644 --- a/frame/bags-list/src/lib.rs +++ b/frame/bags-list/src/lib.rs @@ -271,13 +271,19 @@ impl, I: 'static> Pallet { impl, I: 'static> SortedListProvider for Pallet { type Error = ListError; - type Score = T::Score; fn iter() -> Box> { Box::new(List::::iter().map(|n| n.id().clone())) } + fn iter_from( + start: &T::AccountId, + ) -> Result>, Self::Error> { + let iter = List::::iter_from(start)?; + Ok(Box::new(iter.map(|n| n.id().clone()))) + } + fn count() -> u32 { ListNodes::::count() } diff --git a/frame/bags-list/src/list/mod.rs b/frame/bags-list/src/list/mod.rs index 4e1287458bcb4..db8c06a38d674 100644 --- a/frame/bags-list/src/list/mod.rs +++ b/frame/bags-list/src/list/mod.rs @@ -42,6 +42,8 @@ use sp_std::{ pub enum ListError { /// A duplicate id has been detected. Duplicate, + /// Given node id was not found. + NodeNotFound, } #[cfg(test)] @@ -244,6 +246,33 @@ impl, I: 'static> List { iter.filter_map(Bag::get).flat_map(|bag| bag.iter()) } + /// Same as `iter`, but we start from a specific node. + /// + /// All items after this node are returned, excluding `start` itself. + pub(crate) fn iter_from( + start: &T::AccountId, + ) -> Result>, ListError> { + // We chain two iterators: + // 1. from the given `start` till the end of the bag + // 2. all the bags that come after `start`'s bag. + + let start_node = Node::::get(start).ok_or(ListError::NodeNotFound)?; + let start_node_upper = start_node.bag_upper; + let start_bag = sp_std::iter::successors(start_node.next(), |prev| prev.next()); + + let thresholds = T::BagThresholds::get(); + let idx = thresholds.partition_point(|&threshold| start_node_upper > threshold); + let leftover_bags = thresholds + .into_iter() + .take(idx) + .copied() + .rev() + .filter_map(Bag::get) + .flat_map(|bag| bag.iter()); + + Ok(start_bag.chain(leftover_bags)) + } + /// Insert several ids into the appropriate bags in the list. Continues with insertions /// if duplicates are detected. /// diff --git a/frame/bags-list/src/tests.rs b/frame/bags-list/src/tests.rs index 0d6ba4721b9a2..941623229dc27 100644 --- a/frame/bags-list/src/tests.rs +++ b/frame/bags-list/src/tests.rs @@ -458,6 +458,27 @@ mod sorted_list_provider { }); } + #[test] + fn iter_from_works() { + ExtBuilder::default().add_ids(vec![(5, 5), (6, 15)]).build_and_execute(|| { + // given + assert_eq!( + List::::get_bags(), + vec![(10, vec![1, 5]), (20, vec![6]), (1000, vec![2, 3, 4])] + ); + + assert_eq!(BagsList::iter_from(&2).unwrap().collect::>(), vec![3, 4, 6, 1, 5]); + assert_eq!(BagsList::iter_from(&3).unwrap().collect::>(), vec![4, 6, 1, 5]); + assert_eq!(BagsList::iter_from(&4).unwrap().collect::>(), vec![6, 1, 5]); + assert_eq!(BagsList::iter_from(&6).unwrap().collect::>(), vec![1, 5]); + assert_eq!(BagsList::iter_from(&1).unwrap().collect::>(), vec![5]); + assert!(BagsList::iter_from(&5).unwrap().collect::>().is_empty()); + assert!(BagsList::iter_from(&7).is_err()); + + assert_storage_noop!(assert!(BagsList::iter_from(&8).is_err())); + }); + } + #[test] fn count_works() { ExtBuilder::default().build_and_execute(|| { diff --git a/frame/election-provider-support/src/lib.rs b/frame/election-provider-support/src/lib.rs index 453cef8956fe5..19735cf6035ac 100644 --- a/frame/election-provider-support/src/lib.rs +++ b/frame/election-provider-support/src/lib.rs @@ -441,6 +441,11 @@ pub trait SortedListProvider { /// An iterator over the list, which can have `take` called on it. fn iter() -> Box>; + /// Returns an iterator over the list, starting right after from the given voter. + /// + /// May return an error if `start` is invalid. + fn iter_from(start: &AccountId) -> Result>, Self::Error>; + /// The current count of ids in the list. fn count() -> u32; diff --git a/frame/staking/src/pallet/impls.rs b/frame/staking/src/pallet/impls.rs index 9d5a3ed484184..90f19c6badd8f 100644 --- a/frame/staking/src/pallet/impls.rs +++ b/frame/staking/src/pallet/impls.rs @@ -1308,7 +1308,6 @@ impl SortedListProvider for UseNominatorsAndValidatorsM type Error = (); type Score = VoteWeight; - /// Returns iterator over voter list, which can have `take` called on it. fn iter() -> Box> { Box::new( Validators::::iter() @@ -1316,6 +1315,23 @@ impl SortedListProvider for UseNominatorsAndValidatorsM .chain(Nominators::::iter().map(|(n, _)| n)), ) } + fn iter_from( + start: &T::AccountId, + ) -> Result>, Self::Error> { + if Validators::::contains_key(start) { + let start_key = Validators::::hashed_key_for(start); + Ok(Box::new( + Validators::::iter_from(start_key) + .map(|(n, _)| n) + .chain(Nominators::::iter().map(|(x, _)| x)), + )) + } else if Nominators::::contains_key(start) { + let start_key = Nominators::::hashed_key_for(start); + Ok(Box::new(Nominators::::iter_from(start_key).map(|(n, _)| n))) + } else { + Err(()) + } + } fn count() -> u32 { Nominators::::count().saturating_add(Validators::::count()) }