Skip to content
This repository was archived by the owner on May 17, 2024. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion data_diff/databases/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,9 +200,14 @@ def _refine_coltypes(self, table_path: DbPath, col_dict: Dict[str, ColType]):

fields = [self.normalize_uuid(c, ColType_UUID()) for c in text_columns]
samples_by_row = self.query(Select(fields, TableName(table_path), limit=16), list)
if not samples_by_row:
logger.warning(f"Table {table_path} is empty.")
return

samples_by_col = list(zip(*samples_by_row))

for col_name, samples in safezip(text_columns, samples_by_col):
uuid_samples = list(filter(is_uuid, samples))
uuid_samples = [s for s in samples if s and is_uuid(s)]

if uuid_samples:
if len(uuid_samples) != len(samples):
Expand Down
10 changes: 4 additions & 6 deletions data_diff/databases/database_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,11 @@ class FractionalType(NumericType):
class Float(FractionalType):
pass

class IKey(ABC):
"Interface for ColType, for using a column as a key in data-diff"
python_type: type

class Decimal(FractionalType):
class Decimal(FractionalType, IKey): # Snowflake may use Decimal as a key
@property
def python_type(self) -> type:
if self.precision == 0:
Expand All @@ -66,11 +69,6 @@ class StringType(ColType):
pass


class IKey(ABC):
"Interface for ColType, for using a column as a key in data-diff"
python_type: type


class ColType_UUID(StringType, IKey):
python_type = ArithUUID

Expand Down
7 changes: 7 additions & 0 deletions data_diff/diff_tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from .databases.base import Database
from .databases.database_types import (
ArithUUID,
IKey,
NumericType,
PrecisionType,
StringType,
Expand Down Expand Up @@ -212,6 +213,8 @@ def count_and_checksum(self) -> Tuple[int, int]:
"We recommend increasing --bisection-factor or decreasing --threads."
)

if count:
assert checksum, (count, checksum)
return count or 0, checksum if checksum is None else int(checksum)

def query_key_range(self) -> Tuple[int, int]:
Expand Down Expand Up @@ -306,6 +309,10 @@ def diff_tables(self, table1: TableSegment, table2: TableSegment) -> DiffResult:

key_type = table1._schema[table1.key_column]
key_type2 = table2._schema[table2.key_column]
if not isinstance(key_type, IKey):
raise NotImplementedError(f"Cannot use column of type {key_type} as a key")
if not isinstance(key_type2, IKey):
raise NotImplementedError(f"Cannot use column of type {key_type2} as a key")
assert key_type.python_type is key_type2.python_type

# We add 1 because our ranges are exclusive of the end (like in Python)
Expand Down
3 changes: 2 additions & 1 deletion data_diff/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,9 +121,10 @@ class Checksum(Sql):

def compile(self, c: Compiler):
if len(self.exprs) > 1:
compiled_exprs = ", ".join(map(c.compile, self.exprs))
compiled_exprs = ", ".join(f"coalesce({c.compile(expr)}, '<null>')" for expr in self.exprs)
expr = f"concat({compiled_exprs})"
else:
# No need to coalesce - safe to assume that key cannot be null
(expr,) = self.exprs
expr = c.compile(expr)
md5 = c.database.md5_to_int(expr)
Expand Down
186 changes: 186 additions & 0 deletions tests/test_diff_tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,3 +321,189 @@ def test_table_segment(self):
self.assertRaises(ValueError, self.table.replace, min_update=late, max_update=early)

self.assertRaises(ValueError, self.table.replace, min_key=10, max_key=0)


class TestTableUUID(TestWithConnection):
def setUp(self):
super().setUp()

queries = [
f"DROP TABLE IF EXISTS {self.table_src}",
f"DROP TABLE IF EXISTS {self.table_dst}",
f"CREATE TABLE {self.table_src}(id varchar(100), comment varchar(1000))",
]
for i in range(10):
uuid_value = uuid.uuid1(i)
queries.append(f"INSERT INTO {self.table_src} VALUES ('{uuid_value}', '{uuid_value}')")

self.null_uuid = uuid.uuid1(32132131)
queries += [
f"CREATE TABLE {self.table_dst} AS SELECT * FROM {self.table_src}",

f"INSERT INTO {self.table_src} VALUES ('{self.null_uuid}', NULL)",

"COMMIT"
]

for query in queries:
self.connection.query(query, None)

self.a = TableSegment(self.connection, (self.table_src,), "id", "comment")
self.b = TableSegment(self.connection, (self.table_dst,), "id", "comment")

def test_uuid_column_with_nulls(self):
differ = TableDiffer()
diff = list(differ.diff_tables(self.a, self.b))
self.assertEqual(diff, [("-", (str(self.null_uuid), None))])


class TestTableNullRowChecksum(TestWithConnection):
def setUp(self):
super().setUp()

self.null_uuid = uuid.uuid1(1)
queries = [
f"DROP TABLE IF EXISTS {self.table_src}",
f"DROP TABLE IF EXISTS {self.table_dst}",
f"CREATE TABLE {self.table_src}(id varchar(100), comment varchar(1000))",

f"INSERT INTO {self.table_src} VALUES ('{uuid.uuid1(1)}', '1')",

f"CREATE TABLE {self.table_dst} AS SELECT * FROM {self.table_src}",

# Add a row where a column has NULL value
f"INSERT INTO {self.table_src} VALUES ('{self.null_uuid}', NULL)",

"COMMIT"
]

for query in queries:
self.connection.query(query, None)

self.a = TableSegment(self.connection, (self.table_src,), "id", "comment")
self.b = TableSegment(self.connection, (self.table_dst,), "id", "comment")

def test_uuid_columns_with_nulls(self):
"""
Here we test a case when in one segment one or more columns has only null values. For example,
Table A:
| id | value |
|------|-----------|
| pk_1 | 'value_1' |
| pk_2 | NULL |

Table B:
| id | value |
|------|-----------|
| pk_1 | 'value_1' |

We can choose some bisection factor and bisection threshold (2 and 3 for our example, respectively)
that one segment will look like ('pk_2', NULL). Some databases, when we do a cast these values to string and
try to concatenate, some databases return NULL when concatenating (for example, MySQL). As the result, all next
operations like substring, sum etc return nulls that leads incorrect diff results: ('pk_2', null) should be in
diff results, but it's not. This test helps to detect such cases.
"""

differ = TableDiffer(bisection_factor=2, bisection_threshold=3)
diff = list(differ.diff_tables(self.a, self.b))
self.assertEqual(diff, [("-", (str(self.null_uuid), None))])


class TestConcatMultipleColumnWithNulls(TestWithConnection):
def setUp(self):
super().setUp()

queries = [
f"DROP TABLE IF EXISTS {self.table_src}",
f"DROP TABLE IF EXISTS {self.table_dst}",
f"CREATE TABLE {self.table_src}(id varchar(100), c1 varchar(100), c2 varchar(100))",
f"CREATE TABLE {self.table_dst}(id varchar(100), c1 varchar(100), c2 varchar(100))",
]

self.diffs = []
for i in range(0, 8):
pk = uuid.uuid1(i)
table_src_c1_val = str(i)
table_dst_c1_val = str(i) + "-different"

queries.append(f"INSERT INTO {self.table_src} VALUES ('{pk}', '{table_src_c1_val}', NULL)")
queries.append(f"INSERT INTO {self.table_dst} VALUES ('{pk}', '{table_dst_c1_val}', NULL)")

self.diffs.append(("-", (str(pk), table_src_c1_val, None)))
self.diffs.append(("+", (str(pk), table_dst_c1_val, None)))

queries.append("COMMIT")

for query in queries:
self.connection.query(query, None)

self.a = TableSegment(self.connection, (self.table_src,), "id", extra_columns=("c1", "c2"))
self.b = TableSegment(self.connection, (self.table_dst,), "id", extra_columns=("c1", "c2"))

def test_tables_are_different(self):
"""
Here we test a case when in one segment one or more columns has only null values. For example,
Table A:
| id | c1 | c2 |
|------|----|------|
| pk_1 | 1 | NULL |
| pk_2 | 2 | NULL |
...
| pk_n | n | NULL |

Table B:
| id | c1 | c2 |
|------|--------|------|
| pk_1 | 1-diff | NULL |
| pk_2 | 2-diff | NULL |
...
| pk_n | n-diff | NULL |

To calculate a checksum, we need to concatenate string values by rows. If both tables have columns with NULL
value, it may lead that concat(pk_i, i, NULL) == concat(pk_i, i-diff, NULL). This test handle such cases.
"""

differ = TableDiffer(bisection_factor=2, bisection_threshold=4)
diff = list(differ.diff_tables(self.a, self.b))
self.assertEqual(diff, self.diffs)


class TestTableTableEmpty(TestWithConnection):
def setUp(self):
super().setUp()

self.null_uuid = uuid.uuid1(1)
queries = [
f"DROP TABLE IF EXISTS {self.table_src}",
f"DROP TABLE IF EXISTS {self.table_dst}",
f"CREATE TABLE {self.table_src}(id varchar(100), comment varchar(1000))",
f"CREATE TABLE {self.table_dst}(id varchar(100), comment varchar(1000))",
]

self.diffs = [(uuid.uuid1(i), i) for i in range(100)]
for pk, value in self.diffs:
queries.append(f"INSERT INTO {self.table_src} VALUES ('{pk}', '{value}')")

queries.append("COMMIT")

for query in queries:
self.connection.query(query, None)

self.a = TableSegment(self.connection, (self.table_src,), "id", "comment")
self.b = TableSegment(self.connection, (self.table_dst,), "id", "comment")

def test_right_table_empty(self):
differ = TableDiffer()
self.assertRaises(ValueError, differ.diff_tables, self.a, self.b)

def test_left_table_empty(self):
queries = [
f"INSERT INTO {self.table_dst} SELECT id, comment FROM {self.table_src}",
f"TRUNCATE {self.table_src}",
"COMMIT"
]
for query in queries:
self.connection.query(query, None)

differ = TableDiffer()
self.assertRaises(ValueError, differ.diff_tables, self.a, self.b)