diff --git a/src/sagemaker/jumpstart/utils.py b/src/sagemaker/jumpstart/utils.py index f77e1ae231..0acad85436 100644 --- a/src/sagemaker/jumpstart/utils.py +++ b/src/sagemaker/jumpstart/utils.py @@ -14,7 +14,7 @@ from __future__ import absolute_import import logging import os -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Union from urllib.parse import urlparse from packaging.version import Version import sagemaker @@ -277,7 +277,7 @@ def get_jumpstart_base_name_if_jumpstart_model( def add_jumpstart_tags( tags: Optional[List[Dict[str, str]]] = None, - inference_model_uri: Optional[str] = None, + inference_model_uri: Optional[Union[str, dict]] = None, inference_script_uri: Optional[str] = None, training_model_uri: Optional[str] = None, training_script_uri: Optional[str] = None, @@ -289,7 +289,7 @@ def add_jumpstart_tags( Args: tags (Optional[List[Dict[str,str]]): Current tags for JumpStart inference or training job. (Default: None). - inference_model_uri (Optional[str]): S3 URI for inference model artifact. + inference_model_uri (Optional[Union[dict, str]]): S3 URI for inference model artifact. (Default: None). inference_script_uri (Optional[str]): S3 URI for inference script tarball. (Default: None). @@ -302,6 +302,10 @@ def add_jumpstart_tags( "The URI (%s) is a pipeline variable which is only interpreted at execution time. " "As a result, the JumpStart resources will not be tagged." ) + + if isinstance(inference_model_uri, dict): + inference_model_uri = inference_model_uri.get("S3DataSource", {}).get("S3Uri", None) + if inference_model_uri: if is_pipeline_variable(inference_model_uri): logging.warning(warn_msg, "inference_model_uri") diff --git a/src/sagemaker/model.py b/src/sagemaker/model.py index 018fa35b85..6454901cdf 100644 --- a/src/sagemaker/model.py +++ b/src/sagemaker/model.py @@ -1345,7 +1345,9 @@ def deploy( tags = add_jumpstart_tags( tags=tags, - inference_model_uri=self.model_data if isinstance(self.model_data, str) else None, + inference_model_uri=self.model_data + if isinstance(self.model_data, (str, dict)) + else None, inference_script_uri=self.source_dir, ) diff --git a/tests/unit/sagemaker/jumpstart/test_utils.py b/tests/unit/sagemaker/jumpstart/test_utils.py index 3ddb1b10e8..8857641dc4 100644 --- a/tests/unit/sagemaker/jumpstart/test_utils.py +++ b/tests/unit/sagemaker/jumpstart/test_utils.py @@ -216,6 +216,34 @@ def test_add_jumpstart_tags_inference(): inference_script_uri=inference_script_uri, ) == [{"Key": JumpStartTag.INFERENCE_MODEL_URI.value, "Value": inference_model_uri}] + tags = [] + inference_model_uri = {"S3DataSource": {"S3Uri": random_jumpstart_s3_uri("random_key")}} + inference_script_uri = "dfsdfs" + assert utils.add_jumpstart_tags( + tags=tags, + inference_model_uri=inference_model_uri, + inference_script_uri=inference_script_uri, + ) == [ + { + "Key": JumpStartTag.INFERENCE_MODEL_URI.value, + "Value": inference_model_uri["S3DataSource"]["S3Uri"], + } + ] + + tags = [] + inference_model_uri = {"S3DataSource": {"S3Uri": random_jumpstart_s3_uri("random_key/prefix/")}} + inference_script_uri = "dfsdfs" + assert utils.add_jumpstart_tags( + tags=tags, + inference_model_uri=inference_model_uri, + inference_script_uri=inference_script_uri, + ) == [ + { + "Key": JumpStartTag.INFERENCE_MODEL_URI.value, + "Value": inference_model_uri["S3DataSource"]["S3Uri"], + } + ] + tags = [{"Key": "some", "Value": "tag"}] inference_model_uri = random_jumpstart_s3_uri("random_key") inference_script_uri = "dfsdfs"