Skip to content

Commit f4b0536

Browse files
committed
feat: hyperparameter validation
1 parent 00f23e6 commit f4b0536

File tree

7 files changed

+549
-29
lines changed

7 files changed

+549
-29
lines changed

src/sagemaker/hyperparameters.py

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,12 @@
1515
from __future__ import absolute_import
1616

1717
import logging
18-
from typing import Dict
18+
from typing import Dict, Optional
1919

2020
from sagemaker.jumpstart import utils as jumpstart_utils
2121
from sagemaker.jumpstart import artifacts
22+
from sagemaker.jumpstart.enums import HyperparameterValidationMode
23+
from sagemaker.jumpstart.validators import validate_hyperparameters
2224

2325
logger = logging.getLogger(__name__)
2426

@@ -56,3 +58,45 @@ def retrieve_default(
5658
return artifacts._retrieve_default_hyperparameters(
5759
model_id, model_version, region, include_container_hyperparameters
5860
)
61+
62+
63+
def validate(
64+
region: Optional[str] = None,
65+
model_id: Optional[str] = None,
66+
model_version: Optional[str] = None,
67+
hyperparameters: Optional[dict] = None,
68+
validation_mode: Optional[HyperparameterValidationMode] = None,
69+
):
70+
"""Validate hyperparameters for models.
71+
72+
Args:
73+
region (str): Region for which to validate hyperparameters. (Default: None).
74+
model_id (str): Model ID of the model for which to validate hyperparameters.
75+
(Default: None)
76+
model_version (str): Version of the model for which to validate hyperparameters.
77+
(Default: None)
78+
hyperparameters (dict): Hyperparameters to validate.
79+
(Default: None)
80+
validation_mode (HyperparameterValidationMode): Method of validation to use with
81+
hyperparameters. If set to ``VALIDATE_PROVIDED``, only hyperparameters provided
82+
to this function will be validated, the missing hyperparameters will be ignored.
83+
If set to``VALIDATE_ALGORITHM``, all algorithm hyperparameters will be validated.
84+
If set to ``VALIDATE_ALL``, all hyperparameters for the model will be validated.
85+
(Default: None)
86+
87+
88+
"""
89+
90+
if not jumpstart_utils.is_jumpstart_model_input(model_id, model_version):
91+
raise ValueError("Must specify `model_id` and `model_version` when retrieving script URIs.")
92+
93+
if hyperparameters is None:
94+
raise ValueError("Must specify hyperparameters.")
95+
96+
return validate_hyperparameters(
97+
model_id=model_id,
98+
model_version=model_version,
99+
hyperparameters=hyperparameters,
100+
validation_mode=validation_mode,
101+
region=region,
102+
)

src/sagemaker/jumpstart/artifacts.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919
INFERENCE,
2020
TRAINING,
2121
SUPPORTED_JUMPSTART_SCOPES,
22+
)
23+
from sagemaker.jumpstart.enums import (
2224
ModelFramework,
2325
VariableScope,
2426
)

src/sagemaker/jumpstart/constants.py

Lines changed: 0 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
"""This module stores constants related to SageMaker JumpStart."""
1414
from __future__ import absolute_import
1515
from typing import Set
16-
from enum import Enum
1716
import boto3
1817
from sagemaker.jumpstart.types import JumpStartLaunchedRegionInfo
1918

@@ -122,30 +121,3 @@
122121

123122
INFERENCE_ENTRYPOINT_SCRIPT_NAME = "inference.py"
124123
TRAINING_ENTRYPOINT_SCRIPT_NAME = "transfer_learning.py"
125-
126-
127-
class ModelFramework(str, Enum):
128-
"""Enum class for JumpStart model framework.
129-
130-
The ML framework as referenced in the prefix of the model ID.
131-
This value does not necessarily correspond to the container name.
132-
"""
133-
134-
PYTORCH = "pytorch"
135-
TENSORFLOW = "tensorflow"
136-
MXNET = "mxnet"
137-
HUGGINGFACE = "huggingface"
138-
LIGHTGBM = "lightgbm"
139-
CATBOOST = "catboost"
140-
XGBOOST = "xgboost"
141-
SKLEARN = "sklearn"
142-
143-
144-
class VariableScope(str, Enum):
145-
"""Possible value of the ``scope`` attribute for a hyperparameter or environment variable.
146-
147-
Used for hosting environment variables and training hyperparameters.
148-
"""
149-
150-
CONTAINER = "container"
151-
ALGORITHM = "algorithm"

src/sagemaker/jumpstart/enums.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
from enum import Enum
2+
3+
4+
class ModelFramework(str, Enum):
5+
"""Enum class for JumpStart model framework.
6+
7+
The ML framework as referenced in the prefix of the model ID.
8+
This value does not necessarily correspond to the container name.
9+
"""
10+
11+
PYTORCH = "pytorch"
12+
TENSORFLOW = "tensorflow"
13+
MXNET = "mxnet"
14+
HUGGINGFACE = "huggingface"
15+
LIGHTGBM = "lightgbm"
16+
CATBOOST = "catboost"
17+
XGBOOST = "xgboost"
18+
SKLEARN = "sklearn"
19+
20+
21+
class VariableScope(str, Enum):
22+
"""Possible value of the ``scope`` attribute for a hyperparameter or environment variable.
23+
24+
Used for hosting environment variables and training hyperparameters.
25+
"""
26+
27+
CONTAINER = "container"
28+
ALGORITHM = "algorithm"
29+
30+
31+
class HyperparameterValidationMode(str, Enum):
32+
"""Possible modes for validating hyperparameters."""
33+
34+
VALIDATE_PROVIDED = "validate_provided"
35+
VALIDATE_ALGORITHM = "validate_algorithm"
36+
VALIDATE_ALL = "validate_all"
37+
38+
39+
class VariableTypes(str, Enum):
40+
"""Possible types for hyperparameters and environment variables."""
41+
42+
TEXT = "text"
43+
INT = "int"
44+
FLOAT = "float"
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""This module stores exceptions related to SageMaker JumpStart."""
14+
15+
from typing import Optional
16+
17+
18+
class JumpStartHyperparametersError(Exception):
19+
"""Exception raised for errors with hyperparameters for JumpStart models."""
20+
21+
def __init__(
22+
self,
23+
message: Optional[str] = None,
24+
):
25+
self.message = message
26+
27+
super().__init__(self.message)
Lines changed: 153 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,153 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""This module contains validators related to SageMaker JumpStart."""
14+
from __future__ import absolute_import
15+
from typing import Any, List, Optional
16+
from sagemaker.jumpstart.constants import JUMPSTART_DEFAULT_REGION_NAME
17+
18+
from sagemaker.jumpstart.enums import HyperparameterValidationMode, VariableScope, VariableTypes
19+
from sagemaker.jumpstart import accessors as jumpstart_accessors
20+
from sagemaker.jumpstart.exceptions import JumpStartHyperparametersError
21+
from sagemaker.jumpstart.types import JumpStartHyperparameter
22+
23+
24+
def _validate_hyperparameter(
25+
hyperparameter_name: str,
26+
hyperparameter_value: Any,
27+
hyperparameter_specs: List[JumpStartHyperparameter],
28+
):
29+
"""Perform low-level hyperparameter validation on single parameter.
30+
31+
Args:
32+
hyperparameter_name (str): The name of the hyperparameter to validate.
33+
hyperparameter_value (Any): The value of the hyperparemter to validate.
34+
hyperparameter_specs (List[JumpStartHyperparameter]): List of ``JumpStartHyperparameter`` to
35+
use when validating the hyperparameter.
36+
"""
37+
hyperparameter_spec = [
38+
spec for spec in hyperparameter_specs if spec.name == hyperparameter_name
39+
]
40+
if len(hyperparameter_spec) == 0:
41+
raise JumpStartHyperparametersError(
42+
f"Unable to perform validation -- cannot find hyperparameter '{hyperparameter_name}' in model specs."
43+
)
44+
hyperparameter_spec = hyperparameter_spec[0]
45+
46+
if hyperparameter_spec.type == VariableTypes.TEXT.value:
47+
if type(hyperparameter_value) != str:
48+
raise JumpStartHyperparametersError(
49+
f"Expecting text valued hyperparameter to have string type."
50+
)
51+
52+
if getattr(hyperparameter_spec, "options", None):
53+
if hyperparameter_value not in hyperparameter_spec.options:
54+
raise JumpStartHyperparametersError(
55+
f"Hyperparameter '{hyperparameter_name}' must have one of the following values: "
56+
", ".join(hyperparameter_spec.options)
57+
)
58+
59+
# validate numeric types
60+
if hyperparameter_spec.type in [VariableTypes.INT.value, VariableTypes.FLOAT.value]:
61+
try:
62+
numeric_hyperparam_value = float(hyperparameter_value)
63+
except ValueError:
64+
raise JumpStartHyperparametersError(
65+
f"Hyperparameter '{hyperparameter_name}' must be numeric type ('{hyperparameter_value}')."
66+
)
67+
68+
if hyperparameter_spec.type == VariableTypes.INT.value:
69+
hyperparameter_value_str = str(hyperparameter_value)
70+
start_index = 0
71+
if hyperparameter_value_str[0] in ["+", "-"]:
72+
start_index = 1
73+
if not hyperparameter_value_str[start_index:].isdigit():
74+
raise JumpStartHyperparametersError(
75+
f"Hyperparameter '{hyperparameter_name}' must be integer type ('{hyperparameter_value}')."
76+
)
77+
78+
if getattr(hyperparameter_spec, "min", None):
79+
if numeric_hyperparam_value < hyperparameter_spec.min:
80+
raise JumpStartHyperparametersError(
81+
f"Hyperparameter '{hyperparameter_name}' can be no less than {hyperparameter_spec.min}."
82+
)
83+
84+
if getattr(hyperparameter_spec, "max", None):
85+
if numeric_hyperparam_value > hyperparameter_spec.max:
86+
raise JumpStartHyperparametersError(
87+
f"Hyperparameter '{hyperparameter_name}' can be no greater than {hyperparameter_spec.max}."
88+
)
89+
90+
91+
def validate_hyperparameters(
92+
model_id: str,
93+
model_version: str,
94+
hyperparameters: dict,
95+
validation_mode: HyperparameterValidationMode = HyperparameterValidationMode.VALIDATE_PROVIDED,
96+
region: Optional[str] = JUMPSTART_DEFAULT_REGION_NAME,
97+
):
98+
"""Validate hyperparameters for JumpStart models.
99+
100+
Args:
101+
model_id (str): Model ID of the model for which to validate hyperparameters.
102+
model_version (str): Version of the model for which to validate hyperparameters.
103+
hyperparameters (dict): Hyperparameters to validate.
104+
validation_mode (HyperparameterValidationMode): Method of validation to use with
105+
hyperparameters. If set to ``VALIDATE_PROVIDED``, only hyperparameters provided
106+
to this function will be validated, the missing hyperparameters will be ignored.
107+
If set to``VALIDATE_ALGORITHM``, all algorithm hyperparameters will be validated.
108+
If set to ``VALIDATE_ALL``, all hyperparameters for the model will be validated.
109+
region (str): Region for which to validate hyperparameters. (Default: JumpStart
110+
default region).
111+
112+
"""
113+
114+
if validation_mode is None:
115+
validation_mode = HyperparameterValidationMode.VALIDATE_PROVIDED
116+
117+
if region is None:
118+
region = JUMPSTART_DEFAULT_REGION_NAME
119+
120+
model_specs = jumpstart_accessors.JumpStartModelsAccessor.get_model_specs(
121+
region=region, model_id=model_id, version=model_version
122+
)
123+
hyperparameters_specs = model_specs.hyperparameters
124+
125+
if validation_mode == HyperparameterValidationMode.VALIDATE_PROVIDED:
126+
for hyperparam_name, hyperparam_value in hyperparameters.items():
127+
_validate_hyperparameter(hyperparam_name, hyperparam_value, hyperparameters_specs)
128+
129+
elif validation_mode == HyperparameterValidationMode.VALIDATE_ALGORITHM:
130+
for hyperparam in hyperparameters_specs:
131+
if hyperparam.scope == VariableScope.ALGORITHM:
132+
if hyperparam.name not in hyperparameters:
133+
raise JumpStartHyperparametersError(
134+
f"Cannot find algorithm hyperparameter for '{hyperparam.name}'."
135+
)
136+
_validate_hyperparameter(
137+
hyperparam.name, hyperparameters[hyperparam.name], hyperparameters_specs
138+
)
139+
140+
elif validation_mode == HyperparameterValidationMode.VALIDATE_ALL:
141+
for hyperparam in hyperparameters_specs:
142+
if hyperparam.name not in hyperparameters:
143+
raise JumpStartHyperparametersError(
144+
f"Cannot find hyperparameter for '{hyperparam.name}'."
145+
)
146+
_validate_hyperparameter(
147+
hyperparam.name, hyperparameters[hyperparam.name], hyperparameters_specs
148+
)
149+
150+
else:
151+
raise NotImplementedError(
152+
f"Unable to handle validation for the mode '{validation_mode.value}'."
153+
)

0 commit comments

Comments
 (0)