@@ -167,20 +167,13 @@ def retrieve(
167167 )
168168 else :
169169 _framework = framework
170- if (
171- framework == HUGGING_FACE_FRAMEWORK
172- or framework in TRAINIUM_ALLOWED_FRAMEWORKS
173- ):
170+ if framework == HUGGING_FACE_FRAMEWORK or framework in TRAINIUM_ALLOWED_FRAMEWORKS :
174171 inference_tool = _get_inference_tool (inference_tool , instance_type )
175172 if inference_tool in ["neuron" , "neuronx" ]:
176173 _framework = f"{ framework } -{ inference_tool } "
177- final_image_scope = _get_final_image_scope (
178- framework , instance_type , image_scope
179- )
174+ final_image_scope = _get_final_image_scope (framework , instance_type , image_scope )
180175 _validate_for_suppported_frameworks_and_instance_type (framework , instance_type )
181- config = _config_for_framework_and_scope (
182- _framework , final_image_scope , accelerator_type
183- )
176+ config = _config_for_framework_and_scope (_framework , final_image_scope , accelerator_type )
184177
185178 original_version = version
186179 version = _validate_version_and_set_if_needed (version , config , framework )
@@ -191,14 +184,10 @@ def retrieve(
191184 full_base_framework_version = version_config ["version_aliases" ].get (
192185 base_framework_version , base_framework_version
193186 )
194- _validate_arg (
195- full_base_framework_version , list (version_config .keys ()), "base framework"
196- )
187+ _validate_arg (full_base_framework_version , list (version_config .keys ()), "base framework" )
197188 version_config = version_config .get (full_base_framework_version )
198189
199- py_version = _validate_py_version_and_set_if_needed (
200- py_version , version_config , framework
201- )
190+ py_version = _validate_py_version_and_set_if_needed (py_version , version_config , framework )
202191 version_config = version_config .get (py_version ) or version_config
203192 registry = _registry_from_region (region , version_config ["registries" ])
204193 endpoint_data = utils ._botocore_resolver ().construct_endpoint ("ecr" , region )
@@ -226,9 +215,7 @@ def retrieve(
226215
227216 if framework == HUGGING_FACE_FRAMEWORK :
228217 pt_or_tf_version = (
229- re .compile ("^(pytorch|tensorflow)(.*)$" )
230- .match (base_framework_version )
231- .group (2 )
218+ re .compile ("^(pytorch|tensorflow)(.*)$" ).match (base_framework_version ).group (2 )
232219 )
233220 _version = original_version
234221
@@ -252,13 +239,11 @@ def retrieve(
252239 .get ("version_aliases" , {})
253240 .get (base_framework_version , {})
254241 ):
255- _base_framework_version = config .get ("versions" )[_version ][
256- "version_aliases"
257- ][ base_framework_version ]
242+ _base_framework_version = config .get ("versions" )[_version ]["version_aliases" ][
243+ base_framework_version
244+ ]
258245 pt_or_tf_version = (
259- re .compile ("^(pytorch|tensorflow)(.*)$" )
260- .match (_base_framework_version )
261- .group (2 )
246+ re .compile ("^(pytorch|tensorflow)(.*)$" ).match (_base_framework_version ).group (2 )
262247 )
263248
264249 tag_prefix = f"{ pt_or_tf_version } -transformers{ _version } "
@@ -285,9 +270,7 @@ def retrieve(
285270 if tag :
286271 repo += ":{}" .format (tag )
287272
288- return ECR_URI_TEMPLATE .format (
289- registry = registry , hostname = hostname , repository = repo
290- )
273+ return ECR_URI_TEMPLATE .format (registry = registry , hostname = hostname , repository = repo )
291274
292275
293276def _get_image_tag (
@@ -326,13 +309,9 @@ def _get_image_tag(
326309 }
327310 tag = version_to_arm64_tag_mapping [framework ][version ]
328311 else :
329- tag = _format_tag (
330- tag_prefix , processor , py_version , container_version , inference_tool
331- )
312+ tag = _format_tag (tag_prefix , processor , py_version , container_version , inference_tool )
332313 else :
333- tag = _format_tag (
334- tag_prefix , processor , py_version , container_version , inference_tool
335- )
314+ tag = _format_tag (tag_prefix , processor , py_version , container_version , inference_tool )
336315
337316 if instance_type is not None and _should_auto_select_container_version (
338317 instance_type , distribution
@@ -383,11 +362,7 @@ def _config_for_framework_and_scope(framework, image_scope, accelerator_type=Non
383362 )
384363 image_scope = available_scopes [0 ]
385364
386- if (
387- not image_scope
388- and "scope" in config
389- and set (available_scopes ) == {"training" , "inference" }
390- ):
365+ if not image_scope and "scope" in config and set (available_scopes ) == {"training" , "inference" }:
391366 logger .info (
392367 "Same images used for training and inference. Defaulting to image scope: %s." ,
393368 available_scopes [0 ],
@@ -419,27 +394,20 @@ def _validate_for_suppported_frameworks_and_instance_type(framework, instance_ty
419394 and "trn" in instance_type
420395 and framework not in TRAINIUM_ALLOWED_FRAMEWORKS
421396 ):
422- _validate_framework (
423- framework , TRAINIUM_ALLOWED_FRAMEWORKS , "framework" , "Trainium"
424- )
397+ _validate_framework (framework , TRAINIUM_ALLOWED_FRAMEWORKS , "framework" , "Trainium" )
425398
426399 # Validate for Graviton allowed frameowrks
427400 if (
428401 instance_type is not None
429- and utils .get_instance_type_family (instance_type )
430- in GRAVITON_ALLOWED_TARGET_INSTANCE_FAMILY
402+ and utils .get_instance_type_family (instance_type ) in GRAVITON_ALLOWED_TARGET_INSTANCE_FAMILY
431403 and framework not in GRAVITON_ALLOWED_FRAMEWORKS
432404 ):
433- _validate_framework (
434- framework , GRAVITON_ALLOWED_FRAMEWORKS , "framework" , "Graviton"
435- )
405+ _validate_framework (framework , GRAVITON_ALLOWED_FRAMEWORKS , "framework" , "Graviton" )
436406
437407
438408def config_for_framework (framework ):
439409 """Loads the JSON config for the given framework."""
440- fname = os .path .join (
441- os .path .dirname (__file__ ), "image_uri_config" , "{}.json" .format (framework )
442- )
410+ fname = os .path .join (os .path .dirname (__file__ ), "image_uri_config" , "{}.json" .format (framework ))
443411 with open (fname ) as f :
444412 return json .load (f )
445413
@@ -448,8 +416,7 @@ def _get_final_image_scope(framework, instance_type, image_scope):
448416 """Return final image scope based on provided framework and instance type."""
449417 if (
450418 framework in GRAVITON_ALLOWED_FRAMEWORKS
451- and utils .get_instance_type_family (instance_type )
452- in GRAVITON_ALLOWED_TARGET_INSTANCE_FAMILY
419+ and utils .get_instance_type_family (instance_type ) in GRAVITON_ALLOWED_TARGET_INSTANCE_FAMILY
453420 ):
454421 return INFERENCE_GRAVITON
455422 if image_scope is None and framework in (XGBOOST_FRAMEWORK , SKLEARN_FRAMEWORK ):
@@ -465,9 +432,7 @@ def _get_inference_tool(inference_tool, instance_type):
465432 """Extract the inference tool name from instance type."""
466433 if not inference_tool :
467434 instance_type_family = utils .get_instance_type_family (instance_type )
468- if instance_type_family .startswith ("inf" ) or instance_type_family .startswith (
469- "trn"
470- ):
435+ if instance_type_family .startswith ("inf" ) or instance_type_family .startswith ("trn" ):
471436 return "neuron"
472437 return inference_tool
473438
@@ -479,15 +444,10 @@ def _get_latest_versions(list_of_versions):
479444
480445def _validate_accelerator_type (accelerator_type ):
481446 """Raises a ``ValueError`` if ``accelerator_type`` is invalid."""
482- if (
483- not accelerator_type .startswith ("ml.eia" )
484- and accelerator_type != "local_sagemaker_notebook"
485- ):
447+ if not accelerator_type .startswith ("ml.eia" ) and accelerator_type != "local_sagemaker_notebook" :
486448 raise ValueError (
487449 "Invalid SageMaker Elastic Inference accelerator type: {}. "
488- "See https://docs.aws.amazon.com/sagemaker/latest/dg/ei.html" .format (
489- accelerator_type
490- )
450+ "See https://docs.aws.amazon.com/sagemaker/latest/dg/ei.html" .format (accelerator_type )
491451 )
492452
493453
@@ -497,15 +457,11 @@ def _validate_version_and_set_if_needed(version, config, framework):
497457 aliased_versions = list (config .get ("version_aliases" , {}).keys ())
498458
499459 if len (available_versions ) == 1 and version not in aliased_versions :
500- log_message = (
501- "Defaulting to the only supported framework/algorithm version: {}." .format (
502- available_versions [0 ]
503- )
460+ log_message = "Defaulting to the only supported framework/algorithm version: {}." .format (
461+ available_versions [0 ]
504462 )
505463 if version and version != available_versions [0 ]:
506- logger .warning (
507- "%s Ignoring framework/algorithm version: %s." , log_message , version
508- )
464+ logger .warning ("%s Ignoring framework/algorithm version: %s." , log_message , version )
509465 elif not version :
510466 logger .info (log_message )
511467
@@ -518,9 +474,7 @@ def _validate_version_and_set_if_needed(version, config, framework):
518474 ]:
519475 version = _get_latest_versions (available_versions )
520476
521- _validate_arg (
522- version , available_versions + aliased_versions , "{} version" .format (framework )
523- )
477+ _validate_arg (version , available_versions + aliased_versions , "{} version" .format (framework ))
524478 return version
525479
526480
@@ -546,9 +500,7 @@ def _processor(instance_type, available_processors, serverless_inference_config=
546500 return None
547501
548502 if len (available_processors ) == 1 and not instance_type :
549- logger .info (
550- "Defaulting to only supported image scope: %s." , available_processors [0 ]
551- )
503+ logger .info ("Defaulting to only supported image scope: %s." , available_processors [0 ])
552504 return available_processors [0 ]
553505
554506 if serverless_inference_config is not None :
@@ -585,9 +537,7 @@ def _processor(instance_type, available_processors, serverless_inference_config=
585537 else :
586538 raise ValueError (
587539 "Invalid SageMaker instance type: {}. For options, see: "
588- "https://aws.amazon.com/sagemaker/pricing/instance-types" .format (
589- instance_type
590- )
540+ "https://aws.amazon.com/sagemaker/pricing/instance-types" .format (instance_type )
591541 )
592542
593543 _validate_arg (processor , available_processors , "processor" )
@@ -626,9 +576,7 @@ def _validate_py_version_and_set_if_needed(py_version, version_config, framework
626576 return None
627577
628578 if py_version is None and len (available_versions ) == 1 :
629- logger .info (
630- "Defaulting to only available Python version: %s" , available_versions [0 ]
631- )
579+ logger .info ("Defaulting to only available Python version: %s" , available_versions [0 ])
632580 return available_versions [0 ]
633581
634582 _validate_arg (py_version , available_versions , "Python version" )
@@ -641,9 +589,7 @@ def _validate_arg(arg, available_options, arg_name):
641589 raise ValueError (
642590 "Unsupported {arg_name}: {arg}. You may need to upgrade your SDK version "
643591 "(pip install -U sagemaker) for newer {arg_name}s. Supported {arg_name}(s): "
644- "{options}." .format (
645- arg_name = arg_name , arg = arg , options = ", " .join (available_options )
646- )
592+ "{options}." .format (arg_name = arg_name , arg = arg , options = ", " .join (available_options ))
647593 )
648594
649595
@@ -656,17 +602,11 @@ def _validate_framework(framework, allowed_frameworks, arg_name, hardware_name):
656602 )
657603
658604
659- def _format_tag (
660- tag_prefix , processor , py_version , container_version , inference_tool = None
661- ):
605+ def _format_tag (tag_prefix , processor , py_version , container_version , inference_tool = None ):
662606 """Creates a tag for the image URI."""
663607 if inference_tool :
664- return "-" .join (
665- x for x in (tag_prefix , inference_tool , py_version , container_version ) if x
666- )
667- return "-" .join (
668- x for x in (tag_prefix , processor , py_version , container_version ) if x
669- )
608+ return "-" .join (x for x in (tag_prefix , inference_tool , py_version , container_version ) if x )
609+ return "-" .join (x for x in (tag_prefix , processor , py_version , container_version ) if x )
670610
671611
672612@override_pipeline_parameter_var
@@ -775,6 +715,4 @@ def get_base_python_image_uri(region, py_version="310") -> str:
775715 repo = version_config ["repository" ] + "-" + py_version
776716 repo_and_tag = repo + ":" + version
777717
778- return ECR_URI_TEMPLATE .format (
779- registry = registry , hostname = hostname , repository = repo_and_tag
780- )
718+ return ECR_URI_TEMPLATE .format (registry = registry , hostname = hostname , repository = repo_and_tag )
0 commit comments