Skip to content

Commit d0cb160

Browse files
authored
Merge pull request #6580 from blueyed/typing-testdir-init
typing: Testdir.__init__
2 parents aa318e9 + ad0f4f0 commit d0cb160

File tree

3 files changed

+29
-9
lines changed

3 files changed

+29
-9
lines changed

src/_pytest/config/__init__.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,13 @@
4848
from typing import Type
4949

5050

51+
_PluggyPlugin = object
52+
"""A type to represent plugin objects.
53+
Plugins can be any namespace, so we can't narrow it down much, but we use an
54+
alias to make the intent clear.
55+
Ideally this type would be provided by pluggy itself."""
56+
57+
5158
hookimpl = HookimplMarker("pytest")
5259
hookspec = HookspecMarker("pytest")
5360

src/_pytest/pytester.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -29,12 +29,17 @@
2929
from _pytest.capture import MultiCapture
3030
from _pytest.capture import SysCapture
3131
from _pytest.compat import TYPE_CHECKING
32+
from _pytest.config import _PluggyPlugin
3233
from _pytest.fixtures import FixtureRequest
3334
from _pytest.main import ExitCode
3435
from _pytest.main import Session
3536
from _pytest.monkeypatch import MonkeyPatch
37+
from _pytest.nodes import Collector
38+
from _pytest.nodes import Item
3639
from _pytest.pathlib import Path
40+
from _pytest.python import Module
3741
from _pytest.reports import TestReport
42+
from _pytest.tmpdir import TempdirFactory
3843

3944
if TYPE_CHECKING:
4045
from typing import Type
@@ -534,13 +539,15 @@ class Testdir:
534539
class TimeoutExpired(Exception):
535540
pass
536541

537-
def __init__(self, request, tmpdir_factory):
542+
def __init__(self, request: FixtureRequest, tmpdir_factory: TempdirFactory) -> None:
538543
self.request = request
539-
self._mod_collections = WeakKeyDictionary()
544+
self._mod_collections = (
545+
WeakKeyDictionary()
546+
) # type: WeakKeyDictionary[Module, List[Union[Item, Collector]]]
540547
name = request.function.__name__
541548
self.tmpdir = tmpdir_factory.mktemp(name, numbered=True)
542549
self.test_tmproot = tmpdir_factory.mktemp("tmp-" + name, numbered=True)
543-
self.plugins = []
550+
self.plugins = [] # type: List[Union[str, _PluggyPlugin]]
544551
self._cwd_snapshot = CwdSnapshot()
545552
self._sys_path_snapshot = SysPathsSnapshot()
546553
self._sys_modules_snapshot = self.__take_sys_modules_snapshot()
@@ -1064,7 +1071,9 @@ def getmodulecol(self, source, configargs=(), withinit=False):
10641071
self.config = config = self.parseconfigure(path, *configargs)
10651072
return self.getnode(config, path)
10661073

1067-
def collect_by_name(self, modcol, name):
1074+
def collect_by_name(
1075+
self, modcol: Module, name: str
1076+
) -> Optional[Union[Item, Collector]]:
10681077
"""Return the collection node for name from the module collection.
10691078
10701079
This will search a module collection node for a collection node
@@ -1073,13 +1082,13 @@ def collect_by_name(self, modcol, name):
10731082
:param modcol: a module collection node; see :py:meth:`getmodulecol`
10741083
10751084
:param name: the name of the node to return
1076-
10771085
"""
10781086
if modcol not in self._mod_collections:
10791087
self._mod_collections[modcol] = list(modcol.collect())
10801088
for colitem in self._mod_collections[modcol]:
10811089
if colitem.name == name:
10821090
return colitem
1091+
return None
10831092

10841093
def popen(
10851094
self,

testing/test_collection.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from _pytest.main import _in_venv
1010
from _pytest.main import ExitCode
1111
from _pytest.main import Session
12+
from _pytest.pytester import Testdir
1213

1314

1415
class TestCollector:
@@ -18,7 +19,7 @@ def test_collect_versus_item(self):
1819
assert not issubclass(Collector, Item)
1920
assert not issubclass(Item, Collector)
2021

21-
def test_check_equality(self, testdir):
22+
def test_check_equality(self, testdir: Testdir) -> None:
2223
modcol = testdir.getmodulecol(
2324
"""
2425
def test_pass(): pass
@@ -40,12 +41,15 @@ def test_fail(): assert 0
4041
assert fn1 != fn3
4142

4243
for fn in fn1, fn2, fn3:
43-
assert fn != 3
44+
assert isinstance(fn, pytest.Function)
45+
assert fn != 3 # type: ignore[comparison-overlap] # noqa: F821
4446
assert fn != modcol
45-
assert fn != [1, 2, 3]
46-
assert [1, 2, 3] != fn
47+
assert fn != [1, 2, 3] # type: ignore[comparison-overlap] # noqa: F821
48+
assert [1, 2, 3] != fn # type: ignore[comparison-overlap] # noqa: F821
4749
assert modcol != fn
4850

51+
assert testdir.collect_by_name(modcol, "doesnotexist") is None
52+
4953
def test_getparent(self, testdir):
5054
modcol = testdir.getmodulecol(
5155
"""

0 commit comments

Comments
 (0)