Skip to content

Commit e6dd02d

Browse files
martinRenouakrishna1995
authored andcommitted
More tags formatting and add a test
1 parent b76e165 commit e6dd02d

31 files changed

+167
-119
lines changed

src/sagemaker/algorithm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -392,7 +392,7 @@ def transformer(
392392
if self._is_marketplace():
393393
transform_env = None
394394

395-
tags = tags or self.tags
395+
tags = format_tags(tags) or self.tags
396396
else:
397397
raise RuntimeError("No finished training job found associated with this estimator")
398398

src/sagemaker/apiutils/_base_types.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from __future__ import absolute_import
1515

1616
from sagemaker.apiutils import _boto_functions, _utils
17+
from sagemaker.utils import format_tags
1718

1819

1920
class ApiObject(object):
@@ -194,13 +195,13 @@ def _set_tags(self, resource_arn=None, tags=None):
194195
195196
Args:
196197
resource_arn (str): The arn of the Record
197-
tags (dict): An array of Tag objects that set to Record
198+
tags (Optional[Tags]): An array of Tag objects that set to Record
198199
199200
Returns:
200201
A list of key, value pair objects. i.e. [{"key":"value"}]
201202
"""
202203
tag_list = self.sagemaker_session.sagemaker_client.add_tags(
203-
ResourceArn=resource_arn, Tags=tags
204+
ResourceArn=resource_arn, Tags=format_tags(tags)
204205
)["Tags"]
205206
return tag_list
206207

src/sagemaker/automl/automl.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -580,7 +580,7 @@ def deploy(
580580
be selected on each ``deploy``.
581581
endpoint_name (str): The name of the endpoint to create (default:
582582
None). If not specified, a unique endpoint name will be created.
583-
tags (List[dict[str, str]]): The list of tags to attach to this
583+
tags (Optional[Tags]): The list of tags to attach to this
584584
specific endpoint.
585585
wait (bool): Whether the call should wait until the deployment of
586586
model completes (default: True).
@@ -632,7 +632,7 @@ def deploy(
632632
deserializer=deserializer,
633633
endpoint_name=endpoint_name,
634634
kms_key=model_kms_key,
635-
tags=tags,
635+
tags=format_tags(tags),
636636
wait=wait,
637637
volume_size=volume_size,
638638
model_data_download_timeout=model_data_download_timeout,

src/sagemaker/base_predictor.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@
5353
NumpySerializer,
5454
)
5555
from sagemaker.session import production_variant, Session
56-
from sagemaker.utils import name_from_base, stringify_object
56+
from sagemaker.utils import name_from_base, stringify_object, format_tags
5757

5858
from sagemaker.model_monitor.model_monitoring import DEFAULT_REPOSITORY_NAME
5959

@@ -409,7 +409,7 @@ def update_endpoint(
409409
self.sagemaker_session.create_endpoint_config_from_existing(
410410
current_endpoint_config_name,
411411
new_endpoint_config_name,
412-
new_tags=tags,
412+
new_tags=format_tags(tags),
413413
new_kms_key=kms_key,
414414
new_data_capture_config_dict=data_capture_config_dict,
415415
new_production_variants=production_variants,

src/sagemaker/djl_inference/model.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
from sagemaker.s3_utils import s3_path_join
3131
from sagemaker.serializers import JSONSerializer, BaseSerializer
3232
from sagemaker.session import Session
33-
from sagemaker.utils import _tmpdir, _create_or_update_code_dir
33+
from sagemaker.utils import _tmpdir, _create_or_update_code_dir, format_tags
3434
from sagemaker.workflow.entities import PipelineVariable
3535
from sagemaker.estimator import Estimator
3636
from sagemaker.s3 import S3Uploader
@@ -610,7 +610,7 @@ def deploy(
610610
default deserializer is set by the ``predictor_cls``.
611611
endpoint_name (str): The name of the endpoint to create (default:
612612
None). If not specified, a unique endpoint name will be created.
613-
tags (List[dict[str, str]]): The list of tags to attach to this
613+
tags (Optional[Tags]): The list of tags to attach to this
614614
specific endpoint.
615615
kms_key (str): The ARN of the KMS key that is used to encrypt the
616616
data on the storage volume attached to the instance hosting the
@@ -651,7 +651,7 @@ def deploy(
651651
serializer=serializer,
652652
deserializer=deserializer,
653653
endpoint_name=endpoint_name,
654-
tags=tags,
654+
tags=format_tags(tags),
655655
kms_key=kms_key,
656656
wait=wait,
657657
data_capture_config=data_capture_config,

src/sagemaker/experiments/experiment.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from sagemaker.apiutils import _base_types
2121
from sagemaker.experiments.trial import _Trial
2222
from sagemaker.experiments.trial_component import _TrialComponent
23+
from sagemaker.utils import format_tags
2324

2425

2526
class Experiment(_base_types.Record):
@@ -111,7 +112,7 @@ def create(
111112
manages interactions with Amazon SageMaker APIs and any other
112113
AWS services needed. If not specified, one is created using the
113114
default AWS configuration chain.
114-
tags (List[Dict[str, str]]): A list of tags to associate with the experiment
115+
tags (Optional[Tags]): A list of tags to associate with the experiment
115116
(default: None).
116117
117118
Returns:
@@ -122,7 +123,7 @@ def create(
122123
experiment_name=experiment_name,
123124
display_name=display_name,
124125
description=description,
125-
tags=tags,
126+
tags=format_tags(tags),
126127
sagemaker_session=sagemaker_session,
127128
)
128129

@@ -149,7 +150,7 @@ def _load_or_create(
149150
manages interactions with Amazon SageMaker APIs and any other
150151
AWS services needed. If not specified, one is created using the
151152
default AWS configuration chain.
152-
tags (List[Dict[str, str]]): A list of tags to associate with the experiment
153+
tags (Optional[Tags]): A list of tags to associate with the experiment
153154
(default: None). This is used only when the given `experiment_name` does not
154155
exist and a new experiment has to be created.
155156
@@ -161,7 +162,7 @@ def _load_or_create(
161162
experiment_name=experiment_name,
162163
display_name=display_name,
163164
description=description,
164-
tags=tags,
165+
tags=format_tags(tags),
165166
sagemaker_session=sagemaker_session,
166167
)
167168
except ClientError as ce:

src/sagemaker/experiments/trial.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from sagemaker.apiutils import _base_types
1919
from sagemaker.experiments import _api_types
2020
from sagemaker.experiments.trial_component import _TrialComponent
21+
from sagemaker.utils import format_tags
2122

2223

2324
class _Trial(_base_types.Record):
@@ -101,7 +102,7 @@ def create(
101102
trial_name: (str): Name of the Trial.
102103
display_name (str): Name of the trial that will appear in UI,
103104
such as SageMaker Studio (default: None).
104-
tags (List[dict]): A list of tags to associate with the trial (default: None).
105+
tags (Optional[Tags]): A list of tags to associate with the trial (default: None).
105106
sagemaker_session (sagemaker.session.Session): Session object which
106107
manages interactions with Amazon SageMaker APIs and any other
107108
AWS services needed. If not specified, one is created using the
@@ -115,7 +116,7 @@ def create(
115116
trial_name=trial_name,
116117
experiment_name=experiment_name,
117118
display_name=display_name,
118-
tags=tags,
119+
tags=format_tags(tags),
119120
sagemaker_session=sagemaker_session,
120121
)
121122
return trial
@@ -259,7 +260,7 @@ def _load_or_create(
259260
display_name (str): Name of the trial that will appear in UI,
260261
such as SageMaker Studio (default: None). This is used only when the given
261262
`trial_name` does not exist and a new trial has to be created.
262-
tags (List[dict]): A list of tags to associate with the trial (default: None).
263+
tags (Optional[Tags]): A list of tags to associate with the trial (default: None).
263264
This is used only when the given `trial_name` does not exist and
264265
a new trial has to be created.
265266
sagemaker_session (sagemaker.session.Session): Session object which
@@ -275,7 +276,7 @@ def _load_or_create(
275276
experiment_name=experiment_name,
276277
trial_name=trial_name,
277278
display_name=display_name,
278-
tags=tags,
279+
tags=format_tags(tags),
279280
sagemaker_session=sagemaker_session,
280281
)
281282
except ClientError as ce:

src/sagemaker/experiments/trial_component.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from sagemaker.apiutils import _base_types
2121
from sagemaker.experiments import _api_types
2222
from sagemaker.experiments._api_types import TrialComponentSearchResult
23+
from sagemaker.utils import format_tags
2324

2425

2526
class _TrialComponent(_base_types.Record):
@@ -191,7 +192,7 @@ def create(cls, trial_component_name, display_name=None, tags=None, sagemaker_se
191192
Args:
192193
trial_component_name (str): The name of the trial component.
193194
display_name (str): Display name of the trial component used by Studio (default: None).
194-
tags (List[Dict[str, str]]): Tags to add to the trial component (default: None).
195+
tags (Optional[Tags]): Tags to add to the trial component (default: None).
195196
sagemaker_session (sagemaker.session.Session): Session object which
196197
manages interactions with Amazon SageMaker APIs and any other
197198
AWS services needed. If not specified, one is created using the
@@ -204,7 +205,7 @@ def create(cls, trial_component_name, display_name=None, tags=None, sagemaker_se
204205
cls._boto_create_method,
205206
trial_component_name=trial_component_name,
206207
display_name=display_name,
207-
tags=tags,
208+
tags=format_tags(tags),
208209
sagemaker_session=sagemaker_session,
209210
)
210211

@@ -316,7 +317,7 @@ def _load_or_create(
316317
display_name (str): Display name of the trial component used by Studio (default: None).
317318
This is used only when the given `trial_component_name` does not
318319
exist and a new trial component has to be created.
319-
tags (List[Dict[str, str]]): Tags to add to the trial component (default: None).
320+
tags (Optional[Tags]): Tags to add to the trial component (default: None).
320321
This is used only when the given `trial_component_name` does not
321322
exist and a new trial component has to be created.
322323
sagemaker_session (sagemaker.session.Session): Session object which
@@ -333,7 +334,7 @@ def _load_or_create(
333334
run_tc = _TrialComponent.create(
334335
trial_component_name=trial_component_name,
335336
display_name=display_name,
336-
tags=tags,
337+
tags=format_tags(tags),
337338
sagemaker_session=sagemaker_session,
338339
)
339340
except ClientError as ce:

src/sagemaker/huggingface/model.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
from sagemaker.predictor import Predictor
3030
from sagemaker.serializers import JSONSerializer
3131
from sagemaker.session import Session
32-
from sagemaker.utils import to_string
32+
from sagemaker.utils import to_string, format_tags
3333
from sagemaker.workflow import is_pipeline_variable
3434
from sagemaker.workflow.entities import PipelineVariable
3535

@@ -255,7 +255,7 @@ def deploy(
255255
https://docs.aws.amazon.com/sagemaker/latest/dg/ei.html
256256
endpoint_name (str): The name of the endpoint to create (default:
257257
None). If not specified, a unique endpoint name will be created.
258-
tags (List[dict[str, str]]): The list of tags to attach to this
258+
tags (Optional[Tags]): The list of tags to attach to this
259259
specific endpoint.
260260
kms_key (str): The ARN of the KMS key that is used to encrypt the
261261
data on the storage volume attached to the instance hosting the
@@ -319,7 +319,7 @@ def deploy(
319319
deserializer,
320320
accelerator_type,
321321
endpoint_name,
322-
tags,
322+
format_tags(tags),
323323
kms_key,
324324
wait,
325325
data_capture_config,

src/sagemaker/jumpstart/model.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -388,7 +388,7 @@ def _create_sagemaker_model(
388388
attach to an endpoint for model loading and inference, for
389389
example, 'ml.eia1.medium'. If not specified, no Elastic
390390
Inference accelerator will be attached to the endpoint. (Default: None).
391-
tags (List[dict[str, str]]): Optional. The list of tags to add to
391+
tags (Optional[Tags]): Optional. The list of tags to add to
392392
the model. Example: >>> tags = [{'Key': 'tagname', 'Value':
393393
'tagvalue'}] For more information about tags, see
394394
https://boto3.amazonaws.com/v1/documentation
@@ -402,6 +402,8 @@ def _create_sagemaker_model(
402402
any so they are ignored.
403403
"""
404404

405+
tags = format_tags(tags)
406+
405407
# if the user inputs a model artifact uri, do not use model package arn to create
406408
# inference endpoint.
407409
if self.model_package_arn and not self._model_data_is_set:

0 commit comments

Comments
 (0)