Skip to content
This repository was archived by the owner on May 17, 2024. It is now read-only.

Commit e6b9ffc

Browse files
authored
Merge pull request #222 from datafold/aug30
Better error messages. Move some parsing to before the connects.
2 parents 7276935 + eca57f6 commit e6b9ffc

File tree

5 files changed

+63
-22
lines changed

5 files changed

+63
-22
lines changed

data_diff/__main__.py

Lines changed: 22 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,8 @@ def _main(
155155
logging.debug(f"Applied run configuration: {__conf__}")
156156
elif verbose:
157157
logging.basicConfig(level=logging.INFO, format=LOG_FORMAT, datefmt=DATE_FORMAT)
158+
else:
159+
logging.basicConfig(level=logging.WARNING, format=LOG_FORMAT, datefmt=DATE_FORMAT)
158160

159161
if limit and stats:
160162
logging.error("Cannot specify a limit when using the -s/--stats switch")
@@ -181,14 +183,6 @@ def _main(
181183
logging.error("Error: threads must be >= 1")
182184
return
183185

184-
db1 = connect(database1, threads1 or threads)
185-
db2 = connect(database2, threads2 or threads)
186-
dbs = db1, db2
187-
188-
if interactive:
189-
for db in dbs:
190-
db.enable_interactive()
191-
192186
start = time.monotonic()
193187

194188
try:
@@ -199,7 +193,7 @@ def _main(
199193
where=where,
200194
)
201195
except ParseError as e:
202-
logging.error("Error while parsing age expression: %s" % e)
196+
logging.error(f"Error while parsing age expression: {e}")
203197
return
204198

205199
differ = TableDiffer(
@@ -210,6 +204,25 @@ def _main(
210204
debug=debug,
211205
)
212206

207+
if database1 is None or database2 is None:
208+
logging.error(
209+
f"Error: Databases not specified. Got {database1} and {database2}. Use --help for more information."
210+
)
211+
return
212+
213+
try:
214+
db1 = connect(database1, threads1 or threads)
215+
db2 = connect(database2, threads2 or threads)
216+
except Exception as e:
217+
logging.error(e)
218+
return
219+
220+
dbs = db1, db2
221+
222+
if interactive:
223+
for db in dbs:
224+
db.enable_interactive()
225+
213226
table_names = table1, table2
214227
table_paths = [db.parse_table_name(t) for db, t in safezip(dbs, table_names)]
215228

data_diff/databases/connect.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -202,4 +202,4 @@ def connect(db_conf: Union[str, dict], thread_count: Optional[int] = 1) -> Datab
202202
return connect_to_uri(db_conf, thread_count)
203203
elif isinstance(db_conf, dict):
204204
return connect_with_dict(db_conf, thread_count)
205-
raise TypeError(db_conf)
205+
raise TypeError(f"db configuration must be a URI string or a dictionary. Instead got '{db_conf}'.")

tests/common.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -63,12 +63,17 @@ def get_git_revision_short_hash() -> str:
6363
db.Trino: TEST_TRINO_CONN_STRING,
6464
}
6565

66-
for k, v in CONN_STRINGS.items():
67-
if v is None:
68-
logging.warn(f"Connection to {k} not configured")
69-
else:
70-
logging.info(f"Testing database: {k}")
7166

67+
def _print_used_dbs():
68+
used = {k.__name__ for k, v in CONN_STRINGS.items() if v is not None}
69+
unused = {k.__name__ for k, v in CONN_STRINGS.items() if v is None}
70+
71+
logging.info(f"Testing databases: {', '.join(used)}")
72+
if unused:
73+
logging.info(f"Connection not configured; skipping tests for: {', '.join(unused)}")
74+
75+
76+
_print_used_dbs()
7277
CONN_STRINGS = {k: v for k, v in CONN_STRINGS.items() if v is not None}
7378

7479

tests/test_database_types.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,18 @@
2626
_drop_table_if_exists,
2727
)
2828

29+
CONNS = None
2930

30-
CONNS = {k: db.connect.connect(v, N_THREADS) for k, v in CONN_STRINGS.items()}
3131

32-
CONNS[db.MySQL].query("SET @@session.time_zone='+00:00'", None)
33-
oracle.SESSION_TIME_ZONE = postgresql.SESSION_TIME_ZONE = "UTC"
32+
def init_conns():
33+
global CONNS
34+
if CONNS is not None:
35+
return
36+
37+
CONNS = {k: db.connect.connect(v, N_THREADS) for k, v in CONN_STRINGS.items()}
38+
CONNS[db.MySQL].query("SET @@session.time_zone='+00:00'", None)
39+
oracle.SESSION_TIME_ZONE = postgresql.SESSION_TIME_ZONE = "UTC"
40+
3441

3542
DATABASE_TYPES = {
3643
db.PostgreSQL: {
@@ -374,7 +381,7 @@ def __iter__(self):
374381
) in source_type_categories.items(): # int, datetime, ..
375382
for source_type in source_types:
376383
for target_type in target_type_categories[type_category]:
377-
if CONNS.get(source_db, False) and CONNS.get(target_db, False):
384+
if CONN_STRINGS.get(source_db, False) and CONN_STRINGS.get(target_db, False):
378385
type_pairs.append(
379386
(
380387
source_db,
@@ -518,6 +525,9 @@ def _create_table_with_indexes(conn, table, type):
518525
class TestDiffCrossDatabaseTables(unittest.TestCase):
519526
maxDiff = 10000
520527

528+
def setUp(self) -> None:
529+
init_conns()
530+
521531
def tearDown(self) -> None:
522532
if not BENCHMARK:
523533
_drop_table_if_exists(self.src_conn, self.src_table)

tests/test_diff_tables.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,17 @@
2020
N_THREADS,
2121
)
2222

23+
DATABASE_INSTANCES = None
2324
DATABASE_URIS = {k.__name__: v for k, v in CONN_STRINGS.items()}
24-
DATABASE_INSTANCES = {k.__name__: connect(v, N_THREADS) for k, v in CONN_STRINGS.items()}
25+
26+
27+
def init_instances():
28+
global DATABASE_INSTANCES
29+
if DATABASE_INSTANCES is not None:
30+
return
31+
32+
DATABASE_INSTANCES = {k.__name__: connect(v, N_THREADS) for k, v in CONN_STRINGS.items()}
33+
2534

2635
TEST_DATABASES = {x.__name__ for x in (db.MySQL, db.PostgreSQL, db.Oracle, db.Redshift, db.Snowflake, db.BigQuery)}
2736

@@ -56,6 +65,7 @@ def _get_text_type(conn):
5665
return "STRING"
5766
return "varchar(100)"
5867

68+
5969
def _get_float_type(conn):
6070
if isinstance(conn, db.BigQuery):
6171
return "FLOAT64"
@@ -79,6 +89,7 @@ class TestPerDatabase(unittest.TestCase):
7989

8090
def setUp(self):
8191
assert self.db_name
92+
init_instances()
8293

8394
self.connection = DATABASE_INSTANCES[self.db_name]
8495
if self.with_preql:
@@ -215,10 +226,12 @@ def setUp(self):
215226
float_type = _get_float_type(self.connection)
216227

217228
self.connection.query(
218-
f"create table {self.table_src}(id int, userid int, movieid int, rating {float_type}, timestamp timestamp)", None
229+
f"create table {self.table_src}(id int, userid int, movieid int, rating {float_type}, timestamp timestamp)",
230+
None,
219231
)
220232
self.connection.query(
221-
f"create table {self.table_dst}(id int, userid int, movieid int, rating {float_type}, timestamp timestamp)", None
233+
f"create table {self.table_dst}(id int, userid int, movieid int, rating {float_type}, timestamp timestamp)",
234+
None,
222235
)
223236
# self.preql(
224237
# f"""

0 commit comments

Comments
 (0)