Skip to content
This repository was archived by the owner on May 17, 2024. It is now read-only.

Commit e20e94c

Browse files
author
Sergey Vasilyev
committed
Simplify by replacing the self-made WeakCache with the builtin WeakValueDict
1 parent 47c070e commit e20e94c

File tree

2 files changed

+13
-32
lines changed

2 files changed

+13
-32
lines changed

data_diff/sqeleton/databases/_connect.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
1-
from typing import Type, Optional, Union, Dict
1+
from typing import Hashable, MutableMapping, Type, Optional, Union, Dict
22
from itertools import zip_longest
33
from contextlib import suppress
4+
import weakref
45
import dsnparse
56
import toml
67

78
from runtype import dataclass
89

910
from ..abcs.mixins import AbstractMixin
10-
from ..utils import WeakCache, Self
11+
from ..utils import Self
1112
from .base import Database, ThreadedDatabase
1213
from .postgresql import PostgreSQL
1314
from .mysql import MySQL
@@ -93,11 +94,12 @@ def match_path(self, dsn):
9394

9495
class Connect:
9596
"""Provides methods for connecting to a supported database using a URL or connection dict."""
97+
conn_cache: MutableMapping[Hashable, Database]
9698

9799
def __init__(self, database_by_scheme: Dict[str, Database] = DATABASE_BY_SCHEME):
98100
self.database_by_scheme = database_by_scheme
99101
self.match_uri_path = {name: MatchUriPath(cls) for name, cls in database_by_scheme.items()}
100-
self.conn_cache = WeakCache()
102+
self.conn_cache = weakref.WeakValueDictionary()
101103

102104
def for_databases(self, *dbs):
103105
database_by_scheme = {k: db for k, db in self.database_by_scheme.items() if k in dbs}
@@ -262,9 +264,10 @@ def __call__(
262264
>>> connect({"driver": "mysql", "host": "localhost", "database": "db"})
263265
<data_diff.sqeleton.databases.mysql.MySQL object at ...>
264266
"""
267+
cache_key = self.__make_cache_key(db_conf)
265268
if shared:
266269
with suppress(KeyError):
267-
conn = self.conn_cache.get(db_conf)
270+
conn = self.conn_cache[cache_key]
268271
if not conn.is_closed:
269272
return conn
270273

@@ -276,5 +279,10 @@ def __call__(
276279
raise TypeError(f"db configuration must be a URI string or a dictionary. Instead got '{db_conf}'.")
277280

278281
if shared:
279-
self.conn_cache.add(db_conf, conn)
282+
self.conn_cache[cache_key] = conn
280283
return conn
284+
285+
def __make_cache_key(self, db_conf: Union[str, dict]) -> Hashable:
286+
if isinstance(db_conf, dict):
287+
return tuple(db_conf.items())
288+
return db_conf

data_diff/sqeleton/utils.py

Lines changed: 0 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,10 @@
66
Any,
77
Sequence,
88
Dict,
9-
Hashable,
109
TypeVar,
11-
TYPE_CHECKING,
1210
List,
1311
)
1412
from abc import abstractmethod
15-
from weakref import ref
1613
import math
1714
import string
1815
import re
@@ -27,30 +24,6 @@
2724
Self = Any
2825

2926

30-
class WeakCache:
31-
def __init__(self):
32-
self._cache = {}
33-
34-
def _hashable_key(self, k: Union[dict, Hashable]) -> Hashable:
35-
if isinstance(k, dict):
36-
return tuple(k.items())
37-
return k
38-
39-
def add(self, key: Union[dict, Hashable], value: Any):
40-
key = self._hashable_key(key)
41-
self._cache[key] = ref(value)
42-
43-
def get(self, key: Union[dict, Hashable]) -> Any:
44-
key = self._hashable_key(key)
45-
46-
value = self._cache[key]()
47-
if value is None:
48-
del self._cache[key]
49-
raise KeyError(f"Key {key} not found, or no longer a valid reference")
50-
51-
return value
52-
53-
5427
def join_iter(joiner: Any, iterable: Iterable) -> Iterable:
5528
it = iter(iterable)
5629
try:

0 commit comments

Comments
 (0)