Skip to content

Improve macro safety #362

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Nov 2, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 40 additions & 42 deletions src/_macros.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,73 +28,69 @@ macro_rules! panic_on_tskit_error {
}

macro_rules! unsafe_tsk_column_access {
($i: expr, $lo: expr, $hi: expr, $array: expr) => {{
($i: expr, $lo: expr, $hi: expr, $owner: expr, $array: ident) => {{
if $i < $lo || ($i as $crate::tsk_size_t) >= $hi {
None
} else {
Some(unsafe { *$array.offset($i as isize) })
debug_assert!(!($owner).$array.is_null());
if !$owner.$array.is_null() {
// SAFETY: array is not null
// and we did our best effort
// on bounds checking
Some(unsafe { *$owner.$array.offset($i as isize) })
} else {
None
}
}
}};
($i: expr, $lo: expr, $hi: expr, $array: expr, $output_id_type: expr) => {{
($i: expr, $lo: expr, $hi: expr, $owner: expr, $array: ident, $output_id_type: expr) => {{
if $i < $lo || ($i as $crate::tsk_size_t) >= $hi {
None
} else {
Some($output_id_type(unsafe { *$array.offset($i as isize) }))
debug_assert!(!($owner).$array.is_null());
if !$owner.$array.is_null() {
// SAFETY: array is not null
// and we did our best effort
// on bounds checking
unsafe { Some($output_id_type(*($owner.$array.offset($i as isize)))) }
} else {
None
}
}
}};
}

macro_rules! unsafe_tsk_column_access_and_map_into {
($i: expr, $lo: expr, $hi: expr, $array: expr) => {{
unsafe_tsk_column_access!($i, $lo, $hi, $array).map(|v| v.into())
($i: expr, $lo: expr, $hi: expr, $owner: expr, $array: ident) => {{
unsafe_tsk_column_access!($i, $lo, $hi, $owner, $array).map(|v| v.into())
}};
}

macro_rules! unsafe_tsk_ragged_column_access {
($i: expr, $lo: expr, $hi: expr, $array: expr, $offset_array: expr, $offset_array_len: expr) => {{
($i: expr, $lo: expr, $hi: expr, $owner: expr, $array: ident, $offset_array: ident, $offset_array_len: ident, $output_id_type: ty) => {{
let i = $crate::SizeType::try_from($i).ok()?;
if $i < $lo || i >= $hi {
None
} else if $offset_array_len == 0 {
} else if $owner.$offset_array_len == 0 {
None
} else {
let start = unsafe { *$offset_array.offset($i as isize) };
let stop = if i < $hi {
unsafe { *$offset_array.offset(($i + 1) as isize) }
} else {
$offset_array_len as tsk_size_t
};
if start == stop {
None
} else {
let mut buffer = vec![];
for i in start..stop {
buffer.push(unsafe { *$array.offset(i as isize) });
}
Some(buffer)
debug_assert!(!$owner.$array.is_null());
if $owner.$array.is_null() {
return None;
}
}
}};

($i: expr, $lo: expr, $hi: expr, $array: expr, $offset_array: expr, $offset_array_len: expr, $output_id_type: ty) => {{
let i = $crate::SizeType::try_from($i).ok()?;
if $i < $lo || i >= $hi {
None
} else if $offset_array_len == 0 {
None
} else {
let start = unsafe { *$offset_array.offset($i as isize) };
// SAFETY: we have checked bounds and ensured not null
let start = unsafe { *$owner.$offset_array.offset($i as isize) };
let stop = if i < $hi {
unsafe { *$offset_array.offset(($i + 1) as isize) }
unsafe { *$owner.$offset_array.offset(($i + 1) as isize) }
} else {
$offset_array_len as tsk_size_t
$owner.$offset_array_len as tsk_size_t
};
if start == stop {
None
} else {
Some(unsafe {
std::slice::from_raw_parts(
$array.offset(start as isize) as *const $output_id_type,
$owner.$array.offset(start as isize) as *const $output_id_type,
stop as usize - start as usize,
)
})
Expand All @@ -107,25 +103,27 @@ macro_rules! unsafe_tsk_ragged_column_access {
// to pass clippy checks
#[allow(unused_macros)]
macro_rules! unsafe_tsk_ragged_char_column_access {
($i: expr, $lo: expr, $hi: expr, $array: expr, $offset_array: expr, $offset_array_len: expr) => {{
($i: expr, $lo: expr, $hi: expr, $owner: expr, $array: ident, $offset_array: ident, $offset_array_len: ident) => {{
let i = $crate::SizeType::try_from($i)?;
if $i < $lo || i >= $hi {
Err(TskitError::IndexError {})
} else if $offset_array_len == 0 {
} else if $owner.$offset_array_len == 0 {
Ok(None)
} else {
let start = unsafe { *$offset_array.offset($i as isize) };
assert!(!$owner.$array.is_null());
assert!(!$owner.$offset_array.is_null());
let start = unsafe { *$owner.$offset_array.offset($i as isize) };
let stop = if i < $hi {
unsafe { *$offset_array.offset(($i + 1) as isize) }
unsafe { *$owner.$offset_array.offset(($i + 1) as isize) }
} else {
$offset_array_len as tsk_size_t
$owner.$offset_array_len as tsk_size_t
};
if start == stop {
Ok(None)
} else {
let mut buffer = String::new();
for i in start..stop {
buffer.push(unsafe { *$array.offset(i as isize) as u8 as char });
buffer.push(unsafe { *$owner.$array.offset(i as isize) as u8 as char });
}
Ok(Some(buffer))
}
Expand Down
22 changes: 18 additions & 4 deletions src/edge_table.rs
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,14 @@ impl<'a> EdgeTable<'a> {
/// * `Some(parent)` if `u` is valid.
/// * `None` otherwise.
pub fn parent<E: Into<EdgeId> + Copy>(&'a self, row: E) -> Option<NodeId> {
unsafe_tsk_column_access!(row.into().0, 0, self.num_rows(), self.table_.parent, NodeId)
unsafe_tsk_column_access!(
row.into().0,
0,
self.num_rows(),
self.table_,
parent,
NodeId
)
}

/// Return the ``child`` value from row ``row`` of the table.
Expand All @@ -99,7 +106,7 @@ impl<'a> EdgeTable<'a> {
/// * `Some(child)` if `u` is valid.
/// * `None` otherwise.
pub fn child<E: Into<EdgeId> + Copy>(&'a self, row: E) -> Option<NodeId> {
unsafe_tsk_column_access!(row.into().0, 0, self.num_rows(), self.table_.child, NodeId)
unsafe_tsk_column_access!(row.into().0, 0, self.num_rows(), self.table_, child, NodeId)
}

/// Return the ``left`` value from row ``row`` of the table.
Expand All @@ -109,7 +116,14 @@ impl<'a> EdgeTable<'a> {
/// * `Some(position)` if `u` is valid.
/// * `None` otherwise.
pub fn left<E: Into<EdgeId> + Copy>(&'a self, row: E) -> Option<Position> {
unsafe_tsk_column_access!(row.into().0, 0, self.num_rows(), self.table_.left, Position)
unsafe_tsk_column_access!(
row.into().0,
0,
self.num_rows(),
self.table_,
left,
Position
)
}

/// Return the ``right`` value from row ``row`` of the table.
Expand All @@ -119,7 +133,7 @@ impl<'a> EdgeTable<'a> {
/// * `Some(position)` if `u` is valid.
/// * `None` otherwise.
pub fn right<E: Into<EdgeId> + Copy>(&'a self, row: E) -> Option<Position> {
unsafe_tsk_column_access_and_map_into!(row.into().0, 0, self.num_rows(), self.table_.right)
unsafe_tsk_column_access_and_map_into!(row.into().0, 0, self.num_rows(), self.table_, right)
}

pub fn metadata<T: metadata::MetadataRoundtrip>(
Expand Down
16 changes: 9 additions & 7 deletions src/individual_table.rs
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ impl<'a> IndividualTable<'a> {
/// * `Some(flags)` if `row` is valid.
/// * `None` otherwise.
pub fn flags<I: Into<IndividualId> + Copy>(&self, row: I) -> Option<IndividualFlags> {
unsafe_tsk_column_access_and_map_into!(row.into().0, 0, self.num_rows(), self.table_.flags)
unsafe_tsk_column_access_and_map_into!(row.into().0, 0, self.num_rows(), self.table_, flags)
}

/// Return the locations for a given row.
Expand All @@ -121,9 +121,10 @@ impl<'a> IndividualTable<'a> {
row.into().0,
0,
self.num_rows(),
self.table_.location,
self.table_.location_offset,
self.table_.location_length,
self.table_,
location,
location_offset,
location_length,
Location
)
}
Expand All @@ -139,9 +140,10 @@ impl<'a> IndividualTable<'a> {
row.into().0,
0,
self.num_rows(),
self.table_.parents,
self.table_.parents_offset,
self.table_.parents_length,
self.table_,
parents,
parents_offset,
parents_length,
IndividualId
)
}
Expand Down
9 changes: 9 additions & 0 deletions src/metadata.rs
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,15 @@ pub(crate) fn char_column_to_slice<T: Sized>(
Some(x) => x,
None => return None,
};
debug_assert!(!column.is_null());
debug_assert!(!column_offset.is_null());
if column.is_null() {
return None;
}
if column_offset.is_null() {
return None;
}
// SAFETY: not null and best effort bounds check
let start = unsafe { *column_offset.offset(row_isize) };
let stop = if (row as tsk_size_t) < num_rows {
unsafe { *column_offset.offset(row_isize + 1) }
Expand Down
14 changes: 8 additions & 6 deletions src/migration_table.rs
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ impl<'a> MigrationTable<'a> {
/// * `Some(position)` if `row` is valid.
/// * `None` otherwise.
pub fn left<M: Into<MigrationId> + Copy>(&'a self, row: M) -> Option<Position> {
unsafe_tsk_column_access_and_map_into!(row.into().0, 0, self.num_rows(), self.table_.left)
unsafe_tsk_column_access_and_map_into!(row.into().0, 0, self.num_rows(), self.table_, left)
}

/// Return the right coordinate for a given row.
Expand All @@ -108,7 +108,7 @@ impl<'a> MigrationTable<'a> {
/// * `Some(positions)` if `row` is valid.
/// * `None` otherwise.
pub fn right<M: Into<MigrationId> + Copy>(&'a self, row: M) -> Option<Position> {
unsafe_tsk_column_access_and_map_into!(row.into().0, 0, self.num_rows(), self.table_.right)
unsafe_tsk_column_access_and_map_into!(row.into().0, 0, self.num_rows(), self.table_, right)
}

/// Return the node for a given row.
Expand All @@ -118,7 +118,7 @@ impl<'a> MigrationTable<'a> {
/// * `Some(node)` if `row` is valid.
/// * `None` otherwise.
pub fn node<M: Into<MigrationId> + Copy>(&'a self, row: M) -> Option<NodeId> {
unsafe_tsk_column_access!(row.into().0, 0, self.num_rows(), self.table_.node, NodeId)
unsafe_tsk_column_access!(row.into().0, 0, self.num_rows(), self.table_, node, NodeId)
}

/// Return the source population for a given row.
Expand All @@ -132,7 +132,8 @@ impl<'a> MigrationTable<'a> {
row.into().0,
0,
self.num_rows(),
self.table_.source,
self.table_,
source,
PopulationId
)
}
Expand All @@ -148,7 +149,8 @@ impl<'a> MigrationTable<'a> {
row.into().0,
0,
self.num_rows(),
self.table_.dest,
self.table_,
dest,
PopulationId
)
}
Expand All @@ -160,7 +162,7 @@ impl<'a> MigrationTable<'a> {
/// * `Some(time)` if `row` is valid.
/// * `None` otherwise.
pub fn time<M: Into<MigrationId> + Copy>(&'a self, row: M) -> Option<Time> {
unsafe_tsk_column_access_and_map_into!(row.into().0, 0, self.num_rows(), self.table_.time)
unsafe_tsk_column_access_and_map_into!(row.into().0, 0, self.num_rows(), self.table_, time)
}

/// Return the metadata for a given row.
Expand Down
9 changes: 5 additions & 4 deletions src/mutation_table.rs
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ impl<'a> MutationTable<'a> {
/// Will return [``IndexError``](crate::TskitError::IndexError)
/// if ``row`` is out of range.
pub fn site<M: Into<MutationId> + Copy>(&'a self, row: M) -> Option<SiteId> {
unsafe_tsk_column_access!(row.into().0, 0, self.num_rows(), self.table_.site, SiteId)
unsafe_tsk_column_access!(row.into().0, 0, self.num_rows(), self.table_, site, SiteId)
}

/// Return the ``node`` value from row ``row`` of the table.
Expand All @@ -110,7 +110,7 @@ impl<'a> MutationTable<'a> {
/// Will return [``IndexError``](crate::TskitError::IndexError)
/// if ``row`` is out of range.
pub fn node<M: Into<MutationId> + Copy>(&'a self, row: M) -> Option<NodeId> {
unsafe_tsk_column_access!(row.into().0, 0, self.num_rows(), self.table_.node, NodeId)
unsafe_tsk_column_access!(row.into().0, 0, self.num_rows(), self.table_, node, NodeId)
}

/// Return the ``parent`` value from row ``row`` of the table.
Expand All @@ -124,7 +124,8 @@ impl<'a> MutationTable<'a> {
row.into().0,
0,
self.num_rows(),
self.table_.parent,
self.table_,
parent,
MutationId
)
}
Expand All @@ -136,7 +137,7 @@ impl<'a> MutationTable<'a> {
/// Will return [``IndexError``](crate::TskitError::IndexError)
/// if ``row`` is out of range.
pub fn time<M: Into<MutationId> + Copy>(&'a self, row: M) -> Option<Time> {
unsafe_tsk_column_access!(row.into().0, 0, self.num_rows(), self.table_.time, Time)
unsafe_tsk_column_access!(row.into().0, 0, self.num_rows(), self.table_, time, Time)
}

/// Get the ``derived_state`` value from row ``row`` of the table.
Expand Down
10 changes: 6 additions & 4 deletions src/node_table.rs
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ impl<'a> NodeTable<'a> {
/// # }
/// ```
pub fn time<N: Into<NodeId> + Copy>(&'a self, row: N) -> Option<Time> {
unsafe_tsk_column_access!(row.into().0, 0, self.num_rows(), self.table_.time, Time)
unsafe_tsk_column_access!(row.into().0, 0, self.num_rows(), self.table_, time, Time)
}

/// Return the ``flags`` value from row ``row`` of the table.
Expand All @@ -130,7 +130,7 @@ impl<'a> NodeTable<'a> {
/// # }
/// ```
pub fn flags<N: Into<NodeId> + Copy>(&'a self, row: N) -> Option<NodeFlags> {
unsafe_tsk_column_access_and_map_into!(row.into().0, 0, self.num_rows(), self.table_.flags)
unsafe_tsk_column_access_and_map_into!(row.into().0, 0, self.num_rows(), self.table_, flags)
}

/// Mutable access to node flags.
Expand Down Expand Up @@ -275,7 +275,8 @@ impl<'a> NodeTable<'a> {
row.into().0,
0,
self.num_rows(),
self.table_.population,
self.table_,
population,
PopulationId
)
}
Expand Down Expand Up @@ -320,7 +321,8 @@ impl<'a> NodeTable<'a> {
row.into().0,
0,
self.num_rows(),
self.table_.individual,
self.table_,
individual,
IndividualId
)
}
Expand Down
Loading