From e2099edce5e7cbf9e58501d0e4456e38a278f745 Mon Sep 17 00:00:00 2001 From: "Kevin R. Thornton" Date: Mon, 5 Dec 2022 17:02:35 -0800 Subject: [PATCH] refactor: add sys::LLTreeSeq. * This is the new back end for TreeSequence. * Allows pushing lots of unsafe into sys.rs. --- src/sys.rs | 89 +++++++++++++++++++++++++++++++++++++++++++++++ src/trees.rs | 97 ++++++++++++++++++---------------------------------- 2 files changed, 123 insertions(+), 63 deletions(-) diff --git a/src/sys.rs b/src/sys.rs index 6a2d0b24d..93e44dcb0 100644 --- a/src/sys.rs +++ b/src/sys.rs @@ -6,6 +6,7 @@ use bindings::tsk_mutation_table_t; use bindings::tsk_node_table_t; use bindings::tsk_population_table_t; use bindings::tsk_site_table_t; +use std::ffi::CString; use std::ptr::NonNull; #[cfg(feature = "provenance")] @@ -47,6 +48,94 @@ basic_lltableref_impl!(LLIndividualTableRef, tsk_individual_table_t); #[cfg(feature = "provenance")] basic_lltableref_impl!(LLProvenanceTableRef, tsk_provenance_table_t); +#[repr(transparent)] +pub struct LLTreeSeq(bindings::tsk_treeseq_t); + +impl LLTreeSeq { + pub fn new( + tables: *mut bindings::tsk_table_collection_t, + flags: bindings::tsk_flags_t, + ) -> Result { + let mut inner = std::mem::MaybeUninit::::uninit(); + let mut flags = flags; + flags |= bindings::TSK_TAKE_OWNERSHIP; + let rv = unsafe { bindings::tsk_treeseq_init(inner.as_mut_ptr(), tables, flags) }; + handle_tsk_return_value!(rv, Self(unsafe { inner.assume_init() })) + } + + pub fn as_ref(&self) -> &bindings::tsk_treeseq_t { + &self.0 + } + + pub fn as_ptr(&self) -> *const bindings::tsk_treeseq_t { + &self.0 + } + + pub fn as_mut_ptr(&mut self) -> *mut bindings::tsk_treeseq_t { + &mut self.0 + } + + pub fn simplify( + &self, + samples: &[bindings::tsk_id_t], + options: bindings::tsk_flags_t, + idmap: *mut bindings::tsk_id_t, + ) -> Result { + // The output is an UNINITIALIZED treeseq, + // else we leak memory. + let mut ts = std::mem::MaybeUninit::::uninit(); + // SAFETY: samples is not null, idmap is allowed to be. + // self.as_ptr() is not null + let rv = unsafe { + bindings::tsk_treeseq_simplify( + self.as_ptr(), + samples.as_ptr(), + samples.len() as bindings::tsk_size_t, + options, + ts.as_mut_ptr(), + idmap, + ) + }; + let init = unsafe { ts.assume_init() }; + if rv < 0 { + // SAFETY: the ptr is not null + // and tsk_treeseq_free uses safe methods + // to clean up. + unsafe { bindings::tsk_treeseq_free(ts.as_mut_ptr()) }; + } + handle_tsk_return_value!(rv, Self(init)) + } + + pub fn dump( + &self, + filename: CString, + options: bindings::tsk_flags_t, + ) -> Result { + // SAFETY: self pointer is not null + let rv = unsafe { bindings::tsk_treeseq_dump(self.as_ptr(), filename.as_ptr(), options) }; + handle_tsk_return_value!(rv) + } + + pub fn num_trees(&self) -> bindings::tsk_size_t { + // SAFETY: self pointer is not null + unsafe { bindings::tsk_treeseq_get_num_trees(self.as_ptr()) } + } + + pub fn kc_distance(&self, other: &Self, lambda: f64) -> Result { + let mut kc: f64 = f64::NAN; + let kcp: *mut f64 = &mut kc; + // SAFETY: self pointer is not null + let code = unsafe { + bindings::tsk_treeseq_kc_distance(self.as_ptr(), other.as_ptr(), lambda, kcp) + }; + handle_tsk_return_value!(code, kc) + } + + pub fn num_samples(&self) -> bindings::tsk_size_t { + unsafe { bindings::tsk_treeseq_get_num_samples(self.as_ptr()) } + } +} + fn tsk_column_access_detail, L: Into, T: Copy>( row: R, column: *const T, diff --git a/src/trees.rs b/src/trees.rs index 24d146508..a91a70f3f 100644 --- a/src/trees.rs +++ b/src/trees.rs @@ -1,4 +1,3 @@ -use std::mem::MaybeUninit; use std::ops::Deref; use std::ops::DerefMut; @@ -13,7 +12,7 @@ use crate::TreeFlags; use crate::TreeInterface; use crate::TreeSequenceFlags; use crate::TskReturnValue; -use crate::{tsk_id_t, tsk_size_t, TableCollection}; +use crate::{tsk_id_t, TableCollection}; use ll_bindings::tsk_tree_free; use std::ptr::NonNull; @@ -185,7 +184,7 @@ impl streaming_iterator::DoubleEndedStreamingIterator for Tree { /// assert_eq!(treeseq.nodes_mut().num_rows(), 3); /// ``` pub struct TreeSequence { - pub(crate) inner: ll_bindings::tsk_treeseq_t, + pub(crate) inner: sys::LLTreeSeq, views: crate::table_views::TableViews, } @@ -194,7 +193,7 @@ unsafe impl Sync for TreeSequence {} impl Drop for TreeSequence { fn drop(&mut self) { - let rv = unsafe { ll_bindings::tsk_treeseq_free(&mut self.inner) }; + let rv = unsafe { ll_bindings::tsk_treeseq_free(self.as_mut_ptr()) }; assert_eq!(rv, 0); } } @@ -247,31 +246,24 @@ impl TreeSequence { tables: TableCollection, flags: F, ) -> Result { - let mut inner = std::mem::MaybeUninit::::uninit(); - let mut flags: u32 = flags.into().bits(); - flags |= ll_bindings::TSK_TAKE_OWNERSHIP; let raw_tables_ptr = tables.into_raw()?; - let rv = - unsafe { ll_bindings::tsk_treeseq_init(inner.as_mut_ptr(), raw_tables_ptr, flags) }; + let mut inner = sys::LLTreeSeq::new(raw_tables_ptr, flags.into().bits())?; let views = crate::table_views::TableViews::new_from_tree_sequence(inner.as_mut_ptr())?; - handle_tsk_return_value!(rv, { - let inner = unsafe { inner.assume_init() }; - Self { inner, views } - }) + Ok(Self { inner, views }) } fn as_ref(&self) -> &ll_bindings::tsk_treeseq_t { - &self.inner + self.inner.as_ref() } /// Pointer to the low-level C type. pub fn as_ptr(&self) -> *const ll_bindings::tsk_treeseq_t { - &self.inner + self.inner.as_ptr() } /// Mutable pointer to the low-level C type. pub fn as_mut_ptr(&mut self) -> *mut ll_bindings::tsk_treeseq_t { - &mut self.inner + self.inner.as_mut_ptr() } /// Dump the tree sequence to file. @@ -290,11 +282,7 @@ impl TreeSequence { let c_str = std::ffi::CString::new(filename).map_err(|_| { TskitError::LibraryError("call to ffi::Cstring::new failed".to_string()) })?; - let rv = unsafe { - ll_bindings::tsk_treeseq_dump(self.as_ptr(), c_str.as_ptr(), options.into().bits()) - }; - - handle_tsk_return_value!(rv) + self.inner.dump(c_str, options.into().bits()) } /// Load from a file. @@ -401,7 +389,7 @@ impl TreeSequence { /// Get the number of trees. pub fn num_trees(&self) -> SizeType { - unsafe { ll_bindings::tsk_treeseq_get_num_trees(self.as_ptr()) }.into() + self.inner.num_trees().into() } /// Calculate the average Kendall-Colijn (`K-C`) distance between @@ -416,17 +404,12 @@ impl TreeSequence { /// * `lambda` specifies the relative weight of topology and branch length. /// See [`TreeInterface::kc_distance`] for more details. pub fn kc_distance(&self, other: &TreeSequence, lambda: f64) -> Result { - let mut kc: f64 = f64::NAN; - let kcp: *mut f64 = &mut kc; - let code = unsafe { - ll_bindings::tsk_treeseq_kc_distance(self.as_ptr(), other.as_ptr(), lambda, kcp) - }; - handle_tsk_return_value!(code, kc) + self.inner.kc_distance(&other.inner, lambda) } // FIXME: document pub fn num_samples(&self) -> SizeType { - unsafe { ll_bindings::tsk_treeseq_get_num_samples(self.as_ptr()) }.into() + self.inner.num_samples().into() } /// Simplify tables and return a new tree sequence. @@ -448,42 +431,29 @@ impl TreeSequence { options: O, idmap: bool, ) -> Result<(Self, Option>), TskitError> { - // The output is an UNINITIALIZED treeseq, - // else we leak memory. - let mut ts = MaybeUninit::::uninit(); let mut output_node_map: Vec = vec![]; if idmap { output_node_map.resize(usize::try_from(self.nodes().num_rows())?, NodeId::NULL); } - let rv = unsafe { - ll_bindings::tsk_treeseq_simplify( - self.as_ptr(), - // NOTE: casting away const-ness: - samples.as_ptr().cast::(), - samples.len() as tsk_size_t, - options.into().bits(), - ts.as_mut_ptr(), - match idmap { - true => output_node_map.as_mut_ptr().cast::(), - false => std::ptr::null_mut(), - }, - ) + let llsamples = unsafe { + std::slice::from_raw_parts(samples.as_ptr().cast::(), samples.len()) }; - // TODO: is it possible that this can leak somehow? - handle_tsk_return_value!( - rv, - ( - { - let mut inner = unsafe { ts.assume_init() }; - let views = crate::table_views::TableViews::new_from_tree_sequence(&mut inner)?; - Self { inner, views } - }, - match idmap { - true => Some(output_node_map), - false => None, - } - ) - ) + let mut inner = self.inner.simplify( + llsamples, + options.into().bits(), + match idmap { + true => output_node_map.as_mut_ptr().cast::(), + false => std::ptr::null_mut(), + }, + )?; + let views = crate::table_views::TableViews::new_from_tree_sequence(inner.as_mut_ptr())?; + Ok(( + Self { inner, views }, + match idmap { + true => Some(output_node_map), + false => None, + }, + )) } #[cfg(feature = "provenance")] @@ -532,11 +502,11 @@ impl TreeSequence { let timestamp = humantime::format_rfc3339(std::time::SystemTime::now()).to_string(); let rv = unsafe { ll_bindings::tsk_provenance_table_add_row( - &mut (*self.inner.tables).provenances, + &mut (*self.inner.as_ref().tables).provenances, timestamp.as_ptr() as *mut i8, - timestamp.len() as tsk_size_t, + timestamp.len() as ll_bindings::tsk_size_t, record.as_ptr() as *mut i8, - record.len() as tsk_size_t, + record.len() as ll_bindings::tsk_size_t, ) }; handle_tsk_return_value!(rv, crate::ProvenanceId::from(rv)) @@ -739,6 +709,7 @@ pub(crate) mod test_trees { #[test] fn test_iterate_samples_two_trees() { + use super::ll_bindings::tsk_size_t; let treeseq = treeseq_from_small_table_collection_two_trees(); assert_eq!(treeseq.num_trees(), 2); let mut tree_iter = treeseq.tree_iterator(TreeFlags::SAMPLE_LISTS).unwrap();