1212# language governing permissions and limitations under the License.
1313"""This module contains functions for obtaining JumpStart ECR and S3 URIs."""
1414from __future__ import absolute_import
15- from typing import Optional
15+ from typing import Dict , Optional
1616from sagemaker import image_uris
1717from sagemaker .jumpstart .constants import (
1818 JUMPSTART_DEFAULT_REGION_NAME ,
1919 INFERENCE ,
2020 TRAINING ,
2121 SUPPORTED_JUMPSTART_SCOPES ,
2222 ModelFramework ,
23+ VariableScope ,
2324)
2425from sagemaker .jumpstart .utils import get_jumpstart_content_bucket
2526from sagemaker .jumpstart import accessors as jumpstart_accessors
@@ -295,7 +296,12 @@ def _retrieve_default_hyperparameters(
295296 default hyperparameters.
296297 region (str): Region for which to retrieve default hyperparameters.
297298 include_container_hyperparameters (bool): True if container hyperparameters
298- should be returned as well. (Default: False)
299+ should be returned as well. Container hyperparameters are not used to tune
300+ the specific algorithm, but rather by SageMaker Training to setup
301+ the training container environment. For example, there is a container hyperparameter
302+ that indicates the entrypoint script to use. These hyperparameters may be required
303+ when creating a training job with boto3, however the ``Estimator`` classes
304+ should take care of adding container hyperparameters to the job. (Default: False).
299305 Returns:
300306 dict: the hyperparameters to use for the model.
301307
@@ -312,11 +318,11 @@ def _retrieve_default_hyperparameters(
312318 region = region , model_id = model_id , version = model_version
313319 )
314320
315- default_hyperparameters = {}
321+ default_hyperparameters : Dict [ str , str ] = {}
316322 for hyperparameter in model_specs .hyperparameters :
317323 if (
318- include_container_hyperparameters and hyperparameter .scope == "container"
319- ) or hyperparameter .scope == "algorithm" :
324+ include_container_hyperparameters and hyperparameter .scope == VariableScope . CONTAINER
325+ ) or hyperparameter .scope == VariableScope . ALGORITHM :
320326 default_hyperparameters [hyperparameter .name ] = str (hyperparameter .default )
321327 return default_hyperparameters
322328
@@ -333,7 +339,7 @@ def _retrieve_default_environment_variables(
333339 retrieve the default environment variables.
334340 model_version (str): Version of the JumpStart model for which to retrieve the
335341 default environment variables.
336- region (str): Region for which to retrieve default environment variables.
342+ region (Optional[ str] ): Region for which to retrieve default environment variables.
337343
338344 Returns:
339345 dict: the inference environment variables to use for the model.
@@ -345,13 +351,11 @@ def _retrieve_default_environment_variables(
345351 if region is None :
346352 region = JUMPSTART_DEFAULT_REGION_NAME
347353
348- assert region is not None
349-
350354 model_specs = jumpstart_accessors .JumpStartModelsCache .get_model_specs (
351355 region = region , model_id = model_id , version = model_version
352356 )
353357
354- default_environment_variables = {}
358+ default_environment_variables : Dict [ str , str ] = {}
355359 for environment_variable in model_specs .inference_environment_variables :
356360 default_environment_variables [environment_variable .name ] = str (environment_variable .default )
357361 return default_environment_variables
0 commit comments