1818 JUMPSTART_DEFAULT_REGION_NAME ,
1919 INFERENCE ,
2020 TRAINING ,
21- SUPPORTED_JUMPSTART_SCOPES ,
2221 ModelFramework ,
2322 VariableScope ,
2423)
25- from sagemaker .jumpstart .utils import get_jumpstart_content_bucket
24+ from sagemaker .jumpstart .utils import (
25+ get_jumpstart_content_bucket ,
26+ verify_model_region_and_return_specs ,
27+ )
2628from sagemaker .jumpstart import accessors as jumpstart_accessors
2729
2830
@@ -40,6 +42,8 @@ def _retrieve_image_uri(
4042 distribution : Optional [str ],
4143 base_framework_version : Optional [str ],
4244 training_compiler_config : Optional [str ],
45+ tolerate_vulnerable_model : Optional [bool ],
46+ tolerate_deprecated_model : Optional [bool ],
4347):
4448 """Retrieves the container image URI for JumpStart models.
4549
@@ -72,39 +76,36 @@ def _retrieve_image_uri(
7276 distribution (dict): A dictionary with information on how to run distributed training
7377 training_compiler_config (:class:`~sagemaker.training_compiler.TrainingCompilerConfig`):
7478 A configuration class for the SageMaker Training Compiler.
79+ tolerate_vulnerable_model (bool): True if vulnerable models should be tolerated (exception
80+ not thrown). False if these models should throw an exception.
81+ tolerate_deprecated_model (bool): True if deprecated models should be tolerated (exception
82+ not thrown). False if these models should throw an exception.
7583
7684 Returns:
7785 str: the ECR URI for the corresponding SageMaker Docker image.
7886
7987 Raises:
8088 ValueError: If the combination of arguments specified is not supported.
89+ VulnerableJumpStartModelError: If the model is vulnerable.
90+ DeprecatedJumpStartModelError: If the model is deprecated.
8191 """
8292 if region is None :
8393 region = JUMPSTART_DEFAULT_REGION_NAME
8494
8595 assert region is not None
8696
87- if image_scope is None :
88- raise ValueError (
89- "Must specify `image_scope` argument to retrieve image uri for JumpStart models."
90- )
91- if image_scope not in SUPPORTED_JUMPSTART_SCOPES :
92- raise ValueError (
93- f"JumpStart models only support scopes: { ', ' .join (SUPPORTED_JUMPSTART_SCOPES )} ."
94- )
95-
96- model_specs = jumpstart_accessors .JumpStartModelsAccessor .get_model_specs (
97- region = region , model_id = model_id , version = model_version
97+ model_specs = verify_model_region_and_return_specs (
98+ model_id = model_id ,
99+ version = model_version ,
100+ scope = image_scope ,
101+ region = region ,
102+ tolerate_vulnerable_model = tolerate_vulnerable_model ,
103+ tolerate_deprecated_model = tolerate_deprecated_model ,
98104 )
99105
100106 if image_scope == INFERENCE :
101107 ecr_specs = model_specs .hosting_ecr_specs
102108 elif image_scope == TRAINING :
103- if not model_specs .training_supported :
104- raise ValueError (
105- f"JumpStart model ID '{ model_id } ' and version '{ model_version } ' "
106- "does not support training."
107- )
108109 assert model_specs .training_ecr_specs is not None
109110 ecr_specs = model_specs .training_ecr_specs
110111
@@ -168,6 +169,8 @@ def _retrieve_model_uri(
168169 model_version : str ,
169170 model_scope : Optional [str ],
170171 region : Optional [str ],
172+ tolerate_vulnerable_model : Optional [bool ],
173+ tolerate_deprecated_model : Optional [bool ],
171174):
172175 """Retrieves the model artifact S3 URI for the model matching the given arguments.
173176
@@ -179,39 +182,35 @@ def _retrieve_model_uri(
179182 model_scope (str): The model type, i.e. what it is used for.
180183 Valid values: "training" and "inference".
181184 region (str): Region for which to retrieve model S3 URI.
185+ tolerate_vulnerable_model (bool): True if vulnerable models should be tolerated (exception
186+ not thrown). False if these models should throw an exception.
187+ tolerate_deprecated_model (bool): True if deprecated models should be tolerated (exception
188+ not thrown). False if these models should throw an exception.
182189 Returns:
183190 str: the model artifact S3 URI for the corresponding model.
184191
185192 Raises:
186193 ValueError: If the combination of arguments specified is not supported.
194+ VulnerableJumpStartModelError: If the model is vulnerable.
195+ DeprecatedJumpStartModelError: If the model is deprecated.
187196 """
188197 if region is None :
189198 region = JUMPSTART_DEFAULT_REGION_NAME
190199
191200 assert region is not None
192201
193- if model_scope is None :
194- raise ValueError (
195- "Must specify `model_scope` argument to retrieve model "
196- "artifact uri for JumpStart models."
197- )
198-
199- if model_scope not in SUPPORTED_JUMPSTART_SCOPES :
200- raise ValueError (
201- f"JumpStart models only support scopes: { ', ' .join (SUPPORTED_JUMPSTART_SCOPES )} ."
202- )
203-
204- model_specs = jumpstart_accessors .JumpStartModelsAccessor .get_model_specs (
205- region = region , model_id = model_id , version = model_version
202+ model_specs = verify_model_region_and_return_specs (
203+ model_id = model_id ,
204+ version = model_version ,
205+ scope = model_scope ,
206+ region = region ,
207+ tolerate_vulnerable_model = tolerate_vulnerable_model ,
208+ tolerate_deprecated_model = tolerate_deprecated_model ,
206209 )
210+
207211 if model_scope == INFERENCE :
208212 model_artifact_key = model_specs .hosting_artifact_key
209213 elif model_scope == TRAINING :
210- if not model_specs .training_supported :
211- raise ValueError (
212- f"JumpStart model ID '{ model_id } ' and version '{ model_version } ' "
213- "does not support training."
214- )
215214 assert model_specs .training_artifact_key is not None
216215 model_artifact_key = model_specs .training_artifact_key
217216
@@ -227,6 +226,8 @@ def _retrieve_script_uri(
227226 model_version : str ,
228227 script_scope : Optional [str ],
229228 region : Optional [str ],
229+ tolerate_vulnerable_model : Optional [bool ],
230+ tolerate_deprecated_model : Optional [bool ],
230231):
231232 """Retrieves the script S3 URI associated with the model matching the given arguments.
232233
@@ -238,39 +239,35 @@ def _retrieve_script_uri(
238239 script_scope (str): The script type, i.e. what it is used for.
239240 Valid values: "training" and "inference".
240241 region (str): Region for which to retrieve model script S3 URI.
242+ tolerate_vulnerable_model (bool): True if vulnerable models should be tolerated (exception
243+ not thrown). False if these models should throw an exception.
244+ tolerate_deprecated_model (bool): True if deprecated models should be tolerated (exception
245+ not thrown). False if these models should throw an exception.
241246 Returns:
242247 str: the model script URI for the corresponding model.
243248
244249 Raises:
245250 ValueError: If the combination of arguments specified is not supported.
251+ VulnerableJumpStartModelError: If the model is vulnerable.
252+ DeprecatedJumpStartModelError: If the model is deprecated.
246253 """
247254 if region is None :
248255 region = JUMPSTART_DEFAULT_REGION_NAME
249256
250257 assert region is not None
251258
252- if script_scope is None :
253- raise ValueError (
254- "Must specify `script_scope` argument to retrieve model script uri for "
255- "JumpStart models."
256- )
257-
258- if script_scope not in SUPPORTED_JUMPSTART_SCOPES :
259- raise ValueError (
260- f"JumpStart models only support scopes: { ', ' .join (SUPPORTED_JUMPSTART_SCOPES )} ."
261- )
262-
263- model_specs = jumpstart_accessors .JumpStartModelsAccessor .get_model_specs (
264- region = region , model_id = model_id , version = model_version
259+ model_specs = verify_model_region_and_return_specs (
260+ model_id = model_id ,
261+ version = model_version ,
262+ scope = script_scope ,
263+ region = region ,
264+ tolerate_vulnerable_model = tolerate_vulnerable_model ,
265+ tolerate_deprecated_model = tolerate_deprecated_model ,
265266 )
267+
266268 if script_scope == INFERENCE :
267269 model_script_key = model_specs .hosting_script_key
268270 elif script_scope == TRAINING :
269- if not model_specs .training_supported :
270- raise ValueError (
271- f"JumpStart model ID '{ model_id } ' and version '{ model_version } ' "
272- "does not support training."
273- )
274271 assert model_specs .training_script_key is not None
275272 model_script_key = model_specs .training_script_key
276273
0 commit comments