From f4b05362adb943eb4ef80ce87ca3e1dfbe1cfaa8 Mon Sep 17 00:00:00 2001 From: Evan Kravitz Date: Wed, 19 Jan 2022 17:16:01 +0000 Subject: [PATCH 1/7] feat: hyperparameter validation --- src/sagemaker/hyperparameters.py | 46 ++- src/sagemaker/jumpstart/artifacts.py | 2 + src/sagemaker/jumpstart/constants.py | 28 -- src/sagemaker/jumpstart/enums.py | 44 +++ src/sagemaker/jumpstart/exceptions.py | 27 ++ src/sagemaker/jumpstart/validators.py | 153 ++++++++++ .../jumpstart/test_validate.py | 278 ++++++++++++++++++ 7 files changed, 549 insertions(+), 29 deletions(-) create mode 100644 src/sagemaker/jumpstart/enums.py create mode 100644 src/sagemaker/jumpstart/exceptions.py create mode 100644 src/sagemaker/jumpstart/validators.py create mode 100644 tests/unit/sagemaker/hyperparameters/jumpstart/test_validate.py diff --git a/src/sagemaker/hyperparameters.py b/src/sagemaker/hyperparameters.py index 621ed228e0..5b5cb4f80f 100644 --- a/src/sagemaker/hyperparameters.py +++ b/src/sagemaker/hyperparameters.py @@ -15,10 +15,12 @@ from __future__ import absolute_import import logging -from typing import Dict +from typing import Dict, Optional from sagemaker.jumpstart import utils as jumpstart_utils from sagemaker.jumpstart import artifacts +from sagemaker.jumpstart.enums import HyperparameterValidationMode +from sagemaker.jumpstart.validators import validate_hyperparameters logger = logging.getLogger(__name__) @@ -56,3 +58,45 @@ def retrieve_default( return artifacts._retrieve_default_hyperparameters( model_id, model_version, region, include_container_hyperparameters ) + + +def validate( + region: Optional[str] = None, + model_id: Optional[str] = None, + model_version: Optional[str] = None, + hyperparameters: Optional[dict] = None, + validation_mode: Optional[HyperparameterValidationMode] = None, +): + """Validate hyperparameters for models. + + Args: + region (str): Region for which to validate hyperparameters. (Default: None). + model_id (str): Model ID of the model for which to validate hyperparameters. + (Default: None) + model_version (str): Version of the model for which to validate hyperparameters. + (Default: None) + hyperparameters (dict): Hyperparameters to validate. + (Default: None) + validation_mode (HyperparameterValidationMode): Method of validation to use with + hyperparameters. If set to ``VALIDATE_PROVIDED``, only hyperparameters provided + to this function will be validated, the missing hyperparameters will be ignored. + If set to``VALIDATE_ALGORITHM``, all algorithm hyperparameters will be validated. + If set to ``VALIDATE_ALL``, all hyperparameters for the model will be validated. + (Default: None) + + + """ + + if not jumpstart_utils.is_jumpstart_model_input(model_id, model_version): + raise ValueError("Must specify `model_id` and `model_version` when retrieving script URIs.") + + if hyperparameters is None: + raise ValueError("Must specify hyperparameters.") + + return validate_hyperparameters( + model_id=model_id, + model_version=model_version, + hyperparameters=hyperparameters, + validation_mode=validation_mode, + region=region, + ) diff --git a/src/sagemaker/jumpstart/artifacts.py b/src/sagemaker/jumpstart/artifacts.py index 2919fe44b2..babfdef637 100644 --- a/src/sagemaker/jumpstart/artifacts.py +++ b/src/sagemaker/jumpstart/artifacts.py @@ -19,6 +19,8 @@ INFERENCE, TRAINING, SUPPORTED_JUMPSTART_SCOPES, +) +from sagemaker.jumpstart.enums import ( ModelFramework, VariableScope, ) diff --git a/src/sagemaker/jumpstart/constants.py b/src/sagemaker/jumpstart/constants.py index aedce0e0da..9ae8b71f44 100644 --- a/src/sagemaker/jumpstart/constants.py +++ b/src/sagemaker/jumpstart/constants.py @@ -13,7 +13,6 @@ """This module stores constants related to SageMaker JumpStart.""" from __future__ import absolute_import from typing import Set -from enum import Enum import boto3 from sagemaker.jumpstart.types import JumpStartLaunchedRegionInfo @@ -122,30 +121,3 @@ INFERENCE_ENTRYPOINT_SCRIPT_NAME = "inference.py" TRAINING_ENTRYPOINT_SCRIPT_NAME = "transfer_learning.py" - - -class ModelFramework(str, Enum): - """Enum class for JumpStart model framework. - - The ML framework as referenced in the prefix of the model ID. - This value does not necessarily correspond to the container name. - """ - - PYTORCH = "pytorch" - TENSORFLOW = "tensorflow" - MXNET = "mxnet" - HUGGINGFACE = "huggingface" - LIGHTGBM = "lightgbm" - CATBOOST = "catboost" - XGBOOST = "xgboost" - SKLEARN = "sklearn" - - -class VariableScope(str, Enum): - """Possible value of the ``scope`` attribute for a hyperparameter or environment variable. - - Used for hosting environment variables and training hyperparameters. - """ - - CONTAINER = "container" - ALGORITHM = "algorithm" diff --git a/src/sagemaker/jumpstart/enums.py b/src/sagemaker/jumpstart/enums.py new file mode 100644 index 0000000000..93714bd945 --- /dev/null +++ b/src/sagemaker/jumpstart/enums.py @@ -0,0 +1,44 @@ +from enum import Enum + + +class ModelFramework(str, Enum): + """Enum class for JumpStart model framework. + + The ML framework as referenced in the prefix of the model ID. + This value does not necessarily correspond to the container name. + """ + + PYTORCH = "pytorch" + TENSORFLOW = "tensorflow" + MXNET = "mxnet" + HUGGINGFACE = "huggingface" + LIGHTGBM = "lightgbm" + CATBOOST = "catboost" + XGBOOST = "xgboost" + SKLEARN = "sklearn" + + +class VariableScope(str, Enum): + """Possible value of the ``scope`` attribute for a hyperparameter or environment variable. + + Used for hosting environment variables and training hyperparameters. + """ + + CONTAINER = "container" + ALGORITHM = "algorithm" + + +class HyperparameterValidationMode(str, Enum): + """Possible modes for validating hyperparameters.""" + + VALIDATE_PROVIDED = "validate_provided" + VALIDATE_ALGORITHM = "validate_algorithm" + VALIDATE_ALL = "validate_all" + + +class VariableTypes(str, Enum): + """Possible types for hyperparameters and environment variables.""" + + TEXT = "text" + INT = "int" + FLOAT = "float" diff --git a/src/sagemaker/jumpstart/exceptions.py b/src/sagemaker/jumpstart/exceptions.py new file mode 100644 index 0000000000..9693db71bd --- /dev/null +++ b/src/sagemaker/jumpstart/exceptions.py @@ -0,0 +1,27 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""This module stores exceptions related to SageMaker JumpStart.""" + +from typing import Optional + + +class JumpStartHyperparametersError(Exception): + """Exception raised for errors with hyperparameters for JumpStart models.""" + + def __init__( + self, + message: Optional[str] = None, + ): + self.message = message + + super().__init__(self.message) diff --git a/src/sagemaker/jumpstart/validators.py b/src/sagemaker/jumpstart/validators.py new file mode 100644 index 0000000000..87a57e9021 --- /dev/null +++ b/src/sagemaker/jumpstart/validators.py @@ -0,0 +1,153 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""This module contains validators related to SageMaker JumpStart.""" +from __future__ import absolute_import +from typing import Any, List, Optional +from sagemaker.jumpstart.constants import JUMPSTART_DEFAULT_REGION_NAME + +from sagemaker.jumpstart.enums import HyperparameterValidationMode, VariableScope, VariableTypes +from sagemaker.jumpstart import accessors as jumpstart_accessors +from sagemaker.jumpstart.exceptions import JumpStartHyperparametersError +from sagemaker.jumpstart.types import JumpStartHyperparameter + + +def _validate_hyperparameter( + hyperparameter_name: str, + hyperparameter_value: Any, + hyperparameter_specs: List[JumpStartHyperparameter], +): + """Perform low-level hyperparameter validation on single parameter. + + Args: + hyperparameter_name (str): The name of the hyperparameter to validate. + hyperparameter_value (Any): The value of the hyperparemter to validate. + hyperparameter_specs (List[JumpStartHyperparameter]): List of ``JumpStartHyperparameter`` to + use when validating the hyperparameter. + """ + hyperparameter_spec = [ + spec for spec in hyperparameter_specs if spec.name == hyperparameter_name + ] + if len(hyperparameter_spec) == 0: + raise JumpStartHyperparametersError( + f"Unable to perform validation -- cannot find hyperparameter '{hyperparameter_name}' in model specs." + ) + hyperparameter_spec = hyperparameter_spec[0] + + if hyperparameter_spec.type == VariableTypes.TEXT.value: + if type(hyperparameter_value) != str: + raise JumpStartHyperparametersError( + f"Expecting text valued hyperparameter to have string type." + ) + + if getattr(hyperparameter_spec, "options", None): + if hyperparameter_value not in hyperparameter_spec.options: + raise JumpStartHyperparametersError( + f"Hyperparameter '{hyperparameter_name}' must have one of the following values: " + ", ".join(hyperparameter_spec.options) + ) + + # validate numeric types + if hyperparameter_spec.type in [VariableTypes.INT.value, VariableTypes.FLOAT.value]: + try: + numeric_hyperparam_value = float(hyperparameter_value) + except ValueError: + raise JumpStartHyperparametersError( + f"Hyperparameter '{hyperparameter_name}' must be numeric type ('{hyperparameter_value}')." + ) + + if hyperparameter_spec.type == VariableTypes.INT.value: + hyperparameter_value_str = str(hyperparameter_value) + start_index = 0 + if hyperparameter_value_str[0] in ["+", "-"]: + start_index = 1 + if not hyperparameter_value_str[start_index:].isdigit(): + raise JumpStartHyperparametersError( + f"Hyperparameter '{hyperparameter_name}' must be integer type ('{hyperparameter_value}')." + ) + + if getattr(hyperparameter_spec, "min", None): + if numeric_hyperparam_value < hyperparameter_spec.min: + raise JumpStartHyperparametersError( + f"Hyperparameter '{hyperparameter_name}' can be no less than {hyperparameter_spec.min}." + ) + + if getattr(hyperparameter_spec, "max", None): + if numeric_hyperparam_value > hyperparameter_spec.max: + raise JumpStartHyperparametersError( + f"Hyperparameter '{hyperparameter_name}' can be no greater than {hyperparameter_spec.max}." + ) + + +def validate_hyperparameters( + model_id: str, + model_version: str, + hyperparameters: dict, + validation_mode: HyperparameterValidationMode = HyperparameterValidationMode.VALIDATE_PROVIDED, + region: Optional[str] = JUMPSTART_DEFAULT_REGION_NAME, +): + """Validate hyperparameters for JumpStart models. + + Args: + model_id (str): Model ID of the model for which to validate hyperparameters. + model_version (str): Version of the model for which to validate hyperparameters. + hyperparameters (dict): Hyperparameters to validate. + validation_mode (HyperparameterValidationMode): Method of validation to use with + hyperparameters. If set to ``VALIDATE_PROVIDED``, only hyperparameters provided + to this function will be validated, the missing hyperparameters will be ignored. + If set to``VALIDATE_ALGORITHM``, all algorithm hyperparameters will be validated. + If set to ``VALIDATE_ALL``, all hyperparameters for the model will be validated. + region (str): Region for which to validate hyperparameters. (Default: JumpStart + default region). + + """ + + if validation_mode is None: + validation_mode = HyperparameterValidationMode.VALIDATE_PROVIDED + + if region is None: + region = JUMPSTART_DEFAULT_REGION_NAME + + model_specs = jumpstart_accessors.JumpStartModelsAccessor.get_model_specs( + region=region, model_id=model_id, version=model_version + ) + hyperparameters_specs = model_specs.hyperparameters + + if validation_mode == HyperparameterValidationMode.VALIDATE_PROVIDED: + for hyperparam_name, hyperparam_value in hyperparameters.items(): + _validate_hyperparameter(hyperparam_name, hyperparam_value, hyperparameters_specs) + + elif validation_mode == HyperparameterValidationMode.VALIDATE_ALGORITHM: + for hyperparam in hyperparameters_specs: + if hyperparam.scope == VariableScope.ALGORITHM: + if hyperparam.name not in hyperparameters: + raise JumpStartHyperparametersError( + f"Cannot find algorithm hyperparameter for '{hyperparam.name}'." + ) + _validate_hyperparameter( + hyperparam.name, hyperparameters[hyperparam.name], hyperparameters_specs + ) + + elif validation_mode == HyperparameterValidationMode.VALIDATE_ALL: + for hyperparam in hyperparameters_specs: + if hyperparam.name not in hyperparameters: + raise JumpStartHyperparametersError( + f"Cannot find hyperparameter for '{hyperparam.name}'." + ) + _validate_hyperparameter( + hyperparam.name, hyperparameters[hyperparam.name], hyperparameters_specs + ) + + else: + raise NotImplementedError( + f"Unable to handle validation for the mode '{validation_mode.value}'." + ) diff --git a/tests/unit/sagemaker/hyperparameters/jumpstart/test_validate.py b/tests/unit/sagemaker/hyperparameters/jumpstart/test_validate.py new file mode 100644 index 0000000000..c73fe76b1e --- /dev/null +++ b/tests/unit/sagemaker/hyperparameters/jumpstart/test_validate.py @@ -0,0 +1,278 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + + +from mock.mock import patch +import pytest + +from sagemaker import hyperparameters +from sagemaker.jumpstart.enums import HyperparameterValidationMode +from sagemaker.jumpstart.exceptions import JumpStartHyperparametersError +from sagemaker.jumpstart.types import JumpStartHyperparameter + +from tests.unit.sagemaker.jumpstart.utils import get_spec_from_base_spec + + +@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") +def test_jumpstart_validate_provided_hyperparameters(patched_get_model_specs): + def add_options_to_hyperparameter(*largs, **kwargs): + spec = get_spec_from_base_spec(*largs, **kwargs) + spec.hyperparameters.append( + JumpStartHyperparameter( + { + "name": "penalty", + "type": "text", + "default": "l2", + "options": ["l1", "l2", "elasticnet", "none"], + "scope": "algorithm", + } + ) + ) + return spec + + patched_get_model_specs.side_effect = add_options_to_hyperparameter + + model_id, model_version = "pytorch-eqa-bert-base-cased", "*" + region = "us-west-2" + + hyperparameter_to_test = { + "adam-learning-rate": "0.05", + "batch-size": "4", + "epochs": "3", + "penalty": "l2", + } + + hyperparameters.validate( + region=region, + model_id=model_id, + model_version=model_version, + hyperparameters=hyperparameter_to_test, + ) + + patched_get_model_specs.assert_called_once_with( + region=region, model_id=model_id, version=model_version + ) + + patched_get_model_specs.reset_mock() + + del hyperparameter_to_test["adam-learning-rate"] + + hyperparameters.validate( + region=region, + model_id=model_id, + model_version=model_version, + hyperparameters=hyperparameter_to_test, + ) + + hyperparameter_to_test["batch-size"] = "0" + with pytest.raises(JumpStartHyperparametersError): + hyperparameters.validate( + region=region, + model_id=model_id, + model_version=model_version, + hyperparameters=hyperparameter_to_test, + ) + + hyperparameter_to_test["batch-size"] = "-1" + with pytest.raises(JumpStartHyperparametersError): + hyperparameters.validate( + region=region, + model_id=model_id, + model_version=model_version, + hyperparameters=hyperparameter_to_test, + ) + + hyperparameter_to_test["batch-size"] = "-1.5" + with pytest.raises(JumpStartHyperparametersError): + hyperparameters.validate( + region=region, + model_id=model_id, + model_version=model_version, + hyperparameters=hyperparameter_to_test, + ) + + hyperparameter_to_test["batch-size"] = "1.5" + with pytest.raises(JumpStartHyperparametersError): + hyperparameters.validate( + region=region, + model_id=model_id, + model_version=model_version, + hyperparameters=hyperparameter_to_test, + ) + + hyperparameter_to_test["batch-size"] = "99999" + with pytest.raises(JumpStartHyperparametersError): + hyperparameters.validate( + region=region, + model_id=model_id, + model_version=model_version, + hyperparameters=hyperparameter_to_test, + ) + + hyperparameter_to_test["batch-size"] = 5 + hyperparameters.validate( + region=region, + model_id=model_id, + model_version=model_version, + hyperparameters=hyperparameter_to_test, + ) + + del hyperparameter_to_test["batch-size"] + hyperparameter_to_test["penalty"] = "blah" + with pytest.raises(JumpStartHyperparametersError): + hyperparameters.validate( + region=region, + model_id=model_id, + model_version=model_version, + hyperparameters=hyperparameter_to_test, + ) + + hyperparameter_to_test["penalty"] = "elasticnet" + hyperparameters.validate( + region=region, + model_id=model_id, + model_version=model_version, + hyperparameters=hyperparameter_to_test, + ) + + +@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") +def test_jumpstart_validate_algorithm_hyperparameters(patched_get_model_specs): + def add_options_to_hyperparameter(*largs, **kwargs): + spec = get_spec_from_base_spec(*largs, **kwargs) + spec.hyperparameters.append( + JumpStartHyperparameter( + { + "name": "penalty", + "type": "text", + "default": "l2", + "options": ["l1", "l2", "elasticnet", "none"], + "scope": "algorithm", + } + ) + ) + return spec + + patched_get_model_specs.side_effect = add_options_to_hyperparameter + + model_id, model_version = "pytorch-eqa-bert-base-cased", "*" + region = "us-west-2" + + hyperparameter_to_test = { + "adam-learning-rate": "0.05", + "batch-size": "4", + "epochs": "3", + "penalty": "l2", + } + + hyperparameters.validate( + region=region, + model_id=model_id, + model_version=model_version, + hyperparameters=hyperparameter_to_test, + validation_mode=HyperparameterValidationMode.VALIDATE_ALGORITHM, + ) + + patched_get_model_specs.assert_called_once_with( + region=region, model_id=model_id, version=model_version + ) + + patched_get_model_specs.reset_mock() + + hyperparameter_to_test["random-param"] = "random_val" + hyperparameters.validate( + region=region, + model_id=model_id, + model_version=model_version, + hyperparameters=hyperparameter_to_test, + validation_mode=HyperparameterValidationMode.VALIDATE_ALGORITHM, + ) + + del hyperparameter_to_test["adam-learning-rate"] + with pytest.raises(JumpStartHyperparametersError): + hyperparameters.validate( + region=region, + model_id=model_id, + model_version=model_version, + hyperparameters=hyperparameter_to_test, + validation_mode=HyperparameterValidationMode.VALIDATE_ALGORITHM, + ) + + +@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") +def test_jumpstart_validate_all_hyperparameters(patched_get_model_specs): + + patched_get_model_specs.side_effect = get_spec_from_base_spec + + model_id, model_version = "pytorch-eqa-bert-base-cased", "*" + region = "us-west-2" + + hyperparameter_to_test = { + "adam-learning-rate": "0.05", + "batch-size": "4", + "epochs": "3", + "sagemaker_container_log_level": "20", + "sagemaker_program": "transfer_learning.py", + "sagemaker_submit_directory": "/opt/ml/input/data/code/sourcedir.tar.gz", + } + + hyperparameters.validate( + region=region, + model_id=model_id, + model_version=model_version, + hyperparameters=hyperparameter_to_test, + validation_mode=HyperparameterValidationMode.VALIDATE_ALL, + ) + + patched_get_model_specs.assert_called_once_with( + region=region, model_id=model_id, version=model_version + ) + + patched_get_model_specs.reset_mock() + + del hyperparameter_to_test["sagemaker_submit_directory"] + + with pytest.raises(JumpStartHyperparametersError): + hyperparameters.validate( + region=region, + model_id=model_id, + model_version=model_version, + hyperparameters=hyperparameter_to_test, + validation_mode=HyperparameterValidationMode.VALIDATE_ALL, + ) + + hyperparameter_to_test[ + "sagemaker_submit_directory" + ] = "/opt/ml/input/data/code/sourcedir.tar.gz" + del hyperparameter_to_test["epochs"] + + with pytest.raises(JumpStartHyperparametersError): + hyperparameters.validate( + region=region, + model_id=model_id, + model_version=model_version, + hyperparameters=hyperparameter_to_test, + validation_mode=HyperparameterValidationMode.VALIDATE_ALL, + ) + + hyperparameter_to_test["epochs"] = "3" + + hyperparameter_to_test["other_hyperparam"] = "blah" + hyperparameters.validate( + region=region, + model_id=model_id, + model_version=model_version, + hyperparameters=hyperparameter_to_test, + validation_mode=HyperparameterValidationMode.VALIDATE_ALL, + ) From 8f4aecffab79aee8100d6c0ecceb72beb3992a81 Mon Sep 17 00:00:00 2001 From: Evan Kravitz Date: Sat, 22 Jan 2022 01:18:02 +0000 Subject: [PATCH 2/7] change: improve jumpstart hyperparam validation logic --- src/sagemaker/environment_variables.py | 4 +- src/sagemaker/hyperparameters.py | 11 +- src/sagemaker/jumpstart/enums.py | 16 +++ src/sagemaker/jumpstart/exceptions.py | 3 +- src/sagemaker/jumpstart/types.py | 10 ++ src/sagemaker/jumpstart/validators.py | 107 ++++++++++++++--- src/sagemaker/model_uris.py | 2 +- .../jumpstart/test_validate.py | 111 ++++++++++++++++-- 8 files changed, 232 insertions(+), 32 deletions(-) diff --git a/src/sagemaker/environment_variables.py b/src/sagemaker/environment_variables.py index d4646d2617..21da25a110 100644 --- a/src/sagemaker/environment_variables.py +++ b/src/sagemaker/environment_variables.py @@ -44,7 +44,9 @@ def retrieve_default( ValueError: If the combination of arguments specified is not supported. """ if not jumpstart_utils.is_jumpstart_model_input(model_id, model_version): - raise ValueError("Must specify `model_id` and `model_version` when retrieving script URIs.") + raise ValueError( + "Must specify `model_id` and `model_version` when retrieving environment variables." + ) # mypy type checking require these assertions assert model_id is not None diff --git a/src/sagemaker/hyperparameters.py b/src/sagemaker/hyperparameters.py index 5b5cb4f80f..8cf94e7153 100644 --- a/src/sagemaker/hyperparameters.py +++ b/src/sagemaker/hyperparameters.py @@ -53,7 +53,9 @@ def retrieve_default( ValueError: If the combination of arguments specified is not supported. """ if not jumpstart_utils.is_jumpstart_model_input(model_id, model_version): - raise ValueError("Must specify `model_id` and `model_version` when retrieving script URIs.") + raise ValueError( + "Must specify `model_id` and `model_version` when retrieving hyperparameters." + ) return artifacts._retrieve_default_hyperparameters( model_id, model_version, region, include_container_hyperparameters @@ -84,11 +86,16 @@ def validate( If set to ``VALIDATE_ALL``, all hyperparameters for the model will be validated. (Default: None) + Raises: + JumpStartHyperparametersError: If the hyperparameter is not formatted correctly, + according to its specs in the model metadata. """ if not jumpstart_utils.is_jumpstart_model_input(model_id, model_version): - raise ValueError("Must specify `model_id` and `model_version` when retrieving script URIs.") + raise ValueError( + "Must specify `model_id` and `model_version` when validating hyperparameters." + ) if hyperparameters is None: raise ValueError("Must specify hyperparameters.") diff --git a/src/sagemaker/jumpstart/enums.py b/src/sagemaker/jumpstart/enums.py index 93714bd945..7bca1b20e1 100644 --- a/src/sagemaker/jumpstart/enums.py +++ b/src/sagemaker/jumpstart/enums.py @@ -1,3 +1,18 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""This module stores enums related to SageMaker JumpStart.""" +from __future__ import absolute_import + from enum import Enum @@ -42,3 +57,4 @@ class VariableTypes(str, Enum): TEXT = "text" INT = "int" FLOAT = "float" + BOOL = "bool" diff --git a/src/sagemaker/jumpstart/exceptions.py b/src/sagemaker/jumpstart/exceptions.py index 9693db71bd..eec21910f4 100644 --- a/src/sagemaker/jumpstart/exceptions.py +++ b/src/sagemaker/jumpstart/exceptions.py @@ -11,12 +11,13 @@ # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. """This module stores exceptions related to SageMaker JumpStart.""" +from __future__ import absolute_import from typing import Optional class JumpStartHyperparametersError(Exception): - """Exception raised for errors with hyperparameters for JumpStart models.""" + """Exception raised for bad hyperparameters of a JumpStart model.""" def __init__( self, diff --git a/src/sagemaker/jumpstart/types.py b/src/sagemaker/jumpstart/types.py index 9e4f224ba2..ce341eaf50 100644 --- a/src/sagemaker/jumpstart/types.py +++ b/src/sagemaker/jumpstart/types.py @@ -181,6 +181,8 @@ class JumpStartHyperparameter(JumpStartDataHolderType): "scope", "min", "max", + "exclusive_min", + "exclusive_max", } def __init__(self, spec: Dict[str, Any]): @@ -215,6 +217,14 @@ def from_json(self, json_obj: Dict[str, Any]) -> None: if max_val is not None: self.max = max_val + exclusive_min_val = json_obj.get("exclusive_min") + if exclusive_min_val is not None: + self.exclusive_min = exclusive_min_val + + exclusive_max_val = json_obj.get("exclusive_max") + if exclusive_max_val is not None: + self.exclusive_max = exclusive_max_val + def to_json(self) -> Dict[str, Any]: """Returns json representation of JumpStartHyperparameter object.""" json_obj = {att: getattr(self, att) for att in self.__slots__ if hasattr(self, att)} diff --git a/src/sagemaker/jumpstart/validators.py b/src/sagemaker/jumpstart/validators.py index 87a57e9021..a9ee23142c 100644 --- a/src/sagemaker/jumpstart/validators.py +++ b/src/sagemaker/jumpstart/validators.py @@ -12,7 +12,7 @@ # language governing permissions and limitations under the License. """This module contains validators related to SageMaker JumpStart.""" from __future__ import absolute_import -from typing import Any, List, Optional +from typing import Any, Dict, List, Optional from sagemaker.jumpstart.constants import JUMPSTART_DEFAULT_REGION_NAME from sagemaker.jumpstart.enums import HyperparameterValidationMode, VariableScope, VariableTypes @@ -33,36 +33,88 @@ def _validate_hyperparameter( hyperparameter_value (Any): The value of the hyperparemter to validate. hyperparameter_specs (List[JumpStartHyperparameter]): List of ``JumpStartHyperparameter`` to use when validating the hyperparameter. + + Raises: + JumpStartHyperparametersError: If the hyperparameter is not formatted correctly, + according to its specs in the model metadata. """ hyperparameter_spec = [ spec for spec in hyperparameter_specs if spec.name == hyperparameter_name ] if len(hyperparameter_spec) == 0: raise JumpStartHyperparametersError( - f"Unable to perform validation -- cannot find hyperparameter '{hyperparameter_name}' in model specs." + f"Unable to perform validation -- cannot find hyperparameter '{hyperparameter_name}' " + "in model specs." + ) + + if len(hyperparameter_spec) > 1: + raise JumpStartHyperparametersError( + f"Unable to perform validation -- found multiple hyperparameter " + f"'{hyperparameter_name}' in model specs." ) + hyperparameter_spec = hyperparameter_spec[0] - if hyperparameter_spec.type == VariableTypes.TEXT.value: - if type(hyperparameter_value) != str: + if hyperparameter_spec.type == VariableTypes.BOOL.value: + if isinstance(hyperparameter_value, bool): + return + if not isinstance(hyperparameter_value, str): + raise JumpStartHyperparametersError( + f"Expecting boolean valued hyperparameter, but got '{str(hyperparameter_value)}'." + ) + if str(hyperparameter_value).lower() not in ["true", "false"]: raise JumpStartHyperparametersError( - f"Expecting text valued hyperparameter to have string type." + f"Expecting boolean valued hyperparameter, but got '{str(hyperparameter_value)}'." + ) + elif hyperparameter_spec.type == VariableTypes.TEXT.value: + if not isinstance(hyperparameter_value, str): + raise JumpStartHyperparametersError( + "Expecting text valued hyperparameter to have string type." ) - if getattr(hyperparameter_spec, "options", None): + if hasattr(hyperparameter_spec, "options"): if hyperparameter_value not in hyperparameter_spec.options: raise JumpStartHyperparametersError( - f"Hyperparameter '{hyperparameter_name}' must have one of the following values: " - ", ".join(hyperparameter_spec.options) + f"Hyperparameter '{hyperparameter_name}' must have one of the following " + f"values: {', '.join(hyperparameter_spec.options)}" + ) + + if hasattr(hyperparameter_spec, "min"): + if len(hyperparameter_value) < hyperparameter_spec.min: + raise JumpStartHyperparametersError( + f"Hyperparameter '{hyperparameter_name}' must have length no less than " + f"{hyperparameter_spec.min}" + ) + + if hasattr(hyperparameter_spec, "exclusive_min"): + if len(hyperparameter_value) <= hyperparameter_spec.exclusive_min: + raise JumpStartHyperparametersError( + f"Hyperparameter '{hyperparameter_name}' must have length greater than " + f"{hyperparameter_spec.exclusive_min}" + ) + + if hasattr(hyperparameter_spec, "max"): + if len(hyperparameter_value) > hyperparameter_spec.max: + raise JumpStartHyperparametersError( + f"Hyperparameter '{hyperparameter_name}' must have length no greater than " + f"{hyperparameter_spec.max}" + ) + + if hasattr(hyperparameter_spec, "exclusive_max"): + if len(hyperparameter_value) >= hyperparameter_spec.exclusive_max: + raise JumpStartHyperparametersError( + f"Hyperparameter '{hyperparameter_name}' must have length less than " + f"{hyperparameter_spec.exclusive_max}" ) # validate numeric types - if hyperparameter_spec.type in [VariableTypes.INT.value, VariableTypes.FLOAT.value]: + elif hyperparameter_spec.type in [VariableTypes.INT.value, VariableTypes.FLOAT.value]: try: numeric_hyperparam_value = float(hyperparameter_value) except ValueError: raise JumpStartHyperparametersError( - f"Hyperparameter '{hyperparameter_name}' must be numeric type ('{hyperparameter_value}')." + f"Hyperparameter '{hyperparameter_name}' must be numeric type " + f"('{hyperparameter_value}')." ) if hyperparameter_spec.type == VariableTypes.INT.value: @@ -72,29 +124,46 @@ def _validate_hyperparameter( start_index = 1 if not hyperparameter_value_str[start_index:].isdigit(): raise JumpStartHyperparametersError( - f"Hyperparameter '{hyperparameter_name}' must be integer type ('{hyperparameter_value}')." + f"Hyperparameter '{hyperparameter_name}' must be integer type " + "('{hyperparameter_value}')." ) - if getattr(hyperparameter_spec, "min", None): + if hasattr(hyperparameter_spec, "min"): if numeric_hyperparam_value < hyperparameter_spec.min: raise JumpStartHyperparametersError( - f"Hyperparameter '{hyperparameter_name}' can be no less than {hyperparameter_spec.min}." + f"Hyperparameter '{hyperparameter_name}' can be no less than " + "{hyperparameter_spec.min}." ) - if getattr(hyperparameter_spec, "max", None): + if hasattr(hyperparameter_spec, "max"): if numeric_hyperparam_value > hyperparameter_spec.max: raise JumpStartHyperparametersError( - f"Hyperparameter '{hyperparameter_name}' can be no greater than {hyperparameter_spec.max}." + f"Hyperparameter '{hyperparameter_name}' can be no greater than " + "{hyperparameter_spec.max}." + ) + + if hasattr(hyperparameter_spec, "exclusive_min"): + if numeric_hyperparam_value <= hyperparameter_spec.exclusive_min: + raise JumpStartHyperparametersError( + f"Hyperparameter '{hyperparameter_name}' must be greater than " + "{hyperparameter_spec.exclusive_min}." + ) + + if hasattr(hyperparameter_spec, "exclusive_max"): + if numeric_hyperparam_value >= hyperparameter_spec.exclusive_max: + raise JumpStartHyperparametersError( + f"Hyperparameter '{hyperparameter_name}' must be less than than " + "{hyperparameter_spec.exclusive_max}." ) def validate_hyperparameters( model_id: str, model_version: str, - hyperparameters: dict, + hyperparameters: Dict[str, Any], validation_mode: HyperparameterValidationMode = HyperparameterValidationMode.VALIDATE_PROVIDED, region: Optional[str] = JUMPSTART_DEFAULT_REGION_NAME, -): +) -> None: """Validate hyperparameters for JumpStart models. Args: @@ -109,6 +178,10 @@ def validate_hyperparameters( region (str): Region for which to validate hyperparameters. (Default: JumpStart default region). + Raises: + JumpStartHyperparametersError: If the hyperparameter is not formatted correctly, + according to its specs in the model metadata. + """ if validation_mode is None: diff --git a/src/sagemaker/model_uris.py b/src/sagemaker/model_uris.py index 78061d9c79..86ad571150 100644 --- a/src/sagemaker/model_uris.py +++ b/src/sagemaker/model_uris.py @@ -46,7 +46,7 @@ def retrieve( ValueError: If the combination of arguments specified is not supported. """ if not jumpstart_utils.is_jumpstart_model_input(model_id, model_version): - raise ValueError("Must specify `model_id` and `model_version` when retrieving script URIs.") + raise ValueError("Must specify `model_id` and `model_version` when retrieving model URIs.") # mypy type checking require these assertions assert model_id is not None diff --git a/tests/unit/sagemaker/hyperparameters/jumpstart/test_validate.py b/tests/unit/sagemaker/hyperparameters/jumpstart/test_validate.py index c73fe76b1e..1a2462b81f 100644 --- a/tests/unit/sagemaker/hyperparameters/jumpstart/test_validate.py +++ b/tests/unit/sagemaker/hyperparameters/jumpstart/test_validate.py @@ -28,16 +28,44 @@ def test_jumpstart_validate_provided_hyperparameters(patched_get_model_specs): def add_options_to_hyperparameter(*largs, **kwargs): spec = get_spec_from_base_spec(*largs, **kwargs) - spec.hyperparameters.append( - JumpStartHyperparameter( - { - "name": "penalty", - "type": "text", - "default": "l2", - "options": ["l1", "l2", "elasticnet", "none"], - "scope": "algorithm", - } - ) + spec.hyperparameters.extend( + [ + JumpStartHyperparameter( + { + "name": "penalty", + "type": "text", + "default": "l2", + "options": ["l1", "l2", "elasticnet", "none"], + "scope": "algorithm", + } + ), + JumpStartHyperparameter( + { + "name": "test_bool_param", + "type": "bool", + "default": True, + "scope": "algorithm", + } + ), + JumpStartHyperparameter( + { + "name": "test_exclusive_min_param", + "type": "float", + "default": 4, + "scope": "algorithm", + "exclusive_min": 1, + } + ), + JumpStartHyperparameter( + { + "name": "test_exclusive_max_param", + "type": "int", + "default": -4, + "scope": "algorithm", + "exclusive_max": 4, + } + ), + ] ) return spec @@ -51,6 +79,9 @@ def add_options_to_hyperparameter(*largs, **kwargs): "batch-size": "4", "epochs": "3", "penalty": "l2", + "test_bool_param": False, + "test_exclusive_min_param": 4, + "test_exclusive_max_param": -4, } hyperparameters.validate( @@ -128,6 +159,66 @@ def add_options_to_hyperparameter(*largs, **kwargs): hyperparameters=hyperparameter_to_test, ) + original_bool_val = hyperparameter_to_test["test_bool_param"] + for val in ["False", "fAlSe", "false", "True", "TrUe", "true", True, False]: + hyperparameter_to_test["test_bool_param"] = val + hyperparameters.validate( + region=region, + model_id=model_id, + model_version=model_version, + hyperparameters=hyperparameter_to_test, + ) + for val in [None, "", 5, "Truesday", "Falsehood"]: + hyperparameter_to_test["test_bool_param"] = val + with pytest.raises(JumpStartHyperparametersError): + hyperparameters.validate( + region=region, + model_id=model_id, + model_version=model_version, + hyperparameters=hyperparameter_to_test, + ) + hyperparameter_to_test["test_bool_param"] = original_bool_val + + original_exclusive_min_val = hyperparameter_to_test["test_exclusive_min_param"] + for val in [2, 1 + 1e-9]: + hyperparameter_to_test["test_exclusive_min_param"] = val + hyperparameters.validate( + region=region, + model_id=model_id, + model_version=model_version, + hyperparameters=hyperparameter_to_test, + ) + for val in [1, 1 - 1e-99, -99]: + hyperparameter_to_test["test_exclusive_min_param"] = val + with pytest.raises(JumpStartHyperparametersError): + hyperparameters.validate( + region=region, + model_id=model_id, + model_version=model_version, + hyperparameters=hyperparameter_to_test, + ) + hyperparameter_to_test["test_exclusive_min_param"] = original_exclusive_min_val + + original_exclusive_max_val = hyperparameter_to_test["test_exclusive_max_param"] + for val in [-2, 2, 3]: + hyperparameter_to_test["test_exclusive_max_param"] = val + hyperparameters.validate( + region=region, + model_id=model_id, + model_version=model_version, + hyperparameters=hyperparameter_to_test, + ) + for val in [4, 5, 99]: + hyperparameter_to_test["test_exclusive_max_param"] = val + with pytest.raises(JumpStartHyperparametersError): + hyperparameters.validate( + region=region, + model_id=model_id, + model_version=model_version, + hyperparameters=hyperparameter_to_test, + ) + hyperparameter_to_test["test_exclusive_max_param"] = original_exclusive_max_val + del hyperparameter_to_test["batch-size"] hyperparameter_to_test["penalty"] = "blah" with pytest.raises(JumpStartHyperparametersError): From dd26117e19c16fd4b7617614571698fe3f43847c Mon Sep 17 00:00:00 2001 From: Evan Kravitz Date: Sat, 22 Jan 2022 01:21:53 +0000 Subject: [PATCH 3/7] change: add return typing --- src/sagemaker/hyperparameters.py | 2 +- src/sagemaker/jumpstart/validators.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/sagemaker/hyperparameters.py b/src/sagemaker/hyperparameters.py index 8cf94e7153..eef67ecc73 100644 --- a/src/sagemaker/hyperparameters.py +++ b/src/sagemaker/hyperparameters.py @@ -68,7 +68,7 @@ def validate( model_version: Optional[str] = None, hyperparameters: Optional[dict] = None, validation_mode: Optional[HyperparameterValidationMode] = None, -): +) -> None: """Validate hyperparameters for models. Args: diff --git a/src/sagemaker/jumpstart/validators.py b/src/sagemaker/jumpstart/validators.py index a9ee23142c..30809ba911 100644 --- a/src/sagemaker/jumpstart/validators.py +++ b/src/sagemaker/jumpstart/validators.py @@ -25,7 +25,7 @@ def _validate_hyperparameter( hyperparameter_name: str, hyperparameter_value: Any, hyperparameter_specs: List[JumpStartHyperparameter], -): +) -> None: """Perform low-level hyperparameter validation on single parameter. Args: From 063cdc7204781fe2f55bf5bfb2943c257bfd25b9 Mon Sep 17 00:00:00 2001 From: Evan Kravitz Date: Sat, 22 Jan 2022 15:48:48 +0000 Subject: [PATCH 4/7] change: add js unit tests for hyperparam validation for min/max/exclusiveMin/exclusiveMax text params --- .../jumpstart/test_validate.py | 120 ++++++++++++++++++ 1 file changed, 120 insertions(+) diff --git a/tests/unit/sagemaker/hyperparameters/jumpstart/test_validate.py b/tests/unit/sagemaker/hyperparameters/jumpstart/test_validate.py index 1a2462b81f..ddeeccba1d 100644 --- a/tests/unit/sagemaker/hyperparameters/jumpstart/test_validate.py +++ b/tests/unit/sagemaker/hyperparameters/jumpstart/test_validate.py @@ -65,6 +65,42 @@ def add_options_to_hyperparameter(*largs, **kwargs): "exclusive_max": 4, } ), + JumpStartHyperparameter( + { + "name": "test_exclusive_min_param_text", + "type": "text", + "default": "hello", + "scope": "algorithm", + "exclusive_min": 1, + } + ), + JumpStartHyperparameter( + { + "name": "test_exclusive_max_param_text", + "type": "text", + "default": "hello", + "scope": "algorithm", + "exclusive_max": 6, + } + ), + JumpStartHyperparameter( + { + "name": "test_min_param_text", + "type": "text", + "default": "hello", + "scope": "algorithm", + "min": 1, + } + ), + JumpStartHyperparameter( + { + "name": "test_max_param_text", + "type": "text", + "default": "hello", + "scope": "algorithm", + "max": 6, + } + ), ] ) return spec @@ -82,6 +118,10 @@ def add_options_to_hyperparameter(*largs, **kwargs): "test_bool_param": False, "test_exclusive_min_param": 4, "test_exclusive_max_param": -4, + "test_exclusive_min_param_text": "hello", + "test_exclusive_max_param_text": "hello", + "test_min_param_text": "hello", + "test_max_param_text": "hello", } hyperparameters.validate( @@ -219,6 +259,86 @@ def add_options_to_hyperparameter(*largs, **kwargs): ) hyperparameter_to_test["test_exclusive_max_param"] = original_exclusive_max_val + original_exclusive_max_text_val = hyperparameter_to_test["test_exclusive_max_param_text"] + for val in ["", "sd", "12345"]: + hyperparameter_to_test["test_exclusive_max_param_text"] = val + hyperparameters.validate( + region=region, + model_id=model_id, + model_version=model_version, + hyperparameters=hyperparameter_to_test, + ) + for val in ["123456", "123456789"]: + hyperparameter_to_test["test_exclusive_max_param_text"] = val + with pytest.raises(JumpStartHyperparametersError): + hyperparameters.validate( + region=region, + model_id=model_id, + model_version=model_version, + hyperparameters=hyperparameter_to_test, + ) + hyperparameter_to_test["test_exclusive_max_param_text"] = original_exclusive_max_text_val + + original_max_text_val = hyperparameter_to_test["test_max_param_text"] + for val in ["", "sd", "12345", "123456"]: + hyperparameter_to_test["test_max_param_text"] = val + hyperparameters.validate( + region=region, + model_id=model_id, + model_version=model_version, + hyperparameters=hyperparameter_to_test, + ) + for val in ["1234567", "123456789"]: + hyperparameter_to_test["test_max_param_text"] = val + with pytest.raises(JumpStartHyperparametersError): + hyperparameters.validate( + region=region, + model_id=model_id, + model_version=model_version, + hyperparameters=hyperparameter_to_test, + ) + hyperparameter_to_test["test_max_param_text"] = original_max_text_val + + original_exclusive_min_text_val = hyperparameter_to_test["test_exclusive_min_param_text"] + for val in ["12", "sdfs", "12345dsfs"]: + hyperparameter_to_test["test_exclusive_min_param_text"] = val + hyperparameters.validate( + region=region, + model_id=model_id, + model_version=model_version, + hyperparameters=hyperparameter_to_test, + ) + for val in ["1", "d", ""]: + hyperparameter_to_test["test_exclusive_min_param_text"] = val + with pytest.raises(JumpStartHyperparametersError): + hyperparameters.validate( + region=region, + model_id=model_id, + model_version=model_version, + hyperparameters=hyperparameter_to_test, + ) + hyperparameter_to_test["test_exclusive_min_param_text"] = original_exclusive_min_text_val + + original_min_text_val = hyperparameter_to_test["test_min_param_text"] + for val in ["1", "s", "12", "sdfs", "12345dsfs"]: + hyperparameter_to_test["test_min_param_text"] = val + hyperparameters.validate( + region=region, + model_id=model_id, + model_version=model_version, + hyperparameters=hyperparameter_to_test, + ) + for val in [""]: + hyperparameter_to_test["test_min_param_text"] = val + with pytest.raises(JumpStartHyperparametersError): + hyperparameters.validate( + region=region, + model_id=model_id, + model_version=model_version, + hyperparameters=hyperparameter_to_test, + ) + hyperparameter_to_test["test_min_param_text"] = original_min_text_val + del hyperparameter_to_test["batch-size"] hyperparameter_to_test["penalty"] = "blah" with pytest.raises(JumpStartHyperparametersError): From c6eb14dd373d01eeada781d6f7836a1ad637d968 Mon Sep 17 00:00:00 2001 From: Evan Kravitz Date: Sat, 22 Jan 2022 18:19:19 +0000 Subject: [PATCH 5/7] fix: jumpstart docstring --- src/sagemaker/jumpstart/validators.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/sagemaker/jumpstart/validators.py b/src/sagemaker/jumpstart/validators.py index 30809ba911..57d5fe6d72 100644 --- a/src/sagemaker/jumpstart/validators.py +++ b/src/sagemaker/jumpstart/validators.py @@ -179,8 +179,8 @@ def validate_hyperparameters( default region). Raises: - JumpStartHyperparametersError: If the hyperparameter is not formatted correctly, - according to its specs in the model metadata. + JumpStartHyperparametersError: If the hyperparameters are not formatted correctly, + according to their metadata specs. """ From 617adc1614481ff61a96cec09c37a929ffcf87fa Mon Sep 17 00:00:00 2001 From: Evan Kravitz Date: Tue, 25 Jan 2022 18:49:26 +0000 Subject: [PATCH 6/7] fix: update jumpstart docstring --- src/sagemaker/hyperparameters.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/sagemaker/hyperparameters.py b/src/sagemaker/hyperparameters.py index eef67ecc73..08ef24e19f 100644 --- a/src/sagemaker/hyperparameters.py +++ b/src/sagemaker/hyperparameters.py @@ -89,6 +89,7 @@ def validate( Raises: JumpStartHyperparametersError: If the hyperparameter is not formatted correctly, according to its specs in the model metadata. + ValueError: If the combination of arguments specified is not supported. """ From 8a358947c17ec54120b09cdc22c56eddad6dfe56 Mon Sep 17 00:00:00 2001 From: Evan Kravitz Date: Tue, 25 Jan 2022 20:47:49 +0000 Subject: [PATCH 7/7] fix: entry point variable name --- .../jumpstart/script_mode_class/test_inference.py | 4 ++-- .../jumpstart/script_mode_class/test_transfer_learning.py | 8 ++++---- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/integ/sagemaker/jumpstart/script_mode_class/test_inference.py b/tests/integ/sagemaker/jumpstart/script_mode_class/test_inference.py index e05ae87a31..69ede8b8c8 100644 --- a/tests/integ/sagemaker/jumpstart/script_mode_class/test_inference.py +++ b/tests/integ/sagemaker/jumpstart/script_mode_class/test_inference.py @@ -14,7 +14,7 @@ import os from sagemaker import image_uris, model_uris, script_uris -from sagemaker.jumpstart.constants import INFERENCE_ENTRYPOINT_SCRIPT_NAME +from sagemaker.jumpstart.constants import INFERENCE_ENTRY_POINT_SCRIPT_NAME from sagemaker.model import Model from tests.integ.sagemaker.jumpstart.constants import ( ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID, @@ -57,7 +57,7 @@ def test_jumpstart_inference_model_class(setup): image_uri=image_uri, model_data=model_uri, source_dir=script_uri, - entry_point=INFERENCE_ENTRYPOINT_SCRIPT_NAME, + entry_point=INFERENCE_ENTRY_POINT_SCRIPT_NAME, role=get_sm_session().get_caller_identity_arn(), sagemaker_session=get_sm_session(), enable_network_isolation=True, diff --git a/tests/integ/sagemaker/jumpstart/script_mode_class/test_transfer_learning.py b/tests/integ/sagemaker/jumpstart/script_mode_class/test_transfer_learning.py index 46de1b5d27..29b16cf9a5 100644 --- a/tests/integ/sagemaker/jumpstart/script_mode_class/test_transfer_learning.py +++ b/tests/integ/sagemaker/jumpstart/script_mode_class/test_transfer_learning.py @@ -16,9 +16,9 @@ from sagemaker import hyperparameters, image_uris, model_uris, script_uris from sagemaker.estimator import Estimator from sagemaker.jumpstart.constants import ( - INFERENCE_ENTRYPOINT_SCRIPT_NAME, + INFERENCE_ENTRY_POINT_SCRIPT_NAME, JUMPSTART_DEFAULT_REGION_NAME, - TRAINING_ENTRYPOINT_SCRIPT_NAME, + TRAINING_ENTRY_POINT_SCRIPT_NAME, ) from sagemaker.jumpstart.utils import get_jumpstart_content_bucket from sagemaker.utils import name_from_base @@ -70,7 +70,7 @@ def test_jumpstart_transfer_learning_estimator_class(setup): image_uri=image_uri, source_dir=script_uri, model_uri=model_uri, - entry_point=TRAINING_ENTRYPOINT_SCRIPT_NAME, + entry_point=TRAINING_ENTRY_POINT_SCRIPT_NAME, role=get_sm_session().get_caller_identity_arn(), sagemaker_session=get_sm_session(), enable_network_isolation=True, @@ -111,7 +111,7 @@ def test_jumpstart_transfer_learning_estimator_class(setup): estimator.deploy( initial_instance_count=instance_count, instance_type=inference_instance_type, - entry_point=INFERENCE_ENTRYPOINT_SCRIPT_NAME, + entry_point=INFERENCE_ENTRY_POINT_SCRIPT_NAME, image_uri=image_uri, source_dir=script_uri, endpoint_name=endpoint_name,