55import contextlib
66import types
77import importlib
8+ import inspect
9+ import warnings
10+ import itertools
811
9- from typing import Union , Optional
12+ from typing import Union , Optional , cast
1013from .abc import ResourceReader , Traversable
1114
1215from ._adapters import wrap_spec
1316
1417Package = Union [types .ModuleType , str ]
18+ Anchor = Package
1519
1620
17- def files (package ):
18- # type: (Package) -> Traversable
21+ def package_to_anchor (func ):
1922 """
20- Get a Traversable resource from a package
23+ Replace 'package' parameter as 'anchor' and warn about the change.
24+
25+ Other errors should fall through.
26+
27+ >>> files('a', 'b')
28+ Traceback (most recent call last):
29+ TypeError: files() takes from 0 to 1 positional arguments but 2 were given
30+ """
31+ undefined = object ()
32+
33+ @functools .wraps (func )
34+ def wrapper (anchor = undefined , package = undefined ):
35+ if package is not undefined :
36+ if anchor is not undefined :
37+ return func (anchor , package )
38+ warnings .warn (
39+ "First parameter to files is renamed to 'anchor'" ,
40+ DeprecationWarning ,
41+ stacklevel = 2 ,
42+ )
43+ return func (package )
44+ elif anchor is undefined :
45+ return func ()
46+ return func (anchor )
47+
48+ return wrapper
49+
50+
51+ @package_to_anchor
52+ def files (anchor : Optional [Anchor ] = None ) -> Traversable :
53+ """
54+ Get a Traversable resource for an anchor.
2155 """
22- return from_package (get_package ( package ))
56+ return from_package (resolve ( anchor ))
2357
2458
25- def get_resource_reader (package ):
26- # type: (types.ModuleType) -> Optional[ResourceReader]
59+ def get_resource_reader (package : types .ModuleType ) -> Optional [ResourceReader ]:
2760 """
2861 Return the package's loader if it's a ResourceReader.
2962 """
@@ -39,24 +72,39 @@ def get_resource_reader(package):
3972 return reader (spec .name ) # type: ignore
4073
4174
42- def resolve (cand ):
43- # type: (Package) -> types.ModuleType
44- return cand if isinstance (cand , types .ModuleType ) else importlib .import_module (cand )
75+ @functools .singledispatch
76+ def resolve (cand : Optional [Anchor ]) -> types .ModuleType :
77+ return cast (types .ModuleType , cand )
78+
79+
80+ @resolve .register
81+ def _ (cand : str ) -> types .ModuleType :
82+ return importlib .import_module (cand )
83+
4584
85+ @resolve .register
86+ def _ (cand : None ) -> types .ModuleType :
87+ return resolve (_infer_caller ().f_globals ['__name__' ])
4688
47- def get_package (package ):
48- # type: (Package) -> types.ModuleType
49- """Take a package name or module object and return the module.
5089
51- Raise an exception if the resolved module is not a package.
90+ def _infer_caller ():
5291 """
53- resolved = resolve (package )
54- if wrap_spec (resolved ).submodule_search_locations is None :
55- raise TypeError (f'{ package !r} is not a package' )
56- return resolved
92+ Walk the stack and find the frame of the first caller not in this module.
93+ """
94+
95+ def is_this_file (frame_info ):
96+ return frame_info .filename == __file__
97+
98+ def is_wrapper (frame_info ):
99+ return frame_info .function == 'wrapper'
100+
101+ not_this_file = itertools .filterfalse (is_this_file , inspect .stack ())
102+ # also exclude 'wrapper' due to singledispatch in the call stack
103+ callers = itertools .filterfalse (is_wrapper , not_this_file )
104+ return next (callers ).frame
57105
58106
59- def from_package (package ):
107+ def from_package (package : types . ModuleType ):
60108 """
61109 Return a Traversable object for the given package.
62110
0 commit comments