diff --git a/examples/forward_simulation.rs b/examples/forward_simulation.rs index 03574ef3e..20d999269 100644 --- a/examples/forward_simulation.rs +++ b/examples/forward_simulation.rs @@ -23,6 +23,8 @@ use rand::rngs::StdRng; use rand::Rng; use rand::SeedableRng; use rand_distr::{Exp, Uniform}; +use tskit::TableAccess; +use tskit::TskitTypeAccess; struct SimParams { pub popsize: u32, @@ -318,7 +320,29 @@ fn births( } } -fn simplify(alive: &mut [Diploid], tables: &mut tskit::TableCollection) { +fn rotate_edge_table(mid: usize, tables: &mut tskit::TableCollection) { + // NOTE: using unsafe here because we don't have + // a rust API yet. + let num_edges: usize = tables.edges().num_rows().try_into().unwrap(); + let parent = + unsafe { std::slice::from_raw_parts_mut((*tables.as_mut_ptr()).edges.parent, num_edges) }; + let child = + unsafe { std::slice::from_raw_parts_mut((*tables.as_mut_ptr()).edges.child, num_edges) }; + let left = + unsafe { std::slice::from_raw_parts_mut((*tables.as_mut_ptr()).edges.left, num_edges) }; + let right = + unsafe { std::slice::from_raw_parts_mut((*tables.as_mut_ptr()).edges.right, num_edges) }; + parent.rotate_left(mid); + child.rotate_left(mid); + left.rotate_left(mid); + right.rotate_left(mid); +} + +fn simplify( + bookmark: &tskit::types::Bookmark, + alive: &mut [Diploid], + tables: &mut tskit::TableCollection, +) { let mut samples = vec![]; for a in alive.iter() { assert!(a.node0 != a.node1); @@ -326,11 +350,16 @@ fn simplify(alive: &mut [Diploid], tables: &mut tskit::TableCollection) { samples.push(a.node1); } - match tables.full_sort(tskit::TableSortOptions::default()) { + match tables.sort(bookmark, tskit::TableSortOptions::default()) { Ok(_) => (), Err(e) => panic!("{}", e), } + if bookmark.offsets.edges > 0 { + let mid: usize = bookmark.offsets.edges.try_into().unwrap(); + rotate_edge_table(mid, tables); + } + match tables.simplify( &samples, tskit::SimplificationOptions::KEEP_INPUT_ROOTS, @@ -351,6 +380,69 @@ fn simplify(alive: &mut [Diploid], tables: &mut tskit::TableCollection) { }; } +fn update_bookmark( + alive: &[Diploid], + tables: &mut tskit::TableCollection, + bookmark: &mut tskit::types::Bookmark, +) -> Result<(), tskit::TskitError> { + // get min/max time of alive nodes + let mut most_recent_birth_time: f64 = f64::MAX; + let mut most_ancient_birth_time: f64 = f64::MIN; + + { + let nodes = tables.nodes(); + for a in alive { + for node in [a.node0, a.node1] { + match nodes.time(node) { + Ok(time) => { + most_recent_birth_time = if time < most_recent_birth_time { + time.into() + } else { + most_recent_birth_time + }; + most_ancient_birth_time = if time > most_ancient_birth_time { + time.into() + } else { + most_ancient_birth_time + }; + } + Err(e) => return Err(e), + } + } + } + } + + // All alive nodes born at same time. + if most_ancient_birth_time == most_recent_birth_time { + bookmark.offsets.edges = tables.edges().num_rows().into(); + } else { + // We have non-overlapping generations: + // * Find the last node born at <= the max time + // * Rotate the edge table there + // * Set the bookmark to include the rotated nodes + + // NOTE: we dip into unsafe here because we + // don't yet have direct API support for these ops. + + let num_nodes: usize = tables.nodes().num_rows().try_into().unwrap(); + let time = unsafe { std::slice::from_raw_parts((*tables.as_ptr()).nodes.time, num_nodes) }; + match time + .iter() + .enumerate() + .find(|(_index, time)| **time > most_ancient_birth_time) + { + Some((index, _time)) => { + rotate_edge_table(index, tables); + let num_edges: usize = tables.edges().num_rows().try_into().unwrap(); + bookmark.offsets.edges = (num_edges - index).try_into().unwrap(); + } + None => bookmark.offsets.edges = 0, + } + } + + Ok(()) +} + fn runsim(params: &SimParams) -> tskit::TableCollection { let mut tables = match tskit::TableCollection::new(params.genome_length) { Ok(x) => x, @@ -385,6 +477,7 @@ fn runsim(params: &SimParams) -> tskit::TableCollection { let mut parents: Vec = vec![]; let mut simplified: bool = false; + let mut bookmark = tskit::types::Bookmark::new(); for step in (0..params.nsteps).rev() { parents.clear(); death_and_parents(&alive, params, &mut parents, &mut rng); @@ -392,15 +485,16 @@ fn runsim(params: &SimParams) -> tskit::TableCollection { let remainder = step % params.simplification_interval; match step < params.nsteps && remainder == 0 { true => { - simplify(&mut alive, &mut tables); + simplify(&bookmark, &mut alive, &mut tables); simplified = true; + update_bookmark(&alive, &mut tables, &mut bookmark).unwrap(); } false => simplified = false, } } if !simplified { - simplify(&mut alive, &mut tables); + simplify(&bookmark, &mut alive, &mut tables); } tables @@ -430,7 +524,7 @@ fn test_bad_genome_length() { #[test] fn test_nonoverlapping_generations() { let mut params = SimParams::new(); - params.nsteps = 100; + params.nsteps = 500; params.xovers = 1e-3; params.validate().unwrap(); runsim(¶ms);