Skip to content

Commit 4407867

Browse files
Exempt numpy from these changes
1 parent 7297dde commit 4407867

File tree

2 files changed

+9
-5
lines changed

2 files changed

+9
-5
lines changed

astroid/interpreter/_import/spec.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -160,11 +160,14 @@ def find_module(
160160
pass
161161
submodule_path = sys.path
162162

163+
# We're looping on pyi first because if a pyi exists there's probably a reason
164+
# (i.e. the code is hard or impossible to parse), so we take pyi into account
165+
# But we're not quite ready to do this for numpy
166+
suffixes = (".pyi", ".py", importlib.machinery.BYTECODE_SUFFIXES[0])
167+
numpy_suffixes = (".py", ".pyi", importlib.machinery.BYTECODE_SUFFIXES[0])
163168
for entry in submodule_path:
164169
package_directory = os.path.join(entry, modname)
165-
# We're looping on pyi first because if a pyi exists there's probably a reason
166-
# (i.e. the code is hard or impossible to parse), so we take pyi into account
167-
for suffix in (".pyi", ".py", importlib.machinery.BYTECODE_SUFFIXES[0]):
170+
for suffix in numpy_suffixes if "numpy" in entry else suffixes:
168171
package_file_name = "__init__" + suffix
169172
file_path = os.path.join(package_directory, package_file_name)
170173
if os.path.isfile(file_path):

astroid/modutils.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -499,7 +499,7 @@ def get_source_file(filename: str, include_no_ext: bool = False) -> str:
499499
base, orig_ext = os.path.splitext(filename)
500500
if orig_ext == ".pyi" and os.path.exists(f"{base}{orig_ext}"):
501501
return f"{base}{orig_ext}"
502-
for ext in PY_SOURCE_EXTS:
502+
for ext in PY_SOURCE_EXTS if "numpy" not in filename else reversed(PY_SOURCE_EXTS):
503503
source_path = f"{base}.{ext}"
504504
if os.path.exists(source_path):
505505
return source_path
@@ -671,7 +671,8 @@ def _has_init(directory: str) -> str | None:
671671
else return None.
672672
"""
673673
mod_or_pack = os.path.join(directory, "__init__")
674-
for ext in (*PY_SOURCE_EXTS, "pyc", "pyo"):
674+
exts = reversed(PY_SOURCE_EXTS) if "numpy" in directory else PY_SOURCE_EXTS
675+
for ext in (*exts, "pyc", "pyo"):
675676
if os.path.exists(mod_or_pack + "." + ext):
676677
return mod_or_pack + "." + ext
677678
return None

0 commit comments

Comments
 (0)