1616from sagemaker import image_uris
1717from sagemaker .jumpstart .constants import (
1818 JUMPSTART_DEFAULT_REGION_NAME ,
19- INFERENCE ,
20- TRAINING ,
21- SUPPORTED_JUMPSTART_SCOPES ,
19+ JumpStartScriptScope ,
2220 ModelFramework ,
2321 VariableScope ,
2422)
25- from sagemaker .jumpstart .utils import get_jumpstart_content_bucket
23+ from sagemaker .jumpstart .utils import (
24+ get_jumpstart_content_bucket ,
25+ verify_model_region_and_return_specs ,
26+ )
2627from sagemaker .jumpstart import accessors as jumpstart_accessors
2728
2829
@@ -40,6 +41,8 @@ def _retrieve_image_uri(
4041 distribution : Optional [str ],
4142 base_framework_version : Optional [str ],
4243 training_compiler_config : Optional [str ],
44+ tolerate_vulnerable_model : bool ,
45+ tolerate_deprecated_model : bool ,
4346):
4447 """Retrieves the container image URI for JumpStart models.
4548
@@ -72,40 +75,38 @@ def _retrieve_image_uri(
7275 distribution (dict): A dictionary with information on how to run distributed training
7376 training_compiler_config (:class:`~sagemaker.training_compiler.TrainingCompilerConfig`):
7477 A configuration class for the SageMaker Training Compiler.
78+ tolerate_vulnerable_model (bool): True if vulnerable versions of model
79+ specifications should be tolerated (exception not raised). If False, raises an
80+ exception if the script used by this version of the model has dependencies with known
81+ security vulnerabilities.
82+ tolerate_deprecated_model (bool): True if deprecated versions of model
83+ specifications should be tolerated (exception not raised). If False, raises
84+ an exception if the version of the model is deprecated.
7585
7686 Returns:
7787 str: the ECR URI for the corresponding SageMaker Docker image.
7888
7989 Raises:
8090 ValueError: If the combination of arguments specified is not supported.
91+ VulnerableJumpStartModelError: If any of the dependencies required by the script have
92+ known security vulnerabilities.
93+ DeprecatedJumpStartModelError: If the version of the model is deprecated.
8194 """
8295 if region is None :
8396 region = JUMPSTART_DEFAULT_REGION_NAME
8497
85- assert region is not None
86-
87- if image_scope is None :
88- raise ValueError (
89- "Must specify `image_scope` argument to retrieve image uri for JumpStart models."
90- )
91- if image_scope not in SUPPORTED_JUMPSTART_SCOPES :
92- raise ValueError (
93- f"JumpStart models only support scopes: { ', ' .join (SUPPORTED_JUMPSTART_SCOPES )} ."
94- )
95-
96- model_specs = jumpstart_accessors .JumpStartModelsAccessor .get_model_specs (
97- region = region , model_id = model_id , version = model_version
98+ model_specs = verify_model_region_and_return_specs (
99+ model_id = model_id ,
100+ version = model_version ,
101+ scope = image_scope ,
102+ region = region ,
103+ tolerate_vulnerable_model = tolerate_vulnerable_model ,
104+ tolerate_deprecated_model = tolerate_deprecated_model ,
98105 )
99106
100- if image_scope == INFERENCE :
107+ if image_scope == JumpStartScriptScope . INFERENCE :
101108 ecr_specs = model_specs .hosting_ecr_specs
102- elif image_scope == TRAINING :
103- if not model_specs .training_supported :
104- raise ValueError (
105- f"JumpStart model ID '{ model_id } ' and version '{ model_version } ' "
106- "does not support training."
107- )
108- assert model_specs .training_ecr_specs is not None
109+ elif image_scope == JumpStartScriptScope .TRAINING :
109110 ecr_specs = model_specs .training_ecr_specs
110111
111112 if framework is not None and framework != ecr_specs .framework :
@@ -128,11 +129,11 @@ def _retrieve_image_uri(
128129
129130 base_framework_version_override : Optional [str ] = None
130131 version_override : Optional [str ] = None
131- if ecr_specs .framework == ModelFramework .HUGGINGFACE . value :
132+ if ecr_specs .framework == ModelFramework .HUGGINGFACE :
132133 base_framework_version_override = ecr_specs .framework_version
133134 version_override = ecr_specs .huggingface_transformers_version
134135
135- if image_scope == TRAINING :
136+ if image_scope == JumpStartScriptScope . TRAINING :
136137 return image_uris .get_training_image_uri (
137138 region = region ,
138139 framework = ecr_specs .framework ,
@@ -168,6 +169,8 @@ def _retrieve_model_uri(
168169 model_version : str ,
169170 model_scope : Optional [str ],
170171 region : Optional [str ],
172+ tolerate_vulnerable_model : bool ,
173+ tolerate_deprecated_model : bool ,
171174):
172175 """Retrieves the model artifact S3 URI for the model matching the given arguments.
173176
@@ -179,40 +182,37 @@ def _retrieve_model_uri(
179182 model_scope (str): The model type, i.e. what it is used for.
180183 Valid values: "training" and "inference".
181184 region (str): Region for which to retrieve model S3 URI.
185+ tolerate_vulnerable_model (bool): True if vulnerable versions of model
186+ specifications should be tolerated (exception not raised). If False, raises an
187+ exception if the script used by this version of the model has dependencies with known
188+ security vulnerabilities.
189+ tolerate_deprecated_model (bool): True if deprecated versions of model
190+ specifications should be tolerated (exception not raised). If False, raises
191+ an exception if the version of the model is deprecated.
182192 Returns:
183193 str: the model artifact S3 URI for the corresponding model.
184194
185195 Raises:
186196 ValueError: If the combination of arguments specified is not supported.
197+ VulnerableJumpStartModelError: If any of the dependencies required by the script have
198+ known security vulnerabilities.
199+ DeprecatedJumpStartModelError: If the version of the model is deprecated.
187200 """
188201 if region is None :
189202 region = JUMPSTART_DEFAULT_REGION_NAME
190203
191- assert region is not None
192-
193- if model_scope is None :
194- raise ValueError (
195- "Must specify `model_scope` argument to retrieve model "
196- "artifact uri for JumpStart models."
197- )
198-
199- if model_scope not in SUPPORTED_JUMPSTART_SCOPES :
200- raise ValueError (
201- f"JumpStart models only support scopes: { ', ' .join (SUPPORTED_JUMPSTART_SCOPES )} ."
202- )
203-
204- model_specs = jumpstart_accessors .JumpStartModelsAccessor .get_model_specs (
205- region = region , model_id = model_id , version = model_version
204+ model_specs = verify_model_region_and_return_specs (
205+ model_id = model_id ,
206+ version = model_version ,
207+ scope = model_scope ,
208+ region = region ,
209+ tolerate_vulnerable_model = tolerate_vulnerable_model ,
210+ tolerate_deprecated_model = tolerate_deprecated_model ,
206211 )
207- if model_scope == INFERENCE :
212+
213+ if model_scope == JumpStartScriptScope .INFERENCE :
208214 model_artifact_key = model_specs .hosting_artifact_key
209- elif model_scope == TRAINING :
210- if not model_specs .training_supported :
211- raise ValueError (
212- f"JumpStart model ID '{ model_id } ' and version '{ model_version } ' "
213- "does not support training."
214- )
215- assert model_specs .training_artifact_key is not None
215+ elif model_scope == JumpStartScriptScope .TRAINING :
216216 model_artifact_key = model_specs .training_artifact_key
217217
218218 bucket = get_jumpstart_content_bucket (region )
@@ -227,6 +227,8 @@ def _retrieve_script_uri(
227227 model_version : str ,
228228 script_scope : Optional [str ],
229229 region : Optional [str ],
230+ tolerate_vulnerable_model : bool ,
231+ tolerate_deprecated_model : bool ,
230232):
231233 """Retrieves the script S3 URI associated with the model matching the given arguments.
232234
@@ -238,40 +240,37 @@ def _retrieve_script_uri(
238240 script_scope (str): The script type, i.e. what it is used for.
239241 Valid values: "training" and "inference".
240242 region (str): Region for which to retrieve model script S3 URI.
243+ tolerate_vulnerable_model (bool): True if vulnerable versions of model
244+ specifications should be tolerated (exception not raised). If False, raises an
245+ exception if the script used by this version of the model has dependencies with known
246+ security vulnerabilities.
247+ tolerate_deprecated_model (bool): True if deprecated versions of model
248+ specifications should be tolerated (exception not raised). If False, raises
249+ an exception if the version of the model is deprecated.
241250 Returns:
242251 str: the model script URI for the corresponding model.
243252
244253 Raises:
245254 ValueError: If the combination of arguments specified is not supported.
255+ VulnerableJumpStartModelError: If any of the dependencies required by the script have
256+ known security vulnerabilities.
257+ DeprecatedJumpStartModelError: If the version of the model is deprecated.
246258 """
247259 if region is None :
248260 region = JUMPSTART_DEFAULT_REGION_NAME
249261
250- assert region is not None
251-
252- if script_scope is None :
253- raise ValueError (
254- "Must specify `script_scope` argument to retrieve model script uri for "
255- "JumpStart models."
256- )
257-
258- if script_scope not in SUPPORTED_JUMPSTART_SCOPES :
259- raise ValueError (
260- f"JumpStart models only support scopes: { ', ' .join (SUPPORTED_JUMPSTART_SCOPES )} ."
261- )
262-
263- model_specs = jumpstart_accessors .JumpStartModelsAccessor .get_model_specs (
264- region = region , model_id = model_id , version = model_version
262+ model_specs = verify_model_region_and_return_specs (
263+ model_id = model_id ,
264+ version = model_version ,
265+ scope = script_scope ,
266+ region = region ,
267+ tolerate_vulnerable_model = tolerate_vulnerable_model ,
268+ tolerate_deprecated_model = tolerate_deprecated_model ,
265269 )
266- if script_scope == INFERENCE :
270+
271+ if script_scope == JumpStartScriptScope .INFERENCE :
267272 model_script_key = model_specs .hosting_script_key
268- elif script_scope == TRAINING :
269- if not model_specs .training_supported :
270- raise ValueError (
271- f"JumpStart model ID '{ model_id } ' and version '{ model_version } ' "
272- "does not support training."
273- )
274- assert model_specs .training_script_key is not None
273+ elif script_scope == JumpStartScriptScope .TRAINING :
275274 model_script_key = model_specs .training_script_key
276275
277276 bucket = get_jumpstart_content_bucket (region )
@@ -309,8 +308,6 @@ def _retrieve_default_hyperparameters(
309308 if region is None :
310309 region = JUMPSTART_DEFAULT_REGION_NAME
311310
312- assert region is not None
313-
314311 model_specs = jumpstart_accessors .JumpStartModelsAccessor .get_model_specs (
315312 region = region , model_id = model_id , version = model_version
316313 )
0 commit comments