@@ -56,6 +56,41 @@ def _check_access_error(self, err: ClientError):
5656 "https://docs.aws.amazon.com/sagemaker/latest/dg/modelcollections-permissions.html"
5757 )
5858
59+ def _add_model_group (self , model_package_group , tag_rule_key , tag_rule_value ):
60+ """To add a model package group to a collection
61+
62+ Args:
63+ model_package_group (str): The name of the model package group
64+ tag_rule_key (str): The tag key of the corresponing collection to be added into
65+ tag_rule_value (str): The tag value of the corresponing collection to be added into
66+ """
67+ model_group_details = self .sagemaker_session .sagemaker_client .describe_model_package_group (
68+ ModelPackageGroupName = model_package_group
69+ )
70+ self .sagemaker_session .sagemaker_client .add_tags (
71+ ResourceArn = model_group_details ["ModelPackageGroupArn" ],
72+ Tags = [
73+ {
74+ "Key" : tag_rule_key ,
75+ "Value" : tag_rule_value ,
76+ }
77+ ],
78+ )
79+
80+ def _remove_model_group (self , model_package_group , tag_rule_key ):
81+ """To remove a model package group from a collection
82+
83+ Args:
84+ model_package_group (str): The name of the model package group
85+ tag_rule_key (str): The tag key of the corresponing collection to be removed from
86+ """
87+ model_group_details = self .sagemaker_session .sagemaker_client .describe_model_package_group (
88+ ModelPackageGroupName = model_package_group
89+ )
90+ self .sagemaker_session .sagemaker_client .delete_tags (
91+ ResourceArn = model_group_details ["ModelPackageGroupArn" ], TagKeys = [tag_rule_key ]
92+ )
93+
5994 def create (self , collection_name : str , parent_collection_name : str = None ):
6095 """Creates a collection
6196
@@ -65,38 +100,22 @@ def create(self, collection_name: str, parent_collection_name: str = None):
65100 To be None if the collection is to be created on the root level
66101 """
67102
68- tag_rule_key = f"sagemaker:collection-path:{ time .time ()} "
103+ tag_rule_key = f"sagemaker:collection-path:{ int ( time .time () * 1000 )} "
69104 tags_on_collection = {
70105 "sagemaker:collection" : "true" ,
71106 "sagemaker:collection-path:root" : "true" ,
72107 }
73108 tag_rule_values = [collection_name ]
74109
75110 if parent_collection_name is not None :
76- try :
77- group_query = self .sagemaker_session .get_resource_group_query (
78- group = parent_collection_name
79- )
80- except ClientError as e :
81- error_code = e .response ["Error" ]["Code" ]
82-
83- if error_code == "NotFoundException" :
84- raise ValueError (f"Cannot find collection: { parent_collection_name } " )
85- self ._check_access_error (err = e )
86- raise
87- if group_query .get ("GroupQuery" ):
88- parent_tag_rule_query = json .loads (
89- group_query ["GroupQuery" ].get ("ResourceQuery" , {}).get ("Query" , "" )
90- )
91- parent_tag_rule = parent_tag_rule_query .get ("TagFilters" , [])[0 ]
92- if not parent_tag_rule :
93- raise "Invalid parent_collection_name"
94- parent_tag_value = parent_tag_rule ["Values" ][0 ]
95- tags_on_collection = {
96- parent_tag_rule ["Key" ]: parent_tag_value ,
97- "sagemaker:collection" : "true" ,
98- }
99- tag_rule_values = [f"{ parent_tag_value } /{ collection_name } " ]
111+ parent_tag_rules = self ._get_collection_tag_rule (collection_name = parent_collection_name )
112+ parent_tag_rule_key = parent_tag_rules ["tag_rule_key" ]
113+ parent_tag_value = parent_tag_rules ["tag_rule_value" ]
114+ tags_on_collection = {
115+ parent_tag_rule_key : parent_tag_value ,
116+ "sagemaker:collection" : "true" ,
117+ }
118+ tag_rule_values = [f"{ parent_tag_value } /{ collection_name } " ]
100119 try :
101120 resource_filters = [
102121 "AWS::SageMaker::ModelPackageGroup" ,
@@ -122,7 +141,6 @@ def create(self, collection_name: str, parent_collection_name: str = None):
122141 "Name" : collection_create_response ["Group" ]["Name" ],
123142 "Arn" : collection_create_response ["Group" ]["GroupArn" ],
124143 }
125-
126144 except ClientError as e :
127145 message = e .response ["Error" ]["Message" ]
128146 error_code = e .response ["Error" ]["Code" ]
@@ -134,7 +152,7 @@ def create(self, collection_name: str, parent_collection_name: str = None):
134152 raise
135153
136154 def delete (self , collections : List [str ]):
137- """Deletes a lits of collection
155+ """Deletes a list of collection.
138156
139157 Args:
140158 collections (List[str]): List of collections to be deleted
@@ -152,6 +170,8 @@ def delete(self, collections: List[str]):
152170 "Values" : ["AWS::ResourceGroups::Group" , "AWS::SageMaker::ModelPackageGroup" ],
153171 },
154172 ]
173+
174+ # loops over the list of collection and deletes one at a time.
155175 for collection in collections :
156176 try :
157177 collection_details = self .sagemaker_session .list_group_resources (
@@ -180,3 +200,140 @@ def delete(self, collections: List[str]):
180200 "deleted_collections" : deleted_collection ,
181201 "delete_collection_failures" : delete_collection_failures ,
182202 }
203+
204+ def _get_collection_tag_rule (self , collection_name : str ):
205+ """Returns the tag rule key and value for a collection"""
206+
207+ if collection_name is not None :
208+ try :
209+ group_query = self .sagemaker_session .get_resource_group_query (group = collection_name )
210+ except ClientError as e :
211+ error_code = e .response ["Error" ]["Code" ]
212+
213+ if error_code == "NotFoundException" :
214+ raise ValueError (f"Cannot find collection: { collection_name } " )
215+ self ._check_access_error (err = e )
216+ raise
217+ if group_query .get ("GroupQuery" ):
218+ tag_rule_query = json .loads (
219+ group_query ["GroupQuery" ].get ("ResourceQuery" , {}).get ("Query" , "" )
220+ )
221+ tag_rule = tag_rule_query .get ("TagFilters" , [])[0 ]
222+ if not tag_rule :
223+ raise "Unsupported parent_collection_name"
224+ tag_rule_value = tag_rule ["Values" ][0 ]
225+ tag_rule_key = tag_rule ["Key" ]
226+
227+ return {
228+ "tag_rule_key" : tag_rule_key ,
229+ "tag_rule_value" : tag_rule_value ,
230+ }
231+ raise ValueError ("Collection name is required" )
232+
233+ def add_model_groups (self , collection_name : str , model_groups : List [str ]):
234+ """To add list of model package groups to a collection
235+
236+ Args:
237+ collection_name (str): The name of the collection
238+ model_groups List[str]: Model pckage group names list to be added into the collection
239+ """
240+ if len (model_groups ) > 10 :
241+ raise Exception ("Model groups can have a maximum length of 10" )
242+ tag_rules = self ._get_collection_tag_rule (collection_name = collection_name )
243+ tag_rule_key = tag_rules ["tag_rule_key" ]
244+ tag_rule_value = tag_rules ["tag_rule_value" ]
245+
246+ add_groups_success = []
247+ add_groups_failure = []
248+ if tag_rule_key is not None and tag_rule_value is not None :
249+ for model_group in model_groups :
250+ try :
251+ self ._add_model_group (
252+ model_package_group = model_group ,
253+ tag_rule_key = tag_rule_key ,
254+ tag_rule_value = tag_rule_value ,
255+ )
256+ add_groups_success .append (model_group )
257+ except ClientError as e :
258+ self ._check_access_error (err = e )
259+ message = e .response ["Error" ]["Message" ]
260+ add_groups_failure .append (
261+ {
262+ "model_group" : model_group ,
263+ "failure_reason" : message ,
264+ }
265+ )
266+ return {
267+ "added_groups" : add_groups_success ,
268+ "failure" : add_groups_failure ,
269+ }
270+
271+ def remove_model_groups (self , collection_name : str , model_groups : List [str ]):
272+ """To remove list of model package groups from a collection
273+
274+ Args:
275+ collection_name (str): The name of the collection
276+ model_groups List[str]: Model package group names list to be removed
277+ """
278+
279+ if len (model_groups ) > 10 :
280+ raise Exception ("Model groups can have a maximum length of 10" )
281+ tag_rules = self ._get_collection_tag_rule (collection_name = collection_name )
282+
283+ tag_rule_key = tag_rules ["tag_rule_key" ]
284+ tag_rule_value = tag_rules ["tag_rule_value" ]
285+
286+ remove_groups_success = []
287+ remove_groups_failure = []
288+ if tag_rule_key is not None and tag_rule_value is not None :
289+ for model_group in model_groups :
290+ try :
291+ self ._remove_model_group (
292+ model_package_group = model_group ,
293+ tag_rule_key = tag_rule_key ,
294+ )
295+ remove_groups_success .append (model_group )
296+ except ClientError as e :
297+ self ._check_access_error (err = e )
298+ message = e .response ["Error" ]["Message" ]
299+ remove_groups_failure .append (
300+ {
301+ "model_group" : model_group ,
302+ "failure_reason" : message ,
303+ }
304+ )
305+ return {
306+ "removed_groups" : remove_groups_success ,
307+ "failure" : remove_groups_failure ,
308+ }
309+
310+ def move_model_group (
311+ self , source_collection_name : str , model_group : str , destination_collection_name : str
312+ ):
313+ """To move a model package group from one collection to another
314+
315+ Args:
316+ source_collection_name (str): Collection name of the source
317+ model_group (str): Model package group names which is to be moved
318+ destination_collection_name (str): Collection name of the destination
319+ """
320+ remove_details = self .remove_model_groups (
321+ collection_name = source_collection_name , model_groups = [model_group ]
322+ )
323+ if len (remove_details ["failure" ]) == 1 :
324+ raise Exception (remove_details ["failure" ][0 ]["failure" ])
325+
326+ added_details = self .add_model_groups (
327+ collection_name = destination_collection_name , model_groups = [model_group ]
328+ )
329+
330+ if len (added_details ["failure" ]) == 1 :
331+ # adding the model group back to the source collection in case of an add failure
332+ self .add_model_groups (
333+ collection_name = source_collection_name , model_groups = [model_group ]
334+ )
335+ raise Exception (added_details ["failure" ][0 ]["failure" ])
336+
337+ return {
338+ "moved_success" : model_group ,
339+ }
0 commit comments