Skip to content

Commit 061aaf5

Browse files
committed
Add assert_equal methods
1 parent 4ecb8cb commit 061aaf5

File tree

4 files changed

+454
-10
lines changed

4 files changed

+454
-10
lines changed

python/CHANGELOG.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,9 @@
3131
- Improve display of tables when ``print``ed, limiting lines set via
3232
``tskit.set_print_options`` (:user:`benjeffery`,:issue:`1270`, :pr:`1300`).
3333

34+
- Add ``Table.assert_equals`` and ``TableCollection.assert_equals`` which give an exact
35+
report of any differences. (:user:`benjeffery`,:issue:`1076`, :pr:`1328`)
36+
3437
**Fixes**
3538

3639
- Tree sequences were not properly init'd after unpickling

python/tests/test_tables.py

Lines changed: 277 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
import pickle
3333
import platform
3434
import random
35+
import re
3536
import struct
3637
import time
3738
import unittest
@@ -73,7 +74,7 @@ def get_input(self, n):
7374
class CharColumn(Column):
7475
def get_input(self, n):
7576
rng = np.random.RandomState(42)
76-
return rng.randint(low=0, high=127, size=n, dtype=np.int8)
77+
return rng.randint(low=65, high=122, size=n, dtype=np.int8)
7778

7879

7980
class DoubleColumn(Column):
@@ -1009,7 +1010,158 @@ def test_set_with_optional_properties(self, codec):
10091010
assert md == row.metadata
10101011

10111012

1012-
class TestIndividualTable(CommonTestsMixin, MetadataTestsMixin):
1013+
class AssertEqualsMixin:
1014+
@pytest.fixture
1015+
def test_rows(self, scope="session"):
1016+
test_rows = self.make_transposed_input_data(10)
1017+
# Annoyingly we have to tweak some types as once added to a row and then put in
1018+
# an error message things come out differently
1019+
for n in range(10):
1020+
for col in test_rows[n].keys():
1021+
if col in ["timestamp", "record", "ancestral_state", "derived_state"]:
1022+
test_rows[n][col] = bytes(test_rows[n][col]).decode("ascii")
1023+
return test_rows
1024+
1025+
@pytest.fixture
1026+
def table1(self, test_rows):
1027+
table1 = self.table_class()
1028+
for row in test_rows[:5]:
1029+
table1.add_row(**row)
1030+
return table1
1031+
1032+
def test_equal(self, table1, test_rows):
1033+
table2 = self.table_class()
1034+
for row in test_rows[:5]:
1035+
table2.add_row(**row)
1036+
table1.assert_equals(table2)
1037+
1038+
def test_type(self, table1):
1039+
with pytest.raises(
1040+
AssertionError,
1041+
match=f"Types differ: self={type(table1)} other=<class 'int'>",
1042+
):
1043+
table1.assert_equals(42)
1044+
1045+
def test_metadata_schema(self, table1):
1046+
if hasattr(table1, "metadata_schema"):
1047+
table2 = table1.copy()
1048+
table2.metadata_schema = tskit.MetadataSchema({"codec": "json"})
1049+
with pytest.raises(
1050+
AssertionError,
1051+
match=f"{type(table1).__name__} metadata schemas differ: self=None "
1052+
f"other=OrderedDict([('codec', "
1053+
"'json')])",
1054+
):
1055+
table1.assert_equals(table2)
1056+
table1.assert_equals(table2, ignore_metadata=True)
1057+
1058+
def test_row_changes(self, table1, test_rows):
1059+
for column_name in test_rows[0].keys():
1060+
table2 = self.table_class()
1061+
for row in test_rows[:4]:
1062+
table2.add_row(**row)
1063+
modified_row = {
1064+
**test_rows[4],
1065+
**{column_name: test_rows[5][column_name]},
1066+
}
1067+
table2.add_row(**modified_row)
1068+
with pytest.raises(
1069+
AssertionError,
1070+
match=re.escape(
1071+
f"{type(table1).__name__} row 4 differs:\n"
1072+
f"self.{column_name}={test_rows[4][column_name]} "
1073+
f"other.{column_name}={test_rows[5][column_name]}"
1074+
),
1075+
):
1076+
table1.assert_equals(table2)
1077+
1078+
# Two columns differ, as we don't know the order in the error message
1079+
# test for both independantly
1080+
for column_name, column_name2 in zip(
1081+
list(test_rows[0].keys())[:-1], list(test_rows[0].keys())[1:]
1082+
):
1083+
table2 = self.table_class()
1084+
for row in test_rows[:4]:
1085+
table2.add_row(**row)
1086+
modified_row = {
1087+
**test_rows[4],
1088+
**{
1089+
column_name: test_rows[5][column_name],
1090+
column_name2: test_rows[5][column_name2],
1091+
},
1092+
}
1093+
table2.add_row(**modified_row)
1094+
with pytest.raises(
1095+
AssertionError,
1096+
match=re.escape(
1097+
f"self.{column_name}={test_rows[4][column_name]} "
1098+
f"other.{column_name}={test_rows[5][column_name]}"
1099+
),
1100+
):
1101+
table1.assert_equals(table2)
1102+
with pytest.raises(
1103+
AssertionError,
1104+
match=re.escape(
1105+
f"self.{column_name2}={test_rows[4][column_name2]} "
1106+
f"other.{column_name2}={test_rows[5][column_name2]}"
1107+
),
1108+
):
1109+
table1.assert_equals(table2)
1110+
1111+
def test_num_rows(self, table1, test_rows):
1112+
table2 = self.table_class()
1113+
for row in test_rows[:4]:
1114+
table2.add_row(**row)
1115+
with pytest.raises(
1116+
AssertionError,
1117+
match=f"{type(table1).__name__} number of rows differ: self=5 other=4",
1118+
):
1119+
table1.assert_equals(table2)
1120+
1121+
def test_metadata(self, table1, test_rows):
1122+
if "metadata" in test_rows[0].keys():
1123+
table2 = self.table_class()
1124+
for row in test_rows[:4]:
1125+
table2.add_row(**row)
1126+
modified_row = {
1127+
**test_rows[4],
1128+
**{"metadata": test_rows[5]["metadata"]},
1129+
}
1130+
table2.add_row(**modified_row)
1131+
with pytest.raises(
1132+
AssertionError,
1133+
match=re.escape(
1134+
f"{type(table1).__name__} row 4 differs:\n"
1135+
f"self.metadata={test_rows[4]['metadata']} "
1136+
f"other.metadata={test_rows[5]['metadata']}"
1137+
),
1138+
):
1139+
table1.assert_equals(table2)
1140+
table1.assert_equals(table2, ignore_metadata=True)
1141+
1142+
def test_timestamp(self, table1, test_rows):
1143+
if "timestamp" in test_rows[0].keys():
1144+
table2 = self.table_class()
1145+
for row in test_rows[:4]:
1146+
table2.add_row(**row)
1147+
modified_row = {
1148+
**test_rows[4],
1149+
**{"timestamp": test_rows[5]["timestamp"]},
1150+
}
1151+
table2.add_row(**modified_row)
1152+
with pytest.raises(
1153+
AssertionError,
1154+
match=re.escape(
1155+
f"{type(table1).__name__} row 4 differs:\n"
1156+
f"self.timestamp={test_rows[4]['timestamp']} "
1157+
f"other.timestamp={test_rows[5]['timestamp']}"
1158+
),
1159+
):
1160+
table1.assert_equals(table2)
1161+
table1.assert_equals(table2, ignore_timestamps=True)
1162+
1163+
1164+
class TestIndividualTable(CommonTestsMixin, MetadataTestsMixin, AssertEqualsMixin):
10131165
columns = [UInt32Column("flags")]
10141166
ragged_list_columns = [
10151167
(DoubleColumn("location"), UInt32Column("location_offset")),
@@ -1138,7 +1290,7 @@ def test_various_not_equals(self):
11381290
assert a == b
11391291

11401292

1141-
class TestNodeTable(CommonTestsMixin, MetadataTestsMixin):
1293+
class TestNodeTable(CommonTestsMixin, MetadataTestsMixin, AssertEqualsMixin):
11421294

11431295
columns = [
11441296
UInt32Column("flags"),
@@ -1219,7 +1371,7 @@ def test_add_row_bad_data(self):
12191371
t.add_row(metadata=123)
12201372

12211373

1222-
class TestEdgeTable(CommonTestsMixin, MetadataTestsMixin):
1374+
class TestEdgeTable(CommonTestsMixin, MetadataTestsMixin, AssertEqualsMixin):
12231375

12241376
columns = [
12251377
DoubleColumn("left"),
@@ -1267,7 +1419,7 @@ def test_add_row_bad_data(self):
12671419
t.add_row(0, 0, 0, 0, metadata=123)
12681420

12691421

1270-
class TestSiteTable(CommonTestsMixin, MetadataTestsMixin):
1422+
class TestSiteTable(CommonTestsMixin, MetadataTestsMixin, AssertEqualsMixin):
12711423
columns = [DoubleColumn("position")]
12721424
ragged_list_columns = [
12731425
(CharColumn("ancestral_state"), UInt32Column("ancestral_state_offset")),
@@ -1322,7 +1474,7 @@ def test_packset_ancestral_state(self):
13221474
assert np.array_equal(table.ancestral_state_offset, ancestral_state_offset)
13231475

13241476

1325-
class TestMutationTable(CommonTestsMixin, MetadataTestsMixin):
1477+
class TestMutationTable(CommonTestsMixin, MetadataTestsMixin, AssertEqualsMixin):
13261478
columns = [
13271479
Int32Column("site"),
13281480
Int32Column("node"),
@@ -1394,7 +1546,7 @@ def test_packset_derived_state(self):
13941546
assert np.array_equal(table.derived_state_offset, derived_state_offset)
13951547

13961548

1397-
class TestMigrationTable(CommonTestsMixin, MetadataTestsMixin):
1549+
class TestMigrationTable(CommonTestsMixin, MetadataTestsMixin, AssertEqualsMixin):
13981550
columns = [
13991551
DoubleColumn("left"),
14001552
DoubleColumn("right"),
@@ -1445,7 +1597,7 @@ def test_add_row_bad_data(self):
14451597
t.add_row(0, 0, 0, 0, 0, 0, metadata=123)
14461598

14471599

1448-
class TestProvenanceTable(CommonTestsMixin):
1600+
class TestProvenanceTable(CommonTestsMixin, AssertEqualsMixin):
14491601
columns = []
14501602
ragged_list_columns = [
14511603
(CharColumn("timestamp"), UInt32Column("timestamp_offset")),
@@ -1496,7 +1648,7 @@ def test_packset_record(self):
14961648
assert t[1].record == "BBBB"
14971649

14981650

1499-
class TestPopulationTable(CommonTestsMixin, MetadataTestsMixin):
1651+
class TestPopulationTable(CommonTestsMixin, MetadataTestsMixin, AssertEqualsMixin):
15001652
metadata_mandatory = True
15011653
columns = []
15021654
ragged_list_columns = [(CharColumn("metadata"), UInt32Column("metadata_offset"))]
@@ -3307,6 +3459,122 @@ def test_equals_population_metadata(self, ts_fixture):
33073459
assert t1.equals(t2, ignore_metadata=True)
33083460

33093461

3462+
class TestTableCollectionAssertEquals:
3463+
@pytest.fixture
3464+
def t1(self, ts_fixture):
3465+
return ts_fixture.dump_tables()
3466+
3467+
@pytest.fixture
3468+
def t2(self, ts_fixture):
3469+
return ts_fixture.dump_tables()
3470+
3471+
def test_equal(self, t1, t2):
3472+
assert t1 is not t2
3473+
t1.assert_equals(t2)
3474+
3475+
def test_type(self, t1):
3476+
with pytest.raises(
3477+
AssertionError,
3478+
match=re.escape(
3479+
"Types differ: self=<class 'tskit.tables.TableCollection'> "
3480+
"other=<class 'int'>"
3481+
),
3482+
):
3483+
t1.assert_equals(42)
3484+
3485+
def test_sequence_length(self, t1, t2):
3486+
t2.sequence_length = 42
3487+
with pytest.raises(
3488+
AssertionError, match="Sequence Length differs: self=1.0 other=42.0"
3489+
):
3490+
t1.assert_equals(t2)
3491+
3492+
def test_metadata_schema(self, t1, t2):
3493+
t2.metadata_schema = tskit.MetadataSchema(None)
3494+
with pytest.raises(
3495+
AssertionError,
3496+
match=re.escape(
3497+
"Metadata schemas differ: self=OrderedDict([('codec', 'json')]) "
3498+
"other=None"
3499+
),
3500+
):
3501+
t1.assert_equals(t2)
3502+
t1.assert_equals(t2, ignore_metadata=True)
3503+
t1.assert_equals(t2, ignore_ts_metadata=True)
3504+
3505+
def test_metadata(self, t1, t2):
3506+
t2.metadata = {"foo": "bar"}
3507+
with pytest.raises(
3508+
AssertionError,
3509+
match=re.escape(
3510+
"Metadata differs: self=Test metadata other={'foo': 'bar'}"
3511+
),
3512+
):
3513+
t1.assert_equals(t2)
3514+
t1.assert_equals(t2, ignore_metadata=True)
3515+
t1.assert_equals(t2, ignore_ts_metadata=True)
3516+
3517+
@pytest.mark.parametrize("table_name", tskit.TableCollection(1).name_map)
3518+
def test_tables(self, t1, t2, table_name):
3519+
table = getattr(t2, table_name)
3520+
table.truncate(0)
3521+
with pytest.raises(
3522+
AssertionError,
3523+
match=f"{type(table).__name__} number of rows differ: "
3524+
f"self={len(getattr(t1, table_name))} other=0",
3525+
):
3526+
t1.assert_equals(t2)
3527+
3528+
@pytest.mark.parametrize("table_name", tskit.TableCollection(1).name_map)
3529+
def test_ignore_metadata(self, t1, t2, table_name):
3530+
table = getattr(t2, table_name)
3531+
if hasattr(table, "metadata_schema"):
3532+
table.metadata_schema = tskit.MetadataSchema(None)
3533+
with pytest.raises(
3534+
AssertionError,
3535+
match=re.escape(
3536+
f"{type(table).__name__} metadata schemas differ: "
3537+
f"self=OrderedDict([('codec', 'json')]) other=None"
3538+
),
3539+
):
3540+
t1.assert_equals(t2)
3541+
t1.assert_equals(t2, ignore_metadata=True)
3542+
3543+
def test_ignore_provenance(self, t1, t2):
3544+
t2.provenances.truncate(0)
3545+
with pytest.raises(
3546+
AssertionError,
3547+
match="ProvenanceTable number of rows differ: self=1 other=0",
3548+
):
3549+
t1.assert_equals(t2)
3550+
with pytest.raises(
3551+
AssertionError,
3552+
match="ProvenanceTable number of rows differ: self=1 other=0",
3553+
):
3554+
t1.assert_equals(t2, ignore_timestamps=True)
3555+
3556+
t1.assert_equals(t2, ignore_provenance=True)
3557+
3558+
def test_ignore_timestamps(self, t1, t2):
3559+
table = t2.provenances
3560+
timestamp = table.timestamp
3561+
timestamp[0] = ord("F")
3562+
table.set_columns(
3563+
timestamp=timestamp,
3564+
timestamp_offset=table.timestamp_offset,
3565+
record=table.record,
3566+
record_offset=table.record_offset,
3567+
)
3568+
with pytest.raises(
3569+
AssertionError,
3570+
match="ProvenanceTable row 0 differs:\n"
3571+
"self.timestamp=.* other.timestamp=F.*",
3572+
):
3573+
t1.assert_equals(t2)
3574+
t1.assert_equals(t2, ignore_provenance=True)
3575+
t1.assert_equals(t2, ignore_timestamps=True)
3576+
3577+
33103578
class TestTableCollectionMethodSignatures:
33113579
tc = msprime.simulate(10, random_seed=1234).dump_tables()
33123580

python/tests/test_tree_stats.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -860,7 +860,7 @@ def test_wright_fisher_simplified(self):
860860
)
861861
tables.sort()
862862
ts = tables.tree_sequence().simplify()
863-
ts = msprime.mutate(ts, rate=0.01, random_seed=42)
863+
ts = tsutil.jukes_cantor(ts, 10, 0.01, seed=1)
864864
assert ts.num_sites > 0
865865
self.verify(ts)
866866

0 commit comments

Comments
 (0)