Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def read_requirements(filename):
# Declare minimal set for installation
required_packages = [
"attrs>=20.3.0,<23",
"boto3>=1.26.28,<2.0",
"boto3>=1.26.131,<2.0",
"cloudpickle==2.2.1",
"google-pasta",
"numpy>=1.9.0,<2.0",
Expand Down
11 changes: 10 additions & 1 deletion src/sagemaker/serverless/serverless_inference_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,11 @@
# language governing permissions and limitations under the License.
"""This module contains code related to the ServerlessInferenceConfig class.

Codes are used for configuring async inference endpoint. Use it when deploying
Codes are used for configuring serverless inference endpoint. Use it when deploying
the model to the endpoints.
"""
from __future__ import print_function, absolute_import
from typing import Optional


class ServerlessInferenceConfig(object):
Expand All @@ -29,6 +30,7 @@ def __init__(
self,
memory_size_in_mb: int = 2048,
max_concurrency: int = 5,
provisioned_concurrency: Optional[int] = None,
):
"""Initialize a ServerlessInferenceConfig object for serverless inference configuration.

Expand All @@ -40,9 +42,13 @@ def __init__(
max_concurrency (int): Optional. The maximum number of concurrent invocations
your serverless endpoint can process. If no value is provided, Amazon
SageMaker will choose the default value for you. (Default: 5)
provisioned_concurrency (int): Optional. The provisioned concurrency of your
serverless endpoint. If no value is provided, Amazon SageMaker will not
apply provisioned concucrrency to your Serverless endpoint. (Default: None)
"""
self.memory_size_in_mb = memory_size_in_mb
self.max_concurrency = max_concurrency
self.provisioned_concurrency = provisioned_concurrency

def _to_request_dict(self):
"""Generates a request dictionary using the parameters provided to the class."""
Expand All @@ -51,4 +57,7 @@ def _to_request_dict(self):
"MaxConcurrency": self.max_concurrency,
}

if self.provisioned_concurrency is not None:
request_dict["ProvisionedConcurrency"] = self.provisioned_concurrency

return request_dict
20 changes: 18 additions & 2 deletions tests/integ/test_serverless_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,14 +44,30 @@ def test_serverless_walkthrough(sagemaker_session, cpu_instance_type, training_s
pca.extra_components = 5
pca.fit(pca.record_set(training_set[0][:100]), job_name=job_name)

with timeout_and_delete_endpoint_by_name(job_name, sagemaker_session):
serverless_name = unique_name_from_base("pca-serverless")
with timeout_and_delete_endpoint_by_name(serverless_name, sagemaker_session):

predictor_serverless = pca.deploy(
endpoint_name=job_name, serverless_inference_config=ServerlessInferenceConfig()
endpoint_name=serverless_name, serverless_inference_config=ServerlessInferenceConfig()
)

result = predictor_serverless.predict(training_set[0][:5])

assert len(result) == 5
for record in result:
assert record.label["projection"] is not None

# Test out Serverless Provisioned Concurrency endpoint happy case
serverless_pc_name = unique_name_from_base("pca-serverless-pc")
with timeout_and_delete_endpoint_by_name(serverless_pc_name, sagemaker_session):

predictor_serverless_pc = pca.deploy(
endpoint_name=serverless_pc_name,
serverless_inference_config=ServerlessInferenceConfig(provisioned_concurrency=1),
)

result = predictor_serverless_pc.predict(training_set[0][:5])

assert len(result) == 5
for record in result:
assert record.label["projection"] is not None
Original file line number Diff line number Diff line change
Expand Up @@ -16,21 +16,54 @@

DEFAULT_MEMORY_SIZE_IN_MB = 2048
DEFAULT_MAX_CONCURRENCY = 5
DEFAULT_PROVISIONED_CONCURRENCY = 5

DEFAULT_REQUEST_DICT = {
"MemorySizeInMB": DEFAULT_MEMORY_SIZE_IN_MB,
"MaxConcurrency": DEFAULT_MAX_CONCURRENCY,
}

PROVISIONED_CONCURRENCY_REQUEST_DICT = {
"MemorySizeInMB": DEFAULT_MEMORY_SIZE_IN_MB,
"MaxConcurrency": DEFAULT_MAX_CONCURRENCY,
"ProvisionedConcurrency": DEFAULT_PROVISIONED_CONCURRENCY,
}


def test_init():
serverless_inference_config = ServerlessInferenceConfig()

assert serverless_inference_config.memory_size_in_mb == DEFAULT_MEMORY_SIZE_IN_MB
assert serverless_inference_config.max_concurrency == DEFAULT_MAX_CONCURRENCY

serverless_provisioned_concurrency_inference_config = ServerlessInferenceConfig(
provisioned_concurrency=DEFAULT_PROVISIONED_CONCURRENCY
)

assert (
serverless_provisioned_concurrency_inference_config.memory_size_in_mb
== DEFAULT_MEMORY_SIZE_IN_MB
)
assert (
serverless_provisioned_concurrency_inference_config.max_concurrency
== DEFAULT_MAX_CONCURRENCY
)
assert (
serverless_provisioned_concurrency_inference_config.provisioned_concurrency
== DEFAULT_PROVISIONED_CONCURRENCY
)


def test_to_request_dict():
serverless_inference_config_dict = ServerlessInferenceConfig()._to_request_dict()

assert serverless_inference_config_dict == DEFAULT_REQUEST_DICT

serverless_provisioned_concurrency_inference_config_dict = ServerlessInferenceConfig(
provisioned_concurrency=DEFAULT_PROVISIONED_CONCURRENCY
)._to_request_dict()

assert (
serverless_provisioned_concurrency_inference_config_dict
== PROVISIONED_CONCURRENCY_REQUEST_DICT
)