From c801ea702ff01bce99ebaad77204331354134530 Mon Sep 17 00:00:00 2001 From: Evan Kravitz Date: Wed, 19 Jan 2022 16:00:32 +0000 Subject: [PATCH 1/5] feat: jumpstart vulnerability and deprecated check --- src/sagemaker/image_uris.py | 8 ++ src/sagemaker/jumpstart/artifacts.py | 105 +++++++++--------- src/sagemaker/jumpstart/exceptions.py | 69 ++++++++++++ src/sagemaker/jumpstart/types.py | 15 +++ src/sagemaker/jumpstart/utils.py | 87 +++++++++++++++ src/sagemaker/model_uris.py | 15 ++- src/sagemaker/script_uris.py | 16 ++- .../image_uris/jumpstart/test_common.py | 15 ++- tests/unit/sagemaker/jumpstart/constants.py | 7 ++ tests/unit/sagemaker/jumpstart/test_utils.py | 97 +++++++++++++++- .../model_uris/jumpstart/test_common.py | 14 ++- .../script_uris/jumpstart/test_common.py | 14 ++- 12 files changed, 402 insertions(+), 60 deletions(-) create mode 100644 src/sagemaker/jumpstart/exceptions.py diff --git a/src/sagemaker/image_uris.py b/src/sagemaker/image_uris.py index 01ac633cd8..30732df753 100644 --- a/src/sagemaker/image_uris.py +++ b/src/sagemaker/image_uris.py @@ -45,6 +45,8 @@ def retrieve( training_compiler_config=None, model_id=None, model_version=None, + tolerate_vulnerable_model=None, + tolerate_deprecated_model=None, ) -> str: """Retrieves the ECR URI for the Docker image matching the given arguments. @@ -79,6 +81,10 @@ def retrieve( (default: None). model_version (str): Version of the JumpStart model for which to retrieve the image URI (default: None). + tolerate_vulnerable_model (bool): True if vulnerable models should be tolerated (exception + not thrown). False if these models should throw an exception. (Default: None). + tolerate_deprecated_model (bool): True if deprecated models should be tolerated (exception + not thrown). False if these models should throw an exception. (Default: None). Returns: str: the ECR URI for the corresponding SageMaker Docker image. @@ -106,6 +112,8 @@ def retrieve( distribution, base_framework_version, training_compiler_config, + tolerate_vulnerable_model, + tolerate_deprecated_model, ) if training_compiler_config is None: diff --git a/src/sagemaker/jumpstart/artifacts.py b/src/sagemaker/jumpstart/artifacts.py index 2919fe44b2..86244dd857 100644 --- a/src/sagemaker/jumpstart/artifacts.py +++ b/src/sagemaker/jumpstart/artifacts.py @@ -18,11 +18,13 @@ JUMPSTART_DEFAULT_REGION_NAME, INFERENCE, TRAINING, - SUPPORTED_JUMPSTART_SCOPES, ModelFramework, VariableScope, ) -from sagemaker.jumpstart.utils import get_jumpstart_content_bucket +from sagemaker.jumpstart.utils import ( + get_jumpstart_content_bucket, + verify_model_region_and_return_specs, +) from sagemaker.jumpstart import accessors as jumpstart_accessors @@ -40,6 +42,8 @@ def _retrieve_image_uri( distribution: Optional[str], base_framework_version: Optional[str], training_compiler_config: Optional[str], + tolerate_vulnerable_model: Optional[bool], + tolerate_deprecated_model: Optional[bool], ): """Retrieves the container image URI for JumpStart models. @@ -72,39 +76,36 @@ def _retrieve_image_uri( distribution (dict): A dictionary with information on how to run distributed training training_compiler_config (:class:`~sagemaker.training_compiler.TrainingCompilerConfig`): A configuration class for the SageMaker Training Compiler. + tolerate_vulnerable_model (bool): True if vulnerable models should be tolerated (exception + not thrown). False if these models should throw an exception. + tolerate_deprecated_model (bool): True if deprecated models should be tolerated (exception + not thrown). False if these models should throw an exception. Returns: str: the ECR URI for the corresponding SageMaker Docker image. Raises: ValueError: If the combination of arguments specified is not supported. + VulnerableJumpStartModelError: If the model is vulnerable. + DeprecatedJumpStartModelError: If the model is deprecated. """ if region is None: region = JUMPSTART_DEFAULT_REGION_NAME assert region is not None - if image_scope is None: - raise ValueError( - "Must specify `image_scope` argument to retrieve image uri for JumpStart models." - ) - if image_scope not in SUPPORTED_JUMPSTART_SCOPES: - raise ValueError( - f"JumpStart models only support scopes: {', '.join(SUPPORTED_JUMPSTART_SCOPES)}." - ) - - model_specs = jumpstart_accessors.JumpStartModelsAccessor.get_model_specs( - region=region, model_id=model_id, version=model_version + model_specs = verify_model_region_and_return_specs( + model_id=model_id, + version=model_version, + scope=image_scope, + region=region, + tolerate_vulnerable_model=tolerate_vulnerable_model, + tolerate_deprecated_model=tolerate_deprecated_model, ) if image_scope == INFERENCE: ecr_specs = model_specs.hosting_ecr_specs elif image_scope == TRAINING: - if not model_specs.training_supported: - raise ValueError( - f"JumpStart model ID '{model_id}' and version '{model_version}' " - "does not support training." - ) assert model_specs.training_ecr_specs is not None ecr_specs = model_specs.training_ecr_specs @@ -168,6 +169,8 @@ def _retrieve_model_uri( model_version: str, model_scope: Optional[str], region: Optional[str], + tolerate_vulnerable_model: Optional[bool], + tolerate_deprecated_model: Optional[bool], ): """Retrieves the model artifact S3 URI for the model matching the given arguments. @@ -179,39 +182,35 @@ def _retrieve_model_uri( model_scope (str): The model type, i.e. what it is used for. Valid values: "training" and "inference". region (str): Region for which to retrieve model S3 URI. + tolerate_vulnerable_model (bool): True if vulnerable models should be tolerated (exception + not thrown). False if these models should throw an exception. + tolerate_deprecated_model (bool): True if deprecated models should be tolerated (exception + not thrown). False if these models should throw an exception. Returns: str: the model artifact S3 URI for the corresponding model. Raises: ValueError: If the combination of arguments specified is not supported. + VulnerableJumpStartModelError: If the model is vulnerable. + DeprecatedJumpStartModelError: If the model is deprecated. """ if region is None: region = JUMPSTART_DEFAULT_REGION_NAME assert region is not None - if model_scope is None: - raise ValueError( - "Must specify `model_scope` argument to retrieve model " - "artifact uri for JumpStart models." - ) - - if model_scope not in SUPPORTED_JUMPSTART_SCOPES: - raise ValueError( - f"JumpStart models only support scopes: {', '.join(SUPPORTED_JUMPSTART_SCOPES)}." - ) - - model_specs = jumpstart_accessors.JumpStartModelsAccessor.get_model_specs( - region=region, model_id=model_id, version=model_version + model_specs = verify_model_region_and_return_specs( + model_id=model_id, + version=model_version, + scope=model_scope, + region=region, + tolerate_vulnerable_model=tolerate_vulnerable_model, + tolerate_deprecated_model=tolerate_deprecated_model, ) + if model_scope == INFERENCE: model_artifact_key = model_specs.hosting_artifact_key elif model_scope == TRAINING: - if not model_specs.training_supported: - raise ValueError( - f"JumpStart model ID '{model_id}' and version '{model_version}' " - "does not support training." - ) assert model_specs.training_artifact_key is not None model_artifact_key = model_specs.training_artifact_key @@ -227,6 +226,8 @@ def _retrieve_script_uri( model_version: str, script_scope: Optional[str], region: Optional[str], + tolerate_vulnerable_model: Optional[bool], + tolerate_deprecated_model: Optional[bool], ): """Retrieves the script S3 URI associated with the model matching the given arguments. @@ -238,39 +239,35 @@ def _retrieve_script_uri( script_scope (str): The script type, i.e. what it is used for. Valid values: "training" and "inference". region (str): Region for which to retrieve model script S3 URI. + tolerate_vulnerable_model (bool): True if vulnerable models should be tolerated (exception + not thrown). False if these models should throw an exception. + tolerate_deprecated_model (bool): True if deprecated models should be tolerated (exception + not thrown). False if these models should throw an exception. Returns: str: the model script URI for the corresponding model. Raises: ValueError: If the combination of arguments specified is not supported. + VulnerableJumpStartModelError: If the model is vulnerable. + DeprecatedJumpStartModelError: If the model is deprecated. """ if region is None: region = JUMPSTART_DEFAULT_REGION_NAME assert region is not None - if script_scope is None: - raise ValueError( - "Must specify `script_scope` argument to retrieve model script uri for " - "JumpStart models." - ) - - if script_scope not in SUPPORTED_JUMPSTART_SCOPES: - raise ValueError( - f"JumpStart models only support scopes: {', '.join(SUPPORTED_JUMPSTART_SCOPES)}." - ) - - model_specs = jumpstart_accessors.JumpStartModelsAccessor.get_model_specs( - region=region, model_id=model_id, version=model_version + model_specs = verify_model_region_and_return_specs( + model_id=model_id, + version=model_version, + scope=script_scope, + region=region, + tolerate_vulnerable_model=tolerate_vulnerable_model, + tolerate_deprecated_model=tolerate_deprecated_model, ) + if script_scope == INFERENCE: model_script_key = model_specs.hosting_script_key elif script_scope == TRAINING: - if not model_specs.training_supported: - raise ValueError( - f"JumpStart model ID '{model_id}' and version '{model_version}' " - "does not support training." - ) assert model_specs.training_script_key is not None model_script_key = model_specs.training_script_key diff --git a/src/sagemaker/jumpstart/exceptions.py b/src/sagemaker/jumpstart/exceptions.py new file mode 100644 index 0000000000..9ffbd6b3a3 --- /dev/null +++ b/src/sagemaker/jumpstart/exceptions.py @@ -0,0 +1,69 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""This module stores exceptions related to SageMaker JumpStart.""" + +from typing import List, Optional + + +class VulnerableJumpStartModelError(Exception): + """Exception raised for errors with vulnerable JumpStart models.""" + + def __init__( + self, + model_id: Optional[str] = None, + version: Optional[str] = None, + vulnerabilities: Optional[List[str]] = None, + inference: Optional[bool] = None, + message: Optional[str] = None, + ): + if message: + self.message = message + else: + if None in [model_id, version, vulnerabilities, inference]: + raise ValueError( + "Must specify `model_id`, `version`, `vulnerabilities`, " + "and inference arguments." + ) + if inference is True: + self.message = ( + f"JumpStart model '{model_id}' and version '{version}' has at least 1 " + "vulnerable dependency in the inference scripts. " + f"List of vulnerabilities: {', '.join(vulnerabilities)}" + ) + else: + self.message = ( + f"JumpStart model '{model_id}' and version '{version}' has at least 1 " + "vulnerable dependency in the training scripts. " + f"List of vulnerabilities: {', '.join(vulnerabilities)}" + ) + + super().__init__(self.message) + + +class DeprecatedJumpStartModelError(Exception): + """Exception raised for errors with deprecated JumpStart models.""" + + def __init__( + self, + model_id: Optional[str] = None, + version: Optional[str] = None, + message: Optional[str] = None, + ): + if message: + self.message = message + else: + if None in [model_id, version]: + raise ValueError("Must specify `model_id` and `version` arguments.") + self.message = f"JumpStart model '{model_id}' and version '{version}' is deprecated." + + super().__init__(self.message) diff --git a/src/sagemaker/jumpstart/types.py b/src/sagemaker/jumpstart/types.py index 9e4f224ba2..d5023010dd 100644 --- a/src/sagemaker/jumpstart/types.py +++ b/src/sagemaker/jumpstart/types.py @@ -274,6 +274,13 @@ class JumpStartModelSpecs(JumpStartDataHolderType): "training_script_key", "hyperparameters", "inference_environment_variables", + "inference_vulnerable", + "inference_dependencies", + "inference_vulnerabilities", + "training_vulnerable", + "training_dependencies", + "training_vulnerabilities", + "deprecated", ] def __init__(self, spec: Dict[str, Any]): @@ -302,6 +309,14 @@ def from_json(self, json_obj: Dict[str, Any]) -> None: JumpStartEnvironmentVariable(env_variable) for env_variable in json_obj["inference_environment_variables"] ] + self.inference_vulnerable: bool = bool(json_obj["inference_vulnerable"]) + self.inference_dependencies: List[str] = json_obj["inference_dependencies"] + self.inference_vulnerabilities: List[str] = json_obj["inference_vulnerabilities"] + self.training_vulnerable: bool = bool(json_obj["training_vulnerable"]) + self.training_dependencies: List[str] = json_obj["training_dependencies"] + self.training_vulnerabilities: List[str] = json_obj["training_vulnerabilities"] + self.deprecated: bool = bool(json_obj["deprecated"]) + if self.training_supported: self.training_ecr_specs: JumpStartECRSpecs = JumpStartECRSpecs( json_obj["training_ecr_specs"] diff --git a/src/sagemaker/jumpstart/utils.py b/src/sagemaker/jumpstart/utils.py index 7e54fbdc27..736a16c892 100644 --- a/src/sagemaker/jumpstart/utils.py +++ b/src/sagemaker/jumpstart/utils.py @@ -17,6 +17,10 @@ import sagemaker from sagemaker.jumpstart import constants from sagemaker.jumpstart import accessors +from sagemaker.jumpstart.exceptions import ( + DeprecatedJumpStartModelError, + VulnerableJumpStartModelError, +) from sagemaker.jumpstart.types import JumpStartModelHeader, JumpStartVersionedModelId @@ -136,3 +140,86 @@ def is_jumpstart_model_input(model_id: Optional[str], version: Optional[str]) -> ) return True return False + + +def verify_model_region_and_return_specs( + model_id: Optional[str], + version: Optional[str], + scope: Optional[str], + region: str, + tolerate_vulnerable_model: Optional[bool] = None, + tolerate_deprecated_model: Optional[bool] = None, +): + """Verifies that an acceptable model_id, version, scope, and region combination is provided. + + If the scope is not supported, the model id/region/version has no spec, or the model is vulnerable + or deprecated, an exception will be raised. + + Args: + model_id (Optional[str]): model id of the JumpStart model to verify and + obtains specs. + version (Optional[str]): version of the JumpStart model to verify and + obtains specs. + scope (Optional[str]): scope of the JumpStart model to verify. + region (Optional[str]): region of the JumpStart model to verify and + obtains specs. + tolerate_vulnerable_model (Optional[bool]): True if vulnerable models should be tolerated (exception + not thrown). False if these models should throw an exception. (Default: None). + tolerate_deprecated_model (Optional[bool]): True if deprecated models should be tolerated (exception + not thrown). False if these models should throw an exception. (Default: None). + """ + + if tolerate_vulnerable_model is None: + tolerate_vulnerable_model = False + + if tolerate_deprecated_model is None: + tolerate_deprecated_model = False + + if scope is None: + raise ValueError( + "Must specify `model_scope` argument to retrieve model " + "artifact uri for JumpStart models." + ) + + if scope not in constants.SUPPORTED_JUMPSTART_SCOPES: + raise ValueError( + f"JumpStart models only support scopes: {', '.join(constants.SUPPORTED_JUMPSTART_SCOPES)}." + ) + + model_specs = accessors.JumpStartModelsAccessor.get_model_specs( + region=region, model_id=model_id, version=version + ) + + if scope == constants.TRAINING and not model_specs.training_supported: + raise ValueError( + f"JumpStart model ID '{model_id}' and version '{version}' " "does not support training." + ) + + if model_specs.deprecated and not tolerate_deprecated_model: + raise DeprecatedJumpStartModelError(model_id=model_id, version=version) + + if ( + scope == constants.INFERENCE + and model_specs.inference_vulnerable + and not tolerate_vulnerable_model + ): + raise VulnerableJumpStartModelError( + model_id=model_id, + version=version, + vulnerabilities=model_specs.inference_vulnerabilities, + inference=True, + ) + + if ( + scope == constants.TRAINING + and model_specs.training_vulnerable + and not tolerate_vulnerable_model + ): + raise VulnerableJumpStartModelError( + model_id=model_id, + version=version, + vulnerabilities=model_specs.training_vulnerabilities, + inference=False, + ) + + return model_specs diff --git a/src/sagemaker/model_uris.py b/src/sagemaker/model_uris.py index 78061d9c79..48ca969bcf 100644 --- a/src/sagemaker/model_uris.py +++ b/src/sagemaker/model_uris.py @@ -28,6 +28,8 @@ def retrieve( model_id=None, model_version: Optional[str] = None, model_scope: Optional[str] = None, + tolerate_vulnerable_model: Optional[bool] = None, + tolerate_deprecated_model: Optional[bool] = None, ) -> str: """Retrieves the model artifact S3 URI for the model matching the given arguments. @@ -39,6 +41,10 @@ def retrieve( the model artifact S3 URI. model_scope (str): The model type, i.e. what it is used for. Valid values: "training" and "inference". + tolerate_vulnerable_model (bool): True if vulnerable models should be tolerated (exception + not thrown). False if these models should throw an exception. (Default: None). + tolerate_deprecated_model (bool): True if deprecated models should be tolerated (exception + not thrown). False if these models should throw an exception. (Default: None). Returns: str: the model artifact S3 URI for the corresponding model. @@ -52,4 +58,11 @@ def retrieve( assert model_id is not None assert model_version is not None - return artifacts._retrieve_model_uri(model_id, model_version, model_scope, region) + return artifacts._retrieve_model_uri( + model_id, + model_version, + model_scope, + region, + tolerate_vulnerable_model, + tolerate_deprecated_model, + ) diff --git a/src/sagemaker/script_uris.py b/src/sagemaker/script_uris.py index f5c2a6b97f..fe2d06e275 100644 --- a/src/sagemaker/script_uris.py +++ b/src/sagemaker/script_uris.py @@ -15,6 +15,7 @@ from __future__ import absolute_import import logging +from typing import Optional from sagemaker.jumpstart import utils as jumpstart_utils from sagemaker.jumpstart import artifacts @@ -27,6 +28,8 @@ def retrieve( model_id=None, model_version=None, script_scope=None, + tolerate_vulnerable_model: Optional[bool] = None, + tolerate_deprecated_model: Optional[bool] = None, ) -> str: """Retrieves the script S3 URI associated with the model matching the given arguments. @@ -38,6 +41,10 @@ def retrieve( model script S3 URI. script_scope (str): The script type, i.e. what it is used for. Valid values: "training" and "inference". + tolerate_vulnerable_model (bool): True if vulnerable models should be tolerated (exception + not thrown). False if these models should throw an exception. (Default: None). + tolerate_deprecated_model (bool): True if deprecated models should be tolerated (exception + not thrown). False if these models should throw an exception. (Default: None). Returns: str: the model script URI for the corresponding model. @@ -51,4 +58,11 @@ def retrieve( assert model_id is not None assert model_version is not None - return artifacts._retrieve_script_uri(model_id, model_version, script_scope, region) + return artifacts._retrieve_script_uri( + model_id, + model_version, + script_scope, + region, + tolerate_vulnerable_model, + tolerate_deprecated_model, + ) diff --git a/tests/unit/sagemaker/image_uris/jumpstart/test_common.py b/tests/unit/sagemaker/image_uris/jumpstart/test_common.py index d214065276..f8ba78fde6 100644 --- a/tests/unit/sagemaker/image_uris/jumpstart/test_common.py +++ b/tests/unit/sagemaker/image_uris/jumpstart/test_common.py @@ -16,13 +16,19 @@ import pytest from sagemaker import image_uris +from sagemaker.jumpstart.utils import verify_model_region_and_return_specs from tests.unit.sagemaker.jumpstart.utils import get_spec_from_base_spec from sagemaker.jumpstart import constants as sagemaker_constants +@patch("sagemaker.jumpstart.artifacts.verify_model_region_and_return_specs") @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") -def test_jumpstart_common_image_uri(patched_get_model_specs): +def test_jumpstart_common_image_uri( + patched_get_model_specs, patched_verify_model_region_and_return_specs +): + + patched_verify_model_region_and_return_specs.side_effect = verify_model_region_and_return_specs patched_get_model_specs.side_effect = get_spec_from_base_spec image_uris.retrieve( @@ -36,8 +42,10 @@ def test_jumpstart_common_image_uri(patched_get_model_specs): patched_get_model_specs.assert_called_once_with( region="us-west-2", model_id="pytorch-ic-mobilenet-v2", version="*" ) + patched_verify_model_region_and_return_specs.assert_called_once() patched_get_model_specs.reset_mock() + patched_verify_model_region_and_return_specs.reset_mock() image_uris.retrieve( framework=None, @@ -50,8 +58,10 @@ def test_jumpstart_common_image_uri(patched_get_model_specs): patched_get_model_specs.assert_called_once_with( region="us-west-2", model_id="pytorch-ic-mobilenet-v2", version="1.*" ) + patched_verify_model_region_and_return_specs.assert_called_once() patched_get_model_specs.reset_mock() + patched_verify_model_region_and_return_specs.reset_mock() image_uris.retrieve( framework=None, @@ -66,8 +76,10 @@ def test_jumpstart_common_image_uri(patched_get_model_specs): model_id="pytorch-ic-mobilenet-v2", version="*", ) + patched_verify_model_region_and_return_specs.assert_called_once() patched_get_model_specs.reset_mock() + patched_verify_model_region_and_return_specs.reset_mock() image_uris.retrieve( framework=None, @@ -82,6 +94,7 @@ def test_jumpstart_common_image_uri(patched_get_model_specs): model_id="pytorch-ic-mobilenet-v2", version="1.*", ) + patched_verify_model_region_and_return_specs.assert_called_once() with pytest.raises(ValueError): image_uris.retrieve( diff --git a/tests/unit/sagemaker/jumpstart/constants.py b/tests/unit/sagemaker/jumpstart/constants.py index d0d59be817..ebb3214e4c 100644 --- a/tests/unit/sagemaker/jumpstart/constants.py +++ b/tests/unit/sagemaker/jumpstart/constants.py @@ -1167,6 +1167,13 @@ "scope": "container", }, ], + "inference_vulnerable": False, + "inference_dependencies": [], + "inference_vulnerabilities": [], + "training_vulnerable": False, + "training_dependencies": [], + "training_vulnerabilities": [], + "deprecated": False, } BASE_HEADER = { diff --git a/tests/unit/sagemaker/jumpstart/test_utils.py b/tests/unit/sagemaker/jumpstart/test_utils.py index 008293b8b0..dfe7d887b3 100644 --- a/tests/unit/sagemaker/jumpstart/test_utils.py +++ b/tests/unit/sagemaker/jumpstart/test_utils.py @@ -14,8 +14,13 @@ from mock.mock import Mock, patch import pytest from sagemaker.jumpstart import utils -from sagemaker.jumpstart.constants import JUMPSTART_REGION_NAME_SET +from sagemaker.jumpstart.constants import INFERENCE, JUMPSTART_REGION_NAME_SET, TRAINING +from sagemaker.jumpstart.exceptions import ( + DeprecatedJumpStartModelError, + VulnerableJumpStartModelError, +) from sagemaker.jumpstart.types import JumpStartModelHeader, JumpStartVersionedModelId +from tests.unit.sagemaker.jumpstart.utils import get_spec_from_base_spec def test_get_jumpstart_content_bucket(): @@ -112,3 +117,93 @@ def test_get_sagemaker_version(patched_parse_sm_version: Mock): utils.get_sagemaker_version() utils.get_sagemaker_version() assert patched_parse_sm_version.called_only_once() + + +@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") +def test_jumpstart_vulnerable_model(patched_get_model_specs): + def make_vulnerable_inference_spec(*largs, **kwargs): + spec = get_spec_from_base_spec(*largs, **kwargs) + spec.inference_vulnerable = True + spec.inference_vulnerabilities = ["some", "vulnerability"] + return spec + + patched_get_model_specs.side_effect = make_vulnerable_inference_spec + + with pytest.raises(VulnerableJumpStartModelError) as e: + utils.verify_model_region_and_return_specs( + model_id="pytorch-eqa-bert-base-cased", version="*", scope=INFERENCE, region="us-west-2" + ) + assert ( + "JumpStart model 'pytorch-eqa-bert-base-cased' and version '*' has " + "at least 1 vulnerable dependency in the inference scripts. List of vulnerabilities: " + "some, vulnerability" == str(e.value.message) + ) + + assert ( + utils.verify_model_region_and_return_specs( + model_id="pytorch-eqa-bert-base-cased", + version="*", + scope=INFERENCE, + region="us-west-2", + tolerate_vulnerable_model=True, + ) + is not None + ) + + def make_vulnerable_training_spec(*largs, **kwargs): + spec = get_spec_from_base_spec(*largs, **kwargs) + spec.training_vulnerable = True + spec.training_vulnerabilities = ["some", "vulnerability"] + return spec + + patched_get_model_specs.side_effect = make_vulnerable_training_spec + + with pytest.raises(VulnerableJumpStartModelError) as e: + utils.verify_model_region_and_return_specs( + model_id="pytorch-eqa-bert-base-cased", version="*", scope=TRAINING, region="us-west-2" + ) + assert ( + "JumpStart model 'pytorch-eqa-bert-base-cased' and version '*' has " + "at least 1 vulnerable dependency in the training scripts. List of vulnerabilities: " + "some, vulnerability" == str(e.value.message) + ) + + assert ( + utils.verify_model_region_and_return_specs( + model_id="pytorch-eqa-bert-base-cased", + version="*", + scope=TRAINING, + region="us-west-2", + tolerate_vulnerable_model=True, + ) + is not None + ) + + +@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") +def test_jumpstart_deprecated_model(patched_get_model_specs): + def make_deprecated_spec(*largs, **kwargs): + spec = get_spec_from_base_spec(*largs, **kwargs) + spec.deprecated = True + return spec + + patched_get_model_specs.side_effect = make_deprecated_spec + + with pytest.raises(DeprecatedJumpStartModelError) as e: + utils.verify_model_region_and_return_specs( + model_id="pytorch-eqa-bert-base-cased", version="*", scope=INFERENCE, region="us-west-2" + ) + assert "JumpStart model 'pytorch-eqa-bert-base-cased' and version '*' is deprecated." == str( + e.value.message + ) + + assert ( + utils.verify_model_region_and_return_specs( + model_id="pytorch-eqa-bert-base-cased", + version="*", + scope=INFERENCE, + region="us-west-2", + tolerate_deprecated_model=True, + ) + is not None + ) diff --git a/tests/unit/sagemaker/model_uris/jumpstart/test_common.py b/tests/unit/sagemaker/model_uris/jumpstart/test_common.py index 379c8033ba..7b930e51e0 100644 --- a/tests/unit/sagemaker/model_uris/jumpstart/test_common.py +++ b/tests/unit/sagemaker/model_uris/jumpstart/test_common.py @@ -16,14 +16,19 @@ import pytest from sagemaker import model_uris +from sagemaker.jumpstart.utils import verify_model_region_and_return_specs from tests.unit.sagemaker.jumpstart.utils import get_spec_from_base_spec from sagemaker.jumpstart import constants as sagemaker_constants +@patch("sagemaker.jumpstart.artifacts.verify_model_region_and_return_specs") @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") -def test_jumpstart_common_model_uri(patched_get_model_specs): +def test_jumpstart_common_model_uri( + patched_get_model_specs, patched_verify_model_region_and_return_specs +): + patched_verify_model_region_and_return_specs.side_effect = verify_model_region_and_return_specs patched_get_model_specs.side_effect = get_spec_from_base_spec model_uris.retrieve( @@ -36,8 +41,10 @@ def test_jumpstart_common_model_uri(patched_get_model_specs): model_id="pytorch-ic-mobilenet-v2", version="*", ) + patched_verify_model_region_and_return_specs.assert_called_once() patched_get_model_specs.reset_mock() + patched_verify_model_region_and_return_specs.reset_mock() model_uris.retrieve( model_scope="inference", @@ -49,8 +56,10 @@ def test_jumpstart_common_model_uri(patched_get_model_specs): model_id="pytorch-ic-mobilenet-v2", version="1.*", ) + patched_verify_model_region_and_return_specs.assert_called_once() patched_get_model_specs.reset_mock() + patched_verify_model_region_and_return_specs.reset_mock() model_uris.retrieve( region="us-west-2", @@ -61,8 +70,10 @@ def test_jumpstart_common_model_uri(patched_get_model_specs): patched_get_model_specs.assert_called_once_with( region="us-west-2", model_id="pytorch-ic-mobilenet-v2", version="*" ) + patched_verify_model_region_and_return_specs.assert_called_once() patched_get_model_specs.reset_mock() + patched_verify_model_region_and_return_specs.reset_mock() model_uris.retrieve( region="us-west-2", @@ -73,6 +84,7 @@ def test_jumpstart_common_model_uri(patched_get_model_specs): patched_get_model_specs.assert_called_once_with( region="us-west-2", model_id="pytorch-ic-mobilenet-v2", version="1.*" ) + patched_verify_model_region_and_return_specs.assert_called_once() with pytest.raises(ValueError): model_uris.retrieve( diff --git a/tests/unit/sagemaker/script_uris/jumpstart/test_common.py b/tests/unit/sagemaker/script_uris/jumpstart/test_common.py index 0f61a27ad9..545c01e9ec 100644 --- a/tests/unit/sagemaker/script_uris/jumpstart/test_common.py +++ b/tests/unit/sagemaker/script_uris/jumpstart/test_common.py @@ -16,14 +16,19 @@ from mock.mock import patch from sagemaker import script_uris +from sagemaker.jumpstart.utils import verify_model_region_and_return_specs from tests.unit.sagemaker.jumpstart.utils import get_spec_from_base_spec from sagemaker.jumpstart import constants as sagemaker_constants +@patch("sagemaker.jumpstart.artifacts.verify_model_region_and_return_specs") @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") -def test_jumpstart_common_script_uri(patched_get_model_specs): +def test_jumpstart_common_script_uri( + patched_get_model_specs, patched_verify_model_region_and_return_specs +): + patched_verify_model_region_and_return_specs.side_effect = verify_model_region_and_return_specs patched_get_model_specs.side_effect = get_spec_from_base_spec script_uris.retrieve( @@ -36,8 +41,10 @@ def test_jumpstart_common_script_uri(patched_get_model_specs): model_id="pytorch-ic-mobilenet-v2", version="*", ) + patched_verify_model_region_and_return_specs.assert_called_once() patched_get_model_specs.reset_mock() + patched_verify_model_region_and_return_specs.reset_mock() script_uris.retrieve( script_scope="inference", @@ -49,8 +56,10 @@ def test_jumpstart_common_script_uri(patched_get_model_specs): model_id="pytorch-ic-mobilenet-v2", version="1.*", ) + patched_verify_model_region_and_return_specs.assert_called_once() patched_get_model_specs.reset_mock() + patched_verify_model_region_and_return_specs.reset_mock() script_uris.retrieve( region="us-west-2", @@ -61,8 +70,10 @@ def test_jumpstart_common_script_uri(patched_get_model_specs): patched_get_model_specs.assert_called_once_with( region="us-west-2", model_id="pytorch-ic-mobilenet-v2", version="*" ) + patched_verify_model_region_and_return_specs.assert_called_once() patched_get_model_specs.reset_mock() + patched_verify_model_region_and_return_specs.reset_mock() script_uris.retrieve( region="us-west-2", @@ -73,6 +84,7 @@ def test_jumpstart_common_script_uri(patched_get_model_specs): patched_get_model_specs.assert_called_once_with( region="us-west-2", model_id="pytorch-ic-mobilenet-v2", version="1.*" ) + patched_verify_model_region_and_return_specs.assert_called_once() with pytest.raises(ValueError): script_uris.retrieve( From 878c7057fc27e771d823f161d583886cdba7cfc1 Mon Sep 17 00:00:00 2001 From: Evan Kravitz Date: Thu, 20 Jan 2022 15:50:22 +0000 Subject: [PATCH 2/5] change: cleanup code --- src/sagemaker/image_uris.py | 4 +- src/sagemaker/jumpstart/artifacts.py | 29 ++++---- src/sagemaker/jumpstart/constants.py | 13 +++- src/sagemaker/jumpstart/exceptions.py | 66 ++++++++++++++----- src/sagemaker/jumpstart/utils.py | 49 +++++++++----- src/sagemaker/model_uris.py | 4 +- src/sagemaker/script_uris.py | 4 +- .../image_uris/jumpstart/test_common.py | 2 +- tests/unit/sagemaker/jumpstart/test_utils.py | 46 ++++++++----- .../model_uris/jumpstart/test_common.py | 2 +- .../script_uris/jumpstart/test_common.py | 2 +- 11 files changed, 145 insertions(+), 76 deletions(-) diff --git a/src/sagemaker/image_uris.py b/src/sagemaker/image_uris.py index 30732df753..717ee64d5c 100644 --- a/src/sagemaker/image_uris.py +++ b/src/sagemaker/image_uris.py @@ -82,9 +82,9 @@ def retrieve( model_version (str): Version of the JumpStart model for which to retrieve the image URI (default: None). tolerate_vulnerable_model (bool): True if vulnerable models should be tolerated (exception - not thrown). False if these models should throw an exception. (Default: None). + not raised). False if these models should raise an exception. (Default: None). tolerate_deprecated_model (bool): True if deprecated models should be tolerated (exception - not thrown). False if these models should throw an exception. (Default: None). + not raised). False if these models should raise an exception. (Default: None). Returns: str: the ECR URI for the corresponding SageMaker Docker image. diff --git a/src/sagemaker/jumpstart/artifacts.py b/src/sagemaker/jumpstart/artifacts.py index 86244dd857..8fba926bac 100644 --- a/src/sagemaker/jumpstart/artifacts.py +++ b/src/sagemaker/jumpstart/artifacts.py @@ -16,8 +16,7 @@ from sagemaker import image_uris from sagemaker.jumpstart.constants import ( JUMPSTART_DEFAULT_REGION_NAME, - INFERENCE, - TRAINING, + JumpStartScriptScope, ModelFramework, VariableScope, ) @@ -77,9 +76,9 @@ def _retrieve_image_uri( training_compiler_config (:class:`~sagemaker.training_compiler.TrainingCompilerConfig`): A configuration class for the SageMaker Training Compiler. tolerate_vulnerable_model (bool): True if vulnerable models should be tolerated (exception - not thrown). False if these models should throw an exception. + not raised). False if these models should raise an exception. tolerate_deprecated_model (bool): True if deprecated models should be tolerated (exception - not thrown). False if these models should throw an exception. + not raised). False if these models should raise an exception. Returns: str: the ECR URI for the corresponding SageMaker Docker image. @@ -103,9 +102,9 @@ def _retrieve_image_uri( tolerate_deprecated_model=tolerate_deprecated_model, ) - if image_scope == INFERENCE: + if image_scope == JumpStartScriptScope.INFERENCE.value: ecr_specs = model_specs.hosting_ecr_specs - elif image_scope == TRAINING: + elif image_scope == JumpStartScriptScope.TRAINING.value: assert model_specs.training_ecr_specs is not None ecr_specs = model_specs.training_ecr_specs @@ -133,7 +132,7 @@ def _retrieve_image_uri( base_framework_version_override = ecr_specs.framework_version version_override = ecr_specs.huggingface_transformers_version - if image_scope == TRAINING: + if image_scope == JumpStartScriptScope.TRAINING.value: return image_uris.get_training_image_uri( region=region, framework=ecr_specs.framework, @@ -183,9 +182,9 @@ def _retrieve_model_uri( Valid values: "training" and "inference". region (str): Region for which to retrieve model S3 URI. tolerate_vulnerable_model (bool): True if vulnerable models should be tolerated (exception - not thrown). False if these models should throw an exception. + not raised). False if these models should raise an exception. tolerate_deprecated_model (bool): True if deprecated models should be tolerated (exception - not thrown). False if these models should throw an exception. + not raised). False if these models should raise an exception. Returns: str: the model artifact S3 URI for the corresponding model. @@ -208,9 +207,9 @@ def _retrieve_model_uri( tolerate_deprecated_model=tolerate_deprecated_model, ) - if model_scope == INFERENCE: + if model_scope == JumpStartScriptScope.INFERENCE.value: model_artifact_key = model_specs.hosting_artifact_key - elif model_scope == TRAINING: + elif model_scope == JumpStartScriptScope.TRAINING.value: assert model_specs.training_artifact_key is not None model_artifact_key = model_specs.training_artifact_key @@ -240,9 +239,9 @@ def _retrieve_script_uri( Valid values: "training" and "inference". region (str): Region for which to retrieve model script S3 URI. tolerate_vulnerable_model (bool): True if vulnerable models should be tolerated (exception - not thrown). False if these models should throw an exception. + not raised). False if these models should raise an exception. tolerate_deprecated_model (bool): True if deprecated models should be tolerated (exception - not thrown). False if these models should throw an exception. + not raised). False if these models should raise an exception. Returns: str: the model script URI for the corresponding model. @@ -265,9 +264,9 @@ def _retrieve_script_uri( tolerate_deprecated_model=tolerate_deprecated_model, ) - if script_scope == INFERENCE: + if script_scope == JumpStartScriptScope.INFERENCE.value: model_script_key = model_specs.hosting_script_key - elif script_scope == TRAINING: + elif script_scope == JumpStartScriptScope.TRAINING.value: assert model_specs.training_script_key is not None model_script_key = model_specs.training_script_key diff --git a/src/sagemaker/jumpstart/constants.py b/src/sagemaker/jumpstart/constants.py index aedce0e0da..adb1227803 100644 --- a/src/sagemaker/jumpstart/constants.py +++ b/src/sagemaker/jumpstart/constants.py @@ -116,14 +116,21 @@ JUMPSTART_DEFAULT_MANIFEST_FILE_S3_KEY = "models_manifest.json" -INFERENCE = "inference" -TRAINING = "training" -SUPPORTED_JUMPSTART_SCOPES = set([INFERENCE, TRAINING]) INFERENCE_ENTRYPOINT_SCRIPT_NAME = "inference.py" TRAINING_ENTRYPOINT_SCRIPT_NAME = "transfer_learning.py" +class JumpStartScriptScope(str, Enum): + """Enum class for JumpStart script scopes.""" + + INFERENCE = "inference" + TRAINING = "training" + + +SUPPORTED_JUMPSTART_SCOPES = set(scope.value for scope in JumpStartScriptScope) + + class ModelFramework(str, Enum): """Enum class for JumpStart model framework. diff --git a/src/sagemaker/jumpstart/exceptions.py b/src/sagemaker/jumpstart/exceptions.py index 9ffbd6b3a3..4fdb6e2534 100644 --- a/src/sagemaker/jumpstart/exceptions.py +++ b/src/sagemaker/jumpstart/exceptions.py @@ -12,46 +12,79 @@ # language governing permissions and limitations under the License. """This module stores exceptions related to SageMaker JumpStart.""" +from __future__ import absolute_import from typing import List, Optional +from sagemaker.jumpstart.constants import JumpStartScriptScope + class VulnerableJumpStartModelError(Exception): - """Exception raised for errors with vulnerable JumpStart models.""" + """Exception raised when trying to access a JumpStart model specs flagged as vulnerable. + + Raise this exception only if the scope of attributes accessed in the specifications have + vulnerabilities. For example, a model training script may have vulnerabilities, but not + the hosting scripts. In such a case, raise a ``VulnerableJumpStartModelError`` only when + accessing the training specifications. + """ def __init__( self, model_id: Optional[str] = None, version: Optional[str] = None, vulnerabilities: Optional[List[str]] = None, - inference: Optional[bool] = None, + scope: Optional[JumpStartScriptScope] = None, message: Optional[str] = None, ): + """Instantiates VulnerableJumpStartModelError exception. + + Args: + model_id (Optional[str]): model id of vulnerable JumpStart model. + (Default: None). + version (Optional[str]): version of vulnerable JumpStart model. + (Default: None). + vulnerabilities (Optional[List[str]]): vulnerabilities associated with + model. (Default: None). + + """ if message: self.message = message else: - if None in [model_id, version, vulnerabilities, inference]: + if None in [model_id, version, vulnerabilities, scope]: raise ValueError( - "Must specify `model_id`, `version`, `vulnerabilities`, " - "and inference arguments." + "Must specify `model_id`, `version`, `vulnerabilities`, " "and scope arguments." ) - if inference is True: + if scope == JumpStartScriptScope.INFERENCE: self.message = ( - f"JumpStart model '{model_id}' and version '{version}' has at least 1 " - "vulnerable dependency in the inference scripts. " - f"List of vulnerabilities: {', '.join(vulnerabilities)}" + f"Version '{version}' of JumpStart model '{model_id}' " # type: ignore + "has at least 1 vulnerable dependency in the inference script. " + "Please try targetting a higher version of the model. " + f"List of vulnerabilities: {', '.join(vulnerabilities)}" # type: ignore ) - else: + elif scope == JumpStartScriptScope.TRAINING: self.message = ( - f"JumpStart model '{model_id}' and version '{version}' has at least 1 " - "vulnerable dependency in the training scripts. " - f"List of vulnerabilities: {', '.join(vulnerabilities)}" + f"Version '{version}' of JumpStart model '{model_id}' " # type: ignore + "has at least 1 vulnerable dependency in the training script. " + "Please try targetting a higher version of the model. " + f"List of vulnerabilities: {', '.join(vulnerabilities)}" # type: ignore + ) + else: + raise NotImplementedError( + "Unsupported scope for VulnerableJumpStartModelError: " # type: ignore + f"'{scope.value}'" ) super().__init__(self.message) class DeprecatedJumpStartModelError(Exception): - """Exception raised for errors with deprecated JumpStart models.""" + """Exception raised when trying to access a JumpStart model deprecated specifications. + + A deprecated specification for a JumpStart model does not mean the whole model is + deprecated. There may be more recent specifications available for this model. For + example, all specification before version ``2.0.0`` may be deprecated, in such a + case, the SDK would raise this exception only when specifications ``1.*`` are + accessed. + """ def __init__( self, @@ -64,6 +97,9 @@ def __init__( else: if None in [model_id, version]: raise ValueError("Must specify `model_id` and `version` arguments.") - self.message = f"JumpStart model '{model_id}' and version '{version}' is deprecated." + self.message = ( + f"Version '{version}' of JumpStart model '{model_id}' is deprecated. " + "Please try targetting a higher version of the model." + ) super().__init__(self.message) diff --git a/src/sagemaker/jumpstart/utils.py b/src/sagemaker/jumpstart/utils.py index 736a16c892..906f49d74c 100644 --- a/src/sagemaker/jumpstart/utils.py +++ b/src/sagemaker/jumpstart/utils.py @@ -21,7 +21,11 @@ DeprecatedJumpStartModelError, VulnerableJumpStartModelError, ) -from sagemaker.jumpstart.types import JumpStartModelHeader, JumpStartVersionedModelId +from sagemaker.jumpstart.types import ( + JumpStartModelHeader, + JumpStartModelSpecs, + JumpStartVersionedModelId, +) def get_jumpstart_launched_regions_message() -> str: @@ -149,12 +153,9 @@ def verify_model_region_and_return_specs( region: str, tolerate_vulnerable_model: Optional[bool] = None, tolerate_deprecated_model: Optional[bool] = None, -): +) -> JumpStartModelSpecs: """Verifies that an acceptable model_id, version, scope, and region combination is provided. - If the scope is not supported, the model id/region/version has no spec, or the model is vulnerable - or deprecated, an exception will be raised. - Args: model_id (Optional[str]): model id of the JumpStart model to verify and obtains specs. @@ -163,10 +164,19 @@ def verify_model_region_and_return_specs( scope (Optional[str]): scope of the JumpStart model to verify. region (Optional[str]): region of the JumpStart model to verify and obtains specs. - tolerate_vulnerable_model (Optional[bool]): True if vulnerable models should be tolerated (exception - not thrown). False if these models should throw an exception. (Default: None). - tolerate_deprecated_model (Optional[bool]): True if deprecated models should be tolerated (exception - not thrown). False if these models should throw an exception. (Default: None). + tolerate_vulnerable_model (Optional[bool]): True if vulnerable models should be tolerated + (exception not raised). False if these models should raise an exception. + (Default: None). + tolerate_deprecated_model (Optional[bool]): True if deprecated models should be tolerated + (exception not raised). False if these models should raise an exception. + (Default: None). + + + Raises: + ValueError: If the combination of arguments specified is not supported. + NotImplementedError: If the scope is not supported. + VulnerableJumpStartModelError: If the model is vulnerable. + DeprecatedJumpStartModelError: If the model is deprecated. """ if tolerate_vulnerable_model is None: @@ -182,15 +192,22 @@ def verify_model_region_and_return_specs( ) if scope not in constants.SUPPORTED_JUMPSTART_SCOPES: - raise ValueError( - f"JumpStart models only support scopes: {', '.join(constants.SUPPORTED_JUMPSTART_SCOPES)}." + raise NotImplementedError( + "JumpStart models only support scopes: " + f"{', '.join(constants.SUPPORTED_JUMPSTART_SCOPES)}." ) + assert model_id is not None + assert version is not None + model_specs = accessors.JumpStartModelsAccessor.get_model_specs( region=region, model_id=model_id, version=version ) - if scope == constants.TRAINING and not model_specs.training_supported: + if ( + scope == constants.JumpStartScriptScope.TRAINING.value + and not model_specs.training_supported + ): raise ValueError( f"JumpStart model ID '{model_id}' and version '{version}' " "does not support training." ) @@ -199,7 +216,7 @@ def verify_model_region_and_return_specs( raise DeprecatedJumpStartModelError(model_id=model_id, version=version) if ( - scope == constants.INFERENCE + scope == constants.JumpStartScriptScope.INFERENCE.value and model_specs.inference_vulnerable and not tolerate_vulnerable_model ): @@ -207,11 +224,11 @@ def verify_model_region_and_return_specs( model_id=model_id, version=version, vulnerabilities=model_specs.inference_vulnerabilities, - inference=True, + scope=constants.JumpStartScriptScope.INFERENCE, ) if ( - scope == constants.TRAINING + scope == constants.JumpStartScriptScope.TRAINING.value and model_specs.training_vulnerable and not tolerate_vulnerable_model ): @@ -219,7 +236,7 @@ def verify_model_region_and_return_specs( model_id=model_id, version=version, vulnerabilities=model_specs.training_vulnerabilities, - inference=False, + scope=constants.JumpStartScriptScope.TRAINING, ) return model_specs diff --git a/src/sagemaker/model_uris.py b/src/sagemaker/model_uris.py index 48ca969bcf..1044b0d567 100644 --- a/src/sagemaker/model_uris.py +++ b/src/sagemaker/model_uris.py @@ -42,9 +42,9 @@ def retrieve( model_scope (str): The model type, i.e. what it is used for. Valid values: "training" and "inference". tolerate_vulnerable_model (bool): True if vulnerable models should be tolerated (exception - not thrown). False if these models should throw an exception. (Default: None). + not raised). False if these models should raise an exception. (Default: None). tolerate_deprecated_model (bool): True if deprecated models should be tolerated (exception - not thrown). False if these models should throw an exception. (Default: None). + not raised). False if these models should raise an exception. (Default: None). Returns: str: the model artifact S3 URI for the corresponding model. diff --git a/src/sagemaker/script_uris.py b/src/sagemaker/script_uris.py index fe2d06e275..03f4599fd2 100644 --- a/src/sagemaker/script_uris.py +++ b/src/sagemaker/script_uris.py @@ -42,9 +42,9 @@ def retrieve( script_scope (str): The script type, i.e. what it is used for. Valid values: "training" and "inference". tolerate_vulnerable_model (bool): True if vulnerable models should be tolerated (exception - not thrown). False if these models should throw an exception. (Default: None). + not raised). False if these models should raise an exception. (Default: None). tolerate_deprecated_model (bool): True if deprecated models should be tolerated (exception - not thrown). False if these models should throw an exception. (Default: None). + not raised). False if these models should raise an exception. (Default: None). Returns: str: the model script URI for the corresponding model. diff --git a/tests/unit/sagemaker/image_uris/jumpstart/test_common.py b/tests/unit/sagemaker/image_uris/jumpstart/test_common.py index f8ba78fde6..091f13ea46 100644 --- a/tests/unit/sagemaker/image_uris/jumpstart/test_common.py +++ b/tests/unit/sagemaker/image_uris/jumpstart/test_common.py @@ -96,7 +96,7 @@ def test_jumpstart_common_image_uri( ) patched_verify_model_region_and_return_specs.assert_called_once() - with pytest.raises(ValueError): + with pytest.raises(NotImplementedError): image_uris.retrieve( framework=None, region="us-west-2", diff --git a/tests/unit/sagemaker/jumpstart/test_utils.py b/tests/unit/sagemaker/jumpstart/test_utils.py index dfe7d887b3..388cc23d48 100644 --- a/tests/unit/sagemaker/jumpstart/test_utils.py +++ b/tests/unit/sagemaker/jumpstart/test_utils.py @@ -14,7 +14,7 @@ from mock.mock import Mock, patch import pytest from sagemaker.jumpstart import utils -from sagemaker.jumpstart.constants import INFERENCE, JUMPSTART_REGION_NAME_SET, TRAINING +from sagemaker.jumpstart.constants import JUMPSTART_REGION_NAME_SET, JumpStartScriptScope from sagemaker.jumpstart.exceptions import ( DeprecatedJumpStartModelError, VulnerableJumpStartModelError, @@ -131,19 +131,23 @@ def make_vulnerable_inference_spec(*largs, **kwargs): with pytest.raises(VulnerableJumpStartModelError) as e: utils.verify_model_region_and_return_specs( - model_id="pytorch-eqa-bert-base-cased", version="*", scope=INFERENCE, region="us-west-2" + model_id="pytorch-eqa-bert-base-cased", + version="*", + scope=JumpStartScriptScope.INFERENCE.value, + region="us-west-2", ) assert ( - "JumpStart model 'pytorch-eqa-bert-base-cased' and version '*' has " - "at least 1 vulnerable dependency in the inference scripts. List of vulnerabilities: " - "some, vulnerability" == str(e.value.message) - ) + "Version '*' of JumpStart model 'pytorch-eqa-bert-base-cased' has at least 1 " + "vulnerable dependency in the inference script. " + "Please try targetting a higher version of the model. " + "List of vulnerabilities: some, vulnerability" + ) == str(e.value.message) assert ( utils.verify_model_region_and_return_specs( model_id="pytorch-eqa-bert-base-cased", version="*", - scope=INFERENCE, + scope=JumpStartScriptScope.INFERENCE.value, region="us-west-2", tolerate_vulnerable_model=True, ) @@ -160,19 +164,23 @@ def make_vulnerable_training_spec(*largs, **kwargs): with pytest.raises(VulnerableJumpStartModelError) as e: utils.verify_model_region_and_return_specs( - model_id="pytorch-eqa-bert-base-cased", version="*", scope=TRAINING, region="us-west-2" + model_id="pytorch-eqa-bert-base-cased", + version="*", + scope=JumpStartScriptScope.TRAINING.value, + region="us-west-2", ) assert ( - "JumpStart model 'pytorch-eqa-bert-base-cased' and version '*' has " - "at least 1 vulnerable dependency in the training scripts. List of vulnerabilities: " - "some, vulnerability" == str(e.value.message) - ) + "Version '*' of JumpStart model 'pytorch-eqa-bert-base-cased' has at least 1 " + "vulnerable dependency in the training script. " + "Please try targetting a higher version of the model. " + "List of vulnerabilities: some, vulnerability" + ) == str(e.value.message) assert ( utils.verify_model_region_and_return_specs( model_id="pytorch-eqa-bert-base-cased", version="*", - scope=TRAINING, + scope=JumpStartScriptScope.TRAINING.value, region="us-west-2", tolerate_vulnerable_model=True, ) @@ -191,17 +199,19 @@ def make_deprecated_spec(*largs, **kwargs): with pytest.raises(DeprecatedJumpStartModelError) as e: utils.verify_model_region_and_return_specs( - model_id="pytorch-eqa-bert-base-cased", version="*", scope=INFERENCE, region="us-west-2" + model_id="pytorch-eqa-bert-base-cased", + version="*", + scope=JumpStartScriptScope.INFERENCE.value, + region="us-west-2", ) - assert "JumpStart model 'pytorch-eqa-bert-base-cased' and version '*' is deprecated." == str( - e.value.message - ) + assert "Version '*' of JumpStart model 'pytorch-eqa-bert-base-cased' is deprecated. " + "Please try targetting a higher version of the model." == str(e.value.message) assert ( utils.verify_model_region_and_return_specs( model_id="pytorch-eqa-bert-base-cased", version="*", - scope=INFERENCE, + scope=JumpStartScriptScope.INFERENCE.value, region="us-west-2", tolerate_deprecated_model=True, ) diff --git a/tests/unit/sagemaker/model_uris/jumpstart/test_common.py b/tests/unit/sagemaker/model_uris/jumpstart/test_common.py index 7b930e51e0..699f5836f3 100644 --- a/tests/unit/sagemaker/model_uris/jumpstart/test_common.py +++ b/tests/unit/sagemaker/model_uris/jumpstart/test_common.py @@ -86,7 +86,7 @@ def test_jumpstart_common_model_uri( ) patched_verify_model_region_and_return_specs.assert_called_once() - with pytest.raises(ValueError): + with pytest.raises(NotImplementedError): model_uris.retrieve( region="us-west-2", model_scope="BAD_SCOPE", diff --git a/tests/unit/sagemaker/script_uris/jumpstart/test_common.py b/tests/unit/sagemaker/script_uris/jumpstart/test_common.py index 545c01e9ec..05d8368bf3 100644 --- a/tests/unit/sagemaker/script_uris/jumpstart/test_common.py +++ b/tests/unit/sagemaker/script_uris/jumpstart/test_common.py @@ -86,7 +86,7 @@ def test_jumpstart_common_script_uri( ) patched_verify_model_region_and_return_specs.assert_called_once() - with pytest.raises(ValueError): + with pytest.raises(NotImplementedError): script_uris.retrieve( region="us-west-2", script_scope="BAD_SCOPE", From 09c97f41c977d4a74ba5aa6aac60ca35efb58230 Mon Sep 17 00:00:00 2001 From: Evan Kravitz Date: Fri, 21 Jan 2022 15:30:12 +0000 Subject: [PATCH 3/5] change: cleanup code, docstrings --- src/sagemaker/image_uris.py | 11 +++-- src/sagemaker/jumpstart/artifacts.py | 64 +++++++++++++++++----------- src/sagemaker/jumpstart/utils.py | 14 +++--- src/sagemaker/model_uris.py | 11 +++-- src/sagemaker/script_uris.py | 11 +++-- 5 files changed, 67 insertions(+), 44 deletions(-) diff --git a/src/sagemaker/image_uris.py b/src/sagemaker/image_uris.py index 717ee64d5c..bc29bde7fe 100644 --- a/src/sagemaker/image_uris.py +++ b/src/sagemaker/image_uris.py @@ -81,10 +81,13 @@ def retrieve( (default: None). model_version (str): Version of the JumpStart model for which to retrieve the image URI (default: None). - tolerate_vulnerable_model (bool): True if vulnerable models should be tolerated (exception - not raised). False if these models should raise an exception. (Default: None). - tolerate_deprecated_model (bool): True if deprecated models should be tolerated (exception - not raised). False if these models should raise an exception. (Default: None). + tolerate_vulnerable_model (bool): True if vulnerable versions of model specifications + should be tolerated (exception not raised). False or None, raises an exception if + the script used by this version of the model has dependencies with known security + vulnerabilities. (Default: None). + tolerate_deprecated_model (bool): True if deprecated versions of model specifications + should be tolerated (exception not raised). False or None, raises an exception + if the version of the model is deprecated. (Default: None). Returns: str: the ECR URI for the corresponding SageMaker Docker image. diff --git a/src/sagemaker/jumpstart/artifacts.py b/src/sagemaker/jumpstart/artifacts.py index 8fba926bac..cb9e3515b9 100644 --- a/src/sagemaker/jumpstart/artifacts.py +++ b/src/sagemaker/jumpstart/artifacts.py @@ -75,18 +75,22 @@ def _retrieve_image_uri( distribution (dict): A dictionary with information on how to run distributed training training_compiler_config (:class:`~sagemaker.training_compiler.TrainingCompilerConfig`): A configuration class for the SageMaker Training Compiler. - tolerate_vulnerable_model (bool): True if vulnerable models should be tolerated (exception - not raised). False if these models should raise an exception. - tolerate_deprecated_model (bool): True if deprecated models should be tolerated (exception - not raised). False if these models should raise an exception. + tolerate_vulnerable_model (Optional[bool]): True if vulnerable versions of model + specifications should be tolerated (exception not raised). False or None, raises an + exception if the script used by this version of the model has dependencies with known + security vulnerabilities. + tolerate_deprecated_model (Optional[bool]): True if deprecated versions of model + specifications should be tolerated (exception not raised). False or None, raises + an exception if the version of the model is deprecated. Returns: str: the ECR URI for the corresponding SageMaker Docker image. Raises: ValueError: If the combination of arguments specified is not supported. - VulnerableJumpStartModelError: If the model is vulnerable. - DeprecatedJumpStartModelError: If the model is deprecated. + VulnerableJumpStartModelError: If any of the dependencies required by the script have + known security vulnerabilities. + DeprecatedJumpStartModelError: If the version of the model is deprecated. """ if region is None: region = JUMPSTART_DEFAULT_REGION_NAME @@ -102,9 +106,9 @@ def _retrieve_image_uri( tolerate_deprecated_model=tolerate_deprecated_model, ) - if image_scope == JumpStartScriptScope.INFERENCE.value: + if image_scope == JumpStartScriptScope.INFERENCE: ecr_specs = model_specs.hosting_ecr_specs - elif image_scope == JumpStartScriptScope.TRAINING.value: + elif image_scope == JumpStartScriptScope.TRAINING: assert model_specs.training_ecr_specs is not None ecr_specs = model_specs.training_ecr_specs @@ -128,11 +132,11 @@ def _retrieve_image_uri( base_framework_version_override: Optional[str] = None version_override: Optional[str] = None - if ecr_specs.framework == ModelFramework.HUGGINGFACE.value: + if ecr_specs.framework == ModelFramework.HUGGINGFACE: base_framework_version_override = ecr_specs.framework_version version_override = ecr_specs.huggingface_transformers_version - if image_scope == JumpStartScriptScope.TRAINING.value: + if image_scope == JumpStartScriptScope.TRAINING: return image_uris.get_training_image_uri( region=region, framework=ecr_specs.framework, @@ -181,17 +185,21 @@ def _retrieve_model_uri( model_scope (str): The model type, i.e. what it is used for. Valid values: "training" and "inference". region (str): Region for which to retrieve model S3 URI. - tolerate_vulnerable_model (bool): True if vulnerable models should be tolerated (exception - not raised). False if these models should raise an exception. - tolerate_deprecated_model (bool): True if deprecated models should be tolerated (exception - not raised). False if these models should raise an exception. + tolerate_vulnerable_model (Optional[bool]): True if vulnerable versions of model + specifications should be tolerated (exception not raised). False or None, raises an + exception if the script used by this version of the model has dependencies with known + security vulnerabilities. + tolerate_deprecated_model (Optional[bool]): True if deprecated versions of model + specifications should be tolerated (exception not raised). False or None, raises + an exception if the version of the model is deprecated. Returns: str: the model artifact S3 URI for the corresponding model. Raises: ValueError: If the combination of arguments specified is not supported. - VulnerableJumpStartModelError: If the model is vulnerable. - DeprecatedJumpStartModelError: If the model is deprecated. + VulnerableJumpStartModelError: If any of the dependencies required by the script have + known security vulnerabilities. + DeprecatedJumpStartModelError: If the version of the model is deprecated. """ if region is None: region = JUMPSTART_DEFAULT_REGION_NAME @@ -207,9 +215,9 @@ def _retrieve_model_uri( tolerate_deprecated_model=tolerate_deprecated_model, ) - if model_scope == JumpStartScriptScope.INFERENCE.value: + if model_scope == JumpStartScriptScope.INFERENCE: model_artifact_key = model_specs.hosting_artifact_key - elif model_scope == JumpStartScriptScope.TRAINING.value: + elif model_scope == JumpStartScriptScope.TRAINING: assert model_specs.training_artifact_key is not None model_artifact_key = model_specs.training_artifact_key @@ -238,17 +246,21 @@ def _retrieve_script_uri( script_scope (str): The script type, i.e. what it is used for. Valid values: "training" and "inference". region (str): Region for which to retrieve model script S3 URI. - tolerate_vulnerable_model (bool): True if vulnerable models should be tolerated (exception - not raised). False if these models should raise an exception. - tolerate_deprecated_model (bool): True if deprecated models should be tolerated (exception - not raised). False if these models should raise an exception. + tolerate_vulnerable_model (Optional[bool]): True if vulnerable versions of model + specifications should be tolerated (exception not raised). False or None, raises an + exception if the script used by this version of the model has dependencies with known + security vulnerabilities. + tolerate_deprecated_model (Optional[bool]): True if deprecated versions of model + specifications should be tolerated (exception not raised). False or None, raises + an exception if the version of the model is deprecated. Returns: str: the model script URI for the corresponding model. Raises: ValueError: If the combination of arguments specified is not supported. - VulnerableJumpStartModelError: If the model is vulnerable. - DeprecatedJumpStartModelError: If the model is deprecated. + VulnerableJumpStartModelError: If any of the dependencies required by the script have + known security vulnerabilities. + DeprecatedJumpStartModelError: If the version of the model is deprecated. """ if region is None: region = JUMPSTART_DEFAULT_REGION_NAME @@ -264,9 +276,9 @@ def _retrieve_script_uri( tolerate_deprecated_model=tolerate_deprecated_model, ) - if script_scope == JumpStartScriptScope.INFERENCE.value: + if script_scope == JumpStartScriptScope.INFERENCE: model_script_key = model_specs.hosting_script_key - elif script_scope == JumpStartScriptScope.TRAINING.value: + elif script_scope == JumpStartScriptScope.TRAINING: assert model_specs.training_script_key is not None model_script_key = model_specs.training_script_key diff --git a/src/sagemaker/jumpstart/utils.py b/src/sagemaker/jumpstart/utils.py index 906f49d74c..5b35eee96a 100644 --- a/src/sagemaker/jumpstart/utils.py +++ b/src/sagemaker/jumpstart/utils.py @@ -164,19 +164,21 @@ def verify_model_region_and_return_specs( scope (Optional[str]): scope of the JumpStart model to verify. region (Optional[str]): region of the JumpStart model to verify and obtains specs. - tolerate_vulnerable_model (Optional[bool]): True if vulnerable models should be tolerated - (exception not raised). False if these models should raise an exception. - (Default: None). + tolerate_vulnerable_model (Optional[bool]): True if vulnerable versions of model + specifications should be tolerated (exception not raised). False or None, raises an + exception if the script used by this version of the model has dependencies with known + security vulnerabilities. (Default: None). tolerate_deprecated_model (Optional[bool]): True if deprecated models should be tolerated (exception not raised). False if these models should raise an exception. (Default: None). Raises: - ValueError: If the combination of arguments specified is not supported. NotImplementedError: If the scope is not supported. - VulnerableJumpStartModelError: If the model is vulnerable. - DeprecatedJumpStartModelError: If the model is deprecated. + ValueError: If the combination of arguments specified is not supported. + VulnerableJumpStartModelError: If any of the dependencies required by the script have + known security vulnerabilities. + DeprecatedJumpStartModelError: If the version of the model is deprecated. """ if tolerate_vulnerable_model is None: diff --git a/src/sagemaker/model_uris.py b/src/sagemaker/model_uris.py index 1044b0d567..c37e0aa2e9 100644 --- a/src/sagemaker/model_uris.py +++ b/src/sagemaker/model_uris.py @@ -41,10 +41,13 @@ def retrieve( the model artifact S3 URI. model_scope (str): The model type, i.e. what it is used for. Valid values: "training" and "inference". - tolerate_vulnerable_model (bool): True if vulnerable models should be tolerated (exception - not raised). False if these models should raise an exception. (Default: None). - tolerate_deprecated_model (bool): True if deprecated models should be tolerated (exception - not raised). False if these models should raise an exception. (Default: None). + tolerate_vulnerable_model (Optional[bool]): True if vulnerable versions of model + specifications should be tolerated (exception not raised). False or None, raises an + exception if the script used by this version of the model has dependencies with known + security vulnerabilities. (Default: None). + tolerate_deprecated_model (Optional[bool]): True if deprecated versions of model + specifications should be tolerated (exception not raised). False or None, raises + an exception if the version of the model is deprecated. (Default: None). Returns: str: the model artifact S3 URI for the corresponding model. diff --git a/src/sagemaker/script_uris.py b/src/sagemaker/script_uris.py index 03f4599fd2..24323222f3 100644 --- a/src/sagemaker/script_uris.py +++ b/src/sagemaker/script_uris.py @@ -41,10 +41,13 @@ def retrieve( model script S3 URI. script_scope (str): The script type, i.e. what it is used for. Valid values: "training" and "inference". - tolerate_vulnerable_model (bool): True if vulnerable models should be tolerated (exception - not raised). False if these models should raise an exception. (Default: None). - tolerate_deprecated_model (bool): True if deprecated models should be tolerated (exception - not raised). False if these models should raise an exception. (Default: None). + tolerate_vulnerable_model (Optional[bool]): True if vulnerable versions of model + specifications should be tolerated (exception not raised). False or None, raises an + exception if the script used by this version of the model has dependencies with known + security vulnerabilities. (Default: None). + tolerate_deprecated_model (Optional[bool]): True if deprecated models should be tolerated + (exception not raised). False if these models should raise an exception. + (Default: None). Returns: str: the model script URI for the corresponding model. From f4711f878b63563d27d10b18a4b2fe26178184d8 Mon Sep 17 00:00:00 2001 From: Evan Kravitz Date: Fri, 21 Jan 2022 15:54:07 +0000 Subject: [PATCH 4/5] change: update raises section of retrieve docstrings --- src/sagemaker/image_uris.py | 4 ++++ src/sagemaker/model_uris.py | 4 ++++ src/sagemaker/script_uris.py | 4 ++++ 3 files changed, 12 insertions(+) diff --git a/src/sagemaker/image_uris.py b/src/sagemaker/image_uris.py index bc29bde7fe..d7b61587d1 100644 --- a/src/sagemaker/image_uris.py +++ b/src/sagemaker/image_uris.py @@ -93,7 +93,11 @@ def retrieve( str: the ECR URI for the corresponding SageMaker Docker image. Raises: + NotImplementedError: If the scope is not supported. ValueError: If the combination of arguments specified is not supported. + VulnerableJumpStartModelError: If any of the dependencies required by the script have + known security vulnerabilities. + DeprecatedJumpStartModelError: If the version of the model is deprecated. """ if is_jumpstart_model_input(model_id, model_version): diff --git a/src/sagemaker/model_uris.py b/src/sagemaker/model_uris.py index c37e0aa2e9..19228f25f3 100644 --- a/src/sagemaker/model_uris.py +++ b/src/sagemaker/model_uris.py @@ -52,7 +52,11 @@ def retrieve( str: the model artifact S3 URI for the corresponding model. Raises: + NotImplementedError: If the scope is not supported. ValueError: If the combination of arguments specified is not supported. + VulnerableJumpStartModelError: If any of the dependencies required by the script have + known security vulnerabilities. + DeprecatedJumpStartModelError: If the version of the model is deprecated. """ if not jumpstart_utils.is_jumpstart_model_input(model_id, model_version): raise ValueError("Must specify `model_id` and `model_version` when retrieving script URIs.") diff --git a/src/sagemaker/script_uris.py b/src/sagemaker/script_uris.py index 24323222f3..012f7667ee 100644 --- a/src/sagemaker/script_uris.py +++ b/src/sagemaker/script_uris.py @@ -52,7 +52,11 @@ def retrieve( str: the model script URI for the corresponding model. Raises: + NotImplementedError: If the scope is not supported. ValueError: If the combination of arguments specified is not supported. + VulnerableJumpStartModelError: If any of the dependencies required by the script have + known security vulnerabilities. + DeprecatedJumpStartModelError: If the version of the model is deprecated. """ if not jumpstart_utils.is_jumpstart_model_input(model_id, model_version): raise ValueError("Must specify `model_id` and `model_version` when retrieving script URIs.") From 42aa79cbce17696c3144ce87f18b35e3d19cb427 Mon Sep 17 00:00:00 2001 From: Evan Kravitz Date: Sat, 22 Jan 2022 15:31:55 +0000 Subject: [PATCH 5/5] change: log warnings for tolerated vulnerabilities/deprecations, improve default parameter values --- src/sagemaker/environment_variables.py | 4 - src/sagemaker/image_uris.py | 16 ++-- src/sagemaker/jumpstart/accessors.py | 7 +- src/sagemaker/jumpstart/artifacts.py | 47 +++++------- src/sagemaker/jumpstart/cache.py | 18 ++--- src/sagemaker/jumpstart/utils.py | 77 ++++++++++---------- src/sagemaker/model_uris.py | 22 +++--- src/sagemaker/script_uris.py | 19 ++--- tests/unit/sagemaker/jumpstart/test_utils.py | 72 +++++++++++------- 9 files changed, 131 insertions(+), 151 deletions(-) diff --git a/src/sagemaker/environment_variables.py b/src/sagemaker/environment_variables.py index d4646d2617..108dda4209 100644 --- a/src/sagemaker/environment_variables.py +++ b/src/sagemaker/environment_variables.py @@ -46,8 +46,4 @@ def retrieve_default( if not jumpstart_utils.is_jumpstart_model_input(model_id, model_version): raise ValueError("Must specify `model_id` and `model_version` when retrieving script URIs.") - # mypy type checking require these assertions - assert model_id is not None - assert model_version is not None - return artifacts._retrieve_default_environment_variables(model_id, model_version, region) diff --git a/src/sagemaker/image_uris.py b/src/sagemaker/image_uris.py index d7b61587d1..fa4fd782d3 100644 --- a/src/sagemaker/image_uris.py +++ b/src/sagemaker/image_uris.py @@ -45,8 +45,8 @@ def retrieve( training_compiler_config=None, model_id=None, model_version=None, - tolerate_vulnerable_model=None, - tolerate_deprecated_model=None, + tolerate_vulnerable_model=False, + tolerate_deprecated_model=False, ) -> str: """Retrieves the ECR URI for the Docker image matching the given arguments. @@ -82,12 +82,12 @@ def retrieve( model_version (str): Version of the JumpStart model for which to retrieve the image URI (default: None). tolerate_vulnerable_model (bool): True if vulnerable versions of model specifications - should be tolerated (exception not raised). False or None, raises an exception if + should be tolerated (exception not raised). If False, raises an exception if the script used by this version of the model has dependencies with known security - vulnerabilities. (Default: None). + vulnerabilities. (Default: False). tolerate_deprecated_model (bool): True if deprecated versions of model specifications - should be tolerated (exception not raised). False or None, raises an exception - if the version of the model is deprecated. (Default: None). + should be tolerated (exception not raised). If False, raises an exception + if the version of the model is deprecated. (Default: False). Returns: str: the ECR URI for the corresponding SageMaker Docker image. @@ -101,10 +101,6 @@ def retrieve( """ if is_jumpstart_model_input(model_id, model_version): - # adding assert statements to satisfy mypy type checker - assert model_id is not None - assert model_version is not None - return artifacts._retrieve_image_uri( model_id, model_version, diff --git a/src/sagemaker/jumpstart/accessors.py b/src/sagemaker/jumpstart/accessors.py index d666824849..e297358251 100644 --- a/src/sagemaker/jumpstart/accessors.py +++ b/src/sagemaker/jumpstart/accessors.py @@ -56,7 +56,6 @@ def _validate_and_mutate_region_cache_kwargs( region (str): The region to validate along with the kwargs. """ cache_kwargs_dict = {} if cache_kwargs is None else cache_kwargs - assert isinstance(cache_kwargs_dict, dict) if region is not None and "region" in cache_kwargs_dict: if region != cache_kwargs_dict["region"]: raise ValueError( @@ -92,8 +91,7 @@ def get_model_header(region: str, model_id: str, version: str) -> JumpStartModel JumpStartModelsAccessor._cache_kwargs, region ) JumpStartModelsAccessor._set_cache_and_region(region, cache_kwargs) - assert JumpStartModelsAccessor._cache is not None - return JumpStartModelsAccessor._cache.get_header( + return JumpStartModelsAccessor._cache.get_header( # type: ignore model_id=model_id, semantic_version_str=version ) @@ -110,8 +108,7 @@ def get_model_specs(region: str, model_id: str, version: str) -> JumpStartModelS JumpStartModelsAccessor._cache_kwargs, region ) JumpStartModelsAccessor._set_cache_and_region(region, cache_kwargs) - assert JumpStartModelsAccessor._cache is not None - return JumpStartModelsAccessor._cache.get_specs( + return JumpStartModelsAccessor._cache.get_specs( # type: ignore model_id=model_id, semantic_version_str=version ) diff --git a/src/sagemaker/jumpstart/artifacts.py b/src/sagemaker/jumpstart/artifacts.py index cb9e3515b9..7c9b835b3c 100644 --- a/src/sagemaker/jumpstart/artifacts.py +++ b/src/sagemaker/jumpstart/artifacts.py @@ -41,8 +41,8 @@ def _retrieve_image_uri( distribution: Optional[str], base_framework_version: Optional[str], training_compiler_config: Optional[str], - tolerate_vulnerable_model: Optional[bool], - tolerate_deprecated_model: Optional[bool], + tolerate_vulnerable_model: bool, + tolerate_deprecated_model: bool, ): """Retrieves the container image URI for JumpStart models. @@ -75,12 +75,12 @@ def _retrieve_image_uri( distribution (dict): A dictionary with information on how to run distributed training training_compiler_config (:class:`~sagemaker.training_compiler.TrainingCompilerConfig`): A configuration class for the SageMaker Training Compiler. - tolerate_vulnerable_model (Optional[bool]): True if vulnerable versions of model - specifications should be tolerated (exception not raised). False or None, raises an + tolerate_vulnerable_model (bool): True if vulnerable versions of model + specifications should be tolerated (exception not raised). If False, raises an exception if the script used by this version of the model has dependencies with known security vulnerabilities. - tolerate_deprecated_model (Optional[bool]): True if deprecated versions of model - specifications should be tolerated (exception not raised). False or None, raises + tolerate_deprecated_model (bool): True if deprecated versions of model + specifications should be tolerated (exception not raised). If False, raises an exception if the version of the model is deprecated. Returns: @@ -95,8 +95,6 @@ def _retrieve_image_uri( if region is None: region = JUMPSTART_DEFAULT_REGION_NAME - assert region is not None - model_specs = verify_model_region_and_return_specs( model_id=model_id, version=model_version, @@ -109,7 +107,6 @@ def _retrieve_image_uri( if image_scope == JumpStartScriptScope.INFERENCE: ecr_specs = model_specs.hosting_ecr_specs elif image_scope == JumpStartScriptScope.TRAINING: - assert model_specs.training_ecr_specs is not None ecr_specs = model_specs.training_ecr_specs if framework is not None and framework != ecr_specs.framework: @@ -172,8 +169,8 @@ def _retrieve_model_uri( model_version: str, model_scope: Optional[str], region: Optional[str], - tolerate_vulnerable_model: Optional[bool], - tolerate_deprecated_model: Optional[bool], + tolerate_vulnerable_model: bool, + tolerate_deprecated_model: bool, ): """Retrieves the model artifact S3 URI for the model matching the given arguments. @@ -185,12 +182,12 @@ def _retrieve_model_uri( model_scope (str): The model type, i.e. what it is used for. Valid values: "training" and "inference". region (str): Region for which to retrieve model S3 URI. - tolerate_vulnerable_model (Optional[bool]): True if vulnerable versions of model - specifications should be tolerated (exception not raised). False or None, raises an + tolerate_vulnerable_model (bool): True if vulnerable versions of model + specifications should be tolerated (exception not raised). If False, raises an exception if the script used by this version of the model has dependencies with known security vulnerabilities. - tolerate_deprecated_model (Optional[bool]): True if deprecated versions of model - specifications should be tolerated (exception not raised). False or None, raises + tolerate_deprecated_model (bool): True if deprecated versions of model + specifications should be tolerated (exception not raised). If False, raises an exception if the version of the model is deprecated. Returns: str: the model artifact S3 URI for the corresponding model. @@ -204,8 +201,6 @@ def _retrieve_model_uri( if region is None: region = JUMPSTART_DEFAULT_REGION_NAME - assert region is not None - model_specs = verify_model_region_and_return_specs( model_id=model_id, version=model_version, @@ -218,7 +213,6 @@ def _retrieve_model_uri( if model_scope == JumpStartScriptScope.INFERENCE: model_artifact_key = model_specs.hosting_artifact_key elif model_scope == JumpStartScriptScope.TRAINING: - assert model_specs.training_artifact_key is not None model_artifact_key = model_specs.training_artifact_key bucket = get_jumpstart_content_bucket(region) @@ -233,8 +227,8 @@ def _retrieve_script_uri( model_version: str, script_scope: Optional[str], region: Optional[str], - tolerate_vulnerable_model: Optional[bool], - tolerate_deprecated_model: Optional[bool], + tolerate_vulnerable_model: bool, + tolerate_deprecated_model: bool, ): """Retrieves the script S3 URI associated with the model matching the given arguments. @@ -246,12 +240,12 @@ def _retrieve_script_uri( script_scope (str): The script type, i.e. what it is used for. Valid values: "training" and "inference". region (str): Region for which to retrieve model script S3 URI. - tolerate_vulnerable_model (Optional[bool]): True if vulnerable versions of model - specifications should be tolerated (exception not raised). False or None, raises an + tolerate_vulnerable_model (bool): True if vulnerable versions of model + specifications should be tolerated (exception not raised). If False, raises an exception if the script used by this version of the model has dependencies with known security vulnerabilities. - tolerate_deprecated_model (Optional[bool]): True if deprecated versions of model - specifications should be tolerated (exception not raised). False or None, raises + tolerate_deprecated_model (bool): True if deprecated versions of model + specifications should be tolerated (exception not raised). If False, raises an exception if the version of the model is deprecated. Returns: str: the model script URI for the corresponding model. @@ -265,8 +259,6 @@ def _retrieve_script_uri( if region is None: region = JUMPSTART_DEFAULT_REGION_NAME - assert region is not None - model_specs = verify_model_region_and_return_specs( model_id=model_id, version=model_version, @@ -279,7 +271,6 @@ def _retrieve_script_uri( if script_scope == JumpStartScriptScope.INFERENCE: model_script_key = model_specs.hosting_script_key elif script_scope == JumpStartScriptScope.TRAINING: - assert model_specs.training_script_key is not None model_script_key = model_specs.training_script_key bucket = get_jumpstart_content_bucket(region) @@ -317,8 +308,6 @@ def _retrieve_default_hyperparameters( if region is None: region = JUMPSTART_DEFAULT_REGION_NAME - assert region is not None - model_specs = jumpstart_accessors.JumpStartModelsAccessor.get_model_specs( region=region, model_id=model_id, version=model_version ) diff --git a/src/sagemaker/jumpstart/cache.py b/src/sagemaker/jumpstart/cache.py index fbd711ddf7..26284419de 100644 --- a/src/sagemaker/jumpstart/cache.py +++ b/src/sagemaker/jumpstart/cache.py @@ -166,13 +166,12 @@ def _get_manifest_key_from_model_id_semantic_version( manifest = self._s3_cache.get( JumpStartCachedS3ContentKey(JumpStartS3FileType.MANIFEST, self._manifest_file_s3_key) ).formatted_content - assert isinstance(manifest, dict) sm_version = utils.get_sagemaker_version() versions_compatible_with_sagemaker = [ Version(header.version) - for header in manifest.values() + for header in manifest.values() # type: ignore if header.model_id == model_id and Version(header.min_version) <= Version(sm_version) ] @@ -184,7 +183,8 @@ def _get_manifest_key_from_model_id_semantic_version( return JumpStartVersionedModelId(model_id, sm_compatible_model_version) versions_incompatible_with_sagemaker = [ - Version(header.version) for header in manifest.values() if header.model_id == model_id + Version(header.version) for header in manifest.values() # type: ignore + if header.model_id == model_id ] sm_incompatible_model_version = self._select_version( version, versions_incompatible_with_sagemaker @@ -194,7 +194,7 @@ def _get_manifest_key_from_model_id_semantic_version( model_version_to_use_incompatible_with_sagemaker = sm_incompatible_model_version sm_version_to_use_list = [ header.min_version - for header in manifest.values() + for header in manifest.values() # type: ignore if header.model_id == model_id and header.version == model_version_to_use_incompatible_with_sagemaker ] @@ -262,8 +262,7 @@ def get_manifest(self) -> List[JumpStartModelHeader]: manifest_dict = self._s3_cache.get( JumpStartCachedS3ContentKey(JumpStartS3FileType.MANIFEST, self._manifest_file_s3_key) ).formatted_content - assert isinstance(manifest_dict, dict) - manifest = list(manifest_dict.values()) + manifest = list(manifest_dict.values()) # type: ignore return manifest def get_header(self, model_id: str, semantic_version_str: str) -> JumpStartModelHeader: @@ -324,9 +323,7 @@ def _get_header_impl( JumpStartCachedS3ContentKey(JumpStartS3FileType.MANIFEST, self._manifest_file_s3_key) ).formatted_content try: - assert isinstance(manifest, dict) - header = manifest[versioned_model_id] - assert isinstance(header, JumpStartModelHeader) + header = manifest[versioned_model_id] # type: ignore return header except KeyError: if attempt > 0: @@ -348,8 +345,7 @@ def get_specs(self, model_id: str, semantic_version_str: str) -> JumpStartModelS specs = self._s3_cache.get( JumpStartCachedS3ContentKey(JumpStartS3FileType.SPECS, spec_key) ).formatted_content - assert isinstance(specs, JumpStartModelSpecs) - return specs + return specs # type: ignore def clear(self) -> None: """Clears the model id/version and s3 cache.""" diff --git a/src/sagemaker/jumpstart/utils.py b/src/sagemaker/jumpstart/utils.py index 5b35eee96a..3d87ade3c1 100644 --- a/src/sagemaker/jumpstart/utils.py +++ b/src/sagemaker/jumpstart/utils.py @@ -12,6 +12,7 @@ # language governing permissions and limitations under the License. """This module contains utilities related to SageMaker JumpStart.""" from __future__ import absolute_import +import logging from typing import Dict, List, Optional from packaging.version import Version import sagemaker @@ -28,6 +29,9 @@ ) +LOGGER = logging.getLogger(__name__) + + def get_jumpstart_launched_regions_message() -> str: """Returns formatted string indicating where JumpStart is launched.""" if len(constants.JUMPSTART_REGION_NAME_SET) == 0: @@ -151,8 +155,8 @@ def verify_model_region_and_return_specs( version: Optional[str], scope: Optional[str], region: str, - tolerate_vulnerable_model: Optional[bool] = None, - tolerate_deprecated_model: Optional[bool] = None, + tolerate_vulnerable_model: bool = False, + tolerate_deprecated_model: bool = False, ) -> JumpStartModelSpecs: """Verifies that an acceptable model_id, version, scope, and region combination is provided. @@ -164,13 +168,13 @@ def verify_model_region_and_return_specs( scope (Optional[str]): scope of the JumpStart model to verify. region (Optional[str]): region of the JumpStart model to verify and obtains specs. - tolerate_vulnerable_model (Optional[bool]): True if vulnerable versions of model - specifications should be tolerated (exception not raised). False or None, raises an + tolerate_vulnerable_model (bool): True if vulnerable versions of model + specifications should be tolerated (exception not raised). If False, raises an exception if the script used by this version of the model has dependencies with known - security vulnerabilities. (Default: None). - tolerate_deprecated_model (Optional[bool]): True if deprecated models should be tolerated + security vulnerabilities. (Default: False). + tolerate_deprecated_model (bool): True if deprecated models should be tolerated (exception not raised). False if these models should raise an exception. - (Default: None). + (Default: False). Raises: @@ -181,12 +185,6 @@ def verify_model_region_and_return_specs( DeprecatedJumpStartModelError: If the version of the model is deprecated. """ - if tolerate_vulnerable_model is None: - tolerate_vulnerable_model = False - - if tolerate_deprecated_model is None: - tolerate_deprecated_model = False - if scope is None: raise ValueError( "Must specify `model_scope` argument to retrieve model " @@ -199,11 +197,8 @@ def verify_model_region_and_return_specs( f"{', '.join(constants.SUPPORTED_JUMPSTART_SCOPES)}." ) - assert model_id is not None - assert version is not None - model_specs = accessors.JumpStartModelsAccessor.get_model_specs( - region=region, model_id=model_id, version=version + region=region, model_id=model_id, version=version # type: ignore ) if ( @@ -214,31 +209,33 @@ def verify_model_region_and_return_specs( f"JumpStart model ID '{model_id}' and version '{version}' " "does not support training." ) - if model_specs.deprecated and not tolerate_deprecated_model: - raise DeprecatedJumpStartModelError(model_id=model_id, version=version) - - if ( - scope == constants.JumpStartScriptScope.INFERENCE.value - and model_specs.inference_vulnerable - and not tolerate_vulnerable_model - ): - raise VulnerableJumpStartModelError( - model_id=model_id, - version=version, - vulnerabilities=model_specs.inference_vulnerabilities, - scope=constants.JumpStartScriptScope.INFERENCE, + if model_specs.deprecated: + if not tolerate_deprecated_model: + raise DeprecatedJumpStartModelError(model_id=model_id, version=version) + LOGGER.warning("Using deprecated JumpStart model '%s' and version '%s'.", model_id, version) + + if scope == constants.JumpStartScriptScope.INFERENCE.value and model_specs.inference_vulnerable: + if not tolerate_vulnerable_model: + raise VulnerableJumpStartModelError( + model_id=model_id, + version=version, + vulnerabilities=model_specs.inference_vulnerabilities, + scope=constants.JumpStartScriptScope.INFERENCE, + ) + LOGGER.warning( + "Using vulnerable JumpStart model '%s' and version '%s' (inference).", model_id, version ) - if ( - scope == constants.JumpStartScriptScope.TRAINING.value - and model_specs.training_vulnerable - and not tolerate_vulnerable_model - ): - raise VulnerableJumpStartModelError( - model_id=model_id, - version=version, - vulnerabilities=model_specs.training_vulnerabilities, - scope=constants.JumpStartScriptScope.TRAINING, + if scope == constants.JumpStartScriptScope.TRAINING.value and model_specs.training_vulnerable: + if not tolerate_vulnerable_model: + raise VulnerableJumpStartModelError( + model_id=model_id, + version=version, + vulnerabilities=model_specs.training_vulnerabilities, + scope=constants.JumpStartScriptScope.TRAINING, + ) + LOGGER.warning( + "Using vulnerable JumpStart model '%s' and version '%s' (training).", model_id, version ) return model_specs diff --git a/src/sagemaker/model_uris.py b/src/sagemaker/model_uris.py index 19228f25f3..8894583f89 100644 --- a/src/sagemaker/model_uris.py +++ b/src/sagemaker/model_uris.py @@ -28,8 +28,8 @@ def retrieve( model_id=None, model_version: Optional[str] = None, model_scope: Optional[str] = None, - tolerate_vulnerable_model: Optional[bool] = None, - tolerate_deprecated_model: Optional[bool] = None, + tolerate_vulnerable_model: bool = False, + tolerate_deprecated_model: bool = False, ) -> str: """Retrieves the model artifact S3 URI for the model matching the given arguments. @@ -41,13 +41,13 @@ def retrieve( the model artifact S3 URI. model_scope (str): The model type, i.e. what it is used for. Valid values: "training" and "inference". - tolerate_vulnerable_model (Optional[bool]): True if vulnerable versions of model - specifications should be tolerated (exception not raised). False or None, raises an + tolerate_vulnerable_model (bool): True if vulnerable versions of model + specifications should be tolerated (exception not raised). If False, raises an exception if the script used by this version of the model has dependencies with known - security vulnerabilities. (Default: None). - tolerate_deprecated_model (Optional[bool]): True if deprecated versions of model - specifications should be tolerated (exception not raised). False or None, raises - an exception if the version of the model is deprecated. (Default: None). + security vulnerabilities. (Default: False). + tolerate_deprecated_model (bool): True if deprecated versions of model + specifications should be tolerated (exception not raised). If False, raises + an exception if the version of the model is deprecated. (Default: False). Returns: str: the model artifact S3 URI for the corresponding model. @@ -61,13 +61,9 @@ def retrieve( if not jumpstart_utils.is_jumpstart_model_input(model_id, model_version): raise ValueError("Must specify `model_id` and `model_version` when retrieving script URIs.") - # mypy type checking require these assertions - assert model_id is not None - assert model_version is not None - return artifacts._retrieve_model_uri( model_id, - model_version, + model_version, # type: ignore model_scope, region, tolerate_vulnerable_model, diff --git a/src/sagemaker/script_uris.py b/src/sagemaker/script_uris.py index 012f7667ee..77fda3ce26 100644 --- a/src/sagemaker/script_uris.py +++ b/src/sagemaker/script_uris.py @@ -15,7 +15,6 @@ from __future__ import absolute_import import logging -from typing import Optional from sagemaker.jumpstart import utils as jumpstart_utils from sagemaker.jumpstart import artifacts @@ -28,8 +27,8 @@ def retrieve( model_id=None, model_version=None, script_scope=None, - tolerate_vulnerable_model: Optional[bool] = None, - tolerate_deprecated_model: Optional[bool] = None, + tolerate_vulnerable_model: bool = False, + tolerate_deprecated_model: bool = False, ) -> str: """Retrieves the script S3 URI associated with the model matching the given arguments. @@ -41,13 +40,13 @@ def retrieve( model script S3 URI. script_scope (str): The script type, i.e. what it is used for. Valid values: "training" and "inference". - tolerate_vulnerable_model (Optional[bool]): True if vulnerable versions of model - specifications should be tolerated (exception not raised). False or None, raises an + tolerate_vulnerable_model (bool): True if vulnerable versions of model + specifications should be tolerated (exception not raised). If False, raises an exception if the script used by this version of the model has dependencies with known - security vulnerabilities. (Default: None). - tolerate_deprecated_model (Optional[bool]): True if deprecated models should be tolerated + security vulnerabilities. (Default: False). + tolerate_deprecated_model (bool): True if deprecated models should be tolerated (exception not raised). False if these models should raise an exception. - (Default: None). + (Default: False). Returns: str: the model script URI for the corresponding model. @@ -61,10 +60,6 @@ def retrieve( if not jumpstart_utils.is_jumpstart_model_input(model_id, model_version): raise ValueError("Must specify `model_id` and `model_version` when retrieving script URIs.") - # mypy type checking require these assertions - assert model_id is not None - assert model_version is not None - return artifacts._retrieve_script_uri( model_id, model_version, diff --git a/tests/unit/sagemaker/jumpstart/test_utils.py b/tests/unit/sagemaker/jumpstart/test_utils.py index 388cc23d48..4401513031 100644 --- a/tests/unit/sagemaker/jumpstart/test_utils.py +++ b/tests/unit/sagemaker/jumpstart/test_utils.py @@ -143,16 +143,22 @@ def make_vulnerable_inference_spec(*largs, **kwargs): "List of vulnerabilities: some, vulnerability" ) == str(e.value.message) - assert ( - utils.verify_model_region_and_return_specs( - model_id="pytorch-eqa-bert-base-cased", - version="*", - scope=JumpStartScriptScope.INFERENCE.value, - region="us-west-2", - tolerate_vulnerable_model=True, + with patch("logging.Logger.warning") as mocked_warning_log: + assert ( + utils.verify_model_region_and_return_specs( + model_id="pytorch-eqa-bert-base-cased", + version="*", + scope=JumpStartScriptScope.INFERENCE.value, + region="us-west-2", + tolerate_vulnerable_model=True, + ) + is not None + ) + mocked_warning_log.assert_called_once_with( + "Using vulnerable JumpStart model '%s' and version '%s' (inference).", + "pytorch-eqa-bert-base-cased", + "*", ) - is not None - ) def make_vulnerable_training_spec(*largs, **kwargs): spec = get_spec_from_base_spec(*largs, **kwargs) @@ -176,16 +182,22 @@ def make_vulnerable_training_spec(*largs, **kwargs): "List of vulnerabilities: some, vulnerability" ) == str(e.value.message) - assert ( - utils.verify_model_region_and_return_specs( - model_id="pytorch-eqa-bert-base-cased", - version="*", - scope=JumpStartScriptScope.TRAINING.value, - region="us-west-2", - tolerate_vulnerable_model=True, + with patch("logging.Logger.warning") as mocked_warning_log: + assert ( + utils.verify_model_region_and_return_specs( + model_id="pytorch-eqa-bert-base-cased", + version="*", + scope=JumpStartScriptScope.TRAINING.value, + region="us-west-2", + tolerate_vulnerable_model=True, + ) + is not None + ) + mocked_warning_log.assert_called_once_with( + "Using vulnerable JumpStart model '%s' and version '%s' (training).", + "pytorch-eqa-bert-base-cased", + "*", ) - is not None - ) @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @@ -207,13 +219,19 @@ def make_deprecated_spec(*largs, **kwargs): assert "Version '*' of JumpStart model 'pytorch-eqa-bert-base-cased' is deprecated. " "Please try targetting a higher version of the model." == str(e.value.message) - assert ( - utils.verify_model_region_and_return_specs( - model_id="pytorch-eqa-bert-base-cased", - version="*", - scope=JumpStartScriptScope.INFERENCE.value, - region="us-west-2", - tolerate_deprecated_model=True, + with patch("logging.Logger.warning") as mocked_warning_log: + assert ( + utils.verify_model_region_and_return_specs( + model_id="pytorch-eqa-bert-base-cased", + version="*", + scope=JumpStartScriptScope.INFERENCE.value, + region="us-west-2", + tolerate_deprecated_model=True, + ) + is not None + ) + mocked_warning_log.assert_called_once_with( + "Using deprecated JumpStart model '%s' and version '%s'.", + "pytorch-eqa-bert-base-cased", + "*", ) - is not None - )