@@ -75,18 +75,22 @@ def _retrieve_image_uri(
7575 distribution (dict): A dictionary with information on how to run distributed training
7676 training_compiler_config (:class:`~sagemaker.training_compiler.TrainingCompilerConfig`):
7777 A configuration class for the SageMaker Training Compiler.
78- tolerate_vulnerable_model (bool): True if vulnerable models should be tolerated (exception
79- not raised). False if these models should raise an exception.
80- tolerate_deprecated_model (bool): True if deprecated models should be tolerated (exception
81- not raised). False if these models should raise an exception.
78+ tolerate_vulnerable_model (Optional[bool]): True if vulnerable versions of model
79+ specifications should be tolerated (exception not raised). False or None, raises an
80+ exception if the script used by this version of the model has dependencies with known
81+ security vulnerabilities.
82+ tolerate_deprecated_model (Optional[bool]): True if deprecated versions of model
83+ specifications should be tolerated (exception not raised). False or None, raises
84+ an exception if the version of the model is deprecated.
8285
8386 Returns:
8487 str: the ECR URI for the corresponding SageMaker Docker image.
8588
8689 Raises:
8790 ValueError: If the combination of arguments specified is not supported.
88- VulnerableJumpStartModelError: If the model is vulnerable.
89- DeprecatedJumpStartModelError: If the model is deprecated.
91+ VulnerableJumpStartModelError: If any of the dependencies required by the script have
92+ known security vulnerabilities.
93+ DeprecatedJumpStartModelError: If the version of the model is deprecated.
9094 """
9195 if region is None :
9296 region = JUMPSTART_DEFAULT_REGION_NAME
@@ -102,9 +106,9 @@ def _retrieve_image_uri(
102106 tolerate_deprecated_model = tolerate_deprecated_model ,
103107 )
104108
105- if image_scope == JumpStartScriptScope .INFERENCE . value :
109+ if image_scope == JumpStartScriptScope .INFERENCE :
106110 ecr_specs = model_specs .hosting_ecr_specs
107- elif image_scope == JumpStartScriptScope .TRAINING . value :
111+ elif image_scope == JumpStartScriptScope .TRAINING :
108112 assert model_specs .training_ecr_specs is not None
109113 ecr_specs = model_specs .training_ecr_specs
110114
@@ -128,11 +132,11 @@ def _retrieve_image_uri(
128132
129133 base_framework_version_override : Optional [str ] = None
130134 version_override : Optional [str ] = None
131- if ecr_specs .framework == ModelFramework .HUGGINGFACE . value :
135+ if ecr_specs .framework == ModelFramework .HUGGINGFACE :
132136 base_framework_version_override = ecr_specs .framework_version
133137 version_override = ecr_specs .huggingface_transformers_version
134138
135- if image_scope == JumpStartScriptScope .TRAINING . value :
139+ if image_scope == JumpStartScriptScope .TRAINING :
136140 return image_uris .get_training_image_uri (
137141 region = region ,
138142 framework = ecr_specs .framework ,
@@ -181,17 +185,21 @@ def _retrieve_model_uri(
181185 model_scope (str): The model type, i.e. what it is used for.
182186 Valid values: "training" and "inference".
183187 region (str): Region for which to retrieve model S3 URI.
184- tolerate_vulnerable_model (bool): True if vulnerable models should be tolerated (exception
185- not raised). False if these models should raise an exception.
186- tolerate_deprecated_model (bool): True if deprecated models should be tolerated (exception
187- not raised). False if these models should raise an exception.
188+ tolerate_vulnerable_model (Optional[bool]): True if vulnerable versions of model
189+ specifications should be tolerated (exception not raised). False or None, raises an
190+ exception if the script used by this version of the model has dependencies with known
191+ security vulnerabilities.
192+ tolerate_deprecated_model (Optional[bool]): True if deprecated versions of model
193+ specifications should be tolerated (exception not raised). False or None, raises
194+ an exception if the version of the model is deprecated.
188195 Returns:
189196 str: the model artifact S3 URI for the corresponding model.
190197
191198 Raises:
192199 ValueError: If the combination of arguments specified is not supported.
193- VulnerableJumpStartModelError: If the model is vulnerable.
194- DeprecatedJumpStartModelError: If the model is deprecated.
200+ VulnerableJumpStartModelError: If any of the dependencies required by the script have
201+ known security vulnerabilities.
202+ DeprecatedJumpStartModelError: If the version of the model is deprecated.
195203 """
196204 if region is None :
197205 region = JUMPSTART_DEFAULT_REGION_NAME
@@ -207,9 +215,9 @@ def _retrieve_model_uri(
207215 tolerate_deprecated_model = tolerate_deprecated_model ,
208216 )
209217
210- if model_scope == JumpStartScriptScope .INFERENCE . value :
218+ if model_scope == JumpStartScriptScope .INFERENCE :
211219 model_artifact_key = model_specs .hosting_artifact_key
212- elif model_scope == JumpStartScriptScope .TRAINING . value :
220+ elif model_scope == JumpStartScriptScope .TRAINING :
213221 assert model_specs .training_artifact_key is not None
214222 model_artifact_key = model_specs .training_artifact_key
215223
@@ -238,17 +246,21 @@ def _retrieve_script_uri(
238246 script_scope (str): The script type, i.e. what it is used for.
239247 Valid values: "training" and "inference".
240248 region (str): Region for which to retrieve model script S3 URI.
241- tolerate_vulnerable_model (bool): True if vulnerable models should be tolerated (exception
242- not raised). False if these models should raise an exception.
243- tolerate_deprecated_model (bool): True if deprecated models should be tolerated (exception
244- not raised). False if these models should raise an exception.
249+ tolerate_vulnerable_model (Optional[bool]): True if vulnerable versions of model
250+ specifications should be tolerated (exception not raised). False or None, raises an
251+ exception if the script used by this version of the model has dependencies with known
252+ security vulnerabilities.
253+ tolerate_deprecated_model (Optional[bool]): True if deprecated versions of model
254+ specifications should be tolerated (exception not raised). False or None, raises
255+ an exception if the version of the model is deprecated.
245256 Returns:
246257 str: the model script URI for the corresponding model.
247258
248259 Raises:
249260 ValueError: If the combination of arguments specified is not supported.
250- VulnerableJumpStartModelError: If the model is vulnerable.
251- DeprecatedJumpStartModelError: If the model is deprecated.
261+ VulnerableJumpStartModelError: If any of the dependencies required by the script have
262+ known security vulnerabilities.
263+ DeprecatedJumpStartModelError: If the version of the model is deprecated.
252264 """
253265 if region is None :
254266 region = JUMPSTART_DEFAULT_REGION_NAME
@@ -264,9 +276,9 @@ def _retrieve_script_uri(
264276 tolerate_deprecated_model = tolerate_deprecated_model ,
265277 )
266278
267- if script_scope == JumpStartScriptScope .INFERENCE . value :
279+ if script_scope == JumpStartScriptScope .INFERENCE :
268280 model_script_key = model_specs .hosting_script_key
269- elif script_scope == JumpStartScriptScope .TRAINING . value :
281+ elif script_scope == JumpStartScriptScope .TRAINING :
270282 assert model_specs .training_script_key is not None
271283 model_script_key = model_specs .training_script_key
272284
0 commit comments