Skip to content

Commit cba9034

Browse files
committed
feat: jumpstart default payloads
1 parent ddd06bb commit cba9034

File tree

12 files changed

+828
-9
lines changed

12 files changed

+828
-9
lines changed

src/sagemaker/base_predictor.py

Lines changed: 36 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,9 @@
3232
StreamDeserializer,
3333
StringDeserializer,
3434
)
35+
from sagemaker.jumpstart.payload_utils import PayloadSerializer
36+
from sagemaker.jumpstart.types import JumpStartSerializablePayload
37+
from sagemaker.jumpstart.utils import get_jumpstart_content_bucket
3538
from sagemaker.model_monitor import (
3639
DataCaptureConfig,
3740
DefaultModelMonitor,
@@ -201,20 +204,42 @@ def _create_request_args(
201204
custom_attributes=None,
202205
):
203206
"""Placeholder docstring"""
207+
208+
js_accept = None
209+
210+
if isinstance(data, JumpStartSerializablePayload):
211+
s3_client = self.sagemaker_session.s3_client
212+
region = self.sagemaker_session._region_name
213+
bucket = get_jumpstart_content_bucket(region)
214+
215+
js_serialized_data = PayloadSerializer(
216+
bucket=bucket, region=region, s3_client=s3_client
217+
).serialize(data)
218+
js_content_type = data.content_type
219+
js_accept = data.accept
220+
204221
args = dict(initial_args) if initial_args else {}
205222

206223
if "EndpointName" not in args:
207224
args["EndpointName"] = self.endpoint_name
208225

209226
if "ContentType" not in args:
210-
args["ContentType"] = (
211-
self.content_type
212-
if isinstance(self.content_type, str)
213-
else ", ".join(self.content_type)
214-
)
227+
if isinstance(data, JumpStartSerializablePayload):
228+
args["ContentType"] = js_content_type
229+
else:
230+
args["ContentType"] = (
231+
self.content_type
232+
if isinstance(self.content_type, str)
233+
else ", ".join(self.content_type)
234+
)
215235

216236
if "Accept" not in args:
217-
args["Accept"] = self.accept if isinstance(self.accept, str) else ", ".join(self.accept)
237+
if isinstance(data, JumpStartSerializablePayload) and js_accept:
238+
args["Accept"] = js_accept
239+
else:
240+
args["Accept"] = (
241+
self.accept if isinstance(self.accept, str) else ", ".join(self.accept)
242+
)
218243

219244
if target_model:
220245
args["TargetModel"] = target_model
@@ -228,7 +253,11 @@ def _create_request_args(
228253
if custom_attributes:
229254
args["CustomAttributes"] = custom_attributes
230255

231-
data = self.serializer.serialize(data)
256+
data = (
257+
self.serializer.serialize(data)
258+
if not isinstance(data, JumpStartSerializablePayload)
259+
else js_serialized_data
260+
)
232261

233262
args["Body"] = data
234263
return args

src/sagemaker/jumpstart/accessors.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# language governing permissions and limitations under the License.
1313
"""This module contains accessors related to SageMaker JumpStart."""
1414
from __future__ import absolute_import
15+
import functools
1516
from typing import Any, Dict, List, Optional
1617
import boto3
1718

@@ -37,6 +38,49 @@ def get_sagemaker_version() -> str:
3738
return SageMakerSettings._parsed_sagemaker_version
3839

3940

41+
class JumpStartS3Accessor(object):
42+
"""Static class for storing and retrieving auxilliary s3 artifacts."""
43+
44+
@functools.cache
45+
@staticmethod
46+
def _get_default_s3_client(region: str = JUMPSTART_DEFAULT_REGION_NAME) -> boto3.client:
47+
"""Returns default s3 client associated with the region.
48+
49+
Result is cached so multiple clients in memory are not created.
50+
"""
51+
return boto3.client("s3", region_name=region)
52+
53+
@functools.lru_cache
54+
@staticmethod
55+
def get_object_cached(
56+
bucket: str,
57+
key: str,
58+
region: str = JUMPSTART_DEFAULT_REGION_NAME,
59+
s3_client: Optional[boto3.client] = None,
60+
) -> bytes:
61+
"""Returns s3 object located at the bucket and key.
62+
63+
Requests are cached so that the same s3 request is never made more
64+
than once, unless a different region or client is used.
65+
"""
66+
return JumpStartS3Accessor.get_object(
67+
bucket=bucket, key=key, region=region, s3_client=s3_client
68+
)
69+
70+
@staticmethod
71+
def get_object(
72+
bucket: str,
73+
key: str,
74+
region: str = JUMPSTART_DEFAULT_REGION_NAME,
75+
s3_client: Optional[boto3.client] = None,
76+
) -> bytes:
77+
"""Returns s3 object located at the bucket and key."""
78+
if s3_client is None:
79+
s3_client = JumpStartS3Accessor._get_default_s3_client(region)
80+
81+
return s3_client.get_object(Bucket=bucket, Key=key)["Body"].read()
82+
83+
4084
class JumpStartModelsAccessor(object):
4185
"""Static class for storing the JumpStart models cache."""
4286

src/sagemaker/jumpstart/artifacts/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,3 +61,6 @@
6161
_retrieve_model_package_arn,
6262
_retrieve_model_package_model_artifact_s3_uri,
6363
)
64+
from sagemaker.jumpstart.artifacts.payloads import ( # noqa: F401
65+
_retrieve_default_payloads,
66+
)
Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""This module contains functions for obtaining JumpStart payloads."""
14+
from __future__ import absolute_import
15+
from copy import deepcopy
16+
from typing import Dict, Optional
17+
from sagemaker.jumpstart.constants import (
18+
DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
19+
JUMPSTART_DEFAULT_REGION_NAME,
20+
)
21+
from sagemaker.jumpstart.enums import (
22+
JumpStartScriptScope,
23+
)
24+
from sagemaker.jumpstart.types import JumpStartSerializablePayload
25+
from sagemaker.jumpstart.utils import (
26+
verify_model_region_and_return_specs,
27+
)
28+
from sagemaker.session import Session
29+
30+
31+
def _retrieve_default_payloads(
32+
model_id: str,
33+
model_version: str,
34+
region: Optional[str],
35+
tolerate_vulnerable_model: bool = False,
36+
tolerate_deprecated_model: bool = False,
37+
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
38+
) -> Optional[Dict[str, JumpStartSerializablePayload]]:
39+
"""Returns default payloads.
40+
41+
Args:
42+
model_id (str): JumpStart model ID of the JumpStart model for which to
43+
get default payloads.
44+
model_version (str): Version of the JumpStart model for which to retrieve the
45+
default resource name.
46+
region (Optional[str]): Region for which to retrieve the
47+
default resource name.
48+
tolerate_vulnerable_model (bool): True if vulnerable versions of model
49+
specifications should be tolerated (exception not raised). If False, raises an
50+
exception if the script used by this version of the model has dependencies with known
51+
security vulnerabilities. (Default: False).
52+
tolerate_deprecated_model (bool): True if deprecated versions of model
53+
specifications should be tolerated (exception not raised). If False, raises
54+
an exception if the version of the model is deprecated. (Default: False).
55+
sagemaker_session (sagemaker.session.Session): A SageMaker Session
56+
object, used for SageMaker interactions. If not
57+
specified, one is created using the default AWS configuration
58+
chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
59+
Returns:
60+
str: the default payload.
61+
"""
62+
63+
if region is None:
64+
region = JUMPSTART_DEFAULT_REGION_NAME
65+
66+
model_specs = verify_model_region_and_return_specs(
67+
model_id=model_id,
68+
version=model_version,
69+
scope=JumpStartScriptScope.INFERENCE,
70+
region=region,
71+
tolerate_vulnerable_model=tolerate_vulnerable_model,
72+
tolerate_deprecated_model=tolerate_deprecated_model,
73+
sagemaker_session=sagemaker_session,
74+
)
75+
76+
default_payloads = model_specs.default_payloads
77+
78+
if default_payloads:
79+
for payload in default_payloads.values():
80+
payload.accept = getattr(
81+
payload, "accept", model_specs.predictor_specs.default_accept_type
82+
)
83+
84+
return deepcopy(default_payloads) if default_payloads else None

src/sagemaker/jumpstart/model.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import re
1717

1818
from typing import Dict, List, Optional, Union
19+
from sagemaker import payloads
1920
from sagemaker.async_inference.async_inference_config import AsyncInferenceConfig
2021
from sagemaker.base_deserializers import BaseDeserializer
2122
from sagemaker.base_serializers import BaseSerializer
@@ -28,6 +29,7 @@
2829
get_deploy_kwargs,
2930
get_init_kwargs,
3031
)
32+
from sagemaker.jumpstart.types import JumpStartSerializablePayload
3133
from sagemaker.jumpstart.utils import is_valid_model_id
3234
from sagemaker.utils import stringify_object
3335
from sagemaker.model import MODEL_PACKAGE_ARN_PATTERN, Model
@@ -312,6 +314,27 @@ def _is_valid_model_id_hook():
312314

313315
super(JumpStartModel, self).__init__(**model_init_kwargs.to_kwargs_dict())
314316

317+
def retrieve_default_payload(self) -> JumpStartSerializablePayload:
318+
"""Returns default payload associated with the model.
319+
320+
Payload can be directly used with the `sagemaker.predictor.Predictor.predict(...)` function.
321+
"""
322+
sample_payloads: Optional[List[JumpStartSerializablePayload]] = payloads.retrieve_samples(
323+
model_id=self.model_id,
324+
model_version=self.model_version,
325+
region=self.region,
326+
tolerate_deprecated_model=self.tolerate_deprecated_model,
327+
tolerate_vulnerable_model=self.tolerate_vulnerable_model,
328+
sagemaker_session=self.sagemaker_session,
329+
)
330+
331+
if sample_payloads is None or len(sample_payloads) == 0:
332+
raise NotImplementedError(
333+
f"No default payload supported for model ID '{self.model_id}'."
334+
)
335+
336+
return sample_payloads[0]
337+
315338
def _create_sagemaker_model(
316339
self,
317340
instance_type=None,

0 commit comments

Comments
 (0)