1414from __future__ import absolute_import
1515
1616import json
17+ import logging
1718import os
1819
1920from sagemaker import utils
2021
22+ logger = logging .getLogger (__name__ )
23+
2124ECR_URI_TEMPLATE = "{registry}.dkr.{hostname}/{repository}:{tag}"
2225
2326
24- def retrieve (framework , region , version = None , py_version = None , instance_type = None ):
27+ def retrieve (
28+ framework ,
29+ region ,
30+ version = None ,
31+ py_version = None ,
32+ instance_type = None ,
33+ accelerator_type = None ,
34+ image_scope = None ,
35+ ):
2536 """Retrieves the ECR URI for the Docker image matching the given arguments.
2637
2738 Args:
@@ -34,28 +45,48 @@ def retrieve(framework, region, version=None, py_version=None, instance_type=Non
3445 instance_type (str): The SageMaker instance type. For supported types, see
3546 https://aws.amazon.com/sagemaker/pricing/instance-types. This is required if
3647 there are different images for different processor types.
48+ accelerator_type (str): Elastic Inference accelerator type. For more, see
49+ https://docs.aws.amazon.com/sagemaker/latest/dg/ei.html.
50+ image_scope (str): The image type, i.e. what it is used for.
51+ Valid values: "training", "inference", "eia". If ``accelerator_type`` is set,
52+ ``image_scope`` is ignored.
3753
3854 Returns:
3955 str: the ECR URI for the corresponding SageMaker Docker image.
4056
4157 Raises:
42- ValueError: If the framework version, Python version, processor type, or region is
43- not supported given the other arguments.
58+ ValueError: If the combination of arguments specified is not supported.
4459 """
45- config = config_for_framework (framework )
60+ config = _config_for_framework_and_scope (framework , image_scope , accelerator_type )
4661 version_config = config ["versions" ][_version_for_config (version , config , framework )]
4762
63+ py_version = _validate_py_version_and_set_if_needed (py_version , version_config )
64+ version_config = version_config .get (py_version ) or version_config
65+
4866 registry = _registry_from_region (region , version_config ["registries" ])
4967 hostname = utils ._botocore_resolver ().construct_endpoint ("ecr" , region )["hostname" ]
5068
5169 repo = version_config ["repository" ]
52-
53- _validate_py_version (py_version , version_config ["py_versions" ], framework , version )
54- tag = "{}-{}-{}" .format (version , _processor (instance_type , config ["processors" ]), py_version )
70+ tag = _format_tag (version , _processor (instance_type , config ["processors" ]), py_version )
5571
5672 return ECR_URI_TEMPLATE .format (registry = registry , hostname = hostname , repository = repo , tag = tag )
5773
5874
75+ def _config_for_framework_and_scope (framework , image_scope , accelerator_type = None ):
76+ """Loads the JSON config for the given framework and image scope."""
77+ config = config_for_framework (framework )
78+
79+ if accelerator_type :
80+ if image_scope not in ("eia" , "inference" ):
81+ logger .warning (
82+ "Elastic inference is for inference only. Ignoring image scope: %s." , image_scope
83+ )
84+ image_scope = "eia"
85+
86+ _validate_arg ("image scope" , image_scope , config .get ("scope" , config .keys ()))
87+ return config if "scope" in config else config [image_scope ]
88+
89+
5990def config_for_framework (framework ):
6091 """Loads the JSON config for the given framework."""
6192 fname = os .path .join (os .path .dirname (__file__ ), "image_uri_config" , "{}.json" .format (framework ))
@@ -69,27 +100,13 @@ def _version_for_config(version, config, framework):
69100 if version in config ["version_aliases" ].keys ():
70101 return config ["version_aliases" ][version ]
71102
72- available_versions = config ["versions" ].keys ()
73- if version in available_versions :
74- return version
75-
76- raise ValueError (
77- "Unsupported {} version: {}. "
78- "You may need to upgrade your SDK version (pip install -U sagemaker) for newer versions. "
79- "Supported version(s): {}." .format (framework , version , ", " .join (available_versions ))
80- )
103+ _validate_arg ("{} version" .format (framework ), version , config ["versions" ].keys ())
104+ return version
81105
82106
83107def _registry_from_region (region , registry_dict ):
84108 """Returns the ECR registry (AWS account number) for the given region."""
85- available_regions = registry_dict .keys ()
86- if region not in available_regions :
87- raise ValueError (
88- "Unsupported region: {}. You may need to upgrade "
89- "your SDK version (pip install -U sagemaker) for newer regions. "
90- "Supported region(s): {}." .format (region , ", " .join (available_regions ))
91- )
92-
109+ _validate_arg ("region" , region , registry_dict .keys ())
93110 return registry_dict [region ]
94111
95112
@@ -106,22 +123,37 @@ def _processor(instance_type, available_processors):
106123 family = instance_type .split ("." )[1 ]
107124 processor = "gpu" if family [0 ] in ("g" , "p" ) else "cpu"
108125
109- if processor in available_processors :
110- return processor
111-
112- raise ValueError (
113- "Unsupported processor type: {} (for {}). "
114- "Supported type(s): {}." .format (processor , instance_type , ", " .join (available_processors ))
115- )
126+ _validate_arg ("processor" , processor , available_processors )
127+ return processor
116128
117129
118- def _validate_py_version (py_version , available_versions , framework , fw_version ):
130+ def _validate_py_version_and_set_if_needed (py_version , version_config ):
119131 """Checks if the Python version is one of the supported versions."""
120- if py_version not in available_versions :
132+ available_versions = version_config .get ("py_versions" , version_config .keys ())
133+
134+ if len (available_versions ) == 0 :
135+ if py_version :
136+ logger .info ("Ignoring unnecessary Python version: %s." , py_version )
137+ return None
138+
139+ if py_version is None and len (available_versions ) == 1 :
140+ logger .info ("Defaulting to only available Python version: %s" , available_versions [0 ])
141+ return available_versions [0 ]
142+
143+ _validate_arg ("Python version" , py_version , available_versions )
144+ return py_version
145+
146+
147+ def _validate_arg (arg_name , arg , available_options ):
148+ """Checks if the arg is in the available options, and raises a ``ValueError`` if not."""
149+ if arg not in available_options :
121150 raise ValueError (
122- "Unsupported Python version for {} {}: {}. You may need to upgrade "
123- "your SDK version (pip install -U sagemaker) for newer versions. "
124- "Supported Python version(s): {}." .format (
125- framework , fw_version , py_version , ", " .join (available_versions )
126- )
151+ "Unsupported {arg_name}: {arg}. You may need to upgrade your SDK version "
152+ "(pip install -U sagemaker) for newer {arg_name}s. Supported {arg_name}(s): "
153+ "{options}." .format (arg_name = arg_name , arg = arg , options = ", " .join (available_options ))
127154 )
155+
156+
157+ def _format_tag (version , processor , py_version ):
158+ """Creates a tag for the image URI."""
159+ return "-" .join ([x for x in (version , processor , py_version ) if x ])
0 commit comments