2727
2828py_functools = import_helper .import_fresh_module ('functools' ,
2929 blocked = ['_functools' ])
30- c_functools = import_helper .import_fresh_module ('functools' )
30+ c_functools = import_helper .import_fresh_module ('functools' ,
31+ fresh = ['_functools' ])
3132
3233decimal = import_helper .import_fresh_module ('decimal' , fresh = ['_decimal' ])
3334
35+ _partial_types = [py_functools .partial ]
36+ if c_functools :
37+ _partial_types .append (c_functools .partial )
38+
39+
3440@contextlib .contextmanager
3541def replaced_module (name , replacement ):
3642 original_module = sys .modules [name ]
@@ -202,7 +208,7 @@ def test_repr(self):
202208 kwargs = {'a' : object (), 'b' : object ()}
203209 kwargs_reprs = ['a={a!r}, b={b!r}' .format_map (kwargs ),
204210 'b={b!r}, a={a!r}' .format_map (kwargs )]
205- if self .partial in ( c_functools . partial , py_functools . partial ) :
211+ if self .partial in _partial_types :
206212 name = 'functools.partial'
207213 else :
208214 name = self .partial .__name__
@@ -224,7 +230,7 @@ def test_repr(self):
224230 for kwargs_repr in kwargs_reprs ])
225231
226232 def test_recursive_repr (self ):
227- if self .partial in ( c_functools . partial , py_functools . partial ) :
233+ if self .partial in _partial_types :
228234 name = 'functools.partial'
229235 else :
230236 name = self .partial .__name__
@@ -251,7 +257,7 @@ def test_recursive_repr(self):
251257 f .__setstate__ ((capture , (), {}, {}))
252258
253259 def test_pickle (self ):
254- with self .AllowPickle ( ):
260+ with replaced_module ( 'functools' , self .module ):
255261 f = self .partial (signature , ['asdf' ], bar = [True ])
256262 f .attr = []
257263 for proto in range (pickle .HIGHEST_PROTOCOL + 1 ):
@@ -334,7 +340,7 @@ def test_setstate_subclasses(self):
334340 self .assertIs (type (r [0 ]), tuple )
335341
336342 def test_recursive_pickle (self ):
337- with self .AllowPickle ( ):
343+ with replaced_module ( 'functools' , self .module ):
338344 f = self .partial (capture )
339345 f .__setstate__ ((f , (), {}, {}))
340346 try :
@@ -388,14 +394,9 @@ def __getitem__(self, key):
388394@unittest .skipUnless (c_functools , 'requires the C _functools module' )
389395class TestPartialC (TestPartial , unittest .TestCase ):
390396 if c_functools :
397+ module = c_functools
391398 partial = c_functools .partial
392399
393- class AllowPickle :
394- def __enter__ (self ):
395- return self
396- def __exit__ (self , type , value , tb ):
397- return False
398-
399400 def test_attributes_unwritable (self ):
400401 # attributes should not be writable
401402 p = self .partial (capture , 1 , 2 , a = 10 , b = 20 )
@@ -438,15 +439,9 @@ def __str__(self):
438439
439440
440441class TestPartialPy (TestPartial , unittest .TestCase ):
442+ module = py_functools
441443 partial = py_functools .partial
442444
443- class AllowPickle :
444- def __init__ (self ):
445- self ._cm = replaced_module ("functools" , py_functools )
446- def __enter__ (self ):
447- return self ._cm .__enter__ ()
448- def __exit__ (self , type , value , tb ):
449- return self ._cm .__exit__ (type , value , tb )
450445
451446if c_functools :
452447 class CPartialSubclass (c_functools .partial ):
@@ -1860,9 +1855,10 @@ def test_staticmethod(x):
18601855def py_cached_func (x , y ):
18611856 return 3 * x + y
18621857
1863- @c_functools .lru_cache ()
1864- def c_cached_func (x , y ):
1865- return 3 * x + y
1858+ if c_functools :
1859+ @c_functools .lru_cache ()
1860+ def c_cached_func (x , y ):
1861+ return 3 * x + y
18661862
18671863
18681864class TestLRUPy (TestLRU , unittest .TestCase ):
@@ -1879,18 +1875,20 @@ def cached_staticmeth(x, y):
18791875 return 3 * x + y
18801876
18811877
1878+ @unittest .skipUnless (c_functools , 'requires the C _functools module' )
18821879class TestLRUC (TestLRU , unittest .TestCase ):
1883- module = c_functools
1884- cached_func = c_cached_func ,
1880+ if c_functools :
1881+ module = c_functools
1882+ cached_func = c_cached_func ,
18851883
1886- @module .lru_cache ()
1887- def cached_meth (self , x , y ):
1888- return 3 * x + y
1884+ @module .lru_cache ()
1885+ def cached_meth (self , x , y ):
1886+ return 3 * x + y
18891887
1890- @staticmethod
1891- @module .lru_cache ()
1892- def cached_staticmeth (x , y ):
1893- return 3 * x + y
1888+ @staticmethod
1889+ @module .lru_cache ()
1890+ def cached_staticmeth (x , y ):
1891+ return 3 * x + y
18941892
18951893
18961894class TestSingleDispatch (unittest .TestCase ):
0 commit comments