1
- from typing import Type , Optional , Union , Dict
1
+ from typing import Hashable , MutableMapping , Type , Optional , Union , Dict
2
2
from itertools import zip_longest
3
3
from contextlib import suppress
4
+ import weakref
4
5
import dsnparse
5
6
import toml
6
7
7
8
from runtype import dataclass
8
9
9
10
from ..abcs .mixins import AbstractMixin
10
- from ..utils import WeakCache , Self
11
+ from ..utils import Self
11
12
from .base import Database , ThreadedDatabase
12
13
from .postgresql import PostgreSQL
13
14
from .mysql import MySQL
@@ -93,11 +94,12 @@ def match_path(self, dsn):
93
94
94
95
class Connect :
95
96
"""Provides methods for connecting to a supported database using a URL or connection dict."""
97
+ conn_cache : MutableMapping [Hashable , Database ]
96
98
97
99
def __init__ (self , database_by_scheme : Dict [str , Database ] = DATABASE_BY_SCHEME ):
98
100
self .database_by_scheme = database_by_scheme
99
101
self .match_uri_path = {name : MatchUriPath (cls ) for name , cls in database_by_scheme .items ()}
100
- self .conn_cache = WeakCache ()
102
+ self .conn_cache = weakref . WeakValueDictionary ()
101
103
102
104
def for_databases (self , * dbs ):
103
105
database_by_scheme = {k : db for k , db in self .database_by_scheme .items () if k in dbs }
@@ -262,9 +264,10 @@ def __call__(
262
264
>>> connect({"driver": "mysql", "host": "localhost", "database": "db"})
263
265
<data_diff.sqeleton.databases.mysql.MySQL object at ...>
264
266
"""
267
+ cache_key = self .__make_cache_key (db_conf )
265
268
if shared :
266
269
with suppress (KeyError ):
267
- conn = self .conn_cache . get ( db_conf )
270
+ conn = self .conn_cache [ cache_key ]
268
271
if not conn .is_closed :
269
272
return conn
270
273
@@ -276,5 +279,10 @@ def __call__(
276
279
raise TypeError (f"db configuration must be a URI string or a dictionary. Instead got '{ db_conf } '." )
277
280
278
281
if shared :
279
- self .conn_cache . add ( db_conf , conn )
282
+ self .conn_cache [ cache_key ] = conn
280
283
return conn
284
+
285
+ def __make_cache_key (self , db_conf : Union [str , dict ]) -> Hashable :
286
+ if isinstance (db_conf , dict ):
287
+ return tuple (db_conf .items ())
288
+ return db_conf
0 commit comments