diff --git a/src/sagemaker/base_predictor.py b/src/sagemaker/base_predictor.py index 46983e0983..49106f02c0 100644 --- a/src/sagemaker/base_predictor.py +++ b/src/sagemaker/base_predictor.py @@ -14,7 +14,7 @@ from __future__ import print_function, absolute_import import abc -from typing import Any, Tuple +from typing import Any, Optional, Tuple, Union from sagemaker.deprecations import ( deprecated_class, @@ -32,6 +32,9 @@ StreamDeserializer, StringDeserializer, ) +from sagemaker.jumpstart.payload_utils import PayloadSerializer +from sagemaker.jumpstart.types import JumpStartSerializablePayload +from sagemaker.jumpstart.utils import get_jumpstart_content_bucket from sagemaker.model_monitor import ( DataCaptureConfig, DefaultModelMonitor, @@ -201,20 +204,44 @@ def _create_request_args( custom_attributes=None, ): """Placeholder docstring""" + + jumpstart_serialized_data: Optional[Union[str, bytes]] = None + jumpstart_accept: Optional[str] = None + jumpstart_content_type: Optional[str] = None + + if isinstance(data, JumpStartSerializablePayload): + s3_client = self.sagemaker_session.s3_client + region = self.sagemaker_session._region_name + bucket = get_jumpstart_content_bucket(region) + + jumpstart_serialized_data = PayloadSerializer( + bucket=bucket, region=region, s3_client=s3_client + ).serialize(data) + jumpstart_content_type = data.content_type + jumpstart_accept = data.accept + args = dict(initial_args) if initial_args else {} if "EndpointName" not in args: args["EndpointName"] = self.endpoint_name if "ContentType" not in args: - args["ContentType"] = ( - self.content_type - if isinstance(self.content_type, str) - else ", ".join(self.content_type) - ) + if isinstance(data, JumpStartSerializablePayload) and jumpstart_content_type: + args["ContentType"] = jumpstart_content_type + else: + args["ContentType"] = ( + self.content_type + if isinstance(self.content_type, str) + else ", ".join(self.content_type) + ) if "Accept" not in args: - args["Accept"] = self.accept if isinstance(self.accept, str) else ", ".join(self.accept) + if isinstance(data, JumpStartSerializablePayload) and jumpstart_accept: + args["Accept"] = jumpstart_accept + else: + args["Accept"] = ( + self.accept if isinstance(self.accept, str) else ", ".join(self.accept) + ) if target_model: args["TargetModel"] = target_model @@ -228,7 +255,11 @@ def _create_request_args( if custom_attributes: args["CustomAttributes"] = custom_attributes - data = self.serializer.serialize(data) + data = ( + jumpstart_serialized_data + if isinstance(data, JumpStartSerializablePayload) and jumpstart_serialized_data + else self.serializer.serialize(data) + ) args["Body"] = data return args diff --git a/src/sagemaker/jumpstart/accessors.py b/src/sagemaker/jumpstart/accessors.py index 8117606299..88b77b8560 100644 --- a/src/sagemaker/jumpstart/accessors.py +++ b/src/sagemaker/jumpstart/accessors.py @@ -12,6 +12,7 @@ # language governing permissions and limitations under the License. """This module contains accessors related to SageMaker JumpStart.""" from __future__ import absolute_import +import functools from typing import Any, Dict, List, Optional import boto3 @@ -37,6 +38,88 @@ def get_sagemaker_version() -> str: return SageMakerSettings._parsed_sagemaker_version +class JumpStartS3PayloadAccessor(object): + """Static class for storing and retrieving S3 payload artifacts.""" + + MAX_CACHE_SIZE_BYTES = int(100 * 1e6) + MAX_PAYLOAD_SIZE_BYTES = int(6 * 1e6) + + CACHE_SIZE = MAX_CACHE_SIZE_BYTES // MAX_PAYLOAD_SIZE_BYTES + + @staticmethod + def clear_cache() -> None: + """Clears LRU caches associated with S3 client and retrieved objects.""" + + JumpStartS3PayloadAccessor._get_default_s3_client.cache_clear() + JumpStartS3PayloadAccessor.get_object_cached.cache_clear() + + @staticmethod + @functools.lru_cache() + def _get_default_s3_client(region: str = JUMPSTART_DEFAULT_REGION_NAME) -> boto3.client: + """Returns default S3 client associated with the region. + + Result is cached so multiple clients in memory are not created. + """ + return boto3.client("s3", region_name=region) + + @staticmethod + @functools.lru_cache(maxsize=CACHE_SIZE) + def get_object_cached( + bucket: str, + key: str, + region: str = JUMPSTART_DEFAULT_REGION_NAME, + s3_client: Optional[boto3.client] = None, + ) -> bytes: + """Returns S3 object located at the bucket and key. + + Requests are cached so that the same S3 request is never made more + than once, unless a different region or client is used. + """ + return JumpStartS3PayloadAccessor.get_object( + bucket=bucket, key=key, region=region, s3_client=s3_client + ) + + @staticmethod + def _get_object_size_bytes( + bucket: str, + key: str, + region: str = JUMPSTART_DEFAULT_REGION_NAME, + s3_client: Optional[boto3.client] = None, + ) -> bytes: + """Returns size in bytes of S3 object using S3.HeadObject operation.""" + if s3_client is None: + s3_client = JumpStartS3PayloadAccessor._get_default_s3_client(region) + + return s3_client.head_object(Bucket=bucket, Key=key)["ContentLength"] + + @staticmethod + def get_object( + bucket: str, + key: str, + region: str = JUMPSTART_DEFAULT_REGION_NAME, + s3_client: Optional[boto3.client] = None, + ) -> bytes: + """Returns S3 object located at the bucket and key. + + Raises: + ValueError: The object size is too large. + """ + if s3_client is None: + s3_client = JumpStartS3PayloadAccessor._get_default_s3_client(region) + + object_size_bytes = JumpStartS3PayloadAccessor._get_object_size_bytes( + bucket=bucket, key=key, region=region, s3_client=s3_client + ) + if object_size_bytes > JumpStartS3PayloadAccessor.MAX_PAYLOAD_SIZE_BYTES: + raise ValueError( + f"s3://{bucket}/{key} has size of {object_size_bytes} bytes, " + "which exceeds maximum allowed size of " + f"{JumpStartS3PayloadAccessor.MAX_PAYLOAD_SIZE_BYTES} bytes." + ) + + return s3_client.get_object(Bucket=bucket, Key=key)["Body"].read() + + class JumpStartModelsAccessor(object): """Static class for storing the JumpStart models cache.""" diff --git a/src/sagemaker/jumpstart/artifacts/__init__.py b/src/sagemaker/jumpstart/artifacts/__init__.py index ec44077a64..4393a15402 100644 --- a/src/sagemaker/jumpstart/artifacts/__init__.py +++ b/src/sagemaker/jumpstart/artifacts/__init__.py @@ -61,3 +61,6 @@ _retrieve_model_package_arn, _retrieve_model_package_model_artifact_s3_uri, ) +from sagemaker.jumpstart.artifacts.payloads import ( # noqa: F401 + _retrieve_example_payloads, +) diff --git a/src/sagemaker/jumpstart/artifacts/payloads.py b/src/sagemaker/jumpstart/artifacts/payloads.py new file mode 100644 index 0000000000..3ea2c16f80 --- /dev/null +++ b/src/sagemaker/jumpstart/artifacts/payloads.py @@ -0,0 +1,85 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""This module contains functions to obtain JumpStart model payloads.""" +from __future__ import absolute_import +from copy import deepcopy +from typing import Dict, Optional +from sagemaker.jumpstart.constants import ( + DEFAULT_JUMPSTART_SAGEMAKER_SESSION, + JUMPSTART_DEFAULT_REGION_NAME, +) +from sagemaker.jumpstart.enums import ( + JumpStartScriptScope, +) +from sagemaker.jumpstart.types import JumpStartSerializablePayload +from sagemaker.jumpstart.utils import ( + verify_model_region_and_return_specs, +) +from sagemaker.session import Session + + +def _retrieve_example_payloads( + model_id: str, + model_version: str, + region: Optional[str], + tolerate_vulnerable_model: bool = False, + tolerate_deprecated_model: bool = False, + sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, +) -> Optional[Dict[str, JumpStartSerializablePayload]]: + """Returns example payloads. + + Args: + model_id (str): JumpStart model ID of the JumpStart model for which to + get example payloads. + model_version (str): Version of the JumpStart model for which to retrieve the + example payloads. + region (Optional[str]): Region for which to retrieve the + example payloads. + tolerate_vulnerable_model (bool): True if vulnerable versions of model + specifications should be tolerated (exception not raised). If False, raises an + exception if the script used by this version of the model has dependencies with known + security vulnerabilities. (Default: False). + tolerate_deprecated_model (bool): True if deprecated versions of model + specifications should be tolerated (exception not raised). If False, raises + an exception if the version of the model is deprecated. (Default: False). + sagemaker_session (sagemaker.session.Session): A SageMaker Session + object, used for SageMaker interactions. If not + specified, one is created using the default AWS configuration + chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). + Returns: + Optional[Dict[str, JumpStartSerializablePayload]]: dictionary mapping payload aliases + to the serializable payload object. + """ + + if region is None: + region = JUMPSTART_DEFAULT_REGION_NAME + + model_specs = verify_model_region_and_return_specs( + model_id=model_id, + version=model_version, + scope=JumpStartScriptScope.INFERENCE, + region=region, + tolerate_vulnerable_model=tolerate_vulnerable_model, + tolerate_deprecated_model=tolerate_deprecated_model, + sagemaker_session=sagemaker_session, + ) + + default_payloads = model_specs.default_payloads + + if default_payloads: + for payload in default_payloads.values(): + payload.accept = getattr( + payload, "accept", model_specs.predictor_specs.default_accept_type + ) + + return deepcopy(default_payloads) if default_payloads else None diff --git a/src/sagemaker/jumpstart/model.py b/src/sagemaker/jumpstart/model.py index ab060ea454..95a4bb3b99 100644 --- a/src/sagemaker/jumpstart/model.py +++ b/src/sagemaker/jumpstart/model.py @@ -16,6 +16,7 @@ import re from typing import Dict, List, Optional, Union +from sagemaker import payloads from sagemaker.async_inference.async_inference_config import AsyncInferenceConfig from sagemaker.base_deserializers import BaseDeserializer from sagemaker.base_serializers import BaseSerializer @@ -28,6 +29,7 @@ get_deploy_kwargs, get_init_kwargs, ) +from sagemaker.jumpstart.types import JumpStartSerializablePayload from sagemaker.jumpstart.utils import is_valid_model_id from sagemaker.utils import stringify_object from sagemaker.model import MODEL_PACKAGE_ARN_PATTERN, Model @@ -312,6 +314,46 @@ def _is_valid_model_id_hook(): super(JumpStartModel, self).__init__(**model_init_kwargs.to_kwargs_dict()) + def retrieve_all_examples(self) -> Optional[List[JumpStartSerializablePayload]]: + """Returns all example payloads associated with the model. + + Raises: + NotImplementedError: If the scope is not supported. + ValueError: If the combination of arguments specified is not supported. + VulnerableJumpStartModelError: If any of the dependencies required by the script have + known security vulnerabilities. + DeprecatedJumpStartModelError: If the version of the model is deprecated. + """ + return payloads.retrieve_all_examples( + model_id=self.model_id, + model_version=self.model_version, + region=self.region, + tolerate_deprecated_model=self.tolerate_deprecated_model, + tolerate_vulnerable_model=self.tolerate_vulnerable_model, + sagemaker_session=self.sagemaker_session, + ) + + def retrieve_example_payload(self) -> JumpStartSerializablePayload: + """Returns the example payload associated with the model. + + Payload can be directly used with the `sagemaker.predictor.Predictor.predict(...)` function. + + Raises: + NotImplementedError: If the scope is not supported. + ValueError: If the combination of arguments specified is not supported. + VulnerableJumpStartModelError: If any of the dependencies required by the script have + known security vulnerabilities. + DeprecatedJumpStartModelError: If the version of the model is deprecated. + """ + return payloads.retrieve_example( + model_id=self.model_id, + model_version=self.model_version, + region=self.region, + tolerate_deprecated_model=self.tolerate_deprecated_model, + tolerate_vulnerable_model=self.tolerate_vulnerable_model, + sagemaker_session=self.sagemaker_session, + ) + def _create_sagemaker_model( self, instance_type=None, diff --git a/src/sagemaker/jumpstart/payload_utils.py b/src/sagemaker/jumpstart/payload_utils.py new file mode 100644 index 0000000000..4aa3bafb08 --- /dev/null +++ b/src/sagemaker/jumpstart/payload_utils.py @@ -0,0 +1,143 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""This module stores inference payload utilities for JumpStart models.""" +from __future__ import absolute_import +import base64 +import json +from typing import Optional, Union +import re +import boto3 + +from sagemaker.jumpstart.accessors import JumpStartS3PayloadAccessor +from sagemaker.jumpstart.constants import JUMPSTART_DEFAULT_REGION_NAME +from sagemaker.jumpstart.enums import MIMEType +from sagemaker.jumpstart.types import JumpStartSerializablePayload +from sagemaker.jumpstart.utils import get_jumpstart_content_bucket + +S3_BYTES_REGEX = r"^\$s3<(?P[a-zA-Z0-9-_/.]+)>$" +S3_B64_STR_REGEX = r"\$s3_b64<(?P[a-zA-Z0-9-_/.]+)>" + + +class PayloadSerializer: + """Utility class for serializing payloads associated with JumpStart models. + + Many JumpStart models embed byte-streams into payloads corresponding to images, sounds, + and other content types which require downloading from S3. + """ + + def __init__( + self, + bucket: str = get_jumpstart_content_bucket(), + region: str = JUMPSTART_DEFAULT_REGION_NAME, + s3_client: Optional[boto3.client] = None, + ) -> None: + """Initializes PayloadSerializer object.""" + self.bucket = bucket + self.region = region + self.s3_client = s3_client + + def get_bytes_payload_with_s3_references( + self, + payload_str: str, + ) -> bytes: + """Returns bytes object corresponding to referenced S3 object. + + Raises: + ValueError: If the raw bytes payload is not formatted correctly. + """ + s3_keys = re.compile(S3_BYTES_REGEX).findall(payload_str) + if len(s3_keys) != 1: + raise ValueError("Invalid bytes payload.") + + s3_key = s3_keys[0] + serialized_s3_object = JumpStartS3PayloadAccessor.get_object_cached( + bucket=self.bucket, key=s3_key, region=self.region, s3_client=self.s3_client + ) + + return serialized_s3_object + + def embed_s3_references_in_str_payload( + self, + payload: str, + ) -> str: + """Inserts serialized S3 content into string payload. + + If no S3 content is embedded in payload, original string is returned. + """ + return self._embed_s3_b64_references_in_str_payload(payload_body=payload) + + def _embed_s3_b64_references_in_str_payload( + self, + payload_body: str, + ) -> str: + """Performs base 64 encoding of payloads embedded in a payload. + + This is required so that byte-valued payloads can be transmitted efficiently + as a utf-8 encoded string. + """ + + s3_keys = re.compile(S3_B64_STR_REGEX).findall(payload_body) + for s3_key in s3_keys: + b64_encoded_string = base64.b64encode( + bytearray( + JumpStartS3PayloadAccessor.get_object_cached( + bucket=self.bucket, key=s3_key, region=self.region, s3_client=self.s3_client + ) + ) + ).decode() + payload_body = payload_body.replace(f"$s3_b64<{s3_key}>", b64_encoded_string) + return payload_body + + def embed_s3_references_in_json_payload( + self, payload_body: Union[list, dict, str, int, float] + ) -> Union[list, dict, str, int, float]: + """Finds all S3 references in payload and embeds serialized S3 data. + + If no S3 references are found, the payload is returned un-modified. + + Raises: + ValueError: If the payload has an unrecognized type. + """ + if isinstance(payload_body, str): + return self.embed_s3_references_in_str_payload(payload_body) + if isinstance(payload_body, (float, int)): + return payload_body + if isinstance(payload_body, list): + return [self.embed_s3_references_in_json_payload(item) for item in payload_body] + if isinstance(payload_body, dict): + return { + key: self.embed_s3_references_in_json_payload(value) + for key, value in payload_body.items() + } + raise ValueError(f"Payload has unrecognized type: {type(payload_body)}") + + def serialize(self, payload: JumpStartSerializablePayload) -> Union[str, bytes]: + """Returns payload string or bytes that can be inputted to inference endpoint. + + Raises: + ValueError: If the payload has an unrecognized type. + """ + content_type = MIMEType.from_suffixed_type(payload.content_type) + body = payload.body + + if content_type in {MIMEType.JSON, MIMEType.LIST_TEXT, MIMEType.X_TEXT}: + body = self.embed_s3_references_in_json_payload(body) + else: + body = self.get_bytes_payload_with_s3_references(body) + + if isinstance(body, dict): + body = json.dumps(body) + elif not isinstance(body, str) and not isinstance(body, bytes): + raise ValueError(f"Default payload '{body}' has unrecognized type: {type(body)}") + + return body diff --git a/src/sagemaker/jumpstart/types.py b/src/sagemaker/jumpstart/types.py index e8b717b7c7..a6863687e7 100644 --- a/src/sagemaker/jumpstart/types.py +++ b/src/sagemaker/jumpstart/types.py @@ -31,6 +31,8 @@ class JumpStartDataHolderType: __slots__: List[str] = [] + _non_serializable_slots: List[str] = [] + def __eq__(self, other: Any) -> bool: """Returns True if ``other`` is of the same type and has all attributes equal. @@ -69,7 +71,11 @@ def __str__(self) -> str: {'content_bucket': 'bucket', 'region_name': 'us-west-2'}" """ - att_dict = {att: getattr(self, att) for att in self.__slots__ if hasattr(self, att)} + att_dict = { + att: getattr(self, att) + for att in self.__slots__ + if hasattr(self, att) and att not in self._non_serializable_slots + } return f"{type(self).__name__}: {str(att_dict)}" def __repr__(self) -> str: @@ -79,7 +85,11 @@ def __repr__(self) -> str: {'content_bucket': 'bucket', 'region_name': 'us-west-2'}" """ - att_dict = {att: getattr(self, att) for att in self.__slots__ if hasattr(self, att)} + att_dict = { + att: getattr(self, att) + for att in self.__slots__ + if hasattr(self, att) and att not in self._non_serializable_slots + } return f"{type(self).__name__} at {hex(id(self))}: {str(att_dict)}" @@ -312,6 +322,52 @@ def to_json(self) -> Dict[str, Any]: return json_obj +class JumpStartSerializablePayload(JumpStartDataHolderType): + """Data class for JumpStart serialized payload specs.""" + + __slots__ = [ + "raw_payload", + "content_type", + "accept", + "body", + ] + + _non_serializable_slots = ["raw_payload"] + + def __init__(self, spec: Optional[Dict[str, Any]]): + """Initializes a JumpStartSerializablePayload object from its json representation. + + Args: + spec (Dict[str, Any]): Dictionary representation of payload specs. + """ + self.from_json(spec) + + def from_json(self, json_obj: Optional[Dict[str, Any]]) -> None: + """Sets fields in object based on json. + + Args: + json_obj (Dict[str, Any]): Dictionary representation of serializable + payload specs. + + Raises: + KeyError: If the dictionary is missing keys. + """ + + if json_obj is None: + return + + self.raw_payload = json_obj + self.content_type = json_obj["content_type"] + self.body = json_obj["body"] + accept = json_obj.get("accept") + if accept: + self.accept = accept + + def to_json(self) -> Dict[str, Any]: + """Returns json representation of JumpStartSerializablePayload object.""" + return deepcopy(self.raw_payload) + + class JumpStartInstanceTypeVariants(JumpStartDataHolderType): """Data class for JumpStart instance type variants.""" @@ -468,6 +524,7 @@ class JumpStartModelSpecs(JumpStartDataHolderType): "hosting_use_script_uri", "hosting_instance_type_variants", "training_instance_type_variants", + "default_payloads", ] def __init__(self, spec: Dict[str, Any]): @@ -536,6 +593,14 @@ def from_json(self, json_obj: Dict[str, Any]) -> None: if "predictor_specs" in json_obj else None ) + self.default_payloads: Optional[Dict[str, JumpStartSerializablePayload]] = ( + { + alias: JumpStartSerializablePayload(payload) + for alias, payload in json_obj["default_payloads"].items() + } + if json_obj.get("default_payloads") + else None + ) self.inference_volume_size: Optional[int] = json_obj.get("inference_volume_size") self.inference_enable_network_isolation: bool = json_obj.get( "inference_enable_network_isolation", False diff --git a/src/sagemaker/payloads.py b/src/sagemaker/payloads.py new file mode 100644 index 0000000000..52d633ed4e --- /dev/null +++ b/src/sagemaker/payloads.py @@ -0,0 +1,174 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Utilities related to payloads of pretrained machine learning models.""" +from __future__ import absolute_import + +import logging +from typing import Dict, List, Optional + +from sagemaker.jumpstart import utils as jumpstart_utils +from sagemaker.jumpstart import artifacts +from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION +from sagemaker.jumpstart.payload_utils import PayloadSerializer +from sagemaker.jumpstart.types import JumpStartSerializablePayload +from sagemaker.session import Session + + +logger = logging.getLogger(__name__) + + +def retrieve_all_examples( + region: Optional[str] = None, + model_id: Optional[str] = None, + model_version: Optional[str] = None, + serialize: bool = False, + tolerate_vulnerable_model: bool = False, + tolerate_deprecated_model: bool = False, + sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, +) -> Optional[List[JumpStartSerializablePayload]]: + """Retrieves the compatible payloads for the model matching the given arguments. + + Args: + region (str): The AWS Region for which to retrieve the Jumpstart model payloads. + model_id (str): The model ID of the JumpStart model for which to retrieve + the model payloads. + model_version (str): The version of the JumpStart model for which to retrieve + the model payloads. + serialize (bool): Whether to serialize byte-stream valued payloads by downloading + binary files from s3 and applying encoding, or to keep payload in pre-serialized + state. Set this option to False if you want to avoid s3 downloads or if you + want to inspect the payload in a human-readable form. (Default: False). + tolerate_vulnerable_model (bool): ``True`` if vulnerable versions of model + specifications should be tolerated without raising an exception. If ``False``, raises an + exception if the script used by this version of the model has dependencies with known + security vulnerabilities. (Default: False). + tolerate_deprecated_model (bool): ``True`` if deprecated versions of model + specifications should be tolerated without raising an exception. If ``False``, raises + an exception if the version of the model is deprecated. (Default: False). + sagemaker_session (sagemaker.session.Session): A SageMaker Session + object, used for SageMaker interactions. If not + specified, one is created using the default AWS configuration + chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). + Returns: + Optional[List[JumpStartSerializablePayload]]: List of payloads or None. + + Raises: + NotImplementedError: If the scope is not supported. + ValueError: If the combination of arguments specified is not supported. + VulnerableJumpStartModelError: If any of the dependencies required by the script have + known security vulnerabilities. + DeprecatedJumpStartModelError: If the version of the model is deprecated. + """ + if not jumpstart_utils.is_jumpstart_model_input(model_id, model_version): + raise ValueError( + "Must specify JumpStart `model_id` and `model_version` when retrieving payloads." + ) + + unserialized_payload_dict: Optional[ + Dict[str, JumpStartSerializablePayload] + ] = artifacts._retrieve_example_payloads( + model_id, + model_version, + region, + tolerate_vulnerable_model, + tolerate_deprecated_model, + sagemaker_session=sagemaker_session, + ) + + if unserialized_payload_dict is None: + return None + + unserialized_payloads: List[JumpStartSerializablePayload] = list( + unserialized_payload_dict.values() + ) + + if not serialize: + return unserialized_payloads + + payload_serializer = PayloadSerializer(region=region, s3_client=sagemaker_session.s3_client) + + serialized_payloads: List[JumpStartSerializablePayload] = [] + + for payload in unserialized_payloads: + + serialized_body = payload_serializer.serialize(payload) + + serialized_payloads.append( + JumpStartSerializablePayload( + { + "content_type": payload.content_type, + "body": serialized_body, + "accept": payload.accept, + } + ) + ) + + return serialized_payloads + + +def retrieve_example( + region: Optional[str] = None, + model_id: Optional[str] = None, + model_version: Optional[str] = None, + serialize: bool = False, + tolerate_vulnerable_model: bool = False, + tolerate_deprecated_model: bool = False, + sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, +) -> Optional[JumpStartSerializablePayload]: + """Retrieves a single compatible payload for the model matching the given arguments. + + Args: + region (str): The AWS Region for which to retrieve the Jumpstart model payloads. + model_id (str): The model ID of the JumpStart model for which to retrieve + the model payload. + model_version (str): The version of the JumpStart model for which to retrieve + the model payload. + serialize (bool): Whether to serialize byte-stream valued payloads by downloading + binary files from s3 and applying encoding, or to keep payload in pre-serialized + state. Set this option to False if you want to avoid s3 downloads or if you + want to inspect the payload in a human-readable form. (Default: False). + tolerate_vulnerable_model (bool): ``True`` if vulnerable versions of model + specifications should be tolerated without raising an exception. If ``False``, raises an + exception if the script used by this version of the model has dependencies with known + security vulnerabilities. (Default: False). + tolerate_deprecated_model (bool): ``True`` if deprecated versions of model + specifications should be tolerated without raising an exception. If ``False``, raises + an exception if the version of the model is deprecated. (Default: False). + sagemaker_session (sagemaker.session.Session): A SageMaker Session + object, used for SageMaker interactions. If not + specified, one is created using the default AWS configuration + chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). + Returns: + Optional[JumpStartSerializablePayload]: A single default payload or None. + + Raises: + NotImplementedError: If the scope is not supported. + ValueError: If the combination of arguments specified is not supported. + VulnerableJumpStartModelError: If any of the dependencies required by the script have + known security vulnerabilities. + DeprecatedJumpStartModelError: If the version of the model is deprecated. + """ + example_payloads: Optional[List[JumpStartSerializablePayload]] = retrieve_all_examples( + region=region, + model_id=model_id, + model_version=model_version, + serialize=serialize, + tolerate_vulnerable_model=tolerate_vulnerable_model, + tolerate_deprecated_model=tolerate_deprecated_model, + sagemaker_session=sagemaker_session, + ) + + if example_payloads is None or len(example_payloads) == 0: + return None + + return example_payloads[0] diff --git a/tests/unit/sagemaker/jumpstart/constants.py b/tests/unit/sagemaker/jumpstart/constants.py index f5cc4fbb58..b65167165c 100644 --- a/tests/unit/sagemaker/jumpstart/constants.py +++ b/tests/unit/sagemaker/jumpstart/constants.py @@ -1620,6 +1620,172 @@ }, }, }, + "default_payloads": { + "model_id": "model-depth2img-stable-diffusion-v1-5-controlnet-v1-1-fp16", + "url": "https://huggingface.co/lllyasviel/control_v11f1p_sd15_depth", + "version": "1.0.0", + "min_sdk_version": "2.144.0", + "training_supported": False, + "incremental_training_supported": False, + "hosting_ecr_specs": { + "framework": "djl-deepspeed", + "framework_version": "0.21.0", + "py_version": "py38", + "huggingface_transformers_version": "4.17", + }, + "hosting_artifact_key": "stabilityai-infer/infer-model-depth2img-st" + "able-diffusion-v1-5-controlnet-v1-1-fp16.tar.gz", + "hosting_script_key": "source-directory-tarballs/stabilityai/inference/depth2img/v1.0.0/sourcedir.tar.gz", + "hosting_prepacked_artifact_key": "stabilityai-infer/prepack/v1.0.0/" + "infer-prepack-model-depth2img-stable-diffusion-v1-5-controlnet-v1-1-fp16.tar.gz", + "hosting_prepacked_artifact_version": "1.0.0", + "inference_vulnerable": False, + "inference_dependencies": [ + "accelerate==0.18.0", + "diffusers==0.14.0", + "fsspec==2023.4.0", + "huggingface-hub==0.14.1", + "transformers==4.26.1", + ], + "inference_vulnerabilities": [], + "training_vulnerable": False, + "training_dependencies": [], + "training_vulnerabilities": [], + "deprecated": False, + "inference_environment_variables": [ + { + "name": "SAGEMAKER_PROGRAM", + "type": "text", + "default": "inference.py", + "scope": "container", + "required_for_model_class": True, + }, + { + "name": "SAGEMAKER_SUBMIT_DIRECTORY", + "type": "text", + "default": "/opt/ml/model/code", + "scope": "container", + "required_for_model_class": False, + }, + { + "name": "SAGEMAKER_CONTAINER_LOG_LEVEL", + "type": "text", + "default": "20", + "scope": "container", + "required_for_model_class": False, + }, + { + "name": "SAGEMAKER_MODEL_SERVER_TIMEOUT", + "type": "text", + "default": "3600", + "scope": "container", + "required_for_model_class": False, + }, + { + "name": "ENDPOINT_SERVER_TIMEOUT", + "type": "int", + "default": 3600, + "scope": "container", + "required_for_model_class": True, + }, + { + "name": "MODEL_CACHE_ROOT", + "type": "text", + "default": "/opt/ml/model", + "scope": "container", + "required_for_model_class": True, + }, + { + "name": "SAGEMAKER_ENV", + "type": "text", + "default": "1", + "scope": "container", + "required_for_model_class": True, + }, + { + "name": "SAGEMAKER_MODEL_SERVER_WORKERS", + "type": "int", + "default": 1, + "scope": "container", + "required_for_model_class": True, + }, + ], + "metrics": [], + "default_inference_instance_type": "ml.g5.8xlarge", + "supported_inference_instance_types": [ + "ml.g5.8xlarge", + "ml.g5.xlarge", + "ml.g5.2xlarge", + "ml.g5.4xlarge", + "ml.g5.16xlarge", + "ml.p3.2xlarge", + "ml.g4dn.xlarge", + "ml.g4dn.2xlarge", + "ml.g4dn.4xlarge", + "ml.g4dn.8xlarge", + "ml.g4dn.16xlarge", + ], + "model_kwargs": {}, + "deploy_kwargs": {}, + "predictor_specs": { + "supported_content_types": ["application/json"], + "supported_accept_types": ["application/json"], + "default_content_type": "application/json", + "default_accept_type": "application/json", + }, + "inference_enable_network_isolation": True, + "validation_supported": False, + "fine_tuning_supported": False, + "resource_name_base": "sd-1-5-controlnet-1-1-fp16", + "default_payloads": { + "Dog": { + "content_type": "application/json", + "body": { + "prompt": "a dog", + "num_images_per_prompt": 2, + "num_inference_steps": 20, + "guidance_scale": 7.5, + "seed": 43, + "eta": 0.7, + "image": "$s3_b64", + }, + } + }, + "hosting_instance_type_variants": { + "regional_aliases": { + "af-south-1": { + "alias_ecr_uri_1": "626614931356.dkr.ecr.af-south-1.amazonaws.com/d" + "jl-inference:0.21.0-deepspeed0.8.3-cu117" + }, + }, + "variants": { + "c4": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, + "c5": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, + "c5d": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, + "c5n": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, + "c6i": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, + "g4dn": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, + "g5": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, + "inf1": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, + "inf2": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, + "local": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, + "local_gpu": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, + "m4": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, + "m5": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, + "m5d": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, + "p2": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, + "p3": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, + "p3dn": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, + "p4d": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, + "p4de": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, + "p5": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, + "r5": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, + "r5d": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, + "t2": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, + "t3": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, + }, + }, + }, "predictor-specs-model": { "model_id": "huggingface-text2text-flan-t5-xxl-fp16", "url": "https://huggingface.co/google/flan-t5-xxl", @@ -3190,6 +3356,7 @@ "min_sdk_version": "2.49.0", "training_supported": True, "incremental_training_supported": True, + "default_payloads": None, "hosting_ecr_specs": { "framework": "pytorch", "framework_version": "1.5.0", diff --git a/tests/unit/sagemaker/jumpstart/test_accessors.py b/tests/unit/sagemaker/jumpstart/test_accessors.py index 2de0351103..97427be1ae 100644 --- a/tests/unit/sagemaker/jumpstart/test_accessors.py +++ b/tests/unit/sagemaker/jumpstart/test_accessors.py @@ -11,6 +11,8 @@ # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. from __future__ import absolute_import +from io import BytesIO +from unittest import TestCase from mock.mock import Mock, patch import pytest @@ -134,3 +136,87 @@ def test_jumpstart_models_cache_set_reset_fxs(mock_model_cache: Mock): # necessary because accessors is a static module reload(accessors) + + +class TestS3Accessor(TestCase): + + bucket = "bucket" + key = "key" + + @patch("sagemaker.jumpstart.accessors.boto3.client") + def test_get_object(self, mocked_boto3_client): + + # required due to static class + reload(accessors) + + mocked_boto3_client.return_value = Mock() + + mocked_boto3_client.return_value.get_object.return_value = {"Body": BytesIO(b"s3-object")} + mocked_boto3_client.return_value.head_object.return_value = {"ContentLength": 1} + + response = accessors.JumpStartS3PayloadAccessor.get_object(bucket=self.bucket, key=self.key) + + self.assertEqual(response, b"s3-object") + + mocked_boto3_client.assert_called_once_with("s3", region_name="us-west-2") + mocked_boto3_client.return_value.get_object.assert_called_once_with( + Bucket=self.bucket, Key=self.key + ) + mocked_boto3_client.return_value.head_object.assert_called_once_with( + Bucket=self.bucket, Key=self.key + ) + + @patch("sagemaker.jumpstart.accessors.boto3.client") + def test_get_object_cached(self, mocked_boto3_client): + + # required due to static class + reload(accessors) + + mocked_boto3_client.return_value = Mock() + + mocked_boto3_client.return_value.get_object.return_value = {"Body": BytesIO(b"s3-object")} + mocked_boto3_client.return_value.head_object.return_value = {"ContentLength": 1} + + response = accessors.JumpStartS3PayloadAccessor.get_object_cached( + bucket=self.bucket, key=self.key + ) + response = accessors.JumpStartS3PayloadAccessor.get_object_cached( + bucket=self.bucket, key=self.key + ) + + self.assertEqual(response, b"s3-object") + + # only a single s3 call should be made when identical requests are made + mocked_boto3_client.assert_called_once_with("s3", region_name="us-west-2") + mocked_boto3_client.return_value.get_object.assert_called_once_with( + Bucket=self.bucket, Key=self.key + ) + mocked_boto3_client.return_value.head_object.assert_called_once_with( + Bucket=self.bucket, Key=self.key + ) + + @patch("sagemaker.jumpstart.accessors.boto3.client") + def test_get_object_limit_exceeded(self, mocked_boto3_client): + + # required due to static class + reload(accessors) + + mocked_boto3_client.return_value = Mock() + + mocked_boto3_client.return_value.get_object.return_value = {"Body": BytesIO(b"s3-object")} + mocked_boto3_client.return_value.head_object.return_value = {"ContentLength": 1e99} + + with pytest.raises(ValueError) as e: + accessors.JumpStartS3PayloadAccessor.get_object(bucket=self.bucket, key=self.key) + + self.assertEqual( + str(e.value), + "s3://bucket/key has size of 1e+99 bytes, which " + "exceeds maximum allowed size of 6000000 bytes.", + ) + + mocked_boto3_client.assert_called_once_with("s3", region_name="us-west-2") + mocked_boto3_client.return_value.get_object.assert_not_called() + mocked_boto3_client.return_value.head_object.assert_called_once_with( + Bucket=self.bucket, Key=self.key + ) diff --git a/tests/unit/sagemaker/jumpstart/test_payload_utils.py b/tests/unit/sagemaker/jumpstart/test_payload_utils.py new file mode 100644 index 0000000000..687c9154df --- /dev/null +++ b/tests/unit/sagemaker/jumpstart/test_payload_utils.py @@ -0,0 +1,63 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import +import base64 +from unittest import TestCase +from mock.mock import patch + +from sagemaker.jumpstart.payload_utils import PayloadSerializer +from sagemaker.jumpstart.types import JumpStartSerializablePayload + + +class TestPayloadSerializer(TestCase): + + payload_serializer = PayloadSerializer() + + @patch("sagemaker.jumpstart.payload_utils.JumpStartS3PayloadAccessor.get_object_cached") + def test_serialize_bytes_payload(self, mock_get_object_cached): + + mock_get_object_cached.return_value = "7897" + payload = JumpStartSerializablePayload( + { + "content_type": "audio/wav", + "body": "$s3", + } + ) + serialized_payload = self.payload_serializer.serialize(payload) + self.assertEqual(serialized_payload, "7897") + + @patch("sagemaker.jumpstart.payload_utils.JumpStartS3PayloadAccessor.get_object_cached") + def test_serialize_json_payload(self, mock_get_object_cached): + + mock_get_object_cached.return_value = base64.b64decode("encodedimage") + payload = JumpStartSerializablePayload( + { + "content_type": "application/json", + "body": { + "prompt": "a dog", + "num_images_per_prompt": 2, + "num_inference_steps": 20, + "guidance_scale": 7.5, + "seed": 43, + "eta": 0.7, + "image": "$s3_b64", + }, + } + ) + serialized_payload = self.payload_serializer.serialize(payload) + self.assertEqual( + serialized_payload, + '{"prompt": "a dog", "num_images_per_prompt": 2, ' + '"num_inference_steps": 20, "guidance_scale": 7.5, "seed": ' + '43, "eta": 0.7, "image": "encodedimage"}', + ) diff --git a/tests/unit/sagemaker/jumpstart/test_predictor.py b/tests/unit/sagemaker/jumpstart/test_predictor.py index 0758b54f29..4c2cd5b123 100644 --- a/tests/unit/sagemaker/jumpstart/test_predictor.py +++ b/tests/unit/sagemaker/jumpstart/test_predictor.py @@ -1,10 +1,13 @@ from __future__ import absolute_import +import base64 +from unittest import mock from unittest.mock import patch from sagemaker.deserializers import JSONDeserializer from sagemaker.jumpstart.enums import MIMEType from sagemaker import predictor +from sagemaker.jumpstart.model import JumpStartModel from sagemaker.jumpstart.utils import verify_model_region_and_return_specs @@ -35,3 +38,53 @@ def test_jumpstart_predictor_support( assert isinstance(js_predictor.deserializer, JSONDeserializer) assert js_predictor.accept == MIMEType.JSON + + +@patch("sagemaker.jumpstart.payload_utils.JumpStartS3PayloadAccessor.get_object_cached") +@patch("sagemaker.jumpstart.model.is_valid_model_id") +@patch("sagemaker.jumpstart.utils.verify_model_region_and_return_specs") +@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") +def test_jumpstart_serializable_payload_with_predictor( + patched_get_model_specs, + patched_verify_model_region_and_return_specs, + patched_is_valid_model_id, + patched_get_object_cached, +): + + patched_get_object_cached.return_value = base64.b64decode("encodedimage") + patched_is_valid_model_id.return_value = True + + patched_verify_model_region_and_return_specs.side_effect = verify_model_region_and_return_specs + patched_get_model_specs.side_effect = get_special_model_spec + + model_id, model_version = "default_payloads", "*" + + js_predictor = predictor.retrieve_default( + endpoint_name="blah", model_id=model_id, model_version=model_version + ) + + default_payload = JumpStartModel( + model_id=model_id, model_version=model_version + ).retrieve_example_payload() + + invoke_endpoint_mock = mock.Mock() + + js_predictor.sagemaker_session.sagemaker_runtime_client.invoke_endpoint = invoke_endpoint_mock + js_predictor._handle_response = mock.Mock() + + assert str(default_payload) == ( + "JumpStartSerializablePayload: {'content_type': 'application/json', 'accept': 'application/json'" + ", 'body': {'prompt': 'a dog', 'num_images_per_prompt': 2, 'num_inference_steps':" + " 20, 'guidance_scale': 7.5, 'seed': 43, 'eta': 0.7, 'image':" + " '$s3_b64'}}" + ) + + js_predictor.predict(default_payload) + + invoke_endpoint_mock.assert_called_once_with( + EndpointName="blah", + ContentType="application/json", + Accept="application/json", + Body='{"prompt": "a dog", "num_images_per_prompt": 2, "num_inference_steps": 20, ' + '"guidance_scale": 7.5, "seed": 43, "eta": 0.7, "image": "encodedimage"}', + )