2121import botocore
2222from packaging .version import Version
2323from packaging .specifiers import SpecifierSet , InvalidSpecifier
24+ from sagemaker .session import Session
25+ from sagemaker .utilities .cache import LRUCache
2426from sagemaker .jumpstart .constants import (
2527 ENV_VARIABLE_JUMPSTART_MANIFEST_LOCAL_ROOT_DIR_OVERRIDE ,
2628 ENV_VARIABLE_JUMPSTART_SPECS_LOCAL_ROOT_DIR_OVERRIDE ,
2729 JUMPSTART_DEFAULT_MANIFEST_FILE_S3_KEY ,
2830 JUMPSTART_DEFAULT_REGION_NAME ,
2931 JUMPSTART_LOGGER ,
3032 MODEL_ID_LIST_WEB_URL ,
33+ DEFAULT_JUMPSTART_SAGEMAKER_SESSION ,
3134)
32- from sagemaker .jumpstart .curated_hub .curated_hub import CuratedHub
33- from sagemaker .jumpstart .curated_hub .utils import get_info_from_hub_resource_arn
3435from sagemaker .jumpstart .exceptions import get_wildcard_model_version_msg
3536from sagemaker .jumpstart .parameters import (
3637 JUMPSTART_DEFAULT_MAX_S3_CACHE_ITEMS ,
3738 JUMPSTART_DEFAULT_MAX_SEMANTIC_VERSION_CACHE_ITEMS ,
3839 JUMPSTART_DEFAULT_S3_CACHE_EXPIRATION_HORIZON ,
3940 JUMPSTART_DEFAULT_SEMANTIC_VERSION_CACHE_EXPIRATION_HORIZON ,
4041)
42+ from sagemaker .jumpstart import utils
4143from sagemaker .jumpstart .types import (
4244 JumpStartCachedContentKey ,
4345 JumpStartCachedContentValue ,
4446 JumpStartModelHeader ,
4547 JumpStartModelSpecs ,
4648 JumpStartS3FileType ,
4749 JumpStartVersionedModelId ,
50+ DescribeHubResponse ,
51+ DescribeHubContentsResponse ,
52+ HubType ,
4853 HubContentType ,
4954)
50- from sagemaker .jumpstart import utils
51- from sagemaker .utilities .cache import LRUCache
55+ from sagemaker .jumpstart .curated_hub import utils as hub_utils
5256
5357
5458class JumpStartModelsCache :
@@ -74,6 +78,7 @@ def __init__(
7478 s3_bucket_name : Optional [str ] = None ,
7579 s3_client_config : Optional [botocore .config .Config ] = None ,
7680 s3_client : Optional [boto3 .client ] = None ,
81+ sagemaker_session : Optional [Session ] = DEFAULT_JUMPSTART_SAGEMAKER_SESSION ,
7782 ) -> None : # fmt: on
7883 """Initialize a ``JumpStartModelsCache`` instance.
7984
@@ -95,6 +100,8 @@ def __init__(
95100 s3_client_config (Optional[botocore.config.Config]): s3 client config to use for cache.
96101 Default: None (no config).
97102 s3_client (Optional[boto3.client]): s3 client to use. Default: None.
103+ sagemaker_session (Optional[sagemaker.session.Session]): A SageMaker Session object,
104+ used for SageMaker interactions. Default: Session in region associated with boto3 session.
98105 """
99106
100107 self ._region = region
@@ -121,6 +128,7 @@ def __init__(
121128 if s3_client_config
122129 else boto3 .client ("s3" , region_name = self ._region )
123130 )
131+ self ._sagemaker_session = sagemaker_session
124132
125133 def set_region (self , region : str ) -> None :
126134 """Set region for cache. Clears cache after new region is set."""
@@ -340,32 +348,34 @@ def _retrieval_function(
340348 formatted_content = model_specs
341349 )
342350 if data_type == HubContentType .MODEL :
343- info = get_info_from_hub_resource_arn (
351+ hub_name , _ , model_name , model_version = hub_utils . get_info_from_hub_resource_arn (
344352 id_info
345353 )
346- hub = CuratedHub (hub_name = info .hub_name , region = info .region )
347- hub_content = hub .describe_model (
348- model_name = info .hub_content_name , model_version = info .hub_content_version
354+ hub_model_description : Dict [str , Any ] = self ._sagemaker_session .describe_hub_content (
355+ hub_name = hub_name ,
356+ hub_content_name = model_name ,
357+ hub_content_version = model_version ,
358+ hub_content_type = data_type
349359 )
360+
361+ model_specs = JumpStartModelSpecs (DescribeHubContentsResponse (hub_model_description ), is_hub_content = True )
362+
350363 utils .emit_logs_based_on_model_specs (
351- hub_content . content_document ,
364+ model_specs ,
352365 self .get_region (),
353366 self ._s3_client
354367 )
355- model_specs = JumpStartModelSpecs (hub_content .content_document , is_hub_content = True )
356368 return JumpStartCachedContentValue (
357369 formatted_content = model_specs
358370 )
359- if data_type == HubContentType .HUB :
360- info = get_info_from_hub_resource_arn (
361- id_info
362- )
363- hub = CuratedHub (hub_name = info .hub_name , region = info .region )
364- hub_info = hub .describe ()
365- return JumpStartCachedContentValue (formatted_content = hub_info )
371+ if data_type == HubType .HUB :
372+ hub_name , _ , _ , _ = hub_utils .get_info_from_hub_resource_arn (id_info )
373+ response : Dict [str , Any ] = self ._sagemaker_session .describe_hub (hub_name = hub_name )
374+ hub_description = DescribeHubResponse (response )
375+ return JumpStartCachedContentValue (formatted_content = DescribeHubResponse (hub_description ))
366376 raise ValueError (
367- f"Bad value for key '{ key } ': must be in" ,
368- f"{ [JumpStartS3FileType .MANIFEST , JumpStartS3FileType .SPECS , HubContentType .HUB , HubContentType .MODEL ]} "
377+ f"Bad value for key '{ key } ': must be in " ,
378+ f"{ [JumpStartS3FileType .MANIFEST , JumpStartS3FileType .SPECS , HubType .HUB , HubContentType .MODEL ]} "
369379 )
370380
371381 def get_manifest (self ) -> List [JumpStartModelHeader ]:
@@ -490,7 +500,7 @@ def get_hub(self, hub_arn: str) -> Dict[str, Any]:
490500 hub_arn (str): Arn for the Hub to get info for
491501 """
492502
493- details , _ = self ._content_cache .get (JumpStartCachedContentKey (HubContentType .HUB , hub_arn ))
503+ details , _ = self ._content_cache .get (JumpStartCachedContentKey (HubType .HUB , hub_arn ))
494504 return details .formatted_content
495505
496506 def clear (self ) -> None :
0 commit comments