diff --git a/src/tree_interface.rs b/src/tree_interface.rs index 271102414..7ac16da50 100644 --- a/src/tree_interface.rs +++ b/src/tree_interface.rs @@ -10,6 +10,7 @@ use crate::TskitError; use crate::TskitTypeAccess; use std::ptr::NonNull; +#[derive(Debug)] pub struct TreeInterface { non_owned_pointer: NonNull, num_nodes: tsk_size_t, diff --git a/src/trees.rs b/src/trees.rs index 56ce4e2b7..d38040b99 100644 --- a/src/trees.rs +++ b/src/trees.rs @@ -1,6 +1,5 @@ use std::mem::MaybeUninit; -use std::ops::Deref; -use std::ops::DerefMut; +use std::ops::{Deref, DerefMut}; use crate::bindings as ll_bindings; use crate::error::TskitError; @@ -22,7 +21,6 @@ use crate::TreeSequenceFlags; use crate::TskReturnValue; use crate::TskitTypeAccess; use crate::{tsk_id_t, tsk_size_t, TableCollection}; -use ll_bindings::tsk_tree_free; use std::ptr::NonNull; /// A Tree. @@ -35,13 +33,106 @@ pub struct Tree { advanced: bool, } -impl Drop for Tree { - fn drop(&mut self) { - let rv = unsafe { tsk_tree_free(&mut self.inner) }; +pub struct TreeIterator { + tree: ll_bindings::tsk_tree_t, + current_tree: i32, + advanced: bool, + num_nodes: tsk_size_t, + array_len: tsk_size_t, + flags: TreeFlags, +} + +impl TreeIterator { + // FIXME: init if fallible! + fn new(treeseq: &TreeSequence, num_nodes: u64) -> Self { + let mut tree = MaybeUninit::::uninit(); + let rv = unsafe { ll_bindings::tsk_tree_init(tree.as_mut_ptr(), treeseq.as_ptr(), 0) }; assert_eq!(rv, 0); + let tree = unsafe { tree.assume_init() }; + + Self { + tree, + current_tree: -1, + advanced: false, + num_nodes, + array_len: num_nodes + 1, + flags: 0.into(), + } + } + fn item(&mut self) -> NonOwningTree { + NonOwningTree::new( + NonNull::from(&mut self.tree), + self.num_nodes, + self.array_len, + self.flags, + ) + } +} + +#[derive(Debug)] +#[repr(transparent)] +pub struct NonOwningTree { + api: TreeInterface, +} + +impl Deref for NonOwningTree { + type Target = TreeInterface; + + fn deref(&self) -> &Self::Target { + &self.api } } +impl NonOwningTree { + fn new( + tree: NonNull, + num_nodes: tsk_size_t, + array_len: tsk_size_t, + flags: TreeFlags, + ) -> Self { + let api = TreeInterface::new(tree, num_nodes, array_len, flags); + Self { api } + } +} + +impl Iterator for TreeIterator { + type Item = NonOwningTree; + + fn next(&mut self) -> Option { + let rv = if self.current_tree == 0 { + unsafe { ll_bindings::tsk_tree_first(&mut self.tree) } + } else { + unsafe { ll_bindings::tsk_tree_next(&mut self.tree) } + }; + if rv == 0 { + self.advanced = false; + self.current_tree += 1; + } else if rv == 1 { + self.advanced = true; + self.current_tree += 1; + } else if rv < 0 { + panic_on_tskit_error!(rv); + } + if self.advanced { + Some(self.item()) + } else { + None + } + } +} + +impl Drop for TreeIterator { + fn drop(&mut self) { + unsafe { ll_bindings::tsk_tree_free(&mut self.tree) }; + } +} + +// Trait defining iteration over nodes. +trait NodeIterator { + fn next_node(&mut self); + fn current_node(&mut self) -> Option; +} + impl Deref for Tree { type Target = TreeInterface; fn deref(&self) -> &Self::Target { @@ -49,6 +140,13 @@ impl Deref for Tree { } } +impl Drop for Tree { + fn drop(&mut self) { + let rv = unsafe { ll_bindings::tsk_tree_free(&mut self.inner) }; + assert_eq!(rv, 0); + } +} + impl DerefMut for Tree { fn deref_mut(&mut self) -> &mut Self::Target { &mut self.api @@ -341,6 +439,13 @@ impl TreeSequence { Ok(tree) } + /// Return an iterator over the trees. + pub fn trees(&self) -> impl Iterator> + '_ { + // TODO: should have a safe wrapper for this + let num_nodes = unsafe { (*(*self.as_ptr()).tables).nodes.num_rows }; + TreeIterator::new(self, num_nodes) + } + /// Get the list of samples as a vector. /// # Panics /// @@ -651,6 +756,37 @@ pub(crate) mod test_trees { } } + #[test] + fn test_new_trees_iterator() { + let treeseq = treeseq_from_small_table_collection(); + for tree in treeseq.trees() { + for n in tree.traverse_nodes(NodeTraversalOrder::Preorder) { + for p in tree.parents(n).unwrap() { + println!("{:?}", p); + } + } + } + + // NOTE: the following blocks make cargo valgrind crash, + // which is a sign of badness. + panic!("cargo valgrind will fail here"); + + // This is a safety sticking point: + // we cannot collect the iterable itself b/c + // the underlying tree memory is re-used. + let i = treeseq.trees(); + let v = Vec::<_>::from_iter(i); + assert_eq!(v.len(), 2); + for i in v { + println!("{:?}", i.parent_array()); + } + + let v = treeseq.trees().collect::>(); + for i in v { + println!("{:?}", i.parent_array()); + } + } + #[should_panic] #[test] fn test_num_tracked_samples_not_tracking_samples() {