Skip to content

Commit ffde091

Browse files
authored
perf: use Bookmark in forward simulation example (#318)
1 parent fa79e15 commit ffde091

File tree

1 file changed

+99
-5
lines changed

1 file changed

+99
-5
lines changed

examples/forward_simulation.rs

Lines changed: 99 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@ use rand::rngs::StdRng;
2323
use rand::Rng;
2424
use rand::SeedableRng;
2525
use rand_distr::{Exp, Uniform};
26+
use tskit::TableAccess;
27+
use tskit::TskitTypeAccess;
2628

2729
struct SimParams {
2830
pub popsize: u32,
@@ -318,19 +320,46 @@ fn births(
318320
}
319321
}
320322

321-
fn simplify(alive: &mut [Diploid], tables: &mut tskit::TableCollection) {
323+
fn rotate_edge_table(mid: usize, tables: &mut tskit::TableCollection) {
324+
// NOTE: using unsafe here because we don't have
325+
// a rust API yet.
326+
let num_edges: usize = tables.edges().num_rows().try_into().unwrap();
327+
let parent =
328+
unsafe { std::slice::from_raw_parts_mut((*tables.as_mut_ptr()).edges.parent, num_edges) };
329+
let child =
330+
unsafe { std::slice::from_raw_parts_mut((*tables.as_mut_ptr()).edges.child, num_edges) };
331+
let left =
332+
unsafe { std::slice::from_raw_parts_mut((*tables.as_mut_ptr()).edges.left, num_edges) };
333+
let right =
334+
unsafe { std::slice::from_raw_parts_mut((*tables.as_mut_ptr()).edges.right, num_edges) };
335+
parent.rotate_left(mid);
336+
child.rotate_left(mid);
337+
left.rotate_left(mid);
338+
right.rotate_left(mid);
339+
}
340+
341+
fn simplify(
342+
bookmark: &tskit::types::Bookmark,
343+
alive: &mut [Diploid],
344+
tables: &mut tskit::TableCollection,
345+
) {
322346
let mut samples = vec![];
323347
for a in alive.iter() {
324348
assert!(a.node0 != a.node1);
325349
samples.push(a.node0);
326350
samples.push(a.node1);
327351
}
328352

329-
match tables.full_sort(tskit::TableSortOptions::default()) {
353+
match tables.sort(bookmark, tskit::TableSortOptions::default()) {
330354
Ok(_) => (),
331355
Err(e) => panic!("{}", e),
332356
}
333357

358+
if bookmark.offsets.edges > 0 {
359+
let mid: usize = bookmark.offsets.edges.try_into().unwrap();
360+
rotate_edge_table(mid, tables);
361+
}
362+
334363
match tables.simplify(
335364
&samples,
336365
tskit::SimplificationOptions::KEEP_INPUT_ROOTS,
@@ -351,6 +380,69 @@ fn simplify(alive: &mut [Diploid], tables: &mut tskit::TableCollection) {
351380
};
352381
}
353382

383+
fn update_bookmark(
384+
alive: &[Diploid],
385+
tables: &mut tskit::TableCollection,
386+
bookmark: &mut tskit::types::Bookmark,
387+
) -> Result<(), tskit::TskitError> {
388+
// get min/max time of alive nodes
389+
let mut most_recent_birth_time: f64 = f64::MAX;
390+
let mut most_ancient_birth_time: f64 = f64::MIN;
391+
392+
{
393+
let nodes = tables.nodes();
394+
for a in alive {
395+
for node in [a.node0, a.node1] {
396+
match nodes.time(node) {
397+
Ok(time) => {
398+
most_recent_birth_time = if time < most_recent_birth_time {
399+
time.into()
400+
} else {
401+
most_recent_birth_time
402+
};
403+
most_ancient_birth_time = if time > most_ancient_birth_time {
404+
time.into()
405+
} else {
406+
most_ancient_birth_time
407+
};
408+
}
409+
Err(e) => return Err(e),
410+
}
411+
}
412+
}
413+
}
414+
415+
// All alive nodes born at same time.
416+
if most_ancient_birth_time == most_recent_birth_time {
417+
bookmark.offsets.edges = tables.edges().num_rows().into();
418+
} else {
419+
// We have non-overlapping generations:
420+
// * Find the last node born at <= the max time
421+
// * Rotate the edge table there
422+
// * Set the bookmark to include the rotated nodes
423+
424+
// NOTE: we dip into unsafe here because we
425+
// don't yet have direct API support for these ops.
426+
427+
let num_nodes: usize = tables.nodes().num_rows().try_into().unwrap();
428+
let time = unsafe { std::slice::from_raw_parts((*tables.as_ptr()).nodes.time, num_nodes) };
429+
match time
430+
.iter()
431+
.enumerate()
432+
.find(|(_index, time)| **time > most_ancient_birth_time)
433+
{
434+
Some((index, _time)) => {
435+
rotate_edge_table(index, tables);
436+
let num_edges: usize = tables.edges().num_rows().try_into().unwrap();
437+
bookmark.offsets.edges = (num_edges - index).try_into().unwrap();
438+
}
439+
None => bookmark.offsets.edges = 0,
440+
}
441+
}
442+
443+
Ok(())
444+
}
445+
354446
fn runsim(params: &SimParams) -> tskit::TableCollection {
355447
let mut tables = match tskit::TableCollection::new(params.genome_length) {
356448
Ok(x) => x,
@@ -385,22 +477,24 @@ fn runsim(params: &SimParams) -> tskit::TableCollection {
385477
let mut parents: Vec<Parents> = vec![];
386478
let mut simplified: bool = false;
387479

480+
let mut bookmark = tskit::types::Bookmark::new();
388481
for step in (0..params.nsteps).rev() {
389482
parents.clear();
390483
death_and_parents(&alive, params, &mut parents, &mut rng);
391484
births(&parents, params, step, &mut tables, &mut alive, &mut rng);
392485
let remainder = step % params.simplification_interval;
393486
match step < params.nsteps && remainder == 0 {
394487
true => {
395-
simplify(&mut alive, &mut tables);
488+
simplify(&bookmark, &mut alive, &mut tables);
396489
simplified = true;
490+
update_bookmark(&alive, &mut tables, &mut bookmark).unwrap();
397491
}
398492
false => simplified = false,
399493
}
400494
}
401495

402496
if !simplified {
403-
simplify(&mut alive, &mut tables);
497+
simplify(&bookmark, &mut alive, &mut tables);
404498
}
405499

406500
tables
@@ -430,7 +524,7 @@ fn test_bad_genome_length() {
430524
#[test]
431525
fn test_nonoverlapping_generations() {
432526
let mut params = SimParams::new();
433-
params.nsteps = 100;
527+
params.nsteps = 500;
434528
params.xovers = 1e-3;
435529
params.validate().unwrap();
436530
runsim(&params);

0 commit comments

Comments
 (0)