Skip to content

Commit 86cf577

Browse files
committed
fix: linting, mypy, logical issues for jumpstart models
1 parent 84bf597 commit 86cf577

File tree

10 files changed

+135
-94
lines changed

10 files changed

+135
-94
lines changed

src/sagemaker/image_uris.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ def retrieve(
9292
)
9393
if image_scope is None:
9494
raise ValueError(
95-
"Must specify `image_scope` argument to retrieve image uri for " "JumpStart models."
95+
"Must specify `image_scope` argument to retrieve image uri for JumpStart models."
9696
)
9797
if image_scope == "inference":
9898
ecr_specs = model_specs.hosting_ecr_specs
@@ -103,7 +103,7 @@ def retrieve(
103103
else:
104104
raise ValueError("JumpStart models only support inference and training.")
105105

106-
if framework != None and framework != ecr_specs.framework:
106+
if framework is not None and framework != ecr_specs.framework:
107107
raise ValueError(
108108
f"Bad value for container framework for JumpStart model: '{framework}'."
109109
)

src/sagemaker/jumpstart/accessors.py

Lines changed: 59 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -40,41 +40,78 @@ class JumpStartModelsCache(object):
4040
_cache: Optional[cache.JumpStartModelsCache] = None
4141
_curr_region = JUMPSTART_DEFAULT_REGION_NAME
4242

43-
_cache_kwargs = {}
43+
_cache_kwargs: Dict[str, Any] = {}
4444

45+
@staticmethod
4546
def _validate_region_cache_kwargs(
46-
cache_kwargs: Dict[str, Any] = {}, region: Optional[str] = None
47-
):
48-
if region is not None and "region" in cache_kwargs:
49-
if region != cache_kwargs["region"]:
47+
cache_kwargs: Optional[Dict[str, Any]] = None, region: Optional[str] = None
48+
) -> Dict[str, Any]:
49+
"""Returns cache_kwargs with region argument removed if present.
50+
51+
Raises:
52+
ValueError: If region in `cache_kwargs` is inconsistent with `region` argument.
53+
54+
Args:
55+
cache_kwargs (Optional[Dict[str, Any]]): cache kwargs to validate.
56+
region (str): The region to validate along with the kwargs.
57+
"""
58+
cache_kwargs_dict = {} if cache_kwargs is None else cache_kwargs
59+
assert isinstance(cache_kwargs_dict, dict)
60+
if region is not None and "region" in cache_kwargs_dict:
61+
if region != cache_kwargs_dict["region"]:
5062
raise ValueError(
51-
f"Inconsistent region definitions: {region}, {cache_kwargs['region']}"
63+
f"Inconsistent region definitions: {region}, {cache_kwargs_dict['region']}"
5264
)
53-
del cache_kwargs["region"]
54-
return cache_kwargs
65+
del cache_kwargs_dict["region"]
66+
return cache_kwargs_dict
5567

5668
@staticmethod
5769
def get_model_header(region: str, model_id: str, version: str) -> JumpStartModelHeader:
70+
"""Returns model header from JumpStart models cache.
71+
72+
Args:
73+
region (str): region for which to retrieve header.
74+
model_id (str): model id to retrieve.
75+
version (str): semantic version to retrieve for the model id.
76+
"""
5877
cache_kwargs = JumpStartModelsCache._validate_region_cache_kwargs(
5978
JumpStartModelsCache._cache_kwargs, region
6079
)
61-
if JumpStartModelsCache._cache == None or region != JumpStartModelsCache._curr_region:
80+
if JumpStartModelsCache._cache is None or region != JumpStartModelsCache._curr_region:
6281
JumpStartModelsCache._cache = cache.JumpStartModelsCache(region=region, **cache_kwargs)
6382
JumpStartModelsCache._curr_region = region
83+
assert JumpStartModelsCache._cache is not None
6484
return JumpStartModelsCache._cache.get_header(model_id, version)
6585

6686
@staticmethod
6787
def get_model_specs(region: str, model_id: str, version: str) -> JumpStartModelSpecs:
88+
"""Returns model specs from JumpStart models cache.
89+
90+
Args:
91+
region (str): region for which to retrieve header.
92+
model_id (str): model id to retrieve.
93+
version (str): semantic version to retrieve for the model id.
94+
"""
6895
cache_kwargs = JumpStartModelsCache._validate_region_cache_kwargs(
6996
JumpStartModelsCache._cache_kwargs, region
7097
)
71-
if JumpStartModelsCache._cache == None or region != JumpStartModelsCache._curr_region:
98+
if JumpStartModelsCache._cache is None or region != JumpStartModelsCache._curr_region:
7299
JumpStartModelsCache._cache = cache.JumpStartModelsCache(region=region, **cache_kwargs)
73100
JumpStartModelsCache._curr_region = region
101+
assert JumpStartModelsCache._cache is not None
74102
return JumpStartModelsCache._cache.get_specs(model_id, version)
75103

76104
@staticmethod
77105
def set_cache_kwargs(cache_kwargs: Dict[str, Any], region: str = None) -> None:
106+
"""Sets cache kwargs. Clears the cache.
107+
108+
Raises:
109+
ValueError: If region in `cache_kwargs` is inconsistent with `region` argument.
110+
111+
Args:
112+
cache_kwargs (str): cache kwargs to validate.
113+
region (str): The region to validate along with the kwargs.
114+
"""
78115
cache_kwargs = JumpStartModelsCache._validate_region_cache_kwargs(cache_kwargs, region)
79116
JumpStartModelsCache._cache_kwargs = cache_kwargs
80117
if region is None:
@@ -88,15 +125,15 @@ def set_cache_kwargs(cache_kwargs: Dict[str, Any], region: str = None) -> None:
88125
)
89126

90127
@staticmethod
91-
def reset_cache(cache_kwargs: Dict[str, Any] = {}, region: str = None) -> None:
92-
cache_kwargs = JumpStartModelsCache._validate_region_cache_kwargs(cache_kwargs, region)
93-
JumpStartModelsCache._cache_kwargs = cache_kwargs
94-
if region is None:
95-
JumpStartModelsCache._cache = cache.JumpStartModelsCache(
96-
**JumpStartModelsCache._cache_kwargs
97-
)
98-
else:
99-
JumpStartModelsCache._curr_region = region
100-
JumpStartModelsCache._cache = cache.JumpStartModelsCache(
101-
region=region, **JumpStartModelsCache._cache_kwargs
102-
)
128+
def reset_cache(cache_kwargs: Dict[str, Any] = None, region: str = None) -> None:
129+
"""Resets cache, optionally allowing cache kwargs to be passed to the new cache.
130+
131+
Raises:
132+
ValueError: If region in `cache_kwargs` is inconsistent with `region` argument.
133+
134+
Args:
135+
cache_kwargs (str): cache kwargs to validate.
136+
region (str): The region to validate along with the kwargs.
137+
"""
138+
cache_kwargs_dict = {} if cache_kwargs is None else cache_kwargs
139+
JumpStartModelsCache.set_cache_kwargs(cache_kwargs_dict, region)

src/sagemaker/jumpstart/cache.py

Lines changed: 48 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -48,37 +48,37 @@ class JumpStartModelsCache:
4848
for launching JumpStart models from the SageMaker SDK.
4949
"""
5050

51+
# fmt: off
5152
def __init__(
5253
self,
53-
region: Optional[str] = JUMPSTART_DEFAULT_REGION_NAME,
54-
max_s3_cache_items: Optional[int] = JUMPSTART_DEFAULT_MAX_S3_CACHE_ITEMS,
55-
s3_cache_expiration_horizon: Optional[
56-
datetime.timedelta
57-
] = JUMPSTART_DEFAULT_S3_CACHE_EXPIRATION_HORIZON,
58-
max_semantic_version_cache_items: Optional[
59-
int
60-
] = JUMPSTART_DEFAULT_MAX_SEMANTIC_VERSION_CACHE_ITEMS,
61-
semantic_version_cache_expiration_horizon: Optional[
62-
datetime.timedelta
63-
] = JUMPSTART_DEFAULT_SEMANTIC_VERSION_CACHE_EXPIRATION_HORIZON,
64-
manifest_file_s3_key: Optional[str] = JUMPSTART_DEFAULT_MANIFEST_FILE_S3_KEY,
54+
region: str = JUMPSTART_DEFAULT_REGION_NAME,
55+
max_s3_cache_items: int = JUMPSTART_DEFAULT_MAX_S3_CACHE_ITEMS,
56+
s3_cache_expiration_horizon: datetime.timedelta =
57+
JUMPSTART_DEFAULT_S3_CACHE_EXPIRATION_HORIZON,
58+
max_semantic_version_cache_items: int =
59+
JUMPSTART_DEFAULT_MAX_SEMANTIC_VERSION_CACHE_ITEMS,
60+
semantic_version_cache_expiration_horizon: datetime.timedelta =
61+
JUMPSTART_DEFAULT_SEMANTIC_VERSION_CACHE_EXPIRATION_HORIZON,
62+
manifest_file_s3_key: str =
63+
JUMPSTART_DEFAULT_MANIFEST_FILE_S3_KEY,
6564
s3_bucket_name: Optional[str] = None,
6665
s3_client_config: Optional[botocore.config.Config] = None,
67-
) -> None:
66+
) -> None: # fmt: on
6867
"""Initialize a ``JumpStartModelsCache`` instance.
6968
7069
Args:
71-
region (Optional[str]): AWS region to associate with cache. Default: region associated
70+
region (str): AWS region to associate with cache. Default: region associated
7271
with boto3 session.
73-
max_s3_cache_items (Optional[int]): Maximum number of items to store in s3 cache.
72+
max_s3_cache_items (int): Maximum number of items to store in s3 cache.
7473
Default: 20.
75-
s3_cache_expiration_horizon (Optional[datetime.timedelta]): Maximum time to hold
74+
s3_cache_expiration_horizon (datetime.timedelta): Maximum time to hold
7675
items in s3 cache before invalidation. Default: 6 hours.
77-
max_semantic_version_cache_items (Optional[int]): Maximum number of items to store in
76+
max_semantic_version_cache_items (int): Maximum number of items to store in
7877
semantic version cache. Default: 20.
79-
semantic_version_cache_expiration_horizon (Optional[datetime.timedelta]):
78+
semantic_version_cache_expiration_horizon (datetime.timedelta):
8079
Maximum time to hold items in semantic version cache before invalidation.
8180
Default: 6 hours.
81+
manifest_file_s3_key (str): The key in S3 corresponding to the sdk metadata manifest.
8282
s3_bucket_name (Optional[str]): S3 bucket to associate with cache.
8383
Default: JumpStart-hosted content bucket for region.
8484
s3_client_config (Optional[botocore.config.Config]): s3 client config to use for cache.
@@ -126,7 +126,7 @@ def set_manifest_file_s3_key(self, key: str) -> None:
126126
self._manifest_file_s3_key = key
127127
self.clear()
128128

129-
def get_manifest_file_s3_key(self) -> None:
129+
def get_manifest_file_s3_key(self) -> str:
130130
"""Return manifest file s3 key for cache."""
131131
return self._manifest_file_s3_key
132132

@@ -136,7 +136,7 @@ def set_s3_bucket_name(self, s3_bucket_name: str) -> None:
136136
self.s3_bucket_name = s3_bucket_name
137137
self.clear()
138138

139-
def get_bucket(self) -> None:
139+
def get_bucket(self) -> str:
140140
"""Return bucket used for cache."""
141141
return self.s3_bucket_name
142142

@@ -166,6 +166,7 @@ def _get_manifest_key_from_model_id_semantic_version(
166166
manifest = self._s3_cache.get(
167167
JumpStartCachedS3ContentKey(JumpStartS3FileType.MANIFEST, self._manifest_file_s3_key)
168168
).formatted_content
169+
assert isinstance(manifest, dict)
169170

170171
sm_version = utils.get_sagemaker_version()
171172

@@ -191,16 +192,16 @@ def _get_manifest_key_from_model_id_semantic_version(
191192

192193
if sm_incompatible_model_version is not None:
193194
model_version_to_use_incompatible_with_sagemaker = sm_incompatible_model_version
194-
sm_version_to_use = [
195+
sm_version_to_use_list = [
195196
header.min_version
196197
for header in manifest.values()
197198
if header.model_id == model_id
198199
and header.version == model_version_to_use_incompatible_with_sagemaker
199200
]
200-
if len(sm_version_to_use) != 1:
201+
if len(sm_version_to_use_list) != 1:
201202
# ``manifest`` dict should already enforce this
202203
raise RuntimeError("Found more than one incompatible SageMaker version to use.")
203-
sm_version_to_use = sm_version_to_use[0]
204+
sm_version_to_use = sm_version_to_use_list[0]
204205

205206
error_msg = (
206207
f"Unable to find model manifest for {model_id} with version {version} "
@@ -258,9 +259,12 @@ def _get_file_from_s3(
258259
def get_manifest(self) -> List[JumpStartModelHeader]:
259260
"""Return entire JumpStart models manifest."""
260261

261-
return self._s3_cache.get(
262+
manifest_dict = self._s3_cache.get(
262263
JumpStartCachedS3ContentKey(JumpStartS3FileType.MANIFEST, self._manifest_file_s3_key)
263-
).formatted_content.values()
264+
).formatted_content
265+
assert isinstance(manifest_dict, dict)
266+
manifest = list(manifest_dict.values())
267+
return manifest
264268

265269
def get_header(self, model_id: str, semantic_version_str: str) -> JumpStartModelHeader:
266270
"""Return header for a given JumpStart model id and semantic version.
@@ -277,30 +281,30 @@ def _select_version(
277281
self,
278282
semantic_version_str: str,
279283
available_versions: List[Version],
280-
) -> Optional[Version]:
281-
"""Utility to select appropriate version from available version given
282-
a semantic version with which to filter.
284+
) -> Optional[str]:
285+
"""Utility to select appropriate version from available versions.
283286
284287
Args:
285288
semantic_version_str (str): the semantic version for which to filter
286289
available versions.
287290
available_versions (List[Version]): list of available versions.
288291
"""
289292
if semantic_version_str == "*":
290-
if len(available_versions) is 0:
293+
if len(available_versions) == 0:
291294
return None
292-
else:
293-
return str(max(available_versions))
294-
else:
295-
spec = SpecifierSet(f"=={semantic_version_str}")
296-
available_versions = list(spec.filter(available_versions))
297-
return str(available_versions[0]) if available_versions != [] else None
295+
return str(max(available_versions))
296+
297+
spec = SpecifierSet(f"=={semantic_version_str}")
298+
available_versions_filtered = list(spec.filter(available_versions))
299+
return (
300+
str(available_versions_filtered[0]) if available_versions_filtered != [] else None
301+
)
298302

299303
def _get_header_impl(
300304
self,
301305
model_id: str,
302306
semantic_version_str: str,
303-
attempt: Optional[int] = 0,
307+
attempt: int = 0,
304308
) -> JumpStartModelHeader:
305309
"""Lower-level function to return header.
306310
@@ -310,7 +314,7 @@ def _get_header_impl(
310314
model_id (str): model id for which to get a header.
311315
semantic_version_str (str): The semantic version for which to get a
312316
header.
313-
attempt (Optional[int]): attempt number at retrieving a header.
317+
attempt (int): attempt number at retrieving a header.
314318
"""
315319

316320
versioned_model_id = self._model_id_semantic_version_manifest_key_cache.get(
@@ -320,7 +324,10 @@ def _get_header_impl(
320324
JumpStartCachedS3ContentKey(JumpStartS3FileType.MANIFEST, self._manifest_file_s3_key)
321325
).formatted_content
322326
try:
323-
return manifest[versioned_model_id]
327+
assert isinstance(manifest, dict)
328+
header = manifest[versioned_model_id]
329+
assert isinstance(header, JumpStartModelHeader)
330+
return header
324331
except KeyError:
325332
if attempt > 0:
326333
raise
@@ -338,9 +345,11 @@ def get_specs(self, model_id: str, semantic_version_str: str) -> JumpStartModelS
338345

339346
header = self.get_header(model_id, semantic_version_str)
340347
spec_key = header.spec_key
341-
return self._s3_cache.get(
348+
specs = self._s3_cache.get(
342349
JumpStartCachedS3ContentKey(JumpStartS3FileType.SPECS, spec_key)
343350
).formatted_content
351+
assert isinstance(specs, JumpStartModelSpecs)
352+
return specs
344353

345354
def clear(self) -> None:
346355
"""Clears the model id/version and s3 cache."""

src/sagemaker/jumpstart/types.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -274,15 +274,15 @@ def __init__(
274274
self,
275275
formatted_content: Union[
276276
Dict[JumpStartVersionedModelId, JumpStartModelHeader],
277-
List[JumpStartModelSpecs],
277+
JumpStartModelSpecs,
278278
],
279279
md5_hash: Optional[str] = None,
280280
) -> None:
281281
"""Instantiates JumpStartCachedS3ContentValue object.
282282
283283
Args:
284284
formatted_content (Union[Dict[JumpStartVersionedModelId, JumpStartModelHeader],
285-
List[JumpStartModelSpecs]]):
285+
JumpStartModelSpecs]):
286286
Formatted content for model specs and mappings from
287287
versioned model ids to specs.
288288
md5_hash (str): md5_hash for stored file content from s3.

src/sagemaker/model_uris.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,7 @@
1313
"""Functions for generating S3 model artifact URIs for pre-built SageMaker models."""
1414
from __future__ import absolute_import
1515

16-
import json
1716
import logging
18-
import os
19-
import re
2017
from typing import Optional
2118

2219
from sagemaker.jumpstart import utils as jumpstart_utils
@@ -56,13 +53,15 @@ def retrieve(
5653
)
5754
if model_scope is None:
5855
raise ValueError(
59-
"Must specify `model_scope` argument to retrieve model artifact uri for JumpStart models."
56+
"Must specify `model_scope` argument to retrieve model "
57+
"artifact uri for JumpStart models."
6058
)
6159
if model_scope == "inference":
6260
model_artifact_key = model_specs.hosting_artifact_key
6361
elif model_scope == "training":
6462
if not model_specs.training_supported:
6563
raise ValueError(f"JumpStart model id '{model_id}' does not support training.")
64+
assert model_specs.training_artifact_key is not None
6665
model_artifact_key = model_specs.training_artifact_key
6766
else:
6867
raise ValueError("JumpStart models only support inference and training.")

src/sagemaker/script_uris.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,7 @@
1313
"""Functions for generating S3 model script URIs for pre-built SageMaker models."""
1414
from __future__ import absolute_import
1515

16-
import json
1716
import logging
18-
import os
19-
import re
2017

2118
from sagemaker.jumpstart import utils as jumpstart_utils
2219
from sagemaker.jumpstart import accessors as jumpstart_accessors

0 commit comments

Comments
 (0)