From e3f0312c3e44960e400357a78d3204721cddbe1d Mon Sep 17 00:00:00 2001 From: Evan Kravitz Date: Thu, 5 Oct 2023 14:30:20 +0000 Subject: [PATCH 1/2] fix: js tagging s3 prefix --- src/sagemaker/jumpstart/utils.py | 10 +++++++--- src/sagemaker/model.py | 4 +++- tests/unit/sagemaker/jumpstart/test_utils.py | 10 ++++++++++ 3 files changed, 20 insertions(+), 4 deletions(-) diff --git a/src/sagemaker/jumpstart/utils.py b/src/sagemaker/jumpstart/utils.py index f77e1ae231..0acad85436 100644 --- a/src/sagemaker/jumpstart/utils.py +++ b/src/sagemaker/jumpstart/utils.py @@ -14,7 +14,7 @@ from __future__ import absolute_import import logging import os -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Union from urllib.parse import urlparse from packaging.version import Version import sagemaker @@ -277,7 +277,7 @@ def get_jumpstart_base_name_if_jumpstart_model( def add_jumpstart_tags( tags: Optional[List[Dict[str, str]]] = None, - inference_model_uri: Optional[str] = None, + inference_model_uri: Optional[Union[str, dict]] = None, inference_script_uri: Optional[str] = None, training_model_uri: Optional[str] = None, training_script_uri: Optional[str] = None, @@ -289,7 +289,7 @@ def add_jumpstart_tags( Args: tags (Optional[List[Dict[str,str]]): Current tags for JumpStart inference or training job. (Default: None). - inference_model_uri (Optional[str]): S3 URI for inference model artifact. + inference_model_uri (Optional[Union[dict, str]]): S3 URI for inference model artifact. (Default: None). inference_script_uri (Optional[str]): S3 URI for inference script tarball. (Default: None). @@ -302,6 +302,10 @@ def add_jumpstart_tags( "The URI (%s) is a pipeline variable which is only interpreted at execution time. " "As a result, the JumpStart resources will not be tagged." ) + + if isinstance(inference_model_uri, dict): + inference_model_uri = inference_model_uri.get("S3DataSource", {}).get("S3Uri", None) + if inference_model_uri: if is_pipeline_variable(inference_model_uri): logging.warning(warn_msg, "inference_model_uri") diff --git a/src/sagemaker/model.py b/src/sagemaker/model.py index 018fa35b85..6454901cdf 100644 --- a/src/sagemaker/model.py +++ b/src/sagemaker/model.py @@ -1345,7 +1345,9 @@ def deploy( tags = add_jumpstart_tags( tags=tags, - inference_model_uri=self.model_data if isinstance(self.model_data, str) else None, + inference_model_uri=self.model_data + if isinstance(self.model_data, (str, dict)) + else None, inference_script_uri=self.source_dir, ) diff --git a/tests/unit/sagemaker/jumpstart/test_utils.py b/tests/unit/sagemaker/jumpstart/test_utils.py index 3ddb1b10e8..6d148163e7 100644 --- a/tests/unit/sagemaker/jumpstart/test_utils.py +++ b/tests/unit/sagemaker/jumpstart/test_utils.py @@ -216,6 +216,16 @@ def test_add_jumpstart_tags_inference(): inference_script_uri=inference_script_uri, ) == [{"Key": JumpStartTag.INFERENCE_MODEL_URI.value, "Value": inference_model_uri}] + + tags = [] + inference_model_uri = {"S3DataSource": {"S3Uri": random_jumpstart_s3_uri("random_key")}} + inference_script_uri = "dfsdfs" + assert utils.add_jumpstart_tags( + tags=tags, + inference_model_uri=inference_model_uri, + inference_script_uri=inference_script_uri, + ) == [{"Key": JumpStartTag.INFERENCE_MODEL_URI.value, "Value": inference_model_uri["S3DataSource"]["S3Uri"]}] + tags = [{"Key": "some", "Value": "tag"}] inference_model_uri = random_jumpstart_s3_uri("random_key") inference_script_uri = "dfsdfs" From 7effc34956916f85c6c08540fc2d49f009a8b275 Mon Sep 17 00:00:00 2001 From: Evan Kravitz Date: Thu, 5 Oct 2023 15:05:15 +0000 Subject: [PATCH 2/2] chore: add additional unit test and fix formatting --- tests/unit/sagemaker/jumpstart/test_utils.py | 22 ++++++++++++++++++-- 1 file changed, 20 insertions(+), 2 deletions(-) diff --git a/tests/unit/sagemaker/jumpstart/test_utils.py b/tests/unit/sagemaker/jumpstart/test_utils.py index 6d148163e7..8857641dc4 100644 --- a/tests/unit/sagemaker/jumpstart/test_utils.py +++ b/tests/unit/sagemaker/jumpstart/test_utils.py @@ -216,7 +216,6 @@ def test_add_jumpstart_tags_inference(): inference_script_uri=inference_script_uri, ) == [{"Key": JumpStartTag.INFERENCE_MODEL_URI.value, "Value": inference_model_uri}] - tags = [] inference_model_uri = {"S3DataSource": {"S3Uri": random_jumpstart_s3_uri("random_key")}} inference_script_uri = "dfsdfs" @@ -224,7 +223,26 @@ def test_add_jumpstart_tags_inference(): tags=tags, inference_model_uri=inference_model_uri, inference_script_uri=inference_script_uri, - ) == [{"Key": JumpStartTag.INFERENCE_MODEL_URI.value, "Value": inference_model_uri["S3DataSource"]["S3Uri"]}] + ) == [ + { + "Key": JumpStartTag.INFERENCE_MODEL_URI.value, + "Value": inference_model_uri["S3DataSource"]["S3Uri"], + } + ] + + tags = [] + inference_model_uri = {"S3DataSource": {"S3Uri": random_jumpstart_s3_uri("random_key/prefix/")}} + inference_script_uri = "dfsdfs" + assert utils.add_jumpstart_tags( + tags=tags, + inference_model_uri=inference_model_uri, + inference_script_uri=inference_script_uri, + ) == [ + { + "Key": JumpStartTag.INFERENCE_MODEL_URI.value, + "Value": inference_model_uri["S3DataSource"]["S3Uri"], + } + ] tags = [{"Key": "some", "Value": "tag"}] inference_model_uri = random_jumpstart_s3_uri("random_key")