|
32 | 32 | import pickle
|
33 | 33 | import platform
|
34 | 34 | import random
|
| 35 | +import re |
35 | 36 | import struct
|
36 | 37 | import time
|
37 | 38 | import unittest
|
@@ -73,7 +74,7 @@ def get_input(self, n):
|
73 | 74 | class CharColumn(Column):
|
74 | 75 | def get_input(self, n):
|
75 | 76 | rng = np.random.RandomState(42)
|
76 |
| - return rng.randint(low=0, high=127, size=n, dtype=np.int8) |
| 77 | + return rng.randint(low=65, high=122, size=n, dtype=np.int8) |
77 | 78 |
|
78 | 79 |
|
79 | 80 | class DoubleColumn(Column):
|
@@ -1009,7 +1010,158 @@ def test_set_with_optional_properties(self, codec):
|
1009 | 1010 | assert md == row.metadata
|
1010 | 1011 |
|
1011 | 1012 |
|
1012 |
| -class TestIndividualTable(CommonTestsMixin, MetadataTestsMixin): |
| 1013 | +class AssertEqualsMixin: |
| 1014 | + @pytest.fixture |
| 1015 | + def test_rows(self, scope="session"): |
| 1016 | + test_rows = self.make_transposed_input_data(10) |
| 1017 | + # Annoyingly we have to tweak some types as once added to a row and then put in |
| 1018 | + # an error message things come out differently |
| 1019 | + for n in range(10): |
| 1020 | + for col in test_rows[n].keys(): |
| 1021 | + if col in ["timestamp", "record", "ancestral_state", "derived_state"]: |
| 1022 | + test_rows[n][col] = bytes(test_rows[n][col]).decode("ascii") |
| 1023 | + return test_rows |
| 1024 | + |
| 1025 | + @pytest.fixture |
| 1026 | + def table1(self, test_rows): |
| 1027 | + table1 = self.table_class() |
| 1028 | + for row in test_rows[:5]: |
| 1029 | + table1.add_row(**row) |
| 1030 | + return table1 |
| 1031 | + |
| 1032 | + def test_equal(self, table1, test_rows): |
| 1033 | + table2 = self.table_class() |
| 1034 | + for row in test_rows[:5]: |
| 1035 | + table2.add_row(**row) |
| 1036 | + table1.assert_equals(table2) |
| 1037 | + |
| 1038 | + def test_type(self, table1): |
| 1039 | + with pytest.raises( |
| 1040 | + AssertionError, |
| 1041 | + match=f"Types differ: self={type(table1)} other=<class 'int'>", |
| 1042 | + ): |
| 1043 | + table1.assert_equals(42) |
| 1044 | + |
| 1045 | + def test_metadata_schema(self, table1): |
| 1046 | + if hasattr(table1, "metadata_schema"): |
| 1047 | + table2 = table1.copy() |
| 1048 | + table2.metadata_schema = tskit.MetadataSchema({"codec": "json"}) |
| 1049 | + with pytest.raises( |
| 1050 | + AssertionError, |
| 1051 | + match=f"{type(table1).__name__} metadata schemas differ: self=None " |
| 1052 | + f"other=OrderedDict([('codec', " |
| 1053 | + "'json')])", |
| 1054 | + ): |
| 1055 | + table1.assert_equals(table2) |
| 1056 | + table1.assert_equals(table2, ignore_metadata=True) |
| 1057 | + |
| 1058 | + def test_row_changes(self, table1, test_rows): |
| 1059 | + for column_name in test_rows[0].keys(): |
| 1060 | + table2 = self.table_class() |
| 1061 | + for row in test_rows[:4]: |
| 1062 | + table2.add_row(**row) |
| 1063 | + modified_row = { |
| 1064 | + **test_rows[4], |
| 1065 | + **{column_name: test_rows[5][column_name]}, |
| 1066 | + } |
| 1067 | + table2.add_row(**modified_row) |
| 1068 | + with pytest.raises( |
| 1069 | + AssertionError, |
| 1070 | + match=re.escape( |
| 1071 | + f"{type(table1).__name__} row 4 differs:\n" |
| 1072 | + f"self.{column_name}={test_rows[4][column_name]} " |
| 1073 | + f"other.{column_name}={test_rows[5][column_name]}" |
| 1074 | + ), |
| 1075 | + ): |
| 1076 | + table1.assert_equals(table2) |
| 1077 | + |
| 1078 | + # Two columns differ, as we don't know the order in the error message |
| 1079 | + # test for both independantly |
| 1080 | + for column_name, column_name2 in zip( |
| 1081 | + list(test_rows[0].keys())[:-1], list(test_rows[0].keys())[1:] |
| 1082 | + ): |
| 1083 | + table2 = self.table_class() |
| 1084 | + for row in test_rows[:4]: |
| 1085 | + table2.add_row(**row) |
| 1086 | + modified_row = { |
| 1087 | + **test_rows[4], |
| 1088 | + **{ |
| 1089 | + column_name: test_rows[5][column_name], |
| 1090 | + column_name2: test_rows[5][column_name2], |
| 1091 | + }, |
| 1092 | + } |
| 1093 | + table2.add_row(**modified_row) |
| 1094 | + with pytest.raises( |
| 1095 | + AssertionError, |
| 1096 | + match=re.escape( |
| 1097 | + f"self.{column_name}={test_rows[4][column_name]} " |
| 1098 | + f"other.{column_name}={test_rows[5][column_name]}" |
| 1099 | + ), |
| 1100 | + ): |
| 1101 | + table1.assert_equals(table2) |
| 1102 | + with pytest.raises( |
| 1103 | + AssertionError, |
| 1104 | + match=re.escape( |
| 1105 | + f"self.{column_name2}={test_rows[4][column_name2]} " |
| 1106 | + f"other.{column_name2}={test_rows[5][column_name2]}" |
| 1107 | + ), |
| 1108 | + ): |
| 1109 | + table1.assert_equals(table2) |
| 1110 | + |
| 1111 | + def test_num_rows(self, table1, test_rows): |
| 1112 | + table2 = self.table_class() |
| 1113 | + for row in test_rows[:4]: |
| 1114 | + table2.add_row(**row) |
| 1115 | + with pytest.raises( |
| 1116 | + AssertionError, |
| 1117 | + match=f"{type(table1).__name__} number of rows differ: self=5 other=4", |
| 1118 | + ): |
| 1119 | + table1.assert_equals(table2) |
| 1120 | + |
| 1121 | + def test_metadata(self, table1, test_rows): |
| 1122 | + if "metadata" in test_rows[0].keys(): |
| 1123 | + table2 = self.table_class() |
| 1124 | + for row in test_rows[:4]: |
| 1125 | + table2.add_row(**row) |
| 1126 | + modified_row = { |
| 1127 | + **test_rows[4], |
| 1128 | + **{"metadata": test_rows[5]["metadata"]}, |
| 1129 | + } |
| 1130 | + table2.add_row(**modified_row) |
| 1131 | + with pytest.raises( |
| 1132 | + AssertionError, |
| 1133 | + match=re.escape( |
| 1134 | + f"{type(table1).__name__} row 4 differs:\n" |
| 1135 | + f"self.metadata={test_rows[4]['metadata']} " |
| 1136 | + f"other.metadata={test_rows[5]['metadata']}" |
| 1137 | + ), |
| 1138 | + ): |
| 1139 | + table1.assert_equals(table2) |
| 1140 | + table1.assert_equals(table2, ignore_metadata=True) |
| 1141 | + |
| 1142 | + def test_timestamp(self, table1, test_rows): |
| 1143 | + if "timestamp" in test_rows[0].keys(): |
| 1144 | + table2 = self.table_class() |
| 1145 | + for row in test_rows[:4]: |
| 1146 | + table2.add_row(**row) |
| 1147 | + modified_row = { |
| 1148 | + **test_rows[4], |
| 1149 | + **{"timestamp": test_rows[5]["timestamp"]}, |
| 1150 | + } |
| 1151 | + table2.add_row(**modified_row) |
| 1152 | + with pytest.raises( |
| 1153 | + AssertionError, |
| 1154 | + match=re.escape( |
| 1155 | + f"{type(table1).__name__} row 4 differs:\n" |
| 1156 | + f"self.timestamp={test_rows[4]['timestamp']} " |
| 1157 | + f"other.timestamp={test_rows[5]['timestamp']}" |
| 1158 | + ), |
| 1159 | + ): |
| 1160 | + table1.assert_equals(table2) |
| 1161 | + table1.assert_equals(table2, ignore_timestamps=True) |
| 1162 | + |
| 1163 | + |
| 1164 | +class TestIndividualTable(CommonTestsMixin, MetadataTestsMixin, AssertEqualsMixin): |
1013 | 1165 | columns = [UInt32Column("flags")]
|
1014 | 1166 | ragged_list_columns = [
|
1015 | 1167 | (DoubleColumn("location"), UInt32Column("location_offset")),
|
@@ -1138,7 +1290,7 @@ def test_various_not_equals(self):
|
1138 | 1290 | assert a == b
|
1139 | 1291 |
|
1140 | 1292 |
|
1141 |
| -class TestNodeTable(CommonTestsMixin, MetadataTestsMixin): |
| 1293 | +class TestNodeTable(CommonTestsMixin, MetadataTestsMixin, AssertEqualsMixin): |
1142 | 1294 |
|
1143 | 1295 | columns = [
|
1144 | 1296 | UInt32Column("flags"),
|
@@ -1219,7 +1371,7 @@ def test_add_row_bad_data(self):
|
1219 | 1371 | t.add_row(metadata=123)
|
1220 | 1372 |
|
1221 | 1373 |
|
1222 |
| -class TestEdgeTable(CommonTestsMixin, MetadataTestsMixin): |
| 1374 | +class TestEdgeTable(CommonTestsMixin, MetadataTestsMixin, AssertEqualsMixin): |
1223 | 1375 |
|
1224 | 1376 | columns = [
|
1225 | 1377 | DoubleColumn("left"),
|
@@ -1267,7 +1419,7 @@ def test_add_row_bad_data(self):
|
1267 | 1419 | t.add_row(0, 0, 0, 0, metadata=123)
|
1268 | 1420 |
|
1269 | 1421 |
|
1270 |
| -class TestSiteTable(CommonTestsMixin, MetadataTestsMixin): |
| 1422 | +class TestSiteTable(CommonTestsMixin, MetadataTestsMixin, AssertEqualsMixin): |
1271 | 1423 | columns = [DoubleColumn("position")]
|
1272 | 1424 | ragged_list_columns = [
|
1273 | 1425 | (CharColumn("ancestral_state"), UInt32Column("ancestral_state_offset")),
|
@@ -1322,7 +1474,7 @@ def test_packset_ancestral_state(self):
|
1322 | 1474 | assert np.array_equal(table.ancestral_state_offset, ancestral_state_offset)
|
1323 | 1475 |
|
1324 | 1476 |
|
1325 |
| -class TestMutationTable(CommonTestsMixin, MetadataTestsMixin): |
| 1477 | +class TestMutationTable(CommonTestsMixin, MetadataTestsMixin, AssertEqualsMixin): |
1326 | 1478 | columns = [
|
1327 | 1479 | Int32Column("site"),
|
1328 | 1480 | Int32Column("node"),
|
@@ -1394,7 +1546,7 @@ def test_packset_derived_state(self):
|
1394 | 1546 | assert np.array_equal(table.derived_state_offset, derived_state_offset)
|
1395 | 1547 |
|
1396 | 1548 |
|
1397 |
| -class TestMigrationTable(CommonTestsMixin, MetadataTestsMixin): |
| 1549 | +class TestMigrationTable(CommonTestsMixin, MetadataTestsMixin, AssertEqualsMixin): |
1398 | 1550 | columns = [
|
1399 | 1551 | DoubleColumn("left"),
|
1400 | 1552 | DoubleColumn("right"),
|
@@ -1445,7 +1597,7 @@ def test_add_row_bad_data(self):
|
1445 | 1597 | t.add_row(0, 0, 0, 0, 0, 0, metadata=123)
|
1446 | 1598 |
|
1447 | 1599 |
|
1448 |
| -class TestProvenanceTable(CommonTestsMixin): |
| 1600 | +class TestProvenanceTable(CommonTestsMixin, AssertEqualsMixin): |
1449 | 1601 | columns = []
|
1450 | 1602 | ragged_list_columns = [
|
1451 | 1603 | (CharColumn("timestamp"), UInt32Column("timestamp_offset")),
|
@@ -1496,7 +1648,7 @@ def test_packset_record(self):
|
1496 | 1648 | assert t[1].record == "BBBB"
|
1497 | 1649 |
|
1498 | 1650 |
|
1499 |
| -class TestPopulationTable(CommonTestsMixin, MetadataTestsMixin): |
| 1651 | +class TestPopulationTable(CommonTestsMixin, MetadataTestsMixin, AssertEqualsMixin): |
1500 | 1652 | metadata_mandatory = True
|
1501 | 1653 | columns = []
|
1502 | 1654 | ragged_list_columns = [(CharColumn("metadata"), UInt32Column("metadata_offset"))]
|
@@ -3307,6 +3459,122 @@ def test_equals_population_metadata(self, ts_fixture):
|
3307 | 3459 | assert t1.equals(t2, ignore_metadata=True)
|
3308 | 3460 |
|
3309 | 3461 |
|
| 3462 | +class TestTableCollectionAssertEquals: |
| 3463 | + @pytest.fixture |
| 3464 | + def t1(self, ts_fixture): |
| 3465 | + return ts_fixture.dump_tables() |
| 3466 | + |
| 3467 | + @pytest.fixture |
| 3468 | + def t2(self, ts_fixture): |
| 3469 | + return ts_fixture.dump_tables() |
| 3470 | + |
| 3471 | + def test_equal(self, t1, t2): |
| 3472 | + assert t1 is not t2 |
| 3473 | + t1.assert_equals(t2) |
| 3474 | + |
| 3475 | + def test_type(self, t1): |
| 3476 | + with pytest.raises( |
| 3477 | + AssertionError, |
| 3478 | + match=re.escape( |
| 3479 | + "Types differ: self=<class 'tskit.tables.TableCollection'> " |
| 3480 | + "other=<class 'int'>" |
| 3481 | + ), |
| 3482 | + ): |
| 3483 | + t1.assert_equals(42) |
| 3484 | + |
| 3485 | + def test_sequence_length(self, t1, t2): |
| 3486 | + t2.sequence_length = 42 |
| 3487 | + with pytest.raises( |
| 3488 | + AssertionError, match="Sequence Length differs: self=1.0 other=42.0" |
| 3489 | + ): |
| 3490 | + t1.assert_equals(t2) |
| 3491 | + |
| 3492 | + def test_metadata_schema(self, t1, t2): |
| 3493 | + t2.metadata_schema = tskit.MetadataSchema(None) |
| 3494 | + with pytest.raises( |
| 3495 | + AssertionError, |
| 3496 | + match=re.escape( |
| 3497 | + "Metadata schemas differ: self=OrderedDict([('codec', 'json')]) " |
| 3498 | + "other=None" |
| 3499 | + ), |
| 3500 | + ): |
| 3501 | + t1.assert_equals(t2) |
| 3502 | + t1.assert_equals(t2, ignore_metadata=True) |
| 3503 | + t1.assert_equals(t2, ignore_ts_metadata=True) |
| 3504 | + |
| 3505 | + def test_metadata(self, t1, t2): |
| 3506 | + t2.metadata = {"foo": "bar"} |
| 3507 | + with pytest.raises( |
| 3508 | + AssertionError, |
| 3509 | + match=re.escape( |
| 3510 | + "Metadata differs: self=Test metadata other={'foo': 'bar'}" |
| 3511 | + ), |
| 3512 | + ): |
| 3513 | + t1.assert_equals(t2) |
| 3514 | + t1.assert_equals(t2, ignore_metadata=True) |
| 3515 | + t1.assert_equals(t2, ignore_ts_metadata=True) |
| 3516 | + |
| 3517 | + @pytest.mark.parametrize("table_name", tskit.TableCollection(1).name_map) |
| 3518 | + def test_tables(self, t1, t2, table_name): |
| 3519 | + table = getattr(t2, table_name) |
| 3520 | + table.truncate(0) |
| 3521 | + with pytest.raises( |
| 3522 | + AssertionError, |
| 3523 | + match=f"{type(table).__name__} number of rows differ: " |
| 3524 | + f"self={len(getattr(t1, table_name))} other=0", |
| 3525 | + ): |
| 3526 | + t1.assert_equals(t2) |
| 3527 | + |
| 3528 | + @pytest.mark.parametrize("table_name", tskit.TableCollection(1).name_map) |
| 3529 | + def test_ignore_metadata(self, t1, t2, table_name): |
| 3530 | + table = getattr(t2, table_name) |
| 3531 | + if hasattr(table, "metadata_schema"): |
| 3532 | + table.metadata_schema = tskit.MetadataSchema(None) |
| 3533 | + with pytest.raises( |
| 3534 | + AssertionError, |
| 3535 | + match=re.escape( |
| 3536 | + f"{type(table).__name__} metadata schemas differ: " |
| 3537 | + f"self=OrderedDict([('codec', 'json')]) other=None" |
| 3538 | + ), |
| 3539 | + ): |
| 3540 | + t1.assert_equals(t2) |
| 3541 | + t1.assert_equals(t2, ignore_metadata=True) |
| 3542 | + |
| 3543 | + def test_ignore_provenance(self, t1, t2): |
| 3544 | + t2.provenances.truncate(0) |
| 3545 | + with pytest.raises( |
| 3546 | + AssertionError, |
| 3547 | + match="ProvenanceTable number of rows differ: self=1 other=0", |
| 3548 | + ): |
| 3549 | + t1.assert_equals(t2) |
| 3550 | + with pytest.raises( |
| 3551 | + AssertionError, |
| 3552 | + match="ProvenanceTable number of rows differ: self=1 other=0", |
| 3553 | + ): |
| 3554 | + t1.assert_equals(t2, ignore_timestamps=True) |
| 3555 | + |
| 3556 | + t1.assert_equals(t2, ignore_provenance=True) |
| 3557 | + |
| 3558 | + def test_ignore_timestamps(self, t1, t2): |
| 3559 | + table = t2.provenances |
| 3560 | + timestamp = table.timestamp |
| 3561 | + timestamp[0] = ord("F") |
| 3562 | + table.set_columns( |
| 3563 | + timestamp=timestamp, |
| 3564 | + timestamp_offset=table.timestamp_offset, |
| 3565 | + record=table.record, |
| 3566 | + record_offset=table.record_offset, |
| 3567 | + ) |
| 3568 | + with pytest.raises( |
| 3569 | + AssertionError, |
| 3570 | + match="ProvenanceTable row 0 differs:\n" |
| 3571 | + "self.timestamp=.* other.timestamp=F.*", |
| 3572 | + ): |
| 3573 | + t1.assert_equals(t2) |
| 3574 | + t1.assert_equals(t2, ignore_provenance=True) |
| 3575 | + t1.assert_equals(t2, ignore_timestamps=True) |
| 3576 | + |
| 3577 | + |
3310 | 3578 | class TestTableCollectionMethodSignatures:
|
3311 | 3579 | tc = msprime.simulate(10, random_seed=1234).dump_tables()
|
3312 | 3580 |
|
|
0 commit comments