Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions openfeature/flag_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,9 @@
from openfeature._backports.strenum import StrEnum
from openfeature.exception import ErrorCode

if typing.TYPE_CHECKING: # resolves a circular dependency in type annotations
from openfeature.hook import Hook
if typing.TYPE_CHECKING: # pragma: no cover
# resolves a circular dependency in type annotations
from openfeature.hook import Hook, HookHints


class FlagType(StrEnum):
Expand Down Expand Up @@ -48,7 +49,7 @@ class FlagEvaluationDetails(typing.Generic[T_co]):
@dataclass
class FlagEvaluationOptions:
hooks: typing.List[Hook] = field(default_factory=list)
hook_hints: dict = field(default_factory=dict)
hook_hints: HookHints = field(default_factory=dict)


U_co = typing.TypeVar("U_co", covariant=True)
Expand Down
55 changes: 42 additions & 13 deletions openfeature/hook/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

import typing
from dataclasses import dataclass
from datetime import datetime
from enum import Enum
from typing import TYPE_CHECKING

Expand All @@ -20,24 +20,53 @@ class HookType(Enum):
ERROR = "error"


@dataclass
class HookContext:
flag_key: str
flag_type: FlagType
default_value: typing.Any
evaluation_context: EvaluationContext
client_metadata: typing.Optional[ClientMetadata] = None
provider_metadata: typing.Optional[Metadata] = None
def __init__( # noqa: PLR0913
self,
flag_key: str,
flag_type: FlagType,
default_value: typing.Any,
evaluation_context: EvaluationContext,
client_metadata: typing.Optional[ClientMetadata] = None,
provider_metadata: typing.Optional[Metadata] = None,
):
self.flag_key = flag_key
self.flag_type = flag_type
self.default_value = default_value
self.evaluation_context = evaluation_context
self.client_metadata = client_metadata
self.provider_metadata = provider_metadata

def __setattr__(self, key: str, value: typing.Any) -> None:
if hasattr(self, key) and key in ("flag_key", "flag_type", "default_value"):
if hasattr(self, key) and key in (
"flag_key",
"flag_type",
"default_value",
"client_metadata",
"provider_metadata",
):
raise AttributeError(f"Attribute {key!r} is immutable")
super().__setattr__(key, value)


# https://openfeature.dev/specification/sections/hooks/#requirement-421
HookHints = typing.Mapping[
str,
typing.Union[
bool,
int,
float,
str,
datetime,
typing.List[typing.Any],
typing.Dict[str, typing.Any],
],
]


class Hook:
def before(
self, hook_context: HookContext, hints: dict
self, hook_context: HookContext, hints: HookHints
) -> typing.Optional[EvaluationContext]:
"""
Runs before flag is resolved.
Expand All @@ -54,7 +83,7 @@ def after(
self,
hook_context: HookContext,
details: FlagEvaluationDetails[typing.Any],
hints: dict,
hints: HookHints,
) -> None:
"""
Runs after a flag is resolved.
Expand All @@ -67,7 +96,7 @@ def after(
pass

def error(
self, hook_context: HookContext, exception: Exception, hints: dict
self, hook_context: HookContext, exception: Exception, hints: HookHints
) -> None:
"""
Run when evaluation encounters an error. Errors thrown will be swallowed.
Expand All @@ -78,7 +107,7 @@ def error(
"""
pass

def finally_after(self, hook_context: HookContext, hints: dict) -> None:
def finally_after(self, hook_context: HookContext, hints: HookHints) -> None:
"""
Run after flag evaluation, including any error processing.
This will always run. Errors will be swallowed.
Expand Down
10 changes: 5 additions & 5 deletions openfeature/hook/hook_support.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from openfeature.evaluation_context import EvaluationContext
from openfeature.flag_evaluation import FlagEvaluationDetails, FlagType
from openfeature.hook import Hook, HookContext, HookType
from openfeature.hook import Hook, HookContext, HookHints, HookType

logger = logging.getLogger("openfeature")

Expand All @@ -14,7 +14,7 @@ def error_hooks(
hook_context: HookContext,
exception: Exception,
hooks: typing.List[Hook],
hints: typing.Optional[typing.Mapping] = None,
hints: typing.Optional[HookHints] = None,
) -> None:
kwargs = {"hook_context": hook_context, "exception": exception, "hints": hints}
_execute_hooks(
Expand All @@ -26,7 +26,7 @@ def after_all_hooks(
flag_type: FlagType,
hook_context: HookContext,
hooks: typing.List[Hook],
hints: typing.Optional[typing.Mapping] = None,
hints: typing.Optional[HookHints] = None,
) -> None:
kwargs = {"hook_context": hook_context, "hints": hints}
_execute_hooks(
Expand All @@ -39,7 +39,7 @@ def after_hooks(
hook_context: HookContext,
details: FlagEvaluationDetails[typing.Any],
hooks: typing.List[Hook],
hints: typing.Optional[typing.Mapping] = None,
hints: typing.Optional[HookHints] = None,
) -> None:
kwargs = {"hook_context": hook_context, "details": details, "hints": hints}
_execute_hooks_unchecked(
Expand All @@ -51,7 +51,7 @@ def before_hooks(
flag_type: FlagType,
hook_context: HookContext,
hooks: typing.List[Hook],
hints: typing.Optional[typing.Mapping] = None,
hints: typing.Optional[HookHints] = None,
) -> EvaluationContext:
kwargs = {"hook_context": hook_context, "hints": hints}
executed_hooks = _execute_hooks_unchecked(
Expand Down
14 changes: 10 additions & 4 deletions tests/hook/test_hook_support.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,14 @@ def test_hook_context_has_immutable_and_mutable_fields():

4.1.3 - The "flag key", "flag type", and "default value" properties MUST be immutable.
4.1.4.1 - The evaluation context MUST be mutable only within the before hook.
4.2.2.2 - The client "metadata" field in the "hook context" MUST be immutable.
4.2.2.3 - The provider "metadata" field in the "hook context" MUST be immutable.
"""

# Given
hook_context = HookContext("flag_key", FlagType.BOOLEAN, True, EvaluationContext())
hook_context = HookContext(
"flag_key", FlagType.BOOLEAN, True, EvaluationContext(), ClientMetadata("name")
)

# When
with pytest.raises(AttributeError):
Expand All @@ -52,18 +56,20 @@ def test_hook_context_has_immutable_and_mutable_fields():
hook_context.flag_type = FlagType.STRING
with pytest.raises(AttributeError):
hook_context.default_value = "new_value"
with pytest.raises(AttributeError):
hook_context.client_metadata = ClientMetadata("new_name")
with pytest.raises(AttributeError):
hook_context.provider_metadata = Metadata("name")

hook_context.evaluation_context = EvaluationContext("targeting_key")
hook_context.client_metadata = ClientMetadata("name")
hook_context.provider_metadata = Metadata("name")

# Then
assert hook_context.flag_key == "flag_key"
assert hook_context.flag_type is FlagType.BOOLEAN
assert hook_context.default_value is True
assert hook_context.evaluation_context.targeting_key == "targeting_key"
assert hook_context.client_metadata.name == "name"
assert hook_context.provider_metadata.name == "name"
assert hook_context.provider_metadata is None


def test_error_hooks_run_error_method(mock_hook):
Expand Down