Skip to content
This repository was archived by the owner on May 17, 2024. It is now read-only.

Commit 1faa22f

Browse files
committed
connect(): Added support for shared connection; Database.is_closed property
1 parent daf2d94 commit 1faa22f

File tree

9 files changed

+69
-15
lines changed

9 files changed

+69
-15
lines changed

data_diff/sqeleton/abcs/database_types.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -284,6 +284,11 @@ def close(self):
284284
"Close connection(s) to the database instance. Querying will stop functioning."
285285
...
286286

287+
@property
288+
@abstractmethod
289+
def is_closed(self) -> bool:
290+
"Return whether or not the connection has been closed"
291+
287292
@abstractmethod
288293
def _normalize_table_path(self, path: DbPath) -> DbPath:
289294
...

data_diff/sqeleton/databases/base.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -251,6 +251,7 @@ class Database(AbstractDatabase):
251251
CONNECT_URI_KWPARAMS = []
252252

253253
_interactive = False
254+
is_closed = False
254255

255256
@property
256257
def name(self):
@@ -440,6 +441,10 @@ def _query_conn(self, conn, sql_code: Union[str, ThreadLocalInterpreter]) -> lis
440441
callback = partial(self._query_cursor, c)
441442
return apply_query(callback, sql_code)
442443

444+
def close(self):
445+
self.is_closed = True
446+
return super().close()
447+
443448

444449
class ThreadedDatabase(Database):
445450
"""Access the database through singleton threads.
@@ -476,6 +481,7 @@ def create_connection(self):
476481
...
477482

478483
def close(self):
484+
super().close()
479485
self._queue.shutdown()
480486

481487
@property

data_diff/sqeleton/databases/bigquery.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,7 @@ def _query(self, sql_code: Union[str, ThreadLocalInterpreter]):
141141
return apply_query(self._query_atom, sql_code)
142142

143143
def close(self):
144+
super().close()
144145
self._client.close()
145146

146147
def select_table_schema(self, path: DbPath) -> str:

data_diff/sqeleton/databases/connect.py

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
from typing import Type, List, Optional, Union, Dict
22
from itertools import zip_longest
33
import dsnparse
4+
from contextlib import suppress
45

56
from runtype import dataclass
67

8+
from ..utils import WeakCache
79
from .base import Database, ThreadedDatabase
810
from .postgresql import PostgreSQL
911
from .mysql import MySQL
@@ -19,12 +21,13 @@
1921
from .duckdb import DuckDB
2022

2123

24+
2225
@dataclass
2326
class MatchUriPath:
2427
database_cls: Type[Database]
2528
params: List[str]
2629
kwparams: List[str] = []
27-
help_str: str
30+
help_str: str = "<unspecified>"
2831

2932
def __post_init__(self):
3033
assert self.params == self.database_cls.CONNECT_URI_PARAMS, self.params
@@ -101,6 +104,7 @@ def __init__(self, database_by_scheme: Dict[str, Database]):
101104
name: MatchUriPath(cls, cls.CONNECT_URI_PARAMS, cls.CONNECT_URI_KWPARAMS, help_str=cls.CONNECT_URI_HELP)
102105
for name, cls in database_by_scheme.items()
103106
}
107+
self.conn_cache = WeakCache()
104108

105109
def connect_to_uri(self, db_uri: str, thread_count: Optional[int] = 1) -> Database:
106110
"""Connect to the given database uri
@@ -200,7 +204,7 @@ def _connection_created(self, db):
200204
"Nop function to be overridden by subclasses."
201205
return db
202206

203-
def __call__(self, db_conf: Union[str, dict], thread_count: Optional[int] = 1) -> Database:
207+
def __call__(self, db_conf: Union[str, dict], thread_count: Optional[int] = 1, shared: bool = True) -> Database:
204208
"""Connect to a database using the given database configuration.
205209
206210
Configuration can be given either as a URI string, or as a dict of {option: value}.
@@ -235,8 +239,19 @@ def __call__(self, db_conf: Union[str, dict], thread_count: Optional[int] = 1) -
235239
>>> connect({"driver": "mysql", "host": "localhost", "database": "db"})
236240
<data_diff.databases.mysql.MySQL object at 0x0000025DB3F94820>
237241
"""
242+
if shared:
243+
with suppress(KeyError):
244+
conn = self.conn_cache.get(db_conf)
245+
if not conn.is_closed:
246+
return conn
247+
238248
if isinstance(db_conf, str):
239-
return self.connect_to_uri(db_conf, thread_count)
249+
conn = self.connect_to_uri(db_conf, thread_count)
240250
elif isinstance(db_conf, dict):
241-
return self.connect_with_dict(db_conf, thread_count)
242-
raise TypeError(f"db configuration must be a URI string or a dictionary. Instead got '{db_conf}'.")
251+
conn = self.connect_with_dict(db_conf, thread_count)
252+
else:
253+
raise TypeError(f"db configuration must be a URI string or a dictionary. Instead got '{db_conf}'.")
254+
255+
if shared:
256+
self.conn_cache.add(db_conf, conn)
257+
return conn

data_diff/sqeleton/databases/duckdb.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,7 @@ def _query(self, sql_code: Union[str, ThreadLocalInterpreter]):
132132
return self._query_conn(self._conn, sql_code)
133133

134134
def close(self):
135+
super().close()
135136
self._conn.close()
136137

137138
def create_connection(self):

data_diff/sqeleton/databases/presto.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,7 @@ def _query(self, sql_code: str) -> list:
171171
return query_cursor(c, sql_code)
172172

173173
def close(self):
174+
super().close()
174175
self._conn.close()
175176

176177
def select_table_schema(self, path: DbPath) -> str:

data_diff/sqeleton/databases/snowflake.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,7 @@ def __init__(self, *, schema: str, **kw):
134134
self.default_schema = schema
135135

136136
def close(self):
137+
super().close()
137138
self._conn.close()
138139

139140
def _query(self, sql_code: Union[str, ThreadLocalInterpreter]):

data_diff/sqeleton/utils.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from typing import Union, Dict, Any, Hashable
2+
from weakref import ref
13
from typing import TypeVar
24
from typing import Iterable, Iterator, MutableMapping, Union, Any, Sequence, Dict
35
from abc import abstractmethod
@@ -9,6 +11,30 @@
911
# -- Common --
1012

1113

14+
class WeakCache:
15+
def __init__(self):
16+
self._cache = {}
17+
18+
def _hashable_key(self, k: Union[dict, Hashable]) -> Hashable:
19+
if isinstance(k, dict):
20+
return tuple(k.items())
21+
return k
22+
23+
def add(self, key: Union[dict, Hashable], value: Any):
24+
key = self._hashable_key(key)
25+
self._cache[key] = ref(value)
26+
27+
def get(self, key: Union[dict, Hashable]) -> Any:
28+
key = self._hashable_key(key)
29+
30+
value = self._cache[key]()
31+
if value is None:
32+
del self._cache[key]
33+
raise KeyError(f"Key {key} not found, or no longer a valid reference")
34+
35+
return value
36+
37+
1238
def join_iter(joiner: Any, iterable: Iterable) -> Iterable:
1339
it = iter(iterable)
1440
try:

tests/test_api.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import arrow
22
from datetime import datetime
33

4-
from data_diff import diff_tables, connect_to_table
4+
from data_diff import diff_tables, connect_to_table, Algorithm
55
from data_diff.databases import MySQL
66
from data_diff.sqeleton.queries import table, commit
77

@@ -36,13 +36,17 @@ def setUp(self) -> None:
3636
)
3737

3838
def test_api(self):
39+
# test basic
3940
t1 = connect_to_table(TEST_MYSQL_CONN_STRING, self.table_src_name)
4041
t2 = connect_to_table(TEST_MYSQL_CONN_STRING, (self.table_dst_name,))
41-
diff = list(diff_tables(t1, t2))
42+
diff = list(diff_tables(t1, t2, algorithm=Algorithm.JOINDIFF))
4243
assert len(diff) == 1
4344

44-
t1.database.close()
45-
t2.database.close()
45+
# test algorithm
46+
# (also tests shared connection on connect_to_table)
47+
for algo in (Algorithm.HASHDIFF, Algorithm.JOINDIFF):
48+
diff = list(diff_tables(t1, t2, algorithm=algo))
49+
assert len(diff) == 1
4650

4751
# test where
4852
diff_id = diff[0][1][0]
@@ -53,9 +57,6 @@ def test_api(self):
5357
diff = list(diff_tables(t1, t2))
5458
assert len(diff) == 0
5559

56-
t1.database.close()
57-
t2.database.close()
58-
5960
def test_api_get_stats_dict(self):
6061
# XXX Likely to change in the future
6162
expected_dict = {
@@ -76,6 +77,3 @@ def test_api_get_stats_dict(self):
7677
self.assertEqual(expected_dict, output)
7778
self.assertIsNotNone(diff)
7879
assert len(list(diff)) == 1
79-
80-
t1.database.close()
81-
t2.database.close()

0 commit comments

Comments
 (0)