Skip to content

Commit 2c93206

Browse files
authored
refactor: add sys::LLTreeSeq. (#430)
* This is the new back end for TreeSequence. * Allows pushing lots of unsafe into sys.rs.
1 parent b27d686 commit 2c93206

File tree

2 files changed

+123
-63
lines changed

2 files changed

+123
-63
lines changed

src/sys.rs

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ use bindings::tsk_mutation_table_t;
66
use bindings::tsk_node_table_t;
77
use bindings::tsk_population_table_t;
88
use bindings::tsk_site_table_t;
9+
use std::ffi::CString;
910
use std::ptr::NonNull;
1011

1112
#[cfg(feature = "provenance")]
@@ -47,6 +48,94 @@ basic_lltableref_impl!(LLIndividualTableRef, tsk_individual_table_t);
4748
#[cfg(feature = "provenance")]
4849
basic_lltableref_impl!(LLProvenanceTableRef, tsk_provenance_table_t);
4950

51+
#[repr(transparent)]
52+
pub struct LLTreeSeq(bindings::tsk_treeseq_t);
53+
54+
impl LLTreeSeq {
55+
pub fn new(
56+
tables: *mut bindings::tsk_table_collection_t,
57+
flags: bindings::tsk_flags_t,
58+
) -> Result<Self, TskitError> {
59+
let mut inner = std::mem::MaybeUninit::<bindings::tsk_treeseq_t>::uninit();
60+
let mut flags = flags;
61+
flags |= bindings::TSK_TAKE_OWNERSHIP;
62+
let rv = unsafe { bindings::tsk_treeseq_init(inner.as_mut_ptr(), tables, flags) };
63+
handle_tsk_return_value!(rv, Self(unsafe { inner.assume_init() }))
64+
}
65+
66+
pub fn as_ref(&self) -> &bindings::tsk_treeseq_t {
67+
&self.0
68+
}
69+
70+
pub fn as_ptr(&self) -> *const bindings::tsk_treeseq_t {
71+
&self.0
72+
}
73+
74+
pub fn as_mut_ptr(&mut self) -> *mut bindings::tsk_treeseq_t {
75+
&mut self.0
76+
}
77+
78+
pub fn simplify(
79+
&self,
80+
samples: &[bindings::tsk_id_t],
81+
options: bindings::tsk_flags_t,
82+
idmap: *mut bindings::tsk_id_t,
83+
) -> Result<Self, TskitError> {
84+
// The output is an UNINITIALIZED treeseq,
85+
// else we leak memory.
86+
let mut ts = std::mem::MaybeUninit::<bindings::tsk_treeseq_t>::uninit();
87+
// SAFETY: samples is not null, idmap is allowed to be.
88+
// self.as_ptr() is not null
89+
let rv = unsafe {
90+
bindings::tsk_treeseq_simplify(
91+
self.as_ptr(),
92+
samples.as_ptr(),
93+
samples.len() as bindings::tsk_size_t,
94+
options,
95+
ts.as_mut_ptr(),
96+
idmap,
97+
)
98+
};
99+
let init = unsafe { ts.assume_init() };
100+
if rv < 0 {
101+
// SAFETY: the ptr is not null
102+
// and tsk_treeseq_free uses safe methods
103+
// to clean up.
104+
unsafe { bindings::tsk_treeseq_free(ts.as_mut_ptr()) };
105+
}
106+
handle_tsk_return_value!(rv, Self(init))
107+
}
108+
109+
pub fn dump(
110+
&self,
111+
filename: CString,
112+
options: bindings::tsk_flags_t,
113+
) -> Result<i32, TskitError> {
114+
// SAFETY: self pointer is not null
115+
let rv = unsafe { bindings::tsk_treeseq_dump(self.as_ptr(), filename.as_ptr(), options) };
116+
handle_tsk_return_value!(rv)
117+
}
118+
119+
pub fn num_trees(&self) -> bindings::tsk_size_t {
120+
// SAFETY: self pointer is not null
121+
unsafe { bindings::tsk_treeseq_get_num_trees(self.as_ptr()) }
122+
}
123+
124+
pub fn kc_distance(&self, other: &Self, lambda: f64) -> Result<f64, TskitError> {
125+
let mut kc: f64 = f64::NAN;
126+
let kcp: *mut f64 = &mut kc;
127+
// SAFETY: self pointer is not null
128+
let code = unsafe {
129+
bindings::tsk_treeseq_kc_distance(self.as_ptr(), other.as_ptr(), lambda, kcp)
130+
};
131+
handle_tsk_return_value!(code, kc)
132+
}
133+
134+
pub fn num_samples(&self) -> bindings::tsk_size_t {
135+
unsafe { bindings::tsk_treeseq_get_num_samples(self.as_ptr()) }
136+
}
137+
}
138+
50139
fn tsk_column_access_detail<R: Into<bindings::tsk_id_t>, L: Into<bindings::tsk_size_t>, T: Copy>(
51140
row: R,
52141
column: *const T,

src/trees.rs

Lines changed: 34 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
use std::mem::MaybeUninit;
21
use std::ops::Deref;
32
use std::ops::DerefMut;
43

@@ -13,7 +12,7 @@ use crate::TreeFlags;
1312
use crate::TreeInterface;
1413
use crate::TreeSequenceFlags;
1514
use crate::TskReturnValue;
16-
use crate::{tsk_id_t, tsk_size_t, TableCollection};
15+
use crate::{tsk_id_t, TableCollection};
1716
use ll_bindings::tsk_tree_free;
1817
use std::ptr::NonNull;
1918

@@ -185,7 +184,7 @@ impl streaming_iterator::DoubleEndedStreamingIterator for Tree {
185184
/// assert_eq!(treeseq.nodes_mut().num_rows(), 3);
186185
/// ```
187186
pub struct TreeSequence {
188-
pub(crate) inner: ll_bindings::tsk_treeseq_t,
187+
pub(crate) inner: sys::LLTreeSeq,
189188
views: crate::table_views::TableViews,
190189
}
191190

@@ -194,7 +193,7 @@ unsafe impl Sync for TreeSequence {}
194193

195194
impl Drop for TreeSequence {
196195
fn drop(&mut self) {
197-
let rv = unsafe { ll_bindings::tsk_treeseq_free(&mut self.inner) };
196+
let rv = unsafe { ll_bindings::tsk_treeseq_free(self.as_mut_ptr()) };
198197
assert_eq!(rv, 0);
199198
}
200199
}
@@ -247,31 +246,24 @@ impl TreeSequence {
247246
tables: TableCollection,
248247
flags: F,
249248
) -> Result<Self, TskitError> {
250-
let mut inner = std::mem::MaybeUninit::<ll_bindings::tsk_treeseq_t>::uninit();
251-
let mut flags: u32 = flags.into().bits();
252-
flags |= ll_bindings::TSK_TAKE_OWNERSHIP;
253249
let raw_tables_ptr = tables.into_raw()?;
254-
let rv =
255-
unsafe { ll_bindings::tsk_treeseq_init(inner.as_mut_ptr(), raw_tables_ptr, flags) };
250+
let mut inner = sys::LLTreeSeq::new(raw_tables_ptr, flags.into().bits())?;
256251
let views = crate::table_views::TableViews::new_from_tree_sequence(inner.as_mut_ptr())?;
257-
handle_tsk_return_value!(rv, {
258-
let inner = unsafe { inner.assume_init() };
259-
Self { inner, views }
260-
})
252+
Ok(Self { inner, views })
261253
}
262254

263255
fn as_ref(&self) -> &ll_bindings::tsk_treeseq_t {
264-
&self.inner
256+
self.inner.as_ref()
265257
}
266258

267259
/// Pointer to the low-level C type.
268260
pub fn as_ptr(&self) -> *const ll_bindings::tsk_treeseq_t {
269-
&self.inner
261+
self.inner.as_ptr()
270262
}
271263

272264
/// Mutable pointer to the low-level C type.
273265
pub fn as_mut_ptr(&mut self) -> *mut ll_bindings::tsk_treeseq_t {
274-
&mut self.inner
266+
self.inner.as_mut_ptr()
275267
}
276268

277269
/// Dump the tree sequence to file.
@@ -290,11 +282,7 @@ impl TreeSequence {
290282
let c_str = std::ffi::CString::new(filename).map_err(|_| {
291283
TskitError::LibraryError("call to ffi::Cstring::new failed".to_string())
292284
})?;
293-
let rv = unsafe {
294-
ll_bindings::tsk_treeseq_dump(self.as_ptr(), c_str.as_ptr(), options.into().bits())
295-
};
296-
297-
handle_tsk_return_value!(rv)
285+
self.inner.dump(c_str, options.into().bits())
298286
}
299287

300288
/// Load from a file.
@@ -401,7 +389,7 @@ impl TreeSequence {
401389

402390
/// Get the number of trees.
403391
pub fn num_trees(&self) -> SizeType {
404-
unsafe { ll_bindings::tsk_treeseq_get_num_trees(self.as_ptr()) }.into()
392+
self.inner.num_trees().into()
405393
}
406394

407395
/// Calculate the average Kendall-Colijn (`K-C`) distance between
@@ -416,17 +404,12 @@ impl TreeSequence {
416404
/// * `lambda` specifies the relative weight of topology and branch length.
417405
/// See [`TreeInterface::kc_distance`] for more details.
418406
pub fn kc_distance(&self, other: &TreeSequence, lambda: f64) -> Result<f64, TskitError> {
419-
let mut kc: f64 = f64::NAN;
420-
let kcp: *mut f64 = &mut kc;
421-
let code = unsafe {
422-
ll_bindings::tsk_treeseq_kc_distance(self.as_ptr(), other.as_ptr(), lambda, kcp)
423-
};
424-
handle_tsk_return_value!(code, kc)
407+
self.inner.kc_distance(&other.inner, lambda)
425408
}
426409

427410
// FIXME: document
428411
pub fn num_samples(&self) -> SizeType {
429-
unsafe { ll_bindings::tsk_treeseq_get_num_samples(self.as_ptr()) }.into()
412+
self.inner.num_samples().into()
430413
}
431414

432415
/// Simplify tables and return a new tree sequence.
@@ -448,42 +431,29 @@ impl TreeSequence {
448431
options: O,
449432
idmap: bool,
450433
) -> Result<(Self, Option<Vec<NodeId>>), TskitError> {
451-
// The output is an UNINITIALIZED treeseq,
452-
// else we leak memory.
453-
let mut ts = MaybeUninit::<ll_bindings::tsk_treeseq_t>::uninit();
454434
let mut output_node_map: Vec<NodeId> = vec![];
455435
if idmap {
456436
output_node_map.resize(usize::try_from(self.nodes().num_rows())?, NodeId::NULL);
457437
}
458-
let rv = unsafe {
459-
ll_bindings::tsk_treeseq_simplify(
460-
self.as_ptr(),
461-
// NOTE: casting away const-ness:
462-
samples.as_ptr().cast::<tsk_id_t>(),
463-
samples.len() as tsk_size_t,
464-
options.into().bits(),
465-
ts.as_mut_ptr(),
466-
match idmap {
467-
true => output_node_map.as_mut_ptr().cast::<tsk_id_t>(),
468-
false => std::ptr::null_mut(),
469-
},
470-
)
438+
let llsamples = unsafe {
439+
std::slice::from_raw_parts(samples.as_ptr().cast::<tsk_id_t>(), samples.len())
471440
};
472-
// TODO: is it possible that this can leak somehow?
473-
handle_tsk_return_value!(
474-
rv,
475-
(
476-
{
477-
let mut inner = unsafe { ts.assume_init() };
478-
let views = crate::table_views::TableViews::new_from_tree_sequence(&mut inner)?;
479-
Self { inner, views }
480-
},
481-
match idmap {
482-
true => Some(output_node_map),
483-
false => None,
484-
}
485-
)
486-
)
441+
let mut inner = self.inner.simplify(
442+
llsamples,
443+
options.into().bits(),
444+
match idmap {
445+
true => output_node_map.as_mut_ptr().cast::<tsk_id_t>(),
446+
false => std::ptr::null_mut(),
447+
},
448+
)?;
449+
let views = crate::table_views::TableViews::new_from_tree_sequence(inner.as_mut_ptr())?;
450+
Ok((
451+
Self { inner, views },
452+
match idmap {
453+
true => Some(output_node_map),
454+
false => None,
455+
},
456+
))
487457
}
488458

489459
#[cfg(feature = "provenance")]
@@ -532,11 +502,11 @@ impl TreeSequence {
532502
let timestamp = humantime::format_rfc3339(std::time::SystemTime::now()).to_string();
533503
let rv = unsafe {
534504
ll_bindings::tsk_provenance_table_add_row(
535-
&mut (*self.inner.tables).provenances,
505+
&mut (*self.inner.as_ref().tables).provenances,
536506
timestamp.as_ptr() as *mut i8,
537-
timestamp.len() as tsk_size_t,
507+
timestamp.len() as ll_bindings::tsk_size_t,
538508
record.as_ptr() as *mut i8,
539-
record.len() as tsk_size_t,
509+
record.len() as ll_bindings::tsk_size_t,
540510
)
541511
};
542512
handle_tsk_return_value!(rv, crate::ProvenanceId::from(rv))
@@ -739,6 +709,7 @@ pub(crate) mod test_trees {
739709

740710
#[test]
741711
fn test_iterate_samples_two_trees() {
712+
use super::ll_bindings::tsk_size_t;
742713
let treeseq = treeseq_from_small_table_collection_two_trees();
743714
assert_eq!(treeseq.num_trees(), 2);
744715
let mut tree_iter = treeseq.tree_iterator(TreeFlags::SAMPLE_LISTS).unwrap();

0 commit comments

Comments
 (0)