1515import datetime
1616from difflib import get_close_matches
1717import os
18- from typing import List , Optional , Tuple , Union
18+ from typing import Any , Dict , List , Optional , Tuple , Union
1919import json
2020import boto3
2121import botocore
2929 JUMPSTART_LOGGER ,
3030 MODEL_ID_LIST_WEB_URL ,
3131)
32+ from sagemaker .jumpstart .curated_hub .curated_hub import CuratedHub
3233from sagemaker .jumpstart .exceptions import get_wildcard_model_version_msg
3334from sagemaker .jumpstart .parameters import (
3435 JUMPSTART_DEFAULT_MAX_S3_CACHE_ITEMS ,
3738 JUMPSTART_DEFAULT_SEMANTIC_VERSION_CACHE_EXPIRATION_HORIZON ,
3839)
3940from sagemaker .jumpstart .types import (
40- JumpStartCachedS3ContentKey ,
41- JumpStartCachedS3ContentValue ,
41+ JumpStartCachedContentKey ,
42+ JumpStartCachedContentValue ,
4243 JumpStartModelHeader ,
4344 JumpStartModelSpecs ,
4445 JumpStartS3FileType ,
4546 JumpStartVersionedModelId ,
47+ HubDataType ,
4648)
4749from sagemaker .jumpstart import utils
4850from sagemaker .utilities .cache import LRUCache
@@ -95,7 +97,7 @@ def __init__(
9597 """
9698
9799 self ._region = region
98- self ._s3_cache = LRUCache [JumpStartCachedS3ContentKey , JumpStartCachedS3ContentValue ](
100+ self ._content_cache = LRUCache [JumpStartCachedContentKey , JumpStartCachedContentValue ](
99101 max_cache_items = max_s3_cache_items ,
100102 expiration_horizon = s3_cache_expiration_horizon ,
101103 retrieval_function = self ._retrieval_function ,
@@ -172,8 +174,8 @@ def _get_manifest_key_from_model_id_semantic_version(
172174
173175 model_id , version = key .model_id , key .version
174176
175- manifest = self ._s3_cache .get (
176- JumpStartCachedS3ContentKey (JumpStartS3FileType .MANIFEST , self ._manifest_file_s3_key )
177+ manifest = self ._content_cache .get (
178+ JumpStartCachedContentKey (JumpStartS3FileType .MANIFEST , self ._manifest_file_s3_key )
177179 )[0 ].formatted_content
178180
179181 sm_version = utils .get_sagemaker_version ()
@@ -301,50 +303,71 @@ def _get_json_file_from_local_override(
301303
302304 def _retrieval_function (
303305 self ,
304- key : JumpStartCachedS3ContentKey ,
305- value : Optional [JumpStartCachedS3ContentValue ],
306- ) -> JumpStartCachedS3ContentValue :
307- """Return s3 content given a file type and s3_key in ``JumpStartCachedS3ContentKey ``.
306+ key : JumpStartCachedContentKey ,
307+ value : Optional [JumpStartCachedContentValue ],
308+ ) -> JumpStartCachedContentValue :
309+ """Return s3 content given a data type and s3_key in ``JumpStartCachedContentKey ``.
308310
309311 If a manifest file is being fetched, we only download the object if the md5 hash in
310312 ``head_object`` does not match the current md5 hash for the stored value. This prevents
311313 unnecessarily downloading the full manifest when it hasn't changed.
312314
313315 Args:
314- key (JumpStartCachedS3ContentKey ): key for which to fetch s3 content.
316+ key (JumpStartCachedContentKey ): key for which to fetch JumpStart content.
315317 value (Optional[JumpStartVersionedModelId]): Current value of old cached
316318 s3 content. This is used for the manifest file, so that it is only
317319 downloaded when its content changes.
318320 """
319321
320- file_type , s3_key = key .file_type , key .s3_key
322+ data_type , id_info = key .data_type , key .id_info
321323
322- if file_type == JumpStartS3FileType .MANIFEST :
324+ if data_type == JumpStartS3FileType .MANIFEST :
323325 if value is not None and not self ._is_local_metadata_mode ():
324- etag = self ._get_json_md5_hash (s3_key )
326+ etag = self ._get_json_md5_hash (id_info )
325327 if etag == value .md5_hash :
326328 return value
327- formatted_body , etag = self ._get_json_file (s3_key , file_type )
328- return JumpStartCachedS3ContentValue (
329+ formatted_body , etag = self ._get_json_file (id_info , data_type )
330+ return JumpStartCachedContentValue (
329331 formatted_content = utils .get_formatted_manifest (formatted_body ),
330332 md5_hash = etag ,
331333 )
332- if file_type == JumpStartS3FileType .SPECS :
333- formatted_body , _ = self ._get_json_file (s3_key , file_type )
334+ if data_type == JumpStartS3FileType .SPECS :
335+ formatted_body , _ = self ._get_json_file (id_info , data_type )
334336 model_specs = JumpStartModelSpecs (formatted_body )
335337 utils .emit_logs_based_on_model_specs (model_specs , self .get_region (), self ._s3_client )
336- return JumpStartCachedS3ContentValue (
338+ return JumpStartCachedContentValue (
337339 formatted_content = model_specs
338340 )
341+ if data_type == HubDataType .MODEL :
342+ hub_name , region , model_name , model_version = utils .extract_info_from_hub_content_arn (
343+ id_info
344+ )
345+ hub = CuratedHub (hub_name = hub_name , region = region )
346+ hub_content = hub .describe_model (model_name = model_name , model_version = model_version )
347+ utils .emit_logs_based_on_model_specs (
348+ hub_content .content_document ,
349+ self .get_region (),
350+ self ._s3_client
351+ )
352+ model_specs = JumpStartModelSpecs (hub_content .content_document , is_hub_content = True )
353+ return JumpStartCachedContentValue (
354+ formatted_content = model_specs
355+ )
356+ if data_type == HubDataType .HUB :
357+ hub_name , region , _ , _ = utils .extract_info_from_hub_content_arn (id_info )
358+ hub = CuratedHub (hub_name = hub_name , region = region )
359+ hub_info = hub .describe ()
360+ return JumpStartCachedContentValue (formatted_content = hub_info )
339361 raise ValueError (
340- f"Bad value for key '{ key } ': must be in { [JumpStartS3FileType .MANIFEST , JumpStartS3FileType .SPECS ]} "
362+ f"Bad value for key '{ key } ': must be in" ,
363+ f"{ [JumpStartS3FileType .MANIFEST , JumpStartS3FileType .SPECS , HubDataType .HUB , HubDataType .MODEL ]} "
341364 )
342365
343366 def get_manifest (self ) -> List [JumpStartModelHeader ]:
344367 """Return entire JumpStart models manifest."""
345368
346- manifest_dict = self ._s3_cache .get (
347- JumpStartCachedS3ContentKey (JumpStartS3FileType .MANIFEST , self ._manifest_file_s3_key )
369+ manifest_dict = self ._content_cache .get (
370+ JumpStartCachedContentKey (JumpStartS3FileType .MANIFEST , self ._manifest_file_s3_key )
348371 )[0 ].formatted_content
349372 manifest = list (manifest_dict .values ()) # type: ignore
350373 return manifest
@@ -407,8 +430,8 @@ def _get_header_impl(
407430 JumpStartVersionedModelId (model_id , semantic_version_str )
408431 )[0 ]
409432
410- manifest = self ._s3_cache .get (
411- JumpStartCachedS3ContentKey (JumpStartS3FileType .MANIFEST , self ._manifest_file_s3_key )
433+ manifest = self ._content_cache .get (
434+ JumpStartCachedContentKey (JumpStartS3FileType .MANIFEST , self ._manifest_file_s3_key )
412435 )[0 ].formatted_content
413436 try :
414437 header = manifest [versioned_model_id ] # type: ignore
@@ -430,8 +453,8 @@ def get_specs(self, model_id: str, semantic_version_str: str) -> JumpStartModelS
430453
431454 header = self .get_header (model_id , semantic_version_str )
432455 spec_key = header .spec_key
433- specs , cache_hit = self ._s3_cache .get (
434- JumpStartCachedS3ContentKey (JumpStartS3FileType .SPECS , spec_key )
456+ specs , cache_hit = self ._content_cache .get (
457+ JumpStartCachedContentKey (JumpStartS3FileType .SPECS , spec_key )
435458 )
436459 if not cache_hit and "*" in semantic_version_str :
437460 JUMPSTART_LOGGER .warning (
@@ -443,7 +466,29 @@ def get_specs(self, model_id: str, semantic_version_str: str) -> JumpStartModelS
443466 )
444467 return specs .formatted_content
445468
469+ def get_hub_model (self , hub_model_arn : str ) -> JumpStartModelSpecs :
470+ """Return JumpStart-compatible specs for a given Hub model
471+
472+ Args:
473+ hub_model_arn (str): Arn for the Hub model to get specs for
474+ """
475+
476+ details , _ = self ._content_cache .get (
477+ JumpStartCachedContentKey (HubDataType .MODEL , hub_model_arn )
478+ )
479+ return details .formatted_content
480+
481+ def get_hub (self , hub_arn : str ) -> Dict [str , Any ]:
482+ """Return descriptive info for a given Hub
483+
484+ Args:
485+ hub_arn (str): Arn for the Hub to get info for
486+ """
487+
488+ details , _ = self ._content_cache .get (JumpStartCachedContentKey (HubDataType .HUB , hub_arn ))
489+ return details .formatted_content
490+
446491 def clear (self ) -> None :
447492 """Clears the model ID/version and s3 cache."""
448- self ._s3_cache .clear ()
493+ self ._content_cache .clear ()
449494 self ._model_id_semantic_version_manifest_key_cache .clear ()
0 commit comments