1212# language governing permissions and limitations under the License.
1313"""This module contains functions for obtaining JumpStart ECR and S3 URIs."""
1414from __future__ import absolute_import
15- from typing import Optional
15+ from typing import Dict , Optional
1616from sagemaker import image_uris
1717from sagemaker .jumpstart .constants import (
1818 JUMPSTART_DEFAULT_REGION_NAME ,
1919 INFERENCE ,
2020 TRAINING ,
2121 SUPPORTED_JUMPSTART_SCOPES ,
2222 ModelFramework ,
23+ VariableScope ,
2324)
2425from sagemaker .jumpstart .utils import get_jumpstart_content_bucket
2526from sagemaker .jumpstart import accessors as jumpstart_accessors
@@ -93,7 +94,7 @@ def _retrieve_image_uri(
9394 )
9495
9596 model_specs = jumpstart_accessors .JumpStartModelsAccessor .get_model_specs (
96- region , model_id , model_version
97+ region = region , model_id = model_id , version = model_version
9798 )
9899
99100 if image_scope == INFERENCE :
@@ -110,19 +111,19 @@ def _retrieve_image_uri(
110111 if framework is not None and framework != ecr_specs .framework :
111112 raise ValueError (
112113 f"Incorrect container framework '{ framework } ' for JumpStart model ID '{ model_id } ' "
113- f"and version { model_version } '."
114+ f"and version ' { model_version } '."
114115 )
115116
116117 if version is not None and version != ecr_specs .framework_version :
117118 raise ValueError (
118119 f"Incorrect container framework version '{ version } ' for JumpStart model ID "
119- f"'{ model_id } ' and version { model_version } '."
120+ f"'{ model_id } ' and version ' { model_version } '."
120121 )
121122
122123 if py_version is not None and py_version != ecr_specs .py_version :
123124 raise ValueError (
124125 f"Incorrect python version '{ py_version } ' for JumpStart model ID '{ model_id } ' "
125- f"and version { model_version } '."
126+ f"and version ' { model_version } '."
126127 )
127128
128129 base_framework_version_override : Optional [str ] = None
@@ -201,7 +202,7 @@ def _retrieve_model_uri(
201202 )
202203
203204 model_specs = jumpstart_accessors .JumpStartModelsAccessor .get_model_specs (
204- region , model_id , model_version
205+ region = region , model_id = model_id , version = model_version
205206 )
206207 if model_scope == INFERENCE :
207208 model_artifact_key = model_specs .hosting_artifact_key
@@ -260,7 +261,7 @@ def _retrieve_script_uri(
260261 )
261262
262263 model_specs = jumpstart_accessors .JumpStartModelsAccessor .get_model_specs (
263- region , model_id , model_version
264+ region = region , model_id = model_id , version = model_version
264265 )
265266 if script_scope == INFERENCE :
266267 model_script_key = model_specs .hosting_script_key
@@ -278,3 +279,77 @@ def _retrieve_script_uri(
278279 script_s3_uri = f"s3://{ bucket } /{ model_script_key } "
279280
280281 return script_s3_uri
282+
283+
284+ def _retrieve_default_hyperparameters (
285+ model_id : str ,
286+ model_version : str ,
287+ region : Optional [str ],
288+ include_container_hyperparameters : bool = False ,
289+ ):
290+ """Retrieves the training hyperparameters for the model matching the given arguments.
291+
292+ Args:
293+ model_id (str): JumpStart model ID of the JumpStart model for which to
294+ retrieve the default hyperparameters.
295+ model_version (str): Version of the JumpStart model for which to retrieve the
296+ default hyperparameters.
297+ region (str): Region for which to retrieve default hyperparameters.
298+ include_container_hyperparameters (bool): True if container hyperparameters
299+ should be returned as well. Container hyperparameters are not used to tune
300+ the specific algorithm, but rather by SageMaker Training to setup
301+ the training container environment. For example, there is a container hyperparameter
302+ that indicates the entrypoint script to use. These hyperparameters may be required
303+ when creating a training job with boto3, however the ``Estimator`` classes
304+ should take care of adding container hyperparameters to the job. (Default: False).
305+ Returns:
306+ dict: the hyperparameters to use for the model.
307+ """
308+
309+ if region is None :
310+ region = JUMPSTART_DEFAULT_REGION_NAME
311+
312+ assert region is not None
313+
314+ model_specs = jumpstart_accessors .JumpStartModelsAccessor .get_model_specs (
315+ region = region , model_id = model_id , version = model_version
316+ )
317+
318+ default_hyperparameters : Dict [str , str ] = {}
319+ for hyperparameter in model_specs .hyperparameters :
320+ if (
321+ include_container_hyperparameters and hyperparameter .scope == VariableScope .CONTAINER
322+ ) or hyperparameter .scope == VariableScope .ALGORITHM :
323+ default_hyperparameters [hyperparameter .name ] = str (hyperparameter .default )
324+ return default_hyperparameters
325+
326+
327+ def _retrieve_default_environment_variables (
328+ model_id : str ,
329+ model_version : str ,
330+ region : Optional [str ],
331+ ):
332+ """Retrieves the inference environment variables for the model matching the given arguments.
333+
334+ Args:
335+ model_id (str): JumpStart model ID of the JumpStart model for which to
336+ retrieve the default environment variables.
337+ model_version (str): Version of the JumpStart model for which to retrieve the
338+ default environment variables.
339+ region (Optional[str]): Region for which to retrieve default environment variables.
340+
341+ Returns:
342+ dict: the inference environment variables to use for the model.
343+ """
344+
345+ if region is None :
346+ region = JUMPSTART_DEFAULT_REGION_NAME
347+
348+ model_specs = jumpstart_accessors .JumpStartModelsAccessor .get_model_specs (
349+ region = region , model_id = model_id , version = model_version
350+ )
351+
352+ default_environment_variables : Dict [str , str ] = {}
353+ for environment_variable in model_specs .inference_environment_variables :
354+ default_environment_variables [environment_variable .name ] = str (environment_variable .default )
355+ return default_environment_variables
0 commit comments