Skip to content

Commit 76271ea

Browse files
committed
Merge remote-tracking branch 'upstream/master-jumpstart' into feat/hyperparameter-validation
2 parents c6eb14d + c03efb2 commit 76271ea

File tree

17 files changed

+532
-124
lines changed

17 files changed

+532
-124
lines changed

src/sagemaker/environment_variables.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,4 @@ def retrieve_default(
4848
"Must specify `model_id` and `model_version` when retrieving environment variables."
4949
)
5050

51-
# mypy type checking require these assertions
52-
assert model_id is not None
53-
assert model_version is not None
54-
5551
return artifacts._retrieve_default_environment_variables(model_id, model_version, region)

src/sagemaker/image_uris.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,8 @@ def retrieve(
4545
training_compiler_config=None,
4646
model_id=None,
4747
model_version=None,
48+
tolerate_vulnerable_model=False,
49+
tolerate_deprecated_model=False,
4850
) -> str:
4951
"""Retrieves the ECR URI for the Docker image matching the given arguments.
5052
@@ -79,19 +81,26 @@ def retrieve(
7981
(default: None).
8082
model_version (str): Version of the JumpStart model for which to retrieve the
8183
image URI (default: None).
84+
tolerate_vulnerable_model (bool): True if vulnerable versions of model specifications
85+
should be tolerated (exception not raised). If False, raises an exception if
86+
the script used by this version of the model has dependencies with known security
87+
vulnerabilities. (Default: False).
88+
tolerate_deprecated_model (bool): True if deprecated versions of model specifications
89+
should be tolerated (exception not raised). If False, raises an exception
90+
if the version of the model is deprecated. (Default: False).
8291
8392
Returns:
8493
str: the ECR URI for the corresponding SageMaker Docker image.
8594
8695
Raises:
96+
NotImplementedError: If the scope is not supported.
8797
ValueError: If the combination of arguments specified is not supported.
98+
VulnerableJumpStartModelError: If any of the dependencies required by the script have
99+
known security vulnerabilities.
100+
DeprecatedJumpStartModelError: If the version of the model is deprecated.
88101
"""
89102
if is_jumpstart_model_input(model_id, model_version):
90103

91-
# adding assert statements to satisfy mypy type checker
92-
assert model_id is not None
93-
assert model_version is not None
94-
95104
return artifacts._retrieve_image_uri(
96105
model_id,
97106
model_version,
@@ -106,6 +115,8 @@ def retrieve(
106115
distribution,
107116
base_framework_version,
108117
training_compiler_config,
118+
tolerate_vulnerable_model,
119+
tolerate_deprecated_model,
109120
)
110121

111122
if training_compiler_config is None:

src/sagemaker/jumpstart/accessors.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,6 @@ def _validate_and_mutate_region_cache_kwargs(
5656
region (str): The region to validate along with the kwargs.
5757
"""
5858
cache_kwargs_dict = {} if cache_kwargs is None else cache_kwargs
59-
assert isinstance(cache_kwargs_dict, dict)
6059
if region is not None and "region" in cache_kwargs_dict:
6160
if region != cache_kwargs_dict["region"]:
6261
raise ValueError(
@@ -92,8 +91,7 @@ def get_model_header(region: str, model_id: str, version: str) -> JumpStartModel
9291
JumpStartModelsAccessor._cache_kwargs, region
9392
)
9493
JumpStartModelsAccessor._set_cache_and_region(region, cache_kwargs)
95-
assert JumpStartModelsAccessor._cache is not None
96-
return JumpStartModelsAccessor._cache.get_header(
94+
return JumpStartModelsAccessor._cache.get_header( # type: ignore
9795
model_id=model_id, semantic_version_str=version
9896
)
9997

@@ -110,8 +108,7 @@ def get_model_specs(region: str, model_id: str, version: str) -> JumpStartModelS
110108
JumpStartModelsAccessor._cache_kwargs, region
111109
)
112110
JumpStartModelsAccessor._set_cache_and_region(region, cache_kwargs)
113-
assert JumpStartModelsAccessor._cache is not None
114-
return JumpStartModelsAccessor._cache.get_specs(
111+
return JumpStartModelsAccessor._cache.get_specs( # type: ignore
115112
model_id=model_id, semantic_version_str=version
116113
)
117114

src/sagemaker/jumpstart/artifacts.py

Lines changed: 72 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -16,15 +16,16 @@
1616
from sagemaker import image_uris
1717
from sagemaker.jumpstart.constants import (
1818
JUMPSTART_DEFAULT_REGION_NAME,
19-
INFERENCE,
20-
TRAINING,
21-
SUPPORTED_JUMPSTART_SCOPES,
2219
)
2320
from sagemaker.jumpstart.enums import (
21+
JumpStartScriptScope,
2422
ModelFramework,
2523
VariableScope,
2624
)
27-
from sagemaker.jumpstart.utils import get_jumpstart_content_bucket
25+
from sagemaker.jumpstart.utils import (
26+
get_jumpstart_content_bucket,
27+
verify_model_region_and_return_specs,
28+
)
2829
from sagemaker.jumpstart import accessors as jumpstart_accessors
2930

3031

@@ -42,6 +43,8 @@ def _retrieve_image_uri(
4243
distribution: Optional[str],
4344
base_framework_version: Optional[str],
4445
training_compiler_config: Optional[str],
46+
tolerate_vulnerable_model: bool,
47+
tolerate_deprecated_model: bool,
4548
):
4649
"""Retrieves the container image URI for JumpStart models.
4750
@@ -74,40 +77,38 @@ def _retrieve_image_uri(
7477
distribution (dict): A dictionary with information on how to run distributed training
7578
training_compiler_config (:class:`~sagemaker.training_compiler.TrainingCompilerConfig`):
7679
A configuration class for the SageMaker Training Compiler.
80+
tolerate_vulnerable_model (bool): True if vulnerable versions of model
81+
specifications should be tolerated (exception not raised). If False, raises an
82+
exception if the script used by this version of the model has dependencies with known
83+
security vulnerabilities.
84+
tolerate_deprecated_model (bool): True if deprecated versions of model
85+
specifications should be tolerated (exception not raised). If False, raises
86+
an exception if the version of the model is deprecated.
7787
7888
Returns:
7989
str: the ECR URI for the corresponding SageMaker Docker image.
8090
8191
Raises:
8292
ValueError: If the combination of arguments specified is not supported.
93+
VulnerableJumpStartModelError: If any of the dependencies required by the script have
94+
known security vulnerabilities.
95+
DeprecatedJumpStartModelError: If the version of the model is deprecated.
8396
"""
8497
if region is None:
8598
region = JUMPSTART_DEFAULT_REGION_NAME
8699

87-
assert region is not None
88-
89-
if image_scope is None:
90-
raise ValueError(
91-
"Must specify `image_scope` argument to retrieve image uri for JumpStart models."
92-
)
93-
if image_scope not in SUPPORTED_JUMPSTART_SCOPES:
94-
raise ValueError(
95-
f"JumpStart models only support scopes: {', '.join(SUPPORTED_JUMPSTART_SCOPES)}."
96-
)
97-
98-
model_specs = jumpstart_accessors.JumpStartModelsAccessor.get_model_specs(
99-
region=region, model_id=model_id, version=model_version
100+
model_specs = verify_model_region_and_return_specs(
101+
model_id=model_id,
102+
version=model_version,
103+
scope=image_scope,
104+
region=region,
105+
tolerate_vulnerable_model=tolerate_vulnerable_model,
106+
tolerate_deprecated_model=tolerate_deprecated_model,
100107
)
101108

102-
if image_scope == INFERENCE:
109+
if image_scope == JumpStartScriptScope.INFERENCE:
103110
ecr_specs = model_specs.hosting_ecr_specs
104-
elif image_scope == TRAINING:
105-
if not model_specs.training_supported:
106-
raise ValueError(
107-
f"JumpStart model ID '{model_id}' and version '{model_version}' "
108-
"does not support training."
109-
)
110-
assert model_specs.training_ecr_specs is not None
111+
elif image_scope == JumpStartScriptScope.TRAINING:
111112
ecr_specs = model_specs.training_ecr_specs
112113

113114
if framework is not None and framework != ecr_specs.framework:
@@ -130,11 +131,11 @@ def _retrieve_image_uri(
130131

131132
base_framework_version_override: Optional[str] = None
132133
version_override: Optional[str] = None
133-
if ecr_specs.framework == ModelFramework.HUGGINGFACE.value:
134+
if ecr_specs.framework == ModelFramework.HUGGINGFACE:
134135
base_framework_version_override = ecr_specs.framework_version
135136
version_override = ecr_specs.huggingface_transformers_version
136137

137-
if image_scope == TRAINING:
138+
if image_scope == JumpStartScriptScope.TRAINING:
138139
return image_uris.get_training_image_uri(
139140
region=region,
140141
framework=ecr_specs.framework,
@@ -170,6 +171,8 @@ def _retrieve_model_uri(
170171
model_version: str,
171172
model_scope: Optional[str],
172173
region: Optional[str],
174+
tolerate_vulnerable_model: bool,
175+
tolerate_deprecated_model: bool,
173176
):
174177
"""Retrieves the model artifact S3 URI for the model matching the given arguments.
175178
@@ -181,40 +184,37 @@ def _retrieve_model_uri(
181184
model_scope (str): The model type, i.e. what it is used for.
182185
Valid values: "training" and "inference".
183186
region (str): Region for which to retrieve model S3 URI.
187+
tolerate_vulnerable_model (bool): True if vulnerable versions of model
188+
specifications should be tolerated (exception not raised). If False, raises an
189+
exception if the script used by this version of the model has dependencies with known
190+
security vulnerabilities.
191+
tolerate_deprecated_model (bool): True if deprecated versions of model
192+
specifications should be tolerated (exception not raised). If False, raises
193+
an exception if the version of the model is deprecated.
184194
Returns:
185195
str: the model artifact S3 URI for the corresponding model.
186196
187197
Raises:
188198
ValueError: If the combination of arguments specified is not supported.
199+
VulnerableJumpStartModelError: If any of the dependencies required by the script have
200+
known security vulnerabilities.
201+
DeprecatedJumpStartModelError: If the version of the model is deprecated.
189202
"""
190203
if region is None:
191204
region = JUMPSTART_DEFAULT_REGION_NAME
192205

193-
assert region is not None
194-
195-
if model_scope is None:
196-
raise ValueError(
197-
"Must specify `model_scope` argument to retrieve model "
198-
"artifact uri for JumpStart models."
199-
)
200-
201-
if model_scope not in SUPPORTED_JUMPSTART_SCOPES:
202-
raise ValueError(
203-
f"JumpStart models only support scopes: {', '.join(SUPPORTED_JUMPSTART_SCOPES)}."
204-
)
205-
206-
model_specs = jumpstart_accessors.JumpStartModelsAccessor.get_model_specs(
207-
region=region, model_id=model_id, version=model_version
206+
model_specs = verify_model_region_and_return_specs(
207+
model_id=model_id,
208+
version=model_version,
209+
scope=model_scope,
210+
region=region,
211+
tolerate_vulnerable_model=tolerate_vulnerable_model,
212+
tolerate_deprecated_model=tolerate_deprecated_model,
208213
)
209-
if model_scope == INFERENCE:
214+
215+
if model_scope == JumpStartScriptScope.INFERENCE:
210216
model_artifact_key = model_specs.hosting_artifact_key
211-
elif model_scope == TRAINING:
212-
if not model_specs.training_supported:
213-
raise ValueError(
214-
f"JumpStart model ID '{model_id}' and version '{model_version}' "
215-
"does not support training."
216-
)
217-
assert model_specs.training_artifact_key is not None
217+
elif model_scope == JumpStartScriptScope.TRAINING:
218218
model_artifact_key = model_specs.training_artifact_key
219219

220220
bucket = get_jumpstart_content_bucket(region)
@@ -229,6 +229,8 @@ def _retrieve_script_uri(
229229
model_version: str,
230230
script_scope: Optional[str],
231231
region: Optional[str],
232+
tolerate_vulnerable_model: bool,
233+
tolerate_deprecated_model: bool,
232234
):
233235
"""Retrieves the script S3 URI associated with the model matching the given arguments.
234236
@@ -240,40 +242,37 @@ def _retrieve_script_uri(
240242
script_scope (str): The script type, i.e. what it is used for.
241243
Valid values: "training" and "inference".
242244
region (str): Region for which to retrieve model script S3 URI.
245+
tolerate_vulnerable_model (bool): True if vulnerable versions of model
246+
specifications should be tolerated (exception not raised). If False, raises an
247+
exception if the script used by this version of the model has dependencies with known
248+
security vulnerabilities.
249+
tolerate_deprecated_model (bool): True if deprecated versions of model
250+
specifications should be tolerated (exception not raised). If False, raises
251+
an exception if the version of the model is deprecated.
243252
Returns:
244253
str: the model script URI for the corresponding model.
245254
246255
Raises:
247256
ValueError: If the combination of arguments specified is not supported.
257+
VulnerableJumpStartModelError: If any of the dependencies required by the script have
258+
known security vulnerabilities.
259+
DeprecatedJumpStartModelError: If the version of the model is deprecated.
248260
"""
249261
if region is None:
250262
region = JUMPSTART_DEFAULT_REGION_NAME
251263

252-
assert region is not None
253-
254-
if script_scope is None:
255-
raise ValueError(
256-
"Must specify `script_scope` argument to retrieve model script uri for "
257-
"JumpStart models."
258-
)
259-
260-
if script_scope not in SUPPORTED_JUMPSTART_SCOPES:
261-
raise ValueError(
262-
f"JumpStart models only support scopes: {', '.join(SUPPORTED_JUMPSTART_SCOPES)}."
263-
)
264-
265-
model_specs = jumpstart_accessors.JumpStartModelsAccessor.get_model_specs(
266-
region=region, model_id=model_id, version=model_version
264+
model_specs = verify_model_region_and_return_specs(
265+
model_id=model_id,
266+
version=model_version,
267+
scope=script_scope,
268+
region=region,
269+
tolerate_vulnerable_model=tolerate_vulnerable_model,
270+
tolerate_deprecated_model=tolerate_deprecated_model,
267271
)
268-
if script_scope == INFERENCE:
272+
273+
if script_scope == JumpStartScriptScope.INFERENCE:
269274
model_script_key = model_specs.hosting_script_key
270-
elif script_scope == TRAINING:
271-
if not model_specs.training_supported:
272-
raise ValueError(
273-
f"JumpStart model ID '{model_id}' and version '{model_version}' "
274-
"does not support training."
275-
)
276-
assert model_specs.training_script_key is not None
275+
elif script_scope == JumpStartScriptScope.TRAINING:
277276
model_script_key = model_specs.training_script_key
278277

279278
bucket = get_jumpstart_content_bucket(region)
@@ -311,8 +310,6 @@ def _retrieve_default_hyperparameters(
311310
if region is None:
312311
region = JUMPSTART_DEFAULT_REGION_NAME
313312

314-
assert region is not None
315-
316313
model_specs = jumpstart_accessors.JumpStartModelsAccessor.get_model_specs(
317314
region=region, model_id=model_id, version=model_version
318315
)

0 commit comments

Comments
 (0)