Skip to content
This repository was archived by the owner on May 17, 2024. It is now read-only.

Commit 5449a04

Browse files
authored
Merge pull request #299 from datafold/nov17_tests
Refactor tests to use insert_rows_in_batches(), instead of internally…
2 parents 22022dd + 753ed4a commit 5449a04

File tree

9 files changed

+108
-123
lines changed

9 files changed

+108
-123
lines changed

data_diff/diff_tables.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,8 @@ def _bisect_and_diff_tables(self, table1, table2, info_tree):
185185
raise NotImplementedError(f"Cannot use column of type {key_type} as a key")
186186
if not isinstance(key_type2, IKey):
187187
raise NotImplementedError(f"Cannot use column of type {key_type2} as a key")
188-
assert key_type.python_type is key_type2.python_type
188+
if key_type.python_type is not key_type2.python_type:
189+
raise TypeError(f"Incompatible key types: {key_type} and {key_type2}")
189190

190191
# Query min/max values
191192
key_ranges = self._threaded_call_as_completed("query_key_range", [table1, table2])

data_diff/sqeleton/databases/base.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,10 @@
88
import threading
99
from abc import abstractmethod
1010
from uuid import UUID
11+
import decimal
1112

1213
from ..utils import is_uuid, safezip
13-
from ..queries import Expr, Compiler, table, Select, SKIP, Explain
14+
from ..queries import Expr, Compiler, table, Select, SKIP, Explain, Code
1415
from .database_types import (
1516
AbstractDatabase,
1617
AbstractDialect,
@@ -133,10 +134,15 @@ def _constant_value(self, v):
133134
elif isinstance(v, str):
134135
return f"'{v}'"
135136
elif isinstance(v, datetime):
136-
# TODO use self.timestamp_value
137-
return f"timestamp '{v}'"
137+
return self.timestamp_value(v)
138138
elif isinstance(v, UUID):
139139
return f"'{v}'"
140+
elif isinstance(v, decimal.Decimal):
141+
return str(v)
142+
elif isinstance(v, bytearray):
143+
return f"'{v.decode()}'"
144+
elif isinstance(v, Code):
145+
return v.code
140146
return repr(v)
141147

142148
def constant_values(self, rows) -> str:
@@ -334,7 +340,7 @@ def _process_table_schema(
334340
# Return a dict of form {name: type} after normalization
335341
return col_dict
336342

337-
def _refine_coltypes(self, table_path: DbPath, col_dict: Dict[str, ColType], where: str = None, sample_size=32):
343+
def _refine_coltypes(self, table_path: DbPath, col_dict: Dict[str, ColType], where: str = None, sample_size=64):
338344
"""Refine the types in the column dict, by querying the database for a sample of their values
339345
340346
'where' restricts the rows to be sampled.

data_diff/sqeleton/databases/clickhouse.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
ThreadedDatabase,
99
import_helper,
1010
ConnectError,
11+
DbTime,
1112
)
1213
from .database_types import (
1314
ColType,
@@ -146,6 +147,10 @@ def _parse_type_repr(self, type_repr: str) -> Optional[Type[ColType]]:
146147

147148
return self.TYPE_CLASSES.get(type_repr)
148149

150+
# def timestamp_value(self, t: DbTime) -> str:
151+
# # return f"'{t}'"
152+
# return f"'{str(t)[:19]}'"
153+
149154

150155
class Clickhouse(ThreadedDatabase):
151156
dialect = Dialect()
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
from .compiler import Compiler
22
from .api import this, join, outerjoin, table, SKIP, sum_, avg, min_, max_, cte, commit, when, coalesce
3-
from .ast_classes import Expr, ExprNode, Select, Count, BinOp, Explain, In
3+
from .ast_classes import Expr, ExprNode, Select, Count, BinOp, Explain, In, Code
44
from .extras import Checksum, NormalizeAsString, ApplyFuncAndNormalizeAsString

data_diff/sqeleton/queries/api.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,4 +86,13 @@ def coalesce(*exprs):
8686
return Func("COALESCE", exprs)
8787

8888

89+
def insert_rows_in_batches(db, table: TablePath, rows, *, columns=None, batch_size=1024 * 8):
90+
assert batch_size > 0
91+
rows = list(rows)
92+
93+
while rows:
94+
batch, rows = rows[:batch_size], rows[batch_size:]
95+
db.query(table.insert_rows(batch, columns=columns))
96+
97+
8998
commit = Commit()

data_diff/sqeleton/queries/ast_classes.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,12 @@ def cast_to(self, to):
4343

4444
Expr = Union[ExprNode, str, bool, int, datetime, ArithString, None]
4545

46+
@dataclass
47+
class Code(ExprNode):
48+
code: str
49+
50+
def compile(self, c: Compiler) -> str:
51+
return self.code
4652

4753
def _expr_type(e: Expr) -> type:
4854
if isinstance(e, ExprNode):

tests/common.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from data_diff import tracking
1414
from data_diff import connect
1515
from data_diff.sqeleton.queries.api import table
16+
from data_diff.sqeleton.databases import Database
1617
from data_diff.query_utils import drop_table
1718

1819
tracking.disable_tracking()
@@ -85,7 +86,7 @@ def get_git_revision_short_hash() -> str:
8586
_database_instances = {}
8687

8788

88-
def get_conn(cls: type):
89+
def get_conn(cls: type) -> Database:
8990
if cls not in _database_instances:
9091
_database_instances[cls] = connect(CONN_STRINGS[cls], N_THREADS)
9192
return _database_instances[cls]

tests/test_database_types.py

Lines changed: 66 additions & 112 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,8 @@
1818
from data_diff.query_utils import drop_table
1919
from data_diff.utils import accumulate
2020
from data_diff.sqeleton.utils import number_to_human
21-
from data_diff.sqeleton.queries import table, commit
21+
from data_diff.sqeleton.queries import table, commit, this, Code
22+
from data_diff.sqeleton.queries.api import insert_rows_in_batches
2223
from data_diff.hashdiff_tables import HashDiffer, DEFAULT_BISECTION_THRESHOLD
2324
from data_diff.table_segment import TableSegment
2425
from .common import (
@@ -362,32 +363,25 @@ class PaginatedTable:
362363
# much memory.
363364
RECORDS_PER_BATCH = 1000000
364365

365-
def __init__(self, table, conn):
366-
self.table = table
366+
def __init__(self, table_path, conn):
367+
self.table_path = table_path
367368
self.conn = conn
368369

369370
def __iter__(self):
370-
iter = PaginatedTable(self.table, self.conn)
371-
iter.last_id = 0
372-
iter.values = []
373-
iter.value_index = 0
374-
return iter
375-
376-
def __next__(self) -> str:
377-
if self.value_index == len(self.values): # end of current batch
378-
query = f"SELECT id, col FROM {self.table} WHERE id > {self.last_id} ORDER BY id ASC LIMIT {self.RECORDS_PER_BATCH}"
379-
if isinstance(self.conn, db.Oracle):
380-
query = f"SELECT id, col FROM {self.table} WHERE id > {self.last_id} ORDER BY id ASC OFFSET 0 ROWS FETCH NEXT {self.RECORDS_PER_BATCH} ROWS ONLY"
381-
382-
self.values = self.conn.query(query, list)
383-
if len(self.values) == 0: # we must be done!
384-
raise StopIteration
385-
self.last_id = self.values[-1][0]
386-
self.value_index = 0
387-
388-
this_value = self.values[self.value_index]
389-
self.value_index += 1
390-
return this_value
371+
last_id = 0
372+
while True:
373+
query = (
374+
table(self.table_path)
375+
.select(this.id, this.col)
376+
.where(this.id > last_id)
377+
.order_by(this.id)
378+
.limit(self.RECORDS_PER_BATCH)
379+
)
380+
rows = self.conn.query(query, list)
381+
if not rows:
382+
break
383+
last_id = rows[-1][0]
384+
yield from rows
391385

392386

393387
class DateTimeFaker:
@@ -560,90 +554,42 @@ def expand_params(testcase_func, param_num, param):
560554
return name
561555

562556

563-
def _insert_to_table(conn, table, values, type):
564-
current_n_rows = conn.query(f"SELECT COUNT(*) FROM {table}", int)
557+
def _insert_to_table(conn, table_path, values, type):
558+
tbl = table(table_path)
559+
560+
current_n_rows = conn.query(tbl.count(), int)
565561
if current_n_rows == N_SAMPLES:
566562
assert BENCHMARK, "Table should've been deleted, or we should be in BENCHMARK mode"
567563
return
568564
elif current_n_rows > 0:
569-
conn.query(drop_table(table))
570-
_create_table_with_indexes(conn, table, type)
571-
572-
if BENCHMARK and N_SAMPLES > 10_000:
573-
description = f"{conn.name}: {table}"
574-
values = rich.progress.track(values, total=N_SAMPLES, description=description)
575-
576-
default_insertion_query = f"INSERT INTO {table} (id, col) VALUES "
577-
if isinstance(conn, db.Oracle):
578-
default_insertion_query = f"INSERT INTO {table} (id, col)"
579-
580-
batch_size = 8000
581-
if isinstance(conn, db.BigQuery):
582-
batch_size = 1000
583-
584-
insertion_query = default_insertion_query
585-
selects = []
586-
for j, sample in values:
587-
if re.search(r"(time zone|tz)", type):
588-
sample = sample.replace(tzinfo=timezone.utc)
565+
conn.query(drop_table(table_name))
566+
_create_table_with_indexes(conn, table_path, type)
589567

590-
if isinstance(sample, bytearray):
591-
value = f"'{sample.decode()}'"
568+
# if BENCHMARK and N_SAMPLES > 10_000:
569+
# description = f"{conn.name}: {table}"
570+
# values = rich.progress.track(values, total=N_SAMPLES, description=description)
592571

593-
elif type == "boolean":
594-
value = str(bool(sample))
572+
if type == "boolean":
573+
values = [(i, bool(sample)) for i, sample in values]
574+
elif re.search(r"(time zone|tz)", type):
575+
values = [(i, sample.replace(tzinfo=timezone.utc)) for i, sample in values]
595576

596-
elif isinstance(conn, db.Clickhouse):
597-
if type.startswith("DateTime64"):
598-
value = f"'{sample.replace(tzinfo=None)}'"
577+
if isinstance(conn, db.Clickhouse):
578+
if type.startswith("DateTime64"):
579+
values = [(i, f"{sample.replace(tzinfo=None)}") for i, sample in values]
599580

600-
elif type == "DateTime":
601-
sample = sample.replace(tzinfo=None)
602-
# Clickhouse's DateTime does not allow to store micro/milli/nano seconds
603-
value = f"'{str(sample)[:19]}'"
581+
elif type == "DateTime":
582+
# Clickhouse's DateTime does not allow to store micro/milli/nano seconds
583+
values = [(i, str(sample)[:19]) for i, sample in values]
604584

605-
elif type.startswith("Decimal"):
606-
precision = int(type[8:].rstrip(")").split(",")[1])
607-
value = round(sample, precision)
585+
elif type.startswith("Decimal("):
586+
precision = int(type[8:].rstrip(")").split(",")[1])
587+
values = [(i, round(sample, precision)) for i, sample in values]
588+
elif isinstance(conn, db.BigQuery) and type == "datetime":
589+
values = [(i, Code(f"cast(timestamp '{sample}' as datetime)")) for i, sample in values]
608590

609-
else:
610-
value = f"'{sample}'"
611-
612-
elif isinstance(sample, (float, Decimal, int)):
613-
value = str(sample)
614-
elif isinstance(sample, datetime) and isinstance(conn, (db.Presto, db.Oracle, db.Trino)):
615-
value = f"timestamp '{sample}'"
616-
elif isinstance(sample, datetime) and isinstance(conn, db.BigQuery) and type == "datetime":
617-
value = f"cast(timestamp '{sample}' as datetime)"
618-
619-
else:
620-
value = f"'{sample}'"
621-
622-
if isinstance(conn, db.Oracle):
623-
selects.append(f"SELECT {j}, {value} FROM dual")
624-
else:
625-
insertion_query += f"({j}, {value}),"
626-
627-
# Some databases want small batch sizes...
628-
# Need to also insert on the last row, might not divide cleanly!
629-
if j % batch_size == 0 or j == N_SAMPLES:
630-
if isinstance(conn, db.Oracle):
631-
insertion_query += " UNION ALL ".join(selects)
632-
conn.query(insertion_query, None)
633-
selects = []
634-
insertion_query = default_insertion_query
635-
else:
636-
conn.query(insertion_query[0:-1], None)
637-
insertion_query = default_insertion_query
638-
639-
if insertion_query != default_insertion_query:
640-
# Very bad, but this whole function needs to go
641-
if isinstance(conn, db.Oracle):
642-
insertion_query += " UNION ALL ".join(selects)
643-
conn.query(insertion_query, None)
644-
else:
645-
conn.query(insertion_query[0:-1], None)
646591

592+
insert_rows_in_batches(conn, tbl, values, columns=["id", "col"])
647593
conn.query(commit)
648594

649595

@@ -676,17 +622,27 @@ def _create_indexes(conn, table):
676622
raise (err)
677623

678624

679-
def _create_table_with_indexes(conn, table, type):
625+
def _create_table_with_indexes(conn, table_path, type_):
626+
table_name = ".".join(map(conn.dialect.quote, table_path))
627+
628+
tbl = table(
629+
table_path,
630+
schema={
631+
"id": int,
632+
"col": type_,
633+
},
634+
)
635+
680636
if isinstance(conn, db.Oracle):
681-
already_exists = conn.query(f"SELECT COUNT(*) from tab where tname='{table.upper()}'", int) > 0
637+
already_exists = conn.query(f"SELECT COUNT(*) from tab where tname='{table_name.upper()}'", int) > 0
682638
if not already_exists:
683-
conn.query(f"CREATE TABLE {table}(id int, col {type})", None)
639+
conn.query(tbl.create())
684640
elif isinstance(conn, db.Clickhouse):
685-
conn.query(f"CREATE TABLE {table}(id int, col {type}) engine = Memory;", None)
641+
conn.query(f"CREATE TABLE {table_name}(id int, col {type_}) engine = Memory;", None)
686642
else:
687-
conn.query(f"CREATE TABLE IF NOT EXISTS {table}(id int, col {type})", None)
643+
conn.query(tbl.create(if_not_exists=True))
688644

689-
_create_indexes(conn, table)
645+
_create_indexes(conn, table_name)
690646
conn.query(commit)
691647

692648

@@ -725,17 +681,15 @@ def test_types(self, source_db, target_db, source_type, target_type, type_catego
725681

726682
self.src_table_path = src_table_path = src_conn.parse_table_name(src_table_name)
727683
self.dst_table_path = dst_table_path = dst_conn.parse_table_name(dst_table_name)
728-
self.src_table = src_table = ".".join(map(src_conn.dialect.quote, src_table_path))
729-
self.dst_table = dst_table = ".".join(map(dst_conn.dialect.quote, dst_table_path))
730684

731685
start = time.monotonic()
732686
if not BENCHMARK:
733687
drop_table(src_conn, src_table_path)
734-
_create_table_with_indexes(src_conn, src_table, source_type)
735-
_insert_to_table(src_conn, src_table, enumerate(sample_values, 1), source_type)
688+
_create_table_with_indexes(src_conn, src_table_path, source_type)
689+
_insert_to_table(src_conn, src_table_path, enumerate(sample_values, 1), source_type)
736690
insertion_source_duration = time.monotonic() - start
737691

738-
values_in_source = PaginatedTable(src_table, src_conn)
692+
values_in_source = PaginatedTable(src_table_path, src_conn)
739693
if source_db is db.Presto or source_db is db.Trino:
740694
if source_type.startswith("decimal"):
741695
values_in_source = ((a, Decimal(b)) for a, b in values_in_source)
@@ -745,8 +699,8 @@ def test_types(self, source_db, target_db, source_type, target_type, type_catego
745699
start = time.monotonic()
746700
if not BENCHMARK:
747701
drop_table(dst_conn, dst_table_path)
748-
_create_table_with_indexes(dst_conn, dst_table, target_type)
749-
_insert_to_table(dst_conn, dst_table, values_in_source, target_type)
702+
_create_table_with_indexes(dst_conn, dst_table_path, target_type)
703+
_insert_to_table(dst_conn, dst_table_path, values_in_source, target_type)
750704
insertion_target_duration = time.monotonic() - start
751705

752706
if type_category == "uuid":
@@ -813,8 +767,8 @@ def test_types(self, source_db, target_db, source_type, target_type, type_catego
813767
"rows": N_SAMPLES,
814768
"rows_human": number_to_human(N_SAMPLES),
815769
"name_human": f"{source_db.__name__}/{sanitize(source_type)} <-> {target_db.__name__}/{sanitize(target_type)}",
816-
"src_table": src_table[1:-1], # remove quotes
817-
"target_table": dst_table[1:-1],
770+
"src_table": src_table_path,
771+
"target_table": dst_table_path,
818772
"source_type": source_type,
819773
"target_type": target_type,
820774
"insertion_source_sec": round(insertion_source_duration, 3),

tests/test_diff_tables.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -731,7 +731,10 @@ def test_left_table_empty(self):
731731

732732
class TestInfoTree(unittest.TestCase):
733733
def test_info_tree_root(self):
734-
self.ddb = get_conn(db.DuckDB)
734+
try:
735+
self.db = get_conn(db.DuckDB)
736+
except KeyError: # ddb not defined
737+
self.db = get_conn(db.MySQL)
735738

736739
table_suffix = random_table_suffix()
737740
self.table_src_name = f"src{table_suffix}"
@@ -750,10 +753,10 @@ def test_info_tree_root(self):
750753
self.table2.insert_rows([i] for i in range(2000)),
751754
]
752755
for q in queries:
753-
self.ddb.query(q)
756+
self.db.query(q)
754757

755-
ts1 = TableSegment(self.ddb, self.table1.path, ("id",))
756-
ts2 = TableSegment(self.ddb, self.table2.path, ("id",))
758+
ts1 = TableSegment(self.db, self.table1.path, ("id",))
759+
ts2 = TableSegment(self.db, self.table2.path, ("id",))
757760

758761
for differ in (HashDiffer(bisection_threshold=64), JoinDiffer(True)):
759762
diff_res = differ.diff_tables(ts1, ts2)

0 commit comments

Comments
 (0)