From 3d4f67f88de54f9457a2a2b0d5ec3e88d58a7b4e Mon Sep 17 00:00:00 2001 From: Payton Staub Date: Fri, 18 Jun 2021 12:13:19 -0700 Subject: [PATCH 1/2] Add tags argument to RegisterModel step --- src/sagemaker/model.py | 3 +++ src/sagemaker/session.py | 3 +++ src/sagemaker/workflow/_utils.py | 4 ++++ src/sagemaker/workflow/step_collections.py | 5 +++++ tests/unit/sagemaker/workflow/test_step_collections.py | 2 ++ 5 files changed, 17 insertions(+) diff --git a/src/sagemaker/model.py b/src/sagemaker/model.py index b2f482a4d7..eeb14518dd 100644 --- a/src/sagemaker/model.py +++ b/src/sagemaker/model.py @@ -195,6 +195,7 @@ def _get_model_package_args( marketplace_cert=False, approval_status=None, description=None, + tags=None, ): """Get arguments for session.create_model_package method. @@ -250,6 +251,8 @@ def _get_model_package_args( model_package_args["approval_status"] = approval_status if description is not None: model_package_args["description"] = description + if tags is not None: + model_package_args["tags"] = tags return model_package_args def _init_sagemaker_session_if_does_not_exist(self, instance_type): diff --git a/src/sagemaker/session.py b/src/sagemaker/session.py index 901d61f086..980a720ac1 100644 --- a/src/sagemaker/session.py +++ b/src/sagemaker/session.py @@ -2724,6 +2724,7 @@ def _get_create_model_package_request( marketplace_cert=False, approval_status="PendingManualApproval", description=None, + tags=None, ): """Get request dictionary for CreateModelPackage API. @@ -2761,6 +2762,8 @@ def _get_create_model_package_request( request_dict["ModelPackageGroupName"] = model_package_group_name if description is not None: request_dict["ModelPackageDescription"] = description + if tags is not None: + request_dict["Tags"] = tags if model_metrics: request_dict["ModelMetrics"] = model_metrics if metadata_properties: diff --git a/src/sagemaker/workflow/_utils.py b/src/sagemaker/workflow/_utils.py index ceed4e0dec..a2ab24e3da 100644 --- a/src/sagemaker/workflow/_utils.py +++ b/src/sagemaker/workflow/_utils.py @@ -225,6 +225,7 @@ def __init__( compile_model_family=None, description=None, depends_on: List[str] = None, + tags=None, **kwargs, ): """Constructor of a register model step. @@ -264,6 +265,7 @@ def __init__( self.inference_instances = inference_instances self.transform_instances = transform_instances self.model_package_group_name = model_package_group_name + self.tags = tags self.model_metrics = model_metrics self.metadata_properties = metadata_properties self.approval_status = approval_status @@ -324,10 +326,12 @@ def arguments(self) -> RequestType: metadata_properties=self.metadata_properties, approval_status=self.approval_status, description=self.description, + tags=self.tags, ) request_dict = model.sagemaker_session._get_create_model_package_request( **model_package_args ) + # these are not available in the workflow service and will cause rejection if "CertifyForMarketplace" in request_dict: request_dict.pop("CertifyForMarketplace") diff --git a/src/sagemaker/workflow/step_collections.py b/src/sagemaker/workflow/step_collections.py index dd8f32b7fc..6a32ce465f 100644 --- a/src/sagemaker/workflow/step_collections.py +++ b/src/sagemaker/workflow/step_collections.py @@ -67,6 +67,7 @@ def __init__( image_uri=None, compile_model_family=None, description=None, + tags=None, **kwargs, ): """Construct steps `_RepackModelStep` and `_RegisterModelStep` based on the estimator. @@ -94,6 +95,9 @@ def __init__( compile_model_family (str): The instance family for the compiled model. If specified, a compiled model is used (default: None). description (str): Model Package description (default: None). + tags (List[dict[str, str]]): The list of tags to attach to the model package group. Note + that tags will only be applied to newly created model package groups; if the + name of an existing group is passed to "model_package_group_name", tags will not be applied. **kwargs: additional arguments to `create_model`. """ steps: List[Step] = [] @@ -134,6 +138,7 @@ def __init__( image_uri=image_uri, compile_model_family=compile_model_family, description=description, + tags=tags, **kwargs, ) if not repack_model: diff --git a/tests/unit/sagemaker/workflow/test_step_collections.py b/tests/unit/sagemaker/workflow/test_step_collections.py index 9719e13aec..d1086aec9e 100644 --- a/tests/unit/sagemaker/workflow/test_step_collections.py +++ b/tests/unit/sagemaker/workflow/test_step_collections.py @@ -182,6 +182,7 @@ def test_register_model(estimator, model_metrics): approval_status="Approved", description="description", depends_on=["TestStep"], + tags=[{"Key": "myKey", "Value": "myValue"}] ) assert ordered(register_model.request_dicts()) == ordered( [ @@ -210,6 +211,7 @@ def test_register_model(estimator, model_metrics): }, "ModelPackageDescription": "description", "ModelPackageGroupName": "mpg", + "Tags": [{"Key": "myKey", "Value": "myValue"}] }, }, ] From 3b4cf9686e0470d6bb0d8fd227cb69afc03c8883 Mon Sep 17 00:00:00 2001 From: Payton Staub Date: Fri, 18 Jun 2021 12:27:14 -0700 Subject: [PATCH 2/2] Pylint, black-format --- src/sagemaker/workflow/step_collections.py | 5 +++-- tests/unit/sagemaker/workflow/test_step_collections.py | 4 ++-- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/src/sagemaker/workflow/step_collections.py b/src/sagemaker/workflow/step_collections.py index 6a32ce465f..6ee048c0b2 100644 --- a/src/sagemaker/workflow/step_collections.py +++ b/src/sagemaker/workflow/step_collections.py @@ -96,8 +96,9 @@ def __init__( specified, a compiled model is used (default: None). description (str): Model Package description (default: None). tags (List[dict[str, str]]): The list of tags to attach to the model package group. Note - that tags will only be applied to newly created model package groups; if the - name of an existing group is passed to "model_package_group_name", tags will not be applied. + that tags will only be applied to newly created model package groups; if the + name of an existing group is passed to "model_package_group_name", + tags will not be applied. **kwargs: additional arguments to `create_model`. """ steps: List[Step] = [] diff --git a/tests/unit/sagemaker/workflow/test_step_collections.py b/tests/unit/sagemaker/workflow/test_step_collections.py index d1086aec9e..9ca14f4aaf 100644 --- a/tests/unit/sagemaker/workflow/test_step_collections.py +++ b/tests/unit/sagemaker/workflow/test_step_collections.py @@ -182,7 +182,7 @@ def test_register_model(estimator, model_metrics): approval_status="Approved", description="description", depends_on=["TestStep"], - tags=[{"Key": "myKey", "Value": "myValue"}] + tags=[{"Key": "myKey", "Value": "myValue"}], ) assert ordered(register_model.request_dicts()) == ordered( [ @@ -211,7 +211,7 @@ def test_register_model(estimator, model_metrics): }, "ModelPackageDescription": "description", "ModelPackageGroupName": "mpg", - "Tags": [{"Key": "myKey", "Value": "myValue"}] + "Tags": [{"Key": "myKey", "Value": "myValue"}], }, }, ]