diff --git a/examples/haploid_wright_fisher.rs b/examples/haploid_wright_fisher.rs index 1d10b7fa8..27053a20d 100644 --- a/examples/haploid_wright_fisher.rs +++ b/examples/haploid_wright_fisher.rs @@ -8,12 +8,30 @@ use proptest::prelude::*; use rand::distributions::Distribution; use rand::SeedableRng; +fn rotate_edges(bookmark: &tskit::types::Bookmark, tables: &mut tskit::TableCollection) { + let num_edges = tables.edges().num_rows().as_usize(); + let left = + unsafe { std::slice::from_raw_parts_mut((*tables.as_mut_ptr()).edges.left, num_edges) }; + let right = + unsafe { std::slice::from_raw_parts_mut((*tables.as_mut_ptr()).edges.right, num_edges) }; + let parent = + unsafe { std::slice::from_raw_parts_mut((*tables.as_mut_ptr()).edges.parent, num_edges) }; + let child = + unsafe { std::slice::from_raw_parts_mut((*tables.as_mut_ptr()).edges.child, num_edges) }; + let mid = bookmark.edges().as_usize(); + left.rotate_left(mid); + right.rotate_left(mid); + parent.rotate_left(mid); + child.rotate_left(mid); +} + // ANCHOR: haploid_wright_fisher fn simulate( seed: u64, popsize: usize, num_generations: i32, simplify_interval: i32, + update_bookmark: bool, ) -> Result { if popsize == 0 { return Err(anyhow::Error::msg("popsize must be > 0")); @@ -46,6 +64,7 @@ fn simulate( let parent_picker = rand::distributions::Uniform::new(0, popsize); let breakpoint_generator = rand::distributions::Uniform::new(0.0, 1.0); let mut rng = rand::rngs::StdRng::seed_from_u64(seed); + let mut bookmark = tskit::types::Bookmark::new(); for birth_time in (0..num_generations).rev() { for c in children.iter_mut() { @@ -64,7 +83,10 @@ fn simulate( } if birth_time % simplify_interval == 0 { - tables.full_sort(tskit::TableSortOptions::default())?; + tables.sort(&bookmark, tskit::TableSortOptions::default())?; + if update_bookmark { + rotate_edges(&bookmark, &mut tables); + } if let Some(idmap) = tables.simplify(children, tskit::SimplificationOptions::default(), true)? { @@ -73,6 +95,9 @@ fn simulate( *o = idmap[usize::try_from(*o)?]; } } + if update_bookmark { + bookmark.set_edges(tables.edges().num_rows()); + } } std::mem::swap(&mut parents, &mut children); } @@ -91,6 +116,8 @@ struct SimParams { num_generations: i32, simplify_interval: i32, treefile: Option, + #[clap(short, long, help = "Use bookmark to avoid sorting entire edge table.")] + bookmark: bool, } fn main() -> Result<()> { @@ -100,6 +127,7 @@ fn main() -> Result<()> { params.popsize, params.num_generations, params.simplify_interval, + params.bookmark, )?; if let Some(treefile) = ¶ms.treefile { @@ -114,8 +142,9 @@ proptest! { #[test] fn test_simulate_proptest(seed in any::(), num_generations in 50..100i32, - simplify_interval in 1..100i32) { - let ts = simulate(seed, 100, num_generations, simplify_interval).unwrap(); + simplify_interval in 1..100i32, + bookmark in proptest::bool::ANY) { + let ts = simulate(seed, 100, num_generations, simplify_interval, bookmark).unwrap(); // stress test the branch length fn b/c it is not a trivial // wrapper around the C API.