Skip to content

Commit 63b0372

Browse files
authored
Feat: tagging jumpstart models (#2860)
1 parent c03efb2 commit 63b0372

File tree

7 files changed

+1254
-3
lines changed

7 files changed

+1254
-3
lines changed

src/sagemaker/estimator.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,10 @@
4848
)
4949
from sagemaker.inputs import TrainingInput
5050
from sagemaker.job import _Job
51+
from sagemaker.jumpstart.utils import (
52+
add_jumpstart_tags,
53+
update_inference_tags_with_jumpstart_training_tags,
54+
)
5155
from sagemaker.local import LocalSession
5256
from sagemaker.model import (
5357
CONTAINER_LOG_LEVEL_PARAM_NAME,
@@ -442,7 +446,6 @@ def __init__(
442446
self.volume_kms_key = volume_kms_key
443447
self.max_run = max_run
444448
self.input_mode = input_mode
445-
self.tags = tags
446449
self.metric_definitions = metric_definitions
447450
self.model_uri = model_uri
448451
self.model_channel_name = model_channel_name
@@ -456,7 +459,9 @@ def __init__(
456459
self.entry_point = entry_point
457460
self.dependencies = dependencies
458461
self.uploaded_code = None
459-
462+
self.tags = add_jumpstart_tags(
463+
tags=tags, training_model_uri=self.model_uri, training_script_uri=self.source_dir
464+
)
460465
if self.instance_type in ("local", "local_gpu"):
461466
if self.instance_type == "local_gpu" and self.instance_count > 1:
462467
raise RuntimeError("Distributed Training in Local GPU is not supported")
@@ -1203,6 +1208,10 @@ def deploy(
12031208

12041209
model.name = model_name
12051210

1211+
tags = update_inference_tags_with_jumpstart_training_tags(
1212+
inference_tags=tags, training_tags=self.tags
1213+
)
1214+
12061215
return model.deploy(
12071216
instance_type=instance_type,
12081217
initial_instance_count=initial_instance_count,

src/sagemaker/jumpstart/constants.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,8 @@
112112
}
113113
JUMPSTART_REGION_NAME_SET = {region.region_name for region in JUMPSTART_LAUNCHED_REGIONS}
114114

115+
JUMPSTART_BUCKET_NAME_SET = {region.content_bucket for region in JUMPSTART_LAUNCHED_REGIONS}
116+
115117
JUMPSTART_DEFAULT_REGION_NAME = boto3.session.Session().region_name or "us-west-2"
116118

117119
JUMPSTART_DEFAULT_MANIFEST_FILE_S3_KEY = "models_manifest.json"
@@ -156,3 +158,12 @@ class VariableScope(str, Enum):
156158

157159
CONTAINER = "container"
158160
ALGORITHM = "algorithm"
161+
162+
163+
class JumpStartTag(str, Enum):
164+
"""Enum class for tag keys to apply to JumpStart models."""
165+
166+
INFERENCE_MODEL_URI = "aws-jumpstart-inference-model-uri"
167+
INFERENCE_SCRIPT_URI = "aws-jumpstart-inference-script-uri"
168+
TRAINING_MODEL_URI = "aws-jumpstart-training-model-uri"
169+
TRAINING_SCRIPT_URI = "aws-jumpstart-training-script-uri"

src/sagemaker/jumpstart/utils.py

Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,12 @@
1414
from __future__ import absolute_import
1515
import logging
1616
from typing import Dict, List, Optional
17+
from urllib.parse import urlparse
1718
from packaging.version import Version
1819
import sagemaker
1920
from sagemaker.jumpstart import constants
2021
from sagemaker.jumpstart import accessors
22+
from sagemaker.s3 import parse_s3_url
2123
from sagemaker.jumpstart.exceptions import (
2224
DeprecatedJumpStartModelError,
2325
VulnerableJumpStartModelError,
@@ -150,6 +152,145 @@ def is_jumpstart_model_input(model_id: Optional[str], version: Optional[str]) ->
150152
return False
151153

152154

155+
def is_jumpstart_model_uri(uri: Optional[str]) -> bool:
156+
"""Returns True if URI corresponds to a JumpStart-hosted model.
157+
158+
Args:
159+
uri (Optional[str]): uri for inference/training job.
160+
"""
161+
162+
bucket = None
163+
if urlparse(uri).scheme == "s3":
164+
bucket, _ = parse_s3_url(uri)
165+
166+
return bucket in constants.JUMPSTART_BUCKET_NAME_SET
167+
168+
169+
def tag_key_in_array(tag_key: str, tag_array: List[Dict[str, str]]) -> bool:
170+
"""Returns True if ``tag_key`` is in the ``tag_array``.
171+
172+
Args:
173+
tag_key (str): the tag key to check if it's already in the ``tag_array``.
174+
tag_array (List[Dict[str, str]]): array of tags to check for ``tag_key``.
175+
"""
176+
for tag in tag_array:
177+
if tag_key == tag["Key"]:
178+
return True
179+
return False
180+
181+
182+
def get_tag_value(tag_key: str, tag_array: List[Dict[str, str]]) -> str:
183+
"""Return the value of a tag whose key matches the given ``tag_key``.
184+
185+
Args:
186+
tag_key (str): AWS tag for which to search.
187+
tag_array (List[Dict[str, str]]): List of AWS tags, each formatted as dicts.
188+
189+
Raises:
190+
KeyError: If the number of matches for the ``tag_key`` is not equal to 1.
191+
"""
192+
tag_values = [tag["Value"] for tag in tag_array if tag_key == tag["Key"]]
193+
if len(tag_values) != 1:
194+
raise KeyError(
195+
f"Cannot get value of tag for tag key '{tag_key}' -- found {len(tag_values)} "
196+
f"number of matches in the tag list."
197+
)
198+
199+
return tag_values[0]
200+
201+
202+
def add_single_jumpstart_tag(
203+
uri: str, tag_key: constants.JumpStartTag, curr_tags: Optional[List[Dict[str, str]]]
204+
) -> Optional[List]:
205+
"""Adds ``tag_key`` to ``curr_tags`` if ``uri`` corresponds to a JumpStart model.
206+
207+
Args:
208+
uri (str): URI which may correspond to a JumpStart model.
209+
tag_key (constants.JumpStartTag): Custom tag to apply to current tags if the URI
210+
corresponds to a JumpStart model.
211+
curr_tags (Optional[List]): Current tags associated with ``Estimator`` or ``Model``.
212+
"""
213+
if is_jumpstart_model_uri(uri):
214+
if curr_tags is None:
215+
curr_tags = []
216+
if not tag_key_in_array(tag_key, curr_tags):
217+
curr_tags.append(
218+
{
219+
"Key": tag_key,
220+
"Value": uri,
221+
}
222+
)
223+
return curr_tags
224+
225+
226+
def add_jumpstart_tags(
227+
tags: Optional[List[Dict[str, str]]] = None,
228+
inference_model_uri: Optional[str] = None,
229+
inference_script_uri: Optional[str] = None,
230+
training_model_uri: Optional[str] = None,
231+
training_script_uri: Optional[str] = None,
232+
) -> Optional[List[Dict[str, str]]]:
233+
"""Add custom tags to JumpStart models, return the updated tags.
234+
235+
No-op if this is not a JumpStart model related resource.
236+
237+
Args:
238+
tags (Optional[List[Dict[str,str]]): Current tags for JumpStart inference
239+
or training job. (Default: None).
240+
inference_model_uri (Optional[str]): S3 URI for inference model artifact.
241+
(Default: None).
242+
inference_script_uri (Optional[str]): S3 URI for inference script tarball.
243+
(Default: None).
244+
training_model_uri (Optional[str]): S3 URI for training model artifact.
245+
(Default: None).
246+
training_script_uri (Optional[str]): S3 URI for training script tarball.
247+
(Default: None).
248+
"""
249+
250+
if inference_model_uri:
251+
tags = add_single_jumpstart_tag(
252+
inference_model_uri, constants.JumpStartTag.INFERENCE_MODEL_URI, tags
253+
)
254+
255+
if inference_script_uri:
256+
tags = add_single_jumpstart_tag(
257+
inference_script_uri, constants.JumpStartTag.INFERENCE_SCRIPT_URI, tags
258+
)
259+
260+
if training_model_uri:
261+
tags = add_single_jumpstart_tag(
262+
training_model_uri, constants.JumpStartTag.TRAINING_MODEL_URI, tags
263+
)
264+
265+
if training_script_uri:
266+
tags = add_single_jumpstart_tag(
267+
training_script_uri, constants.JumpStartTag.TRAINING_SCRIPT_URI, tags
268+
)
269+
270+
return tags
271+
272+
273+
def update_inference_tags_with_jumpstart_training_tags(
274+
inference_tags: Optional[List[Dict[str, str]]], training_tags: Optional[List[Dict[str, str]]]
275+
) -> Optional[List[Dict[str, str]]]:
276+
"""Updates the tags for the ``sagemaker.model.Model.deploy`` command with any JumpStart tags.
277+
278+
Args:
279+
inference_tags (Optional[List[Dict[str, str]]]): Custom tags to appy to inference job.
280+
training_tags (Optional[List[Dict[str, str]]]): Tags from training job.
281+
"""
282+
if training_tags:
283+
for tag_key in constants.JumpStartTag:
284+
if tag_key_in_array(tag_key, training_tags):
285+
tag_value = get_tag_value(tag_key, training_tags)
286+
if inference_tags is None:
287+
inference_tags = []
288+
if not tag_key_in_array(tag_key, inference_tags):
289+
inference_tags.append({"Key": tag_key, "Value": tag_value})
290+
291+
return inference_tags
292+
293+
153294
def verify_model_region_and_return_specs(
154295
model_id: Optional[str],
155296
version: Optional[str],

src/sagemaker/model.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
from sagemaker.deprecations import removed_kwargs
3535
from sagemaker.predictor import PredictorBase
3636
from sagemaker.transformer import Transformer
37+
from sagemaker.jumpstart.utils import add_jumpstart_tags
3738

3839
LOGGER = logging.getLogger("sagemaker")
3940

@@ -985,6 +986,10 @@ def deploy(
985986
removed_kwargs("update_endpoint", kwargs)
986987
self._init_sagemaker_session_if_does_not_exist(instance_type)
987988

989+
tags = add_jumpstart_tags(
990+
tags=tags, inference_model_uri=self.model_data, inference_script_uri=self.source_dir
991+
)
992+
988993
if self.role is None:
989994
raise ValueError("Role can not be null for deploying a model")
990995

0 commit comments

Comments
 (0)