1313"""This module contains functions for obtaining JumpStart resoure requirements."""
1414from __future__ import absolute_import
1515
16- from typing import Optional
16+ from typing import Dict , Optional
1717
1818from sagemaker .jumpstart .constants import (
1919 DEFAULT_JUMPSTART_SAGEMAKER_SESSION ,
@@ -37,6 +37,7 @@ def _retrieve_default_resources(
3737 tolerate_vulnerable_model : bool = False ,
3838 tolerate_deprecated_model : bool = False ,
3939 sagemaker_session : Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION ,
40+ instance_type : Optional [str ] = None ,
4041) -> ResourceRequirements :
4142 """Retrieves the default resource requirements for the model.
4243
@@ -60,6 +61,8 @@ def _retrieve_default_resources(
6061 object, used for SageMaker interactions. If not
6162 specified, one is created using the default AWS configuration
6263 chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
64+ instance_type (str): An instance type to optionally supply in order to get
65+ host requirements specific for the instance type.
6366 Returns:
6467 str: The default resource requirements to use for the model or None.
6568
@@ -87,12 +90,28 @@ def _retrieve_default_resources(
8790 is_dynamic_container_deployment_supported = (
8891 model_specs .dynamic_container_deployment_supported
8992 )
90- default_resource_requirements = model_specs .hosting_resource_requirements
93+ default_resource_requirements : Dict [str , int ] = (
94+ model_specs .hosting_resource_requirements or {}
95+ )
9196 else :
9297 raise NotImplementedError (
9398 f"Unsupported script scope for retrieving default resource requirements: '{ scope } '"
9499 )
95100
101+ instance_specific_resource_requirements : Dict [str , int ] = (
102+ model_specs .hosting_instance_type_variants .get_instance_specific_resource_requirements (
103+ instance_type
104+ )
105+ if instance_type
106+ and getattr (model_specs , "hosting_instance_type_variants" , None ) is not None
107+ else {}
108+ )
109+
110+ default_resource_requirements = {
111+ ** default_resource_requirements ,
112+ ** instance_specific_resource_requirements ,
113+ }
114+
96115 if is_dynamic_container_deployment_supported :
97116 requests = {}
98117 if "num_accelerators" in default_resource_requirements :
0 commit comments