1313"""This module provides the JumpStart Curated Hub class."""
1414from __future__ import absolute_import
1515
16- from typing import Optional , Dict , Any
16+ from typing import Any , Dict , Optional
1717import boto3
1818from sagemaker .session import Session
1919from sagemaker .jumpstart .constants import (
2020 JUMPSTART_DEFAULT_REGION_NAME ,
2121)
2222
23- from sagemaker .jumpstart .types import HubDataType
24- import sagemaker .jumpstart .curated_hub . utils as hubutils
23+ from sagemaker .jumpstart .types import HubDescription , HubContentType , HubContentDescription
24+ import sagemaker .jumpstart .session_utils as session_utils
2525
2626
2727class CuratedHub :
2828 """Class for creating and managing a curated JumpStart hub"""
2929
3030 def __init__ (
3131 self ,
32- name : str ,
32+ hub_name : str ,
3333 region : str = JUMPSTART_DEFAULT_REGION_NAME ,
34- session : Optional [Session ] = None ,
34+ sagemaker_session : Optional [Session ] = None ,
3535 ):
36- self .name = name
37- if session .boto_region_name != region :
36+ self .hub_name = hub_name
37+ if sagemaker_session .boto_region_name != region :
3838 # TODO: Handle error
3939 pass
4040 self .region = region
41- self ._session = session or Session (boto3 .Session (region_name = region ))
41+ self ._sagemaker_session = sagemaker_session or Session (boto3 .Session (region_name = region ))
4242
4343 def create (
4444 self ,
@@ -50,32 +50,60 @@ def create(
5050 ) -> Dict [str , str ]:
5151 """Creates a hub with the given description"""
5252
53- return hubutils .create_hub (
54- hub_name = self .name ,
53+ bucket_name = session_utils .create_hub_bucket_if_it_does_not_exist (
54+ bucket_name , self ._sagemaker_session
55+ )
56+
57+ return self ._sagemaker_session .create_hub (
58+ hub_name = self .hub_name ,
5559 hub_description = description ,
5660 hub_display_name = display_name ,
5761 hub_search_keywords = search_keywords ,
5862 hub_bucket_name = bucket_name ,
5963 tags = tags ,
60- sagemaker_session = self ._session ,
6164 )
6265
63- def describe_model (self , model_name : str , model_version : str = "*" ) -> Dict [str , Any ]:
64- """Returns descriptive information about the Hub Model"""
66+ def describe (self ) -> HubDescription :
67+ """Returns descriptive information about the Hub"""
68+
69+ hub_description = self ._sagemaker_session .describe_hub (hub_name = self .hub_name )
70+
71+ return HubDescription (hub_description )
72+
73+ def list_models (self , ** kwargs ) -> Dict [str , Any ]:
74+ """Lists the models in this Curated Hub
6575
66- hub_content = hubutils . describe_hub_content (
67- hub_name = self . name ,
68- content_name = model_name ,
69- content_type = HubDataType . MODEL ,
70- content_version = model_version ,
71- sagemaker_session = self ._session ,
76+ **kwargs: Passed to invocation of ``Session:list_hub_contents``.
77+ """
78+ # TODO: Validate kwargs and fast-fail?
79+
80+ hub_content_summaries = self . _sagemaker_session . list_hub_contents (
81+ hub_name = self .hub_name , hub_content_type = HubContentType . MODEL , ** kwargs
7282 )
83+ # TODO: Handle pagination
84+ return hub_content_summaries
7385
74- return hub_content
86+ def describe_model (self , model_name : str , model_version : str = "*" ) -> HubContentDescription :
87+ """Returns descriptive information about the Hub Model"""
7588
76- def describe (self ) -> Dict [str , Any ]:
77- """Returns descriptive information about the Hub"""
89+ hub_content_description : Dict [str , Any ] = self ._sagemaker_session .describe_hub_content (
90+ hub_name = self .hub_name ,
91+ hub_content_name = model_name ,
92+ hub_content_version = model_version ,
93+ hub_content_type = HubContentType .MODEL ,
94+ )
95+
96+ return HubContentDescription (hub_content_description )
7897
79- hub_info = hubutils .describe_hub (hub_name = self .name , sagemaker_session = self ._session )
98+ def delete_model (self , model_name : str , model_version : str = "*" ) -> None :
99+ """Deletes a model from this CuratedHub."""
100+ return self ._sagemaker_session .delete_hub_content (
101+ hub_content_name = model_name ,
102+ hub_content_version = model_version ,
103+ hub_content_type = HubContentType .MODEL ,
104+ hub_name = self .hub_name ,
105+ )
80106
81- return hub_info
107+ def delete (self ) -> None :
108+ """Deletes this Curated Hub"""
109+ return self ._sagemaker_session .delete_hub (self .hub_name )
0 commit comments