diff --git a/src/_macros.rs b/src/_macros.rs index ee4967881..eb99bcf65 100644 --- a/src/_macros.rs +++ b/src/_macros.rs @@ -1104,6 +1104,29 @@ macro_rules! row_lending_iterator_get { }; } +macro_rules! optional_container_comparison { + ($lhs: expr, $rhs: expr) => { + if let Some(value) = &$lhs { + if let Some(ovalue) = &$rhs { + if value.len() != ovalue.len() { + return false; + } + if value.iter().zip(ovalue.iter()).any(|(a, b)| a != b) { + false + } else { + true + } + } else { + false + } + } else if $rhs.is_some() { + false + } else { + true + } + }; +} + #[cfg(test)] mod test { use crate::error::TskitError; diff --git a/src/edge_table.rs b/src/edge_table.rs index f18b2893f..296b453dd 100644 --- a/src/edge_table.rs +++ b/src/edge_table.rs @@ -65,6 +65,7 @@ impl Iterator for EdgeTableIterator { } /// Row of an [`EdgeTable`] +#[derive(Debug)] pub struct EdgeTableRowView<'a> { table: &'a EdgeTable, pub id: EdgeId, @@ -89,6 +90,41 @@ impl<'a> EdgeTableRowView<'a> { } } +impl<'a> PartialEq for EdgeTableRowView<'a> { + fn eq(&self, other: &Self) -> bool { + self.id == other.id + && self.parent == other.parent + && self.child == other.child + && crate::util::partial_cmp_equal(&self.left, &other.left) + && crate::util::partial_cmp_equal(&self.right, &other.right) + && self.metadata == other.metadata + } +} + +impl<'a> Eq for EdgeTableRowView<'a> {} + +impl<'a> PartialEq for EdgeTableRowView<'a> { + fn eq(&self, other: &EdgeTableRow) -> bool { + self.id == other.id + && self.parent == other.parent + && self.child == other.child + && crate::util::partial_cmp_equal(&self.left, &other.left) + && crate::util::partial_cmp_equal(&self.right, &other.right) + && optional_container_comparison!(self.metadata, other.metadata) + } +} + +impl PartialEq> for EdgeTableRow { + fn eq(&self, other: &EdgeTableRowView) -> bool { + self.id == other.id + && self.parent == other.parent + && self.child == other.child + && crate::util::partial_cmp_equal(&self.left, &other.left) + && crate::util::partial_cmp_equal(&self.right, &other.right) + && optional_container_comparison!(self.metadata, other.metadata) + } +} + impl<'a> streaming_iterator::StreamingIterator for EdgeTableRowView<'a> { type Item = Self; @@ -110,6 +146,7 @@ impl<'a> streaming_iterator::StreamingIterator for EdgeTableRowView<'a> { /// by types implementing [`std::ops::Deref`] to /// [`crate::table_views::TableViews`] #[repr(transparent)] +#[derive(Debug)] pub struct EdgeTable { pub(crate) table_: NonNull, } diff --git a/src/individual_table.rs b/src/individual_table.rs index bc782b81b..8f52cb586 100644 --- a/src/individual_table.rs +++ b/src/individual_table.rs @@ -24,27 +24,11 @@ impl PartialEq for IndividualTableRow { && self.flags == other.flags && self.parents == other.parents && self.metadata == other.metadata - && match &self.location { - None => other.location.is_none(), - Some(a) => match &other.location { - None => false, - Some(b) => { - if a.len() != b.len() { - false - } else { - for (i, j) in a.iter().enumerate() { - if !crate::util::partial_cmp_equal(&b[i], j) { - return false; - } - } - true - } - } - }, - } + && self.location == other.location } } +#[derive(Debug)] pub struct IndividualTableRowView<'a> { table: &'a IndividualTable, pub id: IndividualId, @@ -67,6 +51,38 @@ impl<'a> IndividualTableRowView<'a> { } } +impl<'a> PartialEq for IndividualTableRowView<'a> { + fn eq(&self, other: &Self) -> bool { + self.id == other.id + && self.flags == other.flags + && self.parents == other.parents + && self.metadata == other.metadata + && self.location == other.location + } +} + +impl<'a> Eq for IndividualTableRowView<'a> {} + +impl<'a> PartialEq for IndividualTableRowView<'a> { + fn eq(&self, other: &IndividualTableRow) -> bool { + self.id == other.id + && self.flags == other.flags + && optional_container_comparison!(self.parents, other.parents) + && optional_container_comparison!(self.metadata, other.metadata) + && optional_container_comparison!(self.location, other.location) + } +} + +impl PartialEq> for IndividualTableRow { + fn eq(&self, other: &IndividualTableRowView) -> bool { + self.id == other.id + && self.flags == other.flags + && optional_container_comparison!(self.parents, other.parents) + && optional_container_comparison!(self.metadata, other.metadata) + && optional_container_comparison!(self.location, other.location) + } +} + impl<'a> streaming_iterator::StreamingIterator for IndividualTableRowView<'a> { type Item = Self; @@ -86,6 +102,7 @@ impl<'a> streaming_iterator::StreamingIterator for IndividualTableRowView<'a> { /// These are not created directly but are accessed /// by types implementing [`std::ops::Deref`] to /// [`crate::table_views::TableViews`] +#[derive(Debug)] pub struct IndividualTable { table_: NonNull, } diff --git a/src/migration_table.rs b/src/migration_table.rs index 18f668607..7e3f4b339 100644 --- a/src/migration_table.rs +++ b/src/migration_table.rs @@ -73,6 +73,7 @@ impl Iterator for MigrationTableIterator { } } +#[derive(Debug)] pub struct MigrationTableRowView<'a> { table: &'a MigrationTable, pub id: MigrationId, @@ -101,6 +102,47 @@ impl<'a> MigrationTableRowView<'a> { } } +impl<'a> PartialEq for MigrationTableRowView<'a> { + fn eq(&self, other: &Self) -> bool { + self.id == other.id + && self.node == other.node + && self.source == other.source + && self.dest == other.dest + && crate::util::partial_cmp_equal(&self.left, &other.left) + && crate::util::partial_cmp_equal(&self.right, &other.right) + && crate::util::partial_cmp_equal(&self.time, &other.time) + && self.metadata == other.metadata + } +} + +impl<'a> Eq for MigrationTableRowView<'a> {} + +impl<'a> PartialEq for MigrationTableRowView<'a> { + fn eq(&self, other: &MigrationTableRow) -> bool { + self.id == other.id + && self.node == other.node + && self.source == other.source + && self.dest == other.dest + && crate::util::partial_cmp_equal(&self.left, &other.left) + && crate::util::partial_cmp_equal(&self.right, &other.right) + && crate::util::partial_cmp_equal(&self.time, &other.time) + && optional_container_comparison!(self.metadata, other.metadata) + } +} + +impl PartialEq> for MigrationTableRow { + fn eq(&self, other: &MigrationTableRowView) -> bool { + self.id == other.id + && self.node == other.node + && self.source == other.source + && self.dest == other.dest + && crate::util::partial_cmp_equal(&self.left, &other.left) + && crate::util::partial_cmp_equal(&self.right, &other.right) + && crate::util::partial_cmp_equal(&self.time, &other.time) + && optional_container_comparison!(self.metadata, other.metadata) + } +} + impl<'a> streaming_iterator::StreamingIterator for MigrationTableRowView<'a> { type Item = Self; @@ -123,6 +165,7 @@ impl<'a> streaming_iterator::StreamingIterator for MigrationTableRowView<'a> { /// These are not created directly but are accessed /// by types implementing [`std::ops::Deref`] to /// [`crate::table_views::TableViews`] +#[derive(Debug)] pub struct MigrationTable { table_: NonNull, } diff --git a/src/mutation_table.rs b/src/mutation_table.rs index 6de31ef66..e78582d32 100644 --- a/src/mutation_table.rs +++ b/src/mutation_table.rs @@ -76,6 +76,7 @@ impl Iterator for MutationTableIterator { } } +#[derive(Debug)] pub struct MutationTableRowView<'a> { table: &'a MutationTable, pub id: MutationId, @@ -102,6 +103,44 @@ impl<'a> MutationTableRowView<'a> { } } +impl<'a> PartialEq for MutationTableRowView<'a> { + fn eq(&self, other: &Self) -> bool { + self.id == other.id + && self.site == other.site + && self.node == other.node + && self.parent == other.parent + && crate::util::partial_cmp_equal(&self.time, &other.time) + && self.derived_state == other.derived_state + && self.metadata == other.metadata + } +} + +impl<'a> Eq for MutationTableRowView<'a> {} + +impl<'a> PartialEq for MutationTableRowView<'a> { + fn eq(&self, other: &MutationTableRow) -> bool { + self.id == other.id + && self.site == other.site + && self.node == other.node + && self.parent == other.parent + && crate::util::partial_cmp_equal(&self.time, &other.time) + && optional_container_comparison!(self.derived_state, other.derived_state) + && optional_container_comparison!(self.metadata, other.metadata) + } +} + +impl PartialEq> for MutationTableRow { + fn eq(&self, other: &MutationTableRowView) -> bool { + self.id == other.id + && self.site == other.site + && self.node == other.node + && self.parent == other.parent + && crate::util::partial_cmp_equal(&self.time, &other.time) + && optional_container_comparison!(self.derived_state, other.derived_state) + && optional_container_comparison!(self.metadata, other.metadata) + } +} + impl<'a> streaming_iterator::StreamingIterator for MutationTableRowView<'a> { type Item = Self; @@ -123,6 +162,7 @@ impl<'a> streaming_iterator::StreamingIterator for MutationTableRowView<'a> { /// These are not created directly but are accessed /// by types implementing [`std::ops::Deref`] to /// [`crate::table_views::TableViews`] +#[derive(Debug)] pub struct MutationTable { table_: NonNull, } diff --git a/src/node_table.rs b/src/node_table.rs index 272fd8298..023bce321 100644 --- a/src/node_table.rs +++ b/src/node_table.rs @@ -66,6 +66,7 @@ impl Iterator for NodeTableIterator { } } +#[derive(Debug)] pub struct NodeTableRowView<'a> { table: &'a NodeTable, pub id: NodeId, @@ -90,6 +91,41 @@ impl<'a> NodeTableRowView<'a> { } } +impl<'a> PartialEq for NodeTableRowView<'a> { + fn eq(&self, other: &Self) -> bool { + self.id == other.id + && self.flags == other.flags + && self.population == other.population + && self.individual == other.individual + && crate::util::partial_cmp_equal(&self.time, &other.time) + && self.metadata == other.metadata + } +} + +impl<'a> Eq for NodeTableRowView<'a> {} + +impl<'a> PartialEq for NodeTableRowView<'a> { + fn eq(&self, other: &NodeTableRow) -> bool { + self.id == other.id + && self.flags == other.flags + && self.population == other.population + && self.individual == other.individual + && crate::util::partial_cmp_equal(&self.time, &other.time) + && optional_container_comparison!(self.metadata, other.metadata) + } +} + +impl PartialEq> for NodeTableRow { + fn eq(&self, other: &NodeTableRowView) -> bool { + self.id == other.id + && self.flags == other.flags + && self.population == other.population + && self.individual == other.individual + && crate::util::partial_cmp_equal(&self.time, &other.time) + && optional_container_comparison!(self.metadata, other.metadata) + } +} + impl<'a> streaming_iterator::StreamingIterator for NodeTableRowView<'a> { type Item = Self; @@ -110,6 +146,7 @@ impl<'a> streaming_iterator::StreamingIterator for NodeTableRowView<'a> { /// These are not created directly but are accessed /// by types implementing [`std::ops::Deref`] to /// [`crate::table_views::TableViews`] +#[derive(Debug)] pub struct NodeTable { table_: NonNull, } diff --git a/src/population_table.rs b/src/population_table.rs index d8b25a285..eb94200ff 100644 --- a/src/population_table.rs +++ b/src/population_table.rs @@ -61,6 +61,7 @@ impl Iterator for PopulationTableIterator { } } +#[derive(Debug)] pub struct PopulationTableRowView<'a> { table: &'a PopulationTable, pub id: PopulationId, @@ -77,6 +78,26 @@ impl<'a> PopulationTableRowView<'a> { } } +impl<'a> PartialEq for PopulationTableRowView<'a> { + fn eq(&self, other: &Self) -> bool { + self.id == other.id && self.metadata == other.metadata + } +} + +impl<'a> Eq for PopulationTableRowView<'a> {} + +impl<'a> PartialEq for PopulationTableRowView<'a> { + fn eq(&self, other: &PopulationTableRow) -> bool { + self.id == other.id && optional_container_comparison!(self.metadata, other.metadata) + } +} + +impl PartialEq> for PopulationTableRow { + fn eq(&self, other: &PopulationTableRowView) -> bool { + self.id == other.id && optional_container_comparison!(self.metadata, other.metadata) + } +} + impl<'a> streaming_iterator::StreamingIterator for PopulationTableRowView<'a> { type Item = Self; @@ -94,6 +115,7 @@ impl<'a> streaming_iterator::StreamingIterator for PopulationTableRowView<'a> { /// by types implementing [`std::ops::Deref`] to /// [`crate::table_views::TableViews`] #[repr(transparent)] +#[derive(Debug)] pub struct PopulationTable { table_: NonNull, } diff --git a/src/site_table.rs b/src/site_table.rs index 97f7da738..f33c396b5 100644 --- a/src/site_table.rs +++ b/src/site_table.rs @@ -61,6 +61,7 @@ impl Iterator for SiteTableIterator { } } +#[derive(Debug)] pub struct SiteTableRowView<'a> { table: &'a SiteTable, pub id: SiteId, @@ -81,6 +82,35 @@ impl<'a> SiteTableRowView<'a> { } } +impl<'a> PartialEq for SiteTableRowView<'a> { + fn eq(&self, other: &Self) -> bool { + self.id == other.id + && crate::util::partial_cmp_equal(&self.position, &other.position) + && self.ancestral_state == other.ancestral_state + && self.metadata == other.metadata + } +} + +impl<'a> Eq for SiteTableRowView<'a> {} + +impl<'a> PartialEq for SiteTableRowView<'a> { + fn eq(&self, other: &SiteTableRow) -> bool { + self.id == other.id + && crate::util::partial_cmp_equal(&self.position, &other.position) + && optional_container_comparison!(self.ancestral_state, other.ancestral_state) + && optional_container_comparison!(self.metadata, other.metadata) + } +} + +impl PartialEq> for SiteTableRow { + fn eq(&self, other: &SiteTableRowView) -> bool { + self.id == other.id + && crate::util::partial_cmp_equal(&self.position, &other.position) + && optional_container_comparison!(self.ancestral_state, other.ancestral_state) + && optional_container_comparison!(self.metadata, other.metadata) + } +} + impl<'a> streaming_iterator::StreamingIterator for SiteTableRowView<'a> { type Item = Self; @@ -102,6 +132,7 @@ impl<'a> streaming_iterator::StreamingIterator for SiteTableRowView<'a> { /// These are not created directly but are accessed /// by types implementing [`std::ops::Deref`] to /// [`crate::table_views::TableViews`] +#[derive(Debug)] pub struct SiteTable { table_: NonNull, } diff --git a/tests/test_tables.rs b/tests/test_tables.rs index 315dcdce7..715a615e1 100644 --- a/tests/test_tables.rs +++ b/tests/test_tables.rs @@ -326,7 +326,8 @@ mod test_metadata_round_trips { let mut iter = tables.$table().iter(); while let Some(row) = lending_iter.next() { if let Some(row_from_iter) = iter.next() { - assert_eq!(row.id, row_from_iter.id); + assert_eq!(row, &row_from_iter); + assert_eq!(&row_from_iter, row); } if let Some(metadata) = row.metadata { assert_eq!(MyMetadata::decode(metadata).unwrap(), md); @@ -349,7 +350,8 @@ mod test_metadata_round_trips { let mut iter = tables.$table().iter(); while let Some(row) = lending_iter.next() { if let Some(row_from_iter) = iter.next() { - assert_eq!(row.id, row_from_iter.id); + assert_eq!(row, &row_from_iter); + assert_eq!(&row_from_iter, row); } if let Some(metadata) = row.metadata { assert_eq!(MyMetadata::decode(metadata).unwrap(), md);