1010# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
1111# ANY KIND, either express or implied. See the License for the specific
1212# language governing permissions and limitations under the License.
13+ """This module defines the JumpStartModelsCache class."""
14+ from __future__ import absolute_import
1315import datetime
1416from typing import List , Optional
17+ import json
18+ import boto3
19+ import semantic_version
1520from sagemaker .jumpstart .types import (
1621 JumpStartCachedS3ContentKey ,
1722 JumpStartCachedS3ContentValue ,
1823 JumpStartModelHeader ,
1924 JumpStartModelSpecs ,
20- JumpStartModelSpecs ,
2125 JumpStartS3FileType ,
2226 JumpStartVersionedModelId ,
2327)
2428from sagemaker .jumpstart import utils
2529from sagemaker .utilities .cache import LRUCache
26- import boto3
27- import json
28- import semantic_version
29-
3030
3131DEFAULT_REGION_NAME = boto3 .session .Session ().region_name
3232
4141
4242class JumpStartModelsCache :
4343 """Class that implements a cache for JumpStart models manifests and specs.
44+
4445 The manifest and specs associated with JumpStart models provide the information necessary
4546 for launching JumpStart models from the SageMaker SDK.
4647 """
@@ -62,15 +63,16 @@ def __init__(
6263 Args:
6364 region (Optional[str]): AWS region to associate with cache. Default: region associated
6465 with botocore session.
65- max_s3_cache_items (Optional[int]): Maximum number of files to store in s3 cache. Default: 20.
66- s3_cache_expiration_time (Optional[datetime.timedelta]): Maximum time to hold items in s3
67- cache before invalidation. Default: 6 hours.
66+ max_s3_cache_items (Optional[int]): Maximum number of files to store in s3 cache.
67+ Default: 20.
68+ s3_cache_expiration_time (Optional[datetime.timedelta]): Maximum time to hold items in
69+ s3 cache before invalidation. Default: 6 hours.
6870 max_semantic_version_cache_items (Optional[int]): Maximum number of files to store in
6971 semantic version cache. Default: 20.
70- semantic_version_cache_expiration_time (Optional[datetime.timedelta]): Maximum time to hold
71- items in semantic version cache before invalidation. Default: 6 hours.
72- bucket (Optional[str]): S3 bucket to associate with cache. Default: JumpStart-hosted content
73- bucket for region.
72+ semantic_version_cache_expiration_time (Optional[datetime.timedelta]): Maximum time to
73+ hold items in semantic version cache before invalidation. Default: 6 hours.
74+ bucket (Optional[str]): S3 bucket to associate with cache. Default: JumpStart-hosted
75+ content bucket for region.
7476 """
7577
7678 self ._region = region
@@ -120,15 +122,16 @@ def get_bucket(self) -> None:
120122 return self ._bucket
121123
122124 def _get_manifest_key_from_model_id_semantic_version (
123- self , key : JumpStartVersionedModelId , value : Optional [JumpStartVersionedModelId ]
125+ self ,
126+ key : JumpStartVersionedModelId ,
127+ value : Optional [JumpStartVersionedModelId ], # pylint: disable=W0613
124128 ) -> JumpStartVersionedModelId :
125- """Return model id and version in manifest that matches semantic version/id
126- from customer request.
129+ """Return model id and version in manifest that matches semantic version/id.
127130
128131 Args:
129132 key (JumpStartVersionedModelId): Key for which to fetch versioned model id.
130- value (Optional[JumpStartVersionedModelId]): Unused variable for current value of old cached
131- model id/version.
133+ value (Optional[JumpStartVersionedModelId]): Unused variable for current value of
134+ old cached model id/version.
132135
133136 Raises:
134137 KeyError: If the semantic version is not found in the manifest.
@@ -158,42 +161,42 @@ def _get_manifest_key_from_model_id_semantic_version(
158161 sm_compatible_model_version = spec .select (versions_compatible_with_sagemaker )
159162 if sm_compatible_model_version is not None :
160163 return JumpStartVersionedModelId (model_id , str (sm_compatible_model_version ))
161- else :
162- versions_incompatible_with_sagemaker = [
163- semantic_version .Version (header .version )
164+
165+ versions_incompatible_with_sagemaker = [
166+ semantic_version .Version (header .version )
167+ for _ , header in manifest .items ()
168+ if header .model_id == model_id
169+ ]
170+ sm_incompatible_model_version = spec .select (versions_incompatible_with_sagemaker )
171+ if sm_incompatible_model_version is not None :
172+ model_version_to_use_incompatible_with_sagemaker = str (sm_incompatible_model_version )
173+ sm_version_to_use = [
174+ header .min_version
164175 for _ , header in manifest .items ()
165176 if header .model_id == model_id
177+ and header .version == model_version_to_use_incompatible_with_sagemaker
166178 ]
167- sm_incompatible_model_version = spec .select (versions_incompatible_with_sagemaker )
168- if sm_incompatible_model_version is not None :
169- model_version_to_use_incompatible_with_sagemaker = str (
170- sm_incompatible_model_version
171- )
172- sm_version_to_use = [
173- header .min_version
174- for _ , header in manifest .items ()
175- if header .model_id == model_id
176- and header .version == model_version_to_use_incompatible_with_sagemaker
177- ]
178- assert len (sm_version_to_use ) == 1
179- sm_version_to_use = sm_version_to_use [0 ]
180-
181- error_msg = (
182- f"Unable to find model manifest for { model_id } with version { version } compatible with your SageMaker version ({ sm_version } ). "
183- f"Consider upgrading your SageMaker library to at least version { sm_version_to_use } so you can use version "
184- f"{ model_version_to_use_incompatible_with_sagemaker } of { model_id } ."
185- )
186- raise KeyError (error_msg )
187- else :
188- error_msg = f"Unable to find model manifest for { model_id } with version { version } "
189- raise KeyError (error_msg )
179+ assert len (sm_version_to_use ) == 1
180+ sm_version_to_use = sm_version_to_use [0 ]
181+
182+ error_msg = (
183+ f"Unable to find model manifest for { model_id } with version { version } "
184+ f"compatible with your SageMaker version ({ sm_version } ). "
185+ f"Consider upgrading your SageMaker library to at least version "
186+ f"{ sm_version_to_use } so you can use version "
187+ f"{ model_version_to_use_incompatible_with_sagemaker } of { model_id } ."
188+ )
189+ raise KeyError (error_msg )
190+ error_msg = f"Unable to find model manifest for { model_id } with version { version } "
191+ raise KeyError (error_msg )
190192
191193 def _get_file_from_s3 (
192194 self ,
193195 key : JumpStartCachedS3ContentKey ,
194196 value : Optional [JumpStartCachedS3ContentValue ],
195197 ) -> JumpStartCachedS3ContentValue :
196198 """Return s3 content given a file type and s3_key in ``JumpStartCachedS3ContentKey``.
199+
197200 If a manifest file is being fetched, we only download the object if the md5 hash in
198201 ``head_object`` does not match the current md5 hash for the stored value. This prevents
199202 unnecessarily downloading the full manifest when it hasn't changed.
@@ -228,18 +231,18 @@ def _get_file_from_s3(
228231 raise RuntimeError (f"Bad value for key: { key } " )
229232
230233 def get_header (
231- self , model_id : str , semantic_version : Optional [str ] = None
234+ self , model_id : str , semantic_version_str : Optional [str ] = None
232235 ) -> List [JumpStartModelHeader ]:
233236 """Return list of headers for a given JumpStart model id and semantic version.
234237
235238 Args:
236239 model_id (str): model id for which to get a header.
237- semantic_version (Optional[str]): The semantic version for which to get a header.
238- If None, the highest compatible version is returned.
240+ semantic_version_str (Optional[str]): The semantic version for which to get a
241+ header. If None, the highest compatible version is returned.
239242 """
240243
241244 versioned_model_id = self ._model_id_semantic_version_manifest_key_cache .get (
242- JumpStartVersionedModelId (model_id , semantic_version )
245+ JumpStartVersionedModelId (model_id , semantic_version_str )
243246 )
244247 manifest = self ._s3_cache .get (
245248 JumpStartCachedS3ContentKey (JumpStartS3FileType .MANIFEST , self ._manifest_file_s3_key )
@@ -258,16 +261,17 @@ def get_header(
258261 return self .get_header (model_id , semantic_version )
259262
260263 def get_specs (
261- self , model_id : str , semantic_version : Optional [str ] = None
264+ self , model_id : str , semantic_version_str : Optional [str ] = None
262265 ) -> JumpStartModelSpecs :
263266 """Return specs for a given JumpStart model id and semantic version.
264267
265268 Args:
266269 model_id (str): model id for which to get specs.
267- semantic_version (Optional[str]): The semantic version for which to get specs.
268- If None, the highest compatible version is returned.
270+ semantic_version_str (Optional[str]): The semantic version for which to get
271+ specs. If None, the highest compatible version is returned.
269272 """
270- header = self .get_header (model_id , semantic_version )
273+
274+ header = self .get_header (model_id , semantic_version_str )
271275 spec_key = header .spec_key
272276 return self ._s3_cache .get (
273277 JumpStartCachedS3ContentKey (JumpStartS3FileType .SPECS , spec_key )
0 commit comments