1313"""This module defines the JumpStartModelsCache class."""
1414from __future__ import absolute_import
1515import datetime
16- from typing import List , Optional
16+ from typing import Optional
1717import json
1818import boto3
19+ import botocore
1920import semantic_version
21+ from sagemaker .jumpstart .constants import (
22+ JUMPSTART_DEFAULT_MANIFEST_FILE_S3_KEY ,
23+ JUMPSTART_DEFAULT_REGION_NAME ,
24+ )
25+ from sagemaker .jumpstart .parameters import (
26+ JUMPSTART_DEFAULT_MAX_S3_CACHE_ITEMS ,
27+ JUMPSTART_DEFAULT_MAX_SEMANTIC_VERSION_CACHE_ITEMS ,
28+ JUMPSTART_DEFAULT_S3_CACHE_EXPIRATION_HORIZON ,
29+ JUMPSTART_DEFAULT_SEMANTIC_VERSION_CACHE_EXPIRATION_HORIZON ,
30+ )
2031from sagemaker .jumpstart .types import (
2132 JumpStartCachedS3ContentKey ,
2233 JumpStartCachedS3ContentValue ,
2839from sagemaker .jumpstart import utils
2940from sagemaker .utilities .cache import LRUCache
3041
31- DEFAULT_REGION_NAME = boto3 .session .Session ().region_name
32-
33- DEFAULT_MAX_S3_CACHE_ITEMS = 20
34- DEFAULT_S3_CACHE_EXPIRATION_TIME = datetime .timedelta (hours = 6 )
35-
36- DEFAULT_MAX_SEMANTIC_VERSION_CACHE_ITEMS = 20
37- DEFAULT_SEMANTIC_VERSION_CACHE_EXPIRATION_TIME = datetime .timedelta (hours = 6 )
38-
39- DEFAULT_MANIFEST_FILE_S3_KEY = "models_manifest.json"
40-
4142
4243class JumpStartModelsCache :
4344 """Class that implements a cache for JumpStart models manifests and specs.
@@ -48,78 +49,95 @@ class JumpStartModelsCache:
4849
4950 def __init__ (
5051 self ,
51- region : Optional [str ] = DEFAULT_REGION_NAME ,
52- max_s3_cache_items : Optional [int ] = DEFAULT_MAX_S3_CACHE_ITEMS ,
53- s3_cache_expiration_time : Optional [datetime .timedelta ] = DEFAULT_S3_CACHE_EXPIRATION_TIME ,
54- max_semantic_version_cache_items : Optional [int ] = DEFAULT_MAX_SEMANTIC_VERSION_CACHE_ITEMS ,
55- semantic_version_cache_expiration_time : Optional [
52+ region : Optional [str ] = JUMPSTART_DEFAULT_REGION_NAME ,
53+ max_s3_cache_items : Optional [int ] = JUMPSTART_DEFAULT_MAX_S3_CACHE_ITEMS ,
54+ s3_cache_expiration_horizon : Optional [
5655 datetime .timedelta
57- ] = DEFAULT_SEMANTIC_VERSION_CACHE_EXPIRATION_TIME ,
58- manifest_file_s3_key : Optional [str ] = DEFAULT_MANIFEST_FILE_S3_KEY ,
59- bucket : Optional [str ] = None ,
56+ ] = JUMPSTART_DEFAULT_S3_CACHE_EXPIRATION_HORIZON ,
57+ max_semantic_version_cache_items : Optional [
58+ int
59+ ] = JUMPSTART_DEFAULT_MAX_SEMANTIC_VERSION_CACHE_ITEMS ,
60+ semantic_version_cache_expiration_horizon : Optional [
61+ datetime .timedelta
62+ ] = JUMPSTART_DEFAULT_SEMANTIC_VERSION_CACHE_EXPIRATION_HORIZON ,
63+ manifest_file_s3_key : Optional [str ] = JUMPSTART_DEFAULT_MANIFEST_FILE_S3_KEY ,
64+ s3_bucket_name : Optional [str ] = None ,
65+ s3_client_config : Optional [botocore .config .Config ] = None ,
6066 ) -> None :
6167 """Initialize a ``JumpStartModelsCache`` instance.
6268
6369 Args:
6470 region (Optional[str]): AWS region to associate with cache. Default: region associated
65- with botocore session.
66- max_s3_cache_items (Optional[int]): Maximum number of files to store in s3 cache.
71+ with boto3 session.
72+ max_s3_cache_items (Optional[int]): Maximum number of items to store in s3 cache.
6773 Default: 20.
68- s3_cache_expiration_time (Optional[datetime.timedelta]): Maximum time to hold items in
69- s3 cache before invalidation. Default: 6 hours.
70- max_semantic_version_cache_items (Optional[int]): Maximum number of files to store in
74+ s3_cache_expiration_horizon (Optional[datetime.timedelta]): Maximum time to hold
75+ items in s3 cache before invalidation. Default: 6 hours.
76+ max_semantic_version_cache_items (Optional[int]): Maximum number of items to store in
7177 semantic version cache. Default: 20.
72- semantic_version_cache_expiration_time (Optional[datetime.timedelta]): Maximum time to
73- hold items in semantic version cache before invalidation. Default: 6 hours.
74- bucket (Optional[str]): S3 bucket to associate with cache. Default: JumpStart-hosted
75- content bucket for region.
78+ semantic_version_cache_expiration_horizon (Optional[datetime.timedelta]):
79+ Maximum time to hold items in semantic version cache before invalidation.
80+ Default: 6 hours.
81+ s3_bucket_name (Optional[str]): S3 bucket to associate with cache.
82+ Default: JumpStart-hosted content bucket for region.
83+ s3_client_config (Optional[botocore.config.Config]): s3 client config to use for cache.
84+ Default: None (no config).
7685 """
7786
7887 self ._region = region
7988 self ._s3_cache = LRUCache [JumpStartCachedS3ContentKey , JumpStartCachedS3ContentValue ](
8089 max_cache_items = max_s3_cache_items ,
81- expiration_time = s3_cache_expiration_time ,
90+ expiration_horizon = s3_cache_expiration_horizon ,
8291 retrieval_function = self ._get_file_from_s3 ,
8392 )
8493 self ._model_id_semantic_version_manifest_key_cache = LRUCache [
8594 JumpStartVersionedModelId , JumpStartVersionedModelId
8695 ](
8796 max_cache_items = max_semantic_version_cache_items ,
88- expiration_time = semantic_version_cache_expiration_time ,
97+ expiration_horizon = semantic_version_cache_expiration_horizon ,
8998 retrieval_function = self ._get_manifest_key_from_model_id_semantic_version ,
9099 )
91100 self ._manifest_file_s3_key = manifest_file_s3_key
92- self ._bucket = (
93- utils .get_jumpstart_content_bucket (self ._region ) if bucket is None else bucket
101+ self .s3_bucket_name = (
102+ utils .get_jumpstart_content_bucket (self ._region )
103+ if s3_bucket_name is None
104+ else s3_bucket_name
105+ )
106+ self ._s3_client = (
107+ boto3 .client ("s3" , region_name = self ._region , config = s3_client_config )
108+ if s3_client_config
109+ else boto3 .client ("s3" , region_name = self ._region )
94110 )
95- self ._has_retried_cache_refresh = False
96111
97112 def set_region (self , region : str ) -> None :
98113 """Set region for cache. Clears cache after new region is set."""
99- self ._region = region
100- self .clear ()
114+ if region != self ._region :
115+ self ._region = region
116+ self .clear ()
101117
102118 def get_region (self ) -> str :
103119 """Return region for cache."""
104120 return self ._region
105121
106122 def set_manifest_file_s3_key (self , key : str ) -> None :
107123 """Set manifest file s3 key. Clears cache after new key is set."""
108- self ._manifest_file_s3_key = key
109- self .clear ()
124+ if key != self ._manifest_file_s3_key :
125+ self ._manifest_file_s3_key = key
126+ self .clear ()
110127
111128 def get_manifest_file_s3_key (self ) -> None :
112129 """Return manifest file s3 key for cache."""
113130 return self ._manifest_file_s3_key
114131
115- def set_bucket (self , bucket : str ) -> None :
132+ def set_s3_bucket_name (self , s3_bucket_name : str ) -> None :
116133 """Set s3 bucket used for cache."""
117- self ._bucket = bucket
118- self .clear ()
134+ if s3_bucket_name != self .s3_bucket_name :
135+ self .s3_bucket_name = s3_bucket_name
136+ self .clear ()
119137
120138 def get_bucket (self ) -> None :
121139 """Return bucket used for cache."""
122- return self ._bucket
140+ return self .s3_bucket_name
123141
124142 def _get_manifest_key_from_model_id_semantic_version (
125143 self ,
@@ -128,13 +146,18 @@ def _get_manifest_key_from_model_id_semantic_version(
128146 ) -> JumpStartVersionedModelId :
129147 """Return model id and version in manifest that matches semantic version/id.
130148
149+ Uses ``semantic_version`` to perform version comparison. The highest model version
150+ matching the semantic version is used, which is compatible with the SageMaker
151+ version.
152+
131153 Args:
132154 key (JumpStartVersionedModelId): Key for which to fetch versioned model id.
133155 value (Optional[JumpStartVersionedModelId]): Unused variable for current value of
134156 old cached model id/version.
135157
136158 Raises:
137- KeyError: If the semantic version is not found in the manifest.
159+ KeyError: If the semantic version is not found in the manifest, or is found but
160+ the SageMaker version needs to be upgraded in order for the model to be used.
138161 """
139162
140163 model_id , version = key .model_id , key .version
@@ -147,7 +170,7 @@ def _get_manifest_key_from_model_id_semantic_version(
147170
148171 versions_compatible_with_sagemaker = [
149172 semantic_version .Version (header .version )
150- for _ , header in manifest .items ()
173+ for header in manifest .values ()
151174 if header .model_id == model_id
152175 and semantic_version .Version (header .min_version ) <= semantic_version .Version (sm_version )
153176 ]
@@ -164,19 +187,19 @@ def _get_manifest_key_from_model_id_semantic_version(
164187
165188 versions_incompatible_with_sagemaker = [
166189 semantic_version .Version (header .version )
167- for _ , header in manifest .items ()
190+ for header in manifest .values ()
168191 if header .model_id == model_id
169192 ]
170193 sm_incompatible_model_version = spec .select (versions_incompatible_with_sagemaker )
171194 if sm_incompatible_model_version is not None :
172195 model_version_to_use_incompatible_with_sagemaker = str (sm_incompatible_model_version )
173196 sm_version_to_use = [
174197 header .min_version
175- for _ , header in manifest .items ()
198+ for header in manifest .values ()
176199 if header .model_id == model_id
177200 and header .version == model_version_to_use_incompatible_with_sagemaker
178201 ]
179- assert len (sm_version_to_use ) == 1
202+ assert len (sm_version_to_use ) == 1 # ``manifest`` dict should already enforce this
180203 sm_version_to_use = sm_version_to_use [0 ]
181204
182205 error_msg = (
@@ -187,7 +210,7 @@ def _get_manifest_key_from_model_id_semantic_version(
187210 f"{ model_version_to_use_incompatible_with_sagemaker } of { model_id } ."
188211 )
189212 raise KeyError (error_msg )
190- error_msg = f"Unable to find model manifest for { model_id } with version { version } "
213+ error_msg = f"Unable to find model manifest for { model_id } with version { version } . "
191214 raise KeyError (error_msg )
192215
193216 def _get_file_from_s3 (
@@ -210,33 +233,49 @@ def _get_file_from_s3(
210233
211234 file_type , s3_key = key .file_type , key .s3_key
212235
213- s3_client = boto3 .client ("s3" , region_name = self ._region )
214-
215236 if file_type == JumpStartS3FileType .MANIFEST :
216- etag = s3_client . head_object (Bucket = self ._bucket , Key = s3_key )["ETag" ]
237+ etag = self . _s3_client . head_object (Bucket = self .s3_bucket_name , Key = s3_key )["ETag" ]
217238 if value is not None and etag == value .md5_hash :
218239 return value
219- response = s3_client . get_object (Bucket = self ._bucket , Key = s3_key )
240+ response = self . _s3_client . get_object (Bucket = self .s3_bucket_name , Key = s3_key )
220241 formatted_body = json .loads (response ["Body" ].read ().decode ("utf-8" ))
221242 return JumpStartCachedS3ContentValue (
222243 formatted_file_content = utils .get_formatted_manifest (formatted_body ),
223244 md5_hash = etag ,
224245 )
225246 if file_type == JumpStartS3FileType .SPECS :
226- response = s3_client . get_object (Bucket = self ._bucket , Key = s3_key )
247+ response = self . _s3_client . get_object (Bucket = self .s3_bucket_name , Key = s3_key )
227248 formatted_body = json .loads (response ["Body" ].read ().decode ("utf-8" ))
228249 return JumpStartCachedS3ContentValue (
229250 formatted_file_content = JumpStartModelSpecs (formatted_body )
230251 )
231- raise RuntimeError (f"Bad value for key: { key } " )
252+ raise ValueError (
253+ f"Bad value for key '{ key } ': must be in { [JumpStartS3FileType .MANIFEST , JumpStartS3FileType .SPECS ]} "
254+ )
232255
233256 def get_header (
234257 self , model_id : str , semantic_version_str : Optional [str ] = None
235- ) -> List [JumpStartModelHeader ]:
236- """Return list of headers for a given JumpStart model id and semantic version.
258+ ) -> JumpStartModelHeader :
259+ """Return header for a given JumpStart model id and semantic version.
260+
261+ Args:
262+ model_id (str): model id for which to get a header.
263+ semantic_version_str (Optional[str]): The semantic version for which to get a
264+ header. If None, the highest compatible version is returned.
265+ """
266+
267+ return self ._get_header_impl (model_id , 0 , semantic_version_str )
268+
269+ def _get_header_impl (
270+ self , model_id : str , attempt : int , semantic_version_str : Optional [str ] = None
271+ ) -> JumpStartModelHeader :
272+ """Lower-level function to return header.
273+
274+ Allows a single retry if the cache is old.
237275
238276 Args:
239277 model_id (str): model id for which to get a header.
278+ attempt (int): attempt number at retrieving a header.
240279 semantic_version_str (Optional[str]): The semantic version for which to get a
241280 header. If None, the highest compatible version is returned.
242281 """
@@ -248,17 +287,12 @@ def get_header(
248287 JumpStartCachedS3ContentKey (JumpStartS3FileType .MANIFEST , self ._manifest_file_s3_key )
249288 ).formatted_file_content
250289 try :
251- header = manifest [versioned_model_id ]
252- if self ._has_retried_cache_refresh :
253- self ._has_retried_cache_refresh = False
254- return header
290+ return manifest [versioned_model_id ]
255291 except KeyError :
256- if self ._has_retried_cache_refresh :
257- self ._has_retried_cache_refresh = False
292+ if attempt > 0 :
258293 raise
259294 self .clear ()
260- self ._has_retried_cache_refresh = True
261- return self .get_header (model_id , semantic_version )
295+ return self ._get_header_impl (model_id , attempt + 1 , semantic_version_str )
262296
263297 def get_specs (
264298 self , model_id : str , semantic_version_str : Optional [str ] = None
@@ -278,7 +312,6 @@ def get_specs(
278312 ).formatted_file_content
279313
280314 def clear (self ) -> None :
281- """Clears the model id/version and s3 cache and resets ``_has_retried_cache_refresh`` ."""
315+ """Clears the model id/version and s3 cache."""
282316 self ._s3_cache .clear ()
283317 self ._model_id_semantic_version_manifest_key_cache .clear ()
284- self ._has_retried_cache_refresh = False
0 commit comments