Skip to content
This repository was archived by the owner on Nov 15, 2023. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion frame/bags-list/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -271,13 +271,19 @@ impl<T: Config<I>, I: 'static> Pallet<T, I> {

impl<T: Config<I>, I: 'static> SortedListProvider<T::AccountId> for Pallet<T, I> {
type Error = ListError;

type Score = T::Score;

fn iter() -> Box<dyn Iterator<Item = T::AccountId>> {
Box::new(List::<T, I>::iter().map(|n| n.id().clone()))
}

fn iter_from(
start: &T::AccountId,
) -> Result<Box<dyn Iterator<Item = T::AccountId>>, Self::Error> {
let iter = List::<T, I>::iter_from(start)?;
Ok(Box::new(iter.map(|n| n.id().clone())))
}

fn count() -> u32 {
ListNodes::<T, I>::count()
}
Expand Down
29 changes: 29 additions & 0 deletions frame/bags-list/src/list/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ use sp_std::{
pub enum ListError {
/// A duplicate id has been detected.
Duplicate,
/// Given node id was not found.
NodeNotFound,
}

#[cfg(test)]
Expand Down Expand Up @@ -244,6 +246,33 @@ impl<T: Config<I>, I: 'static> List<T, I> {
iter.filter_map(Bag::get).flat_map(|bag| bag.iter())
}

/// Same as `iter`, but we start from a specific node.
///
/// All items after this node are returned, excluding `start` itself.
pub(crate) fn iter_from(
start: &T::AccountId,
) -> Result<impl Iterator<Item = Node<T, I>>, ListError> {
// We chain two iterators:
// 1. from the given `start` till the end of the bag
// 2. all the bags that come after `start`'s bag.

let start_node = Node::<T, I>::get(start).ok_or(ListError::NodeNotFound)?;
let start_node_upper = start_node.bag_upper;
let start_bag = sp_std::iter::successors(start_node.next(), |prev| prev.next());

let thresholds = T::BagThresholds::get();
let idx = thresholds.partition_point(|&threshold| start_node_upper > threshold);
let leftover_bags = thresholds
.into_iter()
.take(idx)
.copied()
.rev()
.filter_map(Bag::get)
.flat_map(|bag| bag.iter());

Ok(start_bag.chain(leftover_bags))
}

/// Insert several ids into the appropriate bags in the list. Continues with insertions
/// if duplicates are detected.
///
Expand Down
21 changes: 21 additions & 0 deletions frame/bags-list/src/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -458,6 +458,27 @@ mod sorted_list_provider {
});
}

#[test]
fn iter_from_works() {
ExtBuilder::default().add_ids(vec![(5, 5), (6, 15)]).build_and_execute(|| {
// given
assert_eq!(
List::<Runtime>::get_bags(),
vec![(10, vec![1, 5]), (20, vec![6]), (1000, vec![2, 3, 4])]
);

assert_eq!(BagsList::iter_from(&2).unwrap().collect::<Vec<_>>(), vec![3, 4, 6, 1, 5]);
assert_eq!(BagsList::iter_from(&3).unwrap().collect::<Vec<_>>(), vec![4, 6, 1, 5]);
assert_eq!(BagsList::iter_from(&4).unwrap().collect::<Vec<_>>(), vec![6, 1, 5]);
assert_eq!(BagsList::iter_from(&6).unwrap().collect::<Vec<_>>(), vec![1, 5]);
assert_eq!(BagsList::iter_from(&1).unwrap().collect::<Vec<_>>(), vec![5]);
assert!(BagsList::iter_from(&5).unwrap().collect::<Vec<_>>().is_empty());
assert!(BagsList::iter_from(&7).is_err());

assert_storage_noop!(assert!(BagsList::iter_from(&8).is_err()));
});
}

#[test]
fn count_works() {
ExtBuilder::default().build_and_execute(|| {
Expand Down
5 changes: 5 additions & 0 deletions frame/election-provider-support/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -441,6 +441,11 @@ pub trait SortedListProvider<AccountId> {
/// An iterator over the list, which can have `take` called on it.
fn iter() -> Box<dyn Iterator<Item = AccountId>>;

/// Returns an iterator over the list, starting right after from the given voter.
///
/// May return an error if `start` is invalid.
fn iter_from(start: &AccountId) -> Result<Box<dyn Iterator<Item = AccountId>>, Self::Error>;

/// The current count of ids in the list.
fn count() -> u32;

Expand Down
18 changes: 17 additions & 1 deletion frame/staking/src/pallet/impls.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1308,14 +1308,30 @@ impl<T: Config> SortedListProvider<T::AccountId> for UseNominatorsAndValidatorsM
type Error = ();
type Score = VoteWeight;

/// Returns iterator over voter list, which can have `take` called on it.
fn iter() -> Box<dyn Iterator<Item = T::AccountId>> {
Box::new(
Validators::<T>::iter()
.map(|(v, _)| v)
.chain(Nominators::<T>::iter().map(|(n, _)| n)),
)
}
fn iter_from(
start: &T::AccountId,
) -> Result<Box<dyn Iterator<Item = T::AccountId>>, Self::Error> {
if Validators::<T>::contains_key(start) {
let start_key = Validators::<T>::hashed_key_for(start);
Ok(Box::new(
Validators::<T>::iter_from(start_key)
.map(|(n, _)| n)
.chain(Nominators::<T>::iter().map(|(x, _)| x)),
))
} else if Nominators::<T>::contains_key(start) {
let start_key = Nominators::<T>::hashed_key_for(start);
Ok(Box::new(Nominators::<T>::iter_from(start_key).map(|(n, _)| n)))
} else {
Err(())
}
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you maybe also add a minimal test here?

fn count() -> u32 {
Nominators::<T>::count().saturating_add(Validators::<T>::count())
}
Expand Down