Skip to content

Commit 8f4aecf

Browse files
committed
change: improve jumpstart hyperparam validation logic
1 parent f4b0536 commit 8f4aecf

File tree

8 files changed

+232
-32
lines changed

8 files changed

+232
-32
lines changed

src/sagemaker/environment_variables.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,9 @@ def retrieve_default(
4444
ValueError: If the combination of arguments specified is not supported.
4545
"""
4646
if not jumpstart_utils.is_jumpstart_model_input(model_id, model_version):
47-
raise ValueError("Must specify `model_id` and `model_version` when retrieving script URIs.")
47+
raise ValueError(
48+
"Must specify `model_id` and `model_version` when retrieving environment variables."
49+
)
4850

4951
# mypy type checking require these assertions
5052
assert model_id is not None

src/sagemaker/hyperparameters.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,9 @@ def retrieve_default(
5353
ValueError: If the combination of arguments specified is not supported.
5454
"""
5555
if not jumpstart_utils.is_jumpstart_model_input(model_id, model_version):
56-
raise ValueError("Must specify `model_id` and `model_version` when retrieving script URIs.")
56+
raise ValueError(
57+
"Must specify `model_id` and `model_version` when retrieving hyperparameters."
58+
)
5759

5860
return artifacts._retrieve_default_hyperparameters(
5961
model_id, model_version, region, include_container_hyperparameters
@@ -84,11 +86,16 @@ def validate(
8486
If set to ``VALIDATE_ALL``, all hyperparameters for the model will be validated.
8587
(Default: None)
8688
89+
Raises:
90+
JumpStartHyperparametersError: If the hyperparameter is not formatted correctly,
91+
according to its specs in the model metadata.
8792
8893
"""
8994

9095
if not jumpstart_utils.is_jumpstart_model_input(model_id, model_version):
91-
raise ValueError("Must specify `model_id` and `model_version` when retrieving script URIs.")
96+
raise ValueError(
97+
"Must specify `model_id` and `model_version` when validating hyperparameters."
98+
)
9299

93100
if hyperparameters is None:
94101
raise ValueError("Must specify hyperparameters.")

src/sagemaker/jumpstart/enums.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,18 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""This module stores enums related to SageMaker JumpStart."""
14+
from __future__ import absolute_import
15+
116
from enum import Enum
217

318

@@ -42,3 +57,4 @@ class VariableTypes(str, Enum):
4257
TEXT = "text"
4358
INT = "int"
4459
FLOAT = "float"
60+
BOOL = "bool"

src/sagemaker/jumpstart/exceptions.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,13 @@
1111
# ANY KIND, either express or implied. See the License for the specific
1212
# language governing permissions and limitations under the License.
1313
"""This module stores exceptions related to SageMaker JumpStart."""
14+
from __future__ import absolute_import
1415

1516
from typing import Optional
1617

1718

1819
class JumpStartHyperparametersError(Exception):
19-
"""Exception raised for errors with hyperparameters for JumpStart models."""
20+
"""Exception raised for bad hyperparameters of a JumpStart model."""
2021

2122
def __init__(
2223
self,

src/sagemaker/jumpstart/types.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,8 @@ class JumpStartHyperparameter(JumpStartDataHolderType):
181181
"scope",
182182
"min",
183183
"max",
184+
"exclusive_min",
185+
"exclusive_max",
184186
}
185187

186188
def __init__(self, spec: Dict[str, Any]):
@@ -215,6 +217,14 @@ def from_json(self, json_obj: Dict[str, Any]) -> None:
215217
if max_val is not None:
216218
self.max = max_val
217219

220+
exclusive_min_val = json_obj.get("exclusive_min")
221+
if exclusive_min_val is not None:
222+
self.exclusive_min = exclusive_min_val
223+
224+
exclusive_max_val = json_obj.get("exclusive_max")
225+
if exclusive_max_val is not None:
226+
self.exclusive_max = exclusive_max_val
227+
218228
def to_json(self) -> Dict[str, Any]:
219229
"""Returns json representation of JumpStartHyperparameter object."""
220230
json_obj = {att: getattr(self, att) for att in self.__slots__ if hasattr(self, att)}

src/sagemaker/jumpstart/validators.py

Lines changed: 90 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# language governing permissions and limitations under the License.
1313
"""This module contains validators related to SageMaker JumpStart."""
1414
from __future__ import absolute_import
15-
from typing import Any, List, Optional
15+
from typing import Any, Dict, List, Optional
1616
from sagemaker.jumpstart.constants import JUMPSTART_DEFAULT_REGION_NAME
1717

1818
from sagemaker.jumpstart.enums import HyperparameterValidationMode, VariableScope, VariableTypes
@@ -33,36 +33,88 @@ def _validate_hyperparameter(
3333
hyperparameter_value (Any): The value of the hyperparemter to validate.
3434
hyperparameter_specs (List[JumpStartHyperparameter]): List of ``JumpStartHyperparameter`` to
3535
use when validating the hyperparameter.
36+
37+
Raises:
38+
JumpStartHyperparametersError: If the hyperparameter is not formatted correctly,
39+
according to its specs in the model metadata.
3640
"""
3741
hyperparameter_spec = [
3842
spec for spec in hyperparameter_specs if spec.name == hyperparameter_name
3943
]
4044
if len(hyperparameter_spec) == 0:
4145
raise JumpStartHyperparametersError(
42-
f"Unable to perform validation -- cannot find hyperparameter '{hyperparameter_name}' in model specs."
46+
f"Unable to perform validation -- cannot find hyperparameter '{hyperparameter_name}' "
47+
"in model specs."
48+
)
49+
50+
if len(hyperparameter_spec) > 1:
51+
raise JumpStartHyperparametersError(
52+
f"Unable to perform validation -- found multiple hyperparameter "
53+
f"'{hyperparameter_name}' in model specs."
4354
)
55+
4456
hyperparameter_spec = hyperparameter_spec[0]
4557

46-
if hyperparameter_spec.type == VariableTypes.TEXT.value:
47-
if type(hyperparameter_value) != str:
58+
if hyperparameter_spec.type == VariableTypes.BOOL.value:
59+
if isinstance(hyperparameter_value, bool):
60+
return
61+
if not isinstance(hyperparameter_value, str):
62+
raise JumpStartHyperparametersError(
63+
f"Expecting boolean valued hyperparameter, but got '{str(hyperparameter_value)}'."
64+
)
65+
if str(hyperparameter_value).lower() not in ["true", "false"]:
4866
raise JumpStartHyperparametersError(
49-
f"Expecting text valued hyperparameter to have string type."
67+
f"Expecting boolean valued hyperparameter, but got '{str(hyperparameter_value)}'."
68+
)
69+
elif hyperparameter_spec.type == VariableTypes.TEXT.value:
70+
if not isinstance(hyperparameter_value, str):
71+
raise JumpStartHyperparametersError(
72+
"Expecting text valued hyperparameter to have string type."
5073
)
5174

52-
if getattr(hyperparameter_spec, "options", None):
75+
if hasattr(hyperparameter_spec, "options"):
5376
if hyperparameter_value not in hyperparameter_spec.options:
5477
raise JumpStartHyperparametersError(
55-
f"Hyperparameter '{hyperparameter_name}' must have one of the following values: "
56-
", ".join(hyperparameter_spec.options)
78+
f"Hyperparameter '{hyperparameter_name}' must have one of the following "
79+
f"values: {', '.join(hyperparameter_spec.options)}"
80+
)
81+
82+
if hasattr(hyperparameter_spec, "min"):
83+
if len(hyperparameter_value) < hyperparameter_spec.min:
84+
raise JumpStartHyperparametersError(
85+
f"Hyperparameter '{hyperparameter_name}' must have length no less than "
86+
f"{hyperparameter_spec.min}"
87+
)
88+
89+
if hasattr(hyperparameter_spec, "exclusive_min"):
90+
if len(hyperparameter_value) <= hyperparameter_spec.exclusive_min:
91+
raise JumpStartHyperparametersError(
92+
f"Hyperparameter '{hyperparameter_name}' must have length greater than "
93+
f"{hyperparameter_spec.exclusive_min}"
94+
)
95+
96+
if hasattr(hyperparameter_spec, "max"):
97+
if len(hyperparameter_value) > hyperparameter_spec.max:
98+
raise JumpStartHyperparametersError(
99+
f"Hyperparameter '{hyperparameter_name}' must have length no greater than "
100+
f"{hyperparameter_spec.max}"
101+
)
102+
103+
if hasattr(hyperparameter_spec, "exclusive_max"):
104+
if len(hyperparameter_value) >= hyperparameter_spec.exclusive_max:
105+
raise JumpStartHyperparametersError(
106+
f"Hyperparameter '{hyperparameter_name}' must have length less than "
107+
f"{hyperparameter_spec.exclusive_max}"
57108
)
58109

59110
# validate numeric types
60-
if hyperparameter_spec.type in [VariableTypes.INT.value, VariableTypes.FLOAT.value]:
111+
elif hyperparameter_spec.type in [VariableTypes.INT.value, VariableTypes.FLOAT.value]:
61112
try:
62113
numeric_hyperparam_value = float(hyperparameter_value)
63114
except ValueError:
64115
raise JumpStartHyperparametersError(
65-
f"Hyperparameter '{hyperparameter_name}' must be numeric type ('{hyperparameter_value}')."
116+
f"Hyperparameter '{hyperparameter_name}' must be numeric type "
117+
f"('{hyperparameter_value}')."
66118
)
67119

68120
if hyperparameter_spec.type == VariableTypes.INT.value:
@@ -72,29 +124,46 @@ def _validate_hyperparameter(
72124
start_index = 1
73125
if not hyperparameter_value_str[start_index:].isdigit():
74126
raise JumpStartHyperparametersError(
75-
f"Hyperparameter '{hyperparameter_name}' must be integer type ('{hyperparameter_value}')."
127+
f"Hyperparameter '{hyperparameter_name}' must be integer type "
128+
"('{hyperparameter_value}')."
76129
)
77130

78-
if getattr(hyperparameter_spec, "min", None):
131+
if hasattr(hyperparameter_spec, "min"):
79132
if numeric_hyperparam_value < hyperparameter_spec.min:
80133
raise JumpStartHyperparametersError(
81-
f"Hyperparameter '{hyperparameter_name}' can be no less than {hyperparameter_spec.min}."
134+
f"Hyperparameter '{hyperparameter_name}' can be no less than "
135+
"{hyperparameter_spec.min}."
82136
)
83137

84-
if getattr(hyperparameter_spec, "max", None):
138+
if hasattr(hyperparameter_spec, "max"):
85139
if numeric_hyperparam_value > hyperparameter_spec.max:
86140
raise JumpStartHyperparametersError(
87-
f"Hyperparameter '{hyperparameter_name}' can be no greater than {hyperparameter_spec.max}."
141+
f"Hyperparameter '{hyperparameter_name}' can be no greater than "
142+
"{hyperparameter_spec.max}."
143+
)
144+
145+
if hasattr(hyperparameter_spec, "exclusive_min"):
146+
if numeric_hyperparam_value <= hyperparameter_spec.exclusive_min:
147+
raise JumpStartHyperparametersError(
148+
f"Hyperparameter '{hyperparameter_name}' must be greater than "
149+
"{hyperparameter_spec.exclusive_min}."
150+
)
151+
152+
if hasattr(hyperparameter_spec, "exclusive_max"):
153+
if numeric_hyperparam_value >= hyperparameter_spec.exclusive_max:
154+
raise JumpStartHyperparametersError(
155+
f"Hyperparameter '{hyperparameter_name}' must be less than than "
156+
"{hyperparameter_spec.exclusive_max}."
88157
)
89158

90159

91160
def validate_hyperparameters(
92161
model_id: str,
93162
model_version: str,
94-
hyperparameters: dict,
163+
hyperparameters: Dict[str, Any],
95164
validation_mode: HyperparameterValidationMode = HyperparameterValidationMode.VALIDATE_PROVIDED,
96165
region: Optional[str] = JUMPSTART_DEFAULT_REGION_NAME,
97-
):
166+
) -> None:
98167
"""Validate hyperparameters for JumpStart models.
99168
100169
Args:
@@ -109,6 +178,10 @@ def validate_hyperparameters(
109178
region (str): Region for which to validate hyperparameters. (Default: JumpStart
110179
default region).
111180
181+
Raises:
182+
JumpStartHyperparametersError: If the hyperparameter is not formatted correctly,
183+
according to its specs in the model metadata.
184+
112185
"""
113186

114187
if validation_mode is None:

src/sagemaker/model_uris.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def retrieve(
4646
ValueError: If the combination of arguments specified is not supported.
4747
"""
4848
if not jumpstart_utils.is_jumpstart_model_input(model_id, model_version):
49-
raise ValueError("Must specify `model_id` and `model_version` when retrieving script URIs.")
49+
raise ValueError("Must specify `model_id` and `model_version` when retrieving model URIs.")
5050

5151
# mypy type checking require these assertions
5252
assert model_id is not None

0 commit comments

Comments
 (0)