@@ -48,37 +48,37 @@ class JumpStartModelsCache:
4848 for launching JumpStart models from the SageMaker SDK.
4949 """
5050
51+ # fmt: off
5152 def __init__ (
5253 self ,
53- region : Optional [str ] = JUMPSTART_DEFAULT_REGION_NAME ,
54- max_s3_cache_items : Optional [int ] = JUMPSTART_DEFAULT_MAX_S3_CACHE_ITEMS ,
55- s3_cache_expiration_horizon : Optional [
56- datetime .timedelta
57- ] = JUMPSTART_DEFAULT_S3_CACHE_EXPIRATION_HORIZON ,
58- max_semantic_version_cache_items : Optional [
59- int
60- ] = JUMPSTART_DEFAULT_MAX_SEMANTIC_VERSION_CACHE_ITEMS ,
61- semantic_version_cache_expiration_horizon : Optional [
62- datetime .timedelta
63- ] = JUMPSTART_DEFAULT_SEMANTIC_VERSION_CACHE_EXPIRATION_HORIZON ,
64- manifest_file_s3_key : Optional [str ] = JUMPSTART_DEFAULT_MANIFEST_FILE_S3_KEY ,
54+ region : str = JUMPSTART_DEFAULT_REGION_NAME ,
55+ max_s3_cache_items : int = JUMPSTART_DEFAULT_MAX_S3_CACHE_ITEMS ,
56+ s3_cache_expiration_horizon : datetime .timedelta =
57+ JUMPSTART_DEFAULT_S3_CACHE_EXPIRATION_HORIZON ,
58+ max_semantic_version_cache_items : int =
59+ JUMPSTART_DEFAULT_MAX_SEMANTIC_VERSION_CACHE_ITEMS ,
60+ semantic_version_cache_expiration_horizon : datetime .timedelta =
61+ JUMPSTART_DEFAULT_SEMANTIC_VERSION_CACHE_EXPIRATION_HORIZON ,
62+ manifest_file_s3_key : str =
63+ JUMPSTART_DEFAULT_MANIFEST_FILE_S3_KEY ,
6564 s3_bucket_name : Optional [str ] = None ,
6665 s3_client_config : Optional [botocore .config .Config ] = None ,
67- ) -> None :
66+ ) -> None : # fmt: on
6867 """Initialize a ``JumpStartModelsCache`` instance.
6968
7069 Args:
71- region (Optional[ str] ): AWS region to associate with cache. Default: region associated
70+ region (str): AWS region to associate with cache. Default: region associated
7271 with boto3 session.
73- max_s3_cache_items (Optional[ int] ): Maximum number of items to store in s3 cache.
72+ max_s3_cache_items (int): Maximum number of items to store in s3 cache.
7473 Default: 20.
75- s3_cache_expiration_horizon (Optional[ datetime.timedelta] ): Maximum time to hold
74+ s3_cache_expiration_horizon (datetime.timedelta): Maximum time to hold
7675 items in s3 cache before invalidation. Default: 6 hours.
77- max_semantic_version_cache_items (Optional[ int] ): Maximum number of items to store in
76+ max_semantic_version_cache_items (int): Maximum number of items to store in
7877 semantic version cache. Default: 20.
79- semantic_version_cache_expiration_horizon (Optional[ datetime.timedelta] ):
78+ semantic_version_cache_expiration_horizon (datetime.timedelta):
8079 Maximum time to hold items in semantic version cache before invalidation.
8180 Default: 6 hours.
81+ manifest_file_s3_key (str): The key in S3 corresponding to the sdk metadata manifest.
8282 s3_bucket_name (Optional[str]): S3 bucket to associate with cache.
8383 Default: JumpStart-hosted content bucket for region.
8484 s3_client_config (Optional[botocore.config.Config]): s3 client config to use for cache.
@@ -126,7 +126,7 @@ def set_manifest_file_s3_key(self, key: str) -> None:
126126 self ._manifest_file_s3_key = key
127127 self .clear ()
128128
129- def get_manifest_file_s3_key (self ) -> None :
129+ def get_manifest_file_s3_key (self ) -> str :
130130 """Return manifest file s3 key for cache."""
131131 return self ._manifest_file_s3_key
132132
@@ -136,7 +136,7 @@ def set_s3_bucket_name(self, s3_bucket_name: str) -> None:
136136 self .s3_bucket_name = s3_bucket_name
137137 self .clear ()
138138
139- def get_bucket (self ) -> None :
139+ def get_bucket (self ) -> str :
140140 """Return bucket used for cache."""
141141 return self .s3_bucket_name
142142
@@ -166,6 +166,7 @@ def _get_manifest_key_from_model_id_semantic_version(
166166 manifest = self ._s3_cache .get (
167167 JumpStartCachedS3ContentKey (JumpStartS3FileType .MANIFEST , self ._manifest_file_s3_key )
168168 ).formatted_content
169+ assert isinstance (manifest , dict )
169170
170171 sm_version = utils .get_sagemaker_version ()
171172
@@ -191,16 +192,16 @@ def _get_manifest_key_from_model_id_semantic_version(
191192
192193 if sm_incompatible_model_version is not None :
193194 model_version_to_use_incompatible_with_sagemaker = sm_incompatible_model_version
194- sm_version_to_use = [
195+ sm_version_to_use_list = [
195196 header .min_version
196197 for header in manifest .values ()
197198 if header .model_id == model_id
198199 and header .version == model_version_to_use_incompatible_with_sagemaker
199200 ]
200- if len (sm_version_to_use ) != 1 :
201+ if len (sm_version_to_use_list ) != 1 :
201202 # ``manifest`` dict should already enforce this
202203 raise RuntimeError ("Found more than one incompatible SageMaker version to use." )
203- sm_version_to_use = sm_version_to_use [0 ]
204+ sm_version_to_use = sm_version_to_use_list [0 ]
204205
205206 error_msg = (
206207 f"Unable to find model manifest for { model_id } with version { version } "
@@ -258,9 +259,12 @@ def _get_file_from_s3(
258259 def get_manifest (self ) -> List [JumpStartModelHeader ]:
259260 """Return entire JumpStart models manifest."""
260261
261- return self ._s3_cache .get (
262+ manifest_dict = self ._s3_cache .get (
262263 JumpStartCachedS3ContentKey (JumpStartS3FileType .MANIFEST , self ._manifest_file_s3_key )
263- ).formatted_content .values ()
264+ ).formatted_content
265+ assert isinstance (manifest_dict , dict )
266+ manifest = list (manifest_dict .values ())
267+ return manifest
264268
265269 def get_header (self , model_id : str , semantic_version_str : str ) -> JumpStartModelHeader :
266270 """Return header for a given JumpStart model id and semantic version.
@@ -277,30 +281,30 @@ def _select_version(
277281 self ,
278282 semantic_version_str : str ,
279283 available_versions : List [Version ],
280- ) -> Optional [Version ]:
281- """Utility to select appropriate version from available version given
282- a semantic version with which to filter.
284+ ) -> Optional [str ]:
285+ """Utility to select appropriate version from available versions.
283286
284287 Args:
285288 semantic_version_str (str): the semantic version for which to filter
286289 available versions.
287290 available_versions (List[Version]): list of available versions.
288291 """
289292 if semantic_version_str == "*" :
290- if len (available_versions ) is 0 :
293+ if len (available_versions ) == 0 :
291294 return None
292- else :
293- return str (max (available_versions ))
294- else :
295- spec = SpecifierSet (f"=={ semantic_version_str } " )
296- available_versions = list (spec .filter (available_versions ))
297- return str (available_versions [0 ]) if available_versions != [] else None
295+ return str (max (available_versions ))
296+
297+ spec = SpecifierSet (f"=={ semantic_version_str } " )
298+ available_versions_filtered = list (spec .filter (available_versions ))
299+ return (
300+ str (available_versions_filtered [0 ]) if available_versions_filtered != [] else None
301+ )
298302
299303 def _get_header_impl (
300304 self ,
301305 model_id : str ,
302306 semantic_version_str : str ,
303- attempt : Optional [ int ] = 0 ,
307+ attempt : int = 0 ,
304308 ) -> JumpStartModelHeader :
305309 """Lower-level function to return header.
306310
@@ -310,7 +314,7 @@ def _get_header_impl(
310314 model_id (str): model id for which to get a header.
311315 semantic_version_str (str): The semantic version for which to get a
312316 header.
313- attempt (Optional[ int] ): attempt number at retrieving a header.
317+ attempt (int): attempt number at retrieving a header.
314318 """
315319
316320 versioned_model_id = self ._model_id_semantic_version_manifest_key_cache .get (
@@ -320,7 +324,10 @@ def _get_header_impl(
320324 JumpStartCachedS3ContentKey (JumpStartS3FileType .MANIFEST , self ._manifest_file_s3_key )
321325 ).formatted_content
322326 try :
323- return manifest [versioned_model_id ]
327+ assert isinstance (manifest , dict )
328+ header = manifest [versioned_model_id ]
329+ assert isinstance (header , JumpStartModelHeader )
330+ return header
324331 except KeyError :
325332 if attempt > 0 :
326333 raise
@@ -338,9 +345,11 @@ def get_specs(self, model_id: str, semantic_version_str: str) -> JumpStartModelS
338345
339346 header = self .get_header (model_id , semantic_version_str )
340347 spec_key = header .spec_key
341- return self ._s3_cache .get (
348+ specs = self ._s3_cache .get (
342349 JumpStartCachedS3ContentKey (JumpStartS3FileType .SPECS , spec_key )
343350 ).formatted_content
351+ assert isinstance (specs , JumpStartModelSpecs )
352+ return specs
344353
345354 def clear (self ) -> None :
346355 """Clears the model id/version and s3 cache."""
0 commit comments