|
3 | 3 | import sys |
4 | 4 | import warnings |
5 | 5 | from contextlib import contextmanager |
| 6 | +from types import FunctionType |
6 | 7 | from typing import Generator |
| 8 | +from typing import List |
7 | 9 | from typing import Optional |
| 10 | +from typing import Sequence |
| 11 | +from typing import Union |
8 | 12 |
|
9 | 13 | import pytest |
| 14 | +from _pytest.compat import TYPE_CHECKING |
10 | 15 | from _pytest.fixtures import fixture |
11 | 16 | from _pytest.pathlib import Path |
12 | 17 |
|
| 18 | +if TYPE_CHECKING: |
| 19 | + from typing import Type # noqa: F401 |
| 20 | + |
13 | 21 |
|
14 | 22 | @fixture |
15 | 23 | def monkeypatch(): |
@@ -264,20 +272,40 @@ def delenv(self, name: str, raising: bool = True) -> None: |
264 | 272 | """ |
265 | 273 | self.delitem(os.environ, name, raising=raising) |
266 | 274 |
|
267 | | - def mockimport(self, mocked_imports, err=ImportError): |
| 275 | + def mockimport( |
| 276 | + self, |
| 277 | + mocked_imports: Union[str, Sequence[str]], |
| 278 | + err: Union[FunctionType, "Type[BaseException]"] = ImportError, |
| 279 | + ): |
| 280 | + """Mock import with given error to be raised, or callable. |
| 281 | +
|
| 282 | + The callable gets called instead of :func:`python:__import__`. |
| 283 | +
|
| 284 | + This is considered to be **experimental**. |
| 285 | + """ |
268 | 286 | import inspect |
269 | 287 |
|
270 | 288 | if not hasattr(self, "_mocked_imports"): |
271 | 289 | self._mocked_imports = {} |
272 | 290 | self._orig_import = __import__ |
273 | 291 |
|
274 | | - def import_mock(name, *args, **kwargs): |
275 | | - if name in self._mocked_imports: |
276 | | - err = self._mocked_imports[name] |
277 | | - if inspect.isfunction(err): |
278 | | - return err(name, *args, **kwargs) |
279 | | - raise err |
280 | | - return self._orig_import(name, *args, **kwargs) |
| 292 | + def import_mock(*args, **kwargs): |
| 293 | + name = kwargs.get("name", args[0]) # type: str |
| 294 | + fromlist = kwargs.get( |
| 295 | + "fromlist", args[3] if len(args) > 3 else [] |
| 296 | + ) # type: List[str] |
| 297 | + if fromlist: |
| 298 | + req_names = ["{}.{}".format(name, x) for x in fromlist] |
| 299 | + else: |
| 300 | + req_names = [name] |
| 301 | + |
| 302 | + for _name in req_names: |
| 303 | + if _name in self._mocked_imports: |
| 304 | + err = self._mocked_imports[_name] |
| 305 | + if inspect.isfunction(err): |
| 306 | + return err(*args, **kwargs) |
| 307 | + raise err |
| 308 | + return self._orig_import(*args, **kwargs) |
281 | 309 |
|
282 | 310 | self.setattr("builtins.__import__", import_mock) |
283 | 311 |
|
|
0 commit comments