Skip to content

Commit 5e0eb1a

Browse files
authored
feature: TreeSequence::edge_differences_iter (#410)
Returns a lending iterator providing further sub-Iterators over edge differences.
1 parent 8ad03af commit 5e0eb1a

File tree

3 files changed

+319
-0
lines changed

3 files changed

+319
-0
lines changed

src/edge_differences.rs

Lines changed: 254 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,254 @@
1+
use crate::NodeId;
2+
use crate::Position;
3+
use crate::TreeSequence;
4+
5+
use crate::bindings;
6+
7+
#[repr(transparent)]
8+
struct LLEdgeDifferenceIterator(bindings::tsk_diff_iter_t);
9+
10+
impl std::ops::Deref for LLEdgeDifferenceIterator {
11+
type Target = bindings::tsk_diff_iter_t;
12+
13+
fn deref(&self) -> &Self::Target {
14+
&self.0
15+
}
16+
}
17+
18+
impl std::ops::DerefMut for LLEdgeDifferenceIterator {
19+
fn deref_mut(&mut self) -> &mut Self::Target {
20+
&mut self.0
21+
}
22+
}
23+
24+
impl Drop for LLEdgeDifferenceIterator {
25+
fn drop(&mut self) {
26+
unsafe { bindings::tsk_diff_iter_free(&mut self.0) };
27+
}
28+
}
29+
30+
impl LLEdgeDifferenceIterator {
31+
pub fn new_from_treeseq(treeseq: &TreeSequence, flags: bindings::tsk_flags_t) -> Option<Self> {
32+
let mut inner = std::mem::MaybeUninit::<bindings::tsk_diff_iter_t>::uninit();
33+
match unsafe { bindings::tsk_diff_iter_init(inner.as_mut_ptr(), treeseq.as_ptr(), flags) } {
34+
x if x < 0 => None,
35+
_ => Some(Self(unsafe { inner.assume_init() })),
36+
}
37+
}
38+
}
39+
40+
/// Marker type for edge insertion.
41+
pub struct Insertion {}
42+
43+
/// Marker type for edge removal.
44+
pub struct Removal {}
45+
46+
mod private {
47+
pub trait EdgeDifferenceIteration {}
48+
49+
impl EdgeDifferenceIteration for super::Insertion {}
50+
impl EdgeDifferenceIteration for super::Removal {}
51+
}
52+
53+
struct LLEdgeList<T: private::EdgeDifferenceIteration> {
54+
inner: bindings::tsk_edge_list_t,
55+
marker: std::marker::PhantomData<T>,
56+
}
57+
58+
macro_rules! build_lledgelist {
59+
($name: ident, $generic: ty) => {
60+
type $name = LLEdgeList<$generic>;
61+
62+
impl Default for $name {
63+
fn default() -> Self {
64+
Self {
65+
inner: bindings::tsk_edge_list_t {
66+
head: std::ptr::null_mut(),
67+
tail: std::ptr::null_mut(),
68+
},
69+
marker: std::marker::PhantomData::<$generic> {},
70+
}
71+
}
72+
}
73+
};
74+
}
75+
76+
build_lledgelist!(LLEdgeInsertionList, Insertion);
77+
build_lledgelist!(LLEdgeRemovalList, Removal);
78+
79+
/// Concrete type implementing [`Iterator`] over [`EdgeInsertion`] or [`EdgeRemoval`].
80+
/// Created by [`EdgeDifferencesIterator::edge_insertions`] or
81+
/// [`EdgeDifferencesIterator::edge_removals`], respectively.
82+
pub struct EdgeDifferences<'a, T: private::EdgeDifferenceIteration> {
83+
inner: &'a LLEdgeList<T>,
84+
current: *mut bindings::tsk_edge_list_node_t,
85+
}
86+
87+
impl<'a, T: private::EdgeDifferenceIteration> EdgeDifferences<'a, T> {
88+
fn new(inner: &'a LLEdgeList<T>) -> Self {
89+
Self {
90+
inner,
91+
current: std::ptr::null_mut(),
92+
}
93+
}
94+
}
95+
96+
/// An edge difference. Edge insertions and removals are differentiated by
97+
/// marker types [`Insertion`] and [`Removal`], respectively.
98+
#[derive(Debug, Copy, Clone)]
99+
pub struct EdgeDifference<T: private::EdgeDifferenceIteration> {
100+
left: Position,
101+
right: Position,
102+
parent: NodeId,
103+
child: NodeId,
104+
marker: std::marker::PhantomData<T>,
105+
}
106+
107+
impl<T: private::EdgeDifferenceIteration> EdgeDifference<T> {
108+
fn new<P: Into<Position>, N: Into<NodeId>>(left: P, right: P, parent: N, child: N) -> Self {
109+
Self {
110+
left: left.into(),
111+
right: right.into(),
112+
parent: parent.into(),
113+
child: child.into(),
114+
marker: std::marker::PhantomData::<T> {},
115+
}
116+
}
117+
118+
pub fn left(&self) -> Position {
119+
self.left
120+
}
121+
pub fn right(&self) -> Position {
122+
self.right
123+
}
124+
pub fn parent(&self) -> NodeId {
125+
self.parent
126+
}
127+
pub fn child(&self) -> NodeId {
128+
self.child
129+
}
130+
}
131+
132+
impl<T> std::fmt::Display for EdgeDifference<T>
133+
where
134+
T: private::EdgeDifferenceIteration,
135+
{
136+
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
137+
write!(
138+
f,
139+
"left: {}, right: {}, parent: {}, child: {}",
140+
self.left(),
141+
self.right(),
142+
self.parent(),
143+
self.child()
144+
)
145+
}
146+
}
147+
148+
/// Type alias for [`EdgeDifference<Insertion>`]
149+
pub type EdgeInsertion = EdgeDifference<Insertion>;
150+
/// Type alias for [`EdgeDifference<Removal>`]
151+
pub type EdgeRemoval = EdgeDifference<Removal>;
152+
153+
impl<'a, T> Iterator for EdgeDifferences<'a, T>
154+
where
155+
T: private::EdgeDifferenceIteration,
156+
{
157+
type Item = EdgeDifference<T>;
158+
159+
fn next(&mut self) -> Option<Self::Item> {
160+
if self.current.is_null() {
161+
self.current = self.inner.inner.head;
162+
} else {
163+
self.current = unsafe { *self.current }.next;
164+
}
165+
if self.current.is_null() {
166+
None
167+
} else {
168+
let left = unsafe { (*self.current).edge.left };
169+
let right = unsafe { (*self.current).edge.right };
170+
let parent = unsafe { (*self.current).edge.parent };
171+
let child = unsafe { (*self.current).edge.child };
172+
Some(Self::Item::new(left, right, parent, child))
173+
}
174+
}
175+
}
176+
177+
/// Manages iteration over trees to obtain
178+
/// edge differences.
179+
pub struct EdgeDifferencesIterator {
180+
inner: LLEdgeDifferenceIterator,
181+
insertion: LLEdgeInsertionList,
182+
removal: LLEdgeRemovalList,
183+
left: f64,
184+
right: f64,
185+
advanced: i32,
186+
}
187+
188+
impl EdgeDifferencesIterator {
189+
// NOTE: will return None if tskit-c cannot
190+
// allocate memory for internal structures.
191+
pub(crate) fn new_from_treeseq(
192+
treeseq: &TreeSequence,
193+
flags: bindings::tsk_flags_t,
194+
) -> Option<Self> {
195+
LLEdgeDifferenceIterator::new_from_treeseq(treeseq, flags).map(|inner| Self {
196+
inner,
197+
insertion: LLEdgeInsertionList::default(),
198+
removal: LLEdgeRemovalList::default(),
199+
left: f64::default(),
200+
right: f64::default(),
201+
advanced: 0,
202+
})
203+
}
204+
205+
fn advance_tree(&mut self) {
206+
// SAFETY: our tree sequence is guaranteed
207+
// to be valid and own its tables.
208+
self.advanced = unsafe {
209+
bindings::tsk_diff_iter_next(
210+
&mut self.inner.0,
211+
&mut self.left,
212+
&mut self.right,
213+
&mut self.removal.inner,
214+
&mut self.insertion.inner,
215+
)
216+
};
217+
}
218+
219+
pub fn left(&self) -> Position {
220+
self.left.into()
221+
}
222+
223+
pub fn right(&self) -> Position {
224+
self.right.into()
225+
}
226+
227+
pub fn interval(&self) -> (Position, Position) {
228+
(self.left(), self.right())
229+
}
230+
231+
pub fn edge_removals(&self) -> impl Iterator<Item = EdgeRemoval> + '_ {
232+
EdgeDifferences::<Removal>::new(&self.removal)
233+
}
234+
235+
pub fn edge_insertions(&self) -> impl Iterator<Item = EdgeInsertion> + '_ {
236+
EdgeDifferences::<Insertion>::new(&self.insertion)
237+
}
238+
}
239+
240+
impl streaming_iterator::StreamingIterator for EdgeDifferencesIterator {
241+
type Item = EdgeDifferencesIterator;
242+
243+
fn advance(&mut self) {
244+
self.advance_tree()
245+
}
246+
247+
fn get(&self) -> Option<&Self::Item> {
248+
if self.advanced > 0 {
249+
Some(self)
250+
} else {
251+
None
252+
}
253+
}
254+
}

src/lib.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@
7979
pub mod bindings;
8080

8181
mod _macros; // Starts w/_ to be sorted at front by rustfmt!
82+
mod edge_differences;
8283
mod edge_table;
8384
pub mod error;
8485
mod flags;
@@ -427,6 +428,7 @@ impl_time_position_arithmetic!(Position, Time);
427428
/// "Null" identifier value.
428429
pub(crate) const TSK_NULL: tsk_id_t = -1;
429430

431+
pub use edge_differences::*;
430432
pub use edge_table::{EdgeTable, EdgeTableRow, OwningEdgeTable};
431433
pub use error::TskitError;
432434
pub use flags::*;

src/trees.rs

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -538,6 +538,19 @@ impl TreeSequence {
538538
}
539539

540540
delegate_table_view_api!();
541+
542+
/// Build a lending iterator over edge differences.
543+
///
544+
/// # Returns
545+
///
546+
/// * None if the `C` back end is unable to allocate
547+
/// needed memory
548+
/// * `Some(iterator)` otherwise.
549+
pub fn edge_differences_iter(
550+
&self,
551+
) -> Option<crate::edge_differences::EdgeDifferencesIterator> {
552+
crate::edge_differences::EdgeDifferencesIterator::new_from_treeseq(self, 0)
553+
}
541554
}
542555

543556
impl TryFrom<TableCollection> for TreeSequence {
@@ -861,6 +874,56 @@ pub(crate) mod test_trees {
861874
panic!("Expected a tree.");
862875
}
863876
}
877+
878+
// TODO: use trybuild to add tests that the iterator
879+
// lifetime is indeed coupled to that of the treeseq
880+
#[test]
881+
fn test_edge_diffs_lending_iterator_num_trees() {
882+
{
883+
let treeseq = treeseq_from_small_table_collection_two_trees();
884+
let num_nodes: usize = treeseq.nodes().num_rows().try_into().unwrap();
885+
let mut parents = vec![NodeId::NULL; num_nodes + 1];
886+
if let Some(mut ediff_iter) = treeseq.edge_differences_iter() {
887+
let mut tree_iter = treeseq.tree_iterator(0).unwrap();
888+
let mut ntrees = 0;
889+
while let Some(diffs) = ediff_iter.next() {
890+
let tree = tree_iter.next().unwrap();
891+
892+
for edge_out in diffs.edge_removals() {
893+
let p = edge_out.child();
894+
parents[usize::try_from(p).unwrap()] = NodeId::NULL;
895+
}
896+
897+
for edge_in in diffs.edge_insertions() {
898+
let c: usize = edge_in.child().try_into().unwrap();
899+
parents[c] = edge_in.parent();
900+
}
901+
902+
assert_eq!(tree.parent_array(), &parents);
903+
ntrees += 1;
904+
}
905+
assert_eq!(ntrees, 2);
906+
} else {
907+
panic!("expected an edge differences iterator");
908+
}
909+
}
910+
911+
{
912+
let treeseq = treeseq_from_small_table_collection_two_trees();
913+
let mut ediff_iter = treeseq.edge_differences_iter().unwrap();
914+
915+
let mut ntrees = 0;
916+
while let Some(diffs) = ediff_iter.next() {
917+
if ntrees == 0 {
918+
assert_eq!(diffs.interval(), (0.0.into(), 500.0.into()));
919+
} else {
920+
assert_eq!(diffs.interval(), (500.0.into(), 1000.0.into()));
921+
}
922+
ntrees += 1;
923+
}
924+
assert_eq!(ntrees, 2);
925+
}
926+
}
864927
}
865928

866929
#[cfg(test)]

0 commit comments

Comments
 (0)