Skip to content
This repository was archived by the owner on May 17, 2024. It is now read-only.

Commit c905430

Browse files
authored
Merge pull request #315 from datafold/adjust_pr314
Adjustments for PR #314
2 parents 779892d + d304e1a commit c905430

File tree

2 files changed

+54
-7
lines changed

2 files changed

+54
-7
lines changed

data_diff/hashdiff_tables.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -25,15 +25,18 @@
2525

2626

2727
def diff_sets(a: set, b: set) -> Iterator:
28-
s1 = set(a)
29-
s2 = set(b)
30-
d = defaultdict(list)
28+
sa = set(a)
29+
sb = set(b)
3130

3231
# The first item is always the key (see TableDiffer.relevant_columns)
33-
for i in s1 - s2:
34-
d[i[0]].append(("-", i))
35-
for i in s2 - s1:
36-
d[i[0]].append(("+", i))
32+
# TODO update when we add compound keys to hashdiff
33+
d = defaultdict(list)
34+
for row in a:
35+
if row not in sb:
36+
d[row[0]].append(("-", row))
37+
for row in b:
38+
if row not in sa:
39+
d[row[0]].append(("+", row))
3740

3841
for _k, v in sorted(d.items(), key=lambda i: i[0]):
3942
yield from v

tests/test_diff_tables.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -745,3 +745,47 @@ def test_info_tree_root(self):
745745
assert info_tree.info.is_diff
746746
assert info_tree.info.diff_count == 1000
747747
self.assertEqual(info_tree.info.rowcounts, {1: 1000, 2: 2000})
748+
749+
750+
class TestDuplicateTables(DiffTestCase):
751+
db_cls = db.MySQL
752+
753+
src_schema = {"id": int, "data": str}
754+
dst_schema = {"id": int, "data": str}
755+
756+
def setUp(self):
757+
"""
758+
table 1:
759+
(12, 'ABCDE'),
760+
(12, 'ABCDE');
761+
table 2:
762+
(4,'ABCDEF'),
763+
(4,'ABCDE'),
764+
(4,'ABCDE'),
765+
(6,'ABCDE'),
766+
(6,'ABCDE'),
767+
(6,'ABCDE');
768+
"""
769+
770+
super().setUp()
771+
772+
src_values = [(12, "ABCDE"), (12, "ABCDE")]
773+
dst_values = [(4, "ABCDEF"), (4, "ABCDE"), (4, "ABCDE"), (6, "ABCDE"), (6, "ABCDE"), (6, "ABCDE")]
774+
775+
self.diffs = [("-", (str(r[0]), r[1])) for r in src_values] + [("+", (str(r[0]), r[1])) for r in dst_values]
776+
777+
self.connection.query([self.src_table.insert_rows(src_values), self.dst_table.insert_rows(dst_values), commit])
778+
779+
self.a = _table_segment(
780+
self.connection, self.table_src_path, "id", extra_columns=("data",), case_sensitive=False
781+
)
782+
self.b = _table_segment(
783+
self.connection, self.table_dst_path, "id", extra_columns=("data",), case_sensitive=False
784+
)
785+
786+
def test_duplicates(self):
787+
"""If there are duplicates in data, we want to return them as well"""
788+
789+
differ = HashDiffer(bisection_factor=2, bisection_threshold=4)
790+
diff = list(differ.diff_tables(self.a, self.b))
791+
self.assertEqual(diff, self.diffs)

0 commit comments

Comments
 (0)