@@ -41,8 +41,8 @@ def _retrieve_image_uri(
4141 distribution : Optional [str ],
4242 base_framework_version : Optional [str ],
4343 training_compiler_config : Optional [str ],
44- tolerate_vulnerable_model : Optional [ bool ] ,
45- tolerate_deprecated_model : Optional [ bool ] ,
44+ tolerate_vulnerable_model : bool ,
45+ tolerate_deprecated_model : bool ,
4646):
4747 """Retrieves the container image URI for JumpStart models.
4848
@@ -75,12 +75,12 @@ def _retrieve_image_uri(
7575 distribution (dict): A dictionary with information on how to run distributed training
7676 training_compiler_config (:class:`~sagemaker.training_compiler.TrainingCompilerConfig`):
7777 A configuration class for the SageMaker Training Compiler.
78- tolerate_vulnerable_model (Optional[ bool] ): True if vulnerable versions of model
79- specifications should be tolerated (exception not raised). False or None , raises an
78+ tolerate_vulnerable_model (bool): True if vulnerable versions of model
79+ specifications should be tolerated (exception not raised). If False , raises an
8080 exception if the script used by this version of the model has dependencies with known
8181 security vulnerabilities.
82- tolerate_deprecated_model (Optional[ bool] ): True if deprecated versions of model
83- specifications should be tolerated (exception not raised). False or None , raises
82+ tolerate_deprecated_model (bool): True if deprecated versions of model
83+ specifications should be tolerated (exception not raised). If False , raises
8484 an exception if the version of the model is deprecated.
8585
8686 Returns:
@@ -95,8 +95,6 @@ def _retrieve_image_uri(
9595 if region is None :
9696 region = JUMPSTART_DEFAULT_REGION_NAME
9797
98- assert region is not None
99-
10098 model_specs = verify_model_region_and_return_specs (
10199 model_id = model_id ,
102100 version = model_version ,
@@ -109,7 +107,6 @@ def _retrieve_image_uri(
109107 if image_scope == JumpStartScriptScope .INFERENCE :
110108 ecr_specs = model_specs .hosting_ecr_specs
111109 elif image_scope == JumpStartScriptScope .TRAINING :
112- assert model_specs .training_ecr_specs is not None
113110 ecr_specs = model_specs .training_ecr_specs
114111
115112 if framework is not None and framework != ecr_specs .framework :
@@ -172,8 +169,8 @@ def _retrieve_model_uri(
172169 model_version : str ,
173170 model_scope : Optional [str ],
174171 region : Optional [str ],
175- tolerate_vulnerable_model : Optional [ bool ] ,
176- tolerate_deprecated_model : Optional [ bool ] ,
172+ tolerate_vulnerable_model : bool ,
173+ tolerate_deprecated_model : bool ,
177174):
178175 """Retrieves the model artifact S3 URI for the model matching the given arguments.
179176
@@ -185,12 +182,12 @@ def _retrieve_model_uri(
185182 model_scope (str): The model type, i.e. what it is used for.
186183 Valid values: "training" and "inference".
187184 region (str): Region for which to retrieve model S3 URI.
188- tolerate_vulnerable_model (Optional[ bool] ): True if vulnerable versions of model
189- specifications should be tolerated (exception not raised). False or None , raises an
185+ tolerate_vulnerable_model (bool): True if vulnerable versions of model
186+ specifications should be tolerated (exception not raised). If False , raises an
190187 exception if the script used by this version of the model has dependencies with known
191188 security vulnerabilities.
192- tolerate_deprecated_model (Optional[ bool] ): True if deprecated versions of model
193- specifications should be tolerated (exception not raised). False or None , raises
189+ tolerate_deprecated_model (bool): True if deprecated versions of model
190+ specifications should be tolerated (exception not raised). If False , raises
194191 an exception if the version of the model is deprecated.
195192 Returns:
196193 str: the model artifact S3 URI for the corresponding model.
@@ -204,8 +201,6 @@ def _retrieve_model_uri(
204201 if region is None :
205202 region = JUMPSTART_DEFAULT_REGION_NAME
206203
207- assert region is not None
208-
209204 model_specs = verify_model_region_and_return_specs (
210205 model_id = model_id ,
211206 version = model_version ,
@@ -218,7 +213,6 @@ def _retrieve_model_uri(
218213 if model_scope == JumpStartScriptScope .INFERENCE :
219214 model_artifact_key = model_specs .hosting_artifact_key
220215 elif model_scope == JumpStartScriptScope .TRAINING :
221- assert model_specs .training_artifact_key is not None
222216 model_artifact_key = model_specs .training_artifact_key
223217
224218 bucket = get_jumpstart_content_bucket (region )
@@ -233,8 +227,8 @@ def _retrieve_script_uri(
233227 model_version : str ,
234228 script_scope : Optional [str ],
235229 region : Optional [str ],
236- tolerate_vulnerable_model : Optional [ bool ] ,
237- tolerate_deprecated_model : Optional [ bool ] ,
230+ tolerate_vulnerable_model : bool ,
231+ tolerate_deprecated_model : bool ,
238232):
239233 """Retrieves the script S3 URI associated with the model matching the given arguments.
240234
@@ -246,12 +240,12 @@ def _retrieve_script_uri(
246240 script_scope (str): The script type, i.e. what it is used for.
247241 Valid values: "training" and "inference".
248242 region (str): Region for which to retrieve model script S3 URI.
249- tolerate_vulnerable_model (Optional[ bool] ): True if vulnerable versions of model
250- specifications should be tolerated (exception not raised). False or None , raises an
243+ tolerate_vulnerable_model (bool): True if vulnerable versions of model
244+ specifications should be tolerated (exception not raised). If False , raises an
251245 exception if the script used by this version of the model has dependencies with known
252246 security vulnerabilities.
253- tolerate_deprecated_model (Optional[ bool] ): True if deprecated versions of model
254- specifications should be tolerated (exception not raised). False or None , raises
247+ tolerate_deprecated_model (bool): True if deprecated versions of model
248+ specifications should be tolerated (exception not raised). If False , raises
255249 an exception if the version of the model is deprecated.
256250 Returns:
257251 str: the model script URI for the corresponding model.
@@ -265,8 +259,6 @@ def _retrieve_script_uri(
265259 if region is None :
266260 region = JUMPSTART_DEFAULT_REGION_NAME
267261
268- assert region is not None
269-
270262 model_specs = verify_model_region_and_return_specs (
271263 model_id = model_id ,
272264 version = model_version ,
@@ -279,7 +271,6 @@ def _retrieve_script_uri(
279271 if script_scope == JumpStartScriptScope .INFERENCE :
280272 model_script_key = model_specs .hosting_script_key
281273 elif script_scope == JumpStartScriptScope .TRAINING :
282- assert model_specs .training_script_key is not None
283274 model_script_key = model_specs .training_script_key
284275
285276 bucket = get_jumpstart_content_bucket (region )
@@ -317,8 +308,6 @@ def _retrieve_default_hyperparameters(
317308 if region is None :
318309 region = JUMPSTART_DEFAULT_REGION_NAME
319310
320- assert region is not None
321-
322311 model_specs = jumpstart_accessors .JumpStartModelsAccessor .get_model_specs (
323312 region = region , model_id = model_id , version = model_version
324313 )
0 commit comments