1212# language governing permissions and limitations under the License.
1313"""This module contains validators related to SageMaker JumpStart."""
1414from __future__ import absolute_import
15- from typing import Any , List , Optional
15+ from typing import Any , Dict , List , Optional
1616from sagemaker .jumpstart .constants import JUMPSTART_DEFAULT_REGION_NAME
1717
1818from sagemaker .jumpstart .enums import HyperparameterValidationMode , VariableScope , VariableTypes
@@ -33,36 +33,88 @@ def _validate_hyperparameter(
3333 hyperparameter_value (Any): The value of the hyperparemter to validate.
3434 hyperparameter_specs (List[JumpStartHyperparameter]): List of ``JumpStartHyperparameter`` to
3535 use when validating the hyperparameter.
36+
37+ Raises:
38+ JumpStartHyperparametersError: If the hyperparameter is not formatted correctly,
39+ according to its specs in the model metadata.
3640 """
3741 hyperparameter_spec = [
3842 spec for spec in hyperparameter_specs if spec .name == hyperparameter_name
3943 ]
4044 if len (hyperparameter_spec ) == 0 :
4145 raise JumpStartHyperparametersError (
42- f"Unable to perform validation -- cannot find hyperparameter '{ hyperparameter_name } ' in model specs."
46+ f"Unable to perform validation -- cannot find hyperparameter '{ hyperparameter_name } ' "
47+ "in model specs."
48+ )
49+
50+ if len (hyperparameter_spec ) > 1 :
51+ raise JumpStartHyperparametersError (
52+ f"Unable to perform validation -- found multiple hyperparameter "
53+ f"'{ hyperparameter_name } ' in model specs."
4354 )
55+
4456 hyperparameter_spec = hyperparameter_spec [0 ]
4557
46- if hyperparameter_spec .type == VariableTypes .TEXT .value :
47- if type (hyperparameter_value ) != str :
58+ if hyperparameter_spec .type == VariableTypes .BOOL .value :
59+ if isinstance (hyperparameter_value , bool ):
60+ return
61+ if not isinstance (hyperparameter_value , str ):
62+ raise JumpStartHyperparametersError (
63+ f"Expecting boolean valued hyperparameter, but got '{ str (hyperparameter_value )} '."
64+ )
65+ if str (hyperparameter_value ).lower () not in ["true" , "false" ]:
4866 raise JumpStartHyperparametersError (
49- f"Expecting text valued hyperparameter to have string type."
67+ f"Expecting boolean valued hyperparameter, but got '{ str (hyperparameter_value )} '."
68+ )
69+ elif hyperparameter_spec .type == VariableTypes .TEXT .value :
70+ if not isinstance (hyperparameter_value , str ):
71+ raise JumpStartHyperparametersError (
72+ "Expecting text valued hyperparameter to have string type."
5073 )
5174
52- if getattr (hyperparameter_spec , "options" , None ):
75+ if hasattr (hyperparameter_spec , "options" ):
5376 if hyperparameter_value not in hyperparameter_spec .options :
5477 raise JumpStartHyperparametersError (
55- f"Hyperparameter '{ hyperparameter_name } ' must have one of the following values: "
56- ", " .join (hyperparameter_spec .options )
78+ f"Hyperparameter '{ hyperparameter_name } ' must have one of the following "
79+ f"values: { ', ' .join (hyperparameter_spec .options )} "
80+ )
81+
82+ if hasattr (hyperparameter_spec , "min" ):
83+ if len (hyperparameter_value ) < hyperparameter_spec .min :
84+ raise JumpStartHyperparametersError (
85+ f"Hyperparameter '{ hyperparameter_name } ' must have length no less than "
86+ f"{ hyperparameter_spec .min } "
87+ )
88+
89+ if hasattr (hyperparameter_spec , "exclusive_min" ):
90+ if len (hyperparameter_value ) <= hyperparameter_spec .exclusive_min :
91+ raise JumpStartHyperparametersError (
92+ f"Hyperparameter '{ hyperparameter_name } ' must have length greater than "
93+ f"{ hyperparameter_spec .exclusive_min } "
94+ )
95+
96+ if hasattr (hyperparameter_spec , "max" ):
97+ if len (hyperparameter_value ) > hyperparameter_spec .max :
98+ raise JumpStartHyperparametersError (
99+ f"Hyperparameter '{ hyperparameter_name } ' must have length no greater than "
100+ f"{ hyperparameter_spec .max } "
101+ )
102+
103+ if hasattr (hyperparameter_spec , "exclusive_max" ):
104+ if len (hyperparameter_value ) >= hyperparameter_spec .exclusive_max :
105+ raise JumpStartHyperparametersError (
106+ f"Hyperparameter '{ hyperparameter_name } ' must have length less than "
107+ f"{ hyperparameter_spec .exclusive_max } "
57108 )
58109
59110 # validate numeric types
60- if hyperparameter_spec .type in [VariableTypes .INT .value , VariableTypes .FLOAT .value ]:
111+ elif hyperparameter_spec .type in [VariableTypes .INT .value , VariableTypes .FLOAT .value ]:
61112 try :
62113 numeric_hyperparam_value = float (hyperparameter_value )
63114 except ValueError :
64115 raise JumpStartHyperparametersError (
65- f"Hyperparameter '{ hyperparameter_name } ' must be numeric type ('{ hyperparameter_value } ')."
116+ f"Hyperparameter '{ hyperparameter_name } ' must be numeric type "
117+ f"('{ hyperparameter_value } ')."
66118 )
67119
68120 if hyperparameter_spec .type == VariableTypes .INT .value :
@@ -72,29 +124,46 @@ def _validate_hyperparameter(
72124 start_index = 1
73125 if not hyperparameter_value_str [start_index :].isdigit ():
74126 raise JumpStartHyperparametersError (
75- f"Hyperparameter '{ hyperparameter_name } ' must be integer type ('{ hyperparameter_value } ')."
127+ f"Hyperparameter '{ hyperparameter_name } ' must be integer type "
128+ "('{hyperparameter_value}')."
76129 )
77130
78- if getattr (hyperparameter_spec , "min" , None ):
131+ if hasattr (hyperparameter_spec , "min" ):
79132 if numeric_hyperparam_value < hyperparameter_spec .min :
80133 raise JumpStartHyperparametersError (
81- f"Hyperparameter '{ hyperparameter_name } ' can be no less than { hyperparameter_spec .min } ."
134+ f"Hyperparameter '{ hyperparameter_name } ' can be no less than "
135+ "{hyperparameter_spec.min}."
82136 )
83137
84- if getattr (hyperparameter_spec , "max" , None ):
138+ if hasattr (hyperparameter_spec , "max" ):
85139 if numeric_hyperparam_value > hyperparameter_spec .max :
86140 raise JumpStartHyperparametersError (
87- f"Hyperparameter '{ hyperparameter_name } ' can be no greater than { hyperparameter_spec .max } ."
141+ f"Hyperparameter '{ hyperparameter_name } ' can be no greater than "
142+ "{hyperparameter_spec.max}."
143+ )
144+
145+ if hasattr (hyperparameter_spec , "exclusive_min" ):
146+ if numeric_hyperparam_value <= hyperparameter_spec .exclusive_min :
147+ raise JumpStartHyperparametersError (
148+ f"Hyperparameter '{ hyperparameter_name } ' must be greater than "
149+ "{hyperparameter_spec.exclusive_min}."
150+ )
151+
152+ if hasattr (hyperparameter_spec , "exclusive_max" ):
153+ if numeric_hyperparam_value >= hyperparameter_spec .exclusive_max :
154+ raise JumpStartHyperparametersError (
155+ f"Hyperparameter '{ hyperparameter_name } ' must be less than than "
156+ "{hyperparameter_spec.exclusive_max}."
88157 )
89158
90159
91160def validate_hyperparameters (
92161 model_id : str ,
93162 model_version : str ,
94- hyperparameters : dict ,
163+ hyperparameters : Dict [ str , Any ] ,
95164 validation_mode : HyperparameterValidationMode = HyperparameterValidationMode .VALIDATE_PROVIDED ,
96165 region : Optional [str ] = JUMPSTART_DEFAULT_REGION_NAME ,
97- ):
166+ ) -> None :
98167 """Validate hyperparameters for JumpStart models.
99168
100169 Args:
@@ -109,6 +178,10 @@ def validate_hyperparameters(
109178 region (str): Region for which to validate hyperparameters. (Default: JumpStart
110179 default region).
111180
181+ Raises:
182+ JumpStartHyperparametersError: If the hyperparameter is not formatted correctly,
183+ according to its specs in the model metadata.
184+
112185 """
113186
114187 if validation_mode is None :
0 commit comments