1313"""This module contains functions for obtaining JumpStart resoure requirements."""
1414from __future__ import absolute_import
1515
16- from typing import Optional
16+ from typing import Dict , Optional , Tuple
1717
1818from sagemaker .jumpstart .constants import (
1919 DEFAULT_JUMPSTART_SAGEMAKER_SESSION ,
2828from sagemaker .session import Session
2929from sagemaker .compute_resource_requirements .resource_requirements import ResourceRequirements
3030
31+ REQUIREMENT_TYPE_TO_SPEC_FIELD_NAME_TO_RESOURCE_REQUIREMENT_NAME_MAP : Dict [
32+ str , Dict [str , Tuple [str , str ]]
33+ ] = {
34+ "requests" : {
35+ "num_accelerators" : ("num_accelerators" , "num_accelerators" ),
36+ "num_cpus" : ("num_cpus" , "num_cpus" ),
37+ "copies" : ("copies" , "copy_count" ),
38+ "min_memory_mb" : ("memory" , "min_memory" ),
39+ },
40+ "limits" : {
41+ "max_memory_mb" : ("memory" , "max_memory" ),
42+ },
43+ }
44+
3145
3246def _retrieve_default_resources (
3347 model_id : str ,
@@ -38,6 +52,7 @@ def _retrieve_default_resources(
3852 tolerate_vulnerable_model : bool = False ,
3953 tolerate_deprecated_model : bool = False ,
4054 sagemaker_session : Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION ,
55+ instance_type : Optional [str ] = None ,
4156) -> ResourceRequirements :
4257 """Retrieves the default resource requirements for the model.
4358
@@ -63,6 +78,8 @@ def _retrieve_default_resources(
6378 object, used for SageMaker interactions. If not
6479 specified, one is created using the default AWS configuration
6580 chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
81+ instance_type (str): An instance type to optionally supply in order to get
82+ host requirements specific for the instance type.
6683 Returns:
6784 str: The default resource requirements to use for the model or None.
6885
@@ -91,23 +108,44 @@ def _retrieve_default_resources(
91108 is_dynamic_container_deployment_supported = (
92109 model_specs .dynamic_container_deployment_supported
93110 )
94- default_resource_requirements = model_specs .hosting_resource_requirements
111+ default_resource_requirements : Dict [str , int ] = (
112+ model_specs .hosting_resource_requirements or {}
113+ )
95114 else :
96115 raise NotImplementedError (
97116 f"Unsupported script scope for retrieving default resource requirements: '{ scope } '"
98117 )
99118
119+ instance_specific_resource_requirements : Dict [str , int ] = (
120+ model_specs .hosting_instance_type_variants .get_instance_specific_resource_requirements (
121+ instance_type
122+ )
123+ if instance_type
124+ and getattr (model_specs , "hosting_instance_type_variants" , None ) is not None
125+ else {}
126+ )
127+
128+ default_resource_requirements = {
129+ ** default_resource_requirements ,
130+ ** instance_specific_resource_requirements ,
131+ }
132+
100133 if is_dynamic_container_deployment_supported :
101- requests = {}
102- if "num_accelerators" in default_resource_requirements :
103- requests ["num_accelerators" ] = default_resource_requirements ["num_accelerators" ]
104- if "min_memory_mb" in default_resource_requirements :
105- requests ["memory" ] = default_resource_requirements ["min_memory_mb" ]
106- if "num_cpus" in default_resource_requirements :
107- requests ["num_cpus" ] = default_resource_requirements ["num_cpus" ]
108-
109- limits = {}
110- if "max_memory_mb" in default_resource_requirements :
111- limits ["memory" ] = default_resource_requirements ["max_memory_mb" ]
112- return ResourceRequirements (requests = requests , limits = limits )
134+
135+ all_resource_requirement_kwargs = {}
136+
137+ for (
138+ requirement_type ,
139+ spec_field_to_resource_requirement_map ,
140+ ) in REQUIREMENT_TYPE_TO_SPEC_FIELD_NAME_TO_RESOURCE_REQUIREMENT_NAME_MAP .items ():
141+ requirement_kwargs = {}
142+ for spec_field , resource_requirement in spec_field_to_resource_requirement_map .items ():
143+ if spec_field in default_resource_requirements :
144+ requirement_kwargs [resource_requirement [0 ]] = default_resource_requirements [
145+ spec_field
146+ ]
147+
148+ all_resource_requirement_kwargs [requirement_type ] = requirement_kwargs
149+
150+ return ResourceRequirements (** all_resource_requirement_kwargs )
113151 return None
0 commit comments