Skip to content

Commit 2cdfbd9

Browse files
committed
refactor: improve safety of tsk array access
1 parent 5d09e58 commit 2cdfbd9

File tree

10 files changed

+127
-87
lines changed

10 files changed

+127
-87
lines changed

src/_macros.rs

Lines changed: 40 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -28,73 +28,69 @@ macro_rules! panic_on_tskit_error {
2828
}
2929

3030
macro_rules! unsafe_tsk_column_access {
31-
($i: expr, $lo: expr, $hi: expr, $array: expr) => {{
31+
($i: expr, $lo: expr, $hi: expr, $owner: expr, $array: ident) => {{
3232
if $i < $lo || ($i as $crate::tsk_size_t) >= $hi {
3333
None
3434
} else {
35-
Some(unsafe { *$array.offset($i as isize) })
35+
debug_assert!(!($owner).$array.is_null());
36+
if !$owner.$array.is_null() {
37+
// SAFETY: array is not null
38+
// and we did our best effort
39+
// on bounds checking
40+
Some(unsafe { *$owner.$array.offset($i as isize) })
41+
} else {
42+
None
43+
}
3644
}
3745
}};
38-
($i: expr, $lo: expr, $hi: expr, $array: expr, $output_id_type: expr) => {{
46+
($i: expr, $lo: expr, $hi: expr, $owner: expr, $array: ident, $output_id_type: expr) => {{
3947
if $i < $lo || ($i as $crate::tsk_size_t) >= $hi {
4048
None
4149
} else {
42-
Some($output_id_type(unsafe { *$array.offset($i as isize) }))
50+
debug_assert!(!($owner).$array.is_null());
51+
if !$owner.$array.is_null() {
52+
// SAFETY: array is not null
53+
// and we did our best effort
54+
// on bounds checking
55+
unsafe { Some($output_id_type(*($owner.$array.offset($i as isize)))) }
56+
} else {
57+
None
58+
}
4359
}
4460
}};
4561
}
4662

4763
macro_rules! unsafe_tsk_column_access_and_map_into {
48-
($i: expr, $lo: expr, $hi: expr, $array: expr) => {{
49-
unsafe_tsk_column_access!($i, $lo, $hi, $array).map(|v| v.into())
64+
($i: expr, $lo: expr, $hi: expr, $owner: expr, $array: ident) => {{
65+
unsafe_tsk_column_access!($i, $lo, $hi, $owner, $array).map(|v| v.into())
5066
}};
5167
}
5268

5369
macro_rules! unsafe_tsk_ragged_column_access {
54-
($i: expr, $lo: expr, $hi: expr, $array: expr, $offset_array: expr, $offset_array_len: expr) => {{
70+
($i: expr, $lo: expr, $hi: expr, $owner: expr, $array: ident, $offset_array: ident, $offset_array_len: ident, $output_id_type: ty) => {{
5571
let i = $crate::SizeType::try_from($i).ok()?;
5672
if $i < $lo || i >= $hi {
5773
None
58-
} else if $offset_array_len == 0 {
74+
} else if $owner.$offset_array_len == 0 {
5975
None
6076
} else {
61-
let start = unsafe { *$offset_array.offset($i as isize) };
62-
let stop = if i < $hi {
63-
unsafe { *$offset_array.offset(($i + 1) as isize) }
64-
} else {
65-
$offset_array_len as tsk_size_t
66-
};
67-
if start == stop {
68-
None
69-
} else {
70-
let mut buffer = vec![];
71-
for i in start..stop {
72-
buffer.push(unsafe { *$array.offset(i as isize) });
73-
}
74-
Some(buffer)
77+
debug_assert!(!$owner.$array.is_null());
78+
if $owner.$array.is_null() {
79+
return None;
7580
}
76-
}
77-
}};
78-
79-
($i: expr, $lo: expr, $hi: expr, $array: expr, $offset_array: expr, $offset_array_len: expr, $output_id_type: ty) => {{
80-
let i = $crate::SizeType::try_from($i).ok()?;
81-
if $i < $lo || i >= $hi {
82-
None
83-
} else if $offset_array_len == 0 {
84-
None
85-
} else {
86-
let start = unsafe { *$offset_array.offset($i as isize) };
81+
// SAFETY: we have checked bounds and ensured not null
82+
let start = unsafe { *$owner.$offset_array.offset($i as isize) };
8783
let stop = if i < $hi {
88-
unsafe { *$offset_array.offset(($i + 1) as isize) }
84+
unsafe { *$owner.$offset_array.offset(($i + 1) as isize) }
8985
} else {
90-
$offset_array_len as tsk_size_t
86+
$owner.$offset_array_len as tsk_size_t
9187
};
9288
if start == stop {
9389
None
9490
} else {
9591
Some(unsafe {
9692
std::slice::from_raw_parts(
97-
$array.offset(start as isize) as *const $output_id_type,
93+
$owner.$array.offset(start as isize) as *const $output_id_type,
9894
stop as usize - start as usize,
9995
)
10096
})
@@ -107,25 +103,27 @@ macro_rules! unsafe_tsk_ragged_column_access {
107103
// to pass clippy checks
108104
#[allow(unused_macros)]
109105
macro_rules! unsafe_tsk_ragged_char_column_access {
110-
($i: expr, $lo: expr, $hi: expr, $array: expr, $offset_array: expr, $offset_array_len: expr) => {{
106+
($i: expr, $lo: expr, $hi: expr, $owner: expr, $array: ident, $offset_array: ident, $offset_array_len: ident) => {{
111107
let i = $crate::SizeType::try_from($i)?;
112108
if $i < $lo || i >= $hi {
113109
Err(TskitError::IndexError {})
114-
} else if $offset_array_len == 0 {
110+
} else if $owner.$offset_array_len == 0 {
115111
Ok(None)
116112
} else {
117-
let start = unsafe { *$offset_array.offset($i as isize) };
113+
assert!(!$owner.$array.is_null());
114+
assert!(!$owner.$offset_array.is_null());
115+
let start = unsafe { *$owner.$offset_array.offset($i as isize) };
118116
let stop = if i < $hi {
119-
unsafe { *$offset_array.offset(($i + 1) as isize) }
117+
unsafe { *$owner.$offset_array.offset(($i + 1) as isize) }
120118
} else {
121-
$offset_array_len as tsk_size_t
119+
$owner.$offset_array_len as tsk_size_t
122120
};
123121
if start == stop {
124122
Ok(None)
125123
} else {
126124
let mut buffer = String::new();
127125
for i in start..stop {
128-
buffer.push(unsafe { *$array.offset(i as isize) as u8 as char });
126+
buffer.push(unsafe { *$owner.$array.offset(i as isize) as u8 as char });
129127
}
130128
Ok(Some(buffer))
131129
}

src/edge_table.rs

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,14 @@ impl<'a> EdgeTable<'a> {
8989
/// * `Some(parent)` if `u` is valid.
9090
/// * `None` otherwise.
9191
pub fn parent<E: Into<EdgeId> + Copy>(&'a self, row: E) -> Option<NodeId> {
92-
unsafe_tsk_column_access!(row.into().0, 0, self.num_rows(), self.table_.parent, NodeId)
92+
unsafe_tsk_column_access!(
93+
row.into().0,
94+
0,
95+
self.num_rows(),
96+
self.table_,
97+
parent,
98+
NodeId
99+
)
93100
}
94101

95102
/// Return the ``child`` value from row ``row`` of the table.
@@ -99,7 +106,7 @@ impl<'a> EdgeTable<'a> {
99106
/// * `Some(child)` if `u` is valid.
100107
/// * `None` otherwise.
101108
pub fn child<E: Into<EdgeId> + Copy>(&'a self, row: E) -> Option<NodeId> {
102-
unsafe_tsk_column_access!(row.into().0, 0, self.num_rows(), self.table_.child, NodeId)
109+
unsafe_tsk_column_access!(row.into().0, 0, self.num_rows(), self.table_, child, NodeId)
103110
}
104111

105112
/// Return the ``left`` value from row ``row`` of the table.
@@ -109,7 +116,14 @@ impl<'a> EdgeTable<'a> {
109116
/// * `Some(position)` if `u` is valid.
110117
/// * `None` otherwise.
111118
pub fn left<E: Into<EdgeId> + Copy>(&'a self, row: E) -> Option<Position> {
112-
unsafe_tsk_column_access!(row.into().0, 0, self.num_rows(), self.table_.left, Position)
119+
unsafe_tsk_column_access!(
120+
row.into().0,
121+
0,
122+
self.num_rows(),
123+
self.table_,
124+
left,
125+
Position
126+
)
113127
}
114128

115129
/// Return the ``right`` value from row ``row`` of the table.
@@ -119,7 +133,7 @@ impl<'a> EdgeTable<'a> {
119133
/// * `Some(position)` if `u` is valid.
120134
/// * `None` otherwise.
121135
pub fn right<E: Into<EdgeId> + Copy>(&'a self, row: E) -> Option<Position> {
122-
unsafe_tsk_column_access_and_map_into!(row.into().0, 0, self.num_rows(), self.table_.right)
136+
unsafe_tsk_column_access_and_map_into!(row.into().0, 0, self.num_rows(), self.table_, right)
123137
}
124138

125139
pub fn metadata<T: metadata::MetadataRoundtrip>(

src/individual_table.rs

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ impl<'a> IndividualTable<'a> {
107107
/// * `Some(flags)` if `row` is valid.
108108
/// * `None` otherwise.
109109
pub fn flags<I: Into<IndividualId> + Copy>(&self, row: I) -> Option<IndividualFlags> {
110-
unsafe_tsk_column_access_and_map_into!(row.into().0, 0, self.num_rows(), self.table_.flags)
110+
unsafe_tsk_column_access_and_map_into!(row.into().0, 0, self.num_rows(), self.table_, flags)
111111
}
112112

113113
/// Return the locations for a given row.
@@ -121,9 +121,10 @@ impl<'a> IndividualTable<'a> {
121121
row.into().0,
122122
0,
123123
self.num_rows(),
124-
self.table_.location,
125-
self.table_.location_offset,
126-
self.table_.location_length,
124+
self.table_,
125+
location,
126+
location_offset,
127+
location_length,
127128
Location
128129
)
129130
}
@@ -139,9 +140,10 @@ impl<'a> IndividualTable<'a> {
139140
row.into().0,
140141
0,
141142
self.num_rows(),
142-
self.table_.parents,
143-
self.table_.parents_offset,
144-
self.table_.parents_length,
143+
self.table_,
144+
parents,
145+
parents_offset,
146+
parents_length,
145147
IndividualId
146148
)
147149
}

src/metadata.rs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -269,6 +269,15 @@ pub(crate) fn char_column_to_slice<T: Sized>(
269269
Some(x) => x,
270270
None => return None,
271271
};
272+
debug_assert!(!column.is_null());
273+
debug_assert!(!column_offset.is_null());
274+
if column.is_null() {
275+
return None;
276+
}
277+
if column_offset.is_null() {
278+
return None;
279+
}
280+
// SAFETY: not null and best effort bounds check
272281
let start = unsafe { *column_offset.offset(row_isize) };
273282
let stop = if (row as tsk_size_t) < num_rows {
274283
unsafe { *column_offset.offset(row_isize + 1) }

src/migration_table.rs

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ impl<'a> MigrationTable<'a> {
9898
/// * `Some(position)` if `row` is valid.
9999
/// * `None` otherwise.
100100
pub fn left<M: Into<MigrationId> + Copy>(&'a self, row: M) -> Option<Position> {
101-
unsafe_tsk_column_access_and_map_into!(row.into().0, 0, self.num_rows(), self.table_.left)
101+
unsafe_tsk_column_access_and_map_into!(row.into().0, 0, self.num_rows(), self.table_, left)
102102
}
103103

104104
/// Return the right coordinate for a given row.
@@ -108,7 +108,7 @@ impl<'a> MigrationTable<'a> {
108108
/// * `Some(positions)` if `row` is valid.
109109
/// * `None` otherwise.
110110
pub fn right<M: Into<MigrationId> + Copy>(&'a self, row: M) -> Option<Position> {
111-
unsafe_tsk_column_access_and_map_into!(row.into().0, 0, self.num_rows(), self.table_.right)
111+
unsafe_tsk_column_access_and_map_into!(row.into().0, 0, self.num_rows(), self.table_, right)
112112
}
113113

114114
/// Return the node for a given row.
@@ -118,7 +118,7 @@ impl<'a> MigrationTable<'a> {
118118
/// * `Some(node)` if `row` is valid.
119119
/// * `None` otherwise.
120120
pub fn node<M: Into<MigrationId> + Copy>(&'a self, row: M) -> Option<NodeId> {
121-
unsafe_tsk_column_access!(row.into().0, 0, self.num_rows(), self.table_.node, NodeId)
121+
unsafe_tsk_column_access!(row.into().0, 0, self.num_rows(), self.table_, node, NodeId)
122122
}
123123

124124
/// Return the source population for a given row.
@@ -132,7 +132,8 @@ impl<'a> MigrationTable<'a> {
132132
row.into().0,
133133
0,
134134
self.num_rows(),
135-
self.table_.source,
135+
self.table_,
136+
source,
136137
PopulationId
137138
)
138139
}
@@ -148,7 +149,8 @@ impl<'a> MigrationTable<'a> {
148149
row.into().0,
149150
0,
150151
self.num_rows(),
151-
self.table_.dest,
152+
self.table_,
153+
dest,
152154
PopulationId
153155
)
154156
}
@@ -160,7 +162,7 @@ impl<'a> MigrationTable<'a> {
160162
/// * `Some(time)` if `row` is valid.
161163
/// * `None` otherwise.
162164
pub fn time<M: Into<MigrationId> + Copy>(&'a self, row: M) -> Option<Time> {
163-
unsafe_tsk_column_access_and_map_into!(row.into().0, 0, self.num_rows(), self.table_.time)
165+
unsafe_tsk_column_access_and_map_into!(row.into().0, 0, self.num_rows(), self.table_, time)
164166
}
165167

166168
/// Return the metadata for a given row.

src/mutation_table.rs

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ impl<'a> MutationTable<'a> {
100100
/// Will return [``IndexError``](crate::TskitError::IndexError)
101101
/// if ``row`` is out of range.
102102
pub fn site<M: Into<MutationId> + Copy>(&'a self, row: M) -> Option<SiteId> {
103-
unsafe_tsk_column_access!(row.into().0, 0, self.num_rows(), self.table_.site, SiteId)
103+
unsafe_tsk_column_access!(row.into().0, 0, self.num_rows(), self.table_, site, SiteId)
104104
}
105105

106106
/// Return the ``node`` value from row ``row`` of the table.
@@ -110,7 +110,7 @@ impl<'a> MutationTable<'a> {
110110
/// Will return [``IndexError``](crate::TskitError::IndexError)
111111
/// if ``row`` is out of range.
112112
pub fn node<M: Into<MutationId> + Copy>(&'a self, row: M) -> Option<NodeId> {
113-
unsafe_tsk_column_access!(row.into().0, 0, self.num_rows(), self.table_.node, NodeId)
113+
unsafe_tsk_column_access!(row.into().0, 0, self.num_rows(), self.table_, node, NodeId)
114114
}
115115

116116
/// Return the ``parent`` value from row ``row`` of the table.
@@ -124,7 +124,8 @@ impl<'a> MutationTable<'a> {
124124
row.into().0,
125125
0,
126126
self.num_rows(),
127-
self.table_.parent,
127+
self.table_,
128+
parent,
128129
MutationId
129130
)
130131
}
@@ -136,7 +137,7 @@ impl<'a> MutationTable<'a> {
136137
/// Will return [``IndexError``](crate::TskitError::IndexError)
137138
/// if ``row`` is out of range.
138139
pub fn time<M: Into<MutationId> + Copy>(&'a self, row: M) -> Option<Time> {
139-
unsafe_tsk_column_access!(row.into().0, 0, self.num_rows(), self.table_.time, Time)
140+
unsafe_tsk_column_access!(row.into().0, 0, self.num_rows(), self.table_, time, Time)
140141
}
141142

142143
/// Get the ``derived_state`` value from row ``row`` of the table.

src/node_table.rs

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ impl<'a> NodeTable<'a> {
105105
/// # }
106106
/// ```
107107
pub fn time<N: Into<NodeId> + Copy>(&'a self, row: N) -> Option<Time> {
108-
unsafe_tsk_column_access!(row.into().0, 0, self.num_rows(), self.table_.time, Time)
108+
unsafe_tsk_column_access!(row.into().0, 0, self.num_rows(), self.table_, time, Time)
109109
}
110110

111111
/// Return the ``flags`` value from row ``row`` of the table.
@@ -130,7 +130,7 @@ impl<'a> NodeTable<'a> {
130130
/// # }
131131
/// ```
132132
pub fn flags<N: Into<NodeId> + Copy>(&'a self, row: N) -> Option<NodeFlags> {
133-
unsafe_tsk_column_access_and_map_into!(row.into().0, 0, self.num_rows(), self.table_.flags)
133+
unsafe_tsk_column_access_and_map_into!(row.into().0, 0, self.num_rows(), self.table_, flags)
134134
}
135135

136136
/// Mutable access to node flags.
@@ -275,7 +275,8 @@ impl<'a> NodeTable<'a> {
275275
row.into().0,
276276
0,
277277
self.num_rows(),
278-
self.table_.population,
278+
self.table_,
279+
population,
279280
PopulationId
280281
)
281282
}
@@ -320,7 +321,8 @@ impl<'a> NodeTable<'a> {
320321
row.into().0,
321322
0,
322323
self.num_rows(),
323-
self.table_.individual,
324+
self.table_,
325+
individual,
324326
IndividualId
325327
)
326328
}

0 commit comments

Comments
 (0)