From 3712df6860a4cd879e51382f99a01c78f1c126ca Mon Sep 17 00:00:00 2001 From: Evan Kravitz Date: Wed, 10 Nov 2021 19:28:42 +0000 Subject: [PATCH 01/13] feat: client cache for jumpstart models --- setup.py | 1 + src/sagemaker/jumpstart/__init__.py | 0 src/sagemaker/jumpstart/cache.py | 280 ++++++++++ src/sagemaker/jumpstart/constants.py | 20 + src/sagemaker/jumpstart/types.py | 184 +++++++ src/sagemaker/jumpstart/utils.py | 100 ++++ src/sagemaker/utilities/__init__.py | 0 src/sagemaker/utilities/cache.py | 150 ++++++ tests/unit/sagemaker/jumpstart/__init__.py | 0 tests/unit/sagemaker/jumpstart/test_cache.py | 540 +++++++++++++++++++ tests/unit/sagemaker/jumpstart/test_types.py | 122 +++++ tests/unit/sagemaker/jumpstart/test_utils.py | 103 ++++ tests/unit/sagemaker/utilities/__init__.py | 0 tests/unit/sagemaker/utilities/test_cache.py | 163 ++++++ 14 files changed, 1663 insertions(+) create mode 100644 src/sagemaker/jumpstart/__init__.py create mode 100644 src/sagemaker/jumpstart/cache.py create mode 100644 src/sagemaker/jumpstart/constants.py create mode 100644 src/sagemaker/jumpstart/types.py create mode 100644 src/sagemaker/jumpstart/utils.py create mode 100644 src/sagemaker/utilities/__init__.py create mode 100644 src/sagemaker/utilities/cache.py create mode 100644 tests/unit/sagemaker/jumpstart/__init__.py create mode 100644 tests/unit/sagemaker/jumpstart/test_cache.py create mode 100644 tests/unit/sagemaker/jumpstart/test_types.py create mode 100644 tests/unit/sagemaker/jumpstart/test_utils.py create mode 100644 tests/unit/sagemaker/utilities/__init__.py create mode 100644 tests/unit/sagemaker/utilities/test_cache.py diff --git a/setup.py b/setup.py index 83295fba71..0d2c770470 100644 --- a/setup.py +++ b/setup.py @@ -44,6 +44,7 @@ def read_version(): "packaging>=20.0", "pandas", "pathos", + "semantic-version", ] # Specific use case dependencies diff --git a/src/sagemaker/jumpstart/__init__.py b/src/sagemaker/jumpstart/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/sagemaker/jumpstart/cache.py b/src/sagemaker/jumpstart/cache.py new file mode 100644 index 0000000000..fabae85ea7 --- /dev/null +++ b/src/sagemaker/jumpstart/cache.py @@ -0,0 +1,280 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +import datetime +from typing import List, Optional +from sagemaker.jumpstart.types import ( + JumpStartCachedS3ContentKey, + JumpStartCachedS3ContentValue, + JumpStartModelHeader, + JumpStartModelSpecs, + JumpStartModelSpecs, + JumpStartS3FileType, + JumpStartVersionedModelId, +) +from sagemaker.jumpstart import utils +from sagemaker.utilities.cache import LRUCache +import boto3 +import json +import semantic_version + + +DEFAULT_REGION_NAME = boto3.session.Session().region_name + +DEFAULT_MAX_S3_CACHE_ITEMS = 20 +DEFAULT_S3_CACHE_EXPIRATION_TIME = datetime.timedelta(hours=6) + +DEFAULT_MAX_SEMANTIC_VERSION_CACHE_ITEMS = 20 +DEFAULT_SEMANTIC_VERSION_CACHE_EXPIRATION_TIME = datetime.timedelta(hours=6) + +DEFAULT_MANIFEST_FILE_S3_KEY = "models_manifest.json" + + +class JumpStartModelsCache: + """Class that implements a cache for JumpStart models manifests and specs. + The manifest and specs associated with JumpStart models provide the information necessary + for launching JumpStart models from the SageMaker SDK. + """ + + def __init__( + self, + region: Optional[str] = DEFAULT_REGION_NAME, + max_s3_cache_items: Optional[int] = DEFAULT_MAX_S3_CACHE_ITEMS, + s3_cache_expiration_time: Optional[datetime.timedelta] = DEFAULT_S3_CACHE_EXPIRATION_TIME, + max_semantic_version_cache_items: Optional[int] = DEFAULT_MAX_SEMANTIC_VERSION_CACHE_ITEMS, + semantic_version_cache_expiration_time: Optional[ + datetime.timedelta + ] = DEFAULT_SEMANTIC_VERSION_CACHE_EXPIRATION_TIME, + manifest_file_s3_key: Optional[str] = DEFAULT_MANIFEST_FILE_S3_KEY, + bucket: Optional[str] = None, + ) -> None: + """Initialize a ``JumpStartModelsCache`` instance. + + Args: + region (Optional[str]): AWS region to associate with cache. Default: region associated + with botocore session. + max_s3_cache_items (Optional[int]): Maximum number of files to store in s3 cache. Default: 20. + s3_cache_expiration_time (Optional[datetime.timedelta]): Maximum time to hold items in s3 + cache before invalidation. Default: 6 hours. + max_semantic_version_cache_items (Optional[int]): Maximum number of files to store in + semantic version cache. Default: 20. + semantic_version_cache_expiration_time (Optional[datetime.timedelta]): Maximum time to hold + items in semantic version cache before invalidation. Default: 6 hours. + bucket (Optional[str]): S3 bucket to associate with cache. Default: JumpStart-hosted content + bucket for region. + """ + + self._region = region + self._s3_cache = LRUCache[JumpStartCachedS3ContentKey, JumpStartCachedS3ContentValue]( + max_cache_items=max_s3_cache_items, + expiration_time=s3_cache_expiration_time, + retrieval_function=self._get_file_from_s3, + ) + self._model_id_semantic_version_manifest_key_cache = LRUCache[ + JumpStartVersionedModelId, JumpStartVersionedModelId + ]( + max_cache_items=max_semantic_version_cache_items, + expiration_time=semantic_version_cache_expiration_time, + retrieval_function=self._get_manifest_key_from_model_id_semantic_version, + ) + self._manifest_file_s3_key = manifest_file_s3_key + self._bucket = ( + utils.get_jumpstart_content_bucket(self._region) if bucket is None else bucket + ) + self._has_retried_cache_refresh = False + + def set_region(self, region: str) -> None: + """Set region for cache. Clears cache after new region is set.""" + self._region = region + self.clear() + + def get_region(self) -> str: + """Return region for cache.""" + return self._region + + def set_manifest_file_s3_key(self, key: str) -> None: + """Set manifest file s3 key. Clears cache after new key is set.""" + self._manifest_file_s3_key = key + self.clear() + + def get_manifest_file_s3_key(self) -> None: + """Return manifest file s3 key for cache.""" + return self._manifest_file_s3_key + + def set_bucket(self, bucket: str) -> None: + """Set s3 bucket used for cache.""" + self._bucket = bucket + self.clear() + + def get_bucket(self) -> None: + """Return bucket used for cache.""" + return self._bucket + + def _get_manifest_key_from_model_id_semantic_version( + self, key: JumpStartVersionedModelId, value: Optional[JumpStartVersionedModelId] + ) -> JumpStartVersionedModelId: + """Return model id and version in manifest that matches semantic version/id + from customer request. + + Args: + key (JumpStartVersionedModelId): Key for which to fetch versioned model id. + value (Optional[JumpStartVersionedModelId]): Unused variable for current value of old cached + model id/version. + + Raises: + KeyError: If the semantic version is not found in the manifest. + """ + + model_id, version = key.model_id, key.version + + manifest = self._s3_cache.get( + JumpStartCachedS3ContentKey(JumpStartS3FileType.MANIFEST, self._manifest_file_s3_key) + ).formatted_file_content + + sm_version = utils.get_sagemaker_version() + + versions_compatible_with_sagemaker = [ + semantic_version.Version(header.version) + for _, header in manifest.items() + if header.model_id == model_id + and semantic_version.Version(header.min_version) <= semantic_version.Version(sm_version) + ] + + spec = ( + semantic_version.SimpleSpec("*") + if version is None + else semantic_version.SimpleSpec(version) + ) + + sm_compatible_model_version = spec.select(versions_compatible_with_sagemaker) + if sm_compatible_model_version is not None: + return JumpStartVersionedModelId(model_id, str(sm_compatible_model_version)) + else: + versions_incompatible_with_sagemaker = [ + semantic_version.Version(header.version) + for _, header in manifest.items() + if header.model_id == model_id + ] + sm_incompatible_model_version = spec.select(versions_incompatible_with_sagemaker) + if sm_incompatible_model_version is not None: + model_version_to_use_incompatible_with_sagemaker = str( + sm_incompatible_model_version + ) + sm_version_to_use = [ + header.min_version + for _, header in manifest.items() + if header.model_id == model_id + and header.version == model_version_to_use_incompatible_with_sagemaker + ] + assert len(sm_version_to_use) == 1 + sm_version_to_use = sm_version_to_use[0] + + error_msg = ( + f"Unable to find model manifest for {model_id} with version {version} compatible with your SageMaker version ({sm_version}). " + f"Consider upgrading your SageMaker library to at least version {sm_version_to_use} so you can use version " + f"{model_version_to_use_incompatible_with_sagemaker} of {model_id}." + ) + raise KeyError(error_msg) + else: + error_msg = f"Unable to find model manifest for {model_id} with version {version}" + raise KeyError(error_msg) + + def _get_file_from_s3( + self, + key: JumpStartCachedS3ContentKey, + value: Optional[JumpStartCachedS3ContentValue], + ) -> JumpStartCachedS3ContentValue: + """Return s3 content given a file type and s3_key in ``JumpStartCachedS3ContentKey``. + If a manifest file is being fetched, we only download the object if the md5 hash in + ``head_object`` does not match the current md5 hash for the stored value. This prevents + unnecessarily downloading the full manifest when it hasn't changed. + + Args: + key (JumpStartCachedS3ContentKey): key for which to fetch s3 content. + value (Optional[JumpStartVersionedModelId]): Current value of old cached + s3 content. This is used for the manifest file, so that it is only + downloaded when its content changes. + """ + + file_type, s3_key = key.file_type, key.s3_key + + s3_client = boto3.client("s3", region_name=self._region) + + if file_type == JumpStartS3FileType.MANIFEST: + etag = s3_client.head_object(Bucket=self._bucket, Key=s3_key)["ETag"] + if value is not None and etag == value.md5_hash: + return value + response = s3_client.get_object(Bucket=self._bucket, Key=s3_key) + formatted_body = json.loads(response["Body"].read().decode("utf-8")) + return JumpStartCachedS3ContentValue( + formatted_file_content=utils.get_formatted_manifest(formatted_body), + md5_hash=etag, + ) + if file_type == JumpStartS3FileType.SPECS: + response = s3_client.get_object(Bucket=self._bucket, Key=s3_key) + formatted_body = json.loads(response["Body"].read().decode("utf-8")) + return JumpStartCachedS3ContentValue( + formatted_file_content=JumpStartModelSpecs(formatted_body) + ) + raise RuntimeError(f"Bad value for key: {key}") + + def get_header( + self, model_id: str, semantic_version: Optional[str] = None + ) -> List[JumpStartModelHeader]: + """Return list of headers for a given JumpStart model id and semantic version. + + Args: + model_id (str): model id for which to get a header. + semantic_version (Optional[str]): The semantic version for which to get a header. + If None, the highest compatible version is returned. + """ + + versioned_model_id = self._model_id_semantic_version_manifest_key_cache.get( + JumpStartVersionedModelId(model_id, semantic_version) + ) + manifest = self._s3_cache.get( + JumpStartCachedS3ContentKey(JumpStartS3FileType.MANIFEST, self._manifest_file_s3_key) + ).formatted_file_content + try: + header = manifest[versioned_model_id] + if self._has_retried_cache_refresh: + self._has_retried_cache_refresh = False + return header + except KeyError: + if self._has_retried_cache_refresh: + self._has_retried_cache_refresh = False + raise + self.clear() + self._has_retried_cache_refresh = True + return self.get_header(model_id, semantic_version) + + def get_specs( + self, model_id: str, semantic_version: Optional[str] = None + ) -> JumpStartModelSpecs: + """Return specs for a given JumpStart model id and semantic version. + + Args: + model_id (str): model id for which to get specs. + semantic_version (Optional[str]): The semantic version for which to get specs. + If None, the highest compatible version is returned. + """ + header = self.get_header(model_id, semantic_version) + spec_key = header.spec_key + return self._s3_cache.get( + JumpStartCachedS3ContentKey(JumpStartS3FileType.SPECS, spec_key) + ).formatted_file_content + + def clear(self) -> None: + """Clears the model id/version and s3 cache and resets ``_has_retried_cache_refresh``.""" + self._s3_cache.clear() + self._model_id_semantic_version_manifest_key_cache.clear() + self._has_retried_cache_refresh = False diff --git a/src/sagemaker/jumpstart/constants.py b/src/sagemaker/jumpstart/constants.py new file mode 100644 index 0000000000..339888e7c0 --- /dev/null +++ b/src/sagemaker/jumpstart/constants.py @@ -0,0 +1,20 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from typing import Set +from sagemaker.jumpstart.types import JumpStartLaunchedRegionInfo + + +LAUNCHED_REGIONS: Set[JumpStartLaunchedRegionInfo] = set() + +REGION_NAME_TO_LAUNCHED_REGION_DICT = {region.region_name: region for region in LAUNCHED_REGIONS} +REGION_NAME_SET = {region.region_name for region in LAUNCHED_REGIONS} diff --git a/src/sagemaker/jumpstart/types.py b/src/sagemaker/jumpstart/types.py new file mode 100644 index 0000000000..c9ab5f50de --- /dev/null +++ b/src/sagemaker/jumpstart/types.py @@ -0,0 +1,184 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from enum import Enum +from typing import Any, Dict, List, Optional, Union + + +class JumpStartDataHolderType: + """Base class for many JumpStart types. Allows objects to be added to dicts and sets, + and improves string representation. This class allows different objects with the same + attributes and types to have equality. + """ + + def __eq__(self, other: Any) -> bool: + """Returns True if other object is of the same type + and has all attributes equal.""" + if not isinstance(other, type(self)): + return False + for attribute in self.__slots__: + if getattr(self, attribute) != getattr(other, attribute): + return False + return True + + def __hash__(self) -> int: + """Makes hash of object by first mapping to unique tuple, which then + gets hashed. + """ + return hash((type(self),) + tuple([getattr(self, att) for att in self.__slots__])) + + def __str__(self) -> str: + """Returns string representation of object. Example: + "JumpStartLaunchedRegionInfo: {'content_bucket': 'jumpstart-bucket-us-west-2', 'region_name': 'us-west-2'}" + """ + att_dict = {att: getattr(self, att) for att in self.__slots__} + return f"{type(self).__name__}: {str(att_dict)}" + + def __repr__(self) -> str: + """This is often called instead of __str__ and is the official string representation + of an object, typicaly used for debugging. Example: + "JumpStartLaunchedRegionInfo at 0x7f664529efa0: {'content_bucket': 'jumpstart-bucket-us-west-2', 'region_name': 'us-west-2'}" + """ + att_dict = {att: getattr(self, att) for att in self.__slots__} + return f"{type(self).__name__} at {hex(id(self))}: {str(att_dict)}" + + +class JumpStartS3FileType(str, Enum): + """Simple enum for classifying S3 file type.""" + + MANIFEST = "manifest" + SPECS = "specs" + + +class JumpStartLaunchedRegionInfo(JumpStartDataHolderType): + """Data class for launched region info.""" + + __slots__ = ["content_bucket", "region_name"] + + def __init__(self, content_bucket: str, region_name: str): + self.content_bucket = content_bucket + self.region_name = region_name + + +class JumpStartModelHeader(JumpStartDataHolderType): + """Data class JumpStart model header.""" + + __slots__ = ["model_id", "version", "min_version", "spec_key"] + + def __init__(self, header: Dict[str, str]): + """Initializes a JumpStartModelHeader object from its json representation.""" + self.from_json(header) + + def to_json(self) -> Dict[str, str]: + """Returns json representation of JumpStartModelHeader object.""" + json_obj = {att: getattr(self, att) for att in self.__slots__} + return json_obj + + def from_json(self, json_obj: Dict[str, str]) -> None: + """Sets fields in object based on json of header.""" + self.model_id: str = json_obj["model_id"] + self.version: str = json_obj["version"] + self.min_version: str = json_obj["min_version"] + self.spec_key: str = json_obj["spec_key"] + + +class JumpStartModelSpecs(JumpStartDataHolderType): + """Data class JumpStart model specs.""" + + __slots__ = [ + "model_id", + "version", + "min_sdk_version", + "incremental_training_supported", + "hosting_ecr_specs", + "hosting_artifact_uri", + "hosting_script_uri", + "training_supported", + "training_ecr_specs", + "training_artifact_uri", + "training_script_uri", + "hyperparameters", + ] + + def __init__(self, spec: Dict[str, Any]): + """Initializes a JumpStartModelSpecs object from its json representation.""" + self.from_json(spec) + + def from_json(self, json_obj: Dict[str, Any]) -> None: + """Sets fields in object based on json of header.""" + self.model_id: str = json_obj["model_id"] + self.version: str = json_obj["version"] + self.min_sdk_version: str = json_obj["min_sdk_version"] + self.incremental_training_supported: bool = bool(json_obj["incremental_training_supported"]) + self.hosting_ecr_specs: dict = json_obj["hosting_ecr_specs"] + self.hosting_artifact_uri: str = json_obj["hosting_artifact_uri"] + self.hosting_script_uri: str = json_obj["hosting_script_uri"] + self.training_supported: bool = bool(json_obj["training_supported"]) + if self.training_supported: + self.training_ecr_specs: Optional[dict] = json_obj["training_ecr_specs"] + self.training_artifact_uri: Optional[str] = json_obj["training_artifact_uri"] + self.training_script_uri: Optional[str] = json_obj["training_script_uri"] + self.hyperparameters: Optional[dict] = json_obj["hyperparameters"] + else: + self.training_ecr_specs = ( + self.training_artifact_uri + ) = self.training_script_uri = self.hyperparameters = None + + def to_json(self) -> Dict[str, Any]: + """Returns json representation of JumpStartModelSpecs object.""" + json_obj = {att: getattr(self, att) for att in self.__slots__} + return json_obj + + +class JumpStartVersionedModelId(JumpStartDataHolderType): + """Data class for versioned model ids.""" + + __slots__ = ["model_id", "version"] + + def __init__( + self, + model_id: str, + version: str, + ) -> None: + self.model_id = model_id + self.version = version + + +class JumpStartCachedS3ContentKey(JumpStartDataHolderType): + """Data class for the s3 cached content keys.""" + + __slots__ = ["file_type", "s3_key"] + + def __init__( + self, + file_type: JumpStartS3FileType, + s3_key: str, + ) -> None: + self.file_type = file_type + self.s3_key = s3_key + + +class JumpStartCachedS3ContentValue(JumpStartDataHolderType): + """Data class for the s3 cached content values.""" + + __slots__ = ["formatted_file_content", "md5_hash"] + + def __init__( + self, + formatted_file_content: Union[ + Dict[JumpStartVersionedModelId, JumpStartModelHeader], + List[JumpStartModelSpecs], + ], + md5_hash: Optional[str] = None, + ) -> None: + self.formatted_file_content = formatted_file_content + self.md5_hash = md5_hash diff --git a/src/sagemaker/jumpstart/utils.py b/src/sagemaker/jumpstart/utils.py new file mode 100644 index 0000000000..b279ad03e6 --- /dev/null +++ b/src/sagemaker/jumpstart/utils.py @@ -0,0 +1,100 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from typing import Dict, List + +import semantic_version +from sagemaker.jumpstart import constants +import sagemaker +from sagemaker.jumpstart.types import JumpStartModelHeader, JumpStartVersionedModelId + + +def get_jumpstart_launched_regions_string() -> str: + """Returns formatted string indicating where JumpStart is launched.""" + if len(constants.REGION_NAME_SET) == 0: + return "JumpStart is not available in any region." + if len(constants.REGION_NAME_SET) == 1: + region = list(constants.REGION_NAME_SET)[0] + return f"JumpStart is available in {region} region." + + sorted_regions = sorted(list(constants.REGION_NAME_SET)) + if len(constants.REGION_NAME_SET) == 2: + return f"JumpStart is available in {sorted_regions[0]} and {sorted_regions[1]} regions." + + formatted_launched_regions_list = [] + for i in range(len(sorted_regions)): + region_prefix = "" + if i == len(sorted_regions) - 1: + region_prefix = "and " + formatted_launched_regions_list.append(region_prefix + sorted_regions[i]) + formatted_launched_regions_str = ", ".join(formatted_launched_regions_list) + return f"JumpStart is available in {formatted_launched_regions_str} regions." + + +def get_jumpstart_content_bucket(region: str) -> str: + """Returns regionalized content bucket name for JumpStart. + + Raises: + RuntimeError: If JumpStart is not launched in ``region``. + """ + try: + return constants.REGION_NAME_TO_LAUNCHED_REGION_DICT[region].content_bucket + except KeyError: + formatted_launched_regions_str = get_jumpstart_launched_regions_string() + raise RuntimeError( + f"Unable to get content bucket for JumpStart in {region} region. " + f"{formatted_launched_regions_str}" + ) + + +def get_formatted_manifest( + manifest: List[Dict], +) -> Dict[JumpStartVersionedModelId, JumpStartModelHeader]: + """Returns formatted manifest dictionary from raw manifest. Keys are JumpStartVersionedModelId objects, + values are JumpStartModelHeader objects.""" + manifest_dict = {} + for header in manifest: + header_obj = JumpStartModelHeader(header) + manifest_dict[ + JumpStartVersionedModelId(header_obj.model_id, header_obj.version) + ] = header_obj + return manifest_dict + + +def get_sagemaker_version() -> str: + """Returns sagemaker library version by reading __version__ variable + in module. In order to maintain compatibility with the ``semantic_version`` + library, versions with fewer than 2, or more than 3, periods are rejected. + All versions that cannot be parsed with ``semantic_version`` are also + rejected. + + Raises: + RuntimeError: If the SageMaker version is not readable. + """ + version = sagemaker.__version__ + parsed_version = None + + num_periods = version.count(".") + if num_periods == 2: + parsed_version = version + elif num_periods == 3: + trailing_period_index = version.rfind(".") + parsed_version = version[:trailing_period_index] + else: + raise RuntimeError(f"Bad value for SageMaker version: {sagemaker.__version__}") + + try: + semantic_version.Version(parsed_version) + except ValueError: + raise RuntimeError(f"Bad value for SageMaker version: {sagemaker.__version__}") + + return parsed_version diff --git a/src/sagemaker/utilities/__init__.py b/src/sagemaker/utilities/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/sagemaker/utilities/cache.py b/src/sagemaker/utilities/cache.py new file mode 100644 index 0000000000..9c5e997d00 --- /dev/null +++ b/src/sagemaker/utilities/cache.py @@ -0,0 +1,150 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +import datetime +import collections +from typing import Any, TypeVar, Generic, Callable, Optional + +KeyType = TypeVar("KeyType") +ValType = TypeVar("ValType") + + +class LRUCache(Generic[KeyType, ValType]): + """Class that implements LRU cache with expiring items. + LRU caches remove items in a FIFO manner, such that the oldest items to be used are the first to be removed. + If you attempt to retrieve a cache item that is older than the expiration time, the item will be invalidated. + """ + + class Element: + """Class describes the values in the cache. + This object stores the value itself as well as a timestamp so that this element can be invalidated if + it becomes too old. + """ + + def __init__(self, value: ValType, creation_time: datetime.datetime): + """Initialize an ``Element`` instance for ``LRUCache``. + + Args: + value (ValType): Value that is stored in cache. + creation_time (datetime.datetime): Time at which cache item was created. + """ + self.value = value + self.creation_time = creation_time + + def __init__( + self, + max_cache_items: int, + expiration_time: datetime.timedelta, + retrieval_function: Callable[[KeyType, ValType], ValType], + ) -> None: + """Initialize an ``LRUCache`` instance. + + Args: + max_cache_items (int): Maximum number of items to store in cache. + expiration_time (datetime.timedelta): Maximum time duration a cache element can persist + before being invalidated. + retrieval_function (Callable[[KeyType, ValType], ValType]): Function which maps cache keys + and current values to new values. This function must have kwarg arguments ``key`` and + and ``value``. This function is called as a fallback when the key is not found in the + cache, or a key has expired. + + """ + self._max_cache_items = max_cache_items + self._lru_cache: collections.OrderedDict = collections.OrderedDict() + self._expiration_time = expiration_time + self._retrieval_function = retrieval_function + + def __len__(self) -> int: + """Returns number of elements in cache.""" + return len(self._lru_cache) + + def __contains__(self, key: KeyType) -> bool: + """Returns True if key is found in cache, False otherwise. + + Args: + key (KeyType): Key in cache to retrieve. + """ + return key in self._lru_cache + + def clear(self) -> None: + """Deletes all elements from the cache.""" + self._lru_cache.clear() + + def get(self, key: KeyType, data_source_fallback: Optional[bool] = True) -> ValType: + """Returns value corresponding to key in cache. + + Args: + key (KeyType): Key in cache to retrieve. + data_source_fallback (Optional[bool]): True if data should be retrieved if it's stale or not in cache. + Default: True. + Raises: + KeyError: If key is not found in cache or is outdated and + ``data_source_fallback`` is False. + """ + if data_source_fallback: + if key in self._lru_cache: + return self._get_item(key, False) + else: + self.put(key) + return self._get_item(key, False) + return self._get_item(key, True) + + def put(self, key: KeyType, value: Optional[ValType] = None) -> None: + """Adds key to cache using retrieval_function. If value is provided, this is used instead. + If the key is already in cache, the old element is removed. + If the cache size exceeds the size limit, old elements are removed in order to meet the limit. + + Args: + key (KeyType): Key in cache to retrieve. + value (Optional[ValType]): Value to store for key. Default: None. + """ + curr_value = None + if key in self._lru_cache: + curr_value = self._lru_cache.pop(key) + + while len(self._lru_cache) >= self._max_cache_items: + self._lru_cache.popitem(last=False) + + if value is None: + value = self._retrieval_function( # type: ignore + key=key, value=curr_value.element if curr_value else None + ) + + self._lru_cache[key] = self.Element( + value=value, creation_time=datetime.datetime.now(tz=datetime.timezone.utc) + ) + + def _get_item(self, key: KeyType, fail_on_old_value: bool) -> ValType: + """Returns value from cache corresponding to key. If ``fail_on_old_value``, a + KeyError is thrown instead of a new value getting fetched. + + Args: + key (KeyType): Key in cache to retrieve. + fail_on_old_value (bool): True if a KeyError is thrown when the cache value is old. + + Raises: + KeyError: If key is not in cache or if key is old in cache + and fail_on_old_value is True. + """ + try: + element: Any = self._lru_cache.pop(key) + curr_time = datetime.datetime.now(tz=datetime.timezone.utc) + element_age = curr_time - element.creation_time + if element_age > self._expiration_time: + if fail_on_old_value: + raise KeyError(f"{key} is old! Created at {element.creation_time}") + element.value = self._retrieval_function(key=key, value=element.value) # type: ignore + element.creation_time = curr_time + self._lru_cache[key] = element + return element.value + except KeyError: + raise KeyError(f"{key} not found in LRUCache!") diff --git a/tests/unit/sagemaker/jumpstart/__init__.py b/tests/unit/sagemaker/jumpstart/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/unit/sagemaker/jumpstart/test_cache.py b/tests/unit/sagemaker/jumpstart/test_cache.py new file mode 100644 index 0000000000..206af358a2 --- /dev/null +++ b/tests/unit/sagemaker/jumpstart/test_cache.py @@ -0,0 +1,540 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +import copy +import datetime +import io +import json +from typing import Any, Dict, Tuple, Union +import botocore + +from mock.mock import MagicMock, Mock +import pytest +from mock import patch + +from sagemaker.jumpstart.cache import DEFAULT_MANIFEST_FILE_S3_KEY, JumpStartModelsCache +from sagemaker.jumpstart.types import ( + JumpStartCachedS3ContentKey, + JumpStartCachedS3ContentValue, + JumpStartModelHeader, + JumpStartModelSpecs, + JumpStartS3FileType, + JumpStartVersionedModelId, +) +from sagemaker.jumpstart.utils import get_formatted_manifest + +BASE_SPEC = { + "model_id": "pytorch-ic-mobilenet-v2", + "version": "1.0.0", + "min_sdk_version": "2.49.0", + "training_supported": True, + "incremental_training_supported": True, + "hosting_ecr_specs": { + "framework": "pytorch", + "framework_version": "1.7.0", + "py_version": "py3", + }, + "training_ecr_specs": { + "framework": "pytorch", + "framework_version": "1.9.0", + "py_version": "py3", + }, + "hosting_artifact_uri": "pytorch-infer/infer-pytorch-ic-mobilenet-v2.tar.gz", + "training_artifact_uri": "pytorch-training/train-pytorch-ic-mobilenet-v2.tar.gz", + "hosting_script_uri": "source-directory-tarballs/pytorch/inference/ic/v1.0.0/sourcedir.tar.gz", + "training_script_uri": "source-directory-tarballs/pytorch/transfer_learning/ic/v1.0.0/sourcedir.tar.gz", + "hyperparameters": { + "adam-learning-rate": {"type": "float", "default": 0.05, "min": 1e-08, "max": 1}, + "epochs": {"type": "int", "default": 3, "min": 1, "max": 1000}, + "batch-size": {"type": "int", "default": 4, "min": 1, "max": 1024}, + }, +} + + +def get_spec_from_base_spec(model_id: str, version: str) -> JumpStartModelSpecs: + spec = copy.deepcopy(BASE_SPEC) + + spec["version"] = version + spec["model_id"] = model_id + return JumpStartModelSpecs(spec) + + +def patched_get_file_from_s3( + _modelCacheObj: JumpStartModelsCache, + key: JumpStartCachedS3ContentKey, + value: JumpStartCachedS3ContentValue, +) -> JumpStartCachedS3ContentValue: + + filetype, s3_key = key.file_type, key.s3_key + if filetype == JumpStartS3FileType.MANIFEST: + manifest = [ + { + "model_id": "tensorflow-ic-imagenet-inception-v3-classification-4", + "version": "1.0.0", + "min_version": "2.49.0", + "spec_key": "community_models_specs/tensorflow-ic-imagenet-inception-v3-classification-4/specs_v1.0.0.json", + }, + { + "model_id": "tensorflow-ic-imagenet-inception-v3-classification-4", + "version": "2.0.0", + "min_version": "2.49.0", + "spec_key": "community_models_specs/tensorflow-ic-imagenet-inception-v3-classification-4/specs_v2.0.0.json", + }, + { + "model_id": "pytorch-ic-imagenet-inception-v3-classification-4", + "version": "1.0.0", + "min_version": "2.49.0", + "spec_key": "community_models_specs/pytorch-ic-imagenet-inception-v3-classification-4/specs_v1.0.0.json", + }, + { + "model_id": "pytorch-ic-imagenet-inception-v3-classification-4", + "version": "2.0.0", + "min_version": "2.49.0", + "spec_key": "community_models_specs/pytorch-ic-imagenet-inception-v3-classification-4/specs_v2.0.0.json", + }, + { + "model_id": "tensorflow-ic-imagenet-inception-v3-classification-4", + "version": "3.0.0", + "min_version": "4.49.0", + "spec_key": "community_models_specs/tensorflow-ic-imagenet-inception-v3-classification-4/specs_v3.0.0.json", + }, + ] + return JumpStartCachedS3ContentValue( + formatted_file_content=get_formatted_manifest(manifest) + ) + + if filetype == JumpStartS3FileType.SPECS: + _, model_id, specs_version = s3_key.split("/") + version = specs_version.replace("specs_v", "").replace(".json", "") + return JumpStartCachedS3ContentValue( + formatted_file_content=get_spec_from_base_spec(model_id, version) + ) + + raise ValueError(f"Bad value for filetype: {filetype}") + + +@patch.object(JumpStartModelsCache, "_get_file_from_s3", patched_get_file_from_s3) +@patch("sagemaker.jumpstart.utils.get_sagemaker_version", lambda: "2.68.3") +def test_jumpstart_cache_get_header(): + + cache = JumpStartModelsCache(bucket="some_bucket") + + assert ( + JumpStartModelHeader( + { + "model_id": "tensorflow-ic-imagenet-inception-v3-classification-4", + "version": "2.0.0", + "min_version": "2.49.0", + "spec_key": "community_models_specs/tensorflow-ic-imagenet-inception-v3-classification-4/specs_v2.0.0.json", + } + ) + == cache.get_header(model_id="tensorflow-ic-imagenet-inception-v3-classification-4") + ) + + # See if we can make the same query 2 times consecutively + assert ( + JumpStartModelHeader( + { + "model_id": "tensorflow-ic-imagenet-inception-v3-classification-4", + "version": "2.0.0", + "min_version": "2.49.0", + "spec_key": "community_models_specs/tensorflow-ic-imagenet-inception-v3-classification-4/specs_v2.0.0.json", + } + ) + == cache.get_header(model_id="tensorflow-ic-imagenet-inception-v3-classification-4") + ) + + assert JumpStartModelHeader( + { + "model_id": "tensorflow-ic-imagenet-inception-v3-classification-4", + "version": "2.0.0", + "min_version": "2.49.0", + "spec_key": "community_models_specs/tensorflow-ic-imagenet-inception-v3-classification-4/specs_v2.0.0.json", + } + ) == cache.get_header( + model_id="tensorflow-ic-imagenet-inception-v3-classification-4", semantic_version="*" + ) + + assert JumpStartModelHeader( + { + "model_id": "tensorflow-ic-imagenet-inception-v3-classification-4", + "version": "2.0.0", + "min_version": "2.49.0", + "spec_key": "community_models_specs/tensorflow-ic-imagenet-inception-v3-classification-4/specs_v2.0.0.json", + } + ) == cache.get_header( + model_id="tensorflow-ic-imagenet-inception-v3-classification-4", semantic_version="2.*" + ) + + assert JumpStartModelHeader( + { + "model_id": "tensorflow-ic-imagenet-inception-v3-classification-4", + "version": "2.0.0", + "min_version": "2.49.0", + "spec_key": "community_models_specs/tensorflow-ic-imagenet-inception-v3-classification-4/specs_v2.0.0.json", + } + ) == cache.get_header( + model_id="tensorflow-ic-imagenet-inception-v3-classification-4", semantic_version="2.*.*" + ) + + assert JumpStartModelHeader( + { + "model_id": "tensorflow-ic-imagenet-inception-v3-classification-4", + "version": "2.0.0", + "min_version": "2.49.0", + "spec_key": "community_models_specs/tensorflow-ic-imagenet-inception-v3-classification-4/specs_v2.0.0.json", + } + ) == cache.get_header( + model_id="tensorflow-ic-imagenet-inception-v3-classification-4", semantic_version="2.0.0" + ) + + assert JumpStartModelHeader( + { + "model_id": "tensorflow-ic-imagenet-inception-v3-classification-4", + "version": "1.0.0", + "min_version": "2.49.0", + "spec_key": "community_models_specs/tensorflow-ic-imagenet-inception-v3-classification-4/specs_v1.0.0.json", + } + ) == cache.get_header( + model_id="tensorflow-ic-imagenet-inception-v3-classification-4", semantic_version="1.0.0" + ) + + assert JumpStartModelHeader( + { + "model_id": "tensorflow-ic-imagenet-inception-v3-classification-4", + "version": "1.0.0", + "min_version": "2.49.0", + "spec_key": "community_models_specs/tensorflow-ic-imagenet-inception-v3-classification-4/specs_v1.0.0.json", + } + ) == cache.get_header( + model_id="tensorflow-ic-imagenet-inception-v3-classification-4", semantic_version="1.*" + ) + + assert JumpStartModelHeader( + { + "model_id": "tensorflow-ic-imagenet-inception-v3-classification-4", + "version": "1.0.0", + "min_version": "2.49.0", + "spec_key": "community_models_specs/tensorflow-ic-imagenet-inception-v3-classification-4/specs_v1.0.0.json", + } + ) == cache.get_header( + model_id="tensorflow-ic-imagenet-inception-v3-classification-4", semantic_version="1.*.*" + ) + + with pytest.raises(KeyError) as e: + cache.get_header( + model_id="tensorflow-ic-imagenet-inception-v3-classification-4", semantic_version="3.*" + ) + assert ( + "Unable to find model manifest for tensorflow-ic-imagenet-inception-v3-classification-4 with version 3.* " + "compatible with your SageMaker version (2.68.3). Consider upgrading your SageMaker library to at least " + "version 4.49.0 so you can use version 3.0.0 of tensorflow-ic-imagenet-inception-v3-classification-4." + in str(e.value) + ) + + with pytest.raises(KeyError) as e: + cache.get_header( + model_id="pytorch-ic-imagenet-inception-v3-classification-4", semantic_version="3.*" + ) + assert "Consider upgrading" not in str(e.value) + + with pytest.raises(ValueError): + cache.get_header( + model_id="tensorflow-ic-imagenet-inception-v3-classification-4", semantic_version="BAD" + ) + + with pytest.raises(KeyError): + cache.get_header( + model_id="tensorflow-ic-imagenet-inception-v3-classification-4-bak", + ) + + +@patch("boto3.client") +def test_jumpstart_cache_handles_boto3_issues(mock_boto3_client): + + mock_boto3_client.return_value.get_object.side_effect = Exception() + + cache = JumpStartModelsCache(bucket="some_bucket") + + with pytest.raises(Exception): + cache.get_header( + model_id="tensorflow-ic-imagenet-inception-v3-classification-4", + ) + + mock_boto3_client.return_value.reset_mock() + + mock_boto3_client.return_value.head_object.side_effect = Exception() + + cache = JumpStartModelsCache(bucket="some_bucket") + + with pytest.raises(Exception): + cache.get_header( + model_id="tensorflow-ic-imagenet-inception-v3-classification-4", + ) + + +def test_jumpstart_cache_gets_cleared_when_params_are_set(): + cache = JumpStartModelsCache(bucket="some_bucket") + cache.clear = MagicMock() + cache.set_bucket("some_bucket") + cache.clear.assert_called_once() + cache.clear.reset_mock() + cache.set_region("some_region") + cache.clear.assert_called_once() + cache.clear.reset_mock() + cache.set_manifest_file_s3_key("some_key") + cache.clear.assert_called_once() + + +def test_jumpstart_cache_accepts_input_parameters(): + + region = "us-east-1" + max_s3_cache_items = 1 + s3_cache_expiration_time = datetime.timedelta(weeks=2) + max_semantic_version_cache_items = 3 + semantic_version_cache_expiration_time = datetime.timedelta(microseconds=4) + bucket = "my-amazing-bucket" + manifest_file_key = "some_s3_key" + + cache = JumpStartModelsCache( + region=region, + max_s3_cache_items=max_s3_cache_items, + s3_cache_expiration_time=s3_cache_expiration_time, + max_semantic_version_cache_items=max_semantic_version_cache_items, + semantic_version_cache_expiration_time=semantic_version_cache_expiration_time, + bucket=bucket, + manifest_file_s3_key=manifest_file_key, + ) + + assert cache.get_region() == region + assert cache.get_bucket() == bucket + assert cache._s3_cache._max_cache_items == max_s3_cache_items + assert cache._s3_cache._expiration_time == s3_cache_expiration_time + assert ( + cache._model_id_semantic_version_manifest_key_cache._max_cache_items + == max_semantic_version_cache_items + ) + assert ( + cache._model_id_semantic_version_manifest_key_cache._expiration_time + == semantic_version_cache_expiration_time + ) + + +@patch("boto3.client") +def test_jumpstart_cache_evaluates_md5_hash(mock_boto3_client): + + mock_json = json.dumps( + [ + { + "model_id": "pytorch-ic-imagenet-inception-v3-classification-4", + "version": "2.0.0", + "min_version": "2.49.0", + "spec_key": "community_models_specs/pytorch-ic-imagenet-inception-v3-classification-4/specs_v2.0.0.json", + } + ] + ) + + bucket_name = "bucket_name" + now = datetime.datetime.now() + + with patch("datetime.datetime") as mock_datetime: + mock_datetime.now.return_value = now + + cache = JumpStartModelsCache( + bucket=bucket_name, s3_cache_expiration_time=datetime.timedelta(hours=1) + ) + + mock_boto3_client.return_value.get_object.return_value = { + "Body": botocore.response.StreamingBody( + io.BytesIO(bytes(mock_json, "utf-8")), content_length=len(mock_json) + ) + } + mock_boto3_client.return_value.head_object.return_value = {"ETag": "hash1"} + + cache.get_header(model_id="pytorch-ic-imagenet-inception-v3-classification-4") + + # first time accessing cache should involve get_object and head_object + mock_boto3_client.return_value.get_object.assert_called_with( + Bucket=bucket_name, Key=DEFAULT_MANIFEST_FILE_S3_KEY + ) + mock_boto3_client.return_value.head_object.assert_called_with( + Bucket=bucket_name, Key=DEFAULT_MANIFEST_FILE_S3_KEY + ) + + mock_boto3_client.return_value.get_object.reset_mock() + mock_boto3_client.return_value.head_object.reset_mock() + + # second time accessing cache should just involve head_object if hash hasn't changed + mock_boto3_client.return_value.get_object.return_value = { + "Body": botocore.response.StreamingBody( + io.BytesIO(bytes(mock_json, "utf-8")), content_length=len(mock_json) + ) + } + mock_boto3_client.return_value.head_object.return_value = {"ETag": "hash1"} + + # invalidate cache + mock_datetime.now.return_value += datetime.timedelta(hours=2) + + cache.get_header(model_id="pytorch-ic-imagenet-inception-v3-classification-4") + + mock_boto3_client.return_value.head_object.assert_called_with( + Bucket=bucket_name, Key=DEFAULT_MANIFEST_FILE_S3_KEY + ) + mock_boto3_client.return_value.get_object.assert_not_called() + + mock_boto3_client.return_value.get_object.reset_mock() + mock_boto3_client.return_value.head_object.reset_mock() + + # third time accessing cache should involve head_object and get_object if hash has changed + mock_boto3_client.return_value.head_object.return_value = {"ETag": "hash2"} + mock_boto3_client.return_value.get_object.return_value = { + "Body": botocore.response.StreamingBody( + io.BytesIO(bytes(mock_json, "utf-8")), content_length=len(mock_json) + ) + } + + # invalidate cache + mock_datetime.now.return_value += datetime.timedelta(hours=2) + + cache.get_header(model_id="pytorch-ic-imagenet-inception-v3-classification-4") + + mock_boto3_client.return_value.get_object.assert_called_with( + Bucket=bucket_name, Key=DEFAULT_MANIFEST_FILE_S3_KEY + ) + mock_boto3_client.return_value.head_object.assert_called_with( + Bucket=bucket_name, Key=DEFAULT_MANIFEST_FILE_S3_KEY + ) + + +@patch("boto3.client") +def test_jumpstart_cache_makes_correct_s3_calls(mock_boto3_client): + + # test get_header + mock_json = json.dumps( + [ + { + "model_id": "pytorch-ic-imagenet-inception-v3-classification-4", + "version": "2.0.0", + "min_version": "2.49.0", + "spec_key": "community_models_specs/pytorch-ic-imagenet-inception-v3-classification-4/specs_v2.0.0.json", + } + ] + ) + mock_boto3_client.return_value.get_object.return_value = { + "Body": botocore.response.StreamingBody( + io.BytesIO(bytes(mock_json, "utf-8")), content_length=len(mock_json) + ) + } + + mock_boto3_client.return_value.head_object.return_value = {"ETag": "some-hash"} + + bucket_name = "bucket_name" + cache = JumpStartModelsCache(bucket=bucket_name) + cache.get_header(model_id="pytorch-ic-imagenet-inception-v3-classification-4") + + mock_boto3_client.return_value.get_object.assert_called_with( + Bucket=bucket_name, Key=DEFAULT_MANIFEST_FILE_S3_KEY + ) + mock_boto3_client.return_value.head_object.assert_called_with( + Bucket=bucket_name, Key=DEFAULT_MANIFEST_FILE_S3_KEY + ) + + # test get_specs. manifest already in cache, so only s3 call will be to get specs. + mock_json = json.dumps(BASE_SPEC) + + mock_boto3_client.return_value.reset_mock() + + mock_boto3_client.return_value.get_object.return_value = { + "Body": botocore.response.StreamingBody( + io.BytesIO(bytes(mock_json, "utf-8")), content_length=len(mock_json) + ) + } + cache.get_specs(model_id="pytorch-ic-imagenet-inception-v3-classification-4") + + mock_boto3_client.return_value.get_object.assert_called_with( + Bucket=bucket_name, + Key="community_models_specs/pytorch-ic-imagenet-inception-v3-classification-4/specs_v2.0.0.json", + ) + mock_boto3_client.return_value.head_object.assert_not_called() + + +@patch.object(JumpStartModelsCache, "_get_file_from_s3", patched_get_file_from_s3) +def test_jumpstart_cache_handles_bad_semantic_version_manifest_key_cache(): + cache = JumpStartModelsCache(bucket="some_bucket") + + cache.clear = MagicMock() + cache._model_id_semantic_version_manifest_key_cache = MagicMock() + cache._model_id_semantic_version_manifest_key_cache.get.side_effect = [ + JumpStartVersionedModelId( + "tensorflow-ic-imagenet-inception-v3-classification-4", "999.0.0" + ), + JumpStartVersionedModelId("tensorflow-ic-imagenet-inception-v3-classification-4", "1.0.0"), + ] + + assert JumpStartModelHeader( + { + "model_id": "tensorflow-ic-imagenet-inception-v3-classification-4", + "version": "1.0.0", + "min_version": "2.49.0", + "spec_key": "community_models_specs/tensorflow-ic-imagenet-inception-v3-classification-4/specs_v1.0.0.json", + } + ) == cache.get_header( + model_id="tensorflow-ic-imagenet-inception-v3-classification-4", + ) + cache.clear.assert_called_once() + cache.clear.reset_mock() + + cache._model_id_semantic_version_manifest_key_cache.get.side_effect = [ + JumpStartVersionedModelId( + "tensorflow-ic-imagenet-inception-v3-classification-4", "999.0.0" + ), + JumpStartVersionedModelId( + "tensorflow-ic-imagenet-inception-v3-classification-4", "987.0.0" + ), + ] + with pytest.raises(KeyError): + cache.get_header( + model_id="tensorflow-ic-imagenet-inception-v3-classification-4", + ) + cache.clear.assert_called_once() + + +@patch.object(JumpStartModelsCache, "_get_file_from_s3", patched_get_file_from_s3) +@patch("sagemaker.jumpstart.utils.get_sagemaker_version", lambda: "2.68.3") +def test_jumpstart_cache_get_specs(): + cache = JumpStartModelsCache(bucket="some_bucket") + + model_id, version = "tensorflow-ic-imagenet-inception-v3-classification-4", "2.0.0" + assert get_spec_from_base_spec(model_id, version) == cache.get_specs( + model_id=model_id, semantic_version=version + ) + + model_id, version = "tensorflow-ic-imagenet-inception-v3-classification-4", "1.0.0" + assert get_spec_from_base_spec(model_id, version) == cache.get_specs( + model_id=model_id, semantic_version=version + ) + + model_id = "pytorch-ic-imagenet-inception-v3-classification-4" + assert get_spec_from_base_spec(model_id, "1.0.0") == cache.get_specs( + model_id=model_id, semantic_version="1.*" + ) + + with pytest.raises(KeyError): + cache.get_specs( + model_id=model_id + "bak", + ) + + with pytest.raises(KeyError): + cache.get_specs(model_id=model_id, semantic_version="9.*") + + with pytest.raises(ValueError): + cache.get_specs(model_id=model_id, semantic_version="BAD") diff --git a/tests/unit/sagemaker/jumpstart/test_types.py b/tests/unit/sagemaker/jumpstart/test_types.py new file mode 100644 index 0000000000..217722669c --- /dev/null +++ b/tests/unit/sagemaker/jumpstart/test_types.py @@ -0,0 +1,122 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +import copy +from sagemaker.jumpstart.types import JumpStartModelSpecs, JumpStartModelHeader + + +def test_jumpstart_model_header(): + + header_dict = { + "model_id": "tensorflow-ic-imagenet-inception-v3-classification-4", + "version": "1.0.0", + "min_version": "2.49.0", + "spec_key": "community_models_specs/tensorflow-ic-imagenet-inception-v3-classification-4/specs_v1.0.0.json", + } + + header1 = JumpStartModelHeader(header_dict) + + assert header1.model_id == "tensorflow-ic-imagenet-inception-v3-classification-4" + assert header1.version == "1.0.0" + assert header1.min_version == "2.49.0" + assert ( + header1.spec_key + == "community_models_specs/tensorflow-ic-imagenet-inception-v3-classification-4/specs_v1.0.0.json" + ) + + assert header1.to_json() == header_dict + + header2 = JumpStartModelHeader( + { + "model_id": "pytorch-ic-imagenet-inception-v3-classification-4", + "version": "1.0.0", + "min_version": "2.49.0", + "spec_key": "community_models_specs/tensorflow-ic-imagenet-inception-v3-classification-4/specs_v1.0.0.json", + } + ) + + assert header1 != header2 + + header3 = copy.deepcopy(header1) + assert header1 == header3 + + +def test_jumpstart_model_specs(): + + specs_dict = { + "model_id": "pytorch-ic-mobilenet-v2", + "version": "1.0.0", + "min_sdk_version": "2.49.0", + "training_supported": True, + "incremental_training_supported": True, + "hosting_ecr_specs": { + "framework": "pytorch", + "framework_version": "1.7.0", + "py_version": "py3", + }, + "training_ecr_specs": { + "framework": "pytorch", + "framework_version": "1.9.0", + "py_version": "py3", + }, + "hosting_artifact_uri": "pytorch-infer/infer-pytorch-ic-mobilenet-v2.tar.gz", + "training_artifact_uri": "pytorch-training/train-pytorch-ic-mobilenet-v2.tar.gz", + "hosting_script_uri": "source-directory-tarballs/pytorch/inference/ic/v1.0.0/sourcedir.tar.gz", + "training_script_uri": "source-directory-tarballs/pytorch/transfer_learning/ic/v1.0.0/sourcedir.tar.gz", + "hyperparameters": { + "adam-learning-rate": {"type": "float", "default": 0.05, "min": 1e-08, "max": 1}, + "epochs": {"type": "int", "default": 3, "min": 1, "max": 1000}, + "batch-size": {"type": "int", "default": 4, "min": 1, "max": 1024}, + }, + } + + specs1 = JumpStartModelSpecs(specs_dict) + + assert specs1.model_id == "pytorch-ic-mobilenet-v2" + assert specs1.version == "1.0.0" + assert specs1.min_sdk_version == "2.49.0" + assert specs1.training_supported == True + assert specs1.incremental_training_supported == True + assert specs1.hosting_ecr_specs == { + "framework": "pytorch", + "framework_version": "1.7.0", + "py_version": "py3", + } + assert specs1.training_ecr_specs == { + "framework": "pytorch", + "framework_version": "1.9.0", + "py_version": "py3", + } + assert specs1.hosting_artifact_uri == "pytorch-infer/infer-pytorch-ic-mobilenet-v2.tar.gz" + assert specs1.training_artifact_uri == "pytorch-training/train-pytorch-ic-mobilenet-v2.tar.gz" + assert ( + specs1.hosting_script_uri + == "source-directory-tarballs/pytorch/inference/ic/v1.0.0/sourcedir.tar.gz" + ) + assert ( + specs1.training_script_uri + == "source-directory-tarballs/pytorch/transfer_learning/ic/v1.0.0/sourcedir.tar.gz" + ) + assert specs1.hyperparameters == { + "adam-learning-rate": {"type": "float", "default": 0.05, "min": 1e-08, "max": 1}, + "epochs": {"type": "int", "default": 3, "min": 1, "max": 1000}, + "batch-size": {"type": "int", "default": 4, "min": 1, "max": 1024}, + } + + assert specs1.to_json() == specs_dict + + specs_dict["model_id"] = "diff model id" + specs2 = JumpStartModelSpecs(specs_dict) + assert specs1 != specs2 + + specs3 = copy.deepcopy(specs1) + assert specs3 == specs1 diff --git a/tests/unit/sagemaker/jumpstart/test_utils.py b/tests/unit/sagemaker/jumpstart/test_utils.py new file mode 100644 index 0000000000..cc8d02cb6f --- /dev/null +++ b/tests/unit/sagemaker/jumpstart/test_utils.py @@ -0,0 +1,103 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from mock.mock import patch +import pytest +from sagemaker.jumpstart import utils +from sagemaker.jumpstart.constants import REGION_NAME_SET +from sagemaker.jumpstart.types import JumpStartModelHeader, JumpStartVersionedModelId +from sagemaker.jumpstart.types import JumpStartLaunchedRegionInfo + + +def test_get_jumpstart_content_bucket(): + bad_region = "bad_region" + assert bad_region not in REGION_NAME_SET + with pytest.raises(RuntimeError): + utils.get_jumpstart_content_bucket(bad_region) + + +def test_get_jumpstart_launched_regions_string(): + + with patch("sagemaker.jumpstart.constants.REGION_NAME_SET", {}): + assert ( + utils.get_jumpstart_launched_regions_string() + == "JumpStart is not available in any region." + ) + + with patch("sagemaker.jumpstart.constants.REGION_NAME_SET", {"some_region"}): + assert ( + utils.get_jumpstart_launched_regions_string() + == "JumpStart is available in some_region region." + ) + + with patch("sagemaker.jumpstart.constants.REGION_NAME_SET", {"some_region1", "some_region2"}): + assert ( + utils.get_jumpstart_launched_regions_string() + == "JumpStart is available in some_region1 and some_region2 regions." + ) + + with patch("sagemaker.jumpstart.constants.REGION_NAME_SET", {"a", "b", "c"}): + assert ( + utils.get_jumpstart_launched_regions_string() + == "JumpStart is available in a, b, and c regions." + ) + + +def test_get_formatted_manifest(): + mock_manifest = [ + { + "model_id": "tensorflow-ic-imagenet-inception-v3-classification-4", + "version": "1.0.0", + "min_version": "2.49.0", + "spec_key": "community_models_specs/tensorflow-ic-imagenet-inception-v3-classification-4/specs_v1.0.0.json", + }, + ] + + assert utils.get_formatted_manifest(mock_manifest) == { + JumpStartVersionedModelId( + "tensorflow-ic-imagenet-inception-v3-classification-4", "1.0.0" + ): JumpStartModelHeader(mock_manifest[0]) + } + + assert utils.get_formatted_manifest([]) == {} + + +def test_get_sagemaker_version(): + + with patch("sagemaker.__version__", "1.2.3"): + assert utils.get_sagemaker_version() == "1.2.3" + + with patch("sagemaker.__version__", "1.2.3.3332j"): + assert utils.get_sagemaker_version() == "1.2.3" + + with patch("sagemaker.__version__", "1.2.3."): + assert utils.get_sagemaker_version() == "1.2.3" + + with pytest.raises(RuntimeError): + with patch("sagemaker.__version__", "1.2.3dfsdfs"): + utils.get_sagemaker_version() + + with pytest.raises(RuntimeError): + with patch("sagemaker.__version__", "1.2"): + utils.get_sagemaker_version() + + with pytest.raises(RuntimeError): + with patch("sagemaker.__version__", "1"): + utils.get_sagemaker_version() + + with pytest.raises(RuntimeError): + with patch("sagemaker.__version__", ""): + utils.get_sagemaker_version() + + with pytest.raises(RuntimeError): + with patch("sagemaker.__version__", "1.2.3.4.5"): + utils.get_sagemaker_version() diff --git a/tests/unit/sagemaker/utilities/__init__.py b/tests/unit/sagemaker/utilities/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/unit/sagemaker/utilities/test_cache.py b/tests/unit/sagemaker/utilities/test_cache.py new file mode 100644 index 0000000000..a71c2e2abd --- /dev/null +++ b/tests/unit/sagemaker/utilities/test_cache.py @@ -0,0 +1,163 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from typing import Optional, Union +from mock.mock import MagicMock, patch +import pytest + +from sagemaker.utilities import cache +import datetime + + +def retrieval_function(key: Optional[int] = None, value: Optional[str] = None) -> str: + return str(hash(str(key))) + + +def test_cache_retrieves_item(): + my_cache = cache.LRUCache[int, Union[int, str]]( + max_cache_items=10, + expiration_time=datetime.timedelta(hours=1), + retrieval_function=retrieval_function, + ) + + my_cache.put(5) + assert my_cache.get(5, False) == retrieval_function(key=5) + + my_cache.put(6, 7) + assert my_cache.get(6, False) == 7 + assert len(my_cache) == 2 + + my_cache.put(5, 6) + assert my_cache.get(5, False) == 6 + assert len(my_cache) == 2 + + with pytest.raises(KeyError): + my_cache.get(21, False) + + +def test_cache_invalidates_old_item(): + my_cache = cache.LRUCache[int, Union[int, str]]( + max_cache_items=10, + expiration_time=datetime.timedelta(milliseconds=1), + retrieval_function=retrieval_function, + ) + + curr_time = datetime.datetime.now() + with patch("datetime.datetime") as mock_datetime: + mock_datetime.now.return_value = curr_time + my_cache.put(5) + mock_datetime.now.return_value += datetime.timedelta(milliseconds=2) + with pytest.raises(KeyError): + my_cache.get(5, False) + + with patch("datetime.datetime") as mock_datetime: + mock_datetime.now.return_value = curr_time + my_cache.put(5) + mock_datetime.now.return_value += datetime.timedelta(milliseconds=0.5) + assert my_cache.get(5, False) == retrieval_function(key=5) + + +def test_cache_fetches_new_item(): + my_cache = cache.LRUCache[int, Union[int, str]]( + max_cache_items=10, + expiration_time=datetime.timedelta(milliseconds=1), + retrieval_function=retrieval_function, + ) + + curr_time = datetime.datetime.now() + with patch("datetime.datetime") as mock_datetime: + mock_datetime.now.return_value = curr_time + my_cache.put(5, 10) + mock_datetime.now.return_value += datetime.timedelta(milliseconds=2) + assert my_cache.get(5) == retrieval_function(key=5) + + with patch("datetime.datetime") as mock_datetime: + mock_datetime.now.return_value = curr_time + my_cache.put(5, 10) + mock_datetime.now.return_value += datetime.timedelta(milliseconds=0.5) + assert my_cache.get(5, False) == 10 + mock_datetime.now.return_value += datetime.timedelta(milliseconds=0.75) + with pytest.raises(KeyError): + my_cache.get(5, False) + + +def test_cache_removes_old_items_once_size_limit_reached(): + my_cache = cache.LRUCache[int, Union[int, str]]( + max_cache_items=5, + expiration_time=datetime.timedelta(hours=1), + retrieval_function=retrieval_function, + ) + + for i in [1, 2, 3, 4, 5]: + my_cache.put(i) + + assert len(my_cache) == 5 + + my_cache.put(6) + assert len(my_cache) == 5 + with pytest.raises(KeyError): + my_cache.get(1, False) + assert my_cache.get(2, False) == retrieval_function(key=2) + + +def test_cache_get_with_data_source_fallback(): + my_cache = cache.LRUCache[int, Union[int, str]]( + max_cache_items=5, + expiration_time=datetime.timedelta(hours=1), + retrieval_function=retrieval_function, + ) + + for i in range(10): + val = my_cache.get(i) + assert val == retrieval_function(key=i) + + assert len(my_cache) == 5 + + +def test_cache_gets_stored_value(): + my_cache = cache.LRUCache[int, Union[int, str]]( + max_cache_items=5, + expiration_time=datetime.timedelta(hours=1), + retrieval_function=retrieval_function, + ) + + for i in range(5): + my_cache.put(i) + + my_cache._retrieval_function = MagicMock() + my_cache.get(4) + my_cache._retrieval_function.assert_not_called() + + my_cache._retrieval_function.reset_mock() + my_cache.get(5) + my_cache._retrieval_function.assert_called() + + my_cache._retrieval_function.reset_mock() + my_cache.get(0) + my_cache._retrieval_function.assert_called() + + +def test_cache_clear_and_contains(): + my_cache = cache.LRUCache[int, Union[int, str]]( + max_cache_items=5, + expiration_time=datetime.timedelta(hours=1), + retrieval_function=retrieval_function, + ) + + for i in range(5): + my_cache.put(i) + assert i in my_cache + + my_cache.clear() + assert len(my_cache) == 0 + with pytest.raises(KeyError): + my_cache.get(1, False) From d08fa11d12af6bfea59d962109d730ac44239233 Mon Sep 17 00:00:00 2001 From: Evan Kravitz Date: Thu, 11 Nov 2021 16:27:03 +0000 Subject: [PATCH 02/13] fix: formatting/style/lint errors --- src/sagemaker/jumpstart/cache.py | 106 ++++++++++--------- src/sagemaker/jumpstart/constants.py | 2 + src/sagemaker/jumpstart/types.py | 31 ++++-- src/sagemaker/jumpstart/utils.py | 19 ++-- src/sagemaker/utilities/cache.py | 53 ++++++---- tests/unit/sagemaker/jumpstart/test_cache.py | 102 +++++++++++------- tests/unit/sagemaker/jumpstart/test_types.py | 5 +- tests/unit/sagemaker/jumpstart/test_utils.py | 2 +- tests/unit/sagemaker/utilities/test_cache.py | 1 + 9 files changed, 193 insertions(+), 128 deletions(-) diff --git a/src/sagemaker/jumpstart/cache.py b/src/sagemaker/jumpstart/cache.py index fabae85ea7..cd89773f43 100644 --- a/src/sagemaker/jumpstart/cache.py +++ b/src/sagemaker/jumpstart/cache.py @@ -10,23 +10,23 @@ # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. +"""This module defines the JumpStartModelsCache class.""" +from __future__ import absolute_import import datetime from typing import List, Optional +import json +import boto3 +import semantic_version from sagemaker.jumpstart.types import ( JumpStartCachedS3ContentKey, JumpStartCachedS3ContentValue, JumpStartModelHeader, JumpStartModelSpecs, - JumpStartModelSpecs, JumpStartS3FileType, JumpStartVersionedModelId, ) from sagemaker.jumpstart import utils from sagemaker.utilities.cache import LRUCache -import boto3 -import json -import semantic_version - DEFAULT_REGION_NAME = boto3.session.Session().region_name @@ -41,6 +41,7 @@ class JumpStartModelsCache: """Class that implements a cache for JumpStart models manifests and specs. + The manifest and specs associated with JumpStart models provide the information necessary for launching JumpStart models from the SageMaker SDK. """ @@ -62,15 +63,16 @@ def __init__( Args: region (Optional[str]): AWS region to associate with cache. Default: region associated with botocore session. - max_s3_cache_items (Optional[int]): Maximum number of files to store in s3 cache. Default: 20. - s3_cache_expiration_time (Optional[datetime.timedelta]): Maximum time to hold items in s3 - cache before invalidation. Default: 6 hours. + max_s3_cache_items (Optional[int]): Maximum number of files to store in s3 cache. + Default: 20. + s3_cache_expiration_time (Optional[datetime.timedelta]): Maximum time to hold items in + s3 cache before invalidation. Default: 6 hours. max_semantic_version_cache_items (Optional[int]): Maximum number of files to store in semantic version cache. Default: 20. - semantic_version_cache_expiration_time (Optional[datetime.timedelta]): Maximum time to hold - items in semantic version cache before invalidation. Default: 6 hours. - bucket (Optional[str]): S3 bucket to associate with cache. Default: JumpStart-hosted content - bucket for region. + semantic_version_cache_expiration_time (Optional[datetime.timedelta]): Maximum time to + hold items in semantic version cache before invalidation. Default: 6 hours. + bucket (Optional[str]): S3 bucket to associate with cache. Default: JumpStart-hosted + content bucket for region. """ self._region = region @@ -120,15 +122,16 @@ def get_bucket(self) -> None: return self._bucket def _get_manifest_key_from_model_id_semantic_version( - self, key: JumpStartVersionedModelId, value: Optional[JumpStartVersionedModelId] + self, + key: JumpStartVersionedModelId, + value: Optional[JumpStartVersionedModelId], # pylint: disable=W0613 ) -> JumpStartVersionedModelId: - """Return model id and version in manifest that matches semantic version/id - from customer request. + """Return model id and version in manifest that matches semantic version/id. Args: key (JumpStartVersionedModelId): Key for which to fetch versioned model id. - value (Optional[JumpStartVersionedModelId]): Unused variable for current value of old cached - model id/version. + value (Optional[JumpStartVersionedModelId]): Unused variable for current value of + old cached model id/version. Raises: KeyError: If the semantic version is not found in the manifest. @@ -158,35 +161,34 @@ def _get_manifest_key_from_model_id_semantic_version( sm_compatible_model_version = spec.select(versions_compatible_with_sagemaker) if sm_compatible_model_version is not None: return JumpStartVersionedModelId(model_id, str(sm_compatible_model_version)) - else: - versions_incompatible_with_sagemaker = [ - semantic_version.Version(header.version) + + versions_incompatible_with_sagemaker = [ + semantic_version.Version(header.version) + for _, header in manifest.items() + if header.model_id == model_id + ] + sm_incompatible_model_version = spec.select(versions_incompatible_with_sagemaker) + if sm_incompatible_model_version is not None: + model_version_to_use_incompatible_with_sagemaker = str(sm_incompatible_model_version) + sm_version_to_use = [ + header.min_version for _, header in manifest.items() if header.model_id == model_id + and header.version == model_version_to_use_incompatible_with_sagemaker ] - sm_incompatible_model_version = spec.select(versions_incompatible_with_sagemaker) - if sm_incompatible_model_version is not None: - model_version_to_use_incompatible_with_sagemaker = str( - sm_incompatible_model_version - ) - sm_version_to_use = [ - header.min_version - for _, header in manifest.items() - if header.model_id == model_id - and header.version == model_version_to_use_incompatible_with_sagemaker - ] - assert len(sm_version_to_use) == 1 - sm_version_to_use = sm_version_to_use[0] - - error_msg = ( - f"Unable to find model manifest for {model_id} with version {version} compatible with your SageMaker version ({sm_version}). " - f"Consider upgrading your SageMaker library to at least version {sm_version_to_use} so you can use version " - f"{model_version_to_use_incompatible_with_sagemaker} of {model_id}." - ) - raise KeyError(error_msg) - else: - error_msg = f"Unable to find model manifest for {model_id} with version {version}" - raise KeyError(error_msg) + assert len(sm_version_to_use) == 1 + sm_version_to_use = sm_version_to_use[0] + + error_msg = ( + f"Unable to find model manifest for {model_id} with version {version} " + f"compatible with your SageMaker version ({sm_version}). " + f"Consider upgrading your SageMaker library to at least version " + f"{sm_version_to_use} so you can use version " + f"{model_version_to_use_incompatible_with_sagemaker} of {model_id}." + ) + raise KeyError(error_msg) + error_msg = f"Unable to find model manifest for {model_id} with version {version}" + raise KeyError(error_msg) def _get_file_from_s3( self, @@ -194,6 +196,7 @@ def _get_file_from_s3( value: Optional[JumpStartCachedS3ContentValue], ) -> JumpStartCachedS3ContentValue: """Return s3 content given a file type and s3_key in ``JumpStartCachedS3ContentKey``. + If a manifest file is being fetched, we only download the object if the md5 hash in ``head_object`` does not match the current md5 hash for the stored value. This prevents unnecessarily downloading the full manifest when it hasn't changed. @@ -228,18 +231,18 @@ def _get_file_from_s3( raise RuntimeError(f"Bad value for key: {key}") def get_header( - self, model_id: str, semantic_version: Optional[str] = None + self, model_id: str, semantic_version_str: Optional[str] = None ) -> List[JumpStartModelHeader]: """Return list of headers for a given JumpStart model id and semantic version. Args: model_id (str): model id for which to get a header. - semantic_version (Optional[str]): The semantic version for which to get a header. - If None, the highest compatible version is returned. + semantic_version_str (Optional[str]): The semantic version for which to get a + header. If None, the highest compatible version is returned. """ versioned_model_id = self._model_id_semantic_version_manifest_key_cache.get( - JumpStartVersionedModelId(model_id, semantic_version) + JumpStartVersionedModelId(model_id, semantic_version_str) ) manifest = self._s3_cache.get( JumpStartCachedS3ContentKey(JumpStartS3FileType.MANIFEST, self._manifest_file_s3_key) @@ -258,16 +261,17 @@ def get_header( return self.get_header(model_id, semantic_version) def get_specs( - self, model_id: str, semantic_version: Optional[str] = None + self, model_id: str, semantic_version_str: Optional[str] = None ) -> JumpStartModelSpecs: """Return specs for a given JumpStart model id and semantic version. Args: model_id (str): model id for which to get specs. - semantic_version (Optional[str]): The semantic version for which to get specs. - If None, the highest compatible version is returned. + semantic_version_str (Optional[str]): The semantic version for which to get + specs. If None, the highest compatible version is returned. """ - header = self.get_header(model_id, semantic_version) + + header = self.get_header(model_id, semantic_version_str) spec_key = header.spec_key return self._s3_cache.get( JumpStartCachedS3ContentKey(JumpStartS3FileType.SPECS, spec_key) diff --git a/src/sagemaker/jumpstart/constants.py b/src/sagemaker/jumpstart/constants.py index 339888e7c0..65700ecec8 100644 --- a/src/sagemaker/jumpstart/constants.py +++ b/src/sagemaker/jumpstart/constants.py @@ -10,6 +10,8 @@ # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. +"""This module stores constants related to SageMaker JumpStart.""" +from __future__ import absolute_import from typing import Set from sagemaker.jumpstart.types import JumpStartLaunchedRegionInfo diff --git a/src/sagemaker/jumpstart/types.py b/src/sagemaker/jumpstart/types.py index c9ab5f50de..a6f19e29e7 100644 --- a/src/sagemaker/jumpstart/types.py +++ b/src/sagemaker/jumpstart/types.py @@ -10,19 +10,25 @@ # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. +"""This module stores types related to SageMaker JumpStart.""" +from __future__ import absolute_import from enum import Enum from typing import Any, Dict, List, Optional, Union class JumpStartDataHolderType: - """Base class for many JumpStart types. Allows objects to be added to dicts and sets, + """Base class for many JumpStart types. + + Allows objects to be added to dicts and sets, and improves string representation. This class allows different objects with the same attributes and types to have equality. """ + __slots__: List[str] = [] + def __eq__(self, other: Any) -> bool: - """Returns True if other object is of the same type - and has all attributes equal.""" + """Returns True if ``other`` is of the same type and has all attributes equal.""" + if not isinstance(other, type(self)): return False for attribute in self.__slots__: @@ -31,23 +37,30 @@ def __eq__(self, other: Any) -> bool: return True def __hash__(self) -> int: - """Makes hash of object by first mapping to unique tuple, which then - gets hashed. + """Makes hash of object. + + Maps object to unique tuple, which then gets hashed. """ + return hash((type(self),) + tuple([getattr(self, att) for att in self.__slots__])) def __str__(self) -> str: """Returns string representation of object. Example: - "JumpStartLaunchedRegionInfo: {'content_bucket': 'jumpstart-bucket-us-west-2', 'region_name': 'us-west-2'}" + + "JumpStartLaunchedRegionInfo: + {'content_bucket': 'bucket', 'region_name': 'us-west-2'}" """ + att_dict = {att: getattr(self, att) for att in self.__slots__} return f"{type(self).__name__}: {str(att_dict)}" def __repr__(self) -> str: - """This is often called instead of __str__ and is the official string representation - of an object, typicaly used for debugging. Example: - "JumpStartLaunchedRegionInfo at 0x7f664529efa0: {'content_bucket': 'jumpstart-bucket-us-west-2', 'region_name': 'us-west-2'}" + """Returns ``__repr__`` string of object. Example: + + "JumpStartLaunchedRegionInfo at 0x7f664529efa0: + {'content_bucket': 'bucket', 'region_name': 'us-west-2'}" """ + att_dict = {att: getattr(self, att) for att in self.__slots__} return f"{type(self).__name__} at {hex(id(self))}: {str(att_dict)}" diff --git a/src/sagemaker/jumpstart/utils.py b/src/sagemaker/jumpstart/utils.py index b279ad03e6..560a22138e 100644 --- a/src/sagemaker/jumpstart/utils.py +++ b/src/sagemaker/jumpstart/utils.py @@ -10,6 +10,8 @@ # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. +"""This module contains utilities related to SageMaker JumpStart.""" +from __future__ import absolute_import from typing import Dict, List import semantic_version @@ -31,11 +33,11 @@ def get_jumpstart_launched_regions_string() -> str: return f"JumpStart is available in {sorted_regions[0]} and {sorted_regions[1]} regions." formatted_launched_regions_list = [] - for i in range(len(sorted_regions)): + for i, region in enumerate(sorted_regions): region_prefix = "" if i == len(sorted_regions) - 1: region_prefix = "and " - formatted_launched_regions_list.append(region_prefix + sorted_regions[i]) + formatted_launched_regions_list.append(region_prefix + region) formatted_launched_regions_str = ", ".join(formatted_launched_regions_list) return f"JumpStart is available in {formatted_launched_regions_str} regions." @@ -59,8 +61,11 @@ def get_jumpstart_content_bucket(region: str) -> str: def get_formatted_manifest( manifest: List[Dict], ) -> Dict[JumpStartVersionedModelId, JumpStartModelHeader]: - """Returns formatted manifest dictionary from raw manifest. Keys are JumpStartVersionedModelId objects, - values are JumpStartModelHeader objects.""" + """Returns formatted manifest dictionary from raw manifest. + + Keys are JumpStartVersionedModelId objects, values are + ``JumpStartModelHeader`` objects. + """ manifest_dict = {} for header in manifest: header_obj = JumpStartModelHeader(header) @@ -71,8 +76,10 @@ def get_formatted_manifest( def get_sagemaker_version() -> str: - """Returns sagemaker library version by reading __version__ variable - in module. In order to maintain compatibility with the ``semantic_version`` + """Returns sagemaker library version. + + Function reads ``__version__`` variable in ``sagemaker`` module. + In order to maintain compatibility with the ``semantic_version`` library, versions with fewer than 2, or more than 3, periods are rejected. All versions that cannot be parsed with ``semantic_version`` are also rejected. diff --git a/src/sagemaker/utilities/cache.py b/src/sagemaker/utilities/cache.py index 9c5e997d00..0c1cbe0edb 100644 --- a/src/sagemaker/utilities/cache.py +++ b/src/sagemaker/utilities/cache.py @@ -10,6 +10,9 @@ # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. +"""This module defines a LRU cache class.""" +from __future__ import absolute_import + import datetime import collections from typing import Any, TypeVar, Generic, Callable, Optional @@ -20,14 +23,18 @@ class LRUCache(Generic[KeyType, ValType]): """Class that implements LRU cache with expiring items. - LRU caches remove items in a FIFO manner, such that the oldest items to be used are the first to be removed. - If you attempt to retrieve a cache item that is older than the expiration time, the item will be invalidated. + + LRU caches remove items in a FIFO manner, such that the oldest + items to be used are the first to be removed. + If you attempt to retrieve a cache item that is older than the + expiration time, the item will be invalidated. """ class Element: """Class describes the values in the cache. - This object stores the value itself as well as a timestamp so that this element can be invalidated if - it becomes too old. + + This object stores the value itself as well as a timestamp so that this + element can be invalidated if it becomes too old. """ def __init__(self, value: ValType, creation_time: datetime.datetime): @@ -52,10 +59,10 @@ def __init__( max_cache_items (int): Maximum number of items to store in cache. expiration_time (datetime.timedelta): Maximum time duration a cache element can persist before being invalidated. - retrieval_function (Callable[[KeyType, ValType], ValType]): Function which maps cache keys - and current values to new values. This function must have kwarg arguments ``key`` and - and ``value``. This function is called as a fallback when the key is not found in the - cache, or a key has expired. + retrieval_function (Callable[[KeyType, ValType], ValType]): Function which maps cache + keys and current values to new values. This function must have kwarg arguments + ``key`` and and ``value``. This function is called as a fallback when the key + is not found in the cache, or a key has expired. """ self._max_cache_items = max_cache_items @@ -84,8 +91,8 @@ def get(self, key: KeyType, data_source_fallback: Optional[bool] = True) -> ValT Args: key (KeyType): Key in cache to retrieve. - data_source_fallback (Optional[bool]): True if data should be retrieved if it's stale or not in cache. - Default: True. + data_source_fallback (Optional[bool]): True if data should be retrieved if + it's stale or not in cache. Default: True. Raises: KeyError: If key is not found in cache or is outdated and ``data_source_fallback`` is False. @@ -93,15 +100,16 @@ def get(self, key: KeyType, data_source_fallback: Optional[bool] = True) -> ValT if data_source_fallback: if key in self._lru_cache: return self._get_item(key, False) - else: - self.put(key) - return self._get_item(key, False) + self.put(key) + return self._get_item(key, False) return self._get_item(key, True) def put(self, key: KeyType, value: Optional[ValType] = None) -> None: - """Adds key to cache using retrieval_function. If value is provided, this is used instead. - If the key is already in cache, the old element is removed. - If the cache size exceeds the size limit, old elements are removed in order to meet the limit. + """Adds key to cache using ``retrieval_function``. + + If value is provided, this is used instead. If the key is already in cache, + the old element is removed. If the cache size exceeds the size limit, old + elements are removed in order to meet the limit. Args: key (KeyType): Key in cache to retrieve. @@ -124,12 +132,15 @@ def put(self, key: KeyType, value: Optional[ValType] = None) -> None: ) def _get_item(self, key: KeyType, fail_on_old_value: bool) -> ValType: - """Returns value from cache corresponding to key. If ``fail_on_old_value``, a - KeyError is thrown instead of a new value getting fetched. + """Returns value from cache corresponding to key. + + If ``fail_on_old_value``, a KeyError is thrown instead of a new value + getting fetched. Args: key (KeyType): Key in cache to retrieve. - fail_on_old_value (bool): True if a KeyError is thrown when the cache value is old. + fail_on_old_value (bool): True if a KeyError is thrown when the cache value + is old. Raises: KeyError: If key is not in cache or if key is old in cache @@ -142,7 +153,9 @@ def _get_item(self, key: KeyType, fail_on_old_value: bool) -> ValType: if element_age > self._expiration_time: if fail_on_old_value: raise KeyError(f"{key} is old! Created at {element.creation_time}") - element.value = self._retrieval_function(key=key, value=element.value) # type: ignore + element.value = self._retrieval_function( # type: ignore + key=key, value=element.value + ) element.creation_time = curr_time self._lru_cache[key] = element return element.value diff --git a/tests/unit/sagemaker/jumpstart/test_cache.py b/tests/unit/sagemaker/jumpstart/test_cache.py index 206af358a2..71ff109399 100644 --- a/tests/unit/sagemaker/jumpstart/test_cache.py +++ b/tests/unit/sagemaker/jumpstart/test_cache.py @@ -10,14 +10,14 @@ # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. +from __future__ import absolute_import import copy import datetime import io import json -from typing import Any, Dict, Tuple, Union import botocore -from mock.mock import MagicMock, Mock +from mock.mock import MagicMock import pytest from mock import patch @@ -81,31 +81,36 @@ def patched_get_file_from_s3( "model_id": "tensorflow-ic-imagenet-inception-v3-classification-4", "version": "1.0.0", "min_version": "2.49.0", - "spec_key": "community_models_specs/tensorflow-ic-imagenet-inception-v3-classification-4/specs_v1.0.0.json", + "spec_key": "community_models_specs/tensorflow-ic-imagenet" + "-inception-v3-classification-4/specs_v1.0.0.json", }, { "model_id": "tensorflow-ic-imagenet-inception-v3-classification-4", "version": "2.0.0", "min_version": "2.49.0", - "spec_key": "community_models_specs/tensorflow-ic-imagenet-inception-v3-classification-4/specs_v2.0.0.json", + "spec_key": "community_models_specs/tensorflow-ic-imagenet" + "-inception-v3-classification-4/specs_v2.0.0.json", }, { "model_id": "pytorch-ic-imagenet-inception-v3-classification-4", "version": "1.0.0", "min_version": "2.49.0", - "spec_key": "community_models_specs/pytorch-ic-imagenet-inception-v3-classification-4/specs_v1.0.0.json", + "spec_key": "community_models_specs/pytorch-ic-" + "imagenet-inception-v3-classification-4/specs_v1.0.0.json", }, { "model_id": "pytorch-ic-imagenet-inception-v3-classification-4", "version": "2.0.0", "min_version": "2.49.0", - "spec_key": "community_models_specs/pytorch-ic-imagenet-inception-v3-classification-4/specs_v2.0.0.json", + "spec_key": "community_models_specs/pytorch-ic-imagenet-" + "inception-v3-classification-4/specs_v2.0.0.json", }, { "model_id": "tensorflow-ic-imagenet-inception-v3-classification-4", "version": "3.0.0", "min_version": "4.49.0", - "spec_key": "community_models_specs/tensorflow-ic-imagenet-inception-v3-classification-4/specs_v3.0.0.json", + "spec_key": "community_models_specs/tensorflow-ic-" + "imagenet-inception-v3-classification-4/specs_v3.0.0.json", }, ] return JumpStartCachedS3ContentValue( @@ -134,7 +139,8 @@ def test_jumpstart_cache_get_header(): "model_id": "tensorflow-ic-imagenet-inception-v3-classification-4", "version": "2.0.0", "min_version": "2.49.0", - "spec_key": "community_models_specs/tensorflow-ic-imagenet-inception-v3-classification-4/specs_v2.0.0.json", + "spec_key": "community_models_specs/tensorflow-ic" + "-imagenet-inception-v3-classification-4/specs_v2.0.0.json", } ) == cache.get_header(model_id="tensorflow-ic-imagenet-inception-v3-classification-4") @@ -147,7 +153,8 @@ def test_jumpstart_cache_get_header(): "model_id": "tensorflow-ic-imagenet-inception-v3-classification-4", "version": "2.0.0", "min_version": "2.49.0", - "spec_key": "community_models_specs/tensorflow-ic-imagenet-inception-v3-classification-4/specs_v2.0.0.json", + "spec_key": "community_models_specs/tensorflow-ic" + "-imagenet-inception-v3-classification-4/specs_v2.0.0.json", } ) == cache.get_header(model_id="tensorflow-ic-imagenet-inception-v3-classification-4") @@ -158,10 +165,11 @@ def test_jumpstart_cache_get_header(): "model_id": "tensorflow-ic-imagenet-inception-v3-classification-4", "version": "2.0.0", "min_version": "2.49.0", - "spec_key": "community_models_specs/tensorflow-ic-imagenet-inception-v3-classification-4/specs_v2.0.0.json", + "spec_key": "community_models_specs/tensorflow-ic" + "-imagenet-inception-v3-classification-4/specs_v2.0.0.json", } ) == cache.get_header( - model_id="tensorflow-ic-imagenet-inception-v3-classification-4", semantic_version="*" + model_id="tensorflow-ic-imagenet-inception-v3-classification-4", semantic_version_str="*" ) assert JumpStartModelHeader( @@ -169,10 +177,11 @@ def test_jumpstart_cache_get_header(): "model_id": "tensorflow-ic-imagenet-inception-v3-classification-4", "version": "2.0.0", "min_version": "2.49.0", - "spec_key": "community_models_specs/tensorflow-ic-imagenet-inception-v3-classification-4/specs_v2.0.0.json", + "spec_key": "community_models_specs/tensorflow-ic-" + "imagenet-inception-v3-classification-4/specs_v2.0.0.json", } ) == cache.get_header( - model_id="tensorflow-ic-imagenet-inception-v3-classification-4", semantic_version="2.*" + model_id="tensorflow-ic-imagenet-inception-v3-classification-4", semantic_version_str="2.*" ) assert JumpStartModelHeader( @@ -180,10 +189,12 @@ def test_jumpstart_cache_get_header(): "model_id": "tensorflow-ic-imagenet-inception-v3-classification-4", "version": "2.0.0", "min_version": "2.49.0", - "spec_key": "community_models_specs/tensorflow-ic-imagenet-inception-v3-classification-4/specs_v2.0.0.json", + "spec_key": "community_models_specs/tensorflow-ic-" + "imagenet-inception-v3-classification-4/specs_v2.0.0.json", } ) == cache.get_header( - model_id="tensorflow-ic-imagenet-inception-v3-classification-4", semantic_version="2.*.*" + model_id="tensorflow-ic-imagenet-inception-v3-classification-4", + semantic_version_str="2.*.*", ) assert JumpStartModelHeader( @@ -191,10 +202,12 @@ def test_jumpstart_cache_get_header(): "model_id": "tensorflow-ic-imagenet-inception-v3-classification-4", "version": "2.0.0", "min_version": "2.49.0", - "spec_key": "community_models_specs/tensorflow-ic-imagenet-inception-v3-classification-4/specs_v2.0.0.json", + "spec_key": "community_models_specs/tensorflow-ic-" + "imagenet-inception-v3-classification-4/specs_v2.0.0.json", } ) == cache.get_header( - model_id="tensorflow-ic-imagenet-inception-v3-classification-4", semantic_version="2.0.0" + model_id="tensorflow-ic-imagenet-inception-v3-classification-4", + semantic_version_str="2.0.0", ) assert JumpStartModelHeader( @@ -202,10 +215,12 @@ def test_jumpstart_cache_get_header(): "model_id": "tensorflow-ic-imagenet-inception-v3-classification-4", "version": "1.0.0", "min_version": "2.49.0", - "spec_key": "community_models_specs/tensorflow-ic-imagenet-inception-v3-classification-4/specs_v1.0.0.json", + "spec_key": "community_models_specs/tensorflow-ic-" + "imagenet-inception-v3-classification-4/specs_v1.0.0.json", } ) == cache.get_header( - model_id="tensorflow-ic-imagenet-inception-v3-classification-4", semantic_version="1.0.0" + model_id="tensorflow-ic-imagenet-inception-v3-classification-4", + semantic_version_str="1.0.0", ) assert JumpStartModelHeader( @@ -213,10 +228,11 @@ def test_jumpstart_cache_get_header(): "model_id": "tensorflow-ic-imagenet-inception-v3-classification-4", "version": "1.0.0", "min_version": "2.49.0", - "spec_key": "community_models_specs/tensorflow-ic-imagenet-inception-v3-classification-4/specs_v1.0.0.json", + "spec_key": "community_models_specs/tensorflow-ic-" + "imagenet-inception-v3-classification-4/specs_v1.0.0.json", } ) == cache.get_header( - model_id="tensorflow-ic-imagenet-inception-v3-classification-4", semantic_version="1.*" + model_id="tensorflow-ic-imagenet-inception-v3-classification-4", semantic_version_str="1.*" ) assert JumpStartModelHeader( @@ -224,32 +240,36 @@ def test_jumpstart_cache_get_header(): "model_id": "tensorflow-ic-imagenet-inception-v3-classification-4", "version": "1.0.0", "min_version": "2.49.0", - "spec_key": "community_models_specs/tensorflow-ic-imagenet-inception-v3-classification-4/specs_v1.0.0.json", + "spec_key": "community_models_specs/tensorflow-ic-" + "imagenet-inception-v3-classification-4/specs_v1.0.0.json", } ) == cache.get_header( - model_id="tensorflow-ic-imagenet-inception-v3-classification-4", semantic_version="1.*.*" + model_id="tensorflow-ic-imagenet-inception-v3-classification-4", + semantic_version_str="1.*.*", ) with pytest.raises(KeyError) as e: cache.get_header( - model_id="tensorflow-ic-imagenet-inception-v3-classification-4", semantic_version="3.*" + model_id="tensorflow-ic-imagenet-inception-v3-classification-4", + semantic_version_str="3.*", ) assert ( - "Unable to find model manifest for tensorflow-ic-imagenet-inception-v3-classification-4 with version 3.* " - "compatible with your SageMaker version (2.68.3). Consider upgrading your SageMaker library to at least " - "version 4.49.0 so you can use version 3.0.0 of tensorflow-ic-imagenet-inception-v3-classification-4." - in str(e.value) + "Unable to find model manifest for tensorflow-ic-imagenet-inception-v3-classification-4 " + "with version 3.* compatible with your SageMaker version (2.68.3). Consider upgrading " + "your SageMaker library to at least version 4.49.0 so you can use version 3.0.0 of " + "tensorflow-ic-imagenet-inception-v3-classification-4." in str(e.value) ) with pytest.raises(KeyError) as e: cache.get_header( - model_id="pytorch-ic-imagenet-inception-v3-classification-4", semantic_version="3.*" + model_id="pytorch-ic-imagenet-inception-v3-classification-4", semantic_version_str="3.*" ) assert "Consider upgrading" not in str(e.value) with pytest.raises(ValueError): cache.get_header( - model_id="tensorflow-ic-imagenet-inception-v3-classification-4", semantic_version="BAD" + model_id="tensorflow-ic-imagenet-inception-v3-classification-4", + semantic_version_str="BAD", ) with pytest.raises(KeyError): @@ -338,7 +358,8 @@ def test_jumpstart_cache_evaluates_md5_hash(mock_boto3_client): "model_id": "pytorch-ic-imagenet-inception-v3-classification-4", "version": "2.0.0", "min_version": "2.49.0", - "spec_key": "community_models_specs/pytorch-ic-imagenet-inception-v3-classification-4/specs_v2.0.0.json", + "spec_key": "community_models_specs/pytorch-ic-" + "imagenet-inception-v3-classification-4/specs_v2.0.0.json", } ] ) @@ -425,7 +446,8 @@ def test_jumpstart_cache_makes_correct_s3_calls(mock_boto3_client): "model_id": "pytorch-ic-imagenet-inception-v3-classification-4", "version": "2.0.0", "min_version": "2.49.0", - "spec_key": "community_models_specs/pytorch-ic-imagenet-inception-v3-classification-4/specs_v2.0.0.json", + "spec_key": "community_models_specs/pytorch-ic-" + "imagenet-inception-v3-classification-4/specs_v2.0.0.json", } ] ) @@ -462,7 +484,8 @@ def test_jumpstart_cache_makes_correct_s3_calls(mock_boto3_client): mock_boto3_client.return_value.get_object.assert_called_with( Bucket=bucket_name, - Key="community_models_specs/pytorch-ic-imagenet-inception-v3-classification-4/specs_v2.0.0.json", + Key="community_models_specs/pytorch-ic-imagenet-" + "inception-v3-classification-4/specs_v2.0.0.json", ) mock_boto3_client.return_value.head_object.assert_not_called() @@ -485,7 +508,8 @@ def test_jumpstart_cache_handles_bad_semantic_version_manifest_key_cache(): "model_id": "tensorflow-ic-imagenet-inception-v3-classification-4", "version": "1.0.0", "min_version": "2.49.0", - "spec_key": "community_models_specs/tensorflow-ic-imagenet-inception-v3-classification-4/specs_v1.0.0.json", + "spec_key": "community_models_specs/tensorflow-ic-" + "imagenet-inception-v3-classification-4/specs_v1.0.0.json", } ) == cache.get_header( model_id="tensorflow-ic-imagenet-inception-v3-classification-4", @@ -515,17 +539,17 @@ def test_jumpstart_cache_get_specs(): model_id, version = "tensorflow-ic-imagenet-inception-v3-classification-4", "2.0.0" assert get_spec_from_base_spec(model_id, version) == cache.get_specs( - model_id=model_id, semantic_version=version + model_id=model_id, semantic_version_str=version ) model_id, version = "tensorflow-ic-imagenet-inception-v3-classification-4", "1.0.0" assert get_spec_from_base_spec(model_id, version) == cache.get_specs( - model_id=model_id, semantic_version=version + model_id=model_id, semantic_version_str=version ) model_id = "pytorch-ic-imagenet-inception-v3-classification-4" assert get_spec_from_base_spec(model_id, "1.0.0") == cache.get_specs( - model_id=model_id, semantic_version="1.*" + model_id=model_id, semantic_version_str="1.*" ) with pytest.raises(KeyError): @@ -534,7 +558,7 @@ def test_jumpstart_cache_get_specs(): ) with pytest.raises(KeyError): - cache.get_specs(model_id=model_id, semantic_version="9.*") + cache.get_specs(model_id=model_id, semantic_version_str="9.*") with pytest.raises(ValueError): - cache.get_specs(model_id=model_id, semantic_version="BAD") + cache.get_specs(model_id=model_id, semantic_version_str="BAD") diff --git a/tests/unit/sagemaker/jumpstart/test_types.py b/tests/unit/sagemaker/jumpstart/test_types.py index 217722669c..b04c0803ac 100644 --- a/tests/unit/sagemaker/jumpstart/test_types.py +++ b/tests/unit/sagemaker/jumpstart/test_types.py @@ -10,6 +10,7 @@ # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. +from __future__ import absolute_import import copy from sagemaker.jumpstart.types import JumpStartModelSpecs, JumpStartModelHeader @@ -84,8 +85,8 @@ def test_jumpstart_model_specs(): assert specs1.model_id == "pytorch-ic-mobilenet-v2" assert specs1.version == "1.0.0" assert specs1.min_sdk_version == "2.49.0" - assert specs1.training_supported == True - assert specs1.incremental_training_supported == True + assert specs1.training_supported + assert specs1.incremental_training_supported assert specs1.hosting_ecr_specs == { "framework": "pytorch", "framework_version": "1.7.0", diff --git a/tests/unit/sagemaker/jumpstart/test_utils.py b/tests/unit/sagemaker/jumpstart/test_utils.py index cc8d02cb6f..944904d908 100644 --- a/tests/unit/sagemaker/jumpstart/test_utils.py +++ b/tests/unit/sagemaker/jumpstart/test_utils.py @@ -10,12 +10,12 @@ # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. +from __future__ import absolute_import from mock.mock import patch import pytest from sagemaker.jumpstart import utils from sagemaker.jumpstart.constants import REGION_NAME_SET from sagemaker.jumpstart.types import JumpStartModelHeader, JumpStartVersionedModelId -from sagemaker.jumpstart.types import JumpStartLaunchedRegionInfo def test_get_jumpstart_content_bucket(): diff --git a/tests/unit/sagemaker/utilities/test_cache.py b/tests/unit/sagemaker/utilities/test_cache.py index a71c2e2abd..0f5e053775 100644 --- a/tests/unit/sagemaker/utilities/test_cache.py +++ b/tests/unit/sagemaker/utilities/test_cache.py @@ -10,6 +10,7 @@ # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. +from __future__ import absolute_import from typing import Optional, Union from mock.mock import MagicMock, patch import pytest From 37a502eff17549245f5faf4ea2352f1bbc737594 Mon Sep 17 00:00:00 2001 From: Evan Kravitz Date: Thu, 11 Nov 2021 18:40:45 +0000 Subject: [PATCH 03/13] change: improve unit tests for jumpstart cache --- tests/unit/sagemaker/jumpstart/test_cache.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/unit/sagemaker/jumpstart/test_cache.py b/tests/unit/sagemaker/jumpstart/test_cache.py index 71ff109399..52383d39f0 100644 --- a/tests/unit/sagemaker/jumpstart/test_cache.py +++ b/tests/unit/sagemaker/jumpstart/test_cache.py @@ -335,6 +335,7 @@ def test_jumpstart_cache_accepts_input_parameters(): manifest_file_s3_key=manifest_file_key, ) + assert cache.get_manifest_file_s3_key() == manifest_file_key assert cache.get_region() == region assert cache.get_bucket() == bucket assert cache._s3_cache._max_cache_items == max_s3_cache_items From d7ff8f8b2115f3d9c7198a052580bd17a0843bb7 Mon Sep 17 00:00:00 2001 From: Evan Kravitz Date: Fri, 12 Nov 2021 20:21:23 +0000 Subject: [PATCH 04/13] fix: update jumpstart cache and unit tests --- src/sagemaker/jumpstart/cache.py | 163 +++++++++++-------- src/sagemaker/jumpstart/constants.py | 13 +- src/sagemaker/jumpstart/parameters.py | 20 +++ src/sagemaker/jumpstart/types.py | 52 +++++- src/sagemaker/jumpstart/utils.py | 30 ++-- src/sagemaker/utilities/cache.py | 19 ++- tests/unit/sagemaker/jumpstart/test_cache.py | 104 ++++++++---- tests/unit/sagemaker/jumpstart/test_types.py | 26 +-- tests/unit/sagemaker/jumpstart/test_utils.py | 28 ++-- tests/unit/sagemaker/utilities/test_cache.py | 52 ++++-- 10 files changed, 342 insertions(+), 165 deletions(-) create mode 100644 src/sagemaker/jumpstart/parameters.py diff --git a/src/sagemaker/jumpstart/cache.py b/src/sagemaker/jumpstart/cache.py index cd89773f43..2cd281efe4 100644 --- a/src/sagemaker/jumpstart/cache.py +++ b/src/sagemaker/jumpstart/cache.py @@ -13,10 +13,21 @@ """This module defines the JumpStartModelsCache class.""" from __future__ import absolute_import import datetime -from typing import List, Optional +from typing import Optional import json import boto3 +import botocore import semantic_version +from sagemaker.jumpstart.constants import ( + JUMPSTART_DEFAULT_MANIFEST_FILE_S3_KEY, + JUMPSTART_DEFAULT_REGION_NAME, +) +from sagemaker.jumpstart.parameters import ( + JUMPSTART_DEFAULT_MAX_S3_CACHE_ITEMS, + JUMPSTART_DEFAULT_MAX_SEMANTIC_VERSION_CACHE_ITEMS, + JUMPSTART_DEFAULT_S3_CACHE_EXPIRATION_HORIZON, + JUMPSTART_DEFAULT_SEMANTIC_VERSION_CACHE_EXPIRATION_HORIZON, +) from sagemaker.jumpstart.types import ( JumpStartCachedS3ContentKey, JumpStartCachedS3ContentValue, @@ -28,16 +39,6 @@ from sagemaker.jumpstart import utils from sagemaker.utilities.cache import LRUCache -DEFAULT_REGION_NAME = boto3.session.Session().region_name - -DEFAULT_MAX_S3_CACHE_ITEMS = 20 -DEFAULT_S3_CACHE_EXPIRATION_TIME = datetime.timedelta(hours=6) - -DEFAULT_MAX_SEMANTIC_VERSION_CACHE_ITEMS = 20 -DEFAULT_SEMANTIC_VERSION_CACHE_EXPIRATION_TIME = datetime.timedelta(hours=6) - -DEFAULT_MANIFEST_FILE_S3_KEY = "models_manifest.json" - class JumpStartModelsCache: """Class that implements a cache for JumpStart models manifests and specs. @@ -48,56 +49,71 @@ class JumpStartModelsCache: def __init__( self, - region: Optional[str] = DEFAULT_REGION_NAME, - max_s3_cache_items: Optional[int] = DEFAULT_MAX_S3_CACHE_ITEMS, - s3_cache_expiration_time: Optional[datetime.timedelta] = DEFAULT_S3_CACHE_EXPIRATION_TIME, - max_semantic_version_cache_items: Optional[int] = DEFAULT_MAX_SEMANTIC_VERSION_CACHE_ITEMS, - semantic_version_cache_expiration_time: Optional[ + region: Optional[str] = JUMPSTART_DEFAULT_REGION_NAME, + max_s3_cache_items: Optional[int] = JUMPSTART_DEFAULT_MAX_S3_CACHE_ITEMS, + s3_cache_expiration_horizon: Optional[ datetime.timedelta - ] = DEFAULT_SEMANTIC_VERSION_CACHE_EXPIRATION_TIME, - manifest_file_s3_key: Optional[str] = DEFAULT_MANIFEST_FILE_S3_KEY, - bucket: Optional[str] = None, + ] = JUMPSTART_DEFAULT_S3_CACHE_EXPIRATION_HORIZON, + max_semantic_version_cache_items: Optional[ + int + ] = JUMPSTART_DEFAULT_MAX_SEMANTIC_VERSION_CACHE_ITEMS, + semantic_version_cache_expiration_horizon: Optional[ + datetime.timedelta + ] = JUMPSTART_DEFAULT_SEMANTIC_VERSION_CACHE_EXPIRATION_HORIZON, + manifest_file_s3_key: Optional[str] = JUMPSTART_DEFAULT_MANIFEST_FILE_S3_KEY, + s3_bucket_name: Optional[str] = None, + s3_client_config: Optional[botocore.config.Config] = None, ) -> None: """Initialize a ``JumpStartModelsCache`` instance. Args: region (Optional[str]): AWS region to associate with cache. Default: region associated - with botocore session. - max_s3_cache_items (Optional[int]): Maximum number of files to store in s3 cache. + with boto3 session. + max_s3_cache_items (Optional[int]): Maximum number of items to store in s3 cache. Default: 20. - s3_cache_expiration_time (Optional[datetime.timedelta]): Maximum time to hold items in - s3 cache before invalidation. Default: 6 hours. - max_semantic_version_cache_items (Optional[int]): Maximum number of files to store in + s3_cache_expiration_horizon (Optional[datetime.timedelta]): Maximum time to hold + items in s3 cache before invalidation. Default: 6 hours. + max_semantic_version_cache_items (Optional[int]): Maximum number of items to store in semantic version cache. Default: 20. - semantic_version_cache_expiration_time (Optional[datetime.timedelta]): Maximum time to - hold items in semantic version cache before invalidation. Default: 6 hours. - bucket (Optional[str]): S3 bucket to associate with cache. Default: JumpStart-hosted - content bucket for region. + semantic_version_cache_expiration_horizon (Optional[datetime.timedelta]): + Maximum time to hold items in semantic version cache before invalidation. + Default: 6 hours. + s3_bucket_name (Optional[str]): S3 bucket to associate with cache. + Default: JumpStart-hosted content bucket for region. + s3_client_config (Optional[botocore.config.Config]): s3 client config to use for cache. + Default: None (no config). """ self._region = region self._s3_cache = LRUCache[JumpStartCachedS3ContentKey, JumpStartCachedS3ContentValue]( max_cache_items=max_s3_cache_items, - expiration_time=s3_cache_expiration_time, + expiration_horizon=s3_cache_expiration_horizon, retrieval_function=self._get_file_from_s3, ) self._model_id_semantic_version_manifest_key_cache = LRUCache[ JumpStartVersionedModelId, JumpStartVersionedModelId ]( max_cache_items=max_semantic_version_cache_items, - expiration_time=semantic_version_cache_expiration_time, + expiration_horizon=semantic_version_cache_expiration_horizon, retrieval_function=self._get_manifest_key_from_model_id_semantic_version, ) self._manifest_file_s3_key = manifest_file_s3_key - self._bucket = ( - utils.get_jumpstart_content_bucket(self._region) if bucket is None else bucket + self.s3_bucket_name = ( + utils.get_jumpstart_content_bucket(self._region) + if s3_bucket_name is None + else s3_bucket_name + ) + self._s3_client = ( + boto3.client("s3", region_name=self._region, config=s3_client_config) + if s3_client_config + else boto3.client("s3", region_name=self._region) ) - self._has_retried_cache_refresh = False def set_region(self, region: str) -> None: """Set region for cache. Clears cache after new region is set.""" - self._region = region - self.clear() + if region != self._region: + self._region = region + self.clear() def get_region(self) -> str: """Return region for cache.""" @@ -105,21 +121,23 @@ def get_region(self) -> str: def set_manifest_file_s3_key(self, key: str) -> None: """Set manifest file s3 key. Clears cache after new key is set.""" - self._manifest_file_s3_key = key - self.clear() + if key != self._manifest_file_s3_key: + self._manifest_file_s3_key = key + self.clear() def get_manifest_file_s3_key(self) -> None: """Return manifest file s3 key for cache.""" return self._manifest_file_s3_key - def set_bucket(self, bucket: str) -> None: + def set_s3_bucket_name(self, s3_bucket_name: str) -> None: """Set s3 bucket used for cache.""" - self._bucket = bucket - self.clear() + if s3_bucket_name != self.s3_bucket_name: + self.s3_bucket_name = s3_bucket_name + self.clear() def get_bucket(self) -> None: """Return bucket used for cache.""" - return self._bucket + return self.s3_bucket_name def _get_manifest_key_from_model_id_semantic_version( self, @@ -128,13 +146,18 @@ def _get_manifest_key_from_model_id_semantic_version( ) -> JumpStartVersionedModelId: """Return model id and version in manifest that matches semantic version/id. + Uses ``semantic_version`` to perform version comparison. The highest model version + matching the semantic version is used, which is compatible with the SageMaker + version. + Args: key (JumpStartVersionedModelId): Key for which to fetch versioned model id. value (Optional[JumpStartVersionedModelId]): Unused variable for current value of old cached model id/version. Raises: - KeyError: If the semantic version is not found in the manifest. + KeyError: If the semantic version is not found in the manifest, or is found but + the SageMaker version needs to be upgraded in order for the model to be used. """ model_id, version = key.model_id, key.version @@ -147,7 +170,7 @@ def _get_manifest_key_from_model_id_semantic_version( versions_compatible_with_sagemaker = [ semantic_version.Version(header.version) - for _, header in manifest.items() + for header in manifest.values() if header.model_id == model_id and semantic_version.Version(header.min_version) <= semantic_version.Version(sm_version) ] @@ -164,7 +187,7 @@ def _get_manifest_key_from_model_id_semantic_version( versions_incompatible_with_sagemaker = [ semantic_version.Version(header.version) - for _, header in manifest.items() + for header in manifest.values() if header.model_id == model_id ] sm_incompatible_model_version = spec.select(versions_incompatible_with_sagemaker) @@ -172,11 +195,11 @@ def _get_manifest_key_from_model_id_semantic_version( model_version_to_use_incompatible_with_sagemaker = str(sm_incompatible_model_version) sm_version_to_use = [ header.min_version - for _, header in manifest.items() + for header in manifest.values() if header.model_id == model_id and header.version == model_version_to_use_incompatible_with_sagemaker ] - assert len(sm_version_to_use) == 1 + assert len(sm_version_to_use) == 1 # ``manifest`` dict should already enforce this sm_version_to_use = sm_version_to_use[0] error_msg = ( @@ -187,7 +210,7 @@ def _get_manifest_key_from_model_id_semantic_version( f"{model_version_to_use_incompatible_with_sagemaker} of {model_id}." ) raise KeyError(error_msg) - error_msg = f"Unable to find model manifest for {model_id} with version {version}" + error_msg = f"Unable to find model manifest for {model_id} with version {version}." raise KeyError(error_msg) def _get_file_from_s3( @@ -210,33 +233,49 @@ def _get_file_from_s3( file_type, s3_key = key.file_type, key.s3_key - s3_client = boto3.client("s3", region_name=self._region) - if file_type == JumpStartS3FileType.MANIFEST: - etag = s3_client.head_object(Bucket=self._bucket, Key=s3_key)["ETag"] + etag = self._s3_client.head_object(Bucket=self.s3_bucket_name, Key=s3_key)["ETag"] if value is not None and etag == value.md5_hash: return value - response = s3_client.get_object(Bucket=self._bucket, Key=s3_key) + response = self._s3_client.get_object(Bucket=self.s3_bucket_name, Key=s3_key) formatted_body = json.loads(response["Body"].read().decode("utf-8")) return JumpStartCachedS3ContentValue( formatted_file_content=utils.get_formatted_manifest(formatted_body), md5_hash=etag, ) if file_type == JumpStartS3FileType.SPECS: - response = s3_client.get_object(Bucket=self._bucket, Key=s3_key) + response = self._s3_client.get_object(Bucket=self.s3_bucket_name, Key=s3_key) formatted_body = json.loads(response["Body"].read().decode("utf-8")) return JumpStartCachedS3ContentValue( formatted_file_content=JumpStartModelSpecs(formatted_body) ) - raise RuntimeError(f"Bad value for key: {key}") + raise ValueError( + f"Bad value for key '{key}': must be in {[JumpStartS3FileType.MANIFEST, JumpStartS3FileType.SPECS]}" + ) def get_header( self, model_id: str, semantic_version_str: Optional[str] = None - ) -> List[JumpStartModelHeader]: - """Return list of headers for a given JumpStart model id and semantic version. + ) -> JumpStartModelHeader: + """Return header for a given JumpStart model id and semantic version. + + Args: + model_id (str): model id for which to get a header. + semantic_version_str (Optional[str]): The semantic version for which to get a + header. If None, the highest compatible version is returned. + """ + + return self._get_header_impl(model_id, 0, semantic_version_str) + + def _get_header_impl( + self, model_id: str, attempt: int, semantic_version_str: Optional[str] = None + ) -> JumpStartModelHeader: + """Lower-level function to return header. + + Allows a single retry if the cache is old. Args: model_id (str): model id for which to get a header. + attempt (int): attempt number at retrieving a header. semantic_version_str (Optional[str]): The semantic version for which to get a header. If None, the highest compatible version is returned. """ @@ -248,17 +287,12 @@ def get_header( JumpStartCachedS3ContentKey(JumpStartS3FileType.MANIFEST, self._manifest_file_s3_key) ).formatted_file_content try: - header = manifest[versioned_model_id] - if self._has_retried_cache_refresh: - self._has_retried_cache_refresh = False - return header + return manifest[versioned_model_id] except KeyError: - if self._has_retried_cache_refresh: - self._has_retried_cache_refresh = False + if attempt > 0: raise self.clear() - self._has_retried_cache_refresh = True - return self.get_header(model_id, semantic_version) + return self._get_header_impl(model_id, attempt + 1, semantic_version_str) def get_specs( self, model_id: str, semantic_version_str: Optional[str] = None @@ -278,7 +312,6 @@ def get_specs( ).formatted_file_content def clear(self) -> None: - """Clears the model id/version and s3 cache and resets ``_has_retried_cache_refresh``.""" + """Clears the model id/version and s3 cache.""" self._s3_cache.clear() self._model_id_semantic_version_manifest_key_cache.clear() - self._has_retried_cache_refresh = False diff --git a/src/sagemaker/jumpstart/constants.py b/src/sagemaker/jumpstart/constants.py index 65700ecec8..71452433b6 100644 --- a/src/sagemaker/jumpstart/constants.py +++ b/src/sagemaker/jumpstart/constants.py @@ -13,10 +13,17 @@ """This module stores constants related to SageMaker JumpStart.""" from __future__ import absolute_import from typing import Set +import boto3 from sagemaker.jumpstart.types import JumpStartLaunchedRegionInfo -LAUNCHED_REGIONS: Set[JumpStartLaunchedRegionInfo] = set() +JUMPSTART_LAUNCHED_REGIONS: Set[JumpStartLaunchedRegionInfo] = set() -REGION_NAME_TO_LAUNCHED_REGION_DICT = {region.region_name: region for region in LAUNCHED_REGIONS} -REGION_NAME_SET = {region.region_name for region in LAUNCHED_REGIONS} +JUMPSTART_REGION_NAME_TO_LAUNCHED_REGION_DICT = { + region.region_name: region for region in JUMPSTART_LAUNCHED_REGIONS +} +JUMPSTART_REGION_NAME_SET = {region.region_name for region in JUMPSTART_LAUNCHED_REGIONS} + +JUMPSTART_DEFAULT_REGION_NAME = boto3.session.Session().region_name + +JUMPSTART_DEFAULT_MANIFEST_FILE_S3_KEY = "models_manifest.json" diff --git a/src/sagemaker/jumpstart/parameters.py b/src/sagemaker/jumpstart/parameters.py new file mode 100644 index 0000000000..2010c39382 --- /dev/null +++ b/src/sagemaker/jumpstart/parameters.py @@ -0,0 +1,20 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""This module stores parameters related to SageMaker JumpStart.""" +from __future__ import absolute_import +import datetime + +JUMPSTART_DEFAULT_MAX_S3_CACHE_ITEMS = 20 +JUMPSTART_DEFAULT_MAX_SEMANTIC_VERSION_CACHE_ITEMS = 20 +JUMPSTART_DEFAULT_S3_CACHE_EXPIRATION_HORIZON = datetime.timedelta(hours=6) +JUMPSTART_DEFAULT_SEMANTIC_VERSION_CACHE_EXPIRATION_HORIZON = datetime.timedelta(hours=6) diff --git a/src/sagemaker/jumpstart/types.py b/src/sagemaker/jumpstart/types.py index a6f19e29e7..32cb5e7e02 100644 --- a/src/sagemaker/jumpstart/types.py +++ b/src/sagemaker/jumpstart/types.py @@ -20,8 +20,9 @@ class JumpStartDataHolderType: """Base class for many JumpStart types. Allows objects to be added to dicts and sets, - and improves string representation. This class allows different objects with the same - attributes and types to have equality. + and improves string representation. This class overrides the ``__eq__`` + and ``__hash__`` methods so that different objects with the same attributes/types + can be compared. """ __slots__: List[str] = [] @@ -34,6 +35,9 @@ def __eq__(self, other: Any) -> bool: for attribute in self.__slots__: if getattr(self, attribute) != getattr(other, attribute): return False + for attribute in other.__slots__: + if getattr(self, attribute) != getattr(other, attribute): + return False return True def __hash__(self) -> int: @@ -66,7 +70,7 @@ def __repr__(self) -> str: class JumpStartS3FileType(str, Enum): - """Simple enum for classifying S3 file type.""" + """Type of files published in JumpStart S3 distribution buckets.""" MANIFEST = "manifest" SPECS = "specs" @@ -104,6 +108,32 @@ def from_json(self, json_obj: Dict[str, str]) -> None: self.spec_key: str = json_obj["spec_key"] +class JumpStartECRSpecs(JumpStartDataHolderType): + """Data class for JumpStart ECR specs.""" + + __slots__ = { + "framework", + "framework_version", + "py_version", + } + + def __init__(self, spec: Dict[str, Any]): + """Initializes a JumpStartECRSpecs object from its json representation.""" + self.from_json(spec) + + def from_json(self, json_obj: Dict[str, Any]) -> None: + """Sets fields in object based on json.""" + + self.framework = json_obj["framework"] + self.framework_version = json_obj["framework_version"] + self.py_version = json_obj["py_version"] + + def to_json(self) -> Dict[str, Any]: + """Returns json representation of JumpStartECRSpecs object.""" + json_obj = {att: getattr(self, att) for att in self.__slots__} + return json_obj + + class JumpStartModelSpecs(JumpStartDataHolderType): """Data class JumpStart model specs.""" @@ -132,15 +162,17 @@ def from_json(self, json_obj: Dict[str, Any]) -> None: self.version: str = json_obj["version"] self.min_sdk_version: str = json_obj["min_sdk_version"] self.incremental_training_supported: bool = bool(json_obj["incremental_training_supported"]) - self.hosting_ecr_specs: dict = json_obj["hosting_ecr_specs"] + self.hosting_ecr_specs: JumpStartECRSpecs = JumpStartECRSpecs(json_obj["hosting_ecr_specs"]) self.hosting_artifact_uri: str = json_obj["hosting_artifact_uri"] self.hosting_script_uri: str = json_obj["hosting_script_uri"] self.training_supported: bool = bool(json_obj["training_supported"]) if self.training_supported: - self.training_ecr_specs: Optional[dict] = json_obj["training_ecr_specs"] + self.training_ecr_specs: Optional[JumpStartECRSpecs] = JumpStartECRSpecs( + json_obj["training_ecr_specs"] + ) self.training_artifact_uri: Optional[str] = json_obj["training_artifact_uri"] self.training_script_uri: Optional[str] = json_obj["training_script_uri"] - self.hyperparameters: Optional[dict] = json_obj["hyperparameters"] + self.hyperparameters: Optional[Dict[str, Any]] = json_obj["hyperparameters"] else: self.training_ecr_specs = ( self.training_artifact_uri @@ -148,7 +180,13 @@ def from_json(self, json_obj: Dict[str, Any]) -> None: def to_json(self) -> Dict[str, Any]: """Returns json representation of JumpStartModelSpecs object.""" - json_obj = {att: getattr(self, att) for att in self.__slots__} + json_obj = {} + for att in self.__slots__: + cur_val = getattr(self, att) + if isinstance(cur_val, JumpStartECRSpecs): + json_obj[att] = cur_val.to_json() + else: + json_obj[att] = cur_val return json_obj diff --git a/src/sagemaker/jumpstart/utils.py b/src/sagemaker/jumpstart/utils.py index 560a22138e..ca6ce231bf 100644 --- a/src/sagemaker/jumpstart/utils.py +++ b/src/sagemaker/jumpstart/utils.py @@ -13,30 +13,27 @@ """This module contains utilities related to SageMaker JumpStart.""" from __future__ import absolute_import from typing import Dict, List - import semantic_version -from sagemaker.jumpstart import constants import sagemaker +from sagemaker.jumpstart import constants from sagemaker.jumpstart.types import JumpStartModelHeader, JumpStartVersionedModelId -def get_jumpstart_launched_regions_string() -> str: +def get_jumpstart_launched_regions_message() -> str: """Returns formatted string indicating where JumpStart is launched.""" - if len(constants.REGION_NAME_SET) == 0: + if len(constants.JUMPSTART_REGION_NAME_SET) == 0: return "JumpStart is not available in any region." - if len(constants.REGION_NAME_SET) == 1: - region = list(constants.REGION_NAME_SET)[0] + if len(constants.JUMPSTART_REGION_NAME_SET) == 1: + region = list(constants.JUMPSTART_REGION_NAME_SET)[0] return f"JumpStart is available in {region} region." - sorted_regions = sorted(list(constants.REGION_NAME_SET)) - if len(constants.REGION_NAME_SET) == 2: + sorted_regions = sorted(list(constants.JUMPSTART_REGION_NAME_SET)) + if len(constants.JUMPSTART_REGION_NAME_SET) == 2: return f"JumpStart is available in {sorted_regions[0]} and {sorted_regions[1]} regions." formatted_launched_regions_list = [] for i, region in enumerate(sorted_regions): - region_prefix = "" - if i == len(sorted_regions) - 1: - region_prefix = "and " + region_prefix = "" if i < len(sorted_regions) - 1 else "and " formatted_launched_regions_list.append(region_prefix + region) formatted_launched_regions_str = ", ".join(formatted_launched_regions_list) return f"JumpStart is available in {formatted_launched_regions_str} regions." @@ -49,10 +46,10 @@ def get_jumpstart_content_bucket(region: str) -> str: RuntimeError: If JumpStart is not launched in ``region``. """ try: - return constants.REGION_NAME_TO_LAUNCHED_REGION_DICT[region].content_bucket + return constants.JUMPSTART_REGION_NAME_TO_LAUNCHED_REGION_DICT[region].content_bucket except KeyError: - formatted_launched_regions_str = get_jumpstart_launched_regions_string() - raise RuntimeError( + formatted_launched_regions_str = get_jumpstart_launched_regions_message() + raise ValueError( f"Unable to get content bucket for JumpStart in {region} region. " f"{formatted_launched_regions_str}" ) @@ -99,9 +96,6 @@ def get_sagemaker_version() -> str: else: raise RuntimeError(f"Bad value for SageMaker version: {sagemaker.__version__}") - try: - semantic_version.Version(parsed_version) - except ValueError: - raise RuntimeError(f"Bad value for SageMaker version: {sagemaker.__version__}") + semantic_version.Version(parsed_version) return parsed_version diff --git a/src/sagemaker/utilities/cache.py b/src/sagemaker/utilities/cache.py index 0c1cbe0edb..0312e776b3 100644 --- a/src/sagemaker/utilities/cache.py +++ b/src/sagemaker/utilities/cache.py @@ -15,7 +15,7 @@ import datetime import collections -from typing import Any, TypeVar, Generic, Callable, Optional +from typing import TypeVar, Generic, Callable, Optional KeyType = TypeVar("KeyType") ValType = TypeVar("ValType") @@ -50,15 +50,15 @@ def __init__(self, value: ValType, creation_time: datetime.datetime): def __init__( self, max_cache_items: int, - expiration_time: datetime.timedelta, + expiration_horizon: datetime.timedelta, retrieval_function: Callable[[KeyType, ValType], ValType], ) -> None: """Initialize an ``LRUCache`` instance. Args: max_cache_items (int): Maximum number of items to store in cache. - expiration_time (datetime.timedelta): Maximum time duration a cache element can persist - before being invalidated. + expiration_horizon (datetime.timedelta): Maximum time duration a cache element can + persist before being invalidated. retrieval_function (Callable[[KeyType, ValType], ValType]): Function which maps cache keys and current values to new values. This function must have kwarg arguments ``key`` and and ``value``. This function is called as a fallback when the key @@ -67,7 +67,7 @@ def __init__( """ self._max_cache_items = max_cache_items self._lru_cache: collections.OrderedDict = collections.OrderedDict() - self._expiration_time = expiration_time + self._expiration_horizon = expiration_horizon self._retrieval_function = retrieval_function def __len__(self) -> int: @@ -147,12 +147,15 @@ def _get_item(self, key: KeyType, fail_on_old_value: bool) -> ValType: and fail_on_old_value is True. """ try: - element: Any = self._lru_cache.pop(key) + element = self._lru_cache.pop(key) curr_time = datetime.datetime.now(tz=datetime.timezone.utc) element_age = curr_time - element.creation_time - if element_age > self._expiration_time: + if element_age > self._expiration_horizon: if fail_on_old_value: - raise KeyError(f"{key} is old! Created at {element.creation_time}") + raise KeyError( + f"{key} has aged beyond allowed time {self._expiration_horizon}. " + f"Element created at {element.creation_time}." + ) element.value = self._retrieval_function( # type: ignore key=key, value=element.value ) diff --git a/tests/unit/sagemaker/jumpstart/test_cache.py b/tests/unit/sagemaker/jumpstart/test_cache.py index 52383d39f0..712d00bb94 100644 --- a/tests/unit/sagemaker/jumpstart/test_cache.py +++ b/tests/unit/sagemaker/jumpstart/test_cache.py @@ -15,13 +15,14 @@ import datetime import io import json +from botocore.stub import Stubber import botocore from mock.mock import MagicMock import pytest from mock import patch -from sagemaker.jumpstart.cache import DEFAULT_MANIFEST_FILE_S3_KEY, JumpStartModelsCache +from sagemaker.jumpstart.cache import JUMPSTART_DEFAULT_MANIFEST_FILE_S3_KEY, JumpStartModelsCache from sagemaker.jumpstart.types import ( JumpStartCachedS3ContentKey, JumpStartCachedS3ContentValue, @@ -131,7 +132,7 @@ def patched_get_file_from_s3( @patch("sagemaker.jumpstart.utils.get_sagemaker_version", lambda: "2.68.3") def test_jumpstart_cache_get_header(): - cache = JumpStartModelsCache(bucket="some_bucket") + cache = JumpStartModelsCache(s3_bucket_name="some_bucket") assert ( JumpStartModelHeader( @@ -283,7 +284,7 @@ def test_jumpstart_cache_handles_boto3_issues(mock_boto3_client): mock_boto3_client.return_value.get_object.side_effect = Exception() - cache = JumpStartModelsCache(bucket="some_bucket") + cache = JumpStartModelsCache(s3_bucket_name="some_bucket") with pytest.raises(Exception): cache.get_header( @@ -294,7 +295,7 @@ def test_jumpstart_cache_handles_boto3_issues(mock_boto3_client): mock_boto3_client.return_value.head_object.side_effect = Exception() - cache = JumpStartModelsCache(bucket="some_bucket") + cache = JumpStartModelsCache(s3_bucket_name="some_bucket") with pytest.raises(Exception): cache.get_header( @@ -302,36 +303,77 @@ def test_jumpstart_cache_handles_boto3_issues(mock_boto3_client): ) -def test_jumpstart_cache_gets_cleared_when_params_are_set(): - cache = JumpStartModelsCache(bucket="some_bucket") +@patch("boto3.client") +def test_jumpstart_cache_gets_cleared_when_params_are_set(mock_boto3_client): + cache = JumpStartModelsCache( + s3_bucket_name="some_bucket", region="some_region", manifest_file_s3_key="some_key" + ) + cache.clear = MagicMock() - cache.set_bucket("some_bucket") - cache.clear.assert_called_once() + cache.set_s3_bucket_name("some_bucket") + cache.clear.assert_not_called() cache.clear.reset_mock() cache.set_region("some_region") - cache.clear.assert_called_once() + cache.clear.assert_not_called() cache.clear.reset_mock() cache.set_manifest_file_s3_key("some_key") + cache.clear.assert_not_called() + + cache.clear.reset_mock() + + cache.set_s3_bucket_name("some_bucket1") + cache.clear.assert_called_once() + cache.clear.reset_mock() + cache.set_region("some_region1") + cache.clear.assert_called_once() + cache.clear.reset_mock() + cache.set_manifest_file_s3_key("some_key1") cache.clear.assert_called_once() +def test_jumpstart_cache_handles_boto3_client_errors(): + cache = JumpStartModelsCache(s3_bucket_name="some_bucket") + stubbed_s3_client = Stubber(cache._s3_client) + stubbed_s3_client.add_client_error("head_object", http_status_code=404) + stubbed_s3_client.add_client_error("get_object", http_status_code=404) + stubbed_s3_client.activate() + with pytest.raises(botocore.exceptions.ClientError): + cache.get_header(model_id="tensorflow-ic-imagenet-inception-v3-classification-4") + + cache = JumpStartModelsCache(s3_bucket_name="some_bucket") + stubbed_s3_client = Stubber(cache._s3_client) + stubbed_s3_client.add_client_error("head_object", service_error_code="AccessDenied") + stubbed_s3_client.add_client_error("get_object", service_error_code="AccessDenied") + stubbed_s3_client.activate() + with pytest.raises(botocore.exceptions.ClientError): + cache.get_header(model_id="tensorflow-ic-imagenet-inception-v3-classification-4") + + cache = JumpStartModelsCache(s3_bucket_name="some_bucket") + stubbed_s3_client = Stubber(cache._s3_client) + stubbed_s3_client.add_client_error("head_object", service_error_code="EndpointConnectionError") + stubbed_s3_client.add_client_error("get_object", service_error_code="EndpointConnectionError") + stubbed_s3_client.activate() + with pytest.raises(botocore.exceptions.ClientError): + cache.get_header(model_id="tensorflow-ic-imagenet-inception-v3-classification-4") + + def test_jumpstart_cache_accepts_input_parameters(): region = "us-east-1" max_s3_cache_items = 1 - s3_cache_expiration_time = datetime.timedelta(weeks=2) + s3_cache_expiration_horizon = datetime.timedelta(weeks=2) max_semantic_version_cache_items = 3 - semantic_version_cache_expiration_time = datetime.timedelta(microseconds=4) + semantic_version_cache_expiration_horizon = datetime.timedelta(microseconds=4) bucket = "my-amazing-bucket" manifest_file_key = "some_s3_key" cache = JumpStartModelsCache( region=region, max_s3_cache_items=max_s3_cache_items, - s3_cache_expiration_time=s3_cache_expiration_time, + s3_cache_expiration_horizon=s3_cache_expiration_horizon, max_semantic_version_cache_items=max_semantic_version_cache_items, - semantic_version_cache_expiration_time=semantic_version_cache_expiration_time, - bucket=bucket, + semantic_version_cache_expiration_horizon=semantic_version_cache_expiration_horizon, + s3_bucket_name=bucket, manifest_file_s3_key=manifest_file_key, ) @@ -339,14 +381,14 @@ def test_jumpstart_cache_accepts_input_parameters(): assert cache.get_region() == region assert cache.get_bucket() == bucket assert cache._s3_cache._max_cache_items == max_s3_cache_items - assert cache._s3_cache._expiration_time == s3_cache_expiration_time + assert cache._s3_cache._expiration_horizon == s3_cache_expiration_horizon assert ( cache._model_id_semantic_version_manifest_key_cache._max_cache_items == max_semantic_version_cache_items ) assert ( - cache._model_id_semantic_version_manifest_key_cache._expiration_time - == semantic_version_cache_expiration_time + cache._model_id_semantic_version_manifest_key_cache._expiration_horizon + == semantic_version_cache_expiration_horizon ) @@ -366,13 +408,13 @@ def test_jumpstart_cache_evaluates_md5_hash(mock_boto3_client): ) bucket_name = "bucket_name" - now = datetime.datetime.now() + now = datetime.datetime.fromtimestamp(1636730651.079551) with patch("datetime.datetime") as mock_datetime: mock_datetime.now.return_value = now cache = JumpStartModelsCache( - bucket=bucket_name, s3_cache_expiration_time=datetime.timedelta(hours=1) + s3_bucket_name=bucket_name, s3_cache_expiration_horizon=datetime.timedelta(hours=1) ) mock_boto3_client.return_value.get_object.return_value = { @@ -386,10 +428,10 @@ def test_jumpstart_cache_evaluates_md5_hash(mock_boto3_client): # first time accessing cache should involve get_object and head_object mock_boto3_client.return_value.get_object.assert_called_with( - Bucket=bucket_name, Key=DEFAULT_MANIFEST_FILE_S3_KEY + Bucket=bucket_name, Key=JUMPSTART_DEFAULT_MANIFEST_FILE_S3_KEY ) mock_boto3_client.return_value.head_object.assert_called_with( - Bucket=bucket_name, Key=DEFAULT_MANIFEST_FILE_S3_KEY + Bucket=bucket_name, Key=JUMPSTART_DEFAULT_MANIFEST_FILE_S3_KEY ) mock_boto3_client.return_value.get_object.reset_mock() @@ -409,7 +451,7 @@ def test_jumpstart_cache_evaluates_md5_hash(mock_boto3_client): cache.get_header(model_id="pytorch-ic-imagenet-inception-v3-classification-4") mock_boto3_client.return_value.head_object.assert_called_with( - Bucket=bucket_name, Key=DEFAULT_MANIFEST_FILE_S3_KEY + Bucket=bucket_name, Key=JUMPSTART_DEFAULT_MANIFEST_FILE_S3_KEY ) mock_boto3_client.return_value.get_object.assert_not_called() @@ -430,10 +472,10 @@ def test_jumpstart_cache_evaluates_md5_hash(mock_boto3_client): cache.get_header(model_id="pytorch-ic-imagenet-inception-v3-classification-4") mock_boto3_client.return_value.get_object.assert_called_with( - Bucket=bucket_name, Key=DEFAULT_MANIFEST_FILE_S3_KEY + Bucket=bucket_name, Key=JUMPSTART_DEFAULT_MANIFEST_FILE_S3_KEY ) mock_boto3_client.return_value.head_object.assert_called_with( - Bucket=bucket_name, Key=DEFAULT_MANIFEST_FILE_S3_KEY + Bucket=bucket_name, Key=JUMPSTART_DEFAULT_MANIFEST_FILE_S3_KEY ) @@ -461,15 +503,19 @@ def test_jumpstart_cache_makes_correct_s3_calls(mock_boto3_client): mock_boto3_client.return_value.head_object.return_value = {"ETag": "some-hash"} bucket_name = "bucket_name" - cache = JumpStartModelsCache(bucket=bucket_name) + client_config = botocore.config.Config(signature_version="my_signature_version") + cache = JumpStartModelsCache( + s3_bucket_name=bucket_name, s3_client_config=client_config, region="my_region" + ) cache.get_header(model_id="pytorch-ic-imagenet-inception-v3-classification-4") mock_boto3_client.return_value.get_object.assert_called_with( - Bucket=bucket_name, Key=DEFAULT_MANIFEST_FILE_S3_KEY + Bucket=bucket_name, Key=JUMPSTART_DEFAULT_MANIFEST_FILE_S3_KEY ) mock_boto3_client.return_value.head_object.assert_called_with( - Bucket=bucket_name, Key=DEFAULT_MANIFEST_FILE_S3_KEY + Bucket=bucket_name, Key=JUMPSTART_DEFAULT_MANIFEST_FILE_S3_KEY ) + mock_boto3_client.assert_called_with("s3", region_name="my_region", config=client_config) # test get_specs. manifest already in cache, so only s3 call will be to get specs. mock_json = json.dumps(BASE_SPEC) @@ -493,7 +539,7 @@ def test_jumpstart_cache_makes_correct_s3_calls(mock_boto3_client): @patch.object(JumpStartModelsCache, "_get_file_from_s3", patched_get_file_from_s3) def test_jumpstart_cache_handles_bad_semantic_version_manifest_key_cache(): - cache = JumpStartModelsCache(bucket="some_bucket") + cache = JumpStartModelsCache(s3_bucket_name="some_bucket") cache.clear = MagicMock() cache._model_id_semantic_version_manifest_key_cache = MagicMock() @@ -536,7 +582,7 @@ def test_jumpstart_cache_handles_bad_semantic_version_manifest_key_cache(): @patch.object(JumpStartModelsCache, "_get_file_from_s3", patched_get_file_from_s3) @patch("sagemaker.jumpstart.utils.get_sagemaker_version", lambda: "2.68.3") def test_jumpstart_cache_get_specs(): - cache = JumpStartModelsCache(bucket="some_bucket") + cache = JumpStartModelsCache(s3_bucket_name="some_bucket") model_id, version = "tensorflow-ic-imagenet-inception-v3-classification-4", "2.0.0" assert get_spec_from_base_spec(model_id, version) == cache.get_specs( diff --git a/tests/unit/sagemaker/jumpstart/test_types.py b/tests/unit/sagemaker/jumpstart/test_types.py index b04c0803ac..75dd841c9d 100644 --- a/tests/unit/sagemaker/jumpstart/test_types.py +++ b/tests/unit/sagemaker/jumpstart/test_types.py @@ -12,7 +12,7 @@ # language governing permissions and limitations under the License. from __future__ import absolute_import import copy -from sagemaker.jumpstart.types import JumpStartModelSpecs, JumpStartModelHeader +from sagemaker.jumpstart.types import JumpStartECRSpecs, JumpStartModelSpecs, JumpStartModelHeader def test_jumpstart_model_header(): @@ -87,16 +87,20 @@ def test_jumpstart_model_specs(): assert specs1.min_sdk_version == "2.49.0" assert specs1.training_supported assert specs1.incremental_training_supported - assert specs1.hosting_ecr_specs == { - "framework": "pytorch", - "framework_version": "1.7.0", - "py_version": "py3", - } - assert specs1.training_ecr_specs == { - "framework": "pytorch", - "framework_version": "1.9.0", - "py_version": "py3", - } + assert specs1.hosting_ecr_specs == JumpStartECRSpecs( + { + "framework": "pytorch", + "framework_version": "1.7.0", + "py_version": "py3", + } + ) + assert specs1.training_ecr_specs == JumpStartECRSpecs( + { + "framework": "pytorch", + "framework_version": "1.9.0", + "py_version": "py3", + } + ) assert specs1.hosting_artifact_uri == "pytorch-infer/infer-pytorch-ic-mobilenet-v2.tar.gz" assert specs1.training_artifact_uri == "pytorch-training/train-pytorch-ic-mobilenet-v2.tar.gz" assert ( diff --git a/tests/unit/sagemaker/jumpstart/test_utils.py b/tests/unit/sagemaker/jumpstart/test_utils.py index 944904d908..8eff3d5310 100644 --- a/tests/unit/sagemaker/jumpstart/test_utils.py +++ b/tests/unit/sagemaker/jumpstart/test_utils.py @@ -14,40 +14,42 @@ from mock.mock import patch import pytest from sagemaker.jumpstart import utils -from sagemaker.jumpstart.constants import REGION_NAME_SET +from sagemaker.jumpstart.constants import JUMPSTART_REGION_NAME_SET from sagemaker.jumpstart.types import JumpStartModelHeader, JumpStartVersionedModelId def test_get_jumpstart_content_bucket(): bad_region = "bad_region" - assert bad_region not in REGION_NAME_SET - with pytest.raises(RuntimeError): + assert bad_region not in JUMPSTART_REGION_NAME_SET + with pytest.raises(ValueError): utils.get_jumpstart_content_bucket(bad_region) -def test_get_jumpstart_launched_regions_string(): +def test_get_jumpstart_launched_regions_message(): - with patch("sagemaker.jumpstart.constants.REGION_NAME_SET", {}): + with patch("sagemaker.jumpstart.constants.JUMPSTART_REGION_NAME_SET", {}): assert ( - utils.get_jumpstart_launched_regions_string() + utils.get_jumpstart_launched_regions_message() == "JumpStart is not available in any region." ) - with patch("sagemaker.jumpstart.constants.REGION_NAME_SET", {"some_region"}): + with patch("sagemaker.jumpstart.constants.JUMPSTART_REGION_NAME_SET", {"some_region"}): assert ( - utils.get_jumpstart_launched_regions_string() + utils.get_jumpstart_launched_regions_message() == "JumpStart is available in some_region region." ) - with patch("sagemaker.jumpstart.constants.REGION_NAME_SET", {"some_region1", "some_region2"}): + with patch( + "sagemaker.jumpstart.constants.JUMPSTART_REGION_NAME_SET", {"some_region1", "some_region2"} + ): assert ( - utils.get_jumpstart_launched_regions_string() + utils.get_jumpstart_launched_regions_message() == "JumpStart is available in some_region1 and some_region2 regions." ) - with patch("sagemaker.jumpstart.constants.REGION_NAME_SET", {"a", "b", "c"}): + with patch("sagemaker.jumpstart.constants.JUMPSTART_REGION_NAME_SET", {"a", "b", "c"}): assert ( - utils.get_jumpstart_launched_regions_string() + utils.get_jumpstart_launched_regions_message() == "JumpStart is available in a, b, and c regions." ) @@ -82,7 +84,7 @@ def test_get_sagemaker_version(): with patch("sagemaker.__version__", "1.2.3."): assert utils.get_sagemaker_version() == "1.2.3" - with pytest.raises(RuntimeError): + with pytest.raises(ValueError): with patch("sagemaker.__version__", "1.2.3dfsdfs"): utils.get_sagemaker_version() diff --git a/tests/unit/sagemaker/utilities/test_cache.py b/tests/unit/sagemaker/utilities/test_cache.py index 0f5e053775..407c517013 100644 --- a/tests/unit/sagemaker/utilities/test_cache.py +++ b/tests/unit/sagemaker/utilities/test_cache.py @@ -26,7 +26,7 @@ def retrieval_function(key: Optional[int] = None, value: Optional[str] = None) - def test_cache_retrieves_item(): my_cache = cache.LRUCache[int, Union[int, str]]( max_cache_items=10, - expiration_time=datetime.timedelta(hours=1), + expiration_horizon=datetime.timedelta(hours=1), retrieval_function=retrieval_function, ) @@ -48,11 +48,11 @@ def test_cache_retrieves_item(): def test_cache_invalidates_old_item(): my_cache = cache.LRUCache[int, Union[int, str]]( max_cache_items=10, - expiration_time=datetime.timedelta(milliseconds=1), + expiration_horizon=datetime.timedelta(milliseconds=1), retrieval_function=retrieval_function, ) - curr_time = datetime.datetime.now() + curr_time = datetime.datetime.fromtimestamp(1636730651.079551) with patch("datetime.datetime") as mock_datetime: mock_datetime.now.return_value = curr_time my_cache.put(5) @@ -70,11 +70,11 @@ def test_cache_invalidates_old_item(): def test_cache_fetches_new_item(): my_cache = cache.LRUCache[int, Union[int, str]]( max_cache_items=10, - expiration_time=datetime.timedelta(milliseconds=1), + expiration_horizon=datetime.timedelta(milliseconds=1), retrieval_function=retrieval_function, ) - curr_time = datetime.datetime.now() + curr_time = datetime.datetime.fromtimestamp(1636730651.079551) with patch("datetime.datetime") as mock_datetime: mock_datetime.now.return_value = curr_time my_cache.put(5, 10) @@ -94,7 +94,7 @@ def test_cache_fetches_new_item(): def test_cache_removes_old_items_once_size_limit_reached(): my_cache = cache.LRUCache[int, Union[int, str]]( max_cache_items=5, - expiration_time=datetime.timedelta(hours=1), + expiration_horizon=datetime.timedelta(hours=1), retrieval_function=retrieval_function, ) @@ -113,7 +113,7 @@ def test_cache_removes_old_items_once_size_limit_reached(): def test_cache_get_with_data_source_fallback(): my_cache = cache.LRUCache[int, Union[int, str]]( max_cache_items=5, - expiration_time=datetime.timedelta(hours=1), + expiration_horizon=datetime.timedelta(hours=1), retrieval_function=retrieval_function, ) @@ -127,7 +127,7 @@ def test_cache_get_with_data_source_fallback(): def test_cache_gets_stored_value(): my_cache = cache.LRUCache[int, Union[int, str]]( max_cache_items=5, - expiration_time=datetime.timedelta(hours=1), + expiration_horizon=datetime.timedelta(hours=1), retrieval_function=retrieval_function, ) @@ -140,17 +140,47 @@ def test_cache_gets_stored_value(): my_cache._retrieval_function.reset_mock() my_cache.get(5) - my_cache._retrieval_function.assert_called() + my_cache._retrieval_function.assert_called_with(key=5, value=None) my_cache._retrieval_function.reset_mock() my_cache.get(0) - my_cache._retrieval_function.assert_called() + my_cache._retrieval_function.assert_called_with(key=0, value=None) + + +def test_cache_bad_retrieval_function(): + + cache_no_retrieval_fx = cache.LRUCache[int, Union[int, str]]( + max_cache_items=5, + expiration_horizon=datetime.timedelta(hours=1), + retrieval_function=None, + ) + + with pytest.raises(TypeError): + cache_no_retrieval_fx.put(1) + + cache_bad_retrieval_fx_signature = cache.LRUCache[int, Union[int, str]]( + max_cache_items=5, + expiration_horizon=datetime.timedelta(hours=1), + retrieval_function=lambda: 1, + ) + + with pytest.raises(TypeError): + cache_bad_retrieval_fx_signature.put(1) + + cache_retrieval_fx_throws = cache.LRUCache[int, Union[int, str]]( + max_cache_items=5, + expiration_horizon=datetime.timedelta(hours=1), + retrieval_function=lambda key, value: exec("raise(RuntimeError())"), + ) + + with pytest.raises(RuntimeError): + cache_retrieval_fx_throws.put(1) def test_cache_clear_and_contains(): my_cache = cache.LRUCache[int, Union[int, str]]( max_cache_items=5, - expiration_time=datetime.timedelta(hours=1), + expiration_horizon=datetime.timedelta(hours=1), retrieval_function=retrieval_function, ) From 7511bb0cd0e47465008784d74c6443bbc3f26da6 Mon Sep 17 00:00:00 2001 From: Evan Kravitz Date: Mon, 15 Nov 2021 14:18:56 +0000 Subject: [PATCH 05/13] change: add test for cache memory usage, add get_manifest function for jumpstart cache --- src/sagemaker/jumpstart/cache.py | 9 +- src/sagemaker/jumpstart/types.py | 7 +- tests/unit/sagemaker/jumpstart/test_cache.py | 87 +++++++++++--------- tests/unit/sagemaker/utilities/test_cache.py | 23 ++++++ 4 files changed, 84 insertions(+), 42 deletions(-) diff --git a/src/sagemaker/jumpstart/cache.py b/src/sagemaker/jumpstart/cache.py index 2cd281efe4..81ebd91599 100644 --- a/src/sagemaker/jumpstart/cache.py +++ b/src/sagemaker/jumpstart/cache.py @@ -13,7 +13,7 @@ """This module defines the JumpStartModelsCache class.""" from __future__ import absolute_import import datetime -from typing import Optional +from typing import List, Optional import json import boto3 import botocore @@ -253,6 +253,13 @@ def _get_file_from_s3( f"Bad value for key '{key}': must be in {[JumpStartS3FileType.MANIFEST, JumpStartS3FileType.SPECS]}" ) + def get_manifest(self) -> List[JumpStartModelHeader]: + """Return entire JumpStart models manifest.""" + + return self._s3_cache.get( + JumpStartCachedS3ContentKey(JumpStartS3FileType.MANIFEST, self._manifest_file_s3_key) + ).formatted_file_content.values() + def get_header( self, model_id: str, semantic_version_str: Optional[str] = None ) -> JumpStartModelHeader: diff --git a/src/sagemaker/jumpstart/types.py b/src/sagemaker/jumpstart/types.py index 32cb5e7e02..981148983e 100644 --- a/src/sagemaker/jumpstart/types.py +++ b/src/sagemaker/jumpstart/types.py @@ -32,12 +32,13 @@ def __eq__(self, other: Any) -> bool: if not isinstance(other, type(self)): return False + if getattr(other, "__slots__", None) is None: + return False + if self.__slots__ != other.__slots__: + return False for attribute in self.__slots__: if getattr(self, attribute) != getattr(other, attribute): return False - for attribute in other.__slots__: - if getattr(self, attribute) != getattr(other, attribute): - return False return True def __hash__(self) -> int: diff --git a/tests/unit/sagemaker/jumpstart/test_cache.py b/tests/unit/sagemaker/jumpstart/test_cache.py index 712d00bb94..844167654e 100644 --- a/tests/unit/sagemaker/jumpstart/test_cache.py +++ b/tests/unit/sagemaker/jumpstart/test_cache.py @@ -60,6 +60,44 @@ }, } +BASE_MANIFEST = [ + { + "model_id": "tensorflow-ic-imagenet-inception-v3-classification-4", + "version": "1.0.0", + "min_version": "2.49.0", + "spec_key": "community_models_specs/tensorflow-ic-imagenet" + "-inception-v3-classification-4/specs_v1.0.0.json", + }, + { + "model_id": "tensorflow-ic-imagenet-inception-v3-classification-4", + "version": "2.0.0", + "min_version": "2.49.0", + "spec_key": "community_models_specs/tensorflow-ic-imagenet" + "-inception-v3-classification-4/specs_v2.0.0.json", + }, + { + "model_id": "pytorch-ic-imagenet-inception-v3-classification-4", + "version": "1.0.0", + "min_version": "2.49.0", + "spec_key": "community_models_specs/pytorch-ic-" + "imagenet-inception-v3-classification-4/specs_v1.0.0.json", + }, + { + "model_id": "pytorch-ic-imagenet-inception-v3-classification-4", + "version": "2.0.0", + "min_version": "2.49.0", + "spec_key": "community_models_specs/pytorch-ic-imagenet-" + "inception-v3-classification-4/specs_v2.0.0.json", + }, + { + "model_id": "tensorflow-ic-imagenet-inception-v3-classification-4", + "version": "3.0.0", + "min_version": "4.49.0", + "spec_key": "community_models_specs/tensorflow-ic-" + "imagenet-inception-v3-classification-4/specs_v3.0.0.json", + }, +] + def get_spec_from_base_spec(model_id: str, version: str) -> JumpStartModelSpecs: spec = copy.deepcopy(BASE_SPEC) @@ -77,45 +115,9 @@ def patched_get_file_from_s3( filetype, s3_key = key.file_type, key.s3_key if filetype == JumpStartS3FileType.MANIFEST: - manifest = [ - { - "model_id": "tensorflow-ic-imagenet-inception-v3-classification-4", - "version": "1.0.0", - "min_version": "2.49.0", - "spec_key": "community_models_specs/tensorflow-ic-imagenet" - "-inception-v3-classification-4/specs_v1.0.0.json", - }, - { - "model_id": "tensorflow-ic-imagenet-inception-v3-classification-4", - "version": "2.0.0", - "min_version": "2.49.0", - "spec_key": "community_models_specs/tensorflow-ic-imagenet" - "-inception-v3-classification-4/specs_v2.0.0.json", - }, - { - "model_id": "pytorch-ic-imagenet-inception-v3-classification-4", - "version": "1.0.0", - "min_version": "2.49.0", - "spec_key": "community_models_specs/pytorch-ic-" - "imagenet-inception-v3-classification-4/specs_v1.0.0.json", - }, - { - "model_id": "pytorch-ic-imagenet-inception-v3-classification-4", - "version": "2.0.0", - "min_version": "2.49.0", - "spec_key": "community_models_specs/pytorch-ic-imagenet-" - "inception-v3-classification-4/specs_v2.0.0.json", - }, - { - "model_id": "tensorflow-ic-imagenet-inception-v3-classification-4", - "version": "3.0.0", - "min_version": "4.49.0", - "spec_key": "community_models_specs/tensorflow-ic-" - "imagenet-inception-v3-classification-4/specs_v3.0.0.json", - }, - ] + return JumpStartCachedS3ContentValue( - formatted_file_content=get_formatted_manifest(manifest) + formatted_file_content=get_formatted_manifest(BASE_MANIFEST) ) if filetype == JumpStartS3FileType.SPECS: @@ -579,6 +581,15 @@ def test_jumpstart_cache_handles_bad_semantic_version_manifest_key_cache(): cache.clear.assert_called_once() +@patch.object(JumpStartModelsCache, "_get_file_from_s3", patched_get_file_from_s3) +@patch("sagemaker.jumpstart.utils.get_sagemaker_version", lambda: "2.68.3") +def test_jumpstart_get_full_manifest(): + cache = JumpStartModelsCache(s3_bucket_name="some_bucket") + raw_manifest = [header.to_json() for header in cache.get_manifest()] + + raw_manifest == BASE_MANIFEST + + @patch.object(JumpStartModelsCache, "_get_file_from_s3", patched_get_file_from_s3) @patch("sagemaker.jumpstart.utils.get_sagemaker_version", lambda: "2.68.3") def test_jumpstart_cache_get_specs(): diff --git a/tests/unit/sagemaker/utilities/test_cache.py b/tests/unit/sagemaker/utilities/test_cache.py index 407c517013..ff00822443 100644 --- a/tests/unit/sagemaker/utilities/test_cache.py +++ b/tests/unit/sagemaker/utilities/test_cache.py @@ -14,6 +14,8 @@ from typing import Optional, Union from mock.mock import MagicMock, patch import pytest +import pickle + from sagemaker.utilities import cache import datetime @@ -192,3 +194,24 @@ def test_cache_clear_and_contains(): assert len(my_cache) == 0 with pytest.raises(KeyError): my_cache.get(1, False) + + +def test_cache_memory_usage(): + my_cache = cache.LRUCache[int, Union[int, str]]( + max_cache_items=10, + expiration_horizon=datetime.timedelta(hours=1), + retrieval_function=retrieval_function, + ) + cache_size_bytes = [] + cache_size_bytes.append(len(pickle.dumps(my_cache))) + for i in range(50): + my_cache.put(i) + cache_size_bytes.append(len(pickle.dumps(my_cache))) + + max_cache_items_iter_cache_size = cache_size_bytes[10] + past_capacity_iter_cache_size = cache_size_bytes[50] + percent_difference_cache_size = ( + abs(past_capacity_iter_cache_size - max_cache_items_iter_cache_size) + / max_cache_items_iter_cache_size + ) * 100.0 + assert percent_difference_cache_size < 1 From 04143a65e61d6e1e626e3445d7a536749a496df0 Mon Sep 17 00:00:00 2001 From: Evan Kravitz Date: Wed, 17 Nov 2021 19:57:58 +0000 Subject: [PATCH 06/13] change: remove cache memory test, improve jumpstart cache design/tests --- src/sagemaker/jumpstart/cache.py | 12 +- src/sagemaker/jumpstart/types.py | 20 ++-- src/sagemaker/jumpstart/utils.py | 3 +- src/sagemaker/utilities/cache.py | 4 +- tests/unit/sagemaker/jumpstart/test_cache.py | 110 +++++++++++++++---- tests/unit/sagemaker/jumpstart/test_types.py | 16 +-- tests/unit/sagemaker/utilities/test_cache.py | 34 +----- 7 files changed, 126 insertions(+), 73 deletions(-) diff --git a/src/sagemaker/jumpstart/cache.py b/src/sagemaker/jumpstart/cache.py index 81ebd91599..2464154e95 100644 --- a/src/sagemaker/jumpstart/cache.py +++ b/src/sagemaker/jumpstart/cache.py @@ -234,11 +234,13 @@ def _get_file_from_s3( file_type, s3_key = key.file_type, key.s3_key if file_type == JumpStartS3FileType.MANIFEST: - etag = self._s3_client.head_object(Bucket=self.s3_bucket_name, Key=s3_key)["ETag"] - if value is not None and etag == value.md5_hash: - return value + if value is not None: + etag = self._s3_client.head_object(Bucket=self.s3_bucket_name, Key=s3_key)["ETag"] + if etag == value.md5_hash: + return value response = self._s3_client.get_object(Bucket=self.s3_bucket_name, Key=s3_key) formatted_body = json.loads(response["Body"].read().decode("utf-8")) + etag = response["ETag"] return JumpStartCachedS3ContentValue( formatted_file_content=utils.get_formatted_manifest(formatted_body), md5_hash=etag, @@ -271,10 +273,10 @@ def get_header( header. If None, the highest compatible version is returned. """ - return self._get_header_impl(model_id, 0, semantic_version_str) + return self._get_header_impl(model_id, semantic_version_str=semantic_version_str) def _get_header_impl( - self, model_id: str, attempt: int, semantic_version_str: Optional[str] = None + self, model_id: str, attempt: Optional[int] = 0, semantic_version_str: Optional[str] = None ) -> JumpStartModelHeader: """Lower-level function to return header. diff --git a/src/sagemaker/jumpstart/types.py b/src/sagemaker/jumpstart/types.py index 981148983e..4989118e70 100644 --- a/src/sagemaker/jumpstart/types.py +++ b/src/sagemaker/jumpstart/types.py @@ -144,12 +144,12 @@ class JumpStartModelSpecs(JumpStartDataHolderType): "min_sdk_version", "incremental_training_supported", "hosting_ecr_specs", - "hosting_artifact_uri", - "hosting_script_uri", + "hosting_artifact_key", + "hosting_script_key", "training_supported", "training_ecr_specs", - "training_artifact_uri", - "training_script_uri", + "training_artifact_key", + "training_script_key", "hyperparameters", ] @@ -164,20 +164,20 @@ def from_json(self, json_obj: Dict[str, Any]) -> None: self.min_sdk_version: str = json_obj["min_sdk_version"] self.incremental_training_supported: bool = bool(json_obj["incremental_training_supported"]) self.hosting_ecr_specs: JumpStartECRSpecs = JumpStartECRSpecs(json_obj["hosting_ecr_specs"]) - self.hosting_artifact_uri: str = json_obj["hosting_artifact_uri"] - self.hosting_script_uri: str = json_obj["hosting_script_uri"] + self.hosting_artifact_key: str = json_obj["hosting_artifact_key"] + self.hosting_script_key: str = json_obj["hosting_script_key"] self.training_supported: bool = bool(json_obj["training_supported"]) if self.training_supported: self.training_ecr_specs: Optional[JumpStartECRSpecs] = JumpStartECRSpecs( json_obj["training_ecr_specs"] ) - self.training_artifact_uri: Optional[str] = json_obj["training_artifact_uri"] - self.training_script_uri: Optional[str] = json_obj["training_script_uri"] + self.training_artifact_key: Optional[str] = json_obj["training_artifact_key"] + self.training_script_key: Optional[str] = json_obj["training_script_key"] self.hyperparameters: Optional[Dict[str, Any]] = json_obj["hyperparameters"] else: self.training_ecr_specs = ( - self.training_artifact_uri - ) = self.training_script_uri = self.hyperparameters = None + self.training_artifact_key + ) = self.training_script_key = self.hyperparameters = None def to_json(self) -> Dict[str, Any]: """Returns json representation of JumpStartModelSpecs object.""" diff --git a/src/sagemaker/jumpstart/utils.py b/src/sagemaker/jumpstart/utils.py index ca6ce231bf..fafe074e7d 100644 --- a/src/sagemaker/jumpstart/utils.py +++ b/src/sagemaker/jumpstart/utils.py @@ -82,7 +82,8 @@ def get_sagemaker_version() -> str: rejected. Raises: - RuntimeError: If the SageMaker version is not readable. + RuntimeError: If the SageMaker version is not readable. An exception is also raised if + the version cannot be parsed by ``semantic_version``. """ version = sagemaker.__version__ parsed_version = None diff --git a/src/sagemaker/utilities/cache.py b/src/sagemaker/utilities/cache.py index 0312e776b3..ab7e086a3f 100644 --- a/src/sagemaker/utilities/cache.py +++ b/src/sagemaker/utilities/cache.py @@ -134,12 +134,12 @@ def put(self, key: KeyType, value: Optional[ValType] = None) -> None: def _get_item(self, key: KeyType, fail_on_old_value: bool) -> ValType: """Returns value from cache corresponding to key. - If ``fail_on_old_value``, a KeyError is thrown instead of a new value + If ``fail_on_old_value``, a KeyError is raised instead of a new value getting fetched. Args: key (KeyType): Key in cache to retrieve. - fail_on_old_value (bool): True if a KeyError is thrown when the cache value + fail_on_old_value (bool): True if a KeyError is raised when the cache value is old. Raises: diff --git a/tests/unit/sagemaker/jumpstart/test_cache.py b/tests/unit/sagemaker/jumpstart/test_cache.py index 844167654e..34c857231b 100644 --- a/tests/unit/sagemaker/jumpstart/test_cache.py +++ b/tests/unit/sagemaker/jumpstart/test_cache.py @@ -49,10 +49,10 @@ "framework_version": "1.9.0", "py_version": "py3", }, - "hosting_artifact_uri": "pytorch-infer/infer-pytorch-ic-mobilenet-v2.tar.gz", - "training_artifact_uri": "pytorch-training/train-pytorch-ic-mobilenet-v2.tar.gz", - "hosting_script_uri": "source-directory-tarballs/pytorch/inference/ic/v1.0.0/sourcedir.tar.gz", - "training_script_uri": "source-directory-tarballs/pytorch/transfer_learning/ic/v1.0.0/sourcedir.tar.gz", + "hosting_artifact_key": "pytorch-infer/infer-pytorch-ic-mobilenet-v2.tar.gz", + "training_artifact_key": "pytorch-training/train-pytorch-ic-mobilenet-v2.tar.gz", + "hosting_script_key": "source-directory-tarballs/pytorch/inference/ic/v1.0.0/sourcedir.tar.gz", + "training_script_key": "source-directory-tarballs/pytorch/transfer_learning/ic/v1.0.0/sourcedir.tar.gz", "hyperparameters": { "adam-learning-rate": {"type": "float", "default": 0.05, "min": 1e-08, "max": 1}, "epochs": {"type": "int", "default": 3, "min": 1, "max": 1000}, @@ -334,9 +334,9 @@ def test_jumpstart_cache_gets_cleared_when_params_are_set(mock_boto3_client): def test_jumpstart_cache_handles_boto3_client_errors(): + # Testing get_object cache = JumpStartModelsCache(s3_bucket_name="some_bucket") stubbed_s3_client = Stubber(cache._s3_client) - stubbed_s3_client.add_client_error("head_object", http_status_code=404) stubbed_s3_client.add_client_error("get_object", http_status_code=404) stubbed_s3_client.activate() with pytest.raises(botocore.exceptions.ClientError): @@ -344,7 +344,6 @@ def test_jumpstart_cache_handles_boto3_client_errors(): cache = JumpStartModelsCache(s3_bucket_name="some_bucket") stubbed_s3_client = Stubber(cache._s3_client) - stubbed_s3_client.add_client_error("head_object", service_error_code="AccessDenied") stubbed_s3_client.add_client_error("get_object", service_error_code="AccessDenied") stubbed_s3_client.activate() with pytest.raises(botocore.exceptions.ClientError): @@ -352,12 +351,83 @@ def test_jumpstart_cache_handles_boto3_client_errors(): cache = JumpStartModelsCache(s3_bucket_name="some_bucket") stubbed_s3_client = Stubber(cache._s3_client) - stubbed_s3_client.add_client_error("head_object", service_error_code="EndpointConnectionError") stubbed_s3_client.add_client_error("get_object", service_error_code="EndpointConnectionError") stubbed_s3_client.activate() with pytest.raises(botocore.exceptions.ClientError): cache.get_header(model_id="tensorflow-ic-imagenet-inception-v3-classification-4") + # Testing head_object: + mock_now = datetime.datetime.fromtimestamp(1636730651.079551) + with patch("datetime.datetime") as mock_datetime: + mock_manifest_json = json.dumps( + [ + { + "model_id": "pytorch-ic-imagenet-inception-v3-classification-4", + "version": "2.0.0", + "min_version": "2.49.0", + "spec_key": "community_models_specs/pytorch-ic-" + "imagenet-inception-v3-classification-4/specs_v2.0.0.json", + } + ] + ) + + get_object_mocked_response = { + "Body": botocore.response.StreamingBody( + io.BytesIO(bytes(mock_manifest_json, "utf-8")), + content_length=len(mock_manifest_json), + ), + "ETag": "etag", + } + + mock_datetime.now.return_value = mock_now + + cache1 = JumpStartModelsCache( + s3_bucket_name="some_bucket", s3_cache_expiration_horizon=datetime.timedelta(hours=1) + ) + stubbed_s3_client1 = Stubber(cache1._s3_client) + + stubbed_s3_client1.add_response("get_object", copy.deepcopy(get_object_mocked_response)) + stubbed_s3_client1.activate() + cache1.get_header(model_id="pytorch-ic-imagenet-inception-v3-classification-4") + + mock_datetime.now.return_value += datetime.timedelta(weeks=1) + + stubbed_s3_client1.add_client_error("head_object", http_status_code=404) + with pytest.raises(botocore.exceptions.ClientError): + cache1.get_header(model_id="pytorch-ic-imagenet-inception-v3-classification-4") + + cache2 = JumpStartModelsCache( + s3_bucket_name="some_bucket", s3_cache_expiration_horizon=datetime.timedelta(hours=1) + ) + stubbed_s3_client2 = Stubber(cache2._s3_client) + + stubbed_s3_client2.add_response("get_object", copy.deepcopy(get_object_mocked_response)) + stubbed_s3_client2.activate() + cache2.get_header(model_id="pytorch-ic-imagenet-inception-v3-classification-4") + + mock_datetime.now.return_value += datetime.timedelta(weeks=1) + + stubbed_s3_client2.add_client_error("head_object", service_error_code="AccessDenied") + with pytest.raises(botocore.exceptions.ClientError): + cache2.get_header(model_id="pytorch-ic-imagenet-inception-v3-classification-4") + + cache3 = JumpStartModelsCache( + s3_bucket_name="some_bucket", s3_cache_expiration_horizon=datetime.timedelta(hours=1) + ) + stubbed_s3_client3 = Stubber(cache3._s3_client) + + stubbed_s3_client3.add_response("get_object", copy.deepcopy(get_object_mocked_response)) + stubbed_s3_client3.activate() + cache3.get_header(model_id="pytorch-ic-imagenet-inception-v3-classification-4") + + mock_datetime.now.return_value += datetime.timedelta(weeks=1) + + stubbed_s3_client3.add_client_error( + "head_object", service_error_code="EndpointConnectionError" + ) + with pytest.raises(botocore.exceptions.ClientError): + cache3.get_header(model_id="pytorch-ic-imagenet-inception-v3-classification-4") + def test_jumpstart_cache_accepts_input_parameters(): @@ -422,19 +492,18 @@ def test_jumpstart_cache_evaluates_md5_hash(mock_boto3_client): mock_boto3_client.return_value.get_object.return_value = { "Body": botocore.response.StreamingBody( io.BytesIO(bytes(mock_json, "utf-8")), content_length=len(mock_json) - ) + ), + "ETag": "hash1", } mock_boto3_client.return_value.head_object.return_value = {"ETag": "hash1"} cache.get_header(model_id="pytorch-ic-imagenet-inception-v3-classification-4") - # first time accessing cache should involve get_object and head_object + # first time accessing cache should just involve get_object mock_boto3_client.return_value.get_object.assert_called_with( Bucket=bucket_name, Key=JUMPSTART_DEFAULT_MANIFEST_FILE_S3_KEY ) - mock_boto3_client.return_value.head_object.assert_called_with( - Bucket=bucket_name, Key=JUMPSTART_DEFAULT_MANIFEST_FILE_S3_KEY - ) + mock_boto3_client.return_value.head_object.assert_not_called() mock_boto3_client.return_value.get_object.reset_mock() mock_boto3_client.return_value.head_object.reset_mock() @@ -443,7 +512,8 @@ def test_jumpstart_cache_evaluates_md5_hash(mock_boto3_client): mock_boto3_client.return_value.get_object.return_value = { "Body": botocore.response.StreamingBody( io.BytesIO(bytes(mock_json, "utf-8")), content_length=len(mock_json) - ) + ), + "ETag": "hash1", } mock_boto3_client.return_value.head_object.return_value = {"ETag": "hash1"} @@ -465,7 +535,8 @@ def test_jumpstart_cache_evaluates_md5_hash(mock_boto3_client): mock_boto3_client.return_value.get_object.return_value = { "Body": botocore.response.StreamingBody( io.BytesIO(bytes(mock_json, "utf-8")), content_length=len(mock_json) - ) + ), + "ETag": "hash2", } # invalidate cache @@ -499,7 +570,8 @@ def test_jumpstart_cache_makes_correct_s3_calls(mock_boto3_client): mock_boto3_client.return_value.get_object.return_value = { "Body": botocore.response.StreamingBody( io.BytesIO(bytes(mock_json, "utf-8")), content_length=len(mock_json) - ) + ), + "ETag": "etag", } mock_boto3_client.return_value.head_object.return_value = {"ETag": "some-hash"} @@ -514,9 +586,8 @@ def test_jumpstart_cache_makes_correct_s3_calls(mock_boto3_client): mock_boto3_client.return_value.get_object.assert_called_with( Bucket=bucket_name, Key=JUMPSTART_DEFAULT_MANIFEST_FILE_S3_KEY ) - mock_boto3_client.return_value.head_object.assert_called_with( - Bucket=bucket_name, Key=JUMPSTART_DEFAULT_MANIFEST_FILE_S3_KEY - ) + mock_boto3_client.return_value.head_object.assert_not_called() + mock_boto3_client.assert_called_with("s3", region_name="my_region", config=client_config) # test get_specs. manifest already in cache, so only s3 call will be to get specs. @@ -527,7 +598,8 @@ def test_jumpstart_cache_makes_correct_s3_calls(mock_boto3_client): mock_boto3_client.return_value.get_object.return_value = { "Body": botocore.response.StreamingBody( io.BytesIO(bytes(mock_json, "utf-8")), content_length=len(mock_json) - ) + ), + "ETag": "etag", } cache.get_specs(model_id="pytorch-ic-imagenet-inception-v3-classification-4") diff --git a/tests/unit/sagemaker/jumpstart/test_types.py b/tests/unit/sagemaker/jumpstart/test_types.py index 75dd841c9d..6f970d6e58 100644 --- a/tests/unit/sagemaker/jumpstart/test_types.py +++ b/tests/unit/sagemaker/jumpstart/test_types.py @@ -69,10 +69,10 @@ def test_jumpstart_model_specs(): "framework_version": "1.9.0", "py_version": "py3", }, - "hosting_artifact_uri": "pytorch-infer/infer-pytorch-ic-mobilenet-v2.tar.gz", - "training_artifact_uri": "pytorch-training/train-pytorch-ic-mobilenet-v2.tar.gz", - "hosting_script_uri": "source-directory-tarballs/pytorch/inference/ic/v1.0.0/sourcedir.tar.gz", - "training_script_uri": "source-directory-tarballs/pytorch/transfer_learning/ic/v1.0.0/sourcedir.tar.gz", + "hosting_artifact_key": "pytorch-infer/infer-pytorch-ic-mobilenet-v2.tar.gz", + "training_artifact_key": "pytorch-training/train-pytorch-ic-mobilenet-v2.tar.gz", + "hosting_script_key": "source-directory-tarballs/pytorch/inference/ic/v1.0.0/sourcedir.tar.gz", + "training_script_key": "source-directory-tarballs/pytorch/transfer_learning/ic/v1.0.0/sourcedir.tar.gz", "hyperparameters": { "adam-learning-rate": {"type": "float", "default": 0.05, "min": 1e-08, "max": 1}, "epochs": {"type": "int", "default": 3, "min": 1, "max": 1000}, @@ -101,14 +101,14 @@ def test_jumpstart_model_specs(): "py_version": "py3", } ) - assert specs1.hosting_artifact_uri == "pytorch-infer/infer-pytorch-ic-mobilenet-v2.tar.gz" - assert specs1.training_artifact_uri == "pytorch-training/train-pytorch-ic-mobilenet-v2.tar.gz" + assert specs1.hosting_artifact_key == "pytorch-infer/infer-pytorch-ic-mobilenet-v2.tar.gz" + assert specs1.training_artifact_key == "pytorch-training/train-pytorch-ic-mobilenet-v2.tar.gz" assert ( - specs1.hosting_script_uri + specs1.hosting_script_key == "source-directory-tarballs/pytorch/inference/ic/v1.0.0/sourcedir.tar.gz" ) assert ( - specs1.training_script_uri + specs1.training_script_key == "source-directory-tarballs/pytorch/transfer_learning/ic/v1.0.0/sourcedir.tar.gz" ) assert specs1.hyperparameters == { diff --git a/tests/unit/sagemaker/utilities/test_cache.py b/tests/unit/sagemaker/utilities/test_cache.py index ff00822443..10fbe45767 100644 --- a/tests/unit/sagemaker/utilities/test_cache.py +++ b/tests/unit/sagemaker/utilities/test_cache.py @@ -14,7 +14,6 @@ from typing import Optional, Union from mock.mock import MagicMock, patch import pytest -import pickle from sagemaker.utilities import cache @@ -54,16 +53,16 @@ def test_cache_invalidates_old_item(): retrieval_function=retrieval_function, ) - curr_time = datetime.datetime.fromtimestamp(1636730651.079551) + mock_curr_time = datetime.datetime.fromtimestamp(1636730651.079551) with patch("datetime.datetime") as mock_datetime: - mock_datetime.now.return_value = curr_time + mock_datetime.now.return_value = mock_curr_time my_cache.put(5) mock_datetime.now.return_value += datetime.timedelta(milliseconds=2) with pytest.raises(KeyError): my_cache.get(5, False) with patch("datetime.datetime") as mock_datetime: - mock_datetime.now.return_value = curr_time + mock_datetime.now.return_value = mock_curr_time my_cache.put(5) mock_datetime.now.return_value += datetime.timedelta(milliseconds=0.5) assert my_cache.get(5, False) == retrieval_function(key=5) @@ -76,15 +75,15 @@ def test_cache_fetches_new_item(): retrieval_function=retrieval_function, ) - curr_time = datetime.datetime.fromtimestamp(1636730651.079551) + mock_curr_time = datetime.datetime.fromtimestamp(1636730651.079551) with patch("datetime.datetime") as mock_datetime: - mock_datetime.now.return_value = curr_time + mock_datetime.now.return_value = mock_curr_time my_cache.put(5, 10) mock_datetime.now.return_value += datetime.timedelta(milliseconds=2) assert my_cache.get(5) == retrieval_function(key=5) with patch("datetime.datetime") as mock_datetime: - mock_datetime.now.return_value = curr_time + mock_datetime.now.return_value = mock_curr_time my_cache.put(5, 10) mock_datetime.now.return_value += datetime.timedelta(milliseconds=0.5) assert my_cache.get(5, False) == 10 @@ -194,24 +193,3 @@ def test_cache_clear_and_contains(): assert len(my_cache) == 0 with pytest.raises(KeyError): my_cache.get(1, False) - - -def test_cache_memory_usage(): - my_cache = cache.LRUCache[int, Union[int, str]]( - max_cache_items=10, - expiration_horizon=datetime.timedelta(hours=1), - retrieval_function=retrieval_function, - ) - cache_size_bytes = [] - cache_size_bytes.append(len(pickle.dumps(my_cache))) - for i in range(50): - my_cache.put(i) - cache_size_bytes.append(len(pickle.dumps(my_cache))) - - max_cache_items_iter_cache_size = cache_size_bytes[10] - past_capacity_iter_cache_size = cache_size_bytes[50] - percent_difference_cache_size = ( - abs(past_capacity_iter_cache_size - max_cache_items_iter_cache_size) - / max_cache_items_iter_cache_size - ) * 100.0 - assert percent_difference_cache_size < 1 From 02b1a971c9bcf7dc77721f49395713b1b807e571 Mon Sep 17 00:00:00 2001 From: Evan Kravitz Date: Thu, 18 Nov 2021 21:33:40 +0000 Subject: [PATCH 07/13] fix: JumpStartModelSpecs hyperparameters --- src/sagemaker/jumpstart/types.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/sagemaker/jumpstart/types.py b/src/sagemaker/jumpstart/types.py index 4989118e70..d6a55f3fa9 100644 --- a/src/sagemaker/jumpstart/types.py +++ b/src/sagemaker/jumpstart/types.py @@ -173,7 +173,7 @@ def from_json(self, json_obj: Dict[str, Any]) -> None: ) self.training_artifact_key: Optional[str] = json_obj["training_artifact_key"] self.training_script_key: Optional[str] = json_obj["training_script_key"] - self.hyperparameters: Optional[Dict[str, Any]] = json_obj["hyperparameters"] + self.hyperparameters: Optional[Dict[str, Any]] = json_obj.get("hyperparameters") else: self.training_ecr_specs = ( self.training_artifact_key From 6604acc7c91a0edc1ab0010fb8bbebd3bc8ec96d Mon Sep 17 00:00:00 2001 From: Evan Kravitz Date: Mon, 22 Nov 2021 18:38:25 +0000 Subject: [PATCH 08/13] change: require version for cache get operations --- src/sagemaker/jumpstart/cache.py | 29 +++--- tests/unit/sagemaker/jumpstart/test_cache.py | 104 +++++++++++-------- 2 files changed, 76 insertions(+), 57 deletions(-) diff --git a/src/sagemaker/jumpstart/cache.py b/src/sagemaker/jumpstart/cache.py index 2464154e95..e3b237245a 100644 --- a/src/sagemaker/jumpstart/cache.py +++ b/src/sagemaker/jumpstart/cache.py @@ -262,21 +262,22 @@ def get_manifest(self) -> List[JumpStartModelHeader]: JumpStartCachedS3ContentKey(JumpStartS3FileType.MANIFEST, self._manifest_file_s3_key) ).formatted_file_content.values() - def get_header( - self, model_id: str, semantic_version_str: Optional[str] = None - ) -> JumpStartModelHeader: + def get_header(self, model_id: str, semantic_version_str: str) -> JumpStartModelHeader: """Return header for a given JumpStart model id and semantic version. Args: model_id (str): model id for which to get a header. - semantic_version_str (Optional[str]): The semantic version for which to get a - header. If None, the highest compatible version is returned. + semantic_version_str (str): The semantic version for which to get a + header. """ return self._get_header_impl(model_id, semantic_version_str=semantic_version_str) def _get_header_impl( - self, model_id: str, attempt: Optional[int] = 0, semantic_version_str: Optional[str] = None + self, + model_id: str, + semantic_version_str: str, + attempt: Optional[int] = 0, ) -> JumpStartModelHeader: """Lower-level function to return header. @@ -284,9 +285,9 @@ def _get_header_impl( Args: model_id (str): model id for which to get a header. - attempt (int): attempt number at retrieving a header. - semantic_version_str (Optional[str]): The semantic version for which to get a - header. If None, the highest compatible version is returned. + semantic_version_str (str): The semantic version for which to get a + header. + attempt (Optional[int]): attempt number at retrieving a header. """ versioned_model_id = self._model_id_semantic_version_manifest_key_cache.get( @@ -301,17 +302,15 @@ def _get_header_impl( if attempt > 0: raise self.clear() - return self._get_header_impl(model_id, attempt + 1, semantic_version_str) + return self._get_header_impl(model_id, semantic_version_str, attempt + 1) - def get_specs( - self, model_id: str, semantic_version_str: Optional[str] = None - ) -> JumpStartModelSpecs: + def get_specs(self, model_id: str, semantic_version_str: str) -> JumpStartModelSpecs: """Return specs for a given JumpStart model id and semantic version. Args: model_id (str): model id for which to get specs. - semantic_version_str (Optional[str]): The semantic version for which to get - specs. If None, the highest compatible version is returned. + semantic_version_str (str): The semantic version for which to get + specs. """ header = self.get_header(model_id, semantic_version_str) diff --git a/tests/unit/sagemaker/jumpstart/test_cache.py b/tests/unit/sagemaker/jumpstart/test_cache.py index 34c857231b..571c526bdf 100644 --- a/tests/unit/sagemaker/jumpstart/test_cache.py +++ b/tests/unit/sagemaker/jumpstart/test_cache.py @@ -136,33 +136,19 @@ def test_jumpstart_cache_get_header(): cache = JumpStartModelsCache(s3_bucket_name="some_bucket") - assert ( - JumpStartModelHeader( - { - "model_id": "tensorflow-ic-imagenet-inception-v3-classification-4", - "version": "2.0.0", - "min_version": "2.49.0", - "spec_key": "community_models_specs/tensorflow-ic" - "-imagenet-inception-v3-classification-4/specs_v2.0.0.json", - } - ) - == cache.get_header(model_id="tensorflow-ic-imagenet-inception-v3-classification-4") + assert JumpStartModelHeader( + { + "model_id": "tensorflow-ic-imagenet-inception-v3-classification-4", + "version": "2.0.0", + "min_version": "2.49.0", + "spec_key": "community_models_specs/tensorflow-ic" + "-imagenet-inception-v3-classification-4/specs_v2.0.0.json", + } + ) == cache.get_header( + model_id="tensorflow-ic-imagenet-inception-v3-classification-4", semantic_version_str="*" ) # See if we can make the same query 2 times consecutively - assert ( - JumpStartModelHeader( - { - "model_id": "tensorflow-ic-imagenet-inception-v3-classification-4", - "version": "2.0.0", - "min_version": "2.49.0", - "spec_key": "community_models_specs/tensorflow-ic" - "-imagenet-inception-v3-classification-4/specs_v2.0.0.json", - } - ) - == cache.get_header(model_id="tensorflow-ic-imagenet-inception-v3-classification-4") - ) - assert JumpStartModelHeader( { "model_id": "tensorflow-ic-imagenet-inception-v3-classification-4", @@ -278,6 +264,7 @@ def test_jumpstart_cache_get_header(): with pytest.raises(KeyError): cache.get_header( model_id="tensorflow-ic-imagenet-inception-v3-classification-4-bak", + semantic_version_str="*", ) @@ -340,21 +327,30 @@ def test_jumpstart_cache_handles_boto3_client_errors(): stubbed_s3_client.add_client_error("get_object", http_status_code=404) stubbed_s3_client.activate() with pytest.raises(botocore.exceptions.ClientError): - cache.get_header(model_id="tensorflow-ic-imagenet-inception-v3-classification-4") + cache.get_header( + model_id="tensorflow-ic-imagenet-inception-v3-classification-4", + semantic_version_str="*", + ) cache = JumpStartModelsCache(s3_bucket_name="some_bucket") stubbed_s3_client = Stubber(cache._s3_client) stubbed_s3_client.add_client_error("get_object", service_error_code="AccessDenied") stubbed_s3_client.activate() with pytest.raises(botocore.exceptions.ClientError): - cache.get_header(model_id="tensorflow-ic-imagenet-inception-v3-classification-4") + cache.get_header( + model_id="tensorflow-ic-imagenet-inception-v3-classification-4", + semantic_version_str="*", + ) cache = JumpStartModelsCache(s3_bucket_name="some_bucket") stubbed_s3_client = Stubber(cache._s3_client) stubbed_s3_client.add_client_error("get_object", service_error_code="EndpointConnectionError") stubbed_s3_client.activate() with pytest.raises(botocore.exceptions.ClientError): - cache.get_header(model_id="tensorflow-ic-imagenet-inception-v3-classification-4") + cache.get_header( + model_id="tensorflow-ic-imagenet-inception-v3-classification-4", + semantic_version_str="*", + ) # Testing head_object: mock_now = datetime.datetime.fromtimestamp(1636730651.079551) @@ -388,13 +384,18 @@ def test_jumpstart_cache_handles_boto3_client_errors(): stubbed_s3_client1.add_response("get_object", copy.deepcopy(get_object_mocked_response)) stubbed_s3_client1.activate() - cache1.get_header(model_id="pytorch-ic-imagenet-inception-v3-classification-4") + cache1.get_header( + model_id="pytorch-ic-imagenet-inception-v3-classification-4", semantic_version_str="*" + ) mock_datetime.now.return_value += datetime.timedelta(weeks=1) stubbed_s3_client1.add_client_error("head_object", http_status_code=404) with pytest.raises(botocore.exceptions.ClientError): - cache1.get_header(model_id="pytorch-ic-imagenet-inception-v3-classification-4") + cache1.get_header( + model_id="pytorch-ic-imagenet-inception-v3-classification-4", + semantic_version_str="*", + ) cache2 = JumpStartModelsCache( s3_bucket_name="some_bucket", s3_cache_expiration_horizon=datetime.timedelta(hours=1) @@ -403,13 +404,18 @@ def test_jumpstart_cache_handles_boto3_client_errors(): stubbed_s3_client2.add_response("get_object", copy.deepcopy(get_object_mocked_response)) stubbed_s3_client2.activate() - cache2.get_header(model_id="pytorch-ic-imagenet-inception-v3-classification-4") + cache2.get_header( + model_id="pytorch-ic-imagenet-inception-v3-classification-4", semantic_version_str="*" + ) mock_datetime.now.return_value += datetime.timedelta(weeks=1) stubbed_s3_client2.add_client_error("head_object", service_error_code="AccessDenied") with pytest.raises(botocore.exceptions.ClientError): - cache2.get_header(model_id="pytorch-ic-imagenet-inception-v3-classification-4") + cache2.get_header( + model_id="pytorch-ic-imagenet-inception-v3-classification-4", + semantic_version_str="*", + ) cache3 = JumpStartModelsCache( s3_bucket_name="some_bucket", s3_cache_expiration_horizon=datetime.timedelta(hours=1) @@ -418,7 +424,9 @@ def test_jumpstart_cache_handles_boto3_client_errors(): stubbed_s3_client3.add_response("get_object", copy.deepcopy(get_object_mocked_response)) stubbed_s3_client3.activate() - cache3.get_header(model_id="pytorch-ic-imagenet-inception-v3-classification-4") + cache3.get_header( + model_id="pytorch-ic-imagenet-inception-v3-classification-4", semantic_version_str="*" + ) mock_datetime.now.return_value += datetime.timedelta(weeks=1) @@ -426,7 +434,10 @@ def test_jumpstart_cache_handles_boto3_client_errors(): "head_object", service_error_code="EndpointConnectionError" ) with pytest.raises(botocore.exceptions.ClientError): - cache3.get_header(model_id="pytorch-ic-imagenet-inception-v3-classification-4") + cache3.get_header( + model_id="pytorch-ic-imagenet-inception-v3-classification-4", + semantic_version_str="*", + ) def test_jumpstart_cache_accepts_input_parameters(): @@ -497,7 +508,9 @@ def test_jumpstart_cache_evaluates_md5_hash(mock_boto3_client): } mock_boto3_client.return_value.head_object.return_value = {"ETag": "hash1"} - cache.get_header(model_id="pytorch-ic-imagenet-inception-v3-classification-4") + cache.get_header( + model_id="pytorch-ic-imagenet-inception-v3-classification-4", semantic_version_str="*" + ) # first time accessing cache should just involve get_object mock_boto3_client.return_value.get_object.assert_called_with( @@ -520,7 +533,9 @@ def test_jumpstart_cache_evaluates_md5_hash(mock_boto3_client): # invalidate cache mock_datetime.now.return_value += datetime.timedelta(hours=2) - cache.get_header(model_id="pytorch-ic-imagenet-inception-v3-classification-4") + cache.get_header( + model_id="pytorch-ic-imagenet-inception-v3-classification-4", semantic_version_str="*" + ) mock_boto3_client.return_value.head_object.assert_called_with( Bucket=bucket_name, Key=JUMPSTART_DEFAULT_MANIFEST_FILE_S3_KEY @@ -542,7 +557,9 @@ def test_jumpstart_cache_evaluates_md5_hash(mock_boto3_client): # invalidate cache mock_datetime.now.return_value += datetime.timedelta(hours=2) - cache.get_header(model_id="pytorch-ic-imagenet-inception-v3-classification-4") + cache.get_header( + model_id="pytorch-ic-imagenet-inception-v3-classification-4", semantic_version_str="*" + ) mock_boto3_client.return_value.get_object.assert_called_with( Bucket=bucket_name, Key=JUMPSTART_DEFAULT_MANIFEST_FILE_S3_KEY @@ -581,7 +598,9 @@ def test_jumpstart_cache_makes_correct_s3_calls(mock_boto3_client): cache = JumpStartModelsCache( s3_bucket_name=bucket_name, s3_client_config=client_config, region="my_region" ) - cache.get_header(model_id="pytorch-ic-imagenet-inception-v3-classification-4") + cache.get_header( + model_id="pytorch-ic-imagenet-inception-v3-classification-4", semantic_version_str="*" + ) mock_boto3_client.return_value.get_object.assert_called_with( Bucket=bucket_name, Key=JUMPSTART_DEFAULT_MANIFEST_FILE_S3_KEY @@ -601,7 +620,9 @@ def test_jumpstart_cache_makes_correct_s3_calls(mock_boto3_client): ), "ETag": "etag", } - cache.get_specs(model_id="pytorch-ic-imagenet-inception-v3-classification-4") + cache.get_specs( + model_id="pytorch-ic-imagenet-inception-v3-classification-4", semantic_version_str="*" + ) mock_boto3_client.return_value.get_object.assert_called_with( Bucket=bucket_name, @@ -633,7 +654,7 @@ def test_jumpstart_cache_handles_bad_semantic_version_manifest_key_cache(): "imagenet-inception-v3-classification-4/specs_v1.0.0.json", } ) == cache.get_header( - model_id="tensorflow-ic-imagenet-inception-v3-classification-4", + model_id="tensorflow-ic-imagenet-inception-v3-classification-4", semantic_version_str="*" ) cache.clear.assert_called_once() cache.clear.reset_mock() @@ -649,6 +670,7 @@ def test_jumpstart_cache_handles_bad_semantic_version_manifest_key_cache(): with pytest.raises(KeyError): cache.get_header( model_id="tensorflow-ic-imagenet-inception-v3-classification-4", + semantic_version_str="*", ) cache.clear.assert_called_once() @@ -683,9 +705,7 @@ def test_jumpstart_cache_get_specs(): ) with pytest.raises(KeyError): - cache.get_specs( - model_id=model_id + "bak", - ) + cache.get_specs(model_id=model_id + "bak", semantic_version_str="*") with pytest.raises(KeyError): cache.get_specs(model_id=model_id, semantic_version_str="9.*") From b8e130cee6b0ec708d089181b28011be688b71cd Mon Sep 17 00:00:00 2001 From: Evan Kravitz Date: Tue, 23 Nov 2021 01:12:01 +0000 Subject: [PATCH 09/13] change: improve getting sm version logic --- src/sagemaker/jumpstart/constants.py | 2 ++ src/sagemaker/jumpstart/utils.py | 11 ++++++++ tests/unit/sagemaker/jumpstart/test_utils.py | 29 +++++++++++++------- 3 files changed, 32 insertions(+), 10 deletions(-) diff --git a/src/sagemaker/jumpstart/constants.py b/src/sagemaker/jumpstart/constants.py index 71452433b6..97a86eb350 100644 --- a/src/sagemaker/jumpstart/constants.py +++ b/src/sagemaker/jumpstart/constants.py @@ -27,3 +27,5 @@ JUMPSTART_DEFAULT_REGION_NAME = boto3.session.Session().region_name JUMPSTART_DEFAULT_MANIFEST_FILE_S3_KEY = "models_manifest.json" + +PARSED_SAGEMAKER_VERSION = "" # this gets set by sagemaker.jumpstart.utils.get_sagemaker_version() diff --git a/src/sagemaker/jumpstart/utils.py b/src/sagemaker/jumpstart/utils.py index fafe074e7d..6188dd2924 100644 --- a/src/sagemaker/jumpstart/utils.py +++ b/src/sagemaker/jumpstart/utils.py @@ -75,6 +75,17 @@ def get_formatted_manifest( def get_sagemaker_version() -> str: """Returns sagemaker library version. + If the sagemaker library version has not been set yet, this function + calls ``parse_sagemaker_version`` to retrive the version. + """ + if constants.PARSED_SAGEMAKER_VERSION == "": + constants.PARSED_SAGEMAKER_VERSION = parse_sagemaker_version() + return constants.PARSED_SAGEMAKER_VERSION + + +def parse_sagemaker_version() -> str: + """Returns sagemaker library version. This should only be called once. + Function reads ``__version__`` variable in ``sagemaker`` module. In order to maintain compatibility with the ``semantic_version`` library, versions with fewer than 2, or more than 3, periods are rejected. diff --git a/tests/unit/sagemaker/jumpstart/test_utils.py b/tests/unit/sagemaker/jumpstart/test_utils.py index 8eff3d5310..06a4ce960b 100644 --- a/tests/unit/sagemaker/jumpstart/test_utils.py +++ b/tests/unit/sagemaker/jumpstart/test_utils.py @@ -11,7 +11,7 @@ # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. from __future__ import absolute_import -from mock.mock import patch +from mock.mock import Mock, patch import pytest from sagemaker.jumpstart import utils from sagemaker.jumpstart.constants import JUMPSTART_REGION_NAME_SET @@ -73,33 +73,42 @@ def test_get_formatted_manifest(): assert utils.get_formatted_manifest([]) == {} -def test_get_sagemaker_version(): +def test_parse_sagemaker_version(): with patch("sagemaker.__version__", "1.2.3"): - assert utils.get_sagemaker_version() == "1.2.3" + assert utils.parse_sagemaker_version() == "1.2.3" with patch("sagemaker.__version__", "1.2.3.3332j"): - assert utils.get_sagemaker_version() == "1.2.3" + assert utils.parse_sagemaker_version() == "1.2.3" with patch("sagemaker.__version__", "1.2.3."): - assert utils.get_sagemaker_version() == "1.2.3" + assert utils.parse_sagemaker_version() == "1.2.3" with pytest.raises(ValueError): with patch("sagemaker.__version__", "1.2.3dfsdfs"): - utils.get_sagemaker_version() + utils.parse_sagemaker_version() with pytest.raises(RuntimeError): with patch("sagemaker.__version__", "1.2"): - utils.get_sagemaker_version() + utils.parse_sagemaker_version() with pytest.raises(RuntimeError): with patch("sagemaker.__version__", "1"): - utils.get_sagemaker_version() + utils.parse_sagemaker_version() with pytest.raises(RuntimeError): with patch("sagemaker.__version__", ""): - utils.get_sagemaker_version() + utils.parse_sagemaker_version() with pytest.raises(RuntimeError): with patch("sagemaker.__version__", "1.2.3.4.5"): - utils.get_sagemaker_version() + utils.parse_sagemaker_version() + + +@patch("sagemaker.jumpstart.utils.parse_sagemaker_version") +@patch("sagemaker.jumpstart.constants.PARSED_SAGEMAKER_VERSION", "") +def test_get_sagemaker_version(patched_parse_sm_version: Mock): + utils.get_sagemaker_version() + utils.get_sagemaker_version() + utils.get_sagemaker_version() + assert patched_parse_sm_version.called_only_once() From e94a87b3fea9ea51fe0ad92fe4c64bdfd2e7f4dc Mon Sep 17 00:00:00 2001 From: Evan Kravitz Date: Tue, 23 Nov 2021 01:13:48 +0000 Subject: [PATCH 10/13] fix: typo --- src/sagemaker/jumpstart/utils.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/sagemaker/jumpstart/utils.py b/src/sagemaker/jumpstart/utils.py index 6188dd2924..5fb63751d1 100644 --- a/src/sagemaker/jumpstart/utils.py +++ b/src/sagemaker/jumpstart/utils.py @@ -75,8 +75,9 @@ def get_formatted_manifest( def get_sagemaker_version() -> str: """Returns sagemaker library version. - If the sagemaker library version has not been set yet, this function - calls ``parse_sagemaker_version`` to retrive the version. + If the sagemaker library version has not been set, this function + calls ``parse_sagemaker_version`` to retrieve the version and set + the constant. """ if constants.PARSED_SAGEMAKER_VERSION == "": constants.PARSED_SAGEMAKER_VERSION = parse_sagemaker_version() From 3bde0381080a633cfe492667fa897ea55868384f Mon Sep 17 00:00:00 2001 From: Evan Kravitz Date: Wed, 24 Nov 2021 13:25:57 +0000 Subject: [PATCH 11/13] fix: sagemaker version for jumpstart --- src/sagemaker/jumpstart/constants.py | 2 -- src/sagemaker/jumpstart/utils.py | 22 +++++++++++++++++--- tests/unit/sagemaker/jumpstart/test_utils.py | 2 +- 3 files changed, 20 insertions(+), 6 deletions(-) diff --git a/src/sagemaker/jumpstart/constants.py b/src/sagemaker/jumpstart/constants.py index 97a86eb350..71452433b6 100644 --- a/src/sagemaker/jumpstart/constants.py +++ b/src/sagemaker/jumpstart/constants.py @@ -27,5 +27,3 @@ JUMPSTART_DEFAULT_REGION_NAME = boto3.session.Session().region_name JUMPSTART_DEFAULT_MANIFEST_FILE_S3_KEY = "models_manifest.json" - -PARSED_SAGEMAKER_VERSION = "" # this gets set by sagemaker.jumpstart.utils.get_sagemaker_version() diff --git a/src/sagemaker/jumpstart/utils.py b/src/sagemaker/jumpstart/utils.py index 5fb63751d1..1e1f4c4b6d 100644 --- a/src/sagemaker/jumpstart/utils.py +++ b/src/sagemaker/jumpstart/utils.py @@ -19,6 +19,22 @@ from sagemaker.jumpstart.types import JumpStartModelHeader, JumpStartVersionedModelId +class SageMakerSettings(object): + """Static class for storing the SageMaker settings.""" + + _PARSED_SAGEMAKER_VERSION = "" + + @staticmethod + def set_sagemaker_version(version: str) -> None: + """Set SageMaker version.""" + SageMakerSettings._PARSED_SAGEMAKER_VERSION = version + + @staticmethod + def get_sagemaker_version() -> str: + """Return SageMaker version.""" + return SageMakerSettings._PARSED_SAGEMAKER_VERSION + + def get_jumpstart_launched_regions_message() -> str: """Returns formatted string indicating where JumpStart is launched.""" if len(constants.JUMPSTART_REGION_NAME_SET) == 0: @@ -79,9 +95,9 @@ def get_sagemaker_version() -> str: calls ``parse_sagemaker_version`` to retrieve the version and set the constant. """ - if constants.PARSED_SAGEMAKER_VERSION == "": - constants.PARSED_SAGEMAKER_VERSION = parse_sagemaker_version() - return constants.PARSED_SAGEMAKER_VERSION + if SageMakerSettings.get_sagemaker_version() == "": + SageMakerSettings.set_sagemaker_version(parse_sagemaker_version()) + return SageMakerSettings.get_sagemaker_version() def parse_sagemaker_version() -> str: diff --git a/tests/unit/sagemaker/jumpstart/test_utils.py b/tests/unit/sagemaker/jumpstart/test_utils.py index 06a4ce960b..39b4706796 100644 --- a/tests/unit/sagemaker/jumpstart/test_utils.py +++ b/tests/unit/sagemaker/jumpstart/test_utils.py @@ -106,7 +106,7 @@ def test_parse_sagemaker_version(): @patch("sagemaker.jumpstart.utils.parse_sagemaker_version") -@patch("sagemaker.jumpstart.constants.PARSED_SAGEMAKER_VERSION", "") +@patch("sagemaker.jumpstart.utils.SageMakerSettings._PARSED_SAGEMAKER_VERSION", "") def test_get_sagemaker_version(patched_parse_sm_version: Mock): utils.get_sagemaker_version() utils.get_sagemaker_version() From 85c65d580bcb1494db8713f5874d2f212fb81f9f Mon Sep 17 00:00:00 2001 From: Evan Kravitz Date: Fri, 26 Nov 2021 17:47:30 +0000 Subject: [PATCH 12/13] change: improve docstrings, typos --- src/sagemaker/jumpstart/cache.py | 16 +++-- src/sagemaker/jumpstart/types.py | 75 +++++++++++++++++--- src/sagemaker/utilities/cache.py | 2 +- tests/unit/sagemaker/jumpstart/test_cache.py | 4 +- 4 files changed, 77 insertions(+), 20 deletions(-) diff --git a/src/sagemaker/jumpstart/cache.py b/src/sagemaker/jumpstart/cache.py index e3b237245a..117d1e8ba6 100644 --- a/src/sagemaker/jumpstart/cache.py +++ b/src/sagemaker/jumpstart/cache.py @@ -164,7 +164,7 @@ def _get_manifest_key_from_model_id_semantic_version( manifest = self._s3_cache.get( JumpStartCachedS3ContentKey(JumpStartS3FileType.MANIFEST, self._manifest_file_s3_key) - ).formatted_file_content + ).formatted_content sm_version = utils.get_sagemaker_version() @@ -199,7 +199,9 @@ def _get_manifest_key_from_model_id_semantic_version( if header.model_id == model_id and header.version == model_version_to_use_incompatible_with_sagemaker ] - assert len(sm_version_to_use) == 1 # ``manifest`` dict should already enforce this + if len(sm_version_to_use) != 1: + # ``manifest`` dict should already enforce this + raise RuntimeError("Found more than one incompatible SageMaker version to use.") sm_version_to_use = sm_version_to_use[0] error_msg = ( @@ -242,14 +244,14 @@ def _get_file_from_s3( formatted_body = json.loads(response["Body"].read().decode("utf-8")) etag = response["ETag"] return JumpStartCachedS3ContentValue( - formatted_file_content=utils.get_formatted_manifest(formatted_body), + formatted_content=utils.get_formatted_manifest(formatted_body), md5_hash=etag, ) if file_type == JumpStartS3FileType.SPECS: response = self._s3_client.get_object(Bucket=self.s3_bucket_name, Key=s3_key) formatted_body = json.loads(response["Body"].read().decode("utf-8")) return JumpStartCachedS3ContentValue( - formatted_file_content=JumpStartModelSpecs(formatted_body) + formatted_content=JumpStartModelSpecs(formatted_body) ) raise ValueError( f"Bad value for key '{key}': must be in {[JumpStartS3FileType.MANIFEST, JumpStartS3FileType.SPECS]}" @@ -260,7 +262,7 @@ def get_manifest(self) -> List[JumpStartModelHeader]: return self._s3_cache.get( JumpStartCachedS3ContentKey(JumpStartS3FileType.MANIFEST, self._manifest_file_s3_key) - ).formatted_file_content.values() + ).formatted_content.values() def get_header(self, model_id: str, semantic_version_str: str) -> JumpStartModelHeader: """Return header for a given JumpStart model id and semantic version. @@ -295,7 +297,7 @@ def _get_header_impl( ) manifest = self._s3_cache.get( JumpStartCachedS3ContentKey(JumpStartS3FileType.MANIFEST, self._manifest_file_s3_key) - ).formatted_file_content + ).formatted_content try: return manifest[versioned_model_id] except KeyError: @@ -317,7 +319,7 @@ def get_specs(self, model_id: str, semantic_version_str: str) -> JumpStartModelS spec_key = header.spec_key return self._s3_cache.get( JumpStartCachedS3ContentKey(JumpStartS3FileType.SPECS, spec_key) - ).formatted_file_content + ).formatted_content def clear(self) -> None: """Clears the model id/version and s3 cache.""" diff --git a/src/sagemaker/jumpstart/types.py b/src/sagemaker/jumpstart/types.py index d6a55f3fa9..9bb865cc65 100644 --- a/src/sagemaker/jumpstart/types.py +++ b/src/sagemaker/jumpstart/types.py @@ -28,7 +28,11 @@ class JumpStartDataHolderType: __slots__: List[str] = [] def __eq__(self, other: Any) -> bool: - """Returns True if ``other`` is of the same type and has all attributes equal.""" + """Returns True if ``other`` is of the same type and has all attributes equal. + + Args: + other (Any): Other object to which to compare this object. + """ if not isinstance(other, type(self)): return False @@ -83,6 +87,12 @@ class JumpStartLaunchedRegionInfo(JumpStartDataHolderType): __slots__ = ["content_bucket", "region_name"] def __init__(self, content_bucket: str, region_name: str): + """Instantiates JumpStartLaunchedRegionInfo object. + + Args: + content_bucket (str): Name of JumpStart s3 content bucket associated with region. + region_name (str): Name of JumpStart launched region. + """ self.content_bucket = content_bucket self.region_name = region_name @@ -93,7 +103,11 @@ class JumpStartModelHeader(JumpStartDataHolderType): __slots__ = ["model_id", "version", "min_version", "spec_key"] def __init__(self, header: Dict[str, str]): - """Initializes a JumpStartModelHeader object from its json representation.""" + """Initializes a JumpStartModelHeader object from its json representation. + + Args: + header (Dict[str, str]): Dictionary representation of header. + """ self.from_json(header) def to_json(self) -> Dict[str, str]: @@ -102,7 +116,11 @@ def to_json(self) -> Dict[str, str]: return json_obj def from_json(self, json_obj: Dict[str, str]) -> None: - """Sets fields in object based on json of header.""" + """Sets fields in object based on json of header. + + Args: + json_obj (Dict[str, str]): Dictionary representation of header. + """ self.model_id: str = json_obj["model_id"] self.version: str = json_obj["version"] self.min_version: str = json_obj["min_version"] @@ -119,11 +137,19 @@ class JumpStartECRSpecs(JumpStartDataHolderType): } def __init__(self, spec: Dict[str, Any]): - """Initializes a JumpStartECRSpecs object from its json representation.""" + """Initializes a JumpStartECRSpecs object from its json representation. + + Args: + spec (Dict[str, Any]): Dictionary representation of spec. + """ self.from_json(spec) def from_json(self, json_obj: Dict[str, Any]) -> None: - """Sets fields in object based on json.""" + """Sets fields in object based on json. + + Args: + json_obj (Dict[str, Any]): Dictionary representation of spec. + """ self.framework = json_obj["framework"] self.framework_version = json_obj["framework_version"] @@ -154,11 +180,19 @@ class JumpStartModelSpecs(JumpStartDataHolderType): ] def __init__(self, spec: Dict[str, Any]): - """Initializes a JumpStartModelSpecs object from its json representation.""" + """Initializes a JumpStartModelSpecs object from its json representation. + + Args: + spec (Dict[str, Any]): Dictionary representation of spec. + """ self.from_json(spec) def from_json(self, json_obj: Dict[str, Any]) -> None: - """Sets fields in object based on json of header.""" + """Sets fields in object based on json of header. + + Args: + json_obj (Dict[str, Any]): Dictionary representation of spec. + """ self.model_id: str = json_obj["model_id"] self.version: str = json_obj["version"] self.min_sdk_version: str = json_obj["min_sdk_version"] @@ -201,6 +235,12 @@ def __init__( model_id: str, version: str, ) -> None: + """Instantiates JumpStartVersionedModelId object. + + Args: + model_id (str): JumpStart model id. + version (str): JumpStart model version. + """ self.model_id = model_id self.version = version @@ -215,6 +255,12 @@ def __init__( file_type: JumpStartS3FileType, s3_key: str, ) -> None: + """Instantiates JumpStartCachedS3ContentKey object. + + Args: + file_type (JumpStartS3FileType): JumpStart file type. + s3_key (str): object key in s3. + """ self.file_type = file_type self.s3_key = s3_key @@ -222,15 +268,24 @@ def __init__( class JumpStartCachedS3ContentValue(JumpStartDataHolderType): """Data class for the s3 cached content values.""" - __slots__ = ["formatted_file_content", "md5_hash"] + __slots__ = ["formatted_content", "md5_hash"] def __init__( self, - formatted_file_content: Union[ + formatted_content: Union[ Dict[JumpStartVersionedModelId, JumpStartModelHeader], List[JumpStartModelSpecs], ], md5_hash: Optional[str] = None, ) -> None: - self.formatted_file_content = formatted_file_content + """Instantiates JumpStartCachedS3ContentValue object. + + Args: + formatted_content (Union[Dict[JumpStartVersionedModelId, JumpStartModelHeader], + List[JumpStartModelSpecs]]): + Formatted content for model specs and mappings from + versioned model ids to specs. + md5_hash (str): md5_hash for stored file content from s3. + """ + self.formatted_content = formatted_content self.md5_hash = md5_hash diff --git a/src/sagemaker/utilities/cache.py b/src/sagemaker/utilities/cache.py index ab7e086a3f..b5a48ccef8 100644 --- a/src/sagemaker/utilities/cache.py +++ b/src/sagemaker/utilities/cache.py @@ -61,7 +61,7 @@ def __init__( persist before being invalidated. retrieval_function (Callable[[KeyType, ValType], ValType]): Function which maps cache keys and current values to new values. This function must have kwarg arguments - ``key`` and and ``value``. This function is called as a fallback when the key + ``key`` and ``value``. This function is called as a fallback when the key is not found in the cache, or a key has expired. """ diff --git a/tests/unit/sagemaker/jumpstart/test_cache.py b/tests/unit/sagemaker/jumpstart/test_cache.py index 571c526bdf..e073a80d67 100644 --- a/tests/unit/sagemaker/jumpstart/test_cache.py +++ b/tests/unit/sagemaker/jumpstart/test_cache.py @@ -117,14 +117,14 @@ def patched_get_file_from_s3( if filetype == JumpStartS3FileType.MANIFEST: return JumpStartCachedS3ContentValue( - formatted_file_content=get_formatted_manifest(BASE_MANIFEST) + formatted_content=get_formatted_manifest(BASE_MANIFEST) ) if filetype == JumpStartS3FileType.SPECS: _, model_id, specs_version = s3_key.split("/") version = specs_version.replace("specs_v", "").replace(".json", "") return JumpStartCachedS3ContentValue( - formatted_file_content=get_spec_from_base_spec(model_id, version) + formatted_content=get_spec_from_base_spec(model_id, version) ) raise ValueError(f"Bad value for filetype: {filetype}") From 2f8c692a10cefe82b10a022981ec135157dbab09 Mon Sep 17 00:00:00 2001 From: Mufaddal Rohawala Date: Mon, 29 Nov 2021 15:45:56 -0800 Subject: [PATCH 13/13] fix: Fix coverage concurency --- .coveragerc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.coveragerc b/.coveragerc index 8ed7382211..3f40836d93 100644 --- a/.coveragerc +++ b/.coveragerc @@ -1,4 +1,4 @@ [run] -concurrency = threading +concurrency = thread omit = sagemaker/tests/* timid = True