Skip to content

Commit c96649e

Browse files
committed
chore: add js support for copies resource requirement, enforce coupling with ResourceRequirements class
1 parent c5af857 commit c96649e

File tree

3 files changed

+51
-13
lines changed

3 files changed

+51
-13
lines changed

src/sagemaker/jumpstart/artifacts/resource_requirements.py

Lines changed: 30 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,18 @@
2828
from sagemaker.session import Session
2929
from sagemaker.compute_resource_requirements.resource_requirements import ResourceRequirements
3030

31+
REQUIREMENT_TYPE_TO_SPEC_FIELD_NAME_TO_RESOURCE_REQUIREMENT_NAME_MAP: Dict[str, Dict[str, str]] = {
32+
"requests": {
33+
"num_accelerators": ("num_accelerators", "num_accelerators"),
34+
"num_cpus": ("num_cpus", "num_cpus"),
35+
"copies": ("copies", "copy_count"),
36+
"min_memory_mb": ("memory", "min_memory"),
37+
},
38+
"limits": {
39+
"max_memory_mb": ("memory", "max_memory"),
40+
},
41+
}
42+
3143

3244
def _retrieve_default_resources(
3345
model_id: str,
@@ -113,16 +125,22 @@ def _retrieve_default_resources(
113125
}
114126

115127
if is_dynamic_container_deployment_supported:
116-
requests = {}
117-
if "num_accelerators" in default_resource_requirements:
118-
requests["num_accelerators"] = default_resource_requirements["num_accelerators"]
119-
if "min_memory_mb" in default_resource_requirements:
120-
requests["memory"] = default_resource_requirements["min_memory_mb"]
121-
if "num_cpus" in default_resource_requirements:
122-
requests["num_cpus"] = default_resource_requirements["num_cpus"]
123-
124-
limits = {}
125-
if "max_memory_mb" in default_resource_requirements:
126-
limits["memory"] = default_resource_requirements["max_memory_mb"]
127-
return ResourceRequirements(requests=requests, limits=limits)
128+
129+
all_resource_requirement_kwargs = {}
130+
131+
for (
132+
requirement_type,
133+
spec_field_to_resource_requirement_map,
134+
) in REQUIREMENT_TYPE_TO_SPEC_FIELD_NAME_TO_RESOURCE_REQUIREMENT_NAME_MAP.items():
135+
requirement_type
136+
requirement_kwargs = {}
137+
for spec_field, resource_requirement in spec_field_to_resource_requirement_map.items():
138+
if spec_field in default_resource_requirements:
139+
requirement_kwargs[resource_requirement[0]] = default_resource_requirements[
140+
spec_field
141+
]
142+
143+
all_resource_requirement_kwargs[requirement_type] = requirement_kwargs
144+
145+
return ResourceRequirements(**all_resource_requirement_kwargs)
128146
return None

src/sagemaker/resource_requirements.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
import logging
1818
from typing import Optional
19+
from sagemaker.compute_resource_requirements.resource_requirements import ResourceRequirements
1920

2021
from sagemaker.jumpstart import utils as jumpstart_utils
2122
from sagemaker.jumpstart import artifacts
@@ -34,7 +35,7 @@ def retrieve_default(
3435
tolerate_deprecated_model: bool = False,
3536
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
3637
instance_type: Optional[str] = None,
37-
) -> str:
38+
) -> ResourceRequirements:
3839
"""Retrieves the default resource requirements for the model matching the given arguments.
3940
4041
Args:

tests/unit/sagemaker/resource_requirements/jumpstart/test_resource_requirements.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,10 @@
1818
import pytest
1919

2020
from sagemaker import resource_requirements
21+
from sagemaker.compute_resource_requirements.resource_requirements import ResourceRequirements
22+
from sagemaker.jumpstart.artifacts.resource_requirements import (
23+
REQUIREMENT_TYPE_TO_SPEC_FIELD_NAME_TO_RESOURCE_REQUIREMENT_NAME_MAP,
24+
)
2125

2226
from tests.unit.sagemaker.jumpstart.utils import get_spec_from_base_spec, get_special_model_spec
2327

@@ -129,3 +133,18 @@ def test_jumpstart_no_supported_resource_requirements(patched_get_model_specs):
129133
resource_requirements.retrieve_default(
130134
region=region, model_id=model_id, model_version=model_version, scope="training"
131135
)
136+
137+
138+
def test_jumpstart_supports_all_resource_requirement_fields():
139+
140+
all_tracked_resource_requirement_fields = {
141+
field
142+
for requirements in REQUIREMENT_TYPE_TO_SPEC_FIELD_NAME_TO_RESOURCE_REQUIREMENT_NAME_MAP.values()
143+
for _, field in requirements.values()
144+
}
145+
146+
excluded_resource_requirement_fields = {"requests", "limits"}
147+
assert (
148+
set(ResourceRequirements().__dict__.keys()) - excluded_resource_requirement_fields
149+
== all_tracked_resource_requirement_fields
150+
)

0 commit comments

Comments
 (0)