|
| 1 | +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"). You |
| 4 | +# may not use this file except in compliance with the License. A copy of |
| 5 | +# the License is located at |
| 6 | +# |
| 7 | +# http://aws.amazon.com/apache2.0/ |
| 8 | +# |
| 9 | +# or in the "license" file accompanying this file. This file is |
| 10 | +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF |
| 11 | +# ANY KIND, either express or implied. See the License for the specific |
| 12 | +# language governing permissions and limitations under the License. |
| 13 | +"""This module contains validators related to SageMaker JumpStart.""" |
| 14 | +from __future__ import absolute_import |
| 15 | +from typing import Any, List, Optional |
| 16 | +from sagemaker.jumpstart.constants import JUMPSTART_DEFAULT_REGION_NAME |
| 17 | + |
| 18 | +from sagemaker.jumpstart.enums import HyperparameterValidationMode, VariableScope, VariableTypes |
| 19 | +from sagemaker.jumpstart import accessors as jumpstart_accessors |
| 20 | +from sagemaker.jumpstart.exceptions import JumpStartHyperparametersError |
| 21 | +from sagemaker.jumpstart.types import JumpStartHyperparameter |
| 22 | + |
| 23 | + |
| 24 | +def _validate_hyperparameter( |
| 25 | + hyperparameter_name: str, |
| 26 | + hyperparameter_value: Any, |
| 27 | + hyperparameter_specs: List[JumpStartHyperparameter], |
| 28 | +): |
| 29 | + """Perform low-level hyperparameter validation on single parameter. |
| 30 | +
|
| 31 | + Args: |
| 32 | + hyperparameter_name (str): The name of the hyperparameter to validate. |
| 33 | + hyperparameter_value (Any): The value of the hyperparemter to validate. |
| 34 | + hyperparameter_specs (List[JumpStartHyperparameter]): List of ``JumpStartHyperparameter`` to |
| 35 | + use when validating the hyperparameter. |
| 36 | + """ |
| 37 | + hyperparameter_spec = [ |
| 38 | + spec for spec in hyperparameter_specs if spec.name == hyperparameter_name |
| 39 | + ] |
| 40 | + if len(hyperparameter_spec) == 0: |
| 41 | + raise JumpStartHyperparametersError( |
| 42 | + f"Unable to perform validation -- cannot find hyperparameter '{hyperparameter_name}' in model specs." |
| 43 | + ) |
| 44 | + hyperparameter_spec = hyperparameter_spec[0] |
| 45 | + |
| 46 | + if hyperparameter_spec.type == VariableTypes.TEXT.value: |
| 47 | + if type(hyperparameter_value) != str: |
| 48 | + raise JumpStartHyperparametersError( |
| 49 | + f"Expecting text valued hyperparameter to have string type." |
| 50 | + ) |
| 51 | + |
| 52 | + if getattr(hyperparameter_spec, "options", None): |
| 53 | + if hyperparameter_value not in hyperparameter_spec.options: |
| 54 | + raise JumpStartHyperparametersError( |
| 55 | + f"Hyperparameter '{hyperparameter_name}' must have one of the following values: " |
| 56 | + ", ".join(hyperparameter_spec.options) |
| 57 | + ) |
| 58 | + |
| 59 | + # validate numeric types |
| 60 | + if hyperparameter_spec.type in [VariableTypes.INT.value, VariableTypes.FLOAT.value]: |
| 61 | + try: |
| 62 | + numeric_hyperparam_value = float(hyperparameter_value) |
| 63 | + except ValueError: |
| 64 | + raise JumpStartHyperparametersError( |
| 65 | + f"Hyperparameter '{hyperparameter_name}' must be numeric type ('{hyperparameter_value}')." |
| 66 | + ) |
| 67 | + |
| 68 | + if hyperparameter_spec.type == VariableTypes.INT.value: |
| 69 | + hyperparameter_value_str = str(hyperparameter_value) |
| 70 | + start_index = 0 |
| 71 | + if hyperparameter_value_str[0] in ["+", "-"]: |
| 72 | + start_index = 1 |
| 73 | + if not hyperparameter_value_str[start_index:].isdigit(): |
| 74 | + raise JumpStartHyperparametersError( |
| 75 | + f"Hyperparameter '{hyperparameter_name}' must be integer type ('{hyperparameter_value}')." |
| 76 | + ) |
| 77 | + |
| 78 | + if getattr(hyperparameter_spec, "min", None): |
| 79 | + if numeric_hyperparam_value < hyperparameter_spec.min: |
| 80 | + raise JumpStartHyperparametersError( |
| 81 | + f"Hyperparameter '{hyperparameter_name}' can be no less than {hyperparameter_spec.min}." |
| 82 | + ) |
| 83 | + |
| 84 | + if getattr(hyperparameter_spec, "max", None): |
| 85 | + if numeric_hyperparam_value > hyperparameter_spec.max: |
| 86 | + raise JumpStartHyperparametersError( |
| 87 | + f"Hyperparameter '{hyperparameter_name}' can be no greater than {hyperparameter_spec.max}." |
| 88 | + ) |
| 89 | + |
| 90 | + |
| 91 | +def validate_hyperparameters( |
| 92 | + model_id: str, |
| 93 | + model_version: str, |
| 94 | + hyperparameters: dict, |
| 95 | + validation_mode: HyperparameterValidationMode = HyperparameterValidationMode.VALIDATE_PROVIDED, |
| 96 | + region: Optional[str] = JUMPSTART_DEFAULT_REGION_NAME, |
| 97 | +): |
| 98 | + """Validate hyperparameters for JumpStart models. |
| 99 | +
|
| 100 | + Args: |
| 101 | + model_id (str): Model ID of the model for which to validate hyperparameters. |
| 102 | + model_version (str): Version of the model for which to validate hyperparameters. |
| 103 | + hyperparameters (dict): Hyperparameters to validate. |
| 104 | + validation_mode (HyperparameterValidationMode): Method of validation to use with |
| 105 | + hyperparameters. If set to ``VALIDATE_PROVIDED``, only hyperparameters provided |
| 106 | + to this function will be validated, the missing hyperparameters will be ignored. |
| 107 | + If set to``VALIDATE_ALGORITHM``, all algorithm hyperparameters will be validated. |
| 108 | + If set to ``VALIDATE_ALL``, all hyperparameters for the model will be validated. |
| 109 | + region (str): Region for which to validate hyperparameters. (Default: JumpStart |
| 110 | + default region). |
| 111 | +
|
| 112 | + """ |
| 113 | + |
| 114 | + if validation_mode is None: |
| 115 | + validation_mode = HyperparameterValidationMode.VALIDATE_PROVIDED |
| 116 | + |
| 117 | + if region is None: |
| 118 | + region = JUMPSTART_DEFAULT_REGION_NAME |
| 119 | + |
| 120 | + model_specs = jumpstart_accessors.JumpStartModelsAccessor.get_model_specs( |
| 121 | + region=region, model_id=model_id, version=model_version |
| 122 | + ) |
| 123 | + hyperparameters_specs = model_specs.hyperparameters |
| 124 | + |
| 125 | + if validation_mode == HyperparameterValidationMode.VALIDATE_PROVIDED: |
| 126 | + for hyperparam_name, hyperparam_value in hyperparameters.items(): |
| 127 | + _validate_hyperparameter(hyperparam_name, hyperparam_value, hyperparameters_specs) |
| 128 | + |
| 129 | + elif validation_mode == HyperparameterValidationMode.VALIDATE_ALGORITHM: |
| 130 | + for hyperparam in hyperparameters_specs: |
| 131 | + if hyperparam.scope == VariableScope.ALGORITHM: |
| 132 | + if hyperparam.name not in hyperparameters: |
| 133 | + raise JumpStartHyperparametersError( |
| 134 | + f"Cannot find algorithm hyperparameter for '{hyperparam.name}'." |
| 135 | + ) |
| 136 | + _validate_hyperparameter( |
| 137 | + hyperparam.name, hyperparameters[hyperparam.name], hyperparameters_specs |
| 138 | + ) |
| 139 | + |
| 140 | + elif validation_mode == HyperparameterValidationMode.VALIDATE_ALL: |
| 141 | + for hyperparam in hyperparameters_specs: |
| 142 | + if hyperparam.name not in hyperparameters: |
| 143 | + raise JumpStartHyperparametersError( |
| 144 | + f"Cannot find hyperparameter for '{hyperparam.name}'." |
| 145 | + ) |
| 146 | + _validate_hyperparameter( |
| 147 | + hyperparam.name, hyperparameters[hyperparam.name], hyperparameters_specs |
| 148 | + ) |
| 149 | + |
| 150 | + else: |
| 151 | + raise NotImplementedError( |
| 152 | + f"Unable to handle validation for the mode '{validation_mode.value}'." |
| 153 | + ) |
0 commit comments