1717from urllib .parse import urlparse
1818from packaging .version import Version
1919import sagemaker
20- from sagemaker .jumpstart import constants
20+ from sagemaker .jumpstart import constants , enums
2121from sagemaker .jumpstart import accessors
2222from sagemaker .s3 import parse_s3_url
2323from sagemaker .jumpstart .exceptions import (
@@ -200,13 +200,13 @@ def get_tag_value(tag_key: str, tag_array: List[Dict[str, str]]) -> str:
200200
201201
202202def add_single_jumpstart_tag (
203- uri : str , tag_key : constants .JumpStartTag , curr_tags : Optional [List [Dict [str , str ]]]
203+ uri : str , tag_key : enums .JumpStartTag , curr_tags : Optional [List [Dict [str , str ]]]
204204) -> Optional [List ]:
205205 """Adds ``tag_key`` to ``curr_tags`` if ``uri`` corresponds to a JumpStart model.
206206
207207 Args:
208208 uri (str): URI which may correspond to a JumpStart model.
209- tag_key (constants .JumpStartTag): Custom tag to apply to current tags if the URI
209+ tag_key (enums .JumpStartTag): Custom tag to apply to current tags if the URI
210210 corresponds to a JumpStart model.
211211 curr_tags (Optional[List]): Current tags associated with ``Estimator`` or ``Model``.
212212 """
@@ -249,22 +249,22 @@ def add_jumpstart_tags(
249249
250250 if inference_model_uri :
251251 tags = add_single_jumpstart_tag (
252- inference_model_uri , constants .JumpStartTag .INFERENCE_MODEL_URI , tags
252+ inference_model_uri , enums .JumpStartTag .INFERENCE_MODEL_URI , tags
253253 )
254254
255255 if inference_script_uri :
256256 tags = add_single_jumpstart_tag (
257- inference_script_uri , constants .JumpStartTag .INFERENCE_SCRIPT_URI , tags
257+ inference_script_uri , enums .JumpStartTag .INFERENCE_SCRIPT_URI , tags
258258 )
259259
260260 if training_model_uri :
261261 tags = add_single_jumpstart_tag (
262- training_model_uri , constants .JumpStartTag .TRAINING_MODEL_URI , tags
262+ training_model_uri , enums .JumpStartTag .TRAINING_MODEL_URI , tags
263263 )
264264
265265 if training_script_uri :
266266 tags = add_single_jumpstart_tag (
267- training_script_uri , constants .JumpStartTag .TRAINING_SCRIPT_URI , tags
267+ training_script_uri , enums .JumpStartTag .TRAINING_SCRIPT_URI , tags
268268 )
269269
270270 return tags
@@ -280,7 +280,7 @@ def update_inference_tags_with_jumpstart_training_tags(
280280 training_tags (Optional[List[Dict[str, str]]]): Tags from training job.
281281 """
282282 if training_tags :
283- for tag_key in constants .JumpStartTag :
283+ for tag_key in enums .JumpStartTag :
284284 if tag_key_in_array (tag_key , training_tags ):
285285 tag_value = get_tag_value (tag_key , training_tags )
286286 if inference_tags is None :
0 commit comments