1414from __future__ import absolute_import
1515import logging
1616import os
17- from typing import Any , Dict , List , Optional
17+ from typing import Any , Dict , List , Optional , Union
1818from urllib .parse import urlparse
1919from packaging .version import Version
2020import sagemaker
@@ -277,7 +277,7 @@ def get_jumpstart_base_name_if_jumpstart_model(
277277
278278def add_jumpstart_tags (
279279 tags : Optional [List [Dict [str , str ]]] = None ,
280- inference_model_uri : Optional [str ] = None ,
280+ inference_model_uri : Optional [Union [ str , dict ] ] = None ,
281281 inference_script_uri : Optional [str ] = None ,
282282 training_model_uri : Optional [str ] = None ,
283283 training_script_uri : Optional [str ] = None ,
@@ -289,7 +289,7 @@ def add_jumpstart_tags(
289289 Args:
290290 tags (Optional[List[Dict[str,str]]): Current tags for JumpStart inference
291291 or training job. (Default: None).
292- inference_model_uri (Optional[str]): S3 URI for inference model artifact.
292+ inference_model_uri (Optional[Union[dict, str] ]): S3 URI for inference model artifact.
293293 (Default: None).
294294 inference_script_uri (Optional[str]): S3 URI for inference script tarball.
295295 (Default: None).
@@ -302,6 +302,10 @@ def add_jumpstart_tags(
302302 "The URI (%s) is a pipeline variable which is only interpreted at execution time. "
303303 "As a result, the JumpStart resources will not be tagged."
304304 )
305+
306+ if isinstance (inference_model_uri , dict ):
307+ inference_model_uri = inference_model_uri .get ("S3DataSource" , {}).get ("S3Uri" , None )
308+
305309 if inference_model_uri :
306310 if is_pipeline_variable (inference_model_uri ):
307311 logging .warning (warn_msg , "inference_model_uri" )
0 commit comments