44# license information.
55# -------------------------------------------------------------------------
66import inspect
7- from typing import cast , overload , Any , Optional , Dict , Mapping , List
7+ import logging
8+ from typing import cast , overload , Any , Optional , Dict , Mapping , List , Tuple
89from ._defaultfilters import TimeWindowFilter , TargetingFilter
910from ._featurefilters import FeatureFilter
1011from .._models import EvaluationEvent , Variant , TargetingContext
1516 FEATURE_FILTER_NAME ,
1617)
1718
19+ logger = logging .getLogger (__name__ )
20+
1821
1922class FeatureManager (FeatureManagerBase ):
2023 """
@@ -24,6 +27,8 @@ class FeatureManager(FeatureManagerBase):
2427 :keyword list[FeatureFilter] feature_filters: Custom filters to be used for evaluating feature flags.
2528 :keyword Callable[EvaluationEvent] on_feature_evaluated: Callback function to be called when a feature flag is
2629 evaluated.
30+ :keyword Callable[[], TargetingContext] targeting_context_accessor: Callback function to get the current targeting
31+ context if one isn't provided.
2732 """
2833
2934 def __init__ (self , configuration : Mapping [str , Any ], ** kwargs : Any ):
@@ -57,7 +62,7 @@ async def is_enabled(self, feature_flag_id: str, *args: Any, **kwargs: Any) -> b
5762 :return: True if the feature flag is enabled for the given context.
5863 :rtype: bool
5964 """
60- targeting_context = self ._build_targeting_context (args )
65+ targeting_context : TargetingContext = await self ._build_targeting_context_async (args )
6166
6267 result = await self ._check_feature (feature_flag_id , targeting_context , ** kwargs )
6368 if (
@@ -93,7 +98,7 @@ async def get_variant(self, feature_flag_id: str, *args: Any, **kwargs: Any) ->
9398 :return: Variant instance.
9499 :rtype: Variant
95100 """
96- targeting_context = self ._build_targeting_context (args )
101+ targeting_context : TargetingContext = await self ._build_targeting_context_async (args )
97102
98103 result = await self ._check_feature (feature_flag_id , targeting_context , ** kwargs )
99104 if (
@@ -109,6 +114,25 @@ async def get_variant(self, feature_flag_id: str, *args: Any, **kwargs: Any) ->
109114 self ._on_feature_evaluated (result )
110115 return result .variant
111116
117+ async def _build_targeting_context_async (self , args : Tuple [Any ]) -> TargetingContext :
118+ targeting_context = super ()._build_targeting_context (args )
119+ if targeting_context :
120+ return targeting_context
121+ if not targeting_context and self ._targeting_context_accessor and callable (self ._targeting_context_accessor ):
122+
123+ if inspect .iscoroutinefunction (self ._targeting_context_accessor ):
124+ # If a targeting_context_accessor is provided, return the TargetingContext from it
125+ targeting_context = await self ._targeting_context_accessor ()
126+ else :
127+ targeting_context = self ._targeting_context_accessor ()
128+ if targeting_context and isinstance (targeting_context , TargetingContext ):
129+ return targeting_context
130+ logger .warning (
131+ "targeting_context_accessor did not return a TargetingContext. Received type %s." ,
132+ type (targeting_context ),
133+ )
134+ return TargetingContext ()
135+
112136 async def _check_feature_filters (
113137 self , evaluation_event : EvaluationEvent , targeting_context : TargetingContext , ** kwargs : Any
114138 ) -> None :
0 commit comments