Skip to content

Commit 423f389

Browse files
feat: Hyperparameter validation (#2856)
Co-authored-by: Shreya Pandit <[email protected]>
1 parent c9aa29b commit 423f389

File tree

16 files changed

+896
-70
lines changed

16 files changed

+896
-70
lines changed

src/sagemaker/environment_variables.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,8 @@ def retrieve_default(
4444
ValueError: If the combination of arguments specified is not supported.
4545
"""
4646
if not jumpstart_utils.is_jumpstart_model_input(model_id, model_version):
47-
raise ValueError("Must specify `model_id` and `model_version` when retrieving script URIs.")
47+
raise ValueError(
48+
"Must specify `model_id` and `model_version` when retrieving environment variables."
49+
)
4850

4951
return artifacts._retrieve_default_environment_variables(model_id, model_version, region)

src/sagemaker/hyperparameters.py

Lines changed: 54 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,12 @@
1515
from __future__ import absolute_import
1616

1717
import logging
18-
from typing import Dict
18+
from typing import Dict, Optional
1919

2020
from sagemaker.jumpstart import utils as jumpstart_utils
2121
from sagemaker.jumpstart import artifacts
22+
from sagemaker.jumpstart.enums import HyperparameterValidationMode
23+
from sagemaker.jumpstart.validators import validate_hyperparameters
2224

2325
logger = logging.getLogger(__name__)
2426

@@ -51,8 +53,58 @@ def retrieve_default(
5153
ValueError: If the combination of arguments specified is not supported.
5254
"""
5355
if not jumpstart_utils.is_jumpstart_model_input(model_id, model_version):
54-
raise ValueError("Must specify `model_id` and `model_version` when retrieving script URIs.")
56+
raise ValueError(
57+
"Must specify `model_id` and `model_version` when retrieving hyperparameters."
58+
)
5559

5660
return artifacts._retrieve_default_hyperparameters(
5761
model_id, model_version, region, include_container_hyperparameters
5862
)
63+
64+
65+
def validate(
66+
region: Optional[str] = None,
67+
model_id: Optional[str] = None,
68+
model_version: Optional[str] = None,
69+
hyperparameters: Optional[dict] = None,
70+
validation_mode: Optional[HyperparameterValidationMode] = None,
71+
) -> None:
72+
"""Validate hyperparameters for models.
73+
74+
Args:
75+
region (str): Region for which to validate hyperparameters. (Default: None).
76+
model_id (str): Model ID of the model for which to validate hyperparameters.
77+
(Default: None)
78+
model_version (str): Version of the model for which to validate hyperparameters.
79+
(Default: None)
80+
hyperparameters (dict): Hyperparameters to validate.
81+
(Default: None)
82+
validation_mode (HyperparameterValidationMode): Method of validation to use with
83+
hyperparameters. If set to ``VALIDATE_PROVIDED``, only hyperparameters provided
84+
to this function will be validated, the missing hyperparameters will be ignored.
85+
If set to``VALIDATE_ALGORITHM``, all algorithm hyperparameters will be validated.
86+
If set to ``VALIDATE_ALL``, all hyperparameters for the model will be validated.
87+
(Default: None)
88+
89+
Raises:
90+
JumpStartHyperparametersError: If the hyperparameter is not formatted correctly,
91+
according to its specs in the model metadata.
92+
ValueError: If the combination of arguments specified is not supported.
93+
94+
"""
95+
96+
if not jumpstart_utils.is_jumpstart_model_input(model_id, model_version):
97+
raise ValueError(
98+
"Must specify `model_id` and `model_version` when validating hyperparameters."
99+
)
100+
101+
if hyperparameters is None:
102+
raise ValueError("Must specify hyperparameters.")
103+
104+
return validate_hyperparameters(
105+
model_id=model_id,
106+
model_version=model_version,
107+
hyperparameters=hyperparameters,
108+
validation_mode=validation_mode,
109+
region=region,
110+
)

src/sagemaker/jumpstart/artifacts.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616
from sagemaker import image_uris
1717
from sagemaker.jumpstart.constants import (
1818
JUMPSTART_DEFAULT_REGION_NAME,
19+
)
20+
from sagemaker.jumpstart.enums import (
1921
JumpStartScriptScope,
2022
ModelFramework,
2123
VariableScope,

src/sagemaker/jumpstart/constants.py

Lines changed: 3 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,8 @@
1313
"""This module stores constants related to SageMaker JumpStart."""
1414
from __future__ import absolute_import
1515
from typing import Set
16-
from enum import Enum
1716
import boto3
17+
from sagemaker.jumpstart.enums import JumpStartScriptScope
1818
from sagemaker.jumpstart.types import JumpStartLaunchedRegionInfo
1919

2020

@@ -118,52 +118,7 @@
118118

119119
JUMPSTART_DEFAULT_MANIFEST_FILE_S3_KEY = "models_manifest.json"
120120

121-
122-
INFERENCE_ENTRYPOINT_SCRIPT_NAME = "inference.py"
123-
TRAINING_ENTRYPOINT_SCRIPT_NAME = "transfer_learning.py"
124-
125-
126-
class JumpStartScriptScope(str, Enum):
127-
"""Enum class for JumpStart script scopes."""
128-
129-
INFERENCE = "inference"
130-
TRAINING = "training"
131-
121+
INFERENCE_ENTRY_POINT_SCRIPT_NAME = "inference.py"
122+
TRAINING_ENTRY_POINT_SCRIPT_NAME = "transfer_learning.py"
132123

133124
SUPPORTED_JUMPSTART_SCOPES = set(scope.value for scope in JumpStartScriptScope)
134-
135-
136-
class ModelFramework(str, Enum):
137-
"""Enum class for JumpStart model framework.
138-
139-
The ML framework as referenced in the prefix of the model ID.
140-
This value does not necessarily correspond to the container name.
141-
"""
142-
143-
PYTORCH = "pytorch"
144-
TENSORFLOW = "tensorflow"
145-
MXNET = "mxnet"
146-
HUGGINGFACE = "huggingface"
147-
LIGHTGBM = "lightgbm"
148-
CATBOOST = "catboost"
149-
XGBOOST = "xgboost"
150-
SKLEARN = "sklearn"
151-
152-
153-
class VariableScope(str, Enum):
154-
"""Possible value of the ``scope`` attribute for a hyperparameter or environment variable.
155-
156-
Used for hosting environment variables and training hyperparameters.
157-
"""
158-
159-
CONTAINER = "container"
160-
ALGORITHM = "algorithm"
161-
162-
163-
class JumpStartTag(str, Enum):
164-
"""Enum class for tag keys to apply to JumpStart models."""
165-
166-
INFERENCE_MODEL_URI = "aws-jumpstart-inference-model-uri"
167-
INFERENCE_SCRIPT_URI = "aws-jumpstart-inference-script-uri"
168-
TRAINING_MODEL_URI = "aws-jumpstart-training-model-uri"
169-
TRAINING_SCRIPT_URI = "aws-jumpstart-training-script-uri"

src/sagemaker/jumpstart/enums.py

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""This module stores enums related to SageMaker JumpStart."""
14+
15+
from __future__ import absolute_import
16+
17+
from enum import Enum
18+
19+
20+
class ModelFramework(str, Enum):
21+
"""Enum class for JumpStart model framework.
22+
23+
The ML framework as referenced in the prefix of the model ID.
24+
This value does not necessarily correspond to the container name.
25+
"""
26+
27+
PYTORCH = "pytorch"
28+
TENSORFLOW = "tensorflow"
29+
MXNET = "mxnet"
30+
HUGGINGFACE = "huggingface"
31+
LIGHTGBM = "lightgbm"
32+
CATBOOST = "catboost"
33+
XGBOOST = "xgboost"
34+
SKLEARN = "sklearn"
35+
36+
37+
class VariableScope(str, Enum):
38+
"""Possible value of the ``scope`` attribute for a hyperparameter or environment variable.
39+
40+
Used for hosting environment variables and training hyperparameters.
41+
"""
42+
43+
CONTAINER = "container"
44+
ALGORITHM = "algorithm"
45+
46+
47+
class JumpStartScriptScope(str, Enum):
48+
"""Enum class for JumpStart script scopes."""
49+
50+
INFERENCE = "inference"
51+
TRAINING = "training"
52+
53+
54+
class HyperparameterValidationMode(str, Enum):
55+
"""Possible modes for validating hyperparameters."""
56+
57+
VALIDATE_PROVIDED = "validate_provided"
58+
VALIDATE_ALGORITHM = "validate_algorithm"
59+
VALIDATE_ALL = "validate_all"
60+
61+
62+
class VariableTypes(str, Enum):
63+
"""Possible types for hyperparameters and environment variables."""
64+
65+
TEXT = "text"
66+
INT = "int"
67+
FLOAT = "float"
68+
BOOL = "bool"
69+
70+
71+
class JumpStartTag(str, Enum):
72+
"""Enum class for tag keys to apply to JumpStart models."""
73+
74+
INFERENCE_MODEL_URI = "aws-jumpstart-inference-model-uri"
75+
INFERENCE_SCRIPT_URI = "aws-jumpstart-inference-script-uri"
76+
TRAINING_MODEL_URI = "aws-jumpstart-training-model-uri"
77+
TRAINING_SCRIPT_URI = "aws-jumpstart-training-script-uri"

src/sagemaker/jumpstart/exceptions.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,24 @@
1111
# ANY KIND, either express or implied. See the License for the specific
1212
# language governing permissions and limitations under the License.
1313
"""This module stores exceptions related to SageMaker JumpStart."""
14-
1514
from __future__ import absolute_import
1615
from typing import List, Optional
1716

1817
from sagemaker.jumpstart.constants import JumpStartScriptScope
1918

2019

20+
class JumpStartHyperparametersError(Exception):
21+
"""Exception raised for bad hyperparameters of a JumpStart model."""
22+
23+
def __init__(
24+
self,
25+
message: Optional[str] = None,
26+
):
27+
self.message = message
28+
29+
super().__init__(self.message)
30+
31+
2132
class VulnerableJumpStartModelError(Exception):
2233
"""Exception raised when trying to access a JumpStart model specs flagged as vulnerable.
2334

src/sagemaker/jumpstart/types.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,8 @@ class JumpStartHyperparameter(JumpStartDataHolderType):
181181
"scope",
182182
"min",
183183
"max",
184+
"exclusive_min",
185+
"exclusive_max",
184186
}
185187

186188
def __init__(self, spec: Dict[str, Any]):
@@ -215,6 +217,14 @@ def from_json(self, json_obj: Dict[str, Any]) -> None:
215217
if max_val is not None:
216218
self.max = max_val
217219

220+
exclusive_min_val = json_obj.get("exclusive_min")
221+
if exclusive_min_val is not None:
222+
self.exclusive_min = exclusive_min_val
223+
224+
exclusive_max_val = json_obj.get("exclusive_max")
225+
if exclusive_max_val is not None:
226+
self.exclusive_max = exclusive_max_val
227+
218228
def to_json(self) -> Dict[str, Any]:
219229
"""Returns json representation of JumpStartHyperparameter object."""
220230
json_obj = {att: getattr(self, att) for att in self.__slots__ if hasattr(self, att)}

src/sagemaker/jumpstart/utils.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from urllib.parse import urlparse
1818
from packaging.version import Version
1919
import sagemaker
20-
from sagemaker.jumpstart import constants
20+
from sagemaker.jumpstart import constants, enums
2121
from sagemaker.jumpstart import accessors
2222
from sagemaker.s3 import parse_s3_url
2323
from sagemaker.jumpstart.exceptions import (
@@ -200,13 +200,13 @@ def get_tag_value(tag_key: str, tag_array: List[Dict[str, str]]) -> str:
200200

201201

202202
def add_single_jumpstart_tag(
203-
uri: str, tag_key: constants.JumpStartTag, curr_tags: Optional[List[Dict[str, str]]]
203+
uri: str, tag_key: enums.JumpStartTag, curr_tags: Optional[List[Dict[str, str]]]
204204
) -> Optional[List]:
205205
"""Adds ``tag_key`` to ``curr_tags`` if ``uri`` corresponds to a JumpStart model.
206206
207207
Args:
208208
uri (str): URI which may correspond to a JumpStart model.
209-
tag_key (constants.JumpStartTag): Custom tag to apply to current tags if the URI
209+
tag_key (enums.JumpStartTag): Custom tag to apply to current tags if the URI
210210
corresponds to a JumpStart model.
211211
curr_tags (Optional[List]): Current tags associated with ``Estimator`` or ``Model``.
212212
"""
@@ -249,22 +249,22 @@ def add_jumpstart_tags(
249249

250250
if inference_model_uri:
251251
tags = add_single_jumpstart_tag(
252-
inference_model_uri, constants.JumpStartTag.INFERENCE_MODEL_URI, tags
252+
inference_model_uri, enums.JumpStartTag.INFERENCE_MODEL_URI, tags
253253
)
254254

255255
if inference_script_uri:
256256
tags = add_single_jumpstart_tag(
257-
inference_script_uri, constants.JumpStartTag.INFERENCE_SCRIPT_URI, tags
257+
inference_script_uri, enums.JumpStartTag.INFERENCE_SCRIPT_URI, tags
258258
)
259259

260260
if training_model_uri:
261261
tags = add_single_jumpstart_tag(
262-
training_model_uri, constants.JumpStartTag.TRAINING_MODEL_URI, tags
262+
training_model_uri, enums.JumpStartTag.TRAINING_MODEL_URI, tags
263263
)
264264

265265
if training_script_uri:
266266
tags = add_single_jumpstart_tag(
267-
training_script_uri, constants.JumpStartTag.TRAINING_SCRIPT_URI, tags
267+
training_script_uri, enums.JumpStartTag.TRAINING_SCRIPT_URI, tags
268268
)
269269

270270
return tags
@@ -280,7 +280,7 @@ def update_inference_tags_with_jumpstart_training_tags(
280280
training_tags (Optional[List[Dict[str, str]]]): Tags from training job.
281281
"""
282282
if training_tags:
283-
for tag_key in constants.JumpStartTag:
283+
for tag_key in enums.JumpStartTag:
284284
if tag_key_in_array(tag_key, training_tags):
285285
tag_value = get_tag_value(tag_key, training_tags)
286286
if inference_tags is None:

0 commit comments

Comments
 (0)