|
28 | 28 | from sagemaker.session import Session |
29 | 29 | from sagemaker.compute_resource_requirements.resource_requirements import ResourceRequirements |
30 | 30 |
|
| 31 | +REQUIREMENT_TYPE_TO_SPEC_FIELD_NAME_TO_RESOURCE_REQUIREMENT_NAME_MAP: Dict[str, Dict[str, str]] = { |
| 32 | + "requests": { |
| 33 | + "num_accelerators": ("num_accelerators", "num_accelerators"), |
| 34 | + "num_cpus": ("num_cpus", "num_cpus"), |
| 35 | + "copies": ("copies", "copy_count"), |
| 36 | + "min_memory_mb": ("memory", "min_memory"), |
| 37 | + }, |
| 38 | + "limits": { |
| 39 | + "max_memory_mb": ("memory", "max_memory"), |
| 40 | + }, |
| 41 | +} |
| 42 | + |
31 | 43 |
|
32 | 44 | def _retrieve_default_resources( |
33 | 45 | model_id: str, |
@@ -113,16 +125,22 @@ def _retrieve_default_resources( |
113 | 125 | } |
114 | 126 |
|
115 | 127 | if is_dynamic_container_deployment_supported: |
116 | | - requests = {} |
117 | | - if "num_accelerators" in default_resource_requirements: |
118 | | - requests["num_accelerators"] = default_resource_requirements["num_accelerators"] |
119 | | - if "min_memory_mb" in default_resource_requirements: |
120 | | - requests["memory"] = default_resource_requirements["min_memory_mb"] |
121 | | - if "num_cpus" in default_resource_requirements: |
122 | | - requests["num_cpus"] = default_resource_requirements["num_cpus"] |
123 | | - |
124 | | - limits = {} |
125 | | - if "max_memory_mb" in default_resource_requirements: |
126 | | - limits["memory"] = default_resource_requirements["max_memory_mb"] |
127 | | - return ResourceRequirements(requests=requests, limits=limits) |
| 128 | + |
| 129 | + all_resource_requirement_kwargs = {} |
| 130 | + |
| 131 | + for ( |
| 132 | + requirement_type, |
| 133 | + spec_field_to_resource_requirement_map, |
| 134 | + ) in REQUIREMENT_TYPE_TO_SPEC_FIELD_NAME_TO_RESOURCE_REQUIREMENT_NAME_MAP.items(): |
| 135 | + requirement_type |
| 136 | + requirement_kwargs = {} |
| 137 | + for spec_field, resource_requirement in spec_field_to_resource_requirement_map.items(): |
| 138 | + if spec_field in default_resource_requirements: |
| 139 | + requirement_kwargs[resource_requirement[0]] = default_resource_requirements[ |
| 140 | + spec_field |
| 141 | + ] |
| 142 | + |
| 143 | + all_resource_requirement_kwargs[requirement_type] = requirement_kwargs |
| 144 | + |
| 145 | + return ResourceRequirements(**all_resource_requirement_kwargs) |
128 | 146 | return None |
0 commit comments