From cba903440189cf0b140aa338d6d0d3ed6f3fd41e Mon Sep 17 00:00:00 2001 From: Evan Kravitz Date: Mon, 2 Oct 2023 18:37:15 +0000 Subject: [PATCH 01/13] feat: jumpstart default payloads --- src/sagemaker/base_predictor.py | 43 ++++- src/sagemaker/jumpstart/accessors.py | 44 +++++ src/sagemaker/jumpstart/artifacts/__init__.py | 3 + src/sagemaker/jumpstart/artifacts/payloads.py | 84 +++++++++ src/sagemaker/jumpstart/model.py | 23 +++ src/sagemaker/jumpstart/payload_utils.py | 134 ++++++++++++++ src/sagemaker/jumpstart/types.py | 66 ++++++- src/sagemaker/payloads.py | 116 ++++++++++++ tests/unit/sagemaker/jumpstart/constants.py | 167 ++++++++++++++++++ .../sagemaker/jumpstart/test_accessors.py | 48 +++++ .../sagemaker/jumpstart/test_payload_utils.py | 63 +++++++ .../sagemaker/jumpstart/test_predictor.py | 46 +++++ 12 files changed, 828 insertions(+), 9 deletions(-) create mode 100644 src/sagemaker/jumpstart/artifacts/payloads.py create mode 100644 src/sagemaker/jumpstart/payload_utils.py create mode 100644 src/sagemaker/payloads.py create mode 100644 tests/unit/sagemaker/jumpstart/test_payload_utils.py diff --git a/src/sagemaker/base_predictor.py b/src/sagemaker/base_predictor.py index 46983e0983..f02c1d70fd 100644 --- a/src/sagemaker/base_predictor.py +++ b/src/sagemaker/base_predictor.py @@ -32,6 +32,9 @@ StreamDeserializer, StringDeserializer, ) +from sagemaker.jumpstart.payload_utils import PayloadSerializer +from sagemaker.jumpstart.types import JumpStartSerializablePayload +from sagemaker.jumpstart.utils import get_jumpstart_content_bucket from sagemaker.model_monitor import ( DataCaptureConfig, DefaultModelMonitor, @@ -201,20 +204,42 @@ def _create_request_args( custom_attributes=None, ): """Placeholder docstring""" + + js_accept = None + + if isinstance(data, JumpStartSerializablePayload): + s3_client = self.sagemaker_session.s3_client + region = self.sagemaker_session._region_name + bucket = get_jumpstart_content_bucket(region) + + js_serialized_data = PayloadSerializer( + bucket=bucket, region=region, s3_client=s3_client + ).serialize(data) + js_content_type = data.content_type + js_accept = data.accept + args = dict(initial_args) if initial_args else {} if "EndpointName" not in args: args["EndpointName"] = self.endpoint_name if "ContentType" not in args: - args["ContentType"] = ( - self.content_type - if isinstance(self.content_type, str) - else ", ".join(self.content_type) - ) + if isinstance(data, JumpStartSerializablePayload): + args["ContentType"] = js_content_type + else: + args["ContentType"] = ( + self.content_type + if isinstance(self.content_type, str) + else ", ".join(self.content_type) + ) if "Accept" not in args: - args["Accept"] = self.accept if isinstance(self.accept, str) else ", ".join(self.accept) + if isinstance(data, JumpStartSerializablePayload) and js_accept: + args["Accept"] = js_accept + else: + args["Accept"] = ( + self.accept if isinstance(self.accept, str) else ", ".join(self.accept) + ) if target_model: args["TargetModel"] = target_model @@ -228,7 +253,11 @@ def _create_request_args( if custom_attributes: args["CustomAttributes"] = custom_attributes - data = self.serializer.serialize(data) + data = ( + self.serializer.serialize(data) + if not isinstance(data, JumpStartSerializablePayload) + else js_serialized_data + ) args["Body"] = data return args diff --git a/src/sagemaker/jumpstart/accessors.py b/src/sagemaker/jumpstart/accessors.py index 8117606299..1d3e2abb0d 100644 --- a/src/sagemaker/jumpstart/accessors.py +++ b/src/sagemaker/jumpstart/accessors.py @@ -12,6 +12,7 @@ # language governing permissions and limitations under the License. """This module contains accessors related to SageMaker JumpStart.""" from __future__ import absolute_import +import functools from typing import Any, Dict, List, Optional import boto3 @@ -37,6 +38,49 @@ def get_sagemaker_version() -> str: return SageMakerSettings._parsed_sagemaker_version +class JumpStartS3Accessor(object): + """Static class for storing and retrieving auxilliary s3 artifacts.""" + + @functools.cache + @staticmethod + def _get_default_s3_client(region: str = JUMPSTART_DEFAULT_REGION_NAME) -> boto3.client: + """Returns default s3 client associated with the region. + + Result is cached so multiple clients in memory are not created. + """ + return boto3.client("s3", region_name=region) + + @functools.lru_cache + @staticmethod + def get_object_cached( + bucket: str, + key: str, + region: str = JUMPSTART_DEFAULT_REGION_NAME, + s3_client: Optional[boto3.client] = None, + ) -> bytes: + """Returns s3 object located at the bucket and key. + + Requests are cached so that the same s3 request is never made more + than once, unless a different region or client is used. + """ + return JumpStartS3Accessor.get_object( + bucket=bucket, key=key, region=region, s3_client=s3_client + ) + + @staticmethod + def get_object( + bucket: str, + key: str, + region: str = JUMPSTART_DEFAULT_REGION_NAME, + s3_client: Optional[boto3.client] = None, + ) -> bytes: + """Returns s3 object located at the bucket and key.""" + if s3_client is None: + s3_client = JumpStartS3Accessor._get_default_s3_client(region) + + return s3_client.get_object(Bucket=bucket, Key=key)["Body"].read() + + class JumpStartModelsAccessor(object): """Static class for storing the JumpStart models cache.""" diff --git a/src/sagemaker/jumpstart/artifacts/__init__.py b/src/sagemaker/jumpstart/artifacts/__init__.py index ec44077a64..f57bfcbf46 100644 --- a/src/sagemaker/jumpstart/artifacts/__init__.py +++ b/src/sagemaker/jumpstart/artifacts/__init__.py @@ -61,3 +61,6 @@ _retrieve_model_package_arn, _retrieve_model_package_model_artifact_s3_uri, ) +from sagemaker.jumpstart.artifacts.payloads import ( # noqa: F401 + _retrieve_default_payloads, +) diff --git a/src/sagemaker/jumpstart/artifacts/payloads.py b/src/sagemaker/jumpstart/artifacts/payloads.py new file mode 100644 index 0000000000..ee573859c3 --- /dev/null +++ b/src/sagemaker/jumpstart/artifacts/payloads.py @@ -0,0 +1,84 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""This module contains functions for obtaining JumpStart payloads.""" +from __future__ import absolute_import +from copy import deepcopy +from typing import Dict, Optional +from sagemaker.jumpstart.constants import ( + DEFAULT_JUMPSTART_SAGEMAKER_SESSION, + JUMPSTART_DEFAULT_REGION_NAME, +) +from sagemaker.jumpstart.enums import ( + JumpStartScriptScope, +) +from sagemaker.jumpstart.types import JumpStartSerializablePayload +from sagemaker.jumpstart.utils import ( + verify_model_region_and_return_specs, +) +from sagemaker.session import Session + + +def _retrieve_default_payloads( + model_id: str, + model_version: str, + region: Optional[str], + tolerate_vulnerable_model: bool = False, + tolerate_deprecated_model: bool = False, + sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, +) -> Optional[Dict[str, JumpStartSerializablePayload]]: + """Returns default payloads. + + Args: + model_id (str): JumpStart model ID of the JumpStart model for which to + get default payloads. + model_version (str): Version of the JumpStart model for which to retrieve the + default resource name. + region (Optional[str]): Region for which to retrieve the + default resource name. + tolerate_vulnerable_model (bool): True if vulnerable versions of model + specifications should be tolerated (exception not raised). If False, raises an + exception if the script used by this version of the model has dependencies with known + security vulnerabilities. (Default: False). + tolerate_deprecated_model (bool): True if deprecated versions of model + specifications should be tolerated (exception not raised). If False, raises + an exception if the version of the model is deprecated. (Default: False). + sagemaker_session (sagemaker.session.Session): A SageMaker Session + object, used for SageMaker interactions. If not + specified, one is created using the default AWS configuration + chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). + Returns: + str: the default payload. + """ + + if region is None: + region = JUMPSTART_DEFAULT_REGION_NAME + + model_specs = verify_model_region_and_return_specs( + model_id=model_id, + version=model_version, + scope=JumpStartScriptScope.INFERENCE, + region=region, + tolerate_vulnerable_model=tolerate_vulnerable_model, + tolerate_deprecated_model=tolerate_deprecated_model, + sagemaker_session=sagemaker_session, + ) + + default_payloads = model_specs.default_payloads + + if default_payloads: + for payload in default_payloads.values(): + payload.accept = getattr( + payload, "accept", model_specs.predictor_specs.default_accept_type + ) + + return deepcopy(default_payloads) if default_payloads else None diff --git a/src/sagemaker/jumpstart/model.py b/src/sagemaker/jumpstart/model.py index ab060ea454..9c9a08654d 100644 --- a/src/sagemaker/jumpstart/model.py +++ b/src/sagemaker/jumpstart/model.py @@ -16,6 +16,7 @@ import re from typing import Dict, List, Optional, Union +from sagemaker import payloads from sagemaker.async_inference.async_inference_config import AsyncInferenceConfig from sagemaker.base_deserializers import BaseDeserializer from sagemaker.base_serializers import BaseSerializer @@ -28,6 +29,7 @@ get_deploy_kwargs, get_init_kwargs, ) +from sagemaker.jumpstart.types import JumpStartSerializablePayload from sagemaker.jumpstart.utils import is_valid_model_id from sagemaker.utils import stringify_object from sagemaker.model import MODEL_PACKAGE_ARN_PATTERN, Model @@ -312,6 +314,27 @@ def _is_valid_model_id_hook(): super(JumpStartModel, self).__init__(**model_init_kwargs.to_kwargs_dict()) + def retrieve_default_payload(self) -> JumpStartSerializablePayload: + """Returns default payload associated with the model. + + Payload can be directly used with the `sagemaker.predictor.Predictor.predict(...)` function. + """ + sample_payloads: Optional[List[JumpStartSerializablePayload]] = payloads.retrieve_samples( + model_id=self.model_id, + model_version=self.model_version, + region=self.region, + tolerate_deprecated_model=self.tolerate_deprecated_model, + tolerate_vulnerable_model=self.tolerate_vulnerable_model, + sagemaker_session=self.sagemaker_session, + ) + + if sample_payloads is None or len(sample_payloads) == 0: + raise NotImplementedError( + f"No default payload supported for model ID '{self.model_id}'." + ) + + return sample_payloads[0] + def _create_sagemaker_model( self, instance_type=None, diff --git a/src/sagemaker/jumpstart/payload_utils.py b/src/sagemaker/jumpstart/payload_utils.py new file mode 100644 index 0000000000..31b784f409 --- /dev/null +++ b/src/sagemaker/jumpstart/payload_utils.py @@ -0,0 +1,134 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""This module stores payload utilities for SageMaker JumpStart.""" +from __future__ import absolute_import +import base64 +import json +from typing import Any, Optional, Union +import re +import boto3 + +from sagemaker.jumpstart.accessors import JumpStartS3Accessor +from sagemaker.jumpstart.constants import JUMPSTART_DEFAULT_REGION_NAME +from sagemaker.jumpstart.enums import MIMEType +from sagemaker.jumpstart.types import JumpStartSerializablePayload +from sagemaker.jumpstart.utils import get_jumpstart_content_bucket + +S3_BYTES_REGEX = r"^\$s3<(?P[a-zA-Z0-9-_/.]+)>$" +S3_B64_STR_REGEX = r"\$s3_b64<(?P[a-zA-Z0-9-_/.]+)>" + + +class PayloadSerializer: + """Utility class for serializing payloads associated with JumpStart models. + + Many JumpStart models embed byte-streams into payloads corresponding to images, sounds, + and other content types which require downloading from S3. + """ + + def __init__( + self, + bucket: str = get_jumpstart_content_bucket(), + region: str = JUMPSTART_DEFAULT_REGION_NAME, + s3_client: Optional[boto3.client] = None, + ) -> None: + """Initializes PayloadSerializer object.""" + self.bucket = bucket + self.region = region + self.s3_client = s3_client + + def get_bytes_payload_with_s3_references( + self, + payload_str: str, + ) -> bytes: + """Returns bytes object corresponding to referenced s3 object. + + Raises: + ValueError: If the raw bytes payload is not formatted correctly. + """ + s3_keys = re.compile(S3_BYTES_REGEX).findall(payload_str) + if len(s3_keys) != 1: + raise ValueError(f"Invalid bytes payload: {payload_str}") + + s3_key = s3_keys[0] + serialized_s3_object = JumpStartS3Accessor.get_object_cached( + bucket=self.bucket, key=s3_key, region=self.region, s3_client=self.s3_client + ) + + return serialized_s3_object + + def embed_s3_references_in_str_payload( + self, + payload: str, + ) -> str: + """Embeds s3 references in string payloads.""" + return self._embed_s3_b64_references_in_str_payload(payload_body=payload) + + def _embed_s3_b64_references_in_str_payload( + self, + payload_body: str, + ) -> str: + """Performs base 64 encoding of payloads embedded in a payload. + + This is required so that byte-valued payloads can be transmitted efficiently + as a utf-8 encoded string. + """ + + s3_keys = re.compile(S3_B64_STR_REGEX).findall(payload_body) + for s3_key in s3_keys: + b64_encoded_string = base64.b64encode( + bytearray( + JumpStartS3Accessor.get_object_cached( + bucket=self.bucket, key=s3_key, region=self.region, s3_client=self.s3_client + ) + ) + ).decode() + payload_body = payload_body.replace(f"$s3_b64<{s3_key}>", b64_encoded_string) + return payload_body + + def embed_s3_references_in_json_payload( + self, payload_body: Union[list, dict, str, int, float] + ) -> Union[list, dict, str, int, float]: + """Finds all s3 references in payload and embeds serialized s3 data. + + S3 bucket is assumed to be the default JumpStart content bucket. If no s3 references + are found, the payload is returned un-modified. + """ + if isinstance(payload_body, str): + return self.embed_s3_references_in_str_payload(payload_body) + if isinstance(payload_body, (float, int)): + return payload_body + if isinstance(payload_body, list): + return [self.embed_s3_references_in_json_payload(item) for item in payload_body] + if isinstance(payload_body, dict): + return { + key: self.embed_s3_references_in_json_payload(value) + for key, value in payload_body.items() + } + raise ValueError(f"Payload has unrecognized type: {type(payload_body)}") + + def serialize(self, payload: JumpStartSerializablePayload) -> Any: + """Returns payload bytes that can be inputted to inference endpoint.""" + content_type = MIMEType.from_suffixed_type(payload.content_type) + body = payload.body + + if content_type in {MIMEType.JSON, MIMEType.LIST_TEXT, MIMEType.X_TEXT}: + body = self.embed_s3_references_in_json_payload(body) + else: + body = self.get_bytes_payload_with_s3_references(body) + + if isinstance(body, dict): + body = json.dumps(body) + elif not isinstance(body, str) and not isinstance(body, bytes): + raise ValueError(f"Default payload '{body}' has unrecognized type: {type(body)}") + + return body diff --git a/src/sagemaker/jumpstart/types.py b/src/sagemaker/jumpstart/types.py index e8b717b7c7..1c2a7226e4 100644 --- a/src/sagemaker/jumpstart/types.py +++ b/src/sagemaker/jumpstart/types.py @@ -31,6 +31,8 @@ class JumpStartDataHolderType: __slots__: List[str] = [] + _non_serializable_slots: List[str] = [] + def __eq__(self, other: Any) -> bool: """Returns True if ``other`` is of the same type and has all attributes equal. @@ -69,7 +71,11 @@ def __str__(self) -> str: {'content_bucket': 'bucket', 'region_name': 'us-west-2'}" """ - att_dict = {att: getattr(self, att) for att in self.__slots__ if hasattr(self, att)} + att_dict = { + att: getattr(self, att) + for att in self.__slots__ + if hasattr(self, att) and att not in self._non_serializable_slots + } return f"{type(self).__name__}: {str(att_dict)}" def __repr__(self) -> str: @@ -79,7 +85,11 @@ def __repr__(self) -> str: {'content_bucket': 'bucket', 'region_name': 'us-west-2'}" """ - att_dict = {att: getattr(self, att) for att in self.__slots__ if hasattr(self, att)} + att_dict = { + att: getattr(self, att) + for att in self.__slots__ + if hasattr(self, att) and att not in self._non_serializable_slots + } return f"{type(self).__name__} at {hex(id(self))}: {str(att_dict)}" @@ -312,6 +322,49 @@ def to_json(self) -> Dict[str, Any]: return json_obj +class JumpStartSerializablePayload(JumpStartDataHolderType): + """Data class for JumpStart serialized payload specs.""" + + __slots__ = [ + "raw_payload", + "content_type", + "accept", + "body", + ] + + _non_serializable_slots = ["raw_payload"] + + def __init__(self, spec: Optional[Dict[str, Any]]): + """Initializes a JumpStartSerializablePayload object from its json representation. + + Args: + spec (Dict[str, Any]): Dictionary representation of payload specs. + """ + self.from_json(spec) + + def from_json(self, json_obj: Optional[Dict[str, Any]]) -> None: + """Sets fields in object based on json. + + Args: + json_obj (Dict[str, Any]): Dictionary representation of serializable + payload specs. + """ + + if json_obj is None: + return + + self.raw_payload = json_obj + self.content_type = json_obj["content_type"] + self.body = json_obj["body"] + accept = json_obj.get("accept") + if accept: + self.accept = accept + + def to_json(self) -> Dict[str, Any]: + """Returns json representation of JumpStartDefaultPayloads object.""" + return deepcopy(self.raw_payload) + + class JumpStartInstanceTypeVariants(JumpStartDataHolderType): """Data class for JumpStart instance type variants.""" @@ -468,6 +521,7 @@ class JumpStartModelSpecs(JumpStartDataHolderType): "hosting_use_script_uri", "hosting_instance_type_variants", "training_instance_type_variants", + "default_payloads", ] def __init__(self, spec: Dict[str, Any]): @@ -536,6 +590,14 @@ def from_json(self, json_obj: Dict[str, Any]) -> None: if "predictor_specs" in json_obj else None ) + self.default_payloads: Optional[Dict[str, JumpStartSerializablePayload]] = ( + { + alias: JumpStartSerializablePayload(payload) + for alias, payload in json_obj["default_payloads"].items() + } + if json_obj.get("default_payloads") + else None + ) self.inference_volume_size: Optional[int] = json_obj.get("inference_volume_size") self.inference_enable_network_isolation: bool = json_obj.get( "inference_enable_network_isolation", False diff --git a/src/sagemaker/payloads.py b/src/sagemaker/payloads.py new file mode 100644 index 0000000000..fce0b86d71 --- /dev/null +++ b/src/sagemaker/payloads.py @@ -0,0 +1,116 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Utilities related to payloads of pretrained machine learning models.""" +from __future__ import absolute_import + +import logging +from typing import Dict, List, Optional + +from sagemaker.jumpstart import utils as jumpstart_utils +from sagemaker.jumpstart import artifacts +from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION +from sagemaker.jumpstart.payload_utils import PayloadSerializer +from sagemaker.jumpstart.types import JumpStartSerializablePayload +from sagemaker.session import Session + + +logger = logging.getLogger(__name__) + + +def retrieve_samples( + region: Optional[str] = None, + model_id: Optional[str] = None, + model_version: Optional[str] = None, + serialize: bool = True, + tolerate_vulnerable_model: bool = False, + tolerate_deprecated_model: bool = False, + sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, +) -> Optional[List[str]]: + """Retrieves the compatible payloads for the model matching the given arguments. + + Args: + region (str): The AWS Region for which to retrieve the Jumpstart model payloads. + model_id (str): The model ID of the JumpStart model for which to retrieve + the model payloads. + model_version (str): The version of the JumpStart model for which to retrieve + the model payloads. + serialize (bool): Whether to serialize byte-stream valued payloads by downloading + binary files from s3 and applying encoding, or to keep payload in pre-serialized . + state. Set this option to False if you only want to avoid s3 download of it you + want to inspect the payload in a human-readable form. (Default: True). + tolerate_vulnerable_model (bool): ``True`` if vulnerable versions of model + specifications should be tolerated without raising an exception. If ``False``, raises an + exception if the script used by this version of the model has dependencies with known + security vulnerabilities. (Default: False). + tolerate_deprecated_model (bool): ``True`` if deprecated versions of model + specifications should be tolerated without raising an exception. If ``False``, raises + an exception if the version of the model is deprecated. (Default: False). + sagemaker_session (sagemaker.session.Session): A SageMaker Session + object, used for SageMaker interactions. If not + specified, one is created using the default AWS configuration + chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). + Returns: + str: The model artifact S3 URI for the corresponding model. + + Raises: + NotImplementedError: If the scope is not supported. + ValueError: If the combination of arguments specified is not supported. + VulnerableJumpStartModelError: If any of the dependencies required by the script have + known security vulnerabilities. + DeprecatedJumpStartModelError: If the version of the model is deprecated. + """ + if not jumpstart_utils.is_jumpstart_model_input(model_id, model_version): + raise ValueError( + "Must specify JumpStart `model_id` and `model_version` when retrieving model URIs." + ) + + unserialized_payload_dict: Optional[ + Dict[str, JumpStartSerializablePayload] + ] = artifacts._retrieve_default_payloads( + model_id, + model_version, # type: ignore + region, + tolerate_vulnerable_model, + tolerate_deprecated_model, + sagemaker_session=sagemaker_session, + ) + + if unserialized_payload_dict is None: + return None + + unserialized_payloads: List[JumpStartSerializablePayload] = list( + unserialized_payload_dict.values() + ) + + if not serialize: + return unserialized_payloads + + payload_serializer = PayloadSerializer(region=region, s3_client=sagemaker_session.s3_client) + + serialized_payloads: List[JumpStartSerializablePayload] = [] + + for payload in unserialized_payloads: + + serialized_body = payload_serializer.serialize(payload) + + serialized_payloads.append( + JumpStartSerializablePayload( + { + "content_type": payload.content_type, + "body": serialized_body, + "accept": payload.accept, + } + ) + ) + + return serialized_payloads diff --git a/tests/unit/sagemaker/jumpstart/constants.py b/tests/unit/sagemaker/jumpstart/constants.py index f5cc4fbb58..b65167165c 100644 --- a/tests/unit/sagemaker/jumpstart/constants.py +++ b/tests/unit/sagemaker/jumpstart/constants.py @@ -1620,6 +1620,172 @@ }, }, }, + "default_payloads": { + "model_id": "model-depth2img-stable-diffusion-v1-5-controlnet-v1-1-fp16", + "url": "https://huggingface.co/lllyasviel/control_v11f1p_sd15_depth", + "version": "1.0.0", + "min_sdk_version": "2.144.0", + "training_supported": False, + "incremental_training_supported": False, + "hosting_ecr_specs": { + "framework": "djl-deepspeed", + "framework_version": "0.21.0", + "py_version": "py38", + "huggingface_transformers_version": "4.17", + }, + "hosting_artifact_key": "stabilityai-infer/infer-model-depth2img-st" + "able-diffusion-v1-5-controlnet-v1-1-fp16.tar.gz", + "hosting_script_key": "source-directory-tarballs/stabilityai/inference/depth2img/v1.0.0/sourcedir.tar.gz", + "hosting_prepacked_artifact_key": "stabilityai-infer/prepack/v1.0.0/" + "infer-prepack-model-depth2img-stable-diffusion-v1-5-controlnet-v1-1-fp16.tar.gz", + "hosting_prepacked_artifact_version": "1.0.0", + "inference_vulnerable": False, + "inference_dependencies": [ + "accelerate==0.18.0", + "diffusers==0.14.0", + "fsspec==2023.4.0", + "huggingface-hub==0.14.1", + "transformers==4.26.1", + ], + "inference_vulnerabilities": [], + "training_vulnerable": False, + "training_dependencies": [], + "training_vulnerabilities": [], + "deprecated": False, + "inference_environment_variables": [ + { + "name": "SAGEMAKER_PROGRAM", + "type": "text", + "default": "inference.py", + "scope": "container", + "required_for_model_class": True, + }, + { + "name": "SAGEMAKER_SUBMIT_DIRECTORY", + "type": "text", + "default": "/opt/ml/model/code", + "scope": "container", + "required_for_model_class": False, + }, + { + "name": "SAGEMAKER_CONTAINER_LOG_LEVEL", + "type": "text", + "default": "20", + "scope": "container", + "required_for_model_class": False, + }, + { + "name": "SAGEMAKER_MODEL_SERVER_TIMEOUT", + "type": "text", + "default": "3600", + "scope": "container", + "required_for_model_class": False, + }, + { + "name": "ENDPOINT_SERVER_TIMEOUT", + "type": "int", + "default": 3600, + "scope": "container", + "required_for_model_class": True, + }, + { + "name": "MODEL_CACHE_ROOT", + "type": "text", + "default": "/opt/ml/model", + "scope": "container", + "required_for_model_class": True, + }, + { + "name": "SAGEMAKER_ENV", + "type": "text", + "default": "1", + "scope": "container", + "required_for_model_class": True, + }, + { + "name": "SAGEMAKER_MODEL_SERVER_WORKERS", + "type": "int", + "default": 1, + "scope": "container", + "required_for_model_class": True, + }, + ], + "metrics": [], + "default_inference_instance_type": "ml.g5.8xlarge", + "supported_inference_instance_types": [ + "ml.g5.8xlarge", + "ml.g5.xlarge", + "ml.g5.2xlarge", + "ml.g5.4xlarge", + "ml.g5.16xlarge", + "ml.p3.2xlarge", + "ml.g4dn.xlarge", + "ml.g4dn.2xlarge", + "ml.g4dn.4xlarge", + "ml.g4dn.8xlarge", + "ml.g4dn.16xlarge", + ], + "model_kwargs": {}, + "deploy_kwargs": {}, + "predictor_specs": { + "supported_content_types": ["application/json"], + "supported_accept_types": ["application/json"], + "default_content_type": "application/json", + "default_accept_type": "application/json", + }, + "inference_enable_network_isolation": True, + "validation_supported": False, + "fine_tuning_supported": False, + "resource_name_base": "sd-1-5-controlnet-1-1-fp16", + "default_payloads": { + "Dog": { + "content_type": "application/json", + "body": { + "prompt": "a dog", + "num_images_per_prompt": 2, + "num_inference_steps": 20, + "guidance_scale": 7.5, + "seed": 43, + "eta": 0.7, + "image": "$s3_b64", + }, + } + }, + "hosting_instance_type_variants": { + "regional_aliases": { + "af-south-1": { + "alias_ecr_uri_1": "626614931356.dkr.ecr.af-south-1.amazonaws.com/d" + "jl-inference:0.21.0-deepspeed0.8.3-cu117" + }, + }, + "variants": { + "c4": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, + "c5": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, + "c5d": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, + "c5n": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, + "c6i": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, + "g4dn": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, + "g5": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, + "inf1": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, + "inf2": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, + "local": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, + "local_gpu": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, + "m4": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, + "m5": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, + "m5d": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, + "p2": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, + "p3": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, + "p3dn": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, + "p4d": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, + "p4de": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, + "p5": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, + "r5": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, + "r5d": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, + "t2": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, + "t3": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, + }, + }, + }, "predictor-specs-model": { "model_id": "huggingface-text2text-flan-t5-xxl-fp16", "url": "https://huggingface.co/google/flan-t5-xxl", @@ -3190,6 +3356,7 @@ "min_sdk_version": "2.49.0", "training_supported": True, "incremental_training_supported": True, + "default_payloads": None, "hosting_ecr_specs": { "framework": "pytorch", "framework_version": "1.5.0", diff --git a/tests/unit/sagemaker/jumpstart/test_accessors.py b/tests/unit/sagemaker/jumpstart/test_accessors.py index 2de0351103..31a76701a1 100644 --- a/tests/unit/sagemaker/jumpstart/test_accessors.py +++ b/tests/unit/sagemaker/jumpstart/test_accessors.py @@ -11,6 +11,8 @@ # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. from __future__ import absolute_import +from io import BytesIO +from unittest import TestCase from mock.mock import Mock, patch import pytest @@ -134,3 +136,49 @@ def test_jumpstart_models_cache_set_reset_fxs(mock_model_cache: Mock): # necessary because accessors is a static module reload(accessors) + + +class TestS3Accessor(TestCase): + + bucket = "bucket" + key = "key" + + @patch("sagemaker.jumpstart.accessors.boto3.client") + def test_get_object(self, mocked_boto3_client): + + # required due to static class + reload(accessors) + + mocked_boto3_client.return_value = Mock() + + mocked_boto3_client.return_value.get_object.return_value = {"Body": BytesIO(b"s3-object")} + + response = accessors.JumpStartS3Accessor.get_object(bucket=self.bucket, key=self.key) + + self.assertEqual(response, b"s3-object") + + mocked_boto3_client.assert_called_once_with("s3", region_name="us-west-2") + mocked_boto3_client.return_value.get_object.assert_called_once_with( + Bucket=self.bucket, Key=self.key + ) + + @patch("sagemaker.jumpstart.accessors.boto3.client") + def test_get_object_cached(self, mocked_boto3_client): + + # required due to static class + reload(accessors) + + mocked_boto3_client.return_value = Mock() + + mocked_boto3_client.return_value.get_object.return_value = {"Body": BytesIO(b"s3-object")} + + response = accessors.JumpStartS3Accessor.get_object_cached(bucket=self.bucket, key=self.key) + response = accessors.JumpStartS3Accessor.get_object_cached(bucket=self.bucket, key=self.key) + + self.assertEqual(response, b"s3-object") + + # only a single s3 call should be made when identical requests are made + mocked_boto3_client.assert_called_once_with("s3", region_name="us-west-2") + mocked_boto3_client.return_value.get_object.assert_called_once_with( + Bucket=self.bucket, Key=self.key + ) diff --git a/tests/unit/sagemaker/jumpstart/test_payload_utils.py b/tests/unit/sagemaker/jumpstart/test_payload_utils.py new file mode 100644 index 0000000000..5f78c311ce --- /dev/null +++ b/tests/unit/sagemaker/jumpstart/test_payload_utils.py @@ -0,0 +1,63 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import +import base64 +from unittest import TestCase +from mock.mock import patch + +from sagemaker.jumpstart.payload_utils import PayloadSerializer +from sagemaker.jumpstart.types import JumpStartSerializablePayload + + +class TestPayloadSerializer(TestCase): + + payload_serializer = PayloadSerializer() + + @patch("sagemaker.jumpstart.payload_utils.JumpStartS3Accessor.get_object_cached") + def test_serialize_bytes_payload(self, mock_get_object_cached): + + mock_get_object_cached.return_value = "7897" + payload = JumpStartSerializablePayload( + { + "content_type": "audio/wav", + "body": "$s3", + } + ) + serialized_payload = self.payload_serializer.serialize(payload) + self.assertEqual(serialized_payload, "7897") + + @patch("sagemaker.jumpstart.payload_utils.JumpStartS3Accessor.get_object_cached") + def test_serialize_json_payload(self, mock_get_object_cached): + + mock_get_object_cached.return_value = base64.b64decode("encodedimage") + payload = JumpStartSerializablePayload( + { + "content_type": "application/json", + "body": { + "prompt": "a dog", + "num_images_per_prompt": 2, + "num_inference_steps": 20, + "guidance_scale": 7.5, + "seed": 43, + "eta": 0.7, + "image": "$s3_b64", + }, + } + ) + serialized_payload = self.payload_serializer.serialize(payload) + self.assertEqual( + serialized_payload, + '{"prompt": "a dog", "num_images_per_prompt": 2, ' + '"num_inference_steps": 20, "guidance_scale": 7.5, "seed": ' + '43, "eta": 0.7, "image": "encodedimage"}', + ) diff --git a/tests/unit/sagemaker/jumpstart/test_predictor.py b/tests/unit/sagemaker/jumpstart/test_predictor.py index 0758b54f29..80002e02bf 100644 --- a/tests/unit/sagemaker/jumpstart/test_predictor.py +++ b/tests/unit/sagemaker/jumpstart/test_predictor.py @@ -1,10 +1,13 @@ from __future__ import absolute_import +import base64 +from unittest import mock from unittest.mock import patch from sagemaker.deserializers import JSONDeserializer from sagemaker.jumpstart.enums import MIMEType from sagemaker import predictor +from sagemaker.jumpstart.model import JumpStartModel from sagemaker.jumpstart.utils import verify_model_region_and_return_specs @@ -35,3 +38,46 @@ def test_jumpstart_predictor_support( assert isinstance(js_predictor.deserializer, JSONDeserializer) assert js_predictor.accept == MIMEType.JSON + + +@patch("sagemaker.jumpstart.payload_utils.JumpStartS3Accessor.get_object_cached") +@patch("sagemaker.jumpstart.model.is_valid_model_id") +@patch("sagemaker.jumpstart.utils.verify_model_region_and_return_specs") +@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") +def test_jumpstart_serializable_payload_with_predictor( + patched_get_model_specs, + patched_verify_model_region_and_return_specs, + patched_is_valid_model_id, + patched_get_object_cached, +): + + patched_get_object_cached.return_value = base64.b64decode("encodedimage") + patched_is_valid_model_id.return_value = True + + patched_verify_model_region_and_return_specs.side_effect = verify_model_region_and_return_specs + patched_get_model_specs.side_effect = get_special_model_spec + + model_id, model_version = "default_payloads", "*" + + js_predictor = predictor.retrieve_default( + endpoint_name="blah", model_id=model_id, model_version=model_version + ) + + default_payload = JumpStartModel( + model_id=model_id, model_version=model_version + ).retrieve_default_payload() + + invoke_endpoint_mock = mock.Mock() + + js_predictor.sagemaker_session.sagemaker_runtime_client.invoke_endpoint = invoke_endpoint_mock + js_predictor._handle_response = mock.Mock() + + js_predictor.predict(default_payload) + + invoke_endpoint_mock.assert_called_once_with( + EndpointName="blah", + ContentType="application/json", + Accept="application/json", + Body='{"prompt": "a dog", "num_images_per_prompt": 2, "num_inference_steps": 20, ' + '"guidance_scale": 7.5, "seed": 43, "eta": 0.7, "image": "encodedimage"}', + ) From 6082e592a6c159ae504490b05badc516f737ff4e Mon Sep 17 00:00:00 2001 From: Evan Kravitz Date: Mon, 2 Oct 2023 18:43:29 +0000 Subject: [PATCH 02/13] chore: use lru cache for python compatibility --- src/sagemaker/jumpstart/accessors.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/sagemaker/jumpstart/accessors.py b/src/sagemaker/jumpstart/accessors.py index 1d3e2abb0d..90af3b40f6 100644 --- a/src/sagemaker/jumpstart/accessors.py +++ b/src/sagemaker/jumpstart/accessors.py @@ -41,7 +41,7 @@ def get_sagemaker_version() -> str: class JumpStartS3Accessor(object): """Static class for storing and retrieving auxilliary s3 artifacts.""" - @functools.cache + @functools.lru_cache @staticmethod def _get_default_s3_client(region: str = JUMPSTART_DEFAULT_REGION_NAME) -> boto3.client: """Returns default s3 client associated with the region. From bf1306523bea6c459b08b6843a9a4530034d751b Mon Sep 17 00:00:00 2001 From: Evan Kravitz Date: Mon, 2 Oct 2023 18:50:36 +0000 Subject: [PATCH 03/13] fix: lru cache annotation --- src/sagemaker/jumpstart/accessors.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/sagemaker/jumpstart/accessors.py b/src/sagemaker/jumpstart/accessors.py index 90af3b40f6..99462f02a4 100644 --- a/src/sagemaker/jumpstart/accessors.py +++ b/src/sagemaker/jumpstart/accessors.py @@ -41,7 +41,7 @@ def get_sagemaker_version() -> str: class JumpStartS3Accessor(object): """Static class for storing and retrieving auxilliary s3 artifacts.""" - @functools.lru_cache + @functools.lru_cache() @staticmethod def _get_default_s3_client(region: str = JUMPSTART_DEFAULT_REGION_NAME) -> boto3.client: """Returns default s3 client associated with the region. @@ -50,7 +50,7 @@ def _get_default_s3_client(region: str = JUMPSTART_DEFAULT_REGION_NAME) -> boto3 """ return boto3.client("s3", region_name=region) - @functools.lru_cache + @functools.lru_cache() @staticmethod def get_object_cached( bucket: str, From 9750dc59ae429ccf4795948e0ca3d699eb4ac2aa Mon Sep 17 00:00:00 2001 From: Evan Kravitz Date: Mon, 2 Oct 2023 19:13:08 +0000 Subject: [PATCH 04/13] fix: switch order of annotations --- src/sagemaker/jumpstart/accessors.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/sagemaker/jumpstart/accessors.py b/src/sagemaker/jumpstart/accessors.py index 99462f02a4..12c5cb7c4a 100644 --- a/src/sagemaker/jumpstart/accessors.py +++ b/src/sagemaker/jumpstart/accessors.py @@ -41,8 +41,8 @@ def get_sagemaker_version() -> str: class JumpStartS3Accessor(object): """Static class for storing and retrieving auxilliary s3 artifacts.""" - @functools.lru_cache() @staticmethod + @functools.lru_cache() def _get_default_s3_client(region: str = JUMPSTART_DEFAULT_REGION_NAME) -> boto3.client: """Returns default s3 client associated with the region. @@ -50,8 +50,8 @@ def _get_default_s3_client(region: str = JUMPSTART_DEFAULT_REGION_NAME) -> boto3 """ return boto3.client("s3", region_name=region) - @functools.lru_cache() @staticmethod + @functools.lru_cache() def get_object_cached( bucket: str, key: str, From 5dc4aac637d42317f3dfed08ebe8ca0a2b8394e2 Mon Sep 17 00:00:00 2001 From: Evan Kravitz Date: Mon, 2 Oct 2023 19:35:30 +0000 Subject: [PATCH 05/13] fix: docstring, add integ test --- src/sagemaker/payloads.py | 4 +-- .../jumpstart/model/test_jumpstart_model.py | 25 +++++++++++++++++++ 2 files changed, 27 insertions(+), 2 deletions(-) diff --git a/src/sagemaker/payloads.py b/src/sagemaker/payloads.py index fce0b86d71..9645318720 100644 --- a/src/sagemaker/payloads.py +++ b/src/sagemaker/payloads.py @@ -45,8 +45,8 @@ def retrieve_samples( model_version (str): The version of the JumpStart model for which to retrieve the model payloads. serialize (bool): Whether to serialize byte-stream valued payloads by downloading - binary files from s3 and applying encoding, or to keep payload in pre-serialized . - state. Set this option to False if you only want to avoid s3 download of it you + binary files from s3 and applying encoding, or to keep payload in pre-serialized + state. Set this option to False if you want to avoid s3 downloads or if you want to inspect the payload in a human-readable form. (Default: True). tolerate_vulnerable_model (bool): ``True`` if vulnerable versions of model specifications should be tolerated without raising an exception. If ``False``, raises an diff --git a/tests/integ/sagemaker/jumpstart/model/test_jumpstart_model.py b/tests/integ/sagemaker/jumpstart/model/test_jumpstart_model.py index bf39805897..a646472c6d 100644 --- a/tests/integ/sagemaker/jumpstart/model/test_jumpstart_model.py +++ b/tests/integ/sagemaker/jumpstart/model/test_jumpstart_model.py @@ -86,6 +86,31 @@ def test_prepacked_jumpstart_model(setup): assert response is not None +def test_default_payload_jumpstart_model(setup): + + ## DO NOT COMMIT THIS LINE + os.environ.update({"AWS_JUMPSTART_CONTENT_BUCKET_OVERRIDE": "jumpstart-cache-alpha-us-west-2"}) + + model_id = "model-depth2img-stable-diffusion-v1-5-controlnet-v1-1-fp16" + + model = JumpStartModel( + model_id=model_id, + role=get_sm_session().get_caller_identity_arn(), + sagemaker_session=get_sm_session(), + ) + + default_payload = model.retrieve_default_payload() + + # uses ml.g5.8xlarge instance + predictor = model.deploy( + tags=[{"Key": JUMPSTART_TAG, "Value": os.environ[ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID]}], + ) + + response = predictor.predict(default_payload) + + assert response is not None + + @pytest.mark.skipif( tests.integ.test_region() not in GATED_INFERENCE_MODEL_SUPPORTED_REGIONS, reason=f"JumpStart gated inference models unavailable in {tests.integ.test_region()}.", From ca20d872ac0742ba5f194aace9d0448cd991eabe Mon Sep 17 00:00:00 2001 From: Evan Kravitz Date: Mon, 2 Oct 2023 20:30:43 +0000 Subject: [PATCH 06/13] fix: flake8 --- tests/integ/sagemaker/jumpstart/model/test_jumpstart_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/integ/sagemaker/jumpstart/model/test_jumpstart_model.py b/tests/integ/sagemaker/jumpstart/model/test_jumpstart_model.py index a646472c6d..b8aa4c7886 100644 --- a/tests/integ/sagemaker/jumpstart/model/test_jumpstart_model.py +++ b/tests/integ/sagemaker/jumpstart/model/test_jumpstart_model.py @@ -88,7 +88,7 @@ def test_prepacked_jumpstart_model(setup): def test_default_payload_jumpstart_model(setup): - ## DO NOT COMMIT THIS LINE + # DO NOT COMMIT THIS LINE os.environ.update({"AWS_JUMPSTART_CONTENT_BUCKET_OVERRIDE": "jumpstart-cache-alpha-us-west-2"}) model_id = "model-depth2img-stable-diffusion-v1-5-controlnet-v1-1-fp16" From b2a637464cc2a1de0b12252786a0bc5dd53d6383 Mon Sep 17 00:00:00 2001 From: Evan Kravitz Date: Tue, 3 Oct 2023 21:20:14 +0000 Subject: [PATCH 07/13] fix: incorrect words --- src/sagemaker/jumpstart/types.py | 2 +- src/sagemaker/payloads.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/sagemaker/jumpstart/types.py b/src/sagemaker/jumpstart/types.py index 1c2a7226e4..3a048fd826 100644 --- a/src/sagemaker/jumpstart/types.py +++ b/src/sagemaker/jumpstart/types.py @@ -361,7 +361,7 @@ def from_json(self, json_obj: Optional[Dict[str, Any]]) -> None: self.accept = accept def to_json(self) -> Dict[str, Any]: - """Returns json representation of JumpStartDefaultPayloads object.""" + """Returns json representation of JumpStartSerializablePayload object.""" return deepcopy(self.raw_payload) diff --git a/src/sagemaker/payloads.py b/src/sagemaker/payloads.py index 9645318720..def2303275 100644 --- a/src/sagemaker/payloads.py +++ b/src/sagemaker/payloads.py @@ -71,7 +71,7 @@ def retrieve_samples( """ if not jumpstart_utils.is_jumpstart_model_input(model_id, model_version): raise ValueError( - "Must specify JumpStart `model_id` and `model_version` when retrieving model URIs." + "Must specify JumpStart `model_id` and `model_version` when retrieving payloads." ) unserialized_payload_dict: Optional[ From 214e458c5600fa384c374df8d4108e60d6289c12 Mon Sep 17 00:00:00 2001 From: Evan Kravitz Date: Wed, 4 Oct 2023 15:42:40 +0000 Subject: [PATCH 08/13] chore: address PR comments --- src/sagemaker/base_predictor.py | 26 +++++++++-------- src/sagemaker/jumpstart/accessors.py | 17 +++++++---- src/sagemaker/jumpstart/artifacts/payloads.py | 2 +- src/sagemaker/jumpstart/model.py | 12 ++++---- src/sagemaker/jumpstart/payload_utils.py | 29 ++++++++++++------- src/sagemaker/jumpstart/types.py | 3 ++ src/sagemaker/payloads.py | 4 +-- .../jumpstart/model/test_jumpstart_model.py | 25 ---------------- 8 files changed, 56 insertions(+), 62 deletions(-) diff --git a/src/sagemaker/base_predictor.py b/src/sagemaker/base_predictor.py index f02c1d70fd..49106f02c0 100644 --- a/src/sagemaker/base_predictor.py +++ b/src/sagemaker/base_predictor.py @@ -14,7 +14,7 @@ from __future__ import print_function, absolute_import import abc -from typing import Any, Tuple +from typing import Any, Optional, Tuple, Union from sagemaker.deprecations import ( deprecated_class, @@ -205,18 +205,20 @@ def _create_request_args( ): """Placeholder docstring""" - js_accept = None + jumpstart_serialized_data: Optional[Union[str, bytes]] = None + jumpstart_accept: Optional[str] = None + jumpstart_content_type: Optional[str] = None if isinstance(data, JumpStartSerializablePayload): s3_client = self.sagemaker_session.s3_client region = self.sagemaker_session._region_name bucket = get_jumpstart_content_bucket(region) - js_serialized_data = PayloadSerializer( + jumpstart_serialized_data = PayloadSerializer( bucket=bucket, region=region, s3_client=s3_client ).serialize(data) - js_content_type = data.content_type - js_accept = data.accept + jumpstart_content_type = data.content_type + jumpstart_accept = data.accept args = dict(initial_args) if initial_args else {} @@ -224,8 +226,8 @@ def _create_request_args( args["EndpointName"] = self.endpoint_name if "ContentType" not in args: - if isinstance(data, JumpStartSerializablePayload): - args["ContentType"] = js_content_type + if isinstance(data, JumpStartSerializablePayload) and jumpstart_content_type: + args["ContentType"] = jumpstart_content_type else: args["ContentType"] = ( self.content_type @@ -234,8 +236,8 @@ def _create_request_args( ) if "Accept" not in args: - if isinstance(data, JumpStartSerializablePayload) and js_accept: - args["Accept"] = js_accept + if isinstance(data, JumpStartSerializablePayload) and jumpstart_accept: + args["Accept"] = jumpstart_accept else: args["Accept"] = ( self.accept if isinstance(self.accept, str) else ", ".join(self.accept) @@ -254,9 +256,9 @@ def _create_request_args( args["CustomAttributes"] = custom_attributes data = ( - self.serializer.serialize(data) - if not isinstance(data, JumpStartSerializablePayload) - else js_serialized_data + jumpstart_serialized_data + if isinstance(data, JumpStartSerializablePayload) and jumpstart_serialized_data + else self.serializer.serialize(data) ) args["Body"] = data diff --git a/src/sagemaker/jumpstart/accessors.py b/src/sagemaker/jumpstart/accessors.py index 12c5cb7c4a..4b626fda69 100644 --- a/src/sagemaker/jumpstart/accessors.py +++ b/src/sagemaker/jumpstart/accessors.py @@ -39,12 +39,19 @@ def get_sagemaker_version() -> str: class JumpStartS3Accessor(object): - """Static class for storing and retrieving auxilliary s3 artifacts.""" + """Static class for storing and retrieving auxilliary S3 artifacts.""" + + @staticmethod + def clear_cache() -> None: + """Clears LRU caches associated with S3 client and retrieved objects.""" + + JumpStartS3Accessor._get_default_s3_client.cache_clear() + JumpStartS3Accessor.get_object_cached.cache_clear() @staticmethod @functools.lru_cache() def _get_default_s3_client(region: str = JUMPSTART_DEFAULT_REGION_NAME) -> boto3.client: - """Returns default s3 client associated with the region. + """Returns default S3 client associated with the region. Result is cached so multiple clients in memory are not created. """ @@ -58,9 +65,9 @@ def get_object_cached( region: str = JUMPSTART_DEFAULT_REGION_NAME, s3_client: Optional[boto3.client] = None, ) -> bytes: - """Returns s3 object located at the bucket and key. + """Returns S3 object located at the bucket and key. - Requests are cached so that the same s3 request is never made more + Requests are cached so that the same S3 request is never made more than once, unless a different region or client is used. """ return JumpStartS3Accessor.get_object( @@ -74,7 +81,7 @@ def get_object( region: str = JUMPSTART_DEFAULT_REGION_NAME, s3_client: Optional[boto3.client] = None, ) -> bytes: - """Returns s3 object located at the bucket and key.""" + """Returns S3 object located at the bucket and key.""" if s3_client is None: s3_client = JumpStartS3Accessor._get_default_s3_client(region) diff --git a/src/sagemaker/jumpstart/artifacts/payloads.py b/src/sagemaker/jumpstart/artifacts/payloads.py index ee573859c3..63d0390295 100644 --- a/src/sagemaker/jumpstart/artifacts/payloads.py +++ b/src/sagemaker/jumpstart/artifacts/payloads.py @@ -10,7 +10,7 @@ # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. -"""This module contains functions for obtaining JumpStart payloads.""" +"""This module contains functions for obtaining example payloads for JumpStart models.""" from __future__ import absolute_import from copy import deepcopy from typing import Dict, Optional diff --git a/src/sagemaker/jumpstart/model.py b/src/sagemaker/jumpstart/model.py index 9c9a08654d..81dc1b2bd9 100644 --- a/src/sagemaker/jumpstart/model.py +++ b/src/sagemaker/jumpstart/model.py @@ -315,11 +315,11 @@ def _is_valid_model_id_hook(): super(JumpStartModel, self).__init__(**model_init_kwargs.to_kwargs_dict()) def retrieve_default_payload(self) -> JumpStartSerializablePayload: - """Returns default payload associated with the model. + """Returns the default payload associated with the model. Payload can be directly used with the `sagemaker.predictor.Predictor.predict(...)` function. """ - sample_payloads: Optional[List[JumpStartSerializablePayload]] = payloads.retrieve_samples( + payload_options: Optional[List[JumpStartSerializablePayload]] = payloads.retrieve_options( model_id=self.model_id, model_version=self.model_version, region=self.region, @@ -328,12 +328,10 @@ def retrieve_default_payload(self) -> JumpStartSerializablePayload: sagemaker_session=self.sagemaker_session, ) - if sample_payloads is None or len(sample_payloads) == 0: - raise NotImplementedError( - f"No default payload supported for model ID '{self.model_id}'." - ) + if payload_options is None or len(payload_options) == 0: + return None - return sample_payloads[0] + return payload_options[0] def _create_sagemaker_model( self, diff --git a/src/sagemaker/jumpstart/payload_utils.py b/src/sagemaker/jumpstart/payload_utils.py index 31b784f409..67ba1dfb0b 100644 --- a/src/sagemaker/jumpstart/payload_utils.py +++ b/src/sagemaker/jumpstart/payload_utils.py @@ -10,11 +10,11 @@ # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. -"""This module stores payload utilities for SageMaker JumpStart.""" +"""This module stores stores inference payload utilities for JumpStart models.""" from __future__ import absolute_import import base64 import json -from typing import Any, Optional, Union +from typing import Optional, Union import re import boto3 @@ -50,14 +50,14 @@ def get_bytes_payload_with_s3_references( self, payload_str: str, ) -> bytes: - """Returns bytes object corresponding to referenced s3 object. + """Returns bytes object corresponding to referenced S3 object. Raises: ValueError: If the raw bytes payload is not formatted correctly. """ s3_keys = re.compile(S3_BYTES_REGEX).findall(payload_str) if len(s3_keys) != 1: - raise ValueError(f"Invalid bytes payload: {payload_str}") + raise ValueError("Invalid bytes payload.") s3_key = s3_keys[0] serialized_s3_object = JumpStartS3Accessor.get_object_cached( @@ -70,7 +70,10 @@ def embed_s3_references_in_str_payload( self, payload: str, ) -> str: - """Embeds s3 references in string payloads.""" + """Inserts serialized S3 content into string payload. + + If no S3 content is embedded in payload, original string is returned. + """ return self._embed_s3_b64_references_in_str_payload(payload_body=payload) def _embed_s3_b64_references_in_str_payload( @@ -98,10 +101,12 @@ def _embed_s3_b64_references_in_str_payload( def embed_s3_references_in_json_payload( self, payload_body: Union[list, dict, str, int, float] ) -> Union[list, dict, str, int, float]: - """Finds all s3 references in payload and embeds serialized s3 data. + """Finds all S3 references in payload and embeds serialized S3 data. - S3 bucket is assumed to be the default JumpStart content bucket. If no s3 references - are found, the payload is returned un-modified. + If no S3 references are found, the payload is returned un-modified. + + Raises: + ValueError: If the payload has an unrecognized type. """ if isinstance(payload_body, str): return self.embed_s3_references_in_str_payload(payload_body) @@ -116,8 +121,12 @@ def embed_s3_references_in_json_payload( } raise ValueError(f"Payload has unrecognized type: {type(payload_body)}") - def serialize(self, payload: JumpStartSerializablePayload) -> Any: - """Returns payload bytes that can be inputted to inference endpoint.""" + def serialize(self, payload: JumpStartSerializablePayload) -> Union[str, bytes]: + """Returns payload string or bytes that can be inputted to inference endpoint. + + Raises: + ValueError: If the payload has an unrecognized type. + """ content_type = MIMEType.from_suffixed_type(payload.content_type) body = payload.body diff --git a/src/sagemaker/jumpstart/types.py b/src/sagemaker/jumpstart/types.py index 3a048fd826..a6863687e7 100644 --- a/src/sagemaker/jumpstart/types.py +++ b/src/sagemaker/jumpstart/types.py @@ -348,6 +348,9 @@ def from_json(self, json_obj: Optional[Dict[str, Any]]) -> None: Args: json_obj (Dict[str, Any]): Dictionary representation of serializable payload specs. + + Raises: + KeyError: If the dictionary is missing keys. """ if json_obj is None: diff --git a/src/sagemaker/payloads.py b/src/sagemaker/payloads.py index def2303275..0878a86bad 100644 --- a/src/sagemaker/payloads.py +++ b/src/sagemaker/payloads.py @@ -27,7 +27,7 @@ logger = logging.getLogger(__name__) -def retrieve_samples( +def retrieve_options( region: Optional[str] = None, model_id: Optional[str] = None, model_version: Optional[str] = None, @@ -78,7 +78,7 @@ def retrieve_samples( Dict[str, JumpStartSerializablePayload] ] = artifacts._retrieve_default_payloads( model_id, - model_version, # type: ignore + model_version, region, tolerate_vulnerable_model, tolerate_deprecated_model, diff --git a/tests/integ/sagemaker/jumpstart/model/test_jumpstart_model.py b/tests/integ/sagemaker/jumpstart/model/test_jumpstart_model.py index b8aa4c7886..bf39805897 100644 --- a/tests/integ/sagemaker/jumpstart/model/test_jumpstart_model.py +++ b/tests/integ/sagemaker/jumpstart/model/test_jumpstart_model.py @@ -86,31 +86,6 @@ def test_prepacked_jumpstart_model(setup): assert response is not None -def test_default_payload_jumpstart_model(setup): - - # DO NOT COMMIT THIS LINE - os.environ.update({"AWS_JUMPSTART_CONTENT_BUCKET_OVERRIDE": "jumpstart-cache-alpha-us-west-2"}) - - model_id = "model-depth2img-stable-diffusion-v1-5-controlnet-v1-1-fp16" - - model = JumpStartModel( - model_id=model_id, - role=get_sm_session().get_caller_identity_arn(), - sagemaker_session=get_sm_session(), - ) - - default_payload = model.retrieve_default_payload() - - # uses ml.g5.8xlarge instance - predictor = model.deploy( - tags=[{"Key": JUMPSTART_TAG, "Value": os.environ[ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID]}], - ) - - response = predictor.predict(default_payload) - - assert response is not None - - @pytest.mark.skipif( tests.integ.test_region() not in GATED_INFERENCE_MODEL_SUPPORTED_REGIONS, reason=f"JumpStart gated inference models unavailable in {tests.integ.test_region()}.", From aa3163e839fc56804bd1bb00d4b685357dad2634 Mon Sep 17 00:00:00 2001 From: Evan Kravitz Date: Wed, 4 Oct 2023 23:15:36 +0000 Subject: [PATCH 09/13] chore: address PR comments on payload caching --- src/sagemaker/jumpstart/accessors.py | 48 +++++++++++++++---- src/sagemaker/jumpstart/payload_utils.py | 6 +-- src/sagemaker/payloads.py | 4 +- .../sagemaker/jumpstart/test_accessors.py | 44 +++++++++++++++-- .../sagemaker/jumpstart/test_payload_utils.py | 4 +- .../sagemaker/jumpstart/test_predictor.py | 2 +- 6 files changed, 89 insertions(+), 19 deletions(-) diff --git a/src/sagemaker/jumpstart/accessors.py b/src/sagemaker/jumpstart/accessors.py index 4b626fda69..88b77b8560 100644 --- a/src/sagemaker/jumpstart/accessors.py +++ b/src/sagemaker/jumpstart/accessors.py @@ -38,15 +38,20 @@ def get_sagemaker_version() -> str: return SageMakerSettings._parsed_sagemaker_version -class JumpStartS3Accessor(object): - """Static class for storing and retrieving auxilliary S3 artifacts.""" +class JumpStartS3PayloadAccessor(object): + """Static class for storing and retrieving S3 payload artifacts.""" + + MAX_CACHE_SIZE_BYTES = int(100 * 1e6) + MAX_PAYLOAD_SIZE_BYTES = int(6 * 1e6) + + CACHE_SIZE = MAX_CACHE_SIZE_BYTES // MAX_PAYLOAD_SIZE_BYTES @staticmethod def clear_cache() -> None: """Clears LRU caches associated with S3 client and retrieved objects.""" - JumpStartS3Accessor._get_default_s3_client.cache_clear() - JumpStartS3Accessor.get_object_cached.cache_clear() + JumpStartS3PayloadAccessor._get_default_s3_client.cache_clear() + JumpStartS3PayloadAccessor.get_object_cached.cache_clear() @staticmethod @functools.lru_cache() @@ -58,7 +63,7 @@ def _get_default_s3_client(region: str = JUMPSTART_DEFAULT_REGION_NAME) -> boto3 return boto3.client("s3", region_name=region) @staticmethod - @functools.lru_cache() + @functools.lru_cache(maxsize=CACHE_SIZE) def get_object_cached( bucket: str, key: str, @@ -70,10 +75,23 @@ def get_object_cached( Requests are cached so that the same S3 request is never made more than once, unless a different region or client is used. """ - return JumpStartS3Accessor.get_object( + return JumpStartS3PayloadAccessor.get_object( bucket=bucket, key=key, region=region, s3_client=s3_client ) + @staticmethod + def _get_object_size_bytes( + bucket: str, + key: str, + region: str = JUMPSTART_DEFAULT_REGION_NAME, + s3_client: Optional[boto3.client] = None, + ) -> bytes: + """Returns size in bytes of S3 object using S3.HeadObject operation.""" + if s3_client is None: + s3_client = JumpStartS3PayloadAccessor._get_default_s3_client(region) + + return s3_client.head_object(Bucket=bucket, Key=key)["ContentLength"] + @staticmethod def get_object( bucket: str, @@ -81,9 +99,23 @@ def get_object( region: str = JUMPSTART_DEFAULT_REGION_NAME, s3_client: Optional[boto3.client] = None, ) -> bytes: - """Returns S3 object located at the bucket and key.""" + """Returns S3 object located at the bucket and key. + + Raises: + ValueError: The object size is too large. + """ if s3_client is None: - s3_client = JumpStartS3Accessor._get_default_s3_client(region) + s3_client = JumpStartS3PayloadAccessor._get_default_s3_client(region) + + object_size_bytes = JumpStartS3PayloadAccessor._get_object_size_bytes( + bucket=bucket, key=key, region=region, s3_client=s3_client + ) + if object_size_bytes > JumpStartS3PayloadAccessor.MAX_PAYLOAD_SIZE_BYTES: + raise ValueError( + f"s3://{bucket}/{key} has size of {object_size_bytes} bytes, " + "which exceeds maximum allowed size of " + f"{JumpStartS3PayloadAccessor.MAX_PAYLOAD_SIZE_BYTES} bytes." + ) return s3_client.get_object(Bucket=bucket, Key=key)["Body"].read() diff --git a/src/sagemaker/jumpstart/payload_utils.py b/src/sagemaker/jumpstart/payload_utils.py index 67ba1dfb0b..60232271ed 100644 --- a/src/sagemaker/jumpstart/payload_utils.py +++ b/src/sagemaker/jumpstart/payload_utils.py @@ -18,7 +18,7 @@ import re import boto3 -from sagemaker.jumpstart.accessors import JumpStartS3Accessor +from sagemaker.jumpstart.accessors import JumpStartS3PayloadAccessor from sagemaker.jumpstart.constants import JUMPSTART_DEFAULT_REGION_NAME from sagemaker.jumpstart.enums import MIMEType from sagemaker.jumpstart.types import JumpStartSerializablePayload @@ -60,7 +60,7 @@ def get_bytes_payload_with_s3_references( raise ValueError("Invalid bytes payload.") s3_key = s3_keys[0] - serialized_s3_object = JumpStartS3Accessor.get_object_cached( + serialized_s3_object = JumpStartS3PayloadAccessor.get_object_cached( bucket=self.bucket, key=s3_key, region=self.region, s3_client=self.s3_client ) @@ -90,7 +90,7 @@ def _embed_s3_b64_references_in_str_payload( for s3_key in s3_keys: b64_encoded_string = base64.b64encode( bytearray( - JumpStartS3Accessor.get_object_cached( + JumpStartS3PayloadAccessor.get_object_cached( bucket=self.bucket, key=s3_key, region=self.region, s3_client=self.s3_client ) ) diff --git a/src/sagemaker/payloads.py b/src/sagemaker/payloads.py index 0878a86bad..20962af18d 100644 --- a/src/sagemaker/payloads.py +++ b/src/sagemaker/payloads.py @@ -31,7 +31,7 @@ def retrieve_options( region: Optional[str] = None, model_id: Optional[str] = None, model_version: Optional[str] = None, - serialize: bool = True, + serialize: bool = False, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, @@ -47,7 +47,7 @@ def retrieve_options( serialize (bool): Whether to serialize byte-stream valued payloads by downloading binary files from s3 and applying encoding, or to keep payload in pre-serialized state. Set this option to False if you want to avoid s3 downloads or if you - want to inspect the payload in a human-readable form. (Default: True). + want to inspect the payload in a human-readable form. (Default: False). tolerate_vulnerable_model (bool): ``True`` if vulnerable versions of model specifications should be tolerated without raising an exception. If ``False``, raises an exception if the script used by this version of the model has dependencies with known diff --git a/tests/unit/sagemaker/jumpstart/test_accessors.py b/tests/unit/sagemaker/jumpstart/test_accessors.py index 31a76701a1..97427be1ae 100644 --- a/tests/unit/sagemaker/jumpstart/test_accessors.py +++ b/tests/unit/sagemaker/jumpstart/test_accessors.py @@ -152,8 +152,9 @@ def test_get_object(self, mocked_boto3_client): mocked_boto3_client.return_value = Mock() mocked_boto3_client.return_value.get_object.return_value = {"Body": BytesIO(b"s3-object")} + mocked_boto3_client.return_value.head_object.return_value = {"ContentLength": 1} - response = accessors.JumpStartS3Accessor.get_object(bucket=self.bucket, key=self.key) + response = accessors.JumpStartS3PayloadAccessor.get_object(bucket=self.bucket, key=self.key) self.assertEqual(response, b"s3-object") @@ -161,6 +162,9 @@ def test_get_object(self, mocked_boto3_client): mocked_boto3_client.return_value.get_object.assert_called_once_with( Bucket=self.bucket, Key=self.key ) + mocked_boto3_client.return_value.head_object.assert_called_once_with( + Bucket=self.bucket, Key=self.key + ) @patch("sagemaker.jumpstart.accessors.boto3.client") def test_get_object_cached(self, mocked_boto3_client): @@ -171,9 +175,14 @@ def test_get_object_cached(self, mocked_boto3_client): mocked_boto3_client.return_value = Mock() mocked_boto3_client.return_value.get_object.return_value = {"Body": BytesIO(b"s3-object")} + mocked_boto3_client.return_value.head_object.return_value = {"ContentLength": 1} - response = accessors.JumpStartS3Accessor.get_object_cached(bucket=self.bucket, key=self.key) - response = accessors.JumpStartS3Accessor.get_object_cached(bucket=self.bucket, key=self.key) + response = accessors.JumpStartS3PayloadAccessor.get_object_cached( + bucket=self.bucket, key=self.key + ) + response = accessors.JumpStartS3PayloadAccessor.get_object_cached( + bucket=self.bucket, key=self.key + ) self.assertEqual(response, b"s3-object") @@ -182,3 +191,32 @@ def test_get_object_cached(self, mocked_boto3_client): mocked_boto3_client.return_value.get_object.assert_called_once_with( Bucket=self.bucket, Key=self.key ) + mocked_boto3_client.return_value.head_object.assert_called_once_with( + Bucket=self.bucket, Key=self.key + ) + + @patch("sagemaker.jumpstart.accessors.boto3.client") + def test_get_object_limit_exceeded(self, mocked_boto3_client): + + # required due to static class + reload(accessors) + + mocked_boto3_client.return_value = Mock() + + mocked_boto3_client.return_value.get_object.return_value = {"Body": BytesIO(b"s3-object")} + mocked_boto3_client.return_value.head_object.return_value = {"ContentLength": 1e99} + + with pytest.raises(ValueError) as e: + accessors.JumpStartS3PayloadAccessor.get_object(bucket=self.bucket, key=self.key) + + self.assertEqual( + str(e.value), + "s3://bucket/key has size of 1e+99 bytes, which " + "exceeds maximum allowed size of 6000000 bytes.", + ) + + mocked_boto3_client.assert_called_once_with("s3", region_name="us-west-2") + mocked_boto3_client.return_value.get_object.assert_not_called() + mocked_boto3_client.return_value.head_object.assert_called_once_with( + Bucket=self.bucket, Key=self.key + ) diff --git a/tests/unit/sagemaker/jumpstart/test_payload_utils.py b/tests/unit/sagemaker/jumpstart/test_payload_utils.py index 5f78c311ce..687c9154df 100644 --- a/tests/unit/sagemaker/jumpstart/test_payload_utils.py +++ b/tests/unit/sagemaker/jumpstart/test_payload_utils.py @@ -23,7 +23,7 @@ class TestPayloadSerializer(TestCase): payload_serializer = PayloadSerializer() - @patch("sagemaker.jumpstart.payload_utils.JumpStartS3Accessor.get_object_cached") + @patch("sagemaker.jumpstart.payload_utils.JumpStartS3PayloadAccessor.get_object_cached") def test_serialize_bytes_payload(self, mock_get_object_cached): mock_get_object_cached.return_value = "7897" @@ -36,7 +36,7 @@ def test_serialize_bytes_payload(self, mock_get_object_cached): serialized_payload = self.payload_serializer.serialize(payload) self.assertEqual(serialized_payload, "7897") - @patch("sagemaker.jumpstart.payload_utils.JumpStartS3Accessor.get_object_cached") + @patch("sagemaker.jumpstart.payload_utils.JumpStartS3PayloadAccessor.get_object_cached") def test_serialize_json_payload(self, mock_get_object_cached): mock_get_object_cached.return_value = base64.b64decode("encodedimage") diff --git a/tests/unit/sagemaker/jumpstart/test_predictor.py b/tests/unit/sagemaker/jumpstart/test_predictor.py index 80002e02bf..fd3d213650 100644 --- a/tests/unit/sagemaker/jumpstart/test_predictor.py +++ b/tests/unit/sagemaker/jumpstart/test_predictor.py @@ -40,7 +40,7 @@ def test_jumpstart_predictor_support( assert js_predictor.accept == MIMEType.JSON -@patch("sagemaker.jumpstart.payload_utils.JumpStartS3Accessor.get_object_cached") +@patch("sagemaker.jumpstart.payload_utils.JumpStartS3PayloadAccessor.get_object_cached") @patch("sagemaker.jumpstart.model.is_valid_model_id") @patch("sagemaker.jumpstart.utils.verify_model_region_and_return_specs") @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") From 9bfb96202e5724e02fd98cec9077b458ba4acab5 Mon Sep 17 00:00:00 2001 From: Evan Kravitz Date: Thu, 5 Oct 2023 14:07:51 +0000 Subject: [PATCH 10/13] chore: update unit test to check for payload value --- tests/unit/sagemaker/jumpstart/test_predictor.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/tests/unit/sagemaker/jumpstart/test_predictor.py b/tests/unit/sagemaker/jumpstart/test_predictor.py index fd3d213650..9182d2c80d 100644 --- a/tests/unit/sagemaker/jumpstart/test_predictor.py +++ b/tests/unit/sagemaker/jumpstart/test_predictor.py @@ -72,6 +72,13 @@ def test_jumpstart_serializable_payload_with_predictor( js_predictor.sagemaker_session.sagemaker_runtime_client.invoke_endpoint = invoke_endpoint_mock js_predictor._handle_response = mock.Mock() + assert str(default_payload) == ( + "JumpStartSerializablePayload: {'content_type': 'application/json', 'accept': 'application/json'" + ", 'body': {'prompt': 'a dog', 'num_images_per_prompt': 2, 'num_inference_steps':" + " 20, 'guidance_scale': 7.5, 'seed': 43, 'eta': 0.7, 'image':" + " '$s3_b64'}}" + ) + js_predictor.predict(default_payload) invoke_endpoint_mock.assert_called_once_with( From 89980b9baf8306c847c58837252f98095f009670 Mon Sep 17 00:00:00 2001 From: Evan Kravitz Date: Fri, 6 Oct 2023 16:08:35 +0000 Subject: [PATCH 11/13] chore: address git comments --- src/sagemaker/jumpstart/model.py | 14 +++--- src/sagemaker/jumpstart/payload_utils.py | 2 +- src/sagemaker/payloads.py | 62 +++++++++++++++++++++++- 3 files changed, 69 insertions(+), 9 deletions(-) diff --git a/src/sagemaker/jumpstart/model.py b/src/sagemaker/jumpstart/model.py index 81dc1b2bd9..0f0d504541 100644 --- a/src/sagemaker/jumpstart/model.py +++ b/src/sagemaker/jumpstart/model.py @@ -318,8 +318,15 @@ def retrieve_default_payload(self) -> JumpStartSerializablePayload: """Returns the default payload associated with the model. Payload can be directly used with the `sagemaker.predictor.Predictor.predict(...)` function. + + Raises: + NotImplementedError: If the scope is not supported. + ValueError: If the combination of arguments specified is not supported. + VulnerableJumpStartModelError: If any of the dependencies required by the script have + known security vulnerabilities. + DeprecatedJumpStartModelError: If the version of the model is deprecated. """ - payload_options: Optional[List[JumpStartSerializablePayload]] = payloads.retrieve_options( + return payloads.retrieve_example( model_id=self.model_id, model_version=self.model_version, region=self.region, @@ -328,11 +335,6 @@ def retrieve_default_payload(self) -> JumpStartSerializablePayload: sagemaker_session=self.sagemaker_session, ) - if payload_options is None or len(payload_options) == 0: - return None - - return payload_options[0] - def _create_sagemaker_model( self, instance_type=None, diff --git a/src/sagemaker/jumpstart/payload_utils.py b/src/sagemaker/jumpstart/payload_utils.py index 60232271ed..4aa3bafb08 100644 --- a/src/sagemaker/jumpstart/payload_utils.py +++ b/src/sagemaker/jumpstart/payload_utils.py @@ -10,7 +10,7 @@ # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. -"""This module stores stores inference payload utilities for JumpStart models.""" +"""This module stores inference payload utilities for JumpStart models.""" from __future__ import absolute_import import base64 import json diff --git a/src/sagemaker/payloads.py b/src/sagemaker/payloads.py index 20962af18d..834c4a543e 100644 --- a/src/sagemaker/payloads.py +++ b/src/sagemaker/payloads.py @@ -27,7 +27,7 @@ logger = logging.getLogger(__name__) -def retrieve_options( +def retrieve_all_examples( region: Optional[str] = None, model_id: Optional[str] = None, model_version: Optional[str] = None, @@ -60,7 +60,7 @@ def retrieve_options( specified, one is created using the default AWS configuration chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). Returns: - str: The model artifact S3 URI for the corresponding model. + Optional[List[str]]: List of payloads or None. Raises: NotImplementedError: If the scope is not supported. @@ -114,3 +114,61 @@ def retrieve_options( ) return serialized_payloads + + +def retrieve_example( + region: Optional[str] = None, + model_id: Optional[str] = None, + model_version: Optional[str] = None, + serialize: bool = False, + tolerate_vulnerable_model: bool = False, + tolerate_deprecated_model: bool = False, + sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, +) -> Optional[str]: + """Retrieves a single compatible payload for the model matching the given arguments. + + Args: + region (str): The AWS Region for which to retrieve the Jumpstart model payloads. + model_id (str): The model ID of the JumpStart model for which to retrieve + the model payload. + model_version (str): The version of the JumpStart model for which to retrieve + the model payload. + serialize (bool): Whether to serialize byte-stream valued payloads by downloading + binary files from s3 and applying encoding, or to keep payload in pre-serialized + state. Set this option to False if you want to avoid s3 downloads or if you + want to inspect the payload in a human-readable form. (Default: False). + tolerate_vulnerable_model (bool): ``True`` if vulnerable versions of model + specifications should be tolerated without raising an exception. If ``False``, raises an + exception if the script used by this version of the model has dependencies with known + security vulnerabilities. (Default: False). + tolerate_deprecated_model (bool): ``True`` if deprecated versions of model + specifications should be tolerated without raising an exception. If ``False``, raises + an exception if the version of the model is deprecated. (Default: False). + sagemaker_session (sagemaker.session.Session): A SageMaker Session + object, used for SageMaker interactions. If not + specified, one is created using the default AWS configuration + chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). + Returns: + Optional[str]: A single default payload or None. + + Raises: + NotImplementedError: If the scope is not supported. + ValueError: If the combination of arguments specified is not supported. + VulnerableJumpStartModelError: If any of the dependencies required by the script have + known security vulnerabilities. + DeprecatedJumpStartModelError: If the version of the model is deprecated. + """ + example_payloads: Optional[List[str]] = retrieve_all_examples( + region=region, + model_id=model_id, + model_version=model_version, + serialize=serialize, + tolerate_vulnerable_model=tolerate_vulnerable_model, + tolerate_deprecated_model=tolerate_deprecated_model, + sagemaker_session=sagemaker_session, + ) + + if example_payloads is None or len(example_payloads) == 0: + return None + + return example_payloads[0] From fc900e7df8cde4700598c33e0624fd3836655305 Mon Sep 17 00:00:00 2001 From: Evan Kravitz Date: Mon, 9 Oct 2023 13:43:19 +0000 Subject: [PATCH 12/13] chore: improve docstring --- src/sagemaker/jumpstart/artifacts/payloads.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/sagemaker/jumpstart/artifacts/payloads.py b/src/sagemaker/jumpstart/artifacts/payloads.py index 63d0390295..20914a0645 100644 --- a/src/sagemaker/jumpstart/artifacts/payloads.py +++ b/src/sagemaker/jumpstart/artifacts/payloads.py @@ -10,7 +10,7 @@ # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. -"""This module contains functions for obtaining example payloads for JumpStart models.""" +"""This module contains functions to obtain JumpStart model payloads.""" from __future__ import absolute_import from copy import deepcopy from typing import Dict, Optional From e5a7058a22d4608f52d3419ecdec3ef23f986f66 Mon Sep 17 00:00:00 2001 From: Evan Kravitz Date: Tue, 10 Oct 2023 16:55:45 +0000 Subject: [PATCH 13/13] chore: address comments and fix typing/docstrings --- src/sagemaker/jumpstart/artifacts/__init__.py | 2 +- src/sagemaker/jumpstart/artifacts/payloads.py | 13 ++++++----- src/sagemaker/jumpstart/model.py | 23 +++++++++++++++++-- src/sagemaker/payloads.py | 12 +++++----- .../sagemaker/jumpstart/test_predictor.py | 2 +- 5 files changed, 36 insertions(+), 16 deletions(-) diff --git a/src/sagemaker/jumpstart/artifacts/__init__.py b/src/sagemaker/jumpstart/artifacts/__init__.py index f57bfcbf46..4393a15402 100644 --- a/src/sagemaker/jumpstart/artifacts/__init__.py +++ b/src/sagemaker/jumpstart/artifacts/__init__.py @@ -62,5 +62,5 @@ _retrieve_model_package_model_artifact_s3_uri, ) from sagemaker.jumpstart.artifacts.payloads import ( # noqa: F401 - _retrieve_default_payloads, + _retrieve_example_payloads, ) diff --git a/src/sagemaker/jumpstart/artifacts/payloads.py b/src/sagemaker/jumpstart/artifacts/payloads.py index 20914a0645..3ea2c16f80 100644 --- a/src/sagemaker/jumpstart/artifacts/payloads.py +++ b/src/sagemaker/jumpstart/artifacts/payloads.py @@ -28,7 +28,7 @@ from sagemaker.session import Session -def _retrieve_default_payloads( +def _retrieve_example_payloads( model_id: str, model_version: str, region: Optional[str], @@ -36,15 +36,15 @@ def _retrieve_default_payloads( tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, ) -> Optional[Dict[str, JumpStartSerializablePayload]]: - """Returns default payloads. + """Returns example payloads. Args: model_id (str): JumpStart model ID of the JumpStart model for which to - get default payloads. + get example payloads. model_version (str): Version of the JumpStart model for which to retrieve the - default resource name. + example payloads. region (Optional[str]): Region for which to retrieve the - default resource name. + example payloads. tolerate_vulnerable_model (bool): True if vulnerable versions of model specifications should be tolerated (exception not raised). If False, raises an exception if the script used by this version of the model has dependencies with known @@ -57,7 +57,8 @@ def _retrieve_default_payloads( specified, one is created using the default AWS configuration chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). Returns: - str: the default payload. + Optional[Dict[str, JumpStartSerializablePayload]]: dictionary mapping payload aliases + to the serializable payload object. """ if region is None: diff --git a/src/sagemaker/jumpstart/model.py b/src/sagemaker/jumpstart/model.py index 0f0d504541..95a4bb3b99 100644 --- a/src/sagemaker/jumpstart/model.py +++ b/src/sagemaker/jumpstart/model.py @@ -314,8 +314,27 @@ def _is_valid_model_id_hook(): super(JumpStartModel, self).__init__(**model_init_kwargs.to_kwargs_dict()) - def retrieve_default_payload(self) -> JumpStartSerializablePayload: - """Returns the default payload associated with the model. + def retrieve_all_examples(self) -> Optional[List[JumpStartSerializablePayload]]: + """Returns all example payloads associated with the model. + + Raises: + NotImplementedError: If the scope is not supported. + ValueError: If the combination of arguments specified is not supported. + VulnerableJumpStartModelError: If any of the dependencies required by the script have + known security vulnerabilities. + DeprecatedJumpStartModelError: If the version of the model is deprecated. + """ + return payloads.retrieve_all_examples( + model_id=self.model_id, + model_version=self.model_version, + region=self.region, + tolerate_deprecated_model=self.tolerate_deprecated_model, + tolerate_vulnerable_model=self.tolerate_vulnerable_model, + sagemaker_session=self.sagemaker_session, + ) + + def retrieve_example_payload(self) -> JumpStartSerializablePayload: + """Returns the example payload associated with the model. Payload can be directly used with the `sagemaker.predictor.Predictor.predict(...)` function. diff --git a/src/sagemaker/payloads.py b/src/sagemaker/payloads.py index 834c4a543e..52d633ed4e 100644 --- a/src/sagemaker/payloads.py +++ b/src/sagemaker/payloads.py @@ -35,7 +35,7 @@ def retrieve_all_examples( tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, -) -> Optional[List[str]]: +) -> Optional[List[JumpStartSerializablePayload]]: """Retrieves the compatible payloads for the model matching the given arguments. Args: @@ -60,7 +60,7 @@ def retrieve_all_examples( specified, one is created using the default AWS configuration chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). Returns: - Optional[List[str]]: List of payloads or None. + Optional[List[JumpStartSerializablePayload]]: List of payloads or None. Raises: NotImplementedError: If the scope is not supported. @@ -76,7 +76,7 @@ def retrieve_all_examples( unserialized_payload_dict: Optional[ Dict[str, JumpStartSerializablePayload] - ] = artifacts._retrieve_default_payloads( + ] = artifacts._retrieve_example_payloads( model_id, model_version, region, @@ -124,7 +124,7 @@ def retrieve_example( tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, -) -> Optional[str]: +) -> Optional[JumpStartSerializablePayload]: """Retrieves a single compatible payload for the model matching the given arguments. Args: @@ -149,7 +149,7 @@ def retrieve_example( specified, one is created using the default AWS configuration chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). Returns: - Optional[str]: A single default payload or None. + Optional[JumpStartSerializablePayload]: A single default payload or None. Raises: NotImplementedError: If the scope is not supported. @@ -158,7 +158,7 @@ def retrieve_example( known security vulnerabilities. DeprecatedJumpStartModelError: If the version of the model is deprecated. """ - example_payloads: Optional[List[str]] = retrieve_all_examples( + example_payloads: Optional[List[JumpStartSerializablePayload]] = retrieve_all_examples( region=region, model_id=model_id, model_version=model_version, diff --git a/tests/unit/sagemaker/jumpstart/test_predictor.py b/tests/unit/sagemaker/jumpstart/test_predictor.py index 9182d2c80d..4c2cd5b123 100644 --- a/tests/unit/sagemaker/jumpstart/test_predictor.py +++ b/tests/unit/sagemaker/jumpstart/test_predictor.py @@ -65,7 +65,7 @@ def test_jumpstart_serializable_payload_with_predictor( default_payload = JumpStartModel( model_id=model_id, model_version=model_version - ).retrieve_default_payload() + ).retrieve_example_payload() invoke_endpoint_mock = mock.Mock()