1616from sagemaker import image_uris
1717from sagemaker .jumpstart .constants import (
1818 JUMPSTART_DEFAULT_REGION_NAME ,
19- INFERENCE ,
20- TRAINING ,
21- SUPPORTED_JUMPSTART_SCOPES ,
2219)
2320from sagemaker .jumpstart .enums import (
21+ JumpStartScriptScope ,
2422 ModelFramework ,
2523 VariableScope ,
2624)
27- from sagemaker .jumpstart .utils import get_jumpstart_content_bucket
25+ from sagemaker .jumpstart .utils import (
26+ get_jumpstart_content_bucket ,
27+ verify_model_region_and_return_specs ,
28+ )
2829from sagemaker .jumpstart import accessors as jumpstart_accessors
2930
3031
@@ -42,6 +43,8 @@ def _retrieve_image_uri(
4243 distribution : Optional [str ],
4344 base_framework_version : Optional [str ],
4445 training_compiler_config : Optional [str ],
46+ tolerate_vulnerable_model : bool ,
47+ tolerate_deprecated_model : bool ,
4548):
4649 """Retrieves the container image URI for JumpStart models.
4750
@@ -74,40 +77,38 @@ def _retrieve_image_uri(
7477 distribution (dict): A dictionary with information on how to run distributed training
7578 training_compiler_config (:class:`~sagemaker.training_compiler.TrainingCompilerConfig`):
7679 A configuration class for the SageMaker Training Compiler.
80+ tolerate_vulnerable_model (bool): True if vulnerable versions of model
81+ specifications should be tolerated (exception not raised). If False, raises an
82+ exception if the script used by this version of the model has dependencies with known
83+ security vulnerabilities.
84+ tolerate_deprecated_model (bool): True if deprecated versions of model
85+ specifications should be tolerated (exception not raised). If False, raises
86+ an exception if the version of the model is deprecated.
7787
7888 Returns:
7989 str: the ECR URI for the corresponding SageMaker Docker image.
8090
8191 Raises:
8292 ValueError: If the combination of arguments specified is not supported.
93+ VulnerableJumpStartModelError: If any of the dependencies required by the script have
94+ known security vulnerabilities.
95+ DeprecatedJumpStartModelError: If the version of the model is deprecated.
8396 """
8497 if region is None :
8598 region = JUMPSTART_DEFAULT_REGION_NAME
8699
87- assert region is not None
88-
89- if image_scope is None :
90- raise ValueError (
91- "Must specify `image_scope` argument to retrieve image uri for JumpStart models."
92- )
93- if image_scope not in SUPPORTED_JUMPSTART_SCOPES :
94- raise ValueError (
95- f"JumpStart models only support scopes: { ', ' .join (SUPPORTED_JUMPSTART_SCOPES )} ."
96- )
97-
98- model_specs = jumpstart_accessors .JumpStartModelsAccessor .get_model_specs (
99- region = region , model_id = model_id , version = model_version
100+ model_specs = verify_model_region_and_return_specs (
101+ model_id = model_id ,
102+ version = model_version ,
103+ scope = image_scope ,
104+ region = region ,
105+ tolerate_vulnerable_model = tolerate_vulnerable_model ,
106+ tolerate_deprecated_model = tolerate_deprecated_model ,
100107 )
101108
102- if image_scope == INFERENCE :
109+ if image_scope == JumpStartScriptScope . INFERENCE :
103110 ecr_specs = model_specs .hosting_ecr_specs
104- elif image_scope == TRAINING :
105- if not model_specs .training_supported :
106- raise ValueError (
107- f"JumpStart model ID '{ model_id } ' and version '{ model_version } ' "
108- "does not support training."
109- )
110- assert model_specs .training_ecr_specs is not None
111+ elif image_scope == JumpStartScriptScope .TRAINING :
111112 ecr_specs = model_specs .training_ecr_specs
112113
113114 if framework is not None and framework != ecr_specs .framework :
@@ -130,11 +131,11 @@ def _retrieve_image_uri(
130131
131132 base_framework_version_override : Optional [str ] = None
132133 version_override : Optional [str ] = None
133- if ecr_specs .framework == ModelFramework .HUGGINGFACE . value :
134+ if ecr_specs .framework == ModelFramework .HUGGINGFACE :
134135 base_framework_version_override = ecr_specs .framework_version
135136 version_override = ecr_specs .huggingface_transformers_version
136137
137- if image_scope == TRAINING :
138+ if image_scope == JumpStartScriptScope . TRAINING :
138139 return image_uris .get_training_image_uri (
139140 region = region ,
140141 framework = ecr_specs .framework ,
@@ -170,6 +171,8 @@ def _retrieve_model_uri(
170171 model_version : str ,
171172 model_scope : Optional [str ],
172173 region : Optional [str ],
174+ tolerate_vulnerable_model : bool ,
175+ tolerate_deprecated_model : bool ,
173176):
174177 """Retrieves the model artifact S3 URI for the model matching the given arguments.
175178
@@ -181,40 +184,37 @@ def _retrieve_model_uri(
181184 model_scope (str): The model type, i.e. what it is used for.
182185 Valid values: "training" and "inference".
183186 region (str): Region for which to retrieve model S3 URI.
187+ tolerate_vulnerable_model (bool): True if vulnerable versions of model
188+ specifications should be tolerated (exception not raised). If False, raises an
189+ exception if the script used by this version of the model has dependencies with known
190+ security vulnerabilities.
191+ tolerate_deprecated_model (bool): True if deprecated versions of model
192+ specifications should be tolerated (exception not raised). If False, raises
193+ an exception if the version of the model is deprecated.
184194 Returns:
185195 str: the model artifact S3 URI for the corresponding model.
186196
187197 Raises:
188198 ValueError: If the combination of arguments specified is not supported.
199+ VulnerableJumpStartModelError: If any of the dependencies required by the script have
200+ known security vulnerabilities.
201+ DeprecatedJumpStartModelError: If the version of the model is deprecated.
189202 """
190203 if region is None :
191204 region = JUMPSTART_DEFAULT_REGION_NAME
192205
193- assert region is not None
194-
195- if model_scope is None :
196- raise ValueError (
197- "Must specify `model_scope` argument to retrieve model "
198- "artifact uri for JumpStart models."
199- )
200-
201- if model_scope not in SUPPORTED_JUMPSTART_SCOPES :
202- raise ValueError (
203- f"JumpStart models only support scopes: { ', ' .join (SUPPORTED_JUMPSTART_SCOPES )} ."
204- )
205-
206- model_specs = jumpstart_accessors .JumpStartModelsAccessor .get_model_specs (
207- region = region , model_id = model_id , version = model_version
206+ model_specs = verify_model_region_and_return_specs (
207+ model_id = model_id ,
208+ version = model_version ,
209+ scope = model_scope ,
210+ region = region ,
211+ tolerate_vulnerable_model = tolerate_vulnerable_model ,
212+ tolerate_deprecated_model = tolerate_deprecated_model ,
208213 )
209- if model_scope == INFERENCE :
214+
215+ if model_scope == JumpStartScriptScope .INFERENCE :
210216 model_artifact_key = model_specs .hosting_artifact_key
211- elif model_scope == TRAINING :
212- if not model_specs .training_supported :
213- raise ValueError (
214- f"JumpStart model ID '{ model_id } ' and version '{ model_version } ' "
215- "does not support training."
216- )
217- assert model_specs .training_artifact_key is not None
217+ elif model_scope == JumpStartScriptScope .TRAINING :
218218 model_artifact_key = model_specs .training_artifact_key
219219
220220 bucket = get_jumpstart_content_bucket (region )
@@ -229,6 +229,8 @@ def _retrieve_script_uri(
229229 model_version : str ,
230230 script_scope : Optional [str ],
231231 region : Optional [str ],
232+ tolerate_vulnerable_model : bool ,
233+ tolerate_deprecated_model : bool ,
232234):
233235 """Retrieves the script S3 URI associated with the model matching the given arguments.
234236
@@ -240,40 +242,37 @@ def _retrieve_script_uri(
240242 script_scope (str): The script type, i.e. what it is used for.
241243 Valid values: "training" and "inference".
242244 region (str): Region for which to retrieve model script S3 URI.
245+ tolerate_vulnerable_model (bool): True if vulnerable versions of model
246+ specifications should be tolerated (exception not raised). If False, raises an
247+ exception if the script used by this version of the model has dependencies with known
248+ security vulnerabilities.
249+ tolerate_deprecated_model (bool): True if deprecated versions of model
250+ specifications should be tolerated (exception not raised). If False, raises
251+ an exception if the version of the model is deprecated.
243252 Returns:
244253 str: the model script URI for the corresponding model.
245254
246255 Raises:
247256 ValueError: If the combination of arguments specified is not supported.
257+ VulnerableJumpStartModelError: If any of the dependencies required by the script have
258+ known security vulnerabilities.
259+ DeprecatedJumpStartModelError: If the version of the model is deprecated.
248260 """
249261 if region is None :
250262 region = JUMPSTART_DEFAULT_REGION_NAME
251263
252- assert region is not None
253-
254- if script_scope is None :
255- raise ValueError (
256- "Must specify `script_scope` argument to retrieve model script uri for "
257- "JumpStart models."
258- )
259-
260- if script_scope not in SUPPORTED_JUMPSTART_SCOPES :
261- raise ValueError (
262- f"JumpStart models only support scopes: { ', ' .join (SUPPORTED_JUMPSTART_SCOPES )} ."
263- )
264-
265- model_specs = jumpstart_accessors .JumpStartModelsAccessor .get_model_specs (
266- region = region , model_id = model_id , version = model_version
264+ model_specs = verify_model_region_and_return_specs (
265+ model_id = model_id ,
266+ version = model_version ,
267+ scope = script_scope ,
268+ region = region ,
269+ tolerate_vulnerable_model = tolerate_vulnerable_model ,
270+ tolerate_deprecated_model = tolerate_deprecated_model ,
267271 )
268- if script_scope == INFERENCE :
272+
273+ if script_scope == JumpStartScriptScope .INFERENCE :
269274 model_script_key = model_specs .hosting_script_key
270- elif script_scope == TRAINING :
271- if not model_specs .training_supported :
272- raise ValueError (
273- f"JumpStart model ID '{ model_id } ' and version '{ model_version } ' "
274- "does not support training."
275- )
276- assert model_specs .training_script_key is not None
275+ elif script_scope == JumpStartScriptScope .TRAINING :
277276 model_script_key = model_specs .training_script_key
278277
279278 bucket = get_jumpstart_content_bucket (region )
@@ -311,8 +310,6 @@ def _retrieve_default_hyperparameters(
311310 if region is None :
312311 region = JUMPSTART_DEFAULT_REGION_NAME
313312
314- assert region is not None
315-
316313 model_specs = jumpstart_accessors .JumpStartModelsAccessor .get_model_specs (
317314 region = region , model_id = model_id , version = model_version
318315 )
0 commit comments