Skip to content
This repository was archived by the owner on May 17, 2024. It is now read-only.

Commit e9a36de

Browse files
author
Sergey Vasilyev
committed
Simplify by replacing the self-made WeakCache with the builtin WeakValueDict
1 parent 240db97 commit e9a36de

File tree

3 files changed

+13
-53
lines changed

3 files changed

+13
-53
lines changed

data_diff/sqeleton/databases/_connect.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
1-
from typing import Type, Optional, Union, Dict
1+
from typing import Hashable, MutableMapping, Type, Optional, Union, Dict
22
from itertools import zip_longest
33
from contextlib import suppress
4+
import weakref
45
import dsnparse
56
import toml
67

78
from runtype import dataclass
89
from typing_extensions import Self
910

1011
from ..abcs.mixins import AbstractMixin
11-
from ..utils import WeakCache
1212
from .base import Database, ThreadedDatabase
1313
from .postgresql import PostgreSQL
1414
from .mysql import MySQL
@@ -94,11 +94,12 @@ def match_path(self, dsn):
9494

9595
class Connect:
9696
"""Provides methods for connecting to a supported database using a URL or connection dict."""
97+
conn_cache: MutableMapping[Hashable, Database]
9798

9899
def __init__(self, database_by_scheme: Dict[str, Database] = DATABASE_BY_SCHEME):
99100
self.database_by_scheme = database_by_scheme
100101
self.match_uri_path = {name: MatchUriPath(cls) for name, cls in database_by_scheme.items()}
101-
self.conn_cache = WeakCache()
102+
self.conn_cache = weakref.WeakValueDictionary()
102103

103104
def for_databases(self, *dbs) -> Self:
104105
database_by_scheme = {k: db for k, db in self.database_by_scheme.items() if k in dbs}
@@ -263,9 +264,10 @@ def __call__(
263264
>>> connect({"driver": "mysql", "host": "localhost", "database": "db"})
264265
<data_diff.sqeleton.databases.mysql.MySQL object at ...>
265266
"""
267+
cache_key = self.__make_cache_key(db_conf)
266268
if shared:
267269
with suppress(KeyError):
268-
conn = self.conn_cache.get(db_conf)
270+
conn = self.conn_cache[cache_key]
269271
if not conn.is_closed:
270272
return conn
271273

@@ -277,5 +279,10 @@ def __call__(
277279
raise TypeError(f"db configuration must be a URI string or a dictionary. Instead got '{db_conf}'.")
278280

279281
if shared:
280-
self.conn_cache.add(db_conf, conn)
282+
self.conn_cache[cache_key] = conn
281283
return conn
284+
285+
def __make_cache_key(self, db_conf: Union[str, dict]) -> Hashable:
286+
if isinstance(db_conf, dict):
287+
return tuple(db_conf.items())
288+
return db_conf

data_diff/sqeleton/utils.py

Lines changed: 0 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,10 @@
77
Any,
88
Sequence,
99
Dict,
10-
Hashable,
1110
TypeVar,
1211
List,
1312
)
1413
from abc import abstractmethod
15-
from weakref import ref
1614
import math
1715
import string
1816
import re
@@ -24,30 +22,6 @@
2422
# -- Common --
2523

2624

27-
class WeakCache:
28-
def __init__(self):
29-
self._cache = {}
30-
31-
def _hashable_key(self, k: Union[dict, Hashable]) -> Hashable:
32-
if isinstance(k, dict):
33-
return tuple(k.items())
34-
return k
35-
36-
def add(self, key: Union[dict, Hashable], value: Any):
37-
key = self._hashable_key(key)
38-
self._cache[key] = ref(value)
39-
40-
def get(self, key: Union[dict, Hashable]) -> Any:
41-
key = self._hashable_key(key)
42-
43-
value = self._cache[key]()
44-
if value is None:
45-
del self._cache[key]
46-
raise KeyError(f"Key {key} not found, or no longer a valid reference")
47-
48-
return value
49-
50-
5125
def join_iter(joiner: Any, iterable: Iterable) -> Iterable:
5226
it = iter(iterable)
5327
try:

tests/sqeleton/test_utils.py

Lines changed: 1 addition & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import unittest
22

3-
from data_diff.sqeleton.utils import remove_passwords_in_dict, match_regexps, match_like, number_to_human, WeakCache
3+
from data_diff.sqeleton.utils import remove_passwords_in_dict, match_regexps, match_like, number_to_human
44

55

66
class TestUtils(unittest.TestCase):
@@ -81,24 +81,3 @@ def test_number_to_human(self):
8181
assert number_to_human(-1000) == "-1k"
8282
assert number_to_human(-1000000) == "-1m"
8383
assert number_to_human(-1000000000) == "-1b"
84-
85-
def test_weak_cache(self):
86-
# Create cache
87-
cache = WeakCache()
88-
89-
# Test adding and retrieving basic value
90-
o = {1, 2}
91-
cache.add("key", o)
92-
assert cache.get("key") is o
93-
94-
# Test adding and retrieving dict value
95-
cache.add({"key": "value"}, o)
96-
assert cache.get({"key": "value"}) is o
97-
98-
# Test deleting value when reference is lost
99-
del o
100-
try:
101-
cache.get({"key": "value"})
102-
assert False, "KeyError should have been raised"
103-
except KeyError:
104-
pass

0 commit comments

Comments
 (0)