From 17ab3342e8f5ccb61a925db928964a2d610d3bbc Mon Sep 17 00:00:00 2001 From: Keshav Chandak Date: Wed, 12 Apr 2023 08:31:39 +0000 Subject: [PATCH 1/2] Feat: Add/Remove model package group from collection --- src/sagemaker/collection.py | 211 ++++++++++++++++++++++++++++----- tests/integ/test_collection.py | 106 +++++++++++++++++ tests/unit/test_collection.py | 52 +++++++- 3 files changed, 341 insertions(+), 28 deletions(-) diff --git a/src/sagemaker/collection.py b/src/sagemaker/collection.py index 384119eb1e..5953d742cc 100644 --- a/src/sagemaker/collection.py +++ b/src/sagemaker/collection.py @@ -56,6 +56,41 @@ def _check_access_error(self, err: ClientError): "https://docs.aws.amazon.com/sagemaker/latest/dg/modelcollections-permissions.html" ) + def _add_model_group(self, model_package_group, tag_rule_key, tag_rule_value): + """To add a model package group to a collection + + Args: + model_package_group (str): The name of the model package group + tag_rule_key (str): The tag key of the corresponing collection to be added into + tag_rule_value (str): The tag value of the corresponing collection to be added into + """ + model_group_details = self.sagemaker_session.sagemaker_client.describe_model_package_group( + ModelPackageGroupName=model_package_group + ) + self.sagemaker_session.sagemaker_client.add_tags( + ResourceArn=model_group_details["ModelPackageGroupArn"], + Tags=[ + { + "Key": tag_rule_key, + "Value": tag_rule_value, + } + ], + ) + + def _remove_model_group(self, model_package_group, tag_rule_key): + """To remove a model package group from a collection + + Args: + model_package_group (str): The name of the model package group + tag_rule_key (str): The tag key of the corresponing collection to be removed from + """ + model_group_details = self.sagemaker_session.sagemaker_client.describe_model_package_group( + ModelPackageGroupName=model_package_group + ) + self.sagemaker_session.sagemaker_client.delete_tags( + ResourceArn=model_group_details["ModelPackageGroupArn"], TagKeys=[tag_rule_key] + ) + def create(self, collection_name: str, parent_collection_name: str = None): """Creates a collection @@ -65,7 +100,7 @@ def create(self, collection_name: str, parent_collection_name: str = None): To be None if the collection is to be created on the root level """ - tag_rule_key = f"sagemaker:collection-path:{time.time()}" + tag_rule_key = f"sagemaker:collection-path:{int(time.time() * 1000)}" tags_on_collection = { "sagemaker:collection": "true", "sagemaker:collection-path:root": "true", @@ -73,30 +108,14 @@ def create(self, collection_name: str, parent_collection_name: str = None): tag_rule_values = [collection_name] if parent_collection_name is not None: - try: - group_query = self.sagemaker_session.get_resource_group_query( - group=parent_collection_name - ) - except ClientError as e: - error_code = e.response["Error"]["Code"] - - if error_code == "NotFoundException": - raise ValueError(f"Cannot find collection: {parent_collection_name}") - self._check_access_error(err=e) - raise - if group_query.get("GroupQuery"): - parent_tag_rule_query = json.loads( - group_query["GroupQuery"].get("ResourceQuery", {}).get("Query", "") - ) - parent_tag_rule = parent_tag_rule_query.get("TagFilters", [])[0] - if not parent_tag_rule: - raise "Invalid parent_collection_name" - parent_tag_value = parent_tag_rule["Values"][0] - tags_on_collection = { - parent_tag_rule["Key"]: parent_tag_value, - "sagemaker:collection": "true", - } - tag_rule_values = [f"{parent_tag_value}/{collection_name}"] + parent_tag_rules = self._get_collection_tag_rule(collection_name=parent_collection_name) + parent_tag_rule_key = parent_tag_rules["tag_rule_key"] + parent_tag_value = parent_tag_rules["tag_rule_value"] + tags_on_collection = { + parent_tag_rule_key: parent_tag_value, + "sagemaker:collection": "true", + } + tag_rule_values = [f"{parent_tag_value}/{collection_name}"] try: resource_filters = [ "AWS::SageMaker::ModelPackageGroup", @@ -122,7 +141,6 @@ def create(self, collection_name: str, parent_collection_name: str = None): "Name": collection_create_response["Group"]["Name"], "Arn": collection_create_response["Group"]["GroupArn"], } - except ClientError as e: message = e.response["Error"]["Message"] error_code = e.response["Error"]["Code"] @@ -134,7 +152,7 @@ def create(self, collection_name: str, parent_collection_name: str = None): raise def delete(self, collections: List[str]): - """Deletes a lits of collection + """Deletes a list of collection. Args: collections (List[str]): List of collections to be deleted @@ -152,6 +170,8 @@ def delete(self, collections: List[str]): "Values": ["AWS::ResourceGroups::Group", "AWS::SageMaker::ModelPackageGroup"], }, ] + + # loops over the list of collection and deletes one at a time. for collection in collections: try: collection_details = self.sagemaker_session.list_group_resources( @@ -180,3 +200,140 @@ def delete(self, collections: List[str]): "deleted_collections": deleted_collection, "delete_collection_failures": delete_collection_failures, } + + def _get_collection_tag_rule(self, collection_name: str): + """Returns the tag rule key and value for a collection""" + + if collection_name is not None: + try: + group_query = self.sagemaker_session.get_resource_group_query(group=collection_name) + except ClientError as e: + error_code = e.response["Error"]["Code"] + + if error_code == "NotFoundException": + raise ValueError(f"Cannot find collection: {collection_name}") + self._check_access_error(err=e) + raise + if group_query.get("GroupQuery"): + tag_rule_query = json.loads( + group_query["GroupQuery"].get("ResourceQuery", {}).get("Query", "") + ) + tag_rule = tag_rule_query.get("TagFilters", [])[0] + if not tag_rule: + raise "Unsupported parent_collection_name" + tag_rule_value = tag_rule["Values"][0] + tag_rule_key = tag_rule["Key"] + + return { + "tag_rule_key": tag_rule_key, + "tag_rule_value": tag_rule_value, + } + raise ValueError("Collection name is required") + + def add_model_groups(self, collection_name: str, model_groups: List[str]): + """To add list of model package groups to a collection + + Args: + collection_name (str): The name of the collection + model_groups List[str]: Model pckage group names list to be added into the collection + """ + if len(model_groups) > 10: + raise Exception("Model groups can have a maximum length of 10") + tag_rules = self._get_collection_tag_rule(collection_name=collection_name) + tag_rule_key = tag_rules["tag_rule_key"] + tag_rule_value = tag_rules["tag_rule_value"] + + add_groups_success = [] + add_groups_failure = [] + if tag_rule_key is not None and tag_rule_value is not None: + for model_group in model_groups: + try: + self._add_model_group( + model_package_group=model_group, + tag_rule_key=tag_rule_key, + tag_rule_value=tag_rule_value, + ) + add_groups_success.append(model_group) + except ClientError as e: + self._check_access_error(err=e) + message = e.response["Error"]["Message"] + add_groups_failure.append( + { + "model_group": model_group, + "failure_reason": message, + } + ) + return { + "added_groups": add_groups_success, + "failure": add_groups_failure, + } + + def remove_model_groups(self, collection_name: str, model_groups: List[str]): + """To remove list of model package groups from a collection + + Args: + collection_name (str): The name of the collection + model_groups List[str]: Model package group names list to be removed + """ + + if len(model_groups) > 10: + raise Exception("Model groups can have a maximum length of 10") + tag_rules = self._get_collection_tag_rule(collection_name=collection_name) + + tag_rule_key = tag_rules["tag_rule_key"] + tag_rule_value = tag_rules["tag_rule_value"] + + remove_groups_success = [] + remove_groups_failure = [] + if tag_rule_key is not None and tag_rule_value is not None: + for model_group in model_groups: + try: + self._remove_model_group( + model_package_group=model_group, + tag_rule_key=tag_rule_key, + ) + remove_groups_success.append(model_group) + except ClientError as e: + self._check_access_error(err=e) + message = e.response["Error"]["Message"] + remove_groups_failure.append( + { + "model_group": model_group, + "failure_reason": message, + } + ) + return { + "removed_groups": remove_groups_success, + "failure": remove_groups_failure, + } + + def move_model_group( + self, source_collection_name: str, model_group: str, destination_collection_name: str + ): + """To move a model package group from one collection to another + + Args: + source_collection_name (str): Collection name of the source + model_group (str): Model package group names which is to be moved + destination_collection_name (str): Collection name of the destination + """ + remove_details = self.remove_model_groups( + collection_name=source_collection_name, model_groups=[model_group] + ) + if len(remove_details["failure"]) == 1: + raise Exception(remove_details["failure"][0]["failure"]) + + added_details = self.add_model_groups( + collection_name=destination_collection_name, model_groups=[model_group] + ) + + if len(added_details["failure"]) == 1: + # adding the model group back to the source collection in case of an add failure + self.add_model_groups( + collection_name=source_collection_name, model_groups=[model_group] + ) + raise Exception(added_details["failure"][0]["failure"]) + + return { + "moved_success": model_group, + } diff --git a/tests/integ/test_collection.py b/tests/integ/test_collection.py index d1dfcec0ec..91d82ef579 100644 --- a/tests/integ/test_collection.py +++ b/tests/integ/test_collection.py @@ -60,3 +60,109 @@ def test_create_collection_nested_success(sagemaker_session): delete_response = collection.delete([child_collection_name, collection_name]) assert len(delete_response["deleted_collections"]) == 2 assert len(delete_response["delete_collection_failures"]) == 0 + + +def test_add_remove_model_groups_in_collection_success(sagemaker_session): + model_group_name = unique_name_from_base("test-model-group") + sagemaker_session.sagemaker_client.create_model_package_group( + ModelPackageGroupName=model_group_name + ) + collection = Collection(sagemaker_session) + collection_name = unique_name_from_base("test-collection") + collection.create(collection_name) + model_groups = [] + model_groups.append(model_group_name) + add_response = collection.add_model_groups( + collection_name=collection_name, model_groups=model_groups + ) + collection_filter = [ + { + "Name": "resource-type", + "Values": ["AWS::ResourceGroups::Group", "AWS::SageMaker::ModelPackageGroup"], + }, + ] + collection_details = sagemaker_session.list_group_resources( + group=collection_name, filters=collection_filter + ) + + assert len(add_response["failure"]) == 0 + assert len(add_response["added_groups"]) == 1 + assert len(collection_details["Resources"]) == 1 + + remove_response = collection.remove_model_groups( + collection_name=collection_name, model_groups=model_groups + ) + collection_details = sagemaker_session.list_group_resources( + group=collection_name, filters=collection_filter + ) + assert len(remove_response["failure"]) == 0 + assert len(remove_response["removed_groups"]) == 1 + assert len(collection_details["Resources"]) == 0 + + delete_response = collection.delete([collection_name]) + assert len(delete_response["deleted_collections"]) == 1 + sagemaker_session.sagemaker_client.delete_model_package_group( + ModelPackageGroupName=model_group_name + ) + + +def test_move_model_groups_in_collection_success(sagemaker_session): + model_group_name = unique_name_from_base("test-model-group") + sagemaker_session.sagemaker_client.create_model_package_group( + ModelPackageGroupName=model_group_name + ) + collection = Collection(sagemaker_session) + source_collection_name = unique_name_from_base("test-collection-source") + destination_collection_name = unique_name_from_base("test-collection-destination") + collection.create(source_collection_name) + collection.create(destination_collection_name) + model_groups = [] + model_groups.append(model_group_name) + add_response = collection.add_model_groups( + collection_name=source_collection_name, model_groups=model_groups + ) + collection_filter = [ + { + "Name": "resource-type", + "Values": ["AWS::ResourceGroups::Group", "AWS::SageMaker::ModelPackageGroup"], + }, + ] + collection_details = sagemaker_session.list_group_resources( + group=source_collection_name, filters=collection_filter + ) + + assert len(add_response["failure"]) == 0 + assert len(add_response["added_groups"]) == 1 + assert len(collection_details["Resources"]) == 1 + + move_response = collection.move_model_group( + source_collection_name=source_collection_name, + model_group=model_group_name, + destination_collection_name=destination_collection_name, + ) + + assert move_response["moved_success"] == model_group_name + + collection_details = sagemaker_session.list_group_resources( + group=destination_collection_name, filters=collection_filter + ) + + assert len(collection_details["Resources"]) == 1 + + collection_details = sagemaker_session.list_group_resources( + group=source_collection_name, filters=collection_filter + ) + assert len(collection_details["Resources"]) == 0 + + remove_response = collection.remove_model_groups( + collection_name=destination_collection_name, model_groups=model_groups + ) + + assert len(remove_response["failure"]) == 0 + assert len(remove_response["removed_groups"]) == 1 + + delete_response = collection.delete([source_collection_name, destination_collection_name]) + assert len(delete_response["deleted_collections"]) == 2 + sagemaker_session.sagemaker_client.delete_model_package_group( + ModelPackageGroupName=model_group_name + ) diff --git a/tests/unit/test_collection.py b/tests/unit/test_collection.py index 81941c0a9c..f62c141958 100644 --- a/tests/unit/test_collection.py +++ b/tests/unit/test_collection.py @@ -26,6 +26,9 @@ {"Key": "sagemaker:collection-path:1676120428.4811652", "Values": ["test-collection-k"]} ], } +DESCRIBE_MODEL_PACKAGE_GROUP = { + "ModelPackageGroupArn": "arn:aws:resource-groups:us-west-2:205984106344:group/group}" +} CREATE_COLLECTION_RESPONSE = { "Group": { "GroupArn": f"arn:aws:resource-groups:us-west-2:205984106344:group/{COLLECTION_NAME}", @@ -38,6 +41,14 @@ "Tags": {"sagemaker:collection-path:root": "true"}, } +GROUP_QUERY_RESPONSE = { + "GroupQuery": { + "ResourceQuery": { + "Query": '{"TagFilters": [{"Key": "key", "Values": ["value"]}]}', + } + } +} + @pytest.fixture() def sagemaker_session(): @@ -55,7 +66,12 @@ def sagemaker_session(): ) session_mock.delete_resource_group = Mock(name="delete_resource_group", return_value=True) session_mock.list_group_resources = Mock(name="list_group_resources", return_value={}) - + session_mock.get_resource_group_query = Mock( + name="get_resource_group_query", return_value=GROUP_QUERY_RESPONSE + ) + session_mock.sagemaker_client.describe_model_package_group = Mock( + name="describe_model_package_group", return_value=DESCRIBE_MODEL_PACKAGE_GROUP + ) return session_mock @@ -81,3 +97,37 @@ def test_delete_collection_failure_when_collection_is_not_empty(sagemaker_sessio delete_response = collection.delete(collections=[COLLECTION_NAME]) assert len(delete_response["deleted_collections"]) == 0 assert len(delete_response["delete_collection_failures"]) == 1 + + +def test_add_model_groups_success(sagemaker_session): + collection = Collection(sagemaker_session) + add_response = collection.add_model_groups( + collection_name=[COLLECTION_NAME], model_groups=["test-model-group"] + ) + assert len(add_response["added_groups"]) == 1 + assert len(add_response["failure"]) == 0 + + +def test_remove_model_groups_success(sagemaker_session): + collection = Collection(sagemaker_session) + add_response = collection.remove_model_groups( + collection_name=[COLLECTION_NAME], model_groups=["test-model-group"] + ) + assert len(add_response["removed_groups"]) == 1 + assert len(add_response["failure"]) == 0 + + +def test_add_and_remove_model_groups_limit(sagemaker_session): + collection = Collection(sagemaker_session) + model_groups = [] + for i in range(11): + model_groups.append(f"test-model-group{i}") + try: + collection.add_model_groups(collection_name=[COLLECTION_NAME], model_groups=model_groups) + except Exception as e: + assert "Model groups can have a maximum length of 10" in str(e) + + try: + collection.remove_model_groups(collection_name=[COLLECTION_NAME], model_groups=model_groups) + except Exception as e: + assert "Model groups can have a maximum length of 10" in str(e) From dd77850faac7828642301a79c0347b4416a7c631 Mon Sep 17 00:00:00 2001 From: Keshav Chandak Date: Wed, 12 Apr 2023 08:39:06 +0000 Subject: [PATCH 2/2] Feat: list model collection --- src/sagemaker/collection.py | 125 ++++++++++++++++++++++++++++++++- src/sagemaker/session.py | 40 ++++++++++- tests/integ/test_collection.py | 30 ++++++++ 3 files changed, 192 insertions(+), 3 deletions(-) diff --git a/src/sagemaker/collection.py b/src/sagemaker/collection.py index 5953d742cc..7703b14b4d 100644 --- a/src/sagemaker/collection.py +++ b/src/sagemaker/collection.py @@ -147,7 +147,6 @@ def create(self, collection_name: str, parent_collection_name: str = None): if error_code == "BadRequestException" and "group already exists" in message: raise ValueError("Collection with the given name already exists") - self._check_access_error(err=e) raise @@ -337,3 +336,127 @@ def move_model_group( return { "moved_success": model_group, } + + def _convert_tag_collection_response(self, tag_collections: List[str]): + """Converts collection response from tag api to collection list response + + Args: + tag_collections List[dict]: Collections list response from tag api + """ + collection_details = [] + for collection in tag_collections: + collection_arn = collection["ResourceARN"] + collection_name = collection_arn.split("group/")[1] + collection_details.append( + { + "Name": collection_name, + "Arn": collection_arn, + "Type": "Collection", + } + ) + return collection_details + + def _convert_group_resource_response( + self, group_resource_details: List[dict], is_model_group: bool = False + ): + """Converts collection response from resource group api to collection list response + + Args: + group_resource_details (List[dict]): Collections list response from resource group api + is_model_group (bool): If the reponse is of collection or model group type + """ + collection_details = [] + if group_resource_details["Resources"]: + for resource_group in group_resource_details["Resources"]: + collection_arn = resource_group["Identifier"]["ResourceArn"] + collection_name = collection_arn.split("group/")[1] + collection_details.append( + { + "Name": collection_name, + "Arn": collection_arn, + "Type": resource_group["Identifier"]["ResourceType"] + if is_model_group + else "Collection", + } + ) + return collection_details + + def _get_full_list_resource(self, collection_name, collection_filter): + """Iterating to the full resource group list and returns appended paginated response + + Args: + collection_name (str): Name of the collection to get the details + collection_filter (dict): Filter details to be passed to get the resource list + + """ + list_group_response = self.sagemaker_session.list_group_resources( + group=collection_name, filters=collection_filter + ) + next_token = list_group_response.get("NextToken") + while next_token is not None: + + paginated_group_response = self.sagemaker_session.list_group_resources( + group=collection_name, + filters=collection_filter, + next_token=next_token, + ) + list_group_response["Resources"] = ( + list_group_response["Resources"] + paginated_group_response["Resources"] + ) + list_group_response["ResourceIdentifiers"] = ( + list_group_response["ResourceIdentifiers"] + + paginated_group_response["ResourceIdentifiers"] + ) + next_token = paginated_group_response.get("NextToken") + + return list_group_response + + def list_collection(self, collection_name: str = None): + """To all list the collections and content of the collections + + In case there is no collection_name, it lists all the collections on the root level + + Args: + collection_name (str): The name of the collection to list the contents of + """ + collection_content = [] + if collection_name is None: + tag_filters = [ + { + "Key": "sagemaker:collection-path:root", + "Values": ["true"], + }, + ] + resource_type_filters = ["resource-groups:group"] + tag_collections = self.sagemaker_session.get_tagging_resources( + tag_filters=tag_filters, resource_type_filters=resource_type_filters + ) + + return self._convert_tag_collection_response(tag_collections) + + collection_filter = [ + { + "Name": "resource-type", + "Values": ["AWS::ResourceGroups::Group"], + }, + ] + list_group_response = self._get_full_list_resource( + collection_name=collection_name, collection_filter=collection_filter + ) + collection_content = self._convert_group_resource_response(list_group_response) + + collection_filter = [ + { + "Name": "resource-type", + "Values": ["AWS::SageMaker::ModelPackageGroup"], + }, + ] + list_group_response = self._get_full_list_resource( + collection_name=collection_name, collection_filter=collection_filter + ) + + collection_content = collection_content + self._convert_group_resource_response( + list_group_response, True + ) + + return collection_content diff --git a/src/sagemaker/session.py b/src/sagemaker/session.py index 9dfe43cfb3..c0e9112d7c 100644 --- a/src/sagemaker/session.py +++ b/src/sagemaker/session.py @@ -205,6 +205,7 @@ def __init__( self.s3_resource = None self.s3_client = None self.resource_groups_client = None + self.resource_group_tagging_client = None self.config = None self.lambda_client = None self.settings = settings @@ -3962,7 +3963,7 @@ def delete_model(self, model_name): LOGGER.info("Deleting model with name: %s", model_name) self.sagemaker_client.delete_model(ModelName=model_name) - def list_group_resources(self, group, filters): + def list_group_resources(self, group, filters, next_token: str = ""): """To list group resources with given filters Args: @@ -3972,7 +3973,9 @@ def list_group_resources(self, group, filters): self.resource_groups_client = self.resource_groups_client or self.boto_session.client( "resource-groups" ) - return self.resource_groups_client.list_group_resources(Group=group, Filters=filters) + return self.resource_groups_client.list_group_resources( + Group=group, Filters=filters, NextToken=next_token + ) def delete_resource_group(self, group): """To delete a resource group @@ -3996,6 +3999,39 @@ def get_resource_group_query(self, group): ) return self.resource_groups_client.get_group_query(Group=group) + def get_tagging_resources(self, tag_filters, resource_type_filters): + """To list the complete resources for a particular resource group tag + + tag_filters: filters for the tag + resource_type_filters: resource filter for the tag + """ + self.resource_group_tagging_client = ( + self.resource_group_tagging_client + or self.boto_session.client("resourcegroupstaggingapi") + ) + resource_list = [] + + try: + resource_tag_response = self.resource_group_tagging_client.get_resources( + TagFilters=tag_filters, ResourceTypeFilters=resource_type_filters + ) + + resource_list = resource_list + resource_tag_response["ResourceTagMappingList"] + + next_token = resource_tag_response.get("PaginationToken") + while next_token is not None and next_token != "": + resource_tag_response = self.resource_group_tagging_client.get_resources( + TagFilters=tag_filters, + ResourceTypeFilters=resource_type_filters, + NextToken=next_token, + ) + resource_list = resource_list + resource_tag_response["ResourceTagMappingList"] + next_token = resource_tag_response.get("PaginationToken") + + return resource_list + except ClientError as error: + raise error + def create_group(self, name, resource_query, tags): """To create a AWS Resource Group diff --git a/tests/integ/test_collection.py b/tests/integ/test_collection.py index 91d82ef579..2ee1d90e34 100644 --- a/tests/integ/test_collection.py +++ b/tests/integ/test_collection.py @@ -166,3 +166,33 @@ def test_move_model_groups_in_collection_success(sagemaker_session): sagemaker_session.sagemaker_client.delete_model_package_group( ModelPackageGroupName=model_group_name ) + + +def test_list_collection_success(sagemaker_session): + model_group_name = unique_name_from_base("test-model-group") + sagemaker_session.sagemaker_client.create_model_package_group( + ModelPackageGroupName=model_group_name + ) + collection = Collection(sagemaker_session) + collection_name = unique_name_from_base("test-collection") + collection.create(collection_name) + model_groups = [] + model_groups.append(model_group_name) + collection.add_model_groups(collection_name=collection_name, model_groups=model_groups) + child_collection_name = unique_name_from_base("test-collection") + collection.create(parent_collection_name=collection_name, collection_name=child_collection_name) + root_collections = collection.list_collection() + is_collection_found = False + for root_collection in root_collections: + if root_collection["Name"] == collection_name: + is_collection_found = True + assert is_collection_found + + collection_content = collection.list_collection(collection_name) + assert len(collection_content) == 2 + + collection.remove_model_groups(collection_name=collection_name, model_groups=model_groups) + collection.delete([child_collection_name, collection_name]) + sagemaker_session.sagemaker_client.delete_model_package_group( + ModelPackageGroupName=model_group_name + )