Skip to content

refactor tree iteration #296

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 22 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/tree_interface.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ use crate::TskitError;
use crate::TskitTypeAccess;
use std::ptr::NonNull;

#[derive(Debug)]
pub struct TreeInterface {
non_owned_pointer: NonNull<ll_bindings::tsk_tree_t>,
num_nodes: tsk_size_t,
Expand Down
148 changes: 142 additions & 6 deletions src/trees.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
use std::mem::MaybeUninit;
use std::ops::Deref;
use std::ops::DerefMut;
use std::ops::{Deref, DerefMut};

use crate::bindings as ll_bindings;
use crate::error::TskitError;
Expand All @@ -22,7 +21,6 @@ use crate::TreeSequenceFlags;
use crate::TskReturnValue;
use crate::TskitTypeAccess;
use crate::{tsk_id_t, tsk_size_t, TableCollection};
use ll_bindings::tsk_tree_free;
use std::ptr::NonNull;

/// A Tree.
Expand All @@ -35,20 +33,120 @@ pub struct Tree {
advanced: bool,
}

impl Drop for Tree {
fn drop(&mut self) {
let rv = unsafe { tsk_tree_free(&mut self.inner) };
pub struct TreeIterator {
tree: ll_bindings::tsk_tree_t,
current_tree: i32,
advanced: bool,
num_nodes: tsk_size_t,
array_len: tsk_size_t,
flags: TreeFlags,
}

impl TreeIterator {
// FIXME: init if fallible!
fn new(treeseq: &TreeSequence, num_nodes: u64) -> Self {
let mut tree = MaybeUninit::<ll_bindings::tsk_tree_t>::uninit();
let rv = unsafe { ll_bindings::tsk_tree_init(tree.as_mut_ptr(), treeseq.as_ptr(), 0) };
assert_eq!(rv, 0);
let tree = unsafe { tree.assume_init() };

Self {
tree,
current_tree: -1,
advanced: false,
num_nodes,
array_len: num_nodes + 1,
flags: 0.into(),
}
}
fn item(&mut self) -> NonOwningTree {
NonOwningTree::new(
NonNull::from(&mut self.tree),
self.num_nodes,
self.array_len,
self.flags,
)
}
}

#[derive(Debug)]
#[repr(transparent)]
pub struct NonOwningTree {
api: TreeInterface,
}

impl Deref for NonOwningTree {
type Target = TreeInterface;

fn deref(&self) -> &Self::Target {
&self.api
}
}

impl NonOwningTree {
fn new(
tree: NonNull<ll_bindings::tsk_tree_t>,
num_nodes: tsk_size_t,
array_len: tsk_size_t,
flags: TreeFlags,
) -> Self {
let api = TreeInterface::new(tree, num_nodes, array_len, flags);
Self { api }
}
}

impl Iterator for TreeIterator {
type Item = NonOwningTree;

fn next(&mut self) -> Option<Self::Item> {
let rv = if self.current_tree == 0 {
unsafe { ll_bindings::tsk_tree_first(&mut self.tree) }
} else {
unsafe { ll_bindings::tsk_tree_next(&mut self.tree) }
};
if rv == 0 {
self.advanced = false;
self.current_tree += 1;
} else if rv == 1 {
self.advanced = true;
self.current_tree += 1;
} else if rv < 0 {
panic_on_tskit_error!(rv);
}
if self.advanced {
Some(self.item())
} else {
None
}
}
}

impl Drop for TreeIterator {
fn drop(&mut self) {
unsafe { ll_bindings::tsk_tree_free(&mut self.tree) };
}
}

// Trait defining iteration over nodes.
trait NodeIterator {
fn next_node(&mut self);
fn current_node(&mut self) -> Option<NodeId>;
}

impl Deref for Tree {
type Target = TreeInterface;
fn deref(&self) -> &Self::Target {
&self.api
}
}

impl Drop for Tree {
fn drop(&mut self) {
let rv = unsafe { ll_bindings::tsk_tree_free(&mut self.inner) };
assert_eq!(rv, 0);
}
}

impl DerefMut for Tree {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.api
Expand Down Expand Up @@ -341,6 +439,13 @@ impl TreeSequence {
Ok(tree)
}

/// Return an iterator over the trees.
pub fn trees(&self) -> impl Iterator<Item = impl Deref<Target = TreeInterface>> + '_ {
// TODO: should have a safe wrapper for this
let num_nodes = unsafe { (*(*self.as_ptr()).tables).nodes.num_rows };
TreeIterator::new(self, num_nodes)
}

/// Get the list of samples as a vector.
/// # Panics
///
Expand Down Expand Up @@ -651,6 +756,37 @@ pub(crate) mod test_trees {
}
}

#[test]
fn test_new_trees_iterator() {
let treeseq = treeseq_from_small_table_collection();
for tree in treeseq.trees() {
for n in tree.traverse_nodes(NodeTraversalOrder::Preorder) {
for p in tree.parents(n).unwrap() {
println!("{:?}", p);
}
}
}

// NOTE: the following blocks make cargo valgrind crash,
// which is a sign of badness.
panic!("cargo valgrind will fail here");

// This is a safety sticking point:
// we cannot collect the iterable itself b/c
// the underlying tree memory is re-used.
let i = treeseq.trees();
let v = Vec::<_>::from_iter(i);
assert_eq!(v.len(), 2);
for i in v {
println!("{:?}", i.parent_array());
}

let v = treeseq.trees().collect::<Vec<_>>();
for i in v {
println!("{:?}", i.parent_array());
}
}

#[should_panic]
#[test]
fn test_num_tracked_samples_not_tracking_samples() {
Expand Down