diff --git a/.coveragerc b/.coveragerc index 8ed7382211..3f40836d93 100644 --- a/.coveragerc +++ b/.coveragerc @@ -1,4 +1,4 @@ [run] -concurrency = threading +concurrency = thread omit = sagemaker/tests/* timid = True diff --git a/setup.py b/setup.py index 338313f25d..e398611a8b 100644 --- a/setup.py +++ b/setup.py @@ -44,6 +44,7 @@ def read_version(): "packaging>=20.0", "pandas", "pathos", + "semantic-version", ] # Specific use case dependencies diff --git a/src/sagemaker/jumpstart/__init__.py b/src/sagemaker/jumpstart/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/sagemaker/jumpstart/cache.py b/src/sagemaker/jumpstart/cache.py new file mode 100644 index 0000000000..117d1e8ba6 --- /dev/null +++ b/src/sagemaker/jumpstart/cache.py @@ -0,0 +1,327 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""This module defines the JumpStartModelsCache class.""" +from __future__ import absolute_import +import datetime +from typing import List, Optional +import json +import boto3 +import botocore +import semantic_version +from sagemaker.jumpstart.constants import ( + JUMPSTART_DEFAULT_MANIFEST_FILE_S3_KEY, + JUMPSTART_DEFAULT_REGION_NAME, +) +from sagemaker.jumpstart.parameters import ( + JUMPSTART_DEFAULT_MAX_S3_CACHE_ITEMS, + JUMPSTART_DEFAULT_MAX_SEMANTIC_VERSION_CACHE_ITEMS, + JUMPSTART_DEFAULT_S3_CACHE_EXPIRATION_HORIZON, + JUMPSTART_DEFAULT_SEMANTIC_VERSION_CACHE_EXPIRATION_HORIZON, +) +from sagemaker.jumpstart.types import ( + JumpStartCachedS3ContentKey, + JumpStartCachedS3ContentValue, + JumpStartModelHeader, + JumpStartModelSpecs, + JumpStartS3FileType, + JumpStartVersionedModelId, +) +from sagemaker.jumpstart import utils +from sagemaker.utilities.cache import LRUCache + + +class JumpStartModelsCache: + """Class that implements a cache for JumpStart models manifests and specs. + + The manifest and specs associated with JumpStart models provide the information necessary + for launching JumpStart models from the SageMaker SDK. + """ + + def __init__( + self, + region: Optional[str] = JUMPSTART_DEFAULT_REGION_NAME, + max_s3_cache_items: Optional[int] = JUMPSTART_DEFAULT_MAX_S3_CACHE_ITEMS, + s3_cache_expiration_horizon: Optional[ + datetime.timedelta + ] = JUMPSTART_DEFAULT_S3_CACHE_EXPIRATION_HORIZON, + max_semantic_version_cache_items: Optional[ + int + ] = JUMPSTART_DEFAULT_MAX_SEMANTIC_VERSION_CACHE_ITEMS, + semantic_version_cache_expiration_horizon: Optional[ + datetime.timedelta + ] = JUMPSTART_DEFAULT_SEMANTIC_VERSION_CACHE_EXPIRATION_HORIZON, + manifest_file_s3_key: Optional[str] = JUMPSTART_DEFAULT_MANIFEST_FILE_S3_KEY, + s3_bucket_name: Optional[str] = None, + s3_client_config: Optional[botocore.config.Config] = None, + ) -> None: + """Initialize a ``JumpStartModelsCache`` instance. + + Args: + region (Optional[str]): AWS region to associate with cache. Default: region associated + with boto3 session. + max_s3_cache_items (Optional[int]): Maximum number of items to store in s3 cache. + Default: 20. + s3_cache_expiration_horizon (Optional[datetime.timedelta]): Maximum time to hold + items in s3 cache before invalidation. Default: 6 hours. + max_semantic_version_cache_items (Optional[int]): Maximum number of items to store in + semantic version cache. Default: 20. + semantic_version_cache_expiration_horizon (Optional[datetime.timedelta]): + Maximum time to hold items in semantic version cache before invalidation. + Default: 6 hours. + s3_bucket_name (Optional[str]): S3 bucket to associate with cache. + Default: JumpStart-hosted content bucket for region. + s3_client_config (Optional[botocore.config.Config]): s3 client config to use for cache. + Default: None (no config). + """ + + self._region = region + self._s3_cache = LRUCache[JumpStartCachedS3ContentKey, JumpStartCachedS3ContentValue]( + max_cache_items=max_s3_cache_items, + expiration_horizon=s3_cache_expiration_horizon, + retrieval_function=self._get_file_from_s3, + ) + self._model_id_semantic_version_manifest_key_cache = LRUCache[ + JumpStartVersionedModelId, JumpStartVersionedModelId + ]( + max_cache_items=max_semantic_version_cache_items, + expiration_horizon=semantic_version_cache_expiration_horizon, + retrieval_function=self._get_manifest_key_from_model_id_semantic_version, + ) + self._manifest_file_s3_key = manifest_file_s3_key + self.s3_bucket_name = ( + utils.get_jumpstart_content_bucket(self._region) + if s3_bucket_name is None + else s3_bucket_name + ) + self._s3_client = ( + boto3.client("s3", region_name=self._region, config=s3_client_config) + if s3_client_config + else boto3.client("s3", region_name=self._region) + ) + + def set_region(self, region: str) -> None: + """Set region for cache. Clears cache after new region is set.""" + if region != self._region: + self._region = region + self.clear() + + def get_region(self) -> str: + """Return region for cache.""" + return self._region + + def set_manifest_file_s3_key(self, key: str) -> None: + """Set manifest file s3 key. Clears cache after new key is set.""" + if key != self._manifest_file_s3_key: + self._manifest_file_s3_key = key + self.clear() + + def get_manifest_file_s3_key(self) -> None: + """Return manifest file s3 key for cache.""" + return self._manifest_file_s3_key + + def set_s3_bucket_name(self, s3_bucket_name: str) -> None: + """Set s3 bucket used for cache.""" + if s3_bucket_name != self.s3_bucket_name: + self.s3_bucket_name = s3_bucket_name + self.clear() + + def get_bucket(self) -> None: + """Return bucket used for cache.""" + return self.s3_bucket_name + + def _get_manifest_key_from_model_id_semantic_version( + self, + key: JumpStartVersionedModelId, + value: Optional[JumpStartVersionedModelId], # pylint: disable=W0613 + ) -> JumpStartVersionedModelId: + """Return model id and version in manifest that matches semantic version/id. + + Uses ``semantic_version`` to perform version comparison. The highest model version + matching the semantic version is used, which is compatible with the SageMaker + version. + + Args: + key (JumpStartVersionedModelId): Key for which to fetch versioned model id. + value (Optional[JumpStartVersionedModelId]): Unused variable for current value of + old cached model id/version. + + Raises: + KeyError: If the semantic version is not found in the manifest, or is found but + the SageMaker version needs to be upgraded in order for the model to be used. + """ + + model_id, version = key.model_id, key.version + + manifest = self._s3_cache.get( + JumpStartCachedS3ContentKey(JumpStartS3FileType.MANIFEST, self._manifest_file_s3_key) + ).formatted_content + + sm_version = utils.get_sagemaker_version() + + versions_compatible_with_sagemaker = [ + semantic_version.Version(header.version) + for header in manifest.values() + if header.model_id == model_id + and semantic_version.Version(header.min_version) <= semantic_version.Version(sm_version) + ] + + spec = ( + semantic_version.SimpleSpec("*") + if version is None + else semantic_version.SimpleSpec(version) + ) + + sm_compatible_model_version = spec.select(versions_compatible_with_sagemaker) + if sm_compatible_model_version is not None: + return JumpStartVersionedModelId(model_id, str(sm_compatible_model_version)) + + versions_incompatible_with_sagemaker = [ + semantic_version.Version(header.version) + for header in manifest.values() + if header.model_id == model_id + ] + sm_incompatible_model_version = spec.select(versions_incompatible_with_sagemaker) + if sm_incompatible_model_version is not None: + model_version_to_use_incompatible_with_sagemaker = str(sm_incompatible_model_version) + sm_version_to_use = [ + header.min_version + for header in manifest.values() + if header.model_id == model_id + and header.version == model_version_to_use_incompatible_with_sagemaker + ] + if len(sm_version_to_use) != 1: + # ``manifest`` dict should already enforce this + raise RuntimeError("Found more than one incompatible SageMaker version to use.") + sm_version_to_use = sm_version_to_use[0] + + error_msg = ( + f"Unable to find model manifest for {model_id} with version {version} " + f"compatible with your SageMaker version ({sm_version}). " + f"Consider upgrading your SageMaker library to at least version " + f"{sm_version_to_use} so you can use version " + f"{model_version_to_use_incompatible_with_sagemaker} of {model_id}." + ) + raise KeyError(error_msg) + error_msg = f"Unable to find model manifest for {model_id} with version {version}." + raise KeyError(error_msg) + + def _get_file_from_s3( + self, + key: JumpStartCachedS3ContentKey, + value: Optional[JumpStartCachedS3ContentValue], + ) -> JumpStartCachedS3ContentValue: + """Return s3 content given a file type and s3_key in ``JumpStartCachedS3ContentKey``. + + If a manifest file is being fetched, we only download the object if the md5 hash in + ``head_object`` does not match the current md5 hash for the stored value. This prevents + unnecessarily downloading the full manifest when it hasn't changed. + + Args: + key (JumpStartCachedS3ContentKey): key for which to fetch s3 content. + value (Optional[JumpStartVersionedModelId]): Current value of old cached + s3 content. This is used for the manifest file, so that it is only + downloaded when its content changes. + """ + + file_type, s3_key = key.file_type, key.s3_key + + if file_type == JumpStartS3FileType.MANIFEST: + if value is not None: + etag = self._s3_client.head_object(Bucket=self.s3_bucket_name, Key=s3_key)["ETag"] + if etag == value.md5_hash: + return value + response = self._s3_client.get_object(Bucket=self.s3_bucket_name, Key=s3_key) + formatted_body = json.loads(response["Body"].read().decode("utf-8")) + etag = response["ETag"] + return JumpStartCachedS3ContentValue( + formatted_content=utils.get_formatted_manifest(formatted_body), + md5_hash=etag, + ) + if file_type == JumpStartS3FileType.SPECS: + response = self._s3_client.get_object(Bucket=self.s3_bucket_name, Key=s3_key) + formatted_body = json.loads(response["Body"].read().decode("utf-8")) + return JumpStartCachedS3ContentValue( + formatted_content=JumpStartModelSpecs(formatted_body) + ) + raise ValueError( + f"Bad value for key '{key}': must be in {[JumpStartS3FileType.MANIFEST, JumpStartS3FileType.SPECS]}" + ) + + def get_manifest(self) -> List[JumpStartModelHeader]: + """Return entire JumpStart models manifest.""" + + return self._s3_cache.get( + JumpStartCachedS3ContentKey(JumpStartS3FileType.MANIFEST, self._manifest_file_s3_key) + ).formatted_content.values() + + def get_header(self, model_id: str, semantic_version_str: str) -> JumpStartModelHeader: + """Return header for a given JumpStart model id and semantic version. + + Args: + model_id (str): model id for which to get a header. + semantic_version_str (str): The semantic version for which to get a + header. + """ + + return self._get_header_impl(model_id, semantic_version_str=semantic_version_str) + + def _get_header_impl( + self, + model_id: str, + semantic_version_str: str, + attempt: Optional[int] = 0, + ) -> JumpStartModelHeader: + """Lower-level function to return header. + + Allows a single retry if the cache is old. + + Args: + model_id (str): model id for which to get a header. + semantic_version_str (str): The semantic version for which to get a + header. + attempt (Optional[int]): attempt number at retrieving a header. + """ + + versioned_model_id = self._model_id_semantic_version_manifest_key_cache.get( + JumpStartVersionedModelId(model_id, semantic_version_str) + ) + manifest = self._s3_cache.get( + JumpStartCachedS3ContentKey(JumpStartS3FileType.MANIFEST, self._manifest_file_s3_key) + ).formatted_content + try: + return manifest[versioned_model_id] + except KeyError: + if attempt > 0: + raise + self.clear() + return self._get_header_impl(model_id, semantic_version_str, attempt + 1) + + def get_specs(self, model_id: str, semantic_version_str: str) -> JumpStartModelSpecs: + """Return specs for a given JumpStart model id and semantic version. + + Args: + model_id (str): model id for which to get specs. + semantic_version_str (str): The semantic version for which to get + specs. + """ + + header = self.get_header(model_id, semantic_version_str) + spec_key = header.spec_key + return self._s3_cache.get( + JumpStartCachedS3ContentKey(JumpStartS3FileType.SPECS, spec_key) + ).formatted_content + + def clear(self) -> None: + """Clears the model id/version and s3 cache.""" + self._s3_cache.clear() + self._model_id_semantic_version_manifest_key_cache.clear() diff --git a/src/sagemaker/jumpstart/constants.py b/src/sagemaker/jumpstart/constants.py new file mode 100644 index 0000000000..71452433b6 --- /dev/null +++ b/src/sagemaker/jumpstart/constants.py @@ -0,0 +1,29 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""This module stores constants related to SageMaker JumpStart.""" +from __future__ import absolute_import +from typing import Set +import boto3 +from sagemaker.jumpstart.types import JumpStartLaunchedRegionInfo + + +JUMPSTART_LAUNCHED_REGIONS: Set[JumpStartLaunchedRegionInfo] = set() + +JUMPSTART_REGION_NAME_TO_LAUNCHED_REGION_DICT = { + region.region_name: region for region in JUMPSTART_LAUNCHED_REGIONS +} +JUMPSTART_REGION_NAME_SET = {region.region_name for region in JUMPSTART_LAUNCHED_REGIONS} + +JUMPSTART_DEFAULT_REGION_NAME = boto3.session.Session().region_name + +JUMPSTART_DEFAULT_MANIFEST_FILE_S3_KEY = "models_manifest.json" diff --git a/src/sagemaker/jumpstart/parameters.py b/src/sagemaker/jumpstart/parameters.py new file mode 100644 index 0000000000..2010c39382 --- /dev/null +++ b/src/sagemaker/jumpstart/parameters.py @@ -0,0 +1,20 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""This module stores parameters related to SageMaker JumpStart.""" +from __future__ import absolute_import +import datetime + +JUMPSTART_DEFAULT_MAX_S3_CACHE_ITEMS = 20 +JUMPSTART_DEFAULT_MAX_SEMANTIC_VERSION_CACHE_ITEMS = 20 +JUMPSTART_DEFAULT_S3_CACHE_EXPIRATION_HORIZON = datetime.timedelta(hours=6) +JUMPSTART_DEFAULT_SEMANTIC_VERSION_CACHE_EXPIRATION_HORIZON = datetime.timedelta(hours=6) diff --git a/src/sagemaker/jumpstart/types.py b/src/sagemaker/jumpstart/types.py new file mode 100644 index 0000000000..9bb865cc65 --- /dev/null +++ b/src/sagemaker/jumpstart/types.py @@ -0,0 +1,291 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""This module stores types related to SageMaker JumpStart.""" +from __future__ import absolute_import +from enum import Enum +from typing import Any, Dict, List, Optional, Union + + +class JumpStartDataHolderType: + """Base class for many JumpStart types. + + Allows objects to be added to dicts and sets, + and improves string representation. This class overrides the ``__eq__`` + and ``__hash__`` methods so that different objects with the same attributes/types + can be compared. + """ + + __slots__: List[str] = [] + + def __eq__(self, other: Any) -> bool: + """Returns True if ``other`` is of the same type and has all attributes equal. + + Args: + other (Any): Other object to which to compare this object. + """ + + if not isinstance(other, type(self)): + return False + if getattr(other, "__slots__", None) is None: + return False + if self.__slots__ != other.__slots__: + return False + for attribute in self.__slots__: + if getattr(self, attribute) != getattr(other, attribute): + return False + return True + + def __hash__(self) -> int: + """Makes hash of object. + + Maps object to unique tuple, which then gets hashed. + """ + + return hash((type(self),) + tuple([getattr(self, att) for att in self.__slots__])) + + def __str__(self) -> str: + """Returns string representation of object. Example: + + "JumpStartLaunchedRegionInfo: + {'content_bucket': 'bucket', 'region_name': 'us-west-2'}" + """ + + att_dict = {att: getattr(self, att) for att in self.__slots__} + return f"{type(self).__name__}: {str(att_dict)}" + + def __repr__(self) -> str: + """Returns ``__repr__`` string of object. Example: + + "JumpStartLaunchedRegionInfo at 0x7f664529efa0: + {'content_bucket': 'bucket', 'region_name': 'us-west-2'}" + """ + + att_dict = {att: getattr(self, att) for att in self.__slots__} + return f"{type(self).__name__} at {hex(id(self))}: {str(att_dict)}" + + +class JumpStartS3FileType(str, Enum): + """Type of files published in JumpStart S3 distribution buckets.""" + + MANIFEST = "manifest" + SPECS = "specs" + + +class JumpStartLaunchedRegionInfo(JumpStartDataHolderType): + """Data class for launched region info.""" + + __slots__ = ["content_bucket", "region_name"] + + def __init__(self, content_bucket: str, region_name: str): + """Instantiates JumpStartLaunchedRegionInfo object. + + Args: + content_bucket (str): Name of JumpStart s3 content bucket associated with region. + region_name (str): Name of JumpStart launched region. + """ + self.content_bucket = content_bucket + self.region_name = region_name + + +class JumpStartModelHeader(JumpStartDataHolderType): + """Data class JumpStart model header.""" + + __slots__ = ["model_id", "version", "min_version", "spec_key"] + + def __init__(self, header: Dict[str, str]): + """Initializes a JumpStartModelHeader object from its json representation. + + Args: + header (Dict[str, str]): Dictionary representation of header. + """ + self.from_json(header) + + def to_json(self) -> Dict[str, str]: + """Returns json representation of JumpStartModelHeader object.""" + json_obj = {att: getattr(self, att) for att in self.__slots__} + return json_obj + + def from_json(self, json_obj: Dict[str, str]) -> None: + """Sets fields in object based on json of header. + + Args: + json_obj (Dict[str, str]): Dictionary representation of header. + """ + self.model_id: str = json_obj["model_id"] + self.version: str = json_obj["version"] + self.min_version: str = json_obj["min_version"] + self.spec_key: str = json_obj["spec_key"] + + +class JumpStartECRSpecs(JumpStartDataHolderType): + """Data class for JumpStart ECR specs.""" + + __slots__ = { + "framework", + "framework_version", + "py_version", + } + + def __init__(self, spec: Dict[str, Any]): + """Initializes a JumpStartECRSpecs object from its json representation. + + Args: + spec (Dict[str, Any]): Dictionary representation of spec. + """ + self.from_json(spec) + + def from_json(self, json_obj: Dict[str, Any]) -> None: + """Sets fields in object based on json. + + Args: + json_obj (Dict[str, Any]): Dictionary representation of spec. + """ + + self.framework = json_obj["framework"] + self.framework_version = json_obj["framework_version"] + self.py_version = json_obj["py_version"] + + def to_json(self) -> Dict[str, Any]: + """Returns json representation of JumpStartECRSpecs object.""" + json_obj = {att: getattr(self, att) for att in self.__slots__} + return json_obj + + +class JumpStartModelSpecs(JumpStartDataHolderType): + """Data class JumpStart model specs.""" + + __slots__ = [ + "model_id", + "version", + "min_sdk_version", + "incremental_training_supported", + "hosting_ecr_specs", + "hosting_artifact_key", + "hosting_script_key", + "training_supported", + "training_ecr_specs", + "training_artifact_key", + "training_script_key", + "hyperparameters", + ] + + def __init__(self, spec: Dict[str, Any]): + """Initializes a JumpStartModelSpecs object from its json representation. + + Args: + spec (Dict[str, Any]): Dictionary representation of spec. + """ + self.from_json(spec) + + def from_json(self, json_obj: Dict[str, Any]) -> None: + """Sets fields in object based on json of header. + + Args: + json_obj (Dict[str, Any]): Dictionary representation of spec. + """ + self.model_id: str = json_obj["model_id"] + self.version: str = json_obj["version"] + self.min_sdk_version: str = json_obj["min_sdk_version"] + self.incremental_training_supported: bool = bool(json_obj["incremental_training_supported"]) + self.hosting_ecr_specs: JumpStartECRSpecs = JumpStartECRSpecs(json_obj["hosting_ecr_specs"]) + self.hosting_artifact_key: str = json_obj["hosting_artifact_key"] + self.hosting_script_key: str = json_obj["hosting_script_key"] + self.training_supported: bool = bool(json_obj["training_supported"]) + if self.training_supported: + self.training_ecr_specs: Optional[JumpStartECRSpecs] = JumpStartECRSpecs( + json_obj["training_ecr_specs"] + ) + self.training_artifact_key: Optional[str] = json_obj["training_artifact_key"] + self.training_script_key: Optional[str] = json_obj["training_script_key"] + self.hyperparameters: Optional[Dict[str, Any]] = json_obj.get("hyperparameters") + else: + self.training_ecr_specs = ( + self.training_artifact_key + ) = self.training_script_key = self.hyperparameters = None + + def to_json(self) -> Dict[str, Any]: + """Returns json representation of JumpStartModelSpecs object.""" + json_obj = {} + for att in self.__slots__: + cur_val = getattr(self, att) + if isinstance(cur_val, JumpStartECRSpecs): + json_obj[att] = cur_val.to_json() + else: + json_obj[att] = cur_val + return json_obj + + +class JumpStartVersionedModelId(JumpStartDataHolderType): + """Data class for versioned model ids.""" + + __slots__ = ["model_id", "version"] + + def __init__( + self, + model_id: str, + version: str, + ) -> None: + """Instantiates JumpStartVersionedModelId object. + + Args: + model_id (str): JumpStart model id. + version (str): JumpStart model version. + """ + self.model_id = model_id + self.version = version + + +class JumpStartCachedS3ContentKey(JumpStartDataHolderType): + """Data class for the s3 cached content keys.""" + + __slots__ = ["file_type", "s3_key"] + + def __init__( + self, + file_type: JumpStartS3FileType, + s3_key: str, + ) -> None: + """Instantiates JumpStartCachedS3ContentKey object. + + Args: + file_type (JumpStartS3FileType): JumpStart file type. + s3_key (str): object key in s3. + """ + self.file_type = file_type + self.s3_key = s3_key + + +class JumpStartCachedS3ContentValue(JumpStartDataHolderType): + """Data class for the s3 cached content values.""" + + __slots__ = ["formatted_content", "md5_hash"] + + def __init__( + self, + formatted_content: Union[ + Dict[JumpStartVersionedModelId, JumpStartModelHeader], + List[JumpStartModelSpecs], + ], + md5_hash: Optional[str] = None, + ) -> None: + """Instantiates JumpStartCachedS3ContentValue object. + + Args: + formatted_content (Union[Dict[JumpStartVersionedModelId, JumpStartModelHeader], + List[JumpStartModelSpecs]]): + Formatted content for model specs and mappings from + versioned model ids to specs. + md5_hash (str): md5_hash for stored file content from s3. + """ + self.formatted_content = formatted_content + self.md5_hash = md5_hash diff --git a/src/sagemaker/jumpstart/utils.py b/src/sagemaker/jumpstart/utils.py new file mode 100644 index 0000000000..1e1f4c4b6d --- /dev/null +++ b/src/sagemaker/jumpstart/utils.py @@ -0,0 +1,130 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""This module contains utilities related to SageMaker JumpStart.""" +from __future__ import absolute_import +from typing import Dict, List +import semantic_version +import sagemaker +from sagemaker.jumpstart import constants +from sagemaker.jumpstart.types import JumpStartModelHeader, JumpStartVersionedModelId + + +class SageMakerSettings(object): + """Static class for storing the SageMaker settings.""" + + _PARSED_SAGEMAKER_VERSION = "" + + @staticmethod + def set_sagemaker_version(version: str) -> None: + """Set SageMaker version.""" + SageMakerSettings._PARSED_SAGEMAKER_VERSION = version + + @staticmethod + def get_sagemaker_version() -> str: + """Return SageMaker version.""" + return SageMakerSettings._PARSED_SAGEMAKER_VERSION + + +def get_jumpstart_launched_regions_message() -> str: + """Returns formatted string indicating where JumpStart is launched.""" + if len(constants.JUMPSTART_REGION_NAME_SET) == 0: + return "JumpStart is not available in any region." + if len(constants.JUMPSTART_REGION_NAME_SET) == 1: + region = list(constants.JUMPSTART_REGION_NAME_SET)[0] + return f"JumpStart is available in {region} region." + + sorted_regions = sorted(list(constants.JUMPSTART_REGION_NAME_SET)) + if len(constants.JUMPSTART_REGION_NAME_SET) == 2: + return f"JumpStart is available in {sorted_regions[0]} and {sorted_regions[1]} regions." + + formatted_launched_regions_list = [] + for i, region in enumerate(sorted_regions): + region_prefix = "" if i < len(sorted_regions) - 1 else "and " + formatted_launched_regions_list.append(region_prefix + region) + formatted_launched_regions_str = ", ".join(formatted_launched_regions_list) + return f"JumpStart is available in {formatted_launched_regions_str} regions." + + +def get_jumpstart_content_bucket(region: str) -> str: + """Returns regionalized content bucket name for JumpStart. + + Raises: + RuntimeError: If JumpStart is not launched in ``region``. + """ + try: + return constants.JUMPSTART_REGION_NAME_TO_LAUNCHED_REGION_DICT[region].content_bucket + except KeyError: + formatted_launched_regions_str = get_jumpstart_launched_regions_message() + raise ValueError( + f"Unable to get content bucket for JumpStart in {region} region. " + f"{formatted_launched_regions_str}" + ) + + +def get_formatted_manifest( + manifest: List[Dict], +) -> Dict[JumpStartVersionedModelId, JumpStartModelHeader]: + """Returns formatted manifest dictionary from raw manifest. + + Keys are JumpStartVersionedModelId objects, values are + ``JumpStartModelHeader`` objects. + """ + manifest_dict = {} + for header in manifest: + header_obj = JumpStartModelHeader(header) + manifest_dict[ + JumpStartVersionedModelId(header_obj.model_id, header_obj.version) + ] = header_obj + return manifest_dict + + +def get_sagemaker_version() -> str: + """Returns sagemaker library version. + + If the sagemaker library version has not been set, this function + calls ``parse_sagemaker_version`` to retrieve the version and set + the constant. + """ + if SageMakerSettings.get_sagemaker_version() == "": + SageMakerSettings.set_sagemaker_version(parse_sagemaker_version()) + return SageMakerSettings.get_sagemaker_version() + + +def parse_sagemaker_version() -> str: + """Returns sagemaker library version. This should only be called once. + + Function reads ``__version__`` variable in ``sagemaker`` module. + In order to maintain compatibility with the ``semantic_version`` + library, versions with fewer than 2, or more than 3, periods are rejected. + All versions that cannot be parsed with ``semantic_version`` are also + rejected. + + Raises: + RuntimeError: If the SageMaker version is not readable. An exception is also raised if + the version cannot be parsed by ``semantic_version``. + """ + version = sagemaker.__version__ + parsed_version = None + + num_periods = version.count(".") + if num_periods == 2: + parsed_version = version + elif num_periods == 3: + trailing_period_index = version.rfind(".") + parsed_version = version[:trailing_period_index] + else: + raise RuntimeError(f"Bad value for SageMaker version: {sagemaker.__version__}") + + semantic_version.Version(parsed_version) + + return parsed_version diff --git a/src/sagemaker/utilities/__init__.py b/src/sagemaker/utilities/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/sagemaker/utilities/cache.py b/src/sagemaker/utilities/cache.py new file mode 100644 index 0000000000..b5a48ccef8 --- /dev/null +++ b/src/sagemaker/utilities/cache.py @@ -0,0 +1,166 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""This module defines a LRU cache class.""" +from __future__ import absolute_import + +import datetime +import collections +from typing import TypeVar, Generic, Callable, Optional + +KeyType = TypeVar("KeyType") +ValType = TypeVar("ValType") + + +class LRUCache(Generic[KeyType, ValType]): + """Class that implements LRU cache with expiring items. + + LRU caches remove items in a FIFO manner, such that the oldest + items to be used are the first to be removed. + If you attempt to retrieve a cache item that is older than the + expiration time, the item will be invalidated. + """ + + class Element: + """Class describes the values in the cache. + + This object stores the value itself as well as a timestamp so that this + element can be invalidated if it becomes too old. + """ + + def __init__(self, value: ValType, creation_time: datetime.datetime): + """Initialize an ``Element`` instance for ``LRUCache``. + + Args: + value (ValType): Value that is stored in cache. + creation_time (datetime.datetime): Time at which cache item was created. + """ + self.value = value + self.creation_time = creation_time + + def __init__( + self, + max_cache_items: int, + expiration_horizon: datetime.timedelta, + retrieval_function: Callable[[KeyType, ValType], ValType], + ) -> None: + """Initialize an ``LRUCache`` instance. + + Args: + max_cache_items (int): Maximum number of items to store in cache. + expiration_horizon (datetime.timedelta): Maximum time duration a cache element can + persist before being invalidated. + retrieval_function (Callable[[KeyType, ValType], ValType]): Function which maps cache + keys and current values to new values. This function must have kwarg arguments + ``key`` and ``value``. This function is called as a fallback when the key + is not found in the cache, or a key has expired. + + """ + self._max_cache_items = max_cache_items + self._lru_cache: collections.OrderedDict = collections.OrderedDict() + self._expiration_horizon = expiration_horizon + self._retrieval_function = retrieval_function + + def __len__(self) -> int: + """Returns number of elements in cache.""" + return len(self._lru_cache) + + def __contains__(self, key: KeyType) -> bool: + """Returns True if key is found in cache, False otherwise. + + Args: + key (KeyType): Key in cache to retrieve. + """ + return key in self._lru_cache + + def clear(self) -> None: + """Deletes all elements from the cache.""" + self._lru_cache.clear() + + def get(self, key: KeyType, data_source_fallback: Optional[bool] = True) -> ValType: + """Returns value corresponding to key in cache. + + Args: + key (KeyType): Key in cache to retrieve. + data_source_fallback (Optional[bool]): True if data should be retrieved if + it's stale or not in cache. Default: True. + Raises: + KeyError: If key is not found in cache or is outdated and + ``data_source_fallback`` is False. + """ + if data_source_fallback: + if key in self._lru_cache: + return self._get_item(key, False) + self.put(key) + return self._get_item(key, False) + return self._get_item(key, True) + + def put(self, key: KeyType, value: Optional[ValType] = None) -> None: + """Adds key to cache using ``retrieval_function``. + + If value is provided, this is used instead. If the key is already in cache, + the old element is removed. If the cache size exceeds the size limit, old + elements are removed in order to meet the limit. + + Args: + key (KeyType): Key in cache to retrieve. + value (Optional[ValType]): Value to store for key. Default: None. + """ + curr_value = None + if key in self._lru_cache: + curr_value = self._lru_cache.pop(key) + + while len(self._lru_cache) >= self._max_cache_items: + self._lru_cache.popitem(last=False) + + if value is None: + value = self._retrieval_function( # type: ignore + key=key, value=curr_value.element if curr_value else None + ) + + self._lru_cache[key] = self.Element( + value=value, creation_time=datetime.datetime.now(tz=datetime.timezone.utc) + ) + + def _get_item(self, key: KeyType, fail_on_old_value: bool) -> ValType: + """Returns value from cache corresponding to key. + + If ``fail_on_old_value``, a KeyError is raised instead of a new value + getting fetched. + + Args: + key (KeyType): Key in cache to retrieve. + fail_on_old_value (bool): True if a KeyError is raised when the cache value + is old. + + Raises: + KeyError: If key is not in cache or if key is old in cache + and fail_on_old_value is True. + """ + try: + element = self._lru_cache.pop(key) + curr_time = datetime.datetime.now(tz=datetime.timezone.utc) + element_age = curr_time - element.creation_time + if element_age > self._expiration_horizon: + if fail_on_old_value: + raise KeyError( + f"{key} has aged beyond allowed time {self._expiration_horizon}. " + f"Element created at {element.creation_time}." + ) + element.value = self._retrieval_function( # type: ignore + key=key, value=element.value + ) + element.creation_time = curr_time + self._lru_cache[key] = element + return element.value + except KeyError: + raise KeyError(f"{key} not found in LRUCache!") diff --git a/tests/unit/sagemaker/jumpstart/__init__.py b/tests/unit/sagemaker/jumpstart/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/unit/sagemaker/jumpstart/test_cache.py b/tests/unit/sagemaker/jumpstart/test_cache.py new file mode 100644 index 0000000000..e073a80d67 --- /dev/null +++ b/tests/unit/sagemaker/jumpstart/test_cache.py @@ -0,0 +1,714 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import +import copy +import datetime +import io +import json +from botocore.stub import Stubber +import botocore + +from mock.mock import MagicMock +import pytest +from mock import patch + +from sagemaker.jumpstart.cache import JUMPSTART_DEFAULT_MANIFEST_FILE_S3_KEY, JumpStartModelsCache +from sagemaker.jumpstart.types import ( + JumpStartCachedS3ContentKey, + JumpStartCachedS3ContentValue, + JumpStartModelHeader, + JumpStartModelSpecs, + JumpStartS3FileType, + JumpStartVersionedModelId, +) +from sagemaker.jumpstart.utils import get_formatted_manifest + +BASE_SPEC = { + "model_id": "pytorch-ic-mobilenet-v2", + "version": "1.0.0", + "min_sdk_version": "2.49.0", + "training_supported": True, + "incremental_training_supported": True, + "hosting_ecr_specs": { + "framework": "pytorch", + "framework_version": "1.7.0", + "py_version": "py3", + }, + "training_ecr_specs": { + "framework": "pytorch", + "framework_version": "1.9.0", + "py_version": "py3", + }, + "hosting_artifact_key": "pytorch-infer/infer-pytorch-ic-mobilenet-v2.tar.gz", + "training_artifact_key": "pytorch-training/train-pytorch-ic-mobilenet-v2.tar.gz", + "hosting_script_key": "source-directory-tarballs/pytorch/inference/ic/v1.0.0/sourcedir.tar.gz", + "training_script_key": "source-directory-tarballs/pytorch/transfer_learning/ic/v1.0.0/sourcedir.tar.gz", + "hyperparameters": { + "adam-learning-rate": {"type": "float", "default": 0.05, "min": 1e-08, "max": 1}, + "epochs": {"type": "int", "default": 3, "min": 1, "max": 1000}, + "batch-size": {"type": "int", "default": 4, "min": 1, "max": 1024}, + }, +} + +BASE_MANIFEST = [ + { + "model_id": "tensorflow-ic-imagenet-inception-v3-classification-4", + "version": "1.0.0", + "min_version": "2.49.0", + "spec_key": "community_models_specs/tensorflow-ic-imagenet" + "-inception-v3-classification-4/specs_v1.0.0.json", + }, + { + "model_id": "tensorflow-ic-imagenet-inception-v3-classification-4", + "version": "2.0.0", + "min_version": "2.49.0", + "spec_key": "community_models_specs/tensorflow-ic-imagenet" + "-inception-v3-classification-4/specs_v2.0.0.json", + }, + { + "model_id": "pytorch-ic-imagenet-inception-v3-classification-4", + "version": "1.0.0", + "min_version": "2.49.0", + "spec_key": "community_models_specs/pytorch-ic-" + "imagenet-inception-v3-classification-4/specs_v1.0.0.json", + }, + { + "model_id": "pytorch-ic-imagenet-inception-v3-classification-4", + "version": "2.0.0", + "min_version": "2.49.0", + "spec_key": "community_models_specs/pytorch-ic-imagenet-" + "inception-v3-classification-4/specs_v2.0.0.json", + }, + { + "model_id": "tensorflow-ic-imagenet-inception-v3-classification-4", + "version": "3.0.0", + "min_version": "4.49.0", + "spec_key": "community_models_specs/tensorflow-ic-" + "imagenet-inception-v3-classification-4/specs_v3.0.0.json", + }, +] + + +def get_spec_from_base_spec(model_id: str, version: str) -> JumpStartModelSpecs: + spec = copy.deepcopy(BASE_SPEC) + + spec["version"] = version + spec["model_id"] = model_id + return JumpStartModelSpecs(spec) + + +def patched_get_file_from_s3( + _modelCacheObj: JumpStartModelsCache, + key: JumpStartCachedS3ContentKey, + value: JumpStartCachedS3ContentValue, +) -> JumpStartCachedS3ContentValue: + + filetype, s3_key = key.file_type, key.s3_key + if filetype == JumpStartS3FileType.MANIFEST: + + return JumpStartCachedS3ContentValue( + formatted_content=get_formatted_manifest(BASE_MANIFEST) + ) + + if filetype == JumpStartS3FileType.SPECS: + _, model_id, specs_version = s3_key.split("/") + version = specs_version.replace("specs_v", "").replace(".json", "") + return JumpStartCachedS3ContentValue( + formatted_content=get_spec_from_base_spec(model_id, version) + ) + + raise ValueError(f"Bad value for filetype: {filetype}") + + +@patch.object(JumpStartModelsCache, "_get_file_from_s3", patched_get_file_from_s3) +@patch("sagemaker.jumpstart.utils.get_sagemaker_version", lambda: "2.68.3") +def test_jumpstart_cache_get_header(): + + cache = JumpStartModelsCache(s3_bucket_name="some_bucket") + + assert JumpStartModelHeader( + { + "model_id": "tensorflow-ic-imagenet-inception-v3-classification-4", + "version": "2.0.0", + "min_version": "2.49.0", + "spec_key": "community_models_specs/tensorflow-ic" + "-imagenet-inception-v3-classification-4/specs_v2.0.0.json", + } + ) == cache.get_header( + model_id="tensorflow-ic-imagenet-inception-v3-classification-4", semantic_version_str="*" + ) + + # See if we can make the same query 2 times consecutively + assert JumpStartModelHeader( + { + "model_id": "tensorflow-ic-imagenet-inception-v3-classification-4", + "version": "2.0.0", + "min_version": "2.49.0", + "spec_key": "community_models_specs/tensorflow-ic" + "-imagenet-inception-v3-classification-4/specs_v2.0.0.json", + } + ) == cache.get_header( + model_id="tensorflow-ic-imagenet-inception-v3-classification-4", semantic_version_str="*" + ) + + assert JumpStartModelHeader( + { + "model_id": "tensorflow-ic-imagenet-inception-v3-classification-4", + "version": "2.0.0", + "min_version": "2.49.0", + "spec_key": "community_models_specs/tensorflow-ic-" + "imagenet-inception-v3-classification-4/specs_v2.0.0.json", + } + ) == cache.get_header( + model_id="tensorflow-ic-imagenet-inception-v3-classification-4", semantic_version_str="2.*" + ) + + assert JumpStartModelHeader( + { + "model_id": "tensorflow-ic-imagenet-inception-v3-classification-4", + "version": "2.0.0", + "min_version": "2.49.0", + "spec_key": "community_models_specs/tensorflow-ic-" + "imagenet-inception-v3-classification-4/specs_v2.0.0.json", + } + ) == cache.get_header( + model_id="tensorflow-ic-imagenet-inception-v3-classification-4", + semantic_version_str="2.*.*", + ) + + assert JumpStartModelHeader( + { + "model_id": "tensorflow-ic-imagenet-inception-v3-classification-4", + "version": "2.0.0", + "min_version": "2.49.0", + "spec_key": "community_models_specs/tensorflow-ic-" + "imagenet-inception-v3-classification-4/specs_v2.0.0.json", + } + ) == cache.get_header( + model_id="tensorflow-ic-imagenet-inception-v3-classification-4", + semantic_version_str="2.0.0", + ) + + assert JumpStartModelHeader( + { + "model_id": "tensorflow-ic-imagenet-inception-v3-classification-4", + "version": "1.0.0", + "min_version": "2.49.0", + "spec_key": "community_models_specs/tensorflow-ic-" + "imagenet-inception-v3-classification-4/specs_v1.0.0.json", + } + ) == cache.get_header( + model_id="tensorflow-ic-imagenet-inception-v3-classification-4", + semantic_version_str="1.0.0", + ) + + assert JumpStartModelHeader( + { + "model_id": "tensorflow-ic-imagenet-inception-v3-classification-4", + "version": "1.0.0", + "min_version": "2.49.0", + "spec_key": "community_models_specs/tensorflow-ic-" + "imagenet-inception-v3-classification-4/specs_v1.0.0.json", + } + ) == cache.get_header( + model_id="tensorflow-ic-imagenet-inception-v3-classification-4", semantic_version_str="1.*" + ) + + assert JumpStartModelHeader( + { + "model_id": "tensorflow-ic-imagenet-inception-v3-classification-4", + "version": "1.0.0", + "min_version": "2.49.0", + "spec_key": "community_models_specs/tensorflow-ic-" + "imagenet-inception-v3-classification-4/specs_v1.0.0.json", + } + ) == cache.get_header( + model_id="tensorflow-ic-imagenet-inception-v3-classification-4", + semantic_version_str="1.*.*", + ) + + with pytest.raises(KeyError) as e: + cache.get_header( + model_id="tensorflow-ic-imagenet-inception-v3-classification-4", + semantic_version_str="3.*", + ) + assert ( + "Unable to find model manifest for tensorflow-ic-imagenet-inception-v3-classification-4 " + "with version 3.* compatible with your SageMaker version (2.68.3). Consider upgrading " + "your SageMaker library to at least version 4.49.0 so you can use version 3.0.0 of " + "tensorflow-ic-imagenet-inception-v3-classification-4." in str(e.value) + ) + + with pytest.raises(KeyError) as e: + cache.get_header( + model_id="pytorch-ic-imagenet-inception-v3-classification-4", semantic_version_str="3.*" + ) + assert "Consider upgrading" not in str(e.value) + + with pytest.raises(ValueError): + cache.get_header( + model_id="tensorflow-ic-imagenet-inception-v3-classification-4", + semantic_version_str="BAD", + ) + + with pytest.raises(KeyError): + cache.get_header( + model_id="tensorflow-ic-imagenet-inception-v3-classification-4-bak", + semantic_version_str="*", + ) + + +@patch("boto3.client") +def test_jumpstart_cache_handles_boto3_issues(mock_boto3_client): + + mock_boto3_client.return_value.get_object.side_effect = Exception() + + cache = JumpStartModelsCache(s3_bucket_name="some_bucket") + + with pytest.raises(Exception): + cache.get_header( + model_id="tensorflow-ic-imagenet-inception-v3-classification-4", + ) + + mock_boto3_client.return_value.reset_mock() + + mock_boto3_client.return_value.head_object.side_effect = Exception() + + cache = JumpStartModelsCache(s3_bucket_name="some_bucket") + + with pytest.raises(Exception): + cache.get_header( + model_id="tensorflow-ic-imagenet-inception-v3-classification-4", + ) + + +@patch("boto3.client") +def test_jumpstart_cache_gets_cleared_when_params_are_set(mock_boto3_client): + cache = JumpStartModelsCache( + s3_bucket_name="some_bucket", region="some_region", manifest_file_s3_key="some_key" + ) + + cache.clear = MagicMock() + cache.set_s3_bucket_name("some_bucket") + cache.clear.assert_not_called() + cache.clear.reset_mock() + cache.set_region("some_region") + cache.clear.assert_not_called() + cache.clear.reset_mock() + cache.set_manifest_file_s3_key("some_key") + cache.clear.assert_not_called() + + cache.clear.reset_mock() + + cache.set_s3_bucket_name("some_bucket1") + cache.clear.assert_called_once() + cache.clear.reset_mock() + cache.set_region("some_region1") + cache.clear.assert_called_once() + cache.clear.reset_mock() + cache.set_manifest_file_s3_key("some_key1") + cache.clear.assert_called_once() + + +def test_jumpstart_cache_handles_boto3_client_errors(): + # Testing get_object + cache = JumpStartModelsCache(s3_bucket_name="some_bucket") + stubbed_s3_client = Stubber(cache._s3_client) + stubbed_s3_client.add_client_error("get_object", http_status_code=404) + stubbed_s3_client.activate() + with pytest.raises(botocore.exceptions.ClientError): + cache.get_header( + model_id="tensorflow-ic-imagenet-inception-v3-classification-4", + semantic_version_str="*", + ) + + cache = JumpStartModelsCache(s3_bucket_name="some_bucket") + stubbed_s3_client = Stubber(cache._s3_client) + stubbed_s3_client.add_client_error("get_object", service_error_code="AccessDenied") + stubbed_s3_client.activate() + with pytest.raises(botocore.exceptions.ClientError): + cache.get_header( + model_id="tensorflow-ic-imagenet-inception-v3-classification-4", + semantic_version_str="*", + ) + + cache = JumpStartModelsCache(s3_bucket_name="some_bucket") + stubbed_s3_client = Stubber(cache._s3_client) + stubbed_s3_client.add_client_error("get_object", service_error_code="EndpointConnectionError") + stubbed_s3_client.activate() + with pytest.raises(botocore.exceptions.ClientError): + cache.get_header( + model_id="tensorflow-ic-imagenet-inception-v3-classification-4", + semantic_version_str="*", + ) + + # Testing head_object: + mock_now = datetime.datetime.fromtimestamp(1636730651.079551) + with patch("datetime.datetime") as mock_datetime: + mock_manifest_json = json.dumps( + [ + { + "model_id": "pytorch-ic-imagenet-inception-v3-classification-4", + "version": "2.0.0", + "min_version": "2.49.0", + "spec_key": "community_models_specs/pytorch-ic-" + "imagenet-inception-v3-classification-4/specs_v2.0.0.json", + } + ] + ) + + get_object_mocked_response = { + "Body": botocore.response.StreamingBody( + io.BytesIO(bytes(mock_manifest_json, "utf-8")), + content_length=len(mock_manifest_json), + ), + "ETag": "etag", + } + + mock_datetime.now.return_value = mock_now + + cache1 = JumpStartModelsCache( + s3_bucket_name="some_bucket", s3_cache_expiration_horizon=datetime.timedelta(hours=1) + ) + stubbed_s3_client1 = Stubber(cache1._s3_client) + + stubbed_s3_client1.add_response("get_object", copy.deepcopy(get_object_mocked_response)) + stubbed_s3_client1.activate() + cache1.get_header( + model_id="pytorch-ic-imagenet-inception-v3-classification-4", semantic_version_str="*" + ) + + mock_datetime.now.return_value += datetime.timedelta(weeks=1) + + stubbed_s3_client1.add_client_error("head_object", http_status_code=404) + with pytest.raises(botocore.exceptions.ClientError): + cache1.get_header( + model_id="pytorch-ic-imagenet-inception-v3-classification-4", + semantic_version_str="*", + ) + + cache2 = JumpStartModelsCache( + s3_bucket_name="some_bucket", s3_cache_expiration_horizon=datetime.timedelta(hours=1) + ) + stubbed_s3_client2 = Stubber(cache2._s3_client) + + stubbed_s3_client2.add_response("get_object", copy.deepcopy(get_object_mocked_response)) + stubbed_s3_client2.activate() + cache2.get_header( + model_id="pytorch-ic-imagenet-inception-v3-classification-4", semantic_version_str="*" + ) + + mock_datetime.now.return_value += datetime.timedelta(weeks=1) + + stubbed_s3_client2.add_client_error("head_object", service_error_code="AccessDenied") + with pytest.raises(botocore.exceptions.ClientError): + cache2.get_header( + model_id="pytorch-ic-imagenet-inception-v3-classification-4", + semantic_version_str="*", + ) + + cache3 = JumpStartModelsCache( + s3_bucket_name="some_bucket", s3_cache_expiration_horizon=datetime.timedelta(hours=1) + ) + stubbed_s3_client3 = Stubber(cache3._s3_client) + + stubbed_s3_client3.add_response("get_object", copy.deepcopy(get_object_mocked_response)) + stubbed_s3_client3.activate() + cache3.get_header( + model_id="pytorch-ic-imagenet-inception-v3-classification-4", semantic_version_str="*" + ) + + mock_datetime.now.return_value += datetime.timedelta(weeks=1) + + stubbed_s3_client3.add_client_error( + "head_object", service_error_code="EndpointConnectionError" + ) + with pytest.raises(botocore.exceptions.ClientError): + cache3.get_header( + model_id="pytorch-ic-imagenet-inception-v3-classification-4", + semantic_version_str="*", + ) + + +def test_jumpstart_cache_accepts_input_parameters(): + + region = "us-east-1" + max_s3_cache_items = 1 + s3_cache_expiration_horizon = datetime.timedelta(weeks=2) + max_semantic_version_cache_items = 3 + semantic_version_cache_expiration_horizon = datetime.timedelta(microseconds=4) + bucket = "my-amazing-bucket" + manifest_file_key = "some_s3_key" + + cache = JumpStartModelsCache( + region=region, + max_s3_cache_items=max_s3_cache_items, + s3_cache_expiration_horizon=s3_cache_expiration_horizon, + max_semantic_version_cache_items=max_semantic_version_cache_items, + semantic_version_cache_expiration_horizon=semantic_version_cache_expiration_horizon, + s3_bucket_name=bucket, + manifest_file_s3_key=manifest_file_key, + ) + + assert cache.get_manifest_file_s3_key() == manifest_file_key + assert cache.get_region() == region + assert cache.get_bucket() == bucket + assert cache._s3_cache._max_cache_items == max_s3_cache_items + assert cache._s3_cache._expiration_horizon == s3_cache_expiration_horizon + assert ( + cache._model_id_semantic_version_manifest_key_cache._max_cache_items + == max_semantic_version_cache_items + ) + assert ( + cache._model_id_semantic_version_manifest_key_cache._expiration_horizon + == semantic_version_cache_expiration_horizon + ) + + +@patch("boto3.client") +def test_jumpstart_cache_evaluates_md5_hash(mock_boto3_client): + + mock_json = json.dumps( + [ + { + "model_id": "pytorch-ic-imagenet-inception-v3-classification-4", + "version": "2.0.0", + "min_version": "2.49.0", + "spec_key": "community_models_specs/pytorch-ic-" + "imagenet-inception-v3-classification-4/specs_v2.0.0.json", + } + ] + ) + + bucket_name = "bucket_name" + now = datetime.datetime.fromtimestamp(1636730651.079551) + + with patch("datetime.datetime") as mock_datetime: + mock_datetime.now.return_value = now + + cache = JumpStartModelsCache( + s3_bucket_name=bucket_name, s3_cache_expiration_horizon=datetime.timedelta(hours=1) + ) + + mock_boto3_client.return_value.get_object.return_value = { + "Body": botocore.response.StreamingBody( + io.BytesIO(bytes(mock_json, "utf-8")), content_length=len(mock_json) + ), + "ETag": "hash1", + } + mock_boto3_client.return_value.head_object.return_value = {"ETag": "hash1"} + + cache.get_header( + model_id="pytorch-ic-imagenet-inception-v3-classification-4", semantic_version_str="*" + ) + + # first time accessing cache should just involve get_object + mock_boto3_client.return_value.get_object.assert_called_with( + Bucket=bucket_name, Key=JUMPSTART_DEFAULT_MANIFEST_FILE_S3_KEY + ) + mock_boto3_client.return_value.head_object.assert_not_called() + + mock_boto3_client.return_value.get_object.reset_mock() + mock_boto3_client.return_value.head_object.reset_mock() + + # second time accessing cache should just involve head_object if hash hasn't changed + mock_boto3_client.return_value.get_object.return_value = { + "Body": botocore.response.StreamingBody( + io.BytesIO(bytes(mock_json, "utf-8")), content_length=len(mock_json) + ), + "ETag": "hash1", + } + mock_boto3_client.return_value.head_object.return_value = {"ETag": "hash1"} + + # invalidate cache + mock_datetime.now.return_value += datetime.timedelta(hours=2) + + cache.get_header( + model_id="pytorch-ic-imagenet-inception-v3-classification-4", semantic_version_str="*" + ) + + mock_boto3_client.return_value.head_object.assert_called_with( + Bucket=bucket_name, Key=JUMPSTART_DEFAULT_MANIFEST_FILE_S3_KEY + ) + mock_boto3_client.return_value.get_object.assert_not_called() + + mock_boto3_client.return_value.get_object.reset_mock() + mock_boto3_client.return_value.head_object.reset_mock() + + # third time accessing cache should involve head_object and get_object if hash has changed + mock_boto3_client.return_value.head_object.return_value = {"ETag": "hash2"} + mock_boto3_client.return_value.get_object.return_value = { + "Body": botocore.response.StreamingBody( + io.BytesIO(bytes(mock_json, "utf-8")), content_length=len(mock_json) + ), + "ETag": "hash2", + } + + # invalidate cache + mock_datetime.now.return_value += datetime.timedelta(hours=2) + + cache.get_header( + model_id="pytorch-ic-imagenet-inception-v3-classification-4", semantic_version_str="*" + ) + + mock_boto3_client.return_value.get_object.assert_called_with( + Bucket=bucket_name, Key=JUMPSTART_DEFAULT_MANIFEST_FILE_S3_KEY + ) + mock_boto3_client.return_value.head_object.assert_called_with( + Bucket=bucket_name, Key=JUMPSTART_DEFAULT_MANIFEST_FILE_S3_KEY + ) + + +@patch("boto3.client") +def test_jumpstart_cache_makes_correct_s3_calls(mock_boto3_client): + + # test get_header + mock_json = json.dumps( + [ + { + "model_id": "pytorch-ic-imagenet-inception-v3-classification-4", + "version": "2.0.0", + "min_version": "2.49.0", + "spec_key": "community_models_specs/pytorch-ic-" + "imagenet-inception-v3-classification-4/specs_v2.0.0.json", + } + ] + ) + mock_boto3_client.return_value.get_object.return_value = { + "Body": botocore.response.StreamingBody( + io.BytesIO(bytes(mock_json, "utf-8")), content_length=len(mock_json) + ), + "ETag": "etag", + } + + mock_boto3_client.return_value.head_object.return_value = {"ETag": "some-hash"} + + bucket_name = "bucket_name" + client_config = botocore.config.Config(signature_version="my_signature_version") + cache = JumpStartModelsCache( + s3_bucket_name=bucket_name, s3_client_config=client_config, region="my_region" + ) + cache.get_header( + model_id="pytorch-ic-imagenet-inception-v3-classification-4", semantic_version_str="*" + ) + + mock_boto3_client.return_value.get_object.assert_called_with( + Bucket=bucket_name, Key=JUMPSTART_DEFAULT_MANIFEST_FILE_S3_KEY + ) + mock_boto3_client.return_value.head_object.assert_not_called() + + mock_boto3_client.assert_called_with("s3", region_name="my_region", config=client_config) + + # test get_specs. manifest already in cache, so only s3 call will be to get specs. + mock_json = json.dumps(BASE_SPEC) + + mock_boto3_client.return_value.reset_mock() + + mock_boto3_client.return_value.get_object.return_value = { + "Body": botocore.response.StreamingBody( + io.BytesIO(bytes(mock_json, "utf-8")), content_length=len(mock_json) + ), + "ETag": "etag", + } + cache.get_specs( + model_id="pytorch-ic-imagenet-inception-v3-classification-4", semantic_version_str="*" + ) + + mock_boto3_client.return_value.get_object.assert_called_with( + Bucket=bucket_name, + Key="community_models_specs/pytorch-ic-imagenet-" + "inception-v3-classification-4/specs_v2.0.0.json", + ) + mock_boto3_client.return_value.head_object.assert_not_called() + + +@patch.object(JumpStartModelsCache, "_get_file_from_s3", patched_get_file_from_s3) +def test_jumpstart_cache_handles_bad_semantic_version_manifest_key_cache(): + cache = JumpStartModelsCache(s3_bucket_name="some_bucket") + + cache.clear = MagicMock() + cache._model_id_semantic_version_manifest_key_cache = MagicMock() + cache._model_id_semantic_version_manifest_key_cache.get.side_effect = [ + JumpStartVersionedModelId( + "tensorflow-ic-imagenet-inception-v3-classification-4", "999.0.0" + ), + JumpStartVersionedModelId("tensorflow-ic-imagenet-inception-v3-classification-4", "1.0.0"), + ] + + assert JumpStartModelHeader( + { + "model_id": "tensorflow-ic-imagenet-inception-v3-classification-4", + "version": "1.0.0", + "min_version": "2.49.0", + "spec_key": "community_models_specs/tensorflow-ic-" + "imagenet-inception-v3-classification-4/specs_v1.0.0.json", + } + ) == cache.get_header( + model_id="tensorflow-ic-imagenet-inception-v3-classification-4", semantic_version_str="*" + ) + cache.clear.assert_called_once() + cache.clear.reset_mock() + + cache._model_id_semantic_version_manifest_key_cache.get.side_effect = [ + JumpStartVersionedModelId( + "tensorflow-ic-imagenet-inception-v3-classification-4", "999.0.0" + ), + JumpStartVersionedModelId( + "tensorflow-ic-imagenet-inception-v3-classification-4", "987.0.0" + ), + ] + with pytest.raises(KeyError): + cache.get_header( + model_id="tensorflow-ic-imagenet-inception-v3-classification-4", + semantic_version_str="*", + ) + cache.clear.assert_called_once() + + +@patch.object(JumpStartModelsCache, "_get_file_from_s3", patched_get_file_from_s3) +@patch("sagemaker.jumpstart.utils.get_sagemaker_version", lambda: "2.68.3") +def test_jumpstart_get_full_manifest(): + cache = JumpStartModelsCache(s3_bucket_name="some_bucket") + raw_manifest = [header.to_json() for header in cache.get_manifest()] + + raw_manifest == BASE_MANIFEST + + +@patch.object(JumpStartModelsCache, "_get_file_from_s3", patched_get_file_from_s3) +@patch("sagemaker.jumpstart.utils.get_sagemaker_version", lambda: "2.68.3") +def test_jumpstart_cache_get_specs(): + cache = JumpStartModelsCache(s3_bucket_name="some_bucket") + + model_id, version = "tensorflow-ic-imagenet-inception-v3-classification-4", "2.0.0" + assert get_spec_from_base_spec(model_id, version) == cache.get_specs( + model_id=model_id, semantic_version_str=version + ) + + model_id, version = "tensorflow-ic-imagenet-inception-v3-classification-4", "1.0.0" + assert get_spec_from_base_spec(model_id, version) == cache.get_specs( + model_id=model_id, semantic_version_str=version + ) + + model_id = "pytorch-ic-imagenet-inception-v3-classification-4" + assert get_spec_from_base_spec(model_id, "1.0.0") == cache.get_specs( + model_id=model_id, semantic_version_str="1.*" + ) + + with pytest.raises(KeyError): + cache.get_specs(model_id=model_id + "bak", semantic_version_str="*") + + with pytest.raises(KeyError): + cache.get_specs(model_id=model_id, semantic_version_str="9.*") + + with pytest.raises(ValueError): + cache.get_specs(model_id=model_id, semantic_version_str="BAD") diff --git a/tests/unit/sagemaker/jumpstart/test_types.py b/tests/unit/sagemaker/jumpstart/test_types.py new file mode 100644 index 0000000000..6f970d6e58 --- /dev/null +++ b/tests/unit/sagemaker/jumpstart/test_types.py @@ -0,0 +1,127 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import +import copy +from sagemaker.jumpstart.types import JumpStartECRSpecs, JumpStartModelSpecs, JumpStartModelHeader + + +def test_jumpstart_model_header(): + + header_dict = { + "model_id": "tensorflow-ic-imagenet-inception-v3-classification-4", + "version": "1.0.0", + "min_version": "2.49.0", + "spec_key": "community_models_specs/tensorflow-ic-imagenet-inception-v3-classification-4/specs_v1.0.0.json", + } + + header1 = JumpStartModelHeader(header_dict) + + assert header1.model_id == "tensorflow-ic-imagenet-inception-v3-classification-4" + assert header1.version == "1.0.0" + assert header1.min_version == "2.49.0" + assert ( + header1.spec_key + == "community_models_specs/tensorflow-ic-imagenet-inception-v3-classification-4/specs_v1.0.0.json" + ) + + assert header1.to_json() == header_dict + + header2 = JumpStartModelHeader( + { + "model_id": "pytorch-ic-imagenet-inception-v3-classification-4", + "version": "1.0.0", + "min_version": "2.49.0", + "spec_key": "community_models_specs/tensorflow-ic-imagenet-inception-v3-classification-4/specs_v1.0.0.json", + } + ) + + assert header1 != header2 + + header3 = copy.deepcopy(header1) + assert header1 == header3 + + +def test_jumpstart_model_specs(): + + specs_dict = { + "model_id": "pytorch-ic-mobilenet-v2", + "version": "1.0.0", + "min_sdk_version": "2.49.0", + "training_supported": True, + "incremental_training_supported": True, + "hosting_ecr_specs": { + "framework": "pytorch", + "framework_version": "1.7.0", + "py_version": "py3", + }, + "training_ecr_specs": { + "framework": "pytorch", + "framework_version": "1.9.0", + "py_version": "py3", + }, + "hosting_artifact_key": "pytorch-infer/infer-pytorch-ic-mobilenet-v2.tar.gz", + "training_artifact_key": "pytorch-training/train-pytorch-ic-mobilenet-v2.tar.gz", + "hosting_script_key": "source-directory-tarballs/pytorch/inference/ic/v1.0.0/sourcedir.tar.gz", + "training_script_key": "source-directory-tarballs/pytorch/transfer_learning/ic/v1.0.0/sourcedir.tar.gz", + "hyperparameters": { + "adam-learning-rate": {"type": "float", "default": 0.05, "min": 1e-08, "max": 1}, + "epochs": {"type": "int", "default": 3, "min": 1, "max": 1000}, + "batch-size": {"type": "int", "default": 4, "min": 1, "max": 1024}, + }, + } + + specs1 = JumpStartModelSpecs(specs_dict) + + assert specs1.model_id == "pytorch-ic-mobilenet-v2" + assert specs1.version == "1.0.0" + assert specs1.min_sdk_version == "2.49.0" + assert specs1.training_supported + assert specs1.incremental_training_supported + assert specs1.hosting_ecr_specs == JumpStartECRSpecs( + { + "framework": "pytorch", + "framework_version": "1.7.0", + "py_version": "py3", + } + ) + assert specs1.training_ecr_specs == JumpStartECRSpecs( + { + "framework": "pytorch", + "framework_version": "1.9.0", + "py_version": "py3", + } + ) + assert specs1.hosting_artifact_key == "pytorch-infer/infer-pytorch-ic-mobilenet-v2.tar.gz" + assert specs1.training_artifact_key == "pytorch-training/train-pytorch-ic-mobilenet-v2.tar.gz" + assert ( + specs1.hosting_script_key + == "source-directory-tarballs/pytorch/inference/ic/v1.0.0/sourcedir.tar.gz" + ) + assert ( + specs1.training_script_key + == "source-directory-tarballs/pytorch/transfer_learning/ic/v1.0.0/sourcedir.tar.gz" + ) + assert specs1.hyperparameters == { + "adam-learning-rate": {"type": "float", "default": 0.05, "min": 1e-08, "max": 1}, + "epochs": {"type": "int", "default": 3, "min": 1, "max": 1000}, + "batch-size": {"type": "int", "default": 4, "min": 1, "max": 1024}, + } + + assert specs1.to_json() == specs_dict + + specs_dict["model_id"] = "diff model id" + specs2 = JumpStartModelSpecs(specs_dict) + assert specs1 != specs2 + + specs3 = copy.deepcopy(specs1) + assert specs3 == specs1 diff --git a/tests/unit/sagemaker/jumpstart/test_utils.py b/tests/unit/sagemaker/jumpstart/test_utils.py new file mode 100644 index 0000000000..39b4706796 --- /dev/null +++ b/tests/unit/sagemaker/jumpstart/test_utils.py @@ -0,0 +1,114 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import +from mock.mock import Mock, patch +import pytest +from sagemaker.jumpstart import utils +from sagemaker.jumpstart.constants import JUMPSTART_REGION_NAME_SET +from sagemaker.jumpstart.types import JumpStartModelHeader, JumpStartVersionedModelId + + +def test_get_jumpstart_content_bucket(): + bad_region = "bad_region" + assert bad_region not in JUMPSTART_REGION_NAME_SET + with pytest.raises(ValueError): + utils.get_jumpstart_content_bucket(bad_region) + + +def test_get_jumpstart_launched_regions_message(): + + with patch("sagemaker.jumpstart.constants.JUMPSTART_REGION_NAME_SET", {}): + assert ( + utils.get_jumpstart_launched_regions_message() + == "JumpStart is not available in any region." + ) + + with patch("sagemaker.jumpstart.constants.JUMPSTART_REGION_NAME_SET", {"some_region"}): + assert ( + utils.get_jumpstart_launched_regions_message() + == "JumpStart is available in some_region region." + ) + + with patch( + "sagemaker.jumpstart.constants.JUMPSTART_REGION_NAME_SET", {"some_region1", "some_region2"} + ): + assert ( + utils.get_jumpstart_launched_regions_message() + == "JumpStart is available in some_region1 and some_region2 regions." + ) + + with patch("sagemaker.jumpstart.constants.JUMPSTART_REGION_NAME_SET", {"a", "b", "c"}): + assert ( + utils.get_jumpstart_launched_regions_message() + == "JumpStart is available in a, b, and c regions." + ) + + +def test_get_formatted_manifest(): + mock_manifest = [ + { + "model_id": "tensorflow-ic-imagenet-inception-v3-classification-4", + "version": "1.0.0", + "min_version": "2.49.0", + "spec_key": "community_models_specs/tensorflow-ic-imagenet-inception-v3-classification-4/specs_v1.0.0.json", + }, + ] + + assert utils.get_formatted_manifest(mock_manifest) == { + JumpStartVersionedModelId( + "tensorflow-ic-imagenet-inception-v3-classification-4", "1.0.0" + ): JumpStartModelHeader(mock_manifest[0]) + } + + assert utils.get_formatted_manifest([]) == {} + + +def test_parse_sagemaker_version(): + + with patch("sagemaker.__version__", "1.2.3"): + assert utils.parse_sagemaker_version() == "1.2.3" + + with patch("sagemaker.__version__", "1.2.3.3332j"): + assert utils.parse_sagemaker_version() == "1.2.3" + + with patch("sagemaker.__version__", "1.2.3."): + assert utils.parse_sagemaker_version() == "1.2.3" + + with pytest.raises(ValueError): + with patch("sagemaker.__version__", "1.2.3dfsdfs"): + utils.parse_sagemaker_version() + + with pytest.raises(RuntimeError): + with patch("sagemaker.__version__", "1.2"): + utils.parse_sagemaker_version() + + with pytest.raises(RuntimeError): + with patch("sagemaker.__version__", "1"): + utils.parse_sagemaker_version() + + with pytest.raises(RuntimeError): + with patch("sagemaker.__version__", ""): + utils.parse_sagemaker_version() + + with pytest.raises(RuntimeError): + with patch("sagemaker.__version__", "1.2.3.4.5"): + utils.parse_sagemaker_version() + + +@patch("sagemaker.jumpstart.utils.parse_sagemaker_version") +@patch("sagemaker.jumpstart.utils.SageMakerSettings._PARSED_SAGEMAKER_VERSION", "") +def test_get_sagemaker_version(patched_parse_sm_version: Mock): + utils.get_sagemaker_version() + utils.get_sagemaker_version() + utils.get_sagemaker_version() + assert patched_parse_sm_version.called_only_once() diff --git a/tests/unit/sagemaker/utilities/__init__.py b/tests/unit/sagemaker/utilities/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/unit/sagemaker/utilities/test_cache.py b/tests/unit/sagemaker/utilities/test_cache.py new file mode 100644 index 0000000000..10fbe45767 --- /dev/null +++ b/tests/unit/sagemaker/utilities/test_cache.py @@ -0,0 +1,195 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import +from typing import Optional, Union +from mock.mock import MagicMock, patch +import pytest + + +from sagemaker.utilities import cache +import datetime + + +def retrieval_function(key: Optional[int] = None, value: Optional[str] = None) -> str: + return str(hash(str(key))) + + +def test_cache_retrieves_item(): + my_cache = cache.LRUCache[int, Union[int, str]]( + max_cache_items=10, + expiration_horizon=datetime.timedelta(hours=1), + retrieval_function=retrieval_function, + ) + + my_cache.put(5) + assert my_cache.get(5, False) == retrieval_function(key=5) + + my_cache.put(6, 7) + assert my_cache.get(6, False) == 7 + assert len(my_cache) == 2 + + my_cache.put(5, 6) + assert my_cache.get(5, False) == 6 + assert len(my_cache) == 2 + + with pytest.raises(KeyError): + my_cache.get(21, False) + + +def test_cache_invalidates_old_item(): + my_cache = cache.LRUCache[int, Union[int, str]]( + max_cache_items=10, + expiration_horizon=datetime.timedelta(milliseconds=1), + retrieval_function=retrieval_function, + ) + + mock_curr_time = datetime.datetime.fromtimestamp(1636730651.079551) + with patch("datetime.datetime") as mock_datetime: + mock_datetime.now.return_value = mock_curr_time + my_cache.put(5) + mock_datetime.now.return_value += datetime.timedelta(milliseconds=2) + with pytest.raises(KeyError): + my_cache.get(5, False) + + with patch("datetime.datetime") as mock_datetime: + mock_datetime.now.return_value = mock_curr_time + my_cache.put(5) + mock_datetime.now.return_value += datetime.timedelta(milliseconds=0.5) + assert my_cache.get(5, False) == retrieval_function(key=5) + + +def test_cache_fetches_new_item(): + my_cache = cache.LRUCache[int, Union[int, str]]( + max_cache_items=10, + expiration_horizon=datetime.timedelta(milliseconds=1), + retrieval_function=retrieval_function, + ) + + mock_curr_time = datetime.datetime.fromtimestamp(1636730651.079551) + with patch("datetime.datetime") as mock_datetime: + mock_datetime.now.return_value = mock_curr_time + my_cache.put(5, 10) + mock_datetime.now.return_value += datetime.timedelta(milliseconds=2) + assert my_cache.get(5) == retrieval_function(key=5) + + with patch("datetime.datetime") as mock_datetime: + mock_datetime.now.return_value = mock_curr_time + my_cache.put(5, 10) + mock_datetime.now.return_value += datetime.timedelta(milliseconds=0.5) + assert my_cache.get(5, False) == 10 + mock_datetime.now.return_value += datetime.timedelta(milliseconds=0.75) + with pytest.raises(KeyError): + my_cache.get(5, False) + + +def test_cache_removes_old_items_once_size_limit_reached(): + my_cache = cache.LRUCache[int, Union[int, str]]( + max_cache_items=5, + expiration_horizon=datetime.timedelta(hours=1), + retrieval_function=retrieval_function, + ) + + for i in [1, 2, 3, 4, 5]: + my_cache.put(i) + + assert len(my_cache) == 5 + + my_cache.put(6) + assert len(my_cache) == 5 + with pytest.raises(KeyError): + my_cache.get(1, False) + assert my_cache.get(2, False) == retrieval_function(key=2) + + +def test_cache_get_with_data_source_fallback(): + my_cache = cache.LRUCache[int, Union[int, str]]( + max_cache_items=5, + expiration_horizon=datetime.timedelta(hours=1), + retrieval_function=retrieval_function, + ) + + for i in range(10): + val = my_cache.get(i) + assert val == retrieval_function(key=i) + + assert len(my_cache) == 5 + + +def test_cache_gets_stored_value(): + my_cache = cache.LRUCache[int, Union[int, str]]( + max_cache_items=5, + expiration_horizon=datetime.timedelta(hours=1), + retrieval_function=retrieval_function, + ) + + for i in range(5): + my_cache.put(i) + + my_cache._retrieval_function = MagicMock() + my_cache.get(4) + my_cache._retrieval_function.assert_not_called() + + my_cache._retrieval_function.reset_mock() + my_cache.get(5) + my_cache._retrieval_function.assert_called_with(key=5, value=None) + + my_cache._retrieval_function.reset_mock() + my_cache.get(0) + my_cache._retrieval_function.assert_called_with(key=0, value=None) + + +def test_cache_bad_retrieval_function(): + + cache_no_retrieval_fx = cache.LRUCache[int, Union[int, str]]( + max_cache_items=5, + expiration_horizon=datetime.timedelta(hours=1), + retrieval_function=None, + ) + + with pytest.raises(TypeError): + cache_no_retrieval_fx.put(1) + + cache_bad_retrieval_fx_signature = cache.LRUCache[int, Union[int, str]]( + max_cache_items=5, + expiration_horizon=datetime.timedelta(hours=1), + retrieval_function=lambda: 1, + ) + + with pytest.raises(TypeError): + cache_bad_retrieval_fx_signature.put(1) + + cache_retrieval_fx_throws = cache.LRUCache[int, Union[int, str]]( + max_cache_items=5, + expiration_horizon=datetime.timedelta(hours=1), + retrieval_function=lambda key, value: exec("raise(RuntimeError())"), + ) + + with pytest.raises(RuntimeError): + cache_retrieval_fx_throws.put(1) + + +def test_cache_clear_and_contains(): + my_cache = cache.LRUCache[int, Union[int, str]]( + max_cache_items=5, + expiration_horizon=datetime.timedelta(hours=1), + retrieval_function=retrieval_function, + ) + + for i in range(5): + my_cache.put(i) + assert i in my_cache + + my_cache.clear() + assert len(my_cache) == 0 + with pytest.raises(KeyError): + my_cache.get(1, False)