From c288b35dcf372d06d43c94424e2942d619a96c03 Mon Sep 17 00:00:00 2001 From: Hao Zhu Date: Wed, 10 May 2023 19:42:07 -0700 Subject: [PATCH] feature: Add support for SageMaker Serverless inference Provisioned Concurrency feature --- setup.py | 2 +- .../serverless/serverless_inference_config.py | 11 ++++++- tests/integ/test_serverless_inference.py | 20 +++++++++-- .../test_serverless_inference_config.py | 33 +++++++++++++++++++ 4 files changed, 62 insertions(+), 4 deletions(-) diff --git a/setup.py b/setup.py index ad4118a80a..ee7c8268e3 100644 --- a/setup.py +++ b/setup.py @@ -48,7 +48,7 @@ def read_requirements(filename): # Declare minimal set for installation required_packages = [ "attrs>=20.3.0,<23", - "boto3>=1.26.28,<2.0", + "boto3>=1.26.131,<2.0", "cloudpickle==2.2.1", "google-pasta", "numpy>=1.9.0,<2.0", diff --git a/src/sagemaker/serverless/serverless_inference_config.py b/src/sagemaker/serverless/serverless_inference_config.py index adc98a319a..c170c27c5a 100644 --- a/src/sagemaker/serverless/serverless_inference_config.py +++ b/src/sagemaker/serverless/serverless_inference_config.py @@ -12,10 +12,11 @@ # language governing permissions and limitations under the License. """This module contains code related to the ServerlessInferenceConfig class. -Codes are used for configuring async inference endpoint. Use it when deploying +Codes are used for configuring serverless inference endpoint. Use it when deploying the model to the endpoints. """ from __future__ import print_function, absolute_import +from typing import Optional class ServerlessInferenceConfig(object): @@ -29,6 +30,7 @@ def __init__( self, memory_size_in_mb: int = 2048, max_concurrency: int = 5, + provisioned_concurrency: Optional[int] = None, ): """Initialize a ServerlessInferenceConfig object for serverless inference configuration. @@ -40,9 +42,13 @@ def __init__( max_concurrency (int): Optional. The maximum number of concurrent invocations your serverless endpoint can process. If no value is provided, Amazon SageMaker will choose the default value for you. (Default: 5) + provisioned_concurrency (int): Optional. The provisioned concurrency of your + serverless endpoint. If no value is provided, Amazon SageMaker will not + apply provisioned concucrrency to your Serverless endpoint. (Default: None) """ self.memory_size_in_mb = memory_size_in_mb self.max_concurrency = max_concurrency + self.provisioned_concurrency = provisioned_concurrency def _to_request_dict(self): """Generates a request dictionary using the parameters provided to the class.""" @@ -51,4 +57,7 @@ def _to_request_dict(self): "MaxConcurrency": self.max_concurrency, } + if self.provisioned_concurrency is not None: + request_dict["ProvisionedConcurrency"] = self.provisioned_concurrency + return request_dict diff --git a/tests/integ/test_serverless_inference.py b/tests/integ/test_serverless_inference.py index 40b1ace147..9a3a8d05ea 100644 --- a/tests/integ/test_serverless_inference.py +++ b/tests/integ/test_serverless_inference.py @@ -44,10 +44,11 @@ def test_serverless_walkthrough(sagemaker_session, cpu_instance_type, training_s pca.extra_components = 5 pca.fit(pca.record_set(training_set[0][:100]), job_name=job_name) - with timeout_and_delete_endpoint_by_name(job_name, sagemaker_session): + serverless_name = unique_name_from_base("pca-serverless") + with timeout_and_delete_endpoint_by_name(serverless_name, sagemaker_session): predictor_serverless = pca.deploy( - endpoint_name=job_name, serverless_inference_config=ServerlessInferenceConfig() + endpoint_name=serverless_name, serverless_inference_config=ServerlessInferenceConfig() ) result = predictor_serverless.predict(training_set[0][:5]) @@ -55,3 +56,18 @@ def test_serverless_walkthrough(sagemaker_session, cpu_instance_type, training_s assert len(result) == 5 for record in result: assert record.label["projection"] is not None + + # Test out Serverless Provisioned Concurrency endpoint happy case + serverless_pc_name = unique_name_from_base("pca-serverless-pc") + with timeout_and_delete_endpoint_by_name(serverless_pc_name, sagemaker_session): + + predictor_serverless_pc = pca.deploy( + endpoint_name=serverless_pc_name, + serverless_inference_config=ServerlessInferenceConfig(provisioned_concurrency=1), + ) + + result = predictor_serverless_pc.predict(training_set[0][:5]) + + assert len(result) == 5 + for record in result: + assert record.label["projection"] is not None diff --git a/tests/unit/sagemaker/serverless/test_serverless_inference_config.py b/tests/unit/sagemaker/serverless/test_serverless_inference_config.py index fab80748a4..bae679c5cb 100644 --- a/tests/unit/sagemaker/serverless/test_serverless_inference_config.py +++ b/tests/unit/sagemaker/serverless/test_serverless_inference_config.py @@ -16,12 +16,19 @@ DEFAULT_MEMORY_SIZE_IN_MB = 2048 DEFAULT_MAX_CONCURRENCY = 5 +DEFAULT_PROVISIONED_CONCURRENCY = 5 DEFAULT_REQUEST_DICT = { "MemorySizeInMB": DEFAULT_MEMORY_SIZE_IN_MB, "MaxConcurrency": DEFAULT_MAX_CONCURRENCY, } +PROVISIONED_CONCURRENCY_REQUEST_DICT = { + "MemorySizeInMB": DEFAULT_MEMORY_SIZE_IN_MB, + "MaxConcurrency": DEFAULT_MAX_CONCURRENCY, + "ProvisionedConcurrency": DEFAULT_PROVISIONED_CONCURRENCY, +} + def test_init(): serverless_inference_config = ServerlessInferenceConfig() @@ -29,8 +36,34 @@ def test_init(): assert serverless_inference_config.memory_size_in_mb == DEFAULT_MEMORY_SIZE_IN_MB assert serverless_inference_config.max_concurrency == DEFAULT_MAX_CONCURRENCY + serverless_provisioned_concurrency_inference_config = ServerlessInferenceConfig( + provisioned_concurrency=DEFAULT_PROVISIONED_CONCURRENCY + ) + + assert ( + serverless_provisioned_concurrency_inference_config.memory_size_in_mb + == DEFAULT_MEMORY_SIZE_IN_MB + ) + assert ( + serverless_provisioned_concurrency_inference_config.max_concurrency + == DEFAULT_MAX_CONCURRENCY + ) + assert ( + serverless_provisioned_concurrency_inference_config.provisioned_concurrency + == DEFAULT_PROVISIONED_CONCURRENCY + ) + def test_to_request_dict(): serverless_inference_config_dict = ServerlessInferenceConfig()._to_request_dict() assert serverless_inference_config_dict == DEFAULT_REQUEST_DICT + + serverless_provisioned_concurrency_inference_config_dict = ServerlessInferenceConfig( + provisioned_concurrency=DEFAULT_PROVISIONED_CONCURRENCY + )._to_request_dict() + + assert ( + serverless_provisioned_concurrency_inference_config_dict + == PROVISIONED_CONCURRENCY_REQUEST_DICT + )