Skip to content

Commit cdc7605

Browse files
authored
mockimport: handle "from foo import bar" (fromlist) (#265)
1 parent 9edb9fd commit cdc7605

File tree

2 files changed

+51
-8
lines changed

2 files changed

+51
-8
lines changed

src/_pytest/monkeypatch.py

Lines changed: 36 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,21 @@
33
import sys
44
import warnings
55
from contextlib import contextmanager
6+
from types import FunctionType
67
from typing import Generator
8+
from typing import List
79
from typing import Optional
10+
from typing import Sequence
11+
from typing import Union
812

913
import pytest
14+
from _pytest.compat import TYPE_CHECKING
1015
from _pytest.fixtures import fixture
1116
from _pytest.pathlib import Path
1217

18+
if TYPE_CHECKING:
19+
from typing import Type # noqa: F401
20+
1321

1422
@fixture
1523
def monkeypatch():
@@ -264,20 +272,40 @@ def delenv(self, name: str, raising: bool = True) -> None:
264272
"""
265273
self.delitem(os.environ, name, raising=raising)
266274

267-
def mockimport(self, mocked_imports, err=ImportError):
275+
def mockimport(
276+
self,
277+
mocked_imports: Union[str, Sequence[str]],
278+
err: Union[FunctionType, "Type[BaseException]"] = ImportError,
279+
):
280+
"""Mock import with given error to be raised, or callable.
281+
282+
The callable gets called instead of :func:`python:__import__`.
283+
284+
This is considered to be **experimental**.
285+
"""
268286
import inspect
269287

270288
if not hasattr(self, "_mocked_imports"):
271289
self._mocked_imports = {}
272290
self._orig_import = __import__
273291

274-
def import_mock(name, *args, **kwargs):
275-
if name in self._mocked_imports:
276-
err = self._mocked_imports[name]
277-
if inspect.isfunction(err):
278-
return err(name, *args, **kwargs)
279-
raise err
280-
return self._orig_import(name, *args, **kwargs)
292+
def import_mock(*args, **kwargs):
293+
name = kwargs.get("name", args[0]) # type: str
294+
fromlist = kwargs.get(
295+
"fromlist", args[3] if len(args) > 3 else []
296+
) # type: List[str]
297+
if fromlist:
298+
req_names = ["{}.{}".format(name, x) for x in fromlist]
299+
else:
300+
req_names = [name]
301+
302+
for _name in req_names:
303+
if _name in self._mocked_imports:
304+
err = self._mocked_imports[_name]
305+
if inspect.isfunction(err):
306+
return err(*args, **kwargs)
307+
raise err
308+
return self._orig_import(*args, **kwargs)
281309

282310
self.setattr("builtins.__import__", import_mock)
283311

testing/test_monkeypatch.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -496,6 +496,14 @@ def mockedimport(*args, **kwargs):
496496
(("os.foo", globals()), {"level": 42}),
497497
]
498498

499+
# With fromlist.
500+
calls[:] = []
501+
with pytest.raises(ImportError):
502+
from os import foo # noqa: F401
503+
assert calls == [
504+
(("os", globals(), None, ("foo",), 0), {}),
505+
]
506+
499507

500508
def test_mockimport_importlib(monkeypatch):
501509
"""importlib.import_module is not patched"""
@@ -511,3 +519,10 @@ def test_mockimport_already_imported(monkeypatch):
511519
monkeypatch.mockimport("os", TypeError)
512520
with pytest.raises(TypeError):
513521
import os # noqa: F401
522+
523+
524+
def test_mockimport_fromlist(monkeypatch):
525+
monkeypatch.mockimport(("os.foo",), TypeError)
526+
527+
with pytest.raises(TypeError):
528+
from os import foo # noqa: F401

0 commit comments

Comments
 (0)