diff --git a/doc/api/prep_data/feature_store.rst b/doc/api/prep_data/feature_store.rst index ee6350b8da..278574e400 100644 --- a/doc/api/prep_data/feature_store.rst +++ b/doc/api/prep_data/feature_store.rst @@ -75,6 +75,14 @@ Inputs :members: :show-inheritance: +.. autoclass:: sagemaker.feature_store.inputs.ThroughputConfig + :members: + :show-inheritance: + +.. autoclass:: sagemaker.feature_store.inputs.ThroughputConfigUpdate + :members: + :show-inheritance: + .. autoclass:: sagemaker.feature_store.inputs.OnlineStoreConfig :members: :show-inheritance: @@ -99,6 +107,10 @@ Inputs :members: :show-inheritance: +.. autoclass:: sagemaker.feature_store.inputs.ThroughputModeEnum + :members: + :show-inheritance: + .. autoclass:: sagemaker.feature_store.inputs.ResourceEnum :members: :show-inheritance: diff --git a/src/sagemaker/feature_store/feature_group.py b/src/sagemaker/feature_store/feature_group.py index 0e503e192d..9ffb0ea9da 100644 --- a/src/sagemaker/feature_store/feature_group.py +++ b/src/sagemaker/feature_store/feature_group.py @@ -64,6 +64,8 @@ TtlDuration, OnlineStoreConfigUpdate, OnlineStoreStorageTypeEnum, + ThroughputConfig, + ThroughputConfigUpdate, ) from sagemaker.utils import resolve_value_from_config, format_tags, Tags @@ -541,6 +543,7 @@ def create( tags: Optional[Tags] = None, table_format: TableFormatEnum = None, online_store_storage_type: OnlineStoreStorageTypeEnum = None, + throughput_config: ThroughputConfig = None, ) -> Dict[str, Any]: """Create a SageMaker FeatureStore FeatureGroup. @@ -570,6 +573,8 @@ def create( table_format (TableFormatEnum): format of the offline store table (default: None). online_store_storage_type (OnlineStoreStorageTypeEnum): storage type for the online store (default: None). + throughput_config (ThroughputConfig): throughput configuration of the + feature group (default: None). Returns: Response dict from service. @@ -618,6 +623,9 @@ def create( ) create_feature_store_args.update({"online_store_config": online_store_config.to_dict()}) + if throughput_config: + create_feature_store_args.update({"throughput_config": throughput_config.to_dict()}) + # offline store configuration if s3_uri: s3_storage_config = S3StorageConfig(s3_uri=s3_uri) @@ -656,17 +664,17 @@ def update( self, feature_additions: Sequence[FeatureDefinition] = None, online_store_config: OnlineStoreConfigUpdate = None, + throughput_config: ThroughputConfigUpdate = None, ) -> Dict[str, Any]: """Update a FeatureGroup and add new features from the given feature definitions. Args: feature_additions (Sequence[Dict[str, str]): list of feature definitions to be updated. online_store_config (OnlineStoreConfigUpdate): online store config to be updated. - + throughput_config (ThroughputConfigUpdate): target throughput configuration Returns: Response dict from service. """ - if feature_additions is None: feature_additions_parameter = None else: @@ -679,10 +687,15 @@ def update( else: online_store_config_parameter = online_store_config.to_dict() + throughput_config_parameter = ( + None if throughput_config is None else throughput_config.to_dict() + ) + return self.sagemaker_session.update_feature_group( feature_group_name=self.name, feature_additions=feature_additions_parameter, online_store_config=online_store_config_parameter, + throughput_config=throughput_config_parameter, ) def update_feature_metadata( diff --git a/src/sagemaker/feature_store/inputs.py b/src/sagemaker/feature_store/inputs.py index ed61117ead..aaff977d3c 100644 --- a/src/sagemaker/feature_store/inputs.py +++ b/src/sagemaker/feature_store/inputs.py @@ -453,3 +453,79 @@ class ExpirationTimeResponseEnum(Enum): DISABLED = "Disabled" ENABLED = "Enabled" + + +class ThroughputModeEnum(Enum): + """Enum of throughput modes supported by feature group. + + Throughput mode of feature group can be ON_DEMAND or PROVISIONED. + """ + + ON_DEMAND = "OnDemand" + PROVISIONED = "Provisioned" + + +@attr.s +class ThroughputConfig(Config): + """Throughput configuration of the feature group. + + Throughput configuration can be ON_DEMAND, or PROVISIONED with valid values for + read and write capacity units. ON_DEMAND works best for less predictable traffic, + while PROVISIONED works best for consistent and predictable traffic. + + Attributes: + mode (ThroughputModeEnum): Throughput mode + provisioned_read_capacity_units (int): For provisioned feature groups, this indicates + the read throughput you are billed for and can consume without throttling. + provisioned_write_capacity_units (int): For provisioned feature groups, this indicates + the write throughput you are billed for and can consume without throttling. + """ + + mode: ThroughputModeEnum = attr.ib(default=None) + provisioned_read_capacity_units: int = attr.ib(default=None) + provisioned_write_capacity_units: int = attr.ib(default=None) + + def to_dict(self) -> Dict[str, Any]: + """Construct a dictionary based on the attributes provided. + + Returns: + dict represents the attributes. + """ + return Config.construct_dict( + ThroughputMode=self.mode.value if self.mode else None, + ProvisionedReadCapacityUnits=self.provisioned_read_capacity_units, + ProvisionedWriteCapacityUnits=self.provisioned_write_capacity_units, + ) + + +@attr.s +class ThroughputConfigUpdate(Config): + """Target throughput configuration for the feature group. + + Target throughput configuration can be ON_DEMAND, or PROVISIONED with valid values for + read and write capacity units. ON_DEMAND works best for less predictable traffic, + while PROVISIONED works best for consistent and predictable traffic. + + Attributes: + mode (ThroughputModeEnum): Target throughput mode + provisioned_read_capacity_units (int): For provisioned feature groups, this indicates + the read throughput you are billed for and can consume without throttling. + provisioned_write_capacity_units (int): For provisioned feature groups, this indicates + the write throughput you are billed for and can consume without throttling. + """ + + mode: ThroughputModeEnum = attr.ib(default=None) + provisioned_read_capacity_units: int = attr.ib(default=None) + provisioned_write_capacity_units: int = attr.ib(default=None) + + def to_dict(self) -> Dict[str, Any]: + """Construct a dictionary based on the attributes provided. + + Returns: + dict represents the attributes. + """ + return Config.construct_dict( + ThroughputMode=self.mode.value if self.mode else None, + ProvisionedReadCapacityUnits=self.provisioned_read_capacity_units, + ProvisionedWriteCapacityUnits=self.provisioned_write_capacity_units, + ) diff --git a/src/sagemaker/session.py b/src/sagemaker/session.py index ac1bf6e343..8f2753a7cf 100644 --- a/src/sagemaker/session.py +++ b/src/sagemaker/session.py @@ -5679,6 +5679,7 @@ def create_feature_group( role_arn: str = None, online_store_config: Dict[str, str] = None, offline_store_config: Dict[str, str] = None, + throughput_config: Dict[str, Any] = None, description: str = None, tags: Optional[Tags] = None, ) -> Dict[str, Any]: @@ -5694,6 +5695,8 @@ def create_feature_group( feature online store. offline_store_config (Dict[str, str]): dict contains configuration of the feature offline store. + throughput_config (Dict[str, str]): dict contains throughput configuration + for the feature group. description (str): description of the FeatureGroup. tags (Optional[Tags]): tags for labeling a FeatureGroup. @@ -5729,6 +5732,7 @@ def create_feature_group( kwargs, OnlineStoreConfig=inferred_online_store_from_config, OfflineStoreConfig=inferred_offline_store_from_config, + ThroughputConfig=throughput_config, Description=description, Tags=tags, ) @@ -5757,28 +5761,32 @@ def update_feature_group( feature_group_name: str, feature_additions: Sequence[Dict[str, str]] = None, online_store_config: Dict[str, any] = None, + throughput_config: Dict[str, Any] = None, ) -> Dict[str, Any]: """Update a FeatureGroup - either adding new features from the given feature definitions - or updating online store config + Supports modifications like adding new features from the given feature definitions, + updating online store and throughput configurations. Args: feature_group_name (str): name of the FeatureGroup to update. feature_additions (Sequence[Dict[str, str]): list of feature definitions to be updated. + online_store_config (Dict[str, Any]): updates to online store config + throughput_config (Dict[str, Any]): target throughput configuration of the feature group Returns: Response dict from service. """ + update_req = {"FeatureGroupName": feature_group_name} + if online_store_config is not None: + update_req["OnlineStoreConfig"] = online_store_config - if feature_additions is None: - return self.sagemaker_client.update_feature_group( - FeatureGroupName=feature_group_name, - OnlineStoreConfig=online_store_config, - ) + if throughput_config is not None: + update_req["ThroughputConfig"] = throughput_config - return self.sagemaker_client.update_feature_group( - FeatureGroupName=feature_group_name, FeatureAdditions=feature_additions - ) + if feature_additions is not None: + update_req["FeatureAdditions"] = feature_additions + + return self.sagemaker_client.update_feature_group(**update_req) def list_feature_groups( self, diff --git a/tests/integ/test_feature_store.py b/tests/integ/test_feature_store.py index f7190d2122..319d492e83 100644 --- a/tests/integ/test_feature_store.py +++ b/tests/integ/test_feature_store.py @@ -43,6 +43,9 @@ TtlDuration, OnlineStoreConfigUpdate, OnlineStoreStorageTypeEnum, + ThroughputConfig, + ThroughputModeEnum, + ThroughputConfigUpdate, ) from sagemaker.feature_store.dataset_builder import ( JoinTypeEnum, @@ -410,6 +413,78 @@ def test_create_feature_group_standard_storage_type( assert storage_type == "Standard" +def test_throughput_create_as_provisioned_and_update_to_ondemand( + feature_store_session, + role, + feature_group_name, + pandas_data_frame, +): + feature_group = FeatureGroup(name=feature_group_name, sagemaker_session=feature_store_session) + feature_group.load_feature_definitions(data_frame=pandas_data_frame) + with cleanup_feature_group(feature_group): + feature_group.create( + s3_uri=False, + record_identifier_name="feature1", + event_time_feature_name="feature3", + role_arn=role, + enable_online_store=True, + throughput_config=ThroughputConfig(ThroughputModeEnum.PROVISIONED, 4000, 4000), + ) + _wait_for_feature_group_create(feature_group) + + tp_config = feature_group.describe().get("ThroughputConfig") + mode = tp_config.get("ThroughputMode") + rcu = tp_config.get("ProvisionedReadCapacityUnits") + wcu = tp_config.get("ProvisionedWriteCapacityUnits") + assert mode == ThroughputModeEnum.PROVISIONED.value + assert rcu == 4000 + assert wcu == 4000 + + feature_group.update(throughput_config=ThroughputConfigUpdate(ThroughputModeEnum.ON_DEMAND)) + _wait_for_feature_group_update(feature_group) + + tp_config = feature_group.describe().get("ThroughputConfig") + mode = tp_config.get("ThroughputMode") + assert mode == ThroughputModeEnum.ON_DEMAND.value + + +def test_throughput_create_as_ondemand_and_update_to_provisioned( + feature_store_session, + role, + feature_group_name, + pandas_data_frame, +): + feature_group = FeatureGroup(name=feature_group_name, sagemaker_session=feature_store_session) + feature_group.load_feature_definitions(data_frame=pandas_data_frame) + with cleanup_feature_group(feature_group): + feature_group.create( + s3_uri=False, + record_identifier_name="feature1", + event_time_feature_name="feature3", + role_arn=role, + enable_online_store=True, + throughput_config=ThroughputConfig(ThroughputModeEnum.ON_DEMAND), + ) + _wait_for_feature_group_create(feature_group) + + tp_config = feature_group.describe().get("ThroughputConfig") + mode = tp_config.get("ThroughputMode") + assert mode == ThroughputModeEnum.ON_DEMAND.value + + feature_group.update( + throughput_config=ThroughputConfigUpdate(ThroughputModeEnum.PROVISIONED, 100, 200) + ) + _wait_for_feature_group_update(feature_group) + + tp_config = feature_group.describe().get("ThroughputConfig") + mode = tp_config.get("ThroughputMode") + rcu = tp_config.get("ProvisionedReadCapacityUnits") + wcu = tp_config.get("ProvisionedWriteCapacityUnits") + assert mode == ThroughputModeEnum.PROVISIONED.value + assert rcu == 100 + assert wcu == 200 + + def test_ttl_duration( feature_store_session, role, diff --git a/tests/unit/sagemaker/feature_store/test_feature_group.py b/tests/unit/sagemaker/feature_store/test_feature_group.py index c3499e3f51..394ecb25b3 100644 --- a/tests/unit/sagemaker/feature_store/test_feature_group.py +++ b/tests/unit/sagemaker/feature_store/test_feature_group.py @@ -40,6 +40,9 @@ TtlDuration, OnlineStoreConfigUpdate, OnlineStoreStorageTypeEnum, + ThroughputModeEnum, + ThroughputConfig, + ThroughputConfigUpdate, ) from tests.unit import SAGEMAKER_CONFIG_FEATURE_GROUP @@ -305,6 +308,63 @@ def test_feature_store_create_with_in_memory_collection_types( ) +def test_feature_store_create_in_provisioned_throughput_mode( + sagemaker_session_mock, role_arn, feature_group_dummy_definitions +): + feature_group = FeatureGroup(name="MyFeatureGroup", sagemaker_session=sagemaker_session_mock) + feature_group.feature_definitions = feature_group_dummy_definitions + feature_group.create( + s3_uri=False, + record_identifier_name="feature1", + event_time_feature_name="feature2", + role_arn=role_arn, + enable_online_store=True, + throughput_config=ThroughputConfig(ThroughputModeEnum.PROVISIONED, 1000, 2000), + ) + sagemaker_session_mock.create_feature_group.assert_called_with( + feature_group_name="MyFeatureGroup", + record_identifier_name="feature1", + event_time_feature_name="feature2", + feature_definitions=[fd.to_dict() for fd in feature_group_dummy_definitions], + role_arn=role_arn, + description=None, + tags=None, + online_store_config={"EnableOnlineStore": True}, + throughput_config={ + "ThroughputMode": "Provisioned", + "ProvisionedReadCapacityUnits": 1000, + "ProvisionedWriteCapacityUnits": 2000, + }, + ) + + +def test_feature_store_create_in_ondemand_throughput_mode( + sagemaker_session_mock, role_arn, feature_group_dummy_definitions +): + feature_group = FeatureGroup(name="MyFeatureGroup", sagemaker_session=sagemaker_session_mock) + feature_group.feature_definitions = feature_group_dummy_definitions + feature_group.create( + s3_uri=False, + record_identifier_name="feature1", + event_time_feature_name="feature2", + role_arn=role_arn, + enable_online_store=True, + throughput_config=ThroughputConfig(ThroughputModeEnum.ON_DEMAND), + ) + + sagemaker_session_mock.create_feature_group.assert_called_with( + feature_group_name="MyFeatureGroup", + record_identifier_name="feature1", + event_time_feature_name="feature2", + feature_definitions=[fd.to_dict() for fd in feature_group_dummy_definitions], + role_arn=role_arn, + description=None, + tags=None, + online_store_config={"EnableOnlineStore": True}, + throughput_config={"ThroughputMode": "OnDemand"}, + ) + + def test_feature_store_delete(sagemaker_session_mock): feature_group = FeatureGroup(name="MyFeatureGroup", sagemaker_session=sagemaker_session_mock) feature_group.delete() @@ -327,6 +387,35 @@ def test_feature_store_update(sagemaker_session_mock, feature_group_dummy_defini sagemaker_session_mock.update_feature_group.assert_called_with( feature_group_name="MyFeatureGroup", feature_additions=[fd.to_dict() for fd in feature_group_dummy_definitions], + throughput_config=None, + online_store_config=None, + ) + + +def test_feature_store_throughput_update_to_provisioned(sagemaker_session_mock): + feature_group = FeatureGroup(name="MyFeatureGroup", sagemaker_session=sagemaker_session_mock) + feature_group.update( + throughput_config=ThroughputConfigUpdate(ThroughputModeEnum.PROVISIONED, 999, 777) + ) + sagemaker_session_mock.update_feature_group.assert_called_with( + feature_group_name="MyFeatureGroup", + feature_additions=None, + throughput_config={ + "ThroughputMode": "Provisioned", + "ProvisionedReadCapacityUnits": 999, + "ProvisionedWriteCapacityUnits": 777, + }, + online_store_config=None, + ) + + +def test_feature_store_throughput_update_to_ondemand(sagemaker_session_mock): + feature_group = FeatureGroup(name="MyFeatureGroup", sagemaker_session=sagemaker_session_mock) + feature_group.update(throughput_config=ThroughputConfigUpdate(ThroughputModeEnum.ON_DEMAND)) + sagemaker_session_mock.update_feature_group.assert_called_with( + feature_group_name="MyFeatureGroup", + feature_additions=None, + throughput_config={"ThroughputMode": "OnDemand"}, online_store_config=None, ) @@ -341,6 +430,7 @@ def test_feature_store_update_with_ttl_duration(sagemaker_session_mock): feature_group_name="MyFeatureGroup", feature_additions=None, online_store_config=online_store_config.to_dict(), + throughput_config=None, ) diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index c51dcaaea5..6ee2cc9af5 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -5134,7 +5134,7 @@ def test_feature_group_describe(sagemaker_session): ) -def test_feature_group_update(sagemaker_session, feature_group_dummy_definitions): +def test_feature_group_feature_additions_update(sagemaker_session, feature_group_dummy_definitions): sagemaker_session.update_feature_group( feature_group_name="MyFeatureGroup", feature_additions=feature_group_dummy_definitions, @@ -5145,6 +5145,32 @@ def test_feature_group_update(sagemaker_session, feature_group_dummy_definitions ) +def test_feature_group_online_store_config_update(sagemaker_session): + os_conf_update = {"TtlDuration": {"Unit": "Seconds", "Value": 123}} + sagemaker_session.update_feature_group( + feature_group_name="MyFeatureGroup", + online_store_config=os_conf_update, + ) + assert sagemaker_session.sagemaker_client.update_feature_group.called_with( + FeatureGroupName="MyFeatureGroup", OnlineStoreConfig=os_conf_update + ) + + +def test_feature_group_throughput_config_update(sagemaker_session): + tp_update = { + "ThroughputMode": "Provisioned", + "ProvisionedReadCapacityUnits": 123, + "ProvisionedWriteCapacityUnits": 456, + } + sagemaker_session.update_feature_group( + feature_group_name="MyFeatureGroup", + throughput_config=tp_update, + ) + assert sagemaker_session.sagemaker_client.update_feature_group.called_with( + FeatureGroupName="MyFeatureGroup", ThroughputConfig=tp_update + ) + + def test_feature_metadata_update(sagemaker_session): parameter_additions = [ {