Skip to content

Commit 128cead

Browse files
authored
Type Checking (#36)
* disallow any generics * Fixing a few type: ignore * Fixing alias type hint * wrong Mapping * Update validate.yml
1 parent 789120f commit 128cead

File tree

14 files changed

+94
-72
lines changed

14 files changed

+94
-72
lines changed

.github/workflows/validate.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,5 +27,8 @@ jobs:
2727
- name: Test with pytest
2828
run: |
2929
pytest tests --doctest-modules --cov-report=xml --cov-report=html
30+
- name: Run mypy
31+
run: |
32+
mypy featuremanagement
3033
- name: cspell-action
3134
uses: streetsidesoftware/[email protected]

featuremanagement/_defaultfilters.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ class TimeWindowFilter(FeatureFilter):
4444
Feature Filter that determines if the current time is within the time window.
4545
"""
4646

47-
def evaluate(self, context: Mapping, **kwargs: Dict[str, Any]) -> bool:
47+
def evaluate(self, context: Mapping[Any, Any], **kwargs: Any) -> bool:
4848
"""
4949
Determine if the feature flag is enabled for the given context.
5050
@@ -87,7 +87,7 @@ def _is_targeted(context_id: str, rollout_percentage: int) -> bool:
8787
return percentage < rollout_percentage
8888

8989
def _target_group(
90-
self, target_user: Optional[str], target_group: str, group: Mapping, feature_flag_name: str
90+
self, target_user: Optional[str], target_group: str, group: Mapping[str, Any], feature_flag_name: str
9191
) -> bool:
9292
group_rollout_percentage = group.get(ROLLOUT_PERCENTAGE_KEY, 0)
9393
if not target_user:
@@ -96,7 +96,7 @@ def _target_group(
9696

9797
return self._is_targeted(audience_context_id, group_rollout_percentage)
9898

99-
def evaluate(self, context: Mapping, **kwargs: Dict[str, Any]) -> bool:
99+
def evaluate(self, context: Mapping[Any, Any], **kwargs: Any) -> bool:
100100
"""
101101
Determine if the feature flag is enabled for the given context.
102102
@@ -156,11 +156,11 @@ def evaluate(self, context: Mapping, **kwargs: Dict[str, Any]) -> bool:
156156
return self._is_targeted(context_id, default_rollout_percentage)
157157

158158
@staticmethod
159-
def _validate(groups: List, default_rollout_percentage: int) -> None:
159+
def _validate(groups: List[Dict[str, Any]], default_rollout_percentage: int) -> None:
160160
# Validate the audience settings
161161
if default_rollout_percentage < 0 or default_rollout_percentage > 100:
162162
raise TargetingException("DefaultRolloutPercentage must be between 0 and 100")
163163

164164
for group in groups:
165-
if group.get(ROLLOUT_PERCENTAGE_KEY) < 0 or group.get(ROLLOUT_PERCENTAGE_KEY) > 100:
165+
if group.get(ROLLOUT_PERCENTAGE_KEY, 0) < 0 or group.get(ROLLOUT_PERCENTAGE_KEY, 100) > 100:
166166
raise TargetingException("RolloutPercentage must be between 0 and 100")

featuremanagement/_featurefilters.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,19 @@
44
# license information.
55
# -------------------------------------------------------------------------
66
from abc import ABC, abstractmethod
7-
from typing import Mapping, Callable, Dict, Any
7+
from typing import Mapping, Callable, Any, Optional
8+
from typing_extensions import Self
89

910

1011
class FeatureFilter(ABC):
1112
"""
1213
Parent class for all feature filters.
1314
"""
1415

16+
_alias: Optional[str] = None
17+
1518
@abstractmethod
16-
def evaluate(self, context: Mapping, **kwargs: Dict[str, Any]) -> bool:
19+
def evaluate(self, context: Mapping[Any, Any], **kwargs: Any) -> bool:
1720
"""
1821
Determine if the feature flag is enabled for the given context.
1922
@@ -28,12 +31,12 @@ def name(self) -> str:
2831
:return: Name of the filter, or alias if it exists.
2932
:rtype: str
3033
"""
31-
if hasattr(self, "_alias"):
32-
return self._alias # type: ignore
34+
if hasattr(self, "_alias") and self._alias:
35+
return self._alias
3336
return self.__class__.__name__
3437

3538
@staticmethod
36-
def alias(alias: str) -> Callable:
39+
def alias(alias: str) -> Callable[..., Any]:
3740
"""
3841
Decorator to set the alias for the filter.
3942
@@ -42,7 +45,7 @@ def alias(alias: str) -> Callable:
4245
:rtype: Callable
4346
"""
4447

45-
def wrapper(cls) -> Any: # type: ignore
48+
def wrapper(cls: Self) -> Any:
4649
cls._alias = alias # pylint: disable=protected-access
4750
return cls
4851

featuremanagement/_featuremanager.py

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -28,17 +28,20 @@ class FeatureManager(FeatureManagerBase):
2828
evaluated.
2929
"""
3030

31-
def __init__(self, configuration: Mapping, **kwargs: Dict[str, Any]):
31+
def __init__(self, configuration: Mapping[str, Any], **kwargs: Any):
3232
super().__init__(configuration, **kwargs)
33-
filters = [TimeWindowFilter(), TargetingFilter()] + cast(List, kwargs.pop(PROVIDED_FEATURE_FILTERS, []))
33+
self._filters: Dict[str, FeatureFilter] = {}
34+
filters = [TimeWindowFilter(), TargetingFilter()] + cast(
35+
List[FeatureFilter], kwargs.pop(PROVIDED_FEATURE_FILTERS, [])
36+
)
3437

3538
for feature_filter in filters:
3639
if not isinstance(feature_filter, FeatureFilter):
3740
raise ValueError("Custom filter must be a subclass of FeatureFilter")
3841
self._filters[feature_filter.name] = feature_filter
3942

4043
@overload # type: ignore
41-
def is_enabled(self, feature_flag_id: str, user_id: str, **kwargs: Dict[str, Any]) -> bool:
44+
def is_enabled(self, feature_flag_id: str, user_id: str, **kwargs: Any) -> bool:
4245
"""
4346
Determine if the feature flag is enabled for the given context.
4447
@@ -48,7 +51,7 @@ def is_enabled(self, feature_flag_id: str, user_id: str, **kwargs: Dict[str, Any
4851
:rtype: bool
4952
"""
5053

51-
def is_enabled(self, feature_flag_id: str, *args: Any, **kwargs: Dict[str, Any]) -> bool: # type: ignore
54+
def is_enabled(self, feature_flag_id: str, *args: Any, **kwargs: Any) -> bool:
5255
"""
5356
Determine if the feature flag is enabled for the given context.
5457
@@ -70,7 +73,7 @@ def is_enabled(self, feature_flag_id: str, *args: Any, **kwargs: Dict[str, Any])
7073
return result.enabled
7174

7275
@overload # type: ignore
73-
def get_variant(self, feature_flag_id: str, user_id: str, **kwargs: Dict[str, Any]) -> Optional[Variant]:
76+
def get_variant(self, feature_flag_id: str, user_id: str, **kwargs: Any) -> Optional[Variant]:
7477
"""
7578
Determine the variant for the given context.
7679
@@ -80,9 +83,7 @@ def get_variant(self, feature_flag_id: str, user_id: str, **kwargs: Dict[str, An
8083
:rtype: Variant
8184
"""
8285

83-
def get_variant( # type: ignore
84-
self, feature_flag_id: str, *args: Any, **kwargs: Dict[str, Any]
85-
) -> Optional[Variant]:
86+
def get_variant(self, feature_flag_id: str, *args: Any, **kwargs: Any) -> Optional[Variant]:
8687
"""
8788
Determine the variant for the given context.
8889
@@ -105,7 +106,7 @@ def get_variant( # type: ignore
105106
return result.variant
106107

107108
def _check_feature_filters(
108-
self, evaluation_event: EvaluationEvent, targeting_context: TargetingContext, **kwargs: Dict
109+
self, evaluation_event: EvaluationEvent, targeting_context: TargetingContext, **kwargs: Any
109110
) -> None:
110111
feature_flag = evaluation_event.feature
111112
if not feature_flag:
@@ -123,8 +124,8 @@ def _check_feature_filters(
123124

124125
for feature_filter in feature_filters:
125126
filter_name = feature_filter[FEATURE_FILTER_NAME]
126-
kwargs["user"] = targeting_context.user_id # type: ignore
127-
kwargs["groups"] = targeting_context.groups # type: ignore
127+
kwargs["user"] = targeting_context.user_id
128+
kwargs["groups"] = targeting_context.groups
128129
if filter_name not in self._filters:
129130
raise ValueError(f"Feature flag {feature_flag.name} has unknown filter {filter_name}")
130131
if feature_conditions.requirement_type == REQUIREMENT_TYPE_ALL:
@@ -136,7 +137,7 @@ def _check_feature_filters(
136137
break
137138

138139
def _check_feature(
139-
self, feature_flag_id: str, targeting_context: TargetingContext, **kwargs: Dict[str, Any]
140+
self, feature_flag_id: str, targeting_context: TargetingContext, **kwargs: Any
140141
) -> EvaluationEvent:
141142
"""
142143
Determine if the feature flag is enabled for the given context.

featuremanagement/_featuremanagerbase.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,9 @@
33
# Licensed under the MIT License. See License.txt in the project root for
44
# license information.
55
# -------------------------------------------------------------------------
6-
from collections.abc import Mapping
76
import hashlib
87
from abc import ABC
9-
from typing import List, Optional, Dict, Tuple, Any
8+
from typing import List, Optional, Dict, Tuple, Any, Mapping
109
from ._models import FeatureFlag, Variant, VariantAssignmentReason, TargetingContext, EvaluationEvent, VariantReference
1110

1211

@@ -21,7 +20,7 @@
2120
FEATURE_FILTER_PARAMETERS = "parameters"
2221

2322

24-
def _get_feature_flag(configuration: Mapping, feature_flag_name: str) -> Optional[FeatureFlag]:
23+
def _get_feature_flag(configuration: Mapping[str, Any], feature_flag_name: str) -> Optional[FeatureFlag]:
2524
"""
2625
Gets the FeatureFlag json from the configuration, if it exists it gets converted to a FeatureFlag object.
2726
@@ -44,7 +43,7 @@ def _get_feature_flag(configuration: Mapping, feature_flag_name: str) -> Optiona
4443
return None
4544

4645

47-
def _list_feature_flag_names(configuration: Mapping) -> List[str]:
46+
def _list_feature_flag_names(configuration: Mapping[str, Any]) -> List[str]:
4847
"""
4948
List of all feature flag names.
5049
@@ -70,12 +69,11 @@ class FeatureManagerBase(ABC):
7069
Base class for Feature Manager. This class is responsible for all shared logic between the sync and async.
7170
"""
7271

73-
def __init__(self, configuration: Mapping, **kwargs: Dict[str, Any]):
74-
self._filters: Dict = {}
72+
def __init__(self, configuration: Mapping[str, Any], **kwargs: Any):
7573
if configuration is None or not isinstance(configuration, Mapping):
7674
raise AttributeError("Configuration must be a non-empty dictionary")
7775
self._configuration = configuration
78-
self._cache: Dict = {}
76+
self._cache: Dict[str, Optional[FeatureFlag]] = {}
7977
self._copy = configuration.get(FEATURE_MANAGEMENT_KEY)
8078
self._on_feature_evaluated = kwargs.pop("on_feature_evaluated", None)
8179

@@ -214,7 +212,7 @@ def _variant_name_to_variant(self, feature_flag: FeatureFlag, variant_name: Opti
214212
for variant_reference in feature_flag.variants:
215213
if variant_reference.name == variant_name:
216214
configuration = variant_reference.configuration_value
217-
if not configuration:
215+
if not configuration and variant_reference.configuration_reference:
218216
configuration = self._configuration.get(variant_reference.configuration_reference)
219217
return Variant(variant_reference.name, configuration)
220218
return None

featuremanagement/_models/_allocation.py

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
# Licensed under the MIT License. See License.txt in the project root for
44
# license information.
55
# -------------------------------------------------------------------------
6-
from typing import cast, List, Optional, Mapping, Dict, Any
6+
from typing import cast, List, Optional, Mapping, Dict, Any, Union
77
from dataclasses import dataclass
88
from typing_extensions import Self
99
from ._constants import DEFAULT_WHEN_ENABLED, DEFAULT_WHEN_DISABLED, USER, GROUP, PERCENTILE, SEED
@@ -16,7 +16,7 @@ class UserAllocation:
1616
"""
1717

1818
variant: str
19-
users: list
19+
users: List[str]
2020

2121

2222
@dataclass
@@ -26,7 +26,7 @@ class GroupAllocation:
2626
"""
2727

2828
variant: str
29-
groups: list
29+
groups: List[str]
3030

3131

3232
class PercentileAllocation:
@@ -35,12 +35,12 @@ class PercentileAllocation:
3535
"""
3636

3737
def __init__(self) -> None:
38-
self._variant = None
38+
self._variant: Optional[str] = None
3939
self._percentile_from: int = 0
4040
self._percentile_to: Optional[int] = None
4141

4242
@classmethod
43-
def convert_from_json(cls, json: Mapping) -> Self:
43+
def convert_from_json(cls, json: Mapping[str, Union[str, int]]) -> Self:
4444
"""
4545
Convert a JSON object to PercentileAllocation.
4646
@@ -51,9 +51,21 @@ def convert_from_json(cls, json: Mapping) -> Self:
5151
if not json:
5252
raise ValueError("Percentile allocation is not valid.")
5353
user_allocation = cls()
54-
user_allocation._variant = json.get("variant")
55-
user_allocation._percentile_from = json.get("from", 0)
56-
user_allocation._percentile_to = json.get("to")
54+
55+
variant = json.get("variant")
56+
if not variant or not isinstance(variant, str):
57+
raise ValueError("Percentile allocation does not have a valid assigned variant.")
58+
user_allocation._variant = variant
59+
60+
percentile_from = json.get("from", 0)
61+
if not isinstance(percentile_from, int):
62+
raise ValueError("Percentile allocation does not have a valid starting percentile.")
63+
user_allocation._percentile_from = percentile_from
64+
65+
percentile_to = json.get("to")
66+
if not percentile_to or not isinstance(percentile_to, int):
67+
raise ValueError("Percentile allocation does not have a valid ending percentile.")
68+
user_allocation._percentile_to = percentile_to
5769
return user_allocation
5870

5971
@property
@@ -101,7 +113,7 @@ def __init__(self, feature_name: str) -> None:
101113
self._seed = "allocation\n" + feature_name
102114

103115
@classmethod
104-
def convert_from_json(cls, json: Dict, feature_name: str) -> Optional[Self]:
116+
def convert_from_json(cls, json: Dict[str, Any], feature_name: str) -> Optional[Self]:
105117
"""
106118
Convert a JSON object to Allocation.
107119

featuremanagement/_models/_feature_conditions.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
# license information.
55
# -------------------------------------------------------------------------
66
from collections.abc import Mapping
7-
from typing import List
7+
from typing import Any, Dict, List
88
from typing_extensions import Self
99
from ._constants import (
1010
FEATURE_FLAG_CLIENT_FILTERS,
@@ -22,7 +22,7 @@ class FeatureConditions:
2222

2323
def __init__(self) -> None:
2424
self._requirement_type = REQUIREMENT_TYPE_ANY
25-
self._client_filters: List[dict] = []
25+
self._client_filters: List[Dict[str, Any]] = []
2626

2727
@classmethod
2828
def convert_from_json(cls, feature_name: str, json_value: str) -> Self:
@@ -55,7 +55,7 @@ def requirement_type(self) -> str:
5555
return self._requirement_type
5656

5757
@property
58-
def client_filters(self) -> List[dict]:
58+
def client_filters(self) -> List[Dict[str, Any]]:
5959
"""
6060
Get the client filters for the feature flag.
6161

featuremanagement/_models/_feature_flag.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ def __init__(self) -> None:
3232
self._telemetry: Telemetry = Telemetry()
3333

3434
@classmethod
35-
def convert_from_json(cls, json_value: Mapping) -> Self:
35+
def convert_from_json(cls, json_value: Mapping[str, Any]) -> Self:
3636
"""
3737
Convert a JSON object to FeatureFlag.
3838
@@ -56,7 +56,7 @@ def convert_from_json(cls, json_value: Mapping) -> Self:
5656
json_value.get(FEATURE_FLAG_ALLOCATION, None), feature_flag._id
5757
)
5858
if FEATURE_FLAG_VARIANTS in json_value:
59-
variants: List[Mapping] = json_value.get(FEATURE_FLAG_VARIANTS, [])
59+
variants: List[Mapping[str, Any]] = json_value.get(FEATURE_FLAG_VARIANTS, [])
6060
feature_flag._variants = []
6161
for variant in variants:
6262
if variant:

featuremanagement/_models/_telemetry.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
# Licensed under the MIT License. See License.txt in the project root for
44
# license information.
55
# -------------------------------------------------------------------------
6+
from typing import Dict
67
from dataclasses import dataclass, field
78

89

@@ -13,4 +14,4 @@ class Telemetry:
1314
"""
1415

1516
enabled: bool = False
16-
metadata: dict = field(default_factory=dict)
17+
metadata: Dict[str, str] = field(default_factory=dict)

featuremanagement/_models/_variant_reference.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
# license information.
55
# -------------------------------------------------------------------------
66
from dataclasses import dataclass
7-
from typing import Optional, Mapping
7+
from typing import Optional, Mapping, Any
88
from typing_extensions import Self
99
from ._constants import VARIANT_REFERENCE_NAME, CONFIGURATION_VALUE, CONFIGURATION_REFERENCE, STATUS_OVERRIDE
1010

@@ -22,7 +22,7 @@ def __init__(self) -> None:
2222
self._status_override = None
2323

2424
@classmethod
25-
def convert_from_json(cls, json: Mapping) -> Self:
25+
def convert_from_json(cls, json: Mapping[str, Any]) -> Self:
2626
"""
2727
Convert a JSON object to VariantReference.
2828

0 commit comments

Comments
 (0)