2626
2727py_functools = import_helper .import_fresh_module ('functools' ,
2828 blocked = ['_functools' ])
29- c_functools = import_helper .import_fresh_module ('functools' )
29+ c_functools = import_helper .import_fresh_module ('functools' ,
30+ fresh = ['_functools' ])
3031
3132decimal = import_helper .import_fresh_module ('decimal' , fresh = ['_decimal' ])
3233
34+ _partial_types = [py_functools .partial ]
35+ if c_functools :
36+ _partial_types .append (c_functools .partial )
37+
38+
3339@contextlib .contextmanager
3440def replaced_module (name , replacement ):
3541 original_module = sys .modules [name ]
@@ -201,7 +207,7 @@ def test_repr(self):
201207 kwargs = {'a' : object (), 'b' : object ()}
202208 kwargs_reprs = ['a={a!r}, b={b!r}' .format_map (kwargs ),
203209 'b={b!r}, a={a!r}' .format_map (kwargs )]
204- if self .partial in ( c_functools . partial , py_functools . partial ) :
210+ if self .partial in _partial_types :
205211 name = 'functools.partial'
206212 else :
207213 name = self .partial .__name__
@@ -223,7 +229,7 @@ def test_repr(self):
223229 for kwargs_repr in kwargs_reprs ])
224230
225231 def test_recursive_repr (self ):
226- if self .partial in ( c_functools . partial , py_functools . partial ) :
232+ if self .partial in _partial_types :
227233 name = 'functools.partial'
228234 else :
229235 name = self .partial .__name__
@@ -250,7 +256,7 @@ def test_recursive_repr(self):
250256 f .__setstate__ ((capture , (), {}, {}))
251257
252258 def test_pickle (self ):
253- with self .AllowPickle ( ):
259+ with replaced_module ( 'functools' , self .module ):
254260 f = self .partial (signature , ['asdf' ], bar = [True ])
255261 f .attr = []
256262 for proto in range (pickle .HIGHEST_PROTOCOL + 1 ):
@@ -333,7 +339,7 @@ def test_setstate_subclasses(self):
333339 self .assertIs (type (r [0 ]), tuple )
334340
335341 def test_recursive_pickle (self ):
336- with self .AllowPickle ( ):
342+ with replaced_module ( 'functools' , self .module ):
337343 f = self .partial (capture )
338344 f .__setstate__ ((f , (), {}, {}))
339345 try :
@@ -387,14 +393,9 @@ def __getitem__(self, key):
387393@unittest .skipUnless (c_functools , 'requires the C _functools module' )
388394class TestPartialC (TestPartial , unittest .TestCase ):
389395 if c_functools :
396+ module = c_functools
390397 partial = c_functools .partial
391398
392- class AllowPickle :
393- def __enter__ (self ):
394- return self
395- def __exit__ (self , type , value , tb ):
396- return False
397-
398399 def test_attributes_unwritable (self ):
399400 # attributes should not be writable
400401 p = self .partial (capture , 1 , 2 , a = 10 , b = 20 )
@@ -437,15 +438,9 @@ def __str__(self):
437438
438439
439440class TestPartialPy (TestPartial , unittest .TestCase ):
441+ module = py_functools
440442 partial = py_functools .partial
441443
442- class AllowPickle :
443- def __init__ (self ):
444- self ._cm = replaced_module ("functools" , py_functools )
445- def __enter__ (self ):
446- return self ._cm .__enter__ ()
447- def __exit__ (self , type , value , tb ):
448- return self ._cm .__exit__ (type , value , tb )
449444
450445if c_functools :
451446 class CPartialSubclass (c_functools .partial ):
@@ -1872,9 +1867,10 @@ def orig(): ...
18721867def py_cached_func (x , y ):
18731868 return 3 * x + y
18741869
1875- @c_functools .lru_cache ()
1876- def c_cached_func (x , y ):
1877- return 3 * x + y
1870+ if c_functools :
1871+ @c_functools .lru_cache ()
1872+ def c_cached_func (x , y ):
1873+ return 3 * x + y
18781874
18791875
18801876class TestLRUPy (TestLRU , unittest .TestCase ):
@@ -1891,18 +1887,20 @@ def cached_staticmeth(x, y):
18911887 return 3 * x + y
18921888
18931889
1890+ @unittest .skipUnless (c_functools , 'requires the C _functools module' )
18941891class TestLRUC (TestLRU , unittest .TestCase ):
1895- module = c_functools
1896- cached_func = c_cached_func ,
1892+ if c_functools :
1893+ module = c_functools
1894+ cached_func = c_cached_func ,
18971895
1898- @module .lru_cache ()
1899- def cached_meth (self , x , y ):
1900- return 3 * x + y
1896+ @module .lru_cache ()
1897+ def cached_meth (self , x , y ):
1898+ return 3 * x + y
19011899
1902- @staticmethod
1903- @module .lru_cache ()
1904- def cached_staticmeth (x , y ):
1905- return 3 * x + y
1900+ @staticmethod
1901+ @module .lru_cache ()
1902+ def cached_staticmeth (x , y ):
1903+ return 3 * x + y
19061904
19071905
19081906class TestSingleDispatch (unittest .TestCase ):
0 commit comments