9898 to_string ,
9999 check_and_get_run_experiment_config ,
100100 resolve_value_from_config ,
101+ format_tags ,
102+ Tags ,
101103)
102104from sagemaker .workflow import is_pipeline_variable
103105from sagemaker .workflow .entities import PipelineVariable
@@ -144,7 +146,7 @@ def __init__(
144146 output_kms_key : Optional [Union [str , PipelineVariable ]] = None ,
145147 base_job_name : Optional [str ] = None ,
146148 sagemaker_session : Optional [Session ] = None ,
147- tags : Optional [List [ Dict [ str , Union [ str , PipelineVariable ]]] ] = None ,
149+ tags : Optional [Tags ] = None ,
148150 subnets : Optional [List [Union [str , PipelineVariable ]]] = None ,
149151 security_group_ids : Optional [List [Union [str , PipelineVariable ]]] = None ,
150152 model_uri : Optional [str ] = None ,
@@ -269,8 +271,8 @@ def __init__(
269271 manages interactions with Amazon SageMaker APIs and any other
270272 AWS services needed. If not specified, the estimator creates one
271273 using the default AWS configuration chain.
272- tags (list[dict[str, str] or list[dict[str, PipelineVariable] ]):
273- List of tags for labeling a training job. For more, see
274+ tags (Optional[Tags ]):
275+ Tags for labeling a training job. For more, see
274276 https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html.
275277 subnets (list[str] or list[PipelineVariable]): List of subnet ids. If not
276278 specified training job will be created without VPC config.
@@ -601,6 +603,7 @@ def __init__(
601603 else :
602604 self .sagemaker_session = sagemaker_session or Session ()
603605
606+ tags = format_tags (tags )
604607 self .tags = (
605608 add_jumpstart_uri_tags (
606609 tags = tags , training_model_uri = self .model_uri , training_script_uri = self .source_dir
@@ -1347,7 +1350,7 @@ def compile_model(
13471350 framework = None ,
13481351 framework_version = None ,
13491352 compile_max_run = 15 * 60 ,
1350- tags = None ,
1353+ tags : Optional [ Tags ] = None ,
13511354 target_platform_os = None ,
13521355 target_platform_arch = None ,
13531356 target_platform_accelerator = None ,
@@ -1373,7 +1376,7 @@ def compile_model(
13731376 compile_max_run (int): Timeout in seconds for compilation (default:
13741377 15 * 60). After this amount of time Amazon SageMaker Neo
13751378 terminates the compilation job regardless of its current status.
1376- tags (list[dict]): List of tags for labeling a compilation job. For
1379+ tags (list[dict]): Tags for labeling a compilation job. For
13771380 more, see
13781381 https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html.
13791382 target_platform_os (str): Target Platform OS, for example: 'LINUX'.
@@ -1415,7 +1418,7 @@ def compile_model(
14151418 input_shape ,
14161419 output_path ,
14171420 self .role ,
1418- tags ,
1421+ format_tags ( tags ) ,
14191422 self ._compilation_job_name (),
14201423 compile_max_run ,
14211424 framework = framework ,
@@ -1527,7 +1530,7 @@ def deploy(
15271530 model_name = None ,
15281531 kms_key = None ,
15291532 data_capture_config = None ,
1530- tags = None ,
1533+ tags : Optional [ Tags ] = None ,
15311534 serverless_inference_config = None ,
15321535 async_inference_config = None ,
15331536 volume_size = None ,
@@ -1596,8 +1599,10 @@ def deploy(
15961599 empty object passed through, will use pre-defined values in
15971600 ``ServerlessInferenceConfig`` class to deploy serverless endpoint. Deploy an
15981601 instance based endpoint if it's None. (default: None)
1599- tags(List[dict[str, str]] ): Optional. The list of tags to attach to this specific
1602+ tags(Optional[Tags] ): Optional. Tags to attach to this specific
16001603 endpoint. Example:
1604+ >>> tags = {'tagname', 'tagvalue'}
1605+ Or
16011606 >>> tags = [{'Key': 'tagname', 'Value': 'tagvalue'}]
16021607 For more information about tags, see
16031608 https://boto3.amazonaws.com/v1/documentation\
@@ -1659,7 +1664,7 @@ def deploy(
16591664 model .name = model_name
16601665
16611666 tags = update_inference_tags_with_jumpstart_training_tags (
1662- inference_tags = tags , training_tags = self .tags
1667+ inference_tags = format_tags ( tags ) , training_tags = self .tags
16631668 )
16641669
16651670 return model .deploy (
@@ -2007,7 +2012,7 @@ def transformer(
20072012 env = None ,
20082013 max_concurrent_transforms = None ,
20092014 max_payload = None ,
2010- tags = None ,
2015+ tags : Optional [ Tags ] = None ,
20112016 role = None ,
20122017 volume_kms_key = None ,
20132018 vpc_config_override = vpc_utils .VPC_CONFIG_DEFAULT ,
@@ -2041,7 +2046,7 @@ def transformer(
20412046 to be made to each individual transform container at one time.
20422047 max_payload (int): Maximum size of the payload in a single HTTP
20432048 request to the container in MB.
2044- tags (list[dict ]): List of tags for labeling a transform job. If
2049+ tags (Optional[Tags ]): Tags for labeling a transform job. If
20452050 none specified, then the tags used for the training job are used
20462051 for the transform job.
20472052 role (str): The ``ExecutionRoleArn`` IAM Role ARN for the ``Model``,
@@ -2068,7 +2073,7 @@ def transformer(
20682073 model. If not specified, the estimator generates a default job name
20692074 based on the training image name and current timestamp.
20702075 """
2071- tags = tags or self .tags
2076+ tags = format_tags ( tags ) or self .tags
20722077 model_name = self ._get_or_create_name (model_name )
20732078
20742079 if self .latest_training_job is None :
@@ -2661,7 +2666,7 @@ def __init__(
26612666 base_job_name : Optional [str ] = None ,
26622667 sagemaker_session : Optional [Session ] = None ,
26632668 hyperparameters : Optional [Dict [str , Union [str , PipelineVariable ]]] = None ,
2664- tags : Optional [List [ Dict [ str , Union [ str , PipelineVariable ]]] ] = None ,
2669+ tags : Optional [Tags ] = None ,
26652670 subnets : Optional [List [Union [str , PipelineVariable ]]] = None ,
26662671 security_group_ids : Optional [List [Union [str , PipelineVariable ]]] = None ,
26672672 model_uri : Optional [str ] = None ,
@@ -2790,7 +2795,7 @@ def __init__(
27902795 hyperparameters. SageMaker rejects the training job request and returns an
27912796 validation error for detected credentials, if such user input is found.
27922797
2793- tags (list[dict[str, str] or list[dict[str, PipelineVariable]] ): List of tags for
2798+ tags (Optional[Tags] ): Tags for
27942799 labeling a training job. For more, see
27952800 https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html.
27962801 subnets (list[str] or list[PipelineVariable]): List of subnet ids.
@@ -3071,7 +3076,7 @@ def __init__(
30713076 output_kms_key ,
30723077 base_job_name ,
30733078 sagemaker_session ,
3074- tags ,
3079+ format_tags ( tags ) ,
30753080 subnets ,
30763081 security_group_ids ,
30773082 model_uri = model_uri ,
@@ -3702,7 +3707,7 @@ def transformer(
37023707 env = None ,
37033708 max_concurrent_transforms = None ,
37043709 max_payload = None ,
3705- tags = None ,
3710+ tags : Optional [ Tags ] = None ,
37063711 role = None ,
37073712 model_server_workers = None ,
37083713 volume_kms_key = None ,
@@ -3738,7 +3743,7 @@ def transformer(
37383743 to be made to each individual transform container at one time.
37393744 max_payload (int): Maximum size of the payload in a single HTTP
37403745 request to the container in MB.
3741- tags (list[dict ]): List of tags for labeling a transform job. If
3746+ tags (Optional[Tags ]): Tags for labeling a transform job. If
37423747 none specified, then the tags used for the training job are used
37433748 for the transform job.
37443749 role (str): The ``ExecutionRoleArn`` IAM Role ARN for the ``Model``,
@@ -3777,7 +3782,7 @@ def transformer(
37773782 SageMaker Batch Transform job.
37783783 """
37793784 role = role or self .role
3780- tags = tags or self .tags
3785+ tags = format_tags ( tags ) or self .tags
37813786 model_name = self ._get_or_create_name (model_name )
37823787
37833788 if self .latest_training_job is not None :
0 commit comments