1010# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
1111# ANY KIND, either express or implied. See the License for the specific
1212# language governing permissions and limitations under the License.
13- """This module contains functions for obtainining JumpStart artifacts ."""
13+ """This module contains functions for obtaining JumpStart ECR and S3 URIs ."""
1414from __future__ import absolute_import
1515from typing import Optional
1616from sagemaker import image_uris
@@ -42,13 +42,14 @@ def _retrieve_image_uri(
4242):
4343 """Retrieves the container image URI for JumpStart models.
4444
45- Only `model_id` and `model_version ` are required to be non-None ;
45+ Only `model_id`, `model_version`, and `image_scope ` are required;
4646 the rest of the fields are auto-populated.
4747
4848
4949 Args:
50- model_id (str): JumpStart model id for which to retrieve image URI.
51- model_version (str): JumpStart model version for which to retrieve image URI.
50+ model_id (str): JumpStart model ID for which to retrieve image URI.
51+ model_version (str): Version of the JumpStart model for which to retrieve
52+ the image URI (default: None).
5253 framework (str): The name of the framework or algorithm.
5354 region (str): The AWS region.
5455 version (str): The framework or algorithm version. This is required if there is
@@ -89,7 +90,9 @@ def _retrieve_image_uri(
8990 "Must specify `image_scope` argument to retrieve image uri for JumpStart models."
9091 )
9192 if image_scope not in SUPPORTED_JUMPSTART_SCOPES :
92- raise ValueError ("JumpStart models only support inference and training." )
93+ raise ValueError (
94+ f"JumpStart models only support scopes: { ', ' .join (SUPPORTED_JUMPSTART_SCOPES )} ."
95+ )
9396
9497 model_specs = jumpstart_accessors .JumpStartModelsCache .get_model_specs (
9598 region , model_id , model_version
@@ -99,25 +102,33 @@ def _retrieve_image_uri(
99102 ecr_specs = model_specs .hosting_ecr_specs
100103 elif image_scope == TRAINING :
101104 if not model_specs .training_supported :
102- raise ValueError (f"JumpStart model id '{ model_id } ' does not support training." )
105+ raise ValueError (
106+ f"JumpStart model ID '{ model_id } ' and version '{ model_version } ' "
107+ "does not support training."
108+ )
103109 assert model_specs .training_ecr_specs is not None
104110 ecr_specs = model_specs .training_ecr_specs
105111
106112 if framework is not None and framework != ecr_specs .framework :
107- raise ValueError (f"Bad value for container framework for JumpStart model: '{ framework } '." )
113+ raise ValueError (
114+ f"Incorrect container framework '{ framework } ' for JumpStart model ID '{ model_id } ' "
115+ "and version {model_version}'."
116+ )
108117
109118 if version is not None and version != ecr_specs .framework_version :
110119 raise ValueError (
111- f"Bad value for container framework version for JumpStart model: '{ version } '."
120+ f"Incorrect container framework version '{ version } ' for JumpStart model ID "
121+ f"'{ model_id } ' and version { model_version } '."
112122 )
113123
114124 if py_version is not None and py_version != ecr_specs .py_version :
115125 raise ValueError (
116- f"Bad value for container python version for JumpStart model: '{ py_version } '."
126+ f"Incorrect python version '{ py_version } ' for JumpStart model ID '{ model_id } ' "
127+ "and version {model_version}'."
117128 )
118129
119- base_framework_version_override = None
120- version_override = None
130+ base_framework_version_override : Optional [ str ] = None
131+ version_override : Optional [ str ] = None
121132 if ecr_specs .framework == ModelFramework .HUGGINGFACE .value :
122133 base_framework_version_override = ecr_specs .framework_version
123134 version_override = ecr_specs .huggingface_transformers_version
@@ -162,8 +173,10 @@ def _retrieve_model_uri(
162173 """Retrieves the model artifact S3 URI for the model matching the given arguments.
163174
164175 Args:
165- model_id (str): JumpStart model id for which to retrieve model S3 URI.
166- model_version (str): JumpStart model version for which to retrieve model S3 URI.
176+ model_id (str): JumpStart model ID of the JumpStart model for which to retrieve
177+ the model artifact S3 URI.
178+ model_version (str): Version of the JumpStart model for which to retrieve the model
179+ artifact S3 URI.
167180 model_scope (str): The model type, i.e. what it is used for.
168181 Valid values: "training" and "inference".
169182 region (str): Region for which to retrieve model S3 URI.
@@ -185,7 +198,9 @@ def _retrieve_model_uri(
185198 )
186199
187200 if model_scope not in SUPPORTED_JUMPSTART_SCOPES :
188- raise ValueError ("JumpStart models only support inference and training." )
201+ raise ValueError (
202+ f"JumpStart models only support scopes: { ', ' .join (SUPPORTED_JUMPSTART_SCOPES )} ."
203+ )
189204
190205 model_specs = jumpstart_accessors .JumpStartModelsCache .get_model_specs (
191206 region , model_id , model_version
@@ -194,7 +209,10 @@ def _retrieve_model_uri(
194209 model_artifact_key = model_specs .hosting_artifact_key
195210 elif model_scope == TRAINING :
196211 if not model_specs .training_supported :
197- raise ValueError (f"JumpStart model id '{ model_id } ' does not support training." )
212+ raise ValueError (
213+ f"JumpStart model ID '{ model_id } ' and version '{ model_version } ' "
214+ "does not support training."
215+ )
198216 assert model_specs .training_artifact_key is not None
199217 model_artifact_key = model_specs .training_artifact_key
200218
@@ -211,11 +229,13 @@ def _retrieve_script_uri(
211229 script_scope : Optional [str ],
212230 region : Optional [str ],
213231):
214- """Retrieves the model script s3 URI for the model matching the given arguments.
232+ """Retrieves the script S3 URI associated with the model matching the given arguments.
215233
216234 Args:
217- model_id (str): JumpStart model id for which to retrieve model script S3 URI.
218- model_version (str): JumpStart model version for which to retrieve model script S3 URI.
235+ model_id (str): JumpStart model ID of the JumpStart model for which to
236+ retrieve the script S3 URI.
237+ model_version (str): Version of the JumpStart model for which to
238+ retrieve the model script S3 URI.
219239 script_scope (str): The script type, i.e. what it is used for.
220240 Valid values: "training" and "inference".
221241 region (str): Region for which to retrieve model script S3 URI.
@@ -237,7 +257,9 @@ def _retrieve_script_uri(
237257 )
238258
239259 if script_scope not in SUPPORTED_JUMPSTART_SCOPES :
240- raise ValueError ("JumpStart models only support inference and training." )
260+ raise ValueError (
261+ f"JumpStart models only support scopes: { ', ' .join (SUPPORTED_JUMPSTART_SCOPES )} ."
262+ )
241263
242264 model_specs = jumpstart_accessors .JumpStartModelsCache .get_model_specs (
243265 region , model_id , model_version
@@ -246,7 +268,10 @@ def _retrieve_script_uri(
246268 model_script_key = model_specs .hosting_script_key
247269 elif script_scope == TRAINING :
248270 if not model_specs .training_supported :
249- raise ValueError (f"JumpStart model id '{ model_id } ' does not support training." )
271+ raise ValueError (
272+ f"JumpStart model ID '{ model_id } ' and version '{ model_version } ' "
273+ "does not support training."
274+ )
250275 assert model_specs .training_script_key is not None
251276 model_script_key = model_specs .training_script_key
252277
0 commit comments