From 4508b985364a685c34eebb691374a2c4a3e96908 Mon Sep 17 00:00:00 2001 From: JP-Ellis Date: Thu, 2 Nov 2023 08:29:10 +1100 Subject: [PATCH 1/3] fix(typing): improve decorator type hinting The type hinting for the most commonly used decorators were incomplete, resulting in decorated functions being obscured. This makes use of the special type variable `ParamSpec` which allows the type hinting a view into the parameters of a function. As ``ParamSpec` was introduced in Python 3.10, `ParamSpec` is imported from the `typing_extensions` module instead of the standard library. I have also taken the opportunity to fix other instances of `Callable` type hints missing their arguments. Signed-off-by: JP-Ellis --- src/pytest_bdd/plugin.py | 23 +++++++++++++++-------- src/pytest_bdd/reporting.py | 10 ++++++++-- src/pytest_bdd/scenario.py | 20 ++++++++++++-------- src/pytest_bdd/steps.py | 22 ++++++++++++---------- src/pytest_bdd/utils.py | 2 +- 5 files changed, 48 insertions(+), 29 deletions(-) diff --git a/src/pytest_bdd/plugin.py b/src/pytest_bdd/plugin.py index 486cdf87e..ccee01138 100644 --- a/src/pytest_bdd/plugin.py +++ b/src/pytest_bdd/plugin.py @@ -1,16 +1,15 @@ """Pytest plugin entry point. Used for any fixtures needed.""" from __future__ import annotations -from typing import TYPE_CHECKING, Callable, cast +from typing import TYPE_CHECKING, Any, Callable, Generator, TypeVar, cast import pytest +from typing_extensions import ParamSpec from . import cucumber_json, generation, gherkin_terminal_reporter, given, reporting, then, when from .utils import CONFIG_STACK if TYPE_CHECKING: - from typing import Any, Generator - from _pytest.config import Config, PytestPluginManager from _pytest.config.argparsing import Parser from _pytest.fixtures import FixtureRequest @@ -21,6 +20,10 @@ from .parser import Feature, Scenario, Step +P = ParamSpec("P") +T = TypeVar("T") + + def pytest_addhooks(pluginmanager: PytestPluginManager) -> None: """Register plugin hooks.""" from pytest_bdd import hooks @@ -93,7 +96,7 @@ def pytest_bdd_step_error( feature: Feature, scenario: Scenario, step: Step, - step_func: Callable, + step_func: Callable[..., Any], step_func_args: dict, exception: Exception, ) -> None: @@ -102,7 +105,11 @@ def pytest_bdd_step_error( @pytest.hookimpl(tryfirst=True) def pytest_bdd_before_step( - request: FixtureRequest, feature: Feature, scenario: Scenario, step: Step, step_func: Callable + request: FixtureRequest, + feature: Feature, + scenario: Scenario, + step: Step, + step_func: Callable[..., Any], ) -> None: reporting.before_step(request, feature, scenario, step, step_func) @@ -113,7 +120,7 @@ def pytest_bdd_after_step( feature: Feature, scenario: Scenario, step: Step, - step_func: Callable, + step_func: Callable[..., Any], step_func_args: dict[str, Any], ) -> None: reporting.after_step(request, feature, scenario, step, step_func, step_func_args) @@ -123,7 +130,7 @@ def pytest_cmdline_main(config: Config) -> int | None: return generation.cmdline_main(config) -def pytest_bdd_apply_tag(tag: str, function: Callable) -> Callable: +def pytest_bdd_apply_tag(tag: str, function: Callable[P, T]) -> Callable[P, T]: mark = getattr(pytest.mark, tag) marked = mark(function) - return cast(Callable, marked) + return cast(Callable[P, T], marked) diff --git a/src/pytest_bdd/reporting.py b/src/pytest_bdd/reporting.py index 26e1cb0e2..95254f648 100644 --- a/src/pytest_bdd/reporting.py +++ b/src/pytest_bdd/reporting.py @@ -155,7 +155,7 @@ def step_error( feature: Feature, scenario: Scenario, step: Step, - step_func: Callable, + step_func: Callable[..., Any], step_func_args: dict, exception: Exception, ) -> None: @@ -163,7 +163,13 @@ def step_error( request.node.__scenario_report__.fail() -def before_step(request: FixtureRequest, feature: Feature, scenario: Scenario, step: Step, step_func: Callable) -> None: +def before_step( + request: FixtureRequest, + feature: Feature, + scenario: Scenario, + step: Step, + step_func: Callable[..., Any], +) -> None: """Store step start time.""" request.node.__scenario_report__.add_step_report(StepReport(step=step)) diff --git a/src/pytest_bdd/scenario.py b/src/pytest_bdd/scenario.py index df7c029c0..7a231ef50 100644 --- a/src/pytest_bdd/scenario.py +++ b/src/pytest_bdd/scenario.py @@ -16,11 +16,12 @@ import logging import os import re -from typing import TYPE_CHECKING, Callable, Iterator, cast +from typing import TYPE_CHECKING, Any, Callable, Iterable, Iterator, TypeVar, cast import pytest from _pytest.fixtures import FixtureDef, FixtureManager, FixtureRequest, call_fixture_func from _pytest.nodes import iterparentnodeids +from typing_extensions import ParamSpec from . import exceptions from .feature import get_feature, get_features @@ -28,12 +29,12 @@ from .utils import CONFIG_STACK, get_args, get_caller_module_locals, get_caller_module_path if TYPE_CHECKING: - from typing import Any, Iterable - from _pytest.mark.structures import ParameterSet from .parser import Feature, Scenario, ScenarioTemplate, Step +P = ParamSpec("P") +T = TypeVar("T") logger = logging.getLogger(__name__) @@ -197,14 +198,14 @@ def _execute_scenario(feature: Feature, scenario: Scenario, request: FixtureRequ def _get_scenario_decorator( feature: Feature, feature_name: str, templated_scenario: ScenarioTemplate, scenario_name: str -) -> Callable[[Callable], Callable]: +) -> Callable[[Callable[P, T]], Callable[P, T]]: # HACK: Ideally we would use `def decorator(fn)`, but we want to return a custom exception # when the decorator is misused. # Pytest inspect the signature to determine the required fixtures, and in that case it would look # for a fixture called "fn" that doesn't exist (if it exists then it's even worse). # It will error with a "fixture 'fn' not found" message instead. # We can avoid this hack by using a pytest hook and check for misuse instead. - def decorator(*args: Callable) -> Callable: + def decorator(*args: Callable[P, T]) -> Callable[P, T]: if not args: raise exceptions.ScenarioIsDecoratorOnly( "scenario function can only be used as a decorator. Refer to the documentation." @@ -236,7 +237,7 @@ def scenario_wrapper(request: FixtureRequest, _pytest_bdd_example: dict[str, str scenario_wrapper.__doc__ = f"{feature_name}: {scenario_name}" scenario_wrapper.__scenario__ = templated_scenario - return cast(Callable, scenario_wrapper) + return cast(Callable[P, T], scenario_wrapper) return decorator @@ -254,8 +255,11 @@ def collect_example_parametrizations( def scenario( - feature_name: str, scenario_name: str, encoding: str = "utf-8", features_base_dir=None -) -> Callable[[Callable], Callable]: + feature_name: str, + scenario_name: str, + encoding: str = "utf-8", + features_base_dir: str | None = None, +) -> Callable[[Callable[P, T]], Callable[P, T]]: """Scenario decorator. :param str feature_name: Feature file name. Absolute or relative to the configured feature base path. diff --git a/src/pytest_bdd/steps.py b/src/pytest_bdd/steps.py index b3d8be6cf..83c54e4c0 100644 --- a/src/pytest_bdd/steps.py +++ b/src/pytest_bdd/steps.py @@ -43,13 +43,15 @@ def _(article): import pytest from _pytest.fixtures import FixtureDef, FixtureRequest +from typing_extensions import ParamSpec from .parser import Step from .parsers import StepParser, get_parser from .types import GIVEN, THEN, WHEN from .utils import get_caller_module_locals -TCallable = TypeVar("TCallable", bound=Callable[..., Any]) +P = ParamSpec("P") +T = TypeVar("T") @enum.unique @@ -74,10 +76,10 @@ def get_step_fixture_name(step: Step) -> str: def given( name: str | StepParser, - converters: dict[str, Callable] | None = None, + converters: dict[str, Callable[[Any], Any]] | None = None, target_fixture: str | None = None, stacklevel: int = 1, -) -> Callable: +) -> Callable[[Callable[P, T]], Callable[P, T]]: """Given step decorator. :param name: Step name or a parser object. @@ -93,10 +95,10 @@ def given( def when( name: str | StepParser, - converters: dict[str, Callable] | None = None, + converters: dict[str, Callable[[Any], Any]] | None = None, target_fixture: str | None = None, stacklevel: int = 1, -) -> Callable: +) -> Callable[[Callable[P, T]], Callable[P, T]]: """When step decorator. :param name: Step name or a parser object. @@ -112,10 +114,10 @@ def when( def then( name: str | StepParser, - converters: dict[str, Callable] | None = None, + converters: dict[str, Callable[[Any], Any]] | None = None, target_fixture: str | None = None, stacklevel: int = 1, -) -> Callable: +) -> Callable[[Callable[P, T]], Callable[P, T]]: """Then step decorator. :param name: Step name or a parser object. @@ -132,10 +134,10 @@ def then( def step( name: str | StepParser, type_: Literal["given", "when", "then"] | None = None, - converters: dict[str, Callable] | None = None, + converters: dict[str, Callable[[Any], Any]] | None = None, target_fixture: str | None = None, stacklevel: int = 1, -) -> Callable[[TCallable], TCallable]: +) -> Callable[[Callable[P, T]], Callable[P, T]]: """Generic step decorator. :param name: Step name as in the feature file. @@ -155,7 +157,7 @@ def step( if converters is None: converters = {} - def decorator(func: TCallable) -> TCallable: + def decorator(func: Callable[P, T]) -> Callable[P, T]: parser = get_parser(name) context = StepFunctionContext( diff --git a/src/pytest_bdd/utils.py b/src/pytest_bdd/utils.py index 355407821..eb243e5d2 100644 --- a/src/pytest_bdd/utils.py +++ b/src/pytest_bdd/utils.py @@ -19,7 +19,7 @@ CONFIG_STACK: list[Config] = [] -def get_args(func: Callable) -> list[str]: +def get_args(func: Callable[..., Any]) -> list[str]: """Get a list of argument names for a function. :param func: The function to inspect. From 9c60589fb9324ef79a8416ae5cc3cbc84f78a34a Mon Sep 17 00:00:00 2001 From: Alessio Bogon <778703+youtux@users.noreply.github.com> Date: Sat, 2 Dec 2023 21:57:07 +0100 Subject: [PATCH 2/3] fix type for `StepFunctionContext.converters` --- src/pytest_bdd/steps.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/pytest_bdd/steps.py b/src/pytest_bdd/steps.py index 83c54e4c0..98f116d63 100644 --- a/src/pytest_bdd/steps.py +++ b/src/pytest_bdd/steps.py @@ -65,7 +65,7 @@ class StepFunctionContext: type: Literal["given", "when", "then"] | None step_func: Callable[..., Any] parser: StepParser - converters: dict[str, Callable[..., Any]] = field(default_factory=dict) + converters: dict[str, Callable[[str], Any]] = field(default_factory=dict) target_fixture: str | None = None @@ -76,7 +76,7 @@ def get_step_fixture_name(step: Step) -> str: def given( name: str | StepParser, - converters: dict[str, Callable[[Any], Any]] | None = None, + converters: dict[str, Callable[[str], Any]] | None = None, target_fixture: str | None = None, stacklevel: int = 1, ) -> Callable[[Callable[P, T]], Callable[P, T]]: @@ -95,7 +95,7 @@ def given( def when( name: str | StepParser, - converters: dict[str, Callable[[Any], Any]] | None = None, + converters: dict[str, Callable[[str], Any]] | None = None, target_fixture: str | None = None, stacklevel: int = 1, ) -> Callable[[Callable[P, T]], Callable[P, T]]: @@ -114,7 +114,7 @@ def when( def then( name: str | StepParser, - converters: dict[str, Callable[[Any], Any]] | None = None, + converters: dict[str, Callable[[str], Any]] | None = None, target_fixture: str | None = None, stacklevel: int = 1, ) -> Callable[[Callable[P, T]], Callable[P, T]]: @@ -134,7 +134,7 @@ def then( def step( name: str | StepParser, type_: Literal["given", "when", "then"] | None = None, - converters: dict[str, Callable[[Any], Any]] | None = None, + converters: dict[str, Callable[[str], Any]] | None = None, target_fixture: str | None = None, stacklevel: int = 1, ) -> Callable[[Callable[P, T]], Callable[P, T]]: From ebd76f5b4dcb1704d248ddede5208c1733cb9f3b Mon Sep 17 00:00:00 2001 From: Sourcery AI <> Date: Sat, 2 Dec 2023 20:57:17 +0000 Subject: [PATCH 3/3] 'Refactored by Sourcery' --- src/pytest_bdd/reporting.py | 5 +---- src/pytest_bdd/scenario.py | 13 +++++-------- 2 files changed, 6 insertions(+), 12 deletions(-) diff --git a/src/pytest_bdd/reporting.py b/src/pytest_bdd/reporting.py index 95254f648..fecbdf950 100644 --- a/src/pytest_bdd/reporting.py +++ b/src/pytest_bdd/reporting.py @@ -63,10 +63,7 @@ def duration(self) -> float: :return: Step execution duration. :rtype: float """ - if self.stopped is None: - return 0 - - return self.stopped - self.started + return 0 if self.stopped is None else self.stopped - self.started class ScenarioReport: diff --git a/src/pytest_bdd/scenario.py b/src/pytest_bdd/scenario.py index 7a231ef50..d64b3f61a 100644 --- a/src/pytest_bdd/scenario.py +++ b/src/pytest_bdd/scenario.py @@ -47,7 +47,7 @@ def find_fixturedefs_for_step(step: Step, fixturemanager: FixtureManager, nodeid """Find the fixture defs that can parse a step.""" # happens to be that _arg2fixturedefs is changed during the iteration so we use a copy fixture_def_by_name = list(fixturemanager._arg2fixturedefs.items()) - for i, (fixturename, fixturedefs) in enumerate(fixture_def_by_name): + for fixturename, fixturedefs in fixture_def_by_name: for pos, fixturedef in enumerate(fixturedefs): step_func_context = getattr(fixturedef.func, "_pytest_bdd_step_context", None) if step_func_context is None: @@ -245,14 +245,11 @@ def scenario_wrapper(request: FixtureRequest, _pytest_bdd_example: dict[str, str def collect_example_parametrizations( templated_scenario: ScenarioTemplate, ) -> list[ParameterSet] | None: - # We need to evaluate these iterators and store them as lists, otherwise - # we won't be able to do the cartesian product later (the second iterator will be consumed) - contexts = list(templated_scenario.examples.as_contexts()) - if not contexts: + if contexts := list(templated_scenario.examples.as_contexts()): + return [pytest.param(context, id="-".join(context.values())) for context in contexts] + else: return None - return [pytest.param(context, id="-".join(context.values())) for context in contexts] - def scenario( feature_name: str, @@ -267,7 +264,7 @@ def scenario( :param str encoding: Feature file encoding. """ __tracebackhide__ = True - scenario_name = str(scenario_name) + scenario_name = scenario_name caller_module_path = get_caller_module_path() # Get the feature