diff --git a/src/sagemaker/environment_variables.py b/src/sagemaker/environment_variables.py index 108dda4209..4ea2bcc812 100644 --- a/src/sagemaker/environment_variables.py +++ b/src/sagemaker/environment_variables.py @@ -44,6 +44,8 @@ def retrieve_default( ValueError: If the combination of arguments specified is not supported. """ if not jumpstart_utils.is_jumpstart_model_input(model_id, model_version): - raise ValueError("Must specify `model_id` and `model_version` when retrieving script URIs.") + raise ValueError( + "Must specify `model_id` and `model_version` when retrieving environment variables." + ) return artifacts._retrieve_default_environment_variables(model_id, model_version, region) diff --git a/src/sagemaker/hyperparameters.py b/src/sagemaker/hyperparameters.py index 621ed228e0..08ef24e19f 100644 --- a/src/sagemaker/hyperparameters.py +++ b/src/sagemaker/hyperparameters.py @@ -15,10 +15,12 @@ from __future__ import absolute_import import logging -from typing import Dict +from typing import Dict, Optional from sagemaker.jumpstart import utils as jumpstart_utils from sagemaker.jumpstart import artifacts +from sagemaker.jumpstart.enums import HyperparameterValidationMode +from sagemaker.jumpstart.validators import validate_hyperparameters logger = logging.getLogger(__name__) @@ -51,8 +53,58 @@ def retrieve_default( ValueError: If the combination of arguments specified is not supported. """ if not jumpstart_utils.is_jumpstart_model_input(model_id, model_version): - raise ValueError("Must specify `model_id` and `model_version` when retrieving script URIs.") + raise ValueError( + "Must specify `model_id` and `model_version` when retrieving hyperparameters." + ) return artifacts._retrieve_default_hyperparameters( model_id, model_version, region, include_container_hyperparameters ) + + +def validate( + region: Optional[str] = None, + model_id: Optional[str] = None, + model_version: Optional[str] = None, + hyperparameters: Optional[dict] = None, + validation_mode: Optional[HyperparameterValidationMode] = None, +) -> None: + """Validate hyperparameters for models. + + Args: + region (str): Region for which to validate hyperparameters. (Default: None). + model_id (str): Model ID of the model for which to validate hyperparameters. + (Default: None) + model_version (str): Version of the model for which to validate hyperparameters. + (Default: None) + hyperparameters (dict): Hyperparameters to validate. + (Default: None) + validation_mode (HyperparameterValidationMode): Method of validation to use with + hyperparameters. If set to ``VALIDATE_PROVIDED``, only hyperparameters provided + to this function will be validated, the missing hyperparameters will be ignored. + If set to``VALIDATE_ALGORITHM``, all algorithm hyperparameters will be validated. + If set to ``VALIDATE_ALL``, all hyperparameters for the model will be validated. + (Default: None) + + Raises: + JumpStartHyperparametersError: If the hyperparameter is not formatted correctly, + according to its specs in the model metadata. + ValueError: If the combination of arguments specified is not supported. + + """ + + if not jumpstart_utils.is_jumpstart_model_input(model_id, model_version): + raise ValueError( + "Must specify `model_id` and `model_version` when validating hyperparameters." + ) + + if hyperparameters is None: + raise ValueError("Must specify hyperparameters.") + + return validate_hyperparameters( + model_id=model_id, + model_version=model_version, + hyperparameters=hyperparameters, + validation_mode=validation_mode, + region=region, + ) diff --git a/src/sagemaker/jumpstart/artifacts.py b/src/sagemaker/jumpstart/artifacts.py index 7c9b835b3c..a61f46702f 100644 --- a/src/sagemaker/jumpstart/artifacts.py +++ b/src/sagemaker/jumpstart/artifacts.py @@ -16,6 +16,8 @@ from sagemaker import image_uris from sagemaker.jumpstart.constants import ( JUMPSTART_DEFAULT_REGION_NAME, +) +from sagemaker.jumpstart.enums import ( JumpStartScriptScope, ModelFramework, VariableScope, diff --git a/src/sagemaker/jumpstart/constants.py b/src/sagemaker/jumpstart/constants.py index f41117d7ef..363e542b02 100644 --- a/src/sagemaker/jumpstart/constants.py +++ b/src/sagemaker/jumpstart/constants.py @@ -13,8 +13,8 @@ """This module stores constants related to SageMaker JumpStart.""" from __future__ import absolute_import from typing import Set -from enum import Enum import boto3 +from sagemaker.jumpstart.enums import JumpStartScriptScope from sagemaker.jumpstart.types import JumpStartLaunchedRegionInfo @@ -118,52 +118,7 @@ JUMPSTART_DEFAULT_MANIFEST_FILE_S3_KEY = "models_manifest.json" - -INFERENCE_ENTRYPOINT_SCRIPT_NAME = "inference.py" -TRAINING_ENTRYPOINT_SCRIPT_NAME = "transfer_learning.py" - - -class JumpStartScriptScope(str, Enum): - """Enum class for JumpStart script scopes.""" - - INFERENCE = "inference" - TRAINING = "training" - +INFERENCE_ENTRY_POINT_SCRIPT_NAME = "inference.py" +TRAINING_ENTRY_POINT_SCRIPT_NAME = "transfer_learning.py" SUPPORTED_JUMPSTART_SCOPES = set(scope.value for scope in JumpStartScriptScope) - - -class ModelFramework(str, Enum): - """Enum class for JumpStart model framework. - - The ML framework as referenced in the prefix of the model ID. - This value does not necessarily correspond to the container name. - """ - - PYTORCH = "pytorch" - TENSORFLOW = "tensorflow" - MXNET = "mxnet" - HUGGINGFACE = "huggingface" - LIGHTGBM = "lightgbm" - CATBOOST = "catboost" - XGBOOST = "xgboost" - SKLEARN = "sklearn" - - -class VariableScope(str, Enum): - """Possible value of the ``scope`` attribute for a hyperparameter or environment variable. - - Used for hosting environment variables and training hyperparameters. - """ - - CONTAINER = "container" - ALGORITHM = "algorithm" - - -class JumpStartTag(str, Enum): - """Enum class for tag keys to apply to JumpStart models.""" - - INFERENCE_MODEL_URI = "aws-jumpstart-inference-model-uri" - INFERENCE_SCRIPT_URI = "aws-jumpstart-inference-script-uri" - TRAINING_MODEL_URI = "aws-jumpstart-training-model-uri" - TRAINING_SCRIPT_URI = "aws-jumpstart-training-script-uri" diff --git a/src/sagemaker/jumpstart/enums.py b/src/sagemaker/jumpstart/enums.py new file mode 100644 index 0000000000..74708bd852 --- /dev/null +++ b/src/sagemaker/jumpstart/enums.py @@ -0,0 +1,77 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""This module stores enums related to SageMaker JumpStart.""" + +from __future__ import absolute_import + +from enum import Enum + + +class ModelFramework(str, Enum): + """Enum class for JumpStart model framework. + + The ML framework as referenced in the prefix of the model ID. + This value does not necessarily correspond to the container name. + """ + + PYTORCH = "pytorch" + TENSORFLOW = "tensorflow" + MXNET = "mxnet" + HUGGINGFACE = "huggingface" + LIGHTGBM = "lightgbm" + CATBOOST = "catboost" + XGBOOST = "xgboost" + SKLEARN = "sklearn" + + +class VariableScope(str, Enum): + """Possible value of the ``scope`` attribute for a hyperparameter or environment variable. + + Used for hosting environment variables and training hyperparameters. + """ + + CONTAINER = "container" + ALGORITHM = "algorithm" + + +class JumpStartScriptScope(str, Enum): + """Enum class for JumpStart script scopes.""" + + INFERENCE = "inference" + TRAINING = "training" + + +class HyperparameterValidationMode(str, Enum): + """Possible modes for validating hyperparameters.""" + + VALIDATE_PROVIDED = "validate_provided" + VALIDATE_ALGORITHM = "validate_algorithm" + VALIDATE_ALL = "validate_all" + + +class VariableTypes(str, Enum): + """Possible types for hyperparameters and environment variables.""" + + TEXT = "text" + INT = "int" + FLOAT = "float" + BOOL = "bool" + + +class JumpStartTag(str, Enum): + """Enum class for tag keys to apply to JumpStart models.""" + + INFERENCE_MODEL_URI = "aws-jumpstart-inference-model-uri" + INFERENCE_SCRIPT_URI = "aws-jumpstart-inference-script-uri" + TRAINING_MODEL_URI = "aws-jumpstart-training-model-uri" + TRAINING_SCRIPT_URI = "aws-jumpstart-training-script-uri" diff --git a/src/sagemaker/jumpstart/exceptions.py b/src/sagemaker/jumpstart/exceptions.py index 4fdb6e2534..769f8bc7a6 100644 --- a/src/sagemaker/jumpstart/exceptions.py +++ b/src/sagemaker/jumpstart/exceptions.py @@ -11,13 +11,24 @@ # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. """This module stores exceptions related to SageMaker JumpStart.""" - from __future__ import absolute_import from typing import List, Optional from sagemaker.jumpstart.constants import JumpStartScriptScope +class JumpStartHyperparametersError(Exception): + """Exception raised for bad hyperparameters of a JumpStart model.""" + + def __init__( + self, + message: Optional[str] = None, + ): + self.message = message + + super().__init__(self.message) + + class VulnerableJumpStartModelError(Exception): """Exception raised when trying to access a JumpStart model specs flagged as vulnerable. diff --git a/src/sagemaker/jumpstart/types.py b/src/sagemaker/jumpstart/types.py index d5023010dd..7c36795652 100644 --- a/src/sagemaker/jumpstart/types.py +++ b/src/sagemaker/jumpstart/types.py @@ -181,6 +181,8 @@ class JumpStartHyperparameter(JumpStartDataHolderType): "scope", "min", "max", + "exclusive_min", + "exclusive_max", } def __init__(self, spec: Dict[str, Any]): @@ -215,6 +217,14 @@ def from_json(self, json_obj: Dict[str, Any]) -> None: if max_val is not None: self.max = max_val + exclusive_min_val = json_obj.get("exclusive_min") + if exclusive_min_val is not None: + self.exclusive_min = exclusive_min_val + + exclusive_max_val = json_obj.get("exclusive_max") + if exclusive_max_val is not None: + self.exclusive_max = exclusive_max_val + def to_json(self) -> Dict[str, Any]: """Returns json representation of JumpStartHyperparameter object.""" json_obj = {att: getattr(self, att) for att in self.__slots__ if hasattr(self, att)} diff --git a/src/sagemaker/jumpstart/utils.py b/src/sagemaker/jumpstart/utils.py index 511fade585..16bdd9fc4f 100644 --- a/src/sagemaker/jumpstart/utils.py +++ b/src/sagemaker/jumpstart/utils.py @@ -17,7 +17,7 @@ from urllib.parse import urlparse from packaging.version import Version import sagemaker -from sagemaker.jumpstart import constants +from sagemaker.jumpstart import constants, enums from sagemaker.jumpstart import accessors from sagemaker.s3 import parse_s3_url from sagemaker.jumpstart.exceptions import ( @@ -200,13 +200,13 @@ def get_tag_value(tag_key: str, tag_array: List[Dict[str, str]]) -> str: def add_single_jumpstart_tag( - uri: str, tag_key: constants.JumpStartTag, curr_tags: Optional[List[Dict[str, str]]] + uri: str, tag_key: enums.JumpStartTag, curr_tags: Optional[List[Dict[str, str]]] ) -> Optional[List]: """Adds ``tag_key`` to ``curr_tags`` if ``uri`` corresponds to a JumpStart model. Args: uri (str): URI which may correspond to a JumpStart model. - tag_key (constants.JumpStartTag): Custom tag to apply to current tags if the URI + tag_key (enums.JumpStartTag): Custom tag to apply to current tags if the URI corresponds to a JumpStart model. curr_tags (Optional[List]): Current tags associated with ``Estimator`` or ``Model``. """ @@ -249,22 +249,22 @@ def add_jumpstart_tags( if inference_model_uri: tags = add_single_jumpstart_tag( - inference_model_uri, constants.JumpStartTag.INFERENCE_MODEL_URI, tags + inference_model_uri, enums.JumpStartTag.INFERENCE_MODEL_URI, tags ) if inference_script_uri: tags = add_single_jumpstart_tag( - inference_script_uri, constants.JumpStartTag.INFERENCE_SCRIPT_URI, tags + inference_script_uri, enums.JumpStartTag.INFERENCE_SCRIPT_URI, tags ) if training_model_uri: tags = add_single_jumpstart_tag( - training_model_uri, constants.JumpStartTag.TRAINING_MODEL_URI, tags + training_model_uri, enums.JumpStartTag.TRAINING_MODEL_URI, tags ) if training_script_uri: tags = add_single_jumpstart_tag( - training_script_uri, constants.JumpStartTag.TRAINING_SCRIPT_URI, tags + training_script_uri, enums.JumpStartTag.TRAINING_SCRIPT_URI, tags ) return tags @@ -280,7 +280,7 @@ def update_inference_tags_with_jumpstart_training_tags( training_tags (Optional[List[Dict[str, str]]]): Tags from training job. """ if training_tags: - for tag_key in constants.JumpStartTag: + for tag_key in enums.JumpStartTag: if tag_key_in_array(tag_key, training_tags): tag_value = get_tag_value(tag_key, training_tags) if inference_tags is None: diff --git a/src/sagemaker/jumpstart/validators.py b/src/sagemaker/jumpstart/validators.py new file mode 100644 index 0000000000..57d5fe6d72 --- /dev/null +++ b/src/sagemaker/jumpstart/validators.py @@ -0,0 +1,226 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""This module contains validators related to SageMaker JumpStart.""" +from __future__ import absolute_import +from typing import Any, Dict, List, Optional +from sagemaker.jumpstart.constants import JUMPSTART_DEFAULT_REGION_NAME + +from sagemaker.jumpstart.enums import HyperparameterValidationMode, VariableScope, VariableTypes +from sagemaker.jumpstart import accessors as jumpstart_accessors +from sagemaker.jumpstart.exceptions import JumpStartHyperparametersError +from sagemaker.jumpstart.types import JumpStartHyperparameter + + +def _validate_hyperparameter( + hyperparameter_name: str, + hyperparameter_value: Any, + hyperparameter_specs: List[JumpStartHyperparameter], +) -> None: + """Perform low-level hyperparameter validation on single parameter. + + Args: + hyperparameter_name (str): The name of the hyperparameter to validate. + hyperparameter_value (Any): The value of the hyperparemter to validate. + hyperparameter_specs (List[JumpStartHyperparameter]): List of ``JumpStartHyperparameter`` to + use when validating the hyperparameter. + + Raises: + JumpStartHyperparametersError: If the hyperparameter is not formatted correctly, + according to its specs in the model metadata. + """ + hyperparameter_spec = [ + spec for spec in hyperparameter_specs if spec.name == hyperparameter_name + ] + if len(hyperparameter_spec) == 0: + raise JumpStartHyperparametersError( + f"Unable to perform validation -- cannot find hyperparameter '{hyperparameter_name}' " + "in model specs." + ) + + if len(hyperparameter_spec) > 1: + raise JumpStartHyperparametersError( + f"Unable to perform validation -- found multiple hyperparameter " + f"'{hyperparameter_name}' in model specs." + ) + + hyperparameter_spec = hyperparameter_spec[0] + + if hyperparameter_spec.type == VariableTypes.BOOL.value: + if isinstance(hyperparameter_value, bool): + return + if not isinstance(hyperparameter_value, str): + raise JumpStartHyperparametersError( + f"Expecting boolean valued hyperparameter, but got '{str(hyperparameter_value)}'." + ) + if str(hyperparameter_value).lower() not in ["true", "false"]: + raise JumpStartHyperparametersError( + f"Expecting boolean valued hyperparameter, but got '{str(hyperparameter_value)}'." + ) + elif hyperparameter_spec.type == VariableTypes.TEXT.value: + if not isinstance(hyperparameter_value, str): + raise JumpStartHyperparametersError( + "Expecting text valued hyperparameter to have string type." + ) + + if hasattr(hyperparameter_spec, "options"): + if hyperparameter_value not in hyperparameter_spec.options: + raise JumpStartHyperparametersError( + f"Hyperparameter '{hyperparameter_name}' must have one of the following " + f"values: {', '.join(hyperparameter_spec.options)}" + ) + + if hasattr(hyperparameter_spec, "min"): + if len(hyperparameter_value) < hyperparameter_spec.min: + raise JumpStartHyperparametersError( + f"Hyperparameter '{hyperparameter_name}' must have length no less than " + f"{hyperparameter_spec.min}" + ) + + if hasattr(hyperparameter_spec, "exclusive_min"): + if len(hyperparameter_value) <= hyperparameter_spec.exclusive_min: + raise JumpStartHyperparametersError( + f"Hyperparameter '{hyperparameter_name}' must have length greater than " + f"{hyperparameter_spec.exclusive_min}" + ) + + if hasattr(hyperparameter_spec, "max"): + if len(hyperparameter_value) > hyperparameter_spec.max: + raise JumpStartHyperparametersError( + f"Hyperparameter '{hyperparameter_name}' must have length no greater than " + f"{hyperparameter_spec.max}" + ) + + if hasattr(hyperparameter_spec, "exclusive_max"): + if len(hyperparameter_value) >= hyperparameter_spec.exclusive_max: + raise JumpStartHyperparametersError( + f"Hyperparameter '{hyperparameter_name}' must have length less than " + f"{hyperparameter_spec.exclusive_max}" + ) + + # validate numeric types + elif hyperparameter_spec.type in [VariableTypes.INT.value, VariableTypes.FLOAT.value]: + try: + numeric_hyperparam_value = float(hyperparameter_value) + except ValueError: + raise JumpStartHyperparametersError( + f"Hyperparameter '{hyperparameter_name}' must be numeric type " + f"('{hyperparameter_value}')." + ) + + if hyperparameter_spec.type == VariableTypes.INT.value: + hyperparameter_value_str = str(hyperparameter_value) + start_index = 0 + if hyperparameter_value_str[0] in ["+", "-"]: + start_index = 1 + if not hyperparameter_value_str[start_index:].isdigit(): + raise JumpStartHyperparametersError( + f"Hyperparameter '{hyperparameter_name}' must be integer type " + "('{hyperparameter_value}')." + ) + + if hasattr(hyperparameter_spec, "min"): + if numeric_hyperparam_value < hyperparameter_spec.min: + raise JumpStartHyperparametersError( + f"Hyperparameter '{hyperparameter_name}' can be no less than " + "{hyperparameter_spec.min}." + ) + + if hasattr(hyperparameter_spec, "max"): + if numeric_hyperparam_value > hyperparameter_spec.max: + raise JumpStartHyperparametersError( + f"Hyperparameter '{hyperparameter_name}' can be no greater than " + "{hyperparameter_spec.max}." + ) + + if hasattr(hyperparameter_spec, "exclusive_min"): + if numeric_hyperparam_value <= hyperparameter_spec.exclusive_min: + raise JumpStartHyperparametersError( + f"Hyperparameter '{hyperparameter_name}' must be greater than " + "{hyperparameter_spec.exclusive_min}." + ) + + if hasattr(hyperparameter_spec, "exclusive_max"): + if numeric_hyperparam_value >= hyperparameter_spec.exclusive_max: + raise JumpStartHyperparametersError( + f"Hyperparameter '{hyperparameter_name}' must be less than than " + "{hyperparameter_spec.exclusive_max}." + ) + + +def validate_hyperparameters( + model_id: str, + model_version: str, + hyperparameters: Dict[str, Any], + validation_mode: HyperparameterValidationMode = HyperparameterValidationMode.VALIDATE_PROVIDED, + region: Optional[str] = JUMPSTART_DEFAULT_REGION_NAME, +) -> None: + """Validate hyperparameters for JumpStart models. + + Args: + model_id (str): Model ID of the model for which to validate hyperparameters. + model_version (str): Version of the model for which to validate hyperparameters. + hyperparameters (dict): Hyperparameters to validate. + validation_mode (HyperparameterValidationMode): Method of validation to use with + hyperparameters. If set to ``VALIDATE_PROVIDED``, only hyperparameters provided + to this function will be validated, the missing hyperparameters will be ignored. + If set to``VALIDATE_ALGORITHM``, all algorithm hyperparameters will be validated. + If set to ``VALIDATE_ALL``, all hyperparameters for the model will be validated. + region (str): Region for which to validate hyperparameters. (Default: JumpStart + default region). + + Raises: + JumpStartHyperparametersError: If the hyperparameters are not formatted correctly, + according to their metadata specs. + + """ + + if validation_mode is None: + validation_mode = HyperparameterValidationMode.VALIDATE_PROVIDED + + if region is None: + region = JUMPSTART_DEFAULT_REGION_NAME + + model_specs = jumpstart_accessors.JumpStartModelsAccessor.get_model_specs( + region=region, model_id=model_id, version=model_version + ) + hyperparameters_specs = model_specs.hyperparameters + + if validation_mode == HyperparameterValidationMode.VALIDATE_PROVIDED: + for hyperparam_name, hyperparam_value in hyperparameters.items(): + _validate_hyperparameter(hyperparam_name, hyperparam_value, hyperparameters_specs) + + elif validation_mode == HyperparameterValidationMode.VALIDATE_ALGORITHM: + for hyperparam in hyperparameters_specs: + if hyperparam.scope == VariableScope.ALGORITHM: + if hyperparam.name not in hyperparameters: + raise JumpStartHyperparametersError( + f"Cannot find algorithm hyperparameter for '{hyperparam.name}'." + ) + _validate_hyperparameter( + hyperparam.name, hyperparameters[hyperparam.name], hyperparameters_specs + ) + + elif validation_mode == HyperparameterValidationMode.VALIDATE_ALL: + for hyperparam in hyperparameters_specs: + if hyperparam.name not in hyperparameters: + raise JumpStartHyperparametersError( + f"Cannot find hyperparameter for '{hyperparam.name}'." + ) + _validate_hyperparameter( + hyperparam.name, hyperparameters[hyperparam.name], hyperparameters_specs + ) + + else: + raise NotImplementedError( + f"Unable to handle validation for the mode '{validation_mode.value}'." + ) diff --git a/src/sagemaker/model_uris.py b/src/sagemaker/model_uris.py index 8894583f89..692fc89ae9 100644 --- a/src/sagemaker/model_uris.py +++ b/src/sagemaker/model_uris.py @@ -59,7 +59,7 @@ def retrieve( DeprecatedJumpStartModelError: If the version of the model is deprecated. """ if not jumpstart_utils.is_jumpstart_model_input(model_id, model_version): - raise ValueError("Must specify `model_id` and `model_version` when retrieving script URIs.") + raise ValueError("Must specify `model_id` and `model_version` when retrieving model URIs.") return artifacts._retrieve_model_uri( model_id, diff --git a/tests/integ/sagemaker/jumpstart/script_mode_class/test_inference.py b/tests/integ/sagemaker/jumpstart/script_mode_class/test_inference.py index e05ae87a31..69ede8b8c8 100644 --- a/tests/integ/sagemaker/jumpstart/script_mode_class/test_inference.py +++ b/tests/integ/sagemaker/jumpstart/script_mode_class/test_inference.py @@ -14,7 +14,7 @@ import os from sagemaker import image_uris, model_uris, script_uris -from sagemaker.jumpstart.constants import INFERENCE_ENTRYPOINT_SCRIPT_NAME +from sagemaker.jumpstart.constants import INFERENCE_ENTRY_POINT_SCRIPT_NAME from sagemaker.model import Model from tests.integ.sagemaker.jumpstart.constants import ( ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID, @@ -57,7 +57,7 @@ def test_jumpstart_inference_model_class(setup): image_uri=image_uri, model_data=model_uri, source_dir=script_uri, - entry_point=INFERENCE_ENTRYPOINT_SCRIPT_NAME, + entry_point=INFERENCE_ENTRY_POINT_SCRIPT_NAME, role=get_sm_session().get_caller_identity_arn(), sagemaker_session=get_sm_session(), enable_network_isolation=True, diff --git a/tests/integ/sagemaker/jumpstart/script_mode_class/test_transfer_learning.py b/tests/integ/sagemaker/jumpstart/script_mode_class/test_transfer_learning.py index 46de1b5d27..29b16cf9a5 100644 --- a/tests/integ/sagemaker/jumpstart/script_mode_class/test_transfer_learning.py +++ b/tests/integ/sagemaker/jumpstart/script_mode_class/test_transfer_learning.py @@ -16,9 +16,9 @@ from sagemaker import hyperparameters, image_uris, model_uris, script_uris from sagemaker.estimator import Estimator from sagemaker.jumpstart.constants import ( - INFERENCE_ENTRYPOINT_SCRIPT_NAME, + INFERENCE_ENTRY_POINT_SCRIPT_NAME, JUMPSTART_DEFAULT_REGION_NAME, - TRAINING_ENTRYPOINT_SCRIPT_NAME, + TRAINING_ENTRY_POINT_SCRIPT_NAME, ) from sagemaker.jumpstart.utils import get_jumpstart_content_bucket from sagemaker.utils import name_from_base @@ -70,7 +70,7 @@ def test_jumpstart_transfer_learning_estimator_class(setup): image_uri=image_uri, source_dir=script_uri, model_uri=model_uri, - entry_point=TRAINING_ENTRYPOINT_SCRIPT_NAME, + entry_point=TRAINING_ENTRY_POINT_SCRIPT_NAME, role=get_sm_session().get_caller_identity_arn(), sagemaker_session=get_sm_session(), enable_network_isolation=True, @@ -111,7 +111,7 @@ def test_jumpstart_transfer_learning_estimator_class(setup): estimator.deploy( initial_instance_count=instance_count, instance_type=inference_instance_type, - entry_point=INFERENCE_ENTRYPOINT_SCRIPT_NAME, + entry_point=INFERENCE_ENTRY_POINT_SCRIPT_NAME, image_uri=image_uri, source_dir=script_uri, endpoint_name=endpoint_name, diff --git a/tests/unit/sagemaker/hyperparameters/jumpstart/test_validate.py b/tests/unit/sagemaker/hyperparameters/jumpstart/test_validate.py new file mode 100644 index 0000000000..ddeeccba1d --- /dev/null +++ b/tests/unit/sagemaker/hyperparameters/jumpstart/test_validate.py @@ -0,0 +1,489 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + + +from mock.mock import patch +import pytest + +from sagemaker import hyperparameters +from sagemaker.jumpstart.enums import HyperparameterValidationMode +from sagemaker.jumpstart.exceptions import JumpStartHyperparametersError +from sagemaker.jumpstart.types import JumpStartHyperparameter + +from tests.unit.sagemaker.jumpstart.utils import get_spec_from_base_spec + + +@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") +def test_jumpstart_validate_provided_hyperparameters(patched_get_model_specs): + def add_options_to_hyperparameter(*largs, **kwargs): + spec = get_spec_from_base_spec(*largs, **kwargs) + spec.hyperparameters.extend( + [ + JumpStartHyperparameter( + { + "name": "penalty", + "type": "text", + "default": "l2", + "options": ["l1", "l2", "elasticnet", "none"], + "scope": "algorithm", + } + ), + JumpStartHyperparameter( + { + "name": "test_bool_param", + "type": "bool", + "default": True, + "scope": "algorithm", + } + ), + JumpStartHyperparameter( + { + "name": "test_exclusive_min_param", + "type": "float", + "default": 4, + "scope": "algorithm", + "exclusive_min": 1, + } + ), + JumpStartHyperparameter( + { + "name": "test_exclusive_max_param", + "type": "int", + "default": -4, + "scope": "algorithm", + "exclusive_max": 4, + } + ), + JumpStartHyperparameter( + { + "name": "test_exclusive_min_param_text", + "type": "text", + "default": "hello", + "scope": "algorithm", + "exclusive_min": 1, + } + ), + JumpStartHyperparameter( + { + "name": "test_exclusive_max_param_text", + "type": "text", + "default": "hello", + "scope": "algorithm", + "exclusive_max": 6, + } + ), + JumpStartHyperparameter( + { + "name": "test_min_param_text", + "type": "text", + "default": "hello", + "scope": "algorithm", + "min": 1, + } + ), + JumpStartHyperparameter( + { + "name": "test_max_param_text", + "type": "text", + "default": "hello", + "scope": "algorithm", + "max": 6, + } + ), + ] + ) + return spec + + patched_get_model_specs.side_effect = add_options_to_hyperparameter + + model_id, model_version = "pytorch-eqa-bert-base-cased", "*" + region = "us-west-2" + + hyperparameter_to_test = { + "adam-learning-rate": "0.05", + "batch-size": "4", + "epochs": "3", + "penalty": "l2", + "test_bool_param": False, + "test_exclusive_min_param": 4, + "test_exclusive_max_param": -4, + "test_exclusive_min_param_text": "hello", + "test_exclusive_max_param_text": "hello", + "test_min_param_text": "hello", + "test_max_param_text": "hello", + } + + hyperparameters.validate( + region=region, + model_id=model_id, + model_version=model_version, + hyperparameters=hyperparameter_to_test, + ) + + patched_get_model_specs.assert_called_once_with( + region=region, model_id=model_id, version=model_version + ) + + patched_get_model_specs.reset_mock() + + del hyperparameter_to_test["adam-learning-rate"] + + hyperparameters.validate( + region=region, + model_id=model_id, + model_version=model_version, + hyperparameters=hyperparameter_to_test, + ) + + hyperparameter_to_test["batch-size"] = "0" + with pytest.raises(JumpStartHyperparametersError): + hyperparameters.validate( + region=region, + model_id=model_id, + model_version=model_version, + hyperparameters=hyperparameter_to_test, + ) + + hyperparameter_to_test["batch-size"] = "-1" + with pytest.raises(JumpStartHyperparametersError): + hyperparameters.validate( + region=region, + model_id=model_id, + model_version=model_version, + hyperparameters=hyperparameter_to_test, + ) + + hyperparameter_to_test["batch-size"] = "-1.5" + with pytest.raises(JumpStartHyperparametersError): + hyperparameters.validate( + region=region, + model_id=model_id, + model_version=model_version, + hyperparameters=hyperparameter_to_test, + ) + + hyperparameter_to_test["batch-size"] = "1.5" + with pytest.raises(JumpStartHyperparametersError): + hyperparameters.validate( + region=region, + model_id=model_id, + model_version=model_version, + hyperparameters=hyperparameter_to_test, + ) + + hyperparameter_to_test["batch-size"] = "99999" + with pytest.raises(JumpStartHyperparametersError): + hyperparameters.validate( + region=region, + model_id=model_id, + model_version=model_version, + hyperparameters=hyperparameter_to_test, + ) + + hyperparameter_to_test["batch-size"] = 5 + hyperparameters.validate( + region=region, + model_id=model_id, + model_version=model_version, + hyperparameters=hyperparameter_to_test, + ) + + original_bool_val = hyperparameter_to_test["test_bool_param"] + for val in ["False", "fAlSe", "false", "True", "TrUe", "true", True, False]: + hyperparameter_to_test["test_bool_param"] = val + hyperparameters.validate( + region=region, + model_id=model_id, + model_version=model_version, + hyperparameters=hyperparameter_to_test, + ) + for val in [None, "", 5, "Truesday", "Falsehood"]: + hyperparameter_to_test["test_bool_param"] = val + with pytest.raises(JumpStartHyperparametersError): + hyperparameters.validate( + region=region, + model_id=model_id, + model_version=model_version, + hyperparameters=hyperparameter_to_test, + ) + hyperparameter_to_test["test_bool_param"] = original_bool_val + + original_exclusive_min_val = hyperparameter_to_test["test_exclusive_min_param"] + for val in [2, 1 + 1e-9]: + hyperparameter_to_test["test_exclusive_min_param"] = val + hyperparameters.validate( + region=region, + model_id=model_id, + model_version=model_version, + hyperparameters=hyperparameter_to_test, + ) + for val in [1, 1 - 1e-99, -99]: + hyperparameter_to_test["test_exclusive_min_param"] = val + with pytest.raises(JumpStartHyperparametersError): + hyperparameters.validate( + region=region, + model_id=model_id, + model_version=model_version, + hyperparameters=hyperparameter_to_test, + ) + hyperparameter_to_test["test_exclusive_min_param"] = original_exclusive_min_val + + original_exclusive_max_val = hyperparameter_to_test["test_exclusive_max_param"] + for val in [-2, 2, 3]: + hyperparameter_to_test["test_exclusive_max_param"] = val + hyperparameters.validate( + region=region, + model_id=model_id, + model_version=model_version, + hyperparameters=hyperparameter_to_test, + ) + for val in [4, 5, 99]: + hyperparameter_to_test["test_exclusive_max_param"] = val + with pytest.raises(JumpStartHyperparametersError): + hyperparameters.validate( + region=region, + model_id=model_id, + model_version=model_version, + hyperparameters=hyperparameter_to_test, + ) + hyperparameter_to_test["test_exclusive_max_param"] = original_exclusive_max_val + + original_exclusive_max_text_val = hyperparameter_to_test["test_exclusive_max_param_text"] + for val in ["", "sd", "12345"]: + hyperparameter_to_test["test_exclusive_max_param_text"] = val + hyperparameters.validate( + region=region, + model_id=model_id, + model_version=model_version, + hyperparameters=hyperparameter_to_test, + ) + for val in ["123456", "123456789"]: + hyperparameter_to_test["test_exclusive_max_param_text"] = val + with pytest.raises(JumpStartHyperparametersError): + hyperparameters.validate( + region=region, + model_id=model_id, + model_version=model_version, + hyperparameters=hyperparameter_to_test, + ) + hyperparameter_to_test["test_exclusive_max_param_text"] = original_exclusive_max_text_val + + original_max_text_val = hyperparameter_to_test["test_max_param_text"] + for val in ["", "sd", "12345", "123456"]: + hyperparameter_to_test["test_max_param_text"] = val + hyperparameters.validate( + region=region, + model_id=model_id, + model_version=model_version, + hyperparameters=hyperparameter_to_test, + ) + for val in ["1234567", "123456789"]: + hyperparameter_to_test["test_max_param_text"] = val + with pytest.raises(JumpStartHyperparametersError): + hyperparameters.validate( + region=region, + model_id=model_id, + model_version=model_version, + hyperparameters=hyperparameter_to_test, + ) + hyperparameter_to_test["test_max_param_text"] = original_max_text_val + + original_exclusive_min_text_val = hyperparameter_to_test["test_exclusive_min_param_text"] + for val in ["12", "sdfs", "12345dsfs"]: + hyperparameter_to_test["test_exclusive_min_param_text"] = val + hyperparameters.validate( + region=region, + model_id=model_id, + model_version=model_version, + hyperparameters=hyperparameter_to_test, + ) + for val in ["1", "d", ""]: + hyperparameter_to_test["test_exclusive_min_param_text"] = val + with pytest.raises(JumpStartHyperparametersError): + hyperparameters.validate( + region=region, + model_id=model_id, + model_version=model_version, + hyperparameters=hyperparameter_to_test, + ) + hyperparameter_to_test["test_exclusive_min_param_text"] = original_exclusive_min_text_val + + original_min_text_val = hyperparameter_to_test["test_min_param_text"] + for val in ["1", "s", "12", "sdfs", "12345dsfs"]: + hyperparameter_to_test["test_min_param_text"] = val + hyperparameters.validate( + region=region, + model_id=model_id, + model_version=model_version, + hyperparameters=hyperparameter_to_test, + ) + for val in [""]: + hyperparameter_to_test["test_min_param_text"] = val + with pytest.raises(JumpStartHyperparametersError): + hyperparameters.validate( + region=region, + model_id=model_id, + model_version=model_version, + hyperparameters=hyperparameter_to_test, + ) + hyperparameter_to_test["test_min_param_text"] = original_min_text_val + + del hyperparameter_to_test["batch-size"] + hyperparameter_to_test["penalty"] = "blah" + with pytest.raises(JumpStartHyperparametersError): + hyperparameters.validate( + region=region, + model_id=model_id, + model_version=model_version, + hyperparameters=hyperparameter_to_test, + ) + + hyperparameter_to_test["penalty"] = "elasticnet" + hyperparameters.validate( + region=region, + model_id=model_id, + model_version=model_version, + hyperparameters=hyperparameter_to_test, + ) + + +@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") +def test_jumpstart_validate_algorithm_hyperparameters(patched_get_model_specs): + def add_options_to_hyperparameter(*largs, **kwargs): + spec = get_spec_from_base_spec(*largs, **kwargs) + spec.hyperparameters.append( + JumpStartHyperparameter( + { + "name": "penalty", + "type": "text", + "default": "l2", + "options": ["l1", "l2", "elasticnet", "none"], + "scope": "algorithm", + } + ) + ) + return spec + + patched_get_model_specs.side_effect = add_options_to_hyperparameter + + model_id, model_version = "pytorch-eqa-bert-base-cased", "*" + region = "us-west-2" + + hyperparameter_to_test = { + "adam-learning-rate": "0.05", + "batch-size": "4", + "epochs": "3", + "penalty": "l2", + } + + hyperparameters.validate( + region=region, + model_id=model_id, + model_version=model_version, + hyperparameters=hyperparameter_to_test, + validation_mode=HyperparameterValidationMode.VALIDATE_ALGORITHM, + ) + + patched_get_model_specs.assert_called_once_with( + region=region, model_id=model_id, version=model_version + ) + + patched_get_model_specs.reset_mock() + + hyperparameter_to_test["random-param"] = "random_val" + hyperparameters.validate( + region=region, + model_id=model_id, + model_version=model_version, + hyperparameters=hyperparameter_to_test, + validation_mode=HyperparameterValidationMode.VALIDATE_ALGORITHM, + ) + + del hyperparameter_to_test["adam-learning-rate"] + with pytest.raises(JumpStartHyperparametersError): + hyperparameters.validate( + region=region, + model_id=model_id, + model_version=model_version, + hyperparameters=hyperparameter_to_test, + validation_mode=HyperparameterValidationMode.VALIDATE_ALGORITHM, + ) + + +@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") +def test_jumpstart_validate_all_hyperparameters(patched_get_model_specs): + + patched_get_model_specs.side_effect = get_spec_from_base_spec + + model_id, model_version = "pytorch-eqa-bert-base-cased", "*" + region = "us-west-2" + + hyperparameter_to_test = { + "adam-learning-rate": "0.05", + "batch-size": "4", + "epochs": "3", + "sagemaker_container_log_level": "20", + "sagemaker_program": "transfer_learning.py", + "sagemaker_submit_directory": "/opt/ml/input/data/code/sourcedir.tar.gz", + } + + hyperparameters.validate( + region=region, + model_id=model_id, + model_version=model_version, + hyperparameters=hyperparameter_to_test, + validation_mode=HyperparameterValidationMode.VALIDATE_ALL, + ) + + patched_get_model_specs.assert_called_once_with( + region=region, model_id=model_id, version=model_version + ) + + patched_get_model_specs.reset_mock() + + del hyperparameter_to_test["sagemaker_submit_directory"] + + with pytest.raises(JumpStartHyperparametersError): + hyperparameters.validate( + region=region, + model_id=model_id, + model_version=model_version, + hyperparameters=hyperparameter_to_test, + validation_mode=HyperparameterValidationMode.VALIDATE_ALL, + ) + + hyperparameter_to_test[ + "sagemaker_submit_directory" + ] = "/opt/ml/input/data/code/sourcedir.tar.gz" + del hyperparameter_to_test["epochs"] + + with pytest.raises(JumpStartHyperparametersError): + hyperparameters.validate( + region=region, + model_id=model_id, + model_version=model_version, + hyperparameters=hyperparameter_to_test, + validation_mode=HyperparameterValidationMode.VALIDATE_ALL, + ) + + hyperparameter_to_test["epochs"] = "3" + + hyperparameter_to_test["other_hyperparam"] = "blah" + hyperparameters.validate( + region=region, + model_id=model_id, + model_version=model_version, + hyperparameters=hyperparameter_to_test, + validation_mode=HyperparameterValidationMode.VALIDATE_ALL, + ) diff --git a/tests/unit/sagemaker/jumpstart/test_utils.py b/tests/unit/sagemaker/jumpstart/test_utils.py index 76c6161469..1877ede054 100644 --- a/tests/unit/sagemaker/jumpstart/test_utils.py +++ b/tests/unit/sagemaker/jumpstart/test_utils.py @@ -18,9 +18,9 @@ from sagemaker.jumpstart.constants import ( JUMPSTART_BUCKET_NAME_SET, JUMPSTART_REGION_NAME_SET, - JumpStartTag, JumpStartScriptScope, ) +from sagemaker.jumpstart.enums import JumpStartTag from sagemaker.jumpstart.exceptions import ( DeprecatedJumpStartModelError, VulnerableJumpStartModelError, diff --git a/tests/unit/sagemaker/model/test_model.py b/tests/unit/sagemaker/model/test_model.py index 7ffea2b69f..42effef480 100644 --- a/tests/unit/sagemaker/model/test_model.py +++ b/tests/unit/sagemaker/model/test_model.py @@ -19,7 +19,8 @@ import sagemaker from sagemaker.model import FrameworkModel, Model from sagemaker.huggingface.model import HuggingFaceModel -from sagemaker.jumpstart.constants import JUMPSTART_BUCKET_NAME_SET, JumpStartTag +from sagemaker.jumpstart.constants import JUMPSTART_BUCKET_NAME_SET +from sagemaker.jumpstart.enums import JumpStartTag from sagemaker.mxnet.model import MXNetModel from sagemaker.pytorch.model import PyTorchModel from sagemaker.sklearn.model import SKLearnModel diff --git a/tests/unit/test_estimator.py b/tests/unit/test_estimator.py index 792faa61b0..37bdc4d8ed 100644 --- a/tests/unit/test_estimator.py +++ b/tests/unit/test_estimator.py @@ -24,7 +24,8 @@ from botocore.exceptions import ClientError from mock import ANY, MagicMock, Mock, patch from sagemaker.huggingface.estimator import HuggingFace -from sagemaker.jumpstart.constants import JUMPSTART_BUCKET_NAME_SET, JumpStartTag +from sagemaker.jumpstart.constants import JUMPSTART_BUCKET_NAME_SET +from sagemaker.jumpstart.enums import JumpStartTag import sagemaker.local from sagemaker import TrainingInput, utils, vpc_utils