From 1689c7cfde6eb637ad4693381c401d222a9ccd52 Mon Sep 17 00:00:00 2001 From: Namrata Madan Date: Fri, 2 Dec 2022 12:48:09 -0800 Subject: [PATCH] feature: sagemaker remote function Co-authored-by: Ao Guo Co-authored-by: Rohan Gujarathi Co-authored-by: Zhankui Lu Co-authored-by: Dipankar Patro Co-authored-by: Mourya Baddam Co-authored-by: Namrata Madan --- .flake8 | 1 + README.rst | 2 + requirements/extras/test_requirements.txt | 1 + setup.py | 2 + src/sagemaker/config/config_schema.py | 102 +- src/sagemaker/experiments/run.py | 50 +- .../sagemaker-base-python.json | 36 + src/sagemaker/image_uris.py | 26 + src/sagemaker/local/local_session.py | 12 + src/sagemaker/remote_function/__init__.py | 16 + src/sagemaker/remote_function/client.py | 881 +++++++++++ .../remote_function/core/serialization.py | 271 ++++ .../remote_function/core/stored_function.py | 105 ++ src/sagemaker/remote_function/errors.py | 99 ++ .../remote_function/invoke_function.py | 103 ++ src/sagemaker/remote_function/job.py | 653 ++++++++ .../remote_function/logging_config.py | 38 + .../runtime_environment/__init__.py | 0 .../bootstrap_runtime_environment.py | 147 ++ .../runtime_environment_manager.py | 366 +++++ src/sagemaker/s3.py | 87 +- src/sagemaker/session.py | 424 +++--- tests/data/config/config.yaml | 23 + tests/data/remote_function/config.yaml | 18 + .../non_existent_requirements.txt | 1 + .../remote_function/old_deps_requirements.txt | 1 + tests/data/remote_function/pre_exec_commands | 4 + .../remote_function/pre_exec_commands_bad_cmd | 3 + tests/data/remote_function/requirements.txt | 1 + .../sagemaker/remote_function/__init__.py | 0 .../sagemaker/remote_function/conftest.py | 214 +++ .../remote_function/helpers/__init__.py | 0 .../remote_function/helpers/local_module.py | 5 + .../helpers/nested_helper/local_module2.py | 5 + .../remote_function/test_decorator.py | 598 ++++++++ .../remote_function/test_executor.py | 255 ++++ tests/integ/test_s3.py | 32 + tests/unit/sagemaker/config/conftest.py | 27 + .../sagemaker/config/test_config_schema.py | 14 + tests/unit/sagemaker/experiments/helpers.py | 3 + tests/unit/sagemaker/experiments/test_run.py | 31 + .../sagemaker/image_uris/expected_uris.py | 6 + .../sagemaker/image_uris/test_base_python.py | 61 + .../unit/sagemaker/local/test_local_image.py | 1 + .../core/test_serialization.py | 394 +++++ .../core/test_stored_function.py | 124 ++ .../test_bootstrap_runtime_environment.py | 206 +++ .../test_runtime_environment_manager.py | 415 +++++ .../sagemaker/remote_function/test_client.py | 1334 +++++++++++++++++ .../sagemaker/remote_function/test_errors.py | 81 + .../remote_function/test_invoke_function.py | 109 ++ .../sagemaker/remote_function/test_job.py | 558 +++++++ .../remote_function/test_logging_config.py | 28 + tests/unit/test_exception_on_bad_status.py | 14 +- tests/unit/test_session.py | 80 +- 55 files changed, 7832 insertions(+), 236 deletions(-) create mode 100644 src/sagemaker/image_uri_config/sagemaker-base-python.json create mode 100644 src/sagemaker/remote_function/__init__.py create mode 100644 src/sagemaker/remote_function/client.py create mode 100644 src/sagemaker/remote_function/core/serialization.py create mode 100644 src/sagemaker/remote_function/core/stored_function.py create mode 100644 src/sagemaker/remote_function/errors.py create mode 100644 src/sagemaker/remote_function/invoke_function.py create mode 100644 src/sagemaker/remote_function/job.py create mode 100644 src/sagemaker/remote_function/logging_config.py create mode 100644 src/sagemaker/remote_function/runtime_environment/__init__.py create mode 100644 src/sagemaker/remote_function/runtime_environment/bootstrap_runtime_environment.py create mode 100644 src/sagemaker/remote_function/runtime_environment/runtime_environment_manager.py create mode 100644 tests/data/remote_function/config.yaml create mode 100644 tests/data/remote_function/non_existent_requirements.txt create mode 100644 tests/data/remote_function/old_deps_requirements.txt create mode 100644 tests/data/remote_function/pre_exec_commands create mode 100644 tests/data/remote_function/pre_exec_commands_bad_cmd create mode 100644 tests/data/remote_function/requirements.txt create mode 100644 tests/integ/sagemaker/remote_function/__init__.py create mode 100644 tests/integ/sagemaker/remote_function/conftest.py create mode 100644 tests/integ/sagemaker/remote_function/helpers/__init__.py create mode 100644 tests/integ/sagemaker/remote_function/helpers/local_module.py create mode 100644 tests/integ/sagemaker/remote_function/helpers/nested_helper/local_module2.py create mode 100644 tests/integ/sagemaker/remote_function/test_decorator.py create mode 100644 tests/integ/sagemaker/remote_function/test_executor.py create mode 100644 tests/unit/sagemaker/image_uris/test_base_python.py create mode 100644 tests/unit/sagemaker/remote_function/core/test_serialization.py create mode 100644 tests/unit/sagemaker/remote_function/core/test_stored_function.py create mode 100644 tests/unit/sagemaker/remote_function/runtime_environment/test_bootstrap_runtime_environment.py create mode 100644 tests/unit/sagemaker/remote_function/runtime_environment/test_runtime_environment_manager.py create mode 100644 tests/unit/sagemaker/remote_function/test_client.py create mode 100644 tests/unit/sagemaker/remote_function/test_errors.py create mode 100644 tests/unit/sagemaker/remote_function/test_invoke_function.py create mode 100644 tests/unit/sagemaker/remote_function/test_job.py create mode 100644 tests/unit/sagemaker/remote_function/test_logging_config.py diff --git a/.flake8 b/.flake8 index 51ecee6eee..53e43383ac 100644 --- a/.flake8 +++ b/.flake8 @@ -3,3 +3,4 @@ application_import_names = sagemaker, tests import-order-style = google per-file-ignores = tests/unit/test_tuner.py: F405 + src/sagemaker/config/config_schema.py: E501 diff --git a/README.rst b/README.rst index a41c0da7ac..1e9bdf162a 100644 --- a/README.rst +++ b/README.rst @@ -133,6 +133,8 @@ To run the integration tests, the following prerequisites must be met 1. AWS account credentials are available in the environment for the boto3 client to use. 2. The AWS account has an IAM role named :code:`SageMakerRole`. It should have the AmazonSageMakerFullAccess policy attached as well as a policy with `the necessary permissions to use Elastic Inference `__. +3. To run remote_function tests, dummy ecr repo should be created. It can be created by running - + :code:`aws ecr create-repository --repository-name remote-function-dummy-container` We recommend selectively running just those integration tests you'd like to run. You can filter by individual test function names with: diff --git a/requirements/extras/test_requirements.txt b/requirements/extras/test_requirements.txt index d2bcbe60c6..695c5b2d47 100644 --- a/requirements/extras/test_requirements.txt +++ b/requirements/extras/test_requirements.txt @@ -21,3 +21,4 @@ sagemaker-experiments==0.1.35 Jinja2==3.0.3 pandas>=1.3.5,<1.5 scikit-learn==1.0.2 +cloudpickle==2.2.1 diff --git a/setup.py b/setup.py index 98a63c9d32..ad4118a80a 100644 --- a/setup.py +++ b/setup.py @@ -49,6 +49,7 @@ def read_requirements(filename): required_packages = [ "attrs>=20.3.0,<23", "boto3>=1.26.28,<2.0", + "cloudpickle==2.2.1", "google-pasta", "numpy>=1.9.0,<2.0", "protobuf>=3.1,<4.0", @@ -62,6 +63,7 @@ def read_requirements(filename): "PyYAML==5.4.1", "jsonschema", "platformdirs", + "tblib==1.7.0", ] # Specific use case dependencies diff --git a/src/sagemaker/config/config_schema.py b/src/sagemaker/config/config_schema.py index cd1ce48baf..033742603a 100644 --- a/src/sagemaker/config/config_schema.py +++ b/src/sagemaker/config/config_schema.py @@ -44,6 +44,17 @@ SAGEMAKER = "SageMaker" PYTHON_SDK = "PythonSDK" MODULES = "Modules" +REMOTE_FUNCTION = "RemoteFunction" +DEPENDENCIES = "Dependencies" +PRE_EXECUTION_SCRIPT = "PreExecutionScript" +PRE_EXECUTION_COMMANDS = "PreExecutionCommands" +ENVIRONMENT_VARIABLES = "EnvironmentVariables" +IMAGE_URI = "ImageUri" +INCLUDE_LOCAL_WORKDIR = "IncludeLocalWorkDir" +INSTANCE_TYPE = "InstanceType" +S3_KMS_KEY_ID = "S3KmsKeyId" +S3_ROOT_URI = "S3RootUri" +JOB_CONDA_ENV = "JobCondaEnvironment" OFFLINE_STORE_CONFIG = "OfflineStoreConfig" ONLINE_STORE_CONFIG = "OnlineStoreConfig" S3_STORAGE_CONFIG = "S3StorageConfig" @@ -221,6 +232,49 @@ def _simple_path(*args: str): SAGEMAKER, MODEL_PACKAGE, VALIDATION_SPECIFICATION, VALIDATION_PROFILES ) +REMOTE_FUNCTION_DEPENDENCIES = _simple_path( + SAGEMAKER, PYTHON_SDK, MODULES, REMOTE_FUNCTION, DEPENDENCIES +) +REMOTE_FUNCTION_PRE_EXECUTION_COMMANDS = _simple_path( + SAGEMAKER, PYTHON_SDK, MODULES, REMOTE_FUNCTION, PRE_EXECUTION_COMMANDS +) +REMOTE_FUNCTION_PRE_EXECUTION_SCRIPT = _simple_path( + SAGEMAKER, PYTHON_SDK, MODULES, REMOTE_FUNCTION, PRE_EXECUTION_SCRIPT +) +REMOTE_FUNCTION_ENVIRONMENT_VARIABLES = _simple_path( + SAGEMAKER, PYTHON_SDK, MODULES, REMOTE_FUNCTION, ENVIRONMENT_VARIABLES +) +REMOTE_FUNCTION_IMAGE_URI = _simple_path(SAGEMAKER, PYTHON_SDK, MODULES, REMOTE_FUNCTION, IMAGE_URI) +REMOTE_FUNCTION_INCLUDE_LOCAL_WORKDIR = _simple_path( + SAGEMAKER, PYTHON_SDK, MODULES, REMOTE_FUNCTION, INCLUDE_LOCAL_WORKDIR +) +REMOTE_FUNCTION_INSTANCE_TYPE = _simple_path( + SAGEMAKER, PYTHON_SDK, MODULES, REMOTE_FUNCTION, INSTANCE_TYPE +) +REMOTE_FUNCTION_JOB_CONDA_ENV = _simple_path( + SAGEMAKER, PYTHON_SDK, MODULES, REMOTE_FUNCTION, JOB_CONDA_ENV +) +REMOTE_FUNCTION_ROLE_ARN = _simple_path(SAGEMAKER, PYTHON_SDK, MODULES, REMOTE_FUNCTION, ROLE_ARN) +REMOTE_FUNCTION_S3_KMS_KEY_ID = _simple_path( + SAGEMAKER, PYTHON_SDK, MODULES, REMOTE_FUNCTION, S3_KMS_KEY_ID +) +REMOTE_FUNCTION_S3_ROOT_URI = _simple_path( + SAGEMAKER, PYTHON_SDK, MODULES, REMOTE_FUNCTION, S3_ROOT_URI +) +REMOTE_FUNCTION_TAGS = _simple_path(SAGEMAKER, PYTHON_SDK, MODULES, REMOTE_FUNCTION, TAGS) +REMOTE_FUNCTION_VOLUME_KMS_KEY_ID = _simple_path( + SAGEMAKER, PYTHON_SDK, MODULES, REMOTE_FUNCTION, VOLUME_KMS_KEY_ID +) +REMOTE_FUNCTION_VPC_CONFIG_SUBNETS = _simple_path( + SAGEMAKER, PYTHON_SDK, MODULES, REMOTE_FUNCTION, VPC_CONFIG, SUBNETS +) +REMOTE_FUNCTION_VPC_CONFIG_SECURITY_GROUP_IDS = _simple_path( + SAGEMAKER, PYTHON_SDK, MODULES, REMOTE_FUNCTION, VPC_CONFIG, SECURITY_GROUP_IDS +) +REMOTE_FUNCTION_ENABLE_INTER_CONTAINER_TRAFFIC_ENCRYPTION = _simple_path( + SAGEMAKER, PYTHON_SDK, MODULES, REMOTE_FUNCTION, ENABLE_INTER_CONTAINER_TRAFFIC_ENCRYPTION +) + # Paths for reference elsewhere in the SDK. # Names include the schema version since the paths could change with other schema versions MONITORING_SCHEDULE_INTER_CONTAINER_ENCRYPTION_PATH = _simple_path( @@ -245,7 +299,6 @@ def _simple_path(*args: str): SAGEMAKER, TRAINING_JOB, ENABLE_INTER_CONTAINER_TRAFFIC_ENCRYPTION ) - SAGEMAKER_PYTHON_SDK_CONFIG_SCHEMA = { "$schema": "https://json-schema.org/draft/2020-12/schema", TYPE: OBJECT, @@ -377,6 +430,23 @@ def _simple_path(*args: str): "minItems": 0, "maxItems": 50, }, + # Regex is taken from https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_CreateTrainingJob.html#sagemaker-CreateTrainingJob-request-Environment + "environmentVariables": { + TYPE: OBJECT, + ADDITIONAL_PROPERTIES: False, + PATTERN_PROPERTIES: { + r"([a-zA-Z_][a-zA-Z0-9_]*){1,512}": { + TYPE: "string", + "pattern": r"[\S\s]*", + "maxLength": 512, + } + }, + "maxProperties": 48, + }, + # Regex is taken from https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_S3DataSource.html#sagemaker-Type-S3DataSource-S3Uri + "s3Uri": {TYPE: "string", "pattern": "^(https|s3)://([^/]+)/?(.*)$", "maxLength": 1024}, + # Regex is taken from https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_AlgorithmSpecification.html#sagemaker-Type-AlgorithmSpecification-ContainerEntrypoint + "preExecutionCommand": {TYPE: "string", "pattern": r".*"}, }, PROPERTIES: { SCHEMA_VERSION: { @@ -406,6 +476,36 @@ def _simple_path(*args: str): # Any SageMaker Python SDK specific configuration will be added here. TYPE: OBJECT, ADDITIONAL_PROPERTIES: False, + PROPERTIES: { + REMOTE_FUNCTION: { + TYPE: OBJECT, + ADDITIONAL_PROPERTIES: False, + PROPERTIES: { + DEPENDENCIES: {TYPE: "string"}, + PRE_EXECUTION_COMMANDS: { + TYPE: "array", + "items": {"$ref": "#/definitions/preExecutionCommand"}, + }, + PRE_EXECUTION_SCRIPT: {TYPE: "string"}, + ENABLE_INTER_CONTAINER_TRAFFIC_ENCRYPTION: { + TYPE: "boolean" + }, + ENVIRONMENT_VARIABLES: { + "$ref": "#/definitions/environmentVariables" + }, + IMAGE_URI: {TYPE: "string"}, + INCLUDE_LOCAL_WORKDIR: {TYPE: "boolean"}, + INSTANCE_TYPE: {TYPE: "string"}, + JOB_CONDA_ENV: {TYPE: "string"}, + ROLE_ARN: {"$ref": "#/definitions/roleArn"}, + S3_KMS_KEY_ID: {"$ref": "#/definitions/kmsKeyId"}, + S3_ROOT_URI: {"$ref": "#/definitions/s3Uri"}, + TAGS: {"$ref": "#/definitions/tags"}, + VOLUME_KMS_KEY_ID: {"$ref": "#/definitions/kmsKeyId"}, + VPC_CONFIG: {"$ref": "#/definitions/vpcConfig"}, + }, + } + }, } }, }, diff --git a/src/sagemaker/experiments/run.py b/src/sagemaker/experiments/run.py index b5eee03636..07b7080ea3 100644 --- a/src/sagemaker/experiments/run.py +++ b/src/sagemaker/experiments/run.py @@ -715,6 +715,14 @@ def __exit__(self, exc_type, exc_value, exc_traceback): self.close() + def __getstate__(self): + """Overriding this method to prevent instance of Run from being pickled. + + Raise: + NotImplementedError: If attempting to pickle this instance. + """ + raise NotImplementedError("Instance of Run type is not allowed to be pickled.") + def load_run( run_name: Optional[str] = None, @@ -787,36 +795,38 @@ def load_run( Returns: Run: The loaded Run object. """ - sagemaker_session = sagemaker_session or _utils.default_session() environment = _RunEnvironment.load() verify_load_input_names(run_name=run_name, experiment_name=experiment_name) - if run_name or environment: - if run_name: - logger.warning( - "run_name is explicitly supplied in load_run, " - "which will be prioritized to load the Run object. " - "In other words, the run name in the experiment config, fetched from the " - "job environment or the current run context, will be ignored." - ) - else: - exp_config = get_tc_and_exp_config_from_job_env( - environment=environment, sagemaker_session=sagemaker_session - ) - run_name = Run._extract_run_name_from_tc_name( - trial_component_name=exp_config[RUN_NAME], - experiment_name=exp_config[EXPERIMENT_NAME], - ) - experiment_name = exp_config[EXPERIMENT_NAME] - + if run_name: + logger.warning( + "run_name is explicitly supplied in load_run, " + "which will be prioritized to load the Run object. " + "In other words, the run name in the experiment config, fetched from the " + "job environment or the current run context, will be ignored." + ) run_instance = Run( experiment_name=experiment_name, run_name=run_name, - sagemaker_session=sagemaker_session, + sagemaker_session=sagemaker_session or _utils.default_session(), ) elif _RunContext.get_current_run(): run_instance = _RunContext.get_current_run() + elif environment: + exp_config = get_tc_and_exp_config_from_job_env( + environment=environment, sagemaker_session=sagemaker_session or _utils.default_session() + ) + run_name = Run._extract_run_name_from_tc_name( + trial_component_name=exp_config[RUN_NAME], + experiment_name=exp_config[EXPERIMENT_NAME], + ) + experiment_name = exp_config[EXPERIMENT_NAME] + run_instance = Run( + experiment_name=experiment_name, + run_name=run_name, + sagemaker_session=sagemaker_session or _utils.default_session(), + ) else: raise RuntimeError( "Failed to load a Run object. " diff --git a/src/sagemaker/image_uri_config/sagemaker-base-python.json b/src/sagemaker/image_uri_config/sagemaker-base-python.json new file mode 100644 index 0000000000..771f66ab95 --- /dev/null +++ b/src/sagemaker/image_uri_config/sagemaker-base-python.json @@ -0,0 +1,36 @@ +{ + "versions": { + "1.0": { + "registries": { + "us-east-2": "429704687514", + "me-south-1": "117516905037", + "us-west-2": "236514542706", + "ca-central-1": "310906938811", + "ap-east-1": "493642496378", + "us-east-1": "081325390199", + "ap-northeast-2": "806072073708", + "eu-west-2": "712779665605", + "ap-southeast-2": "52832661640", + "cn-northwest-1": "390780980154", + "eu-north-1": "243637512696", + "cn-north-1": "390048526115", + "ap-south-1": "394103062818", + "eu-west-3": "615547856133", + "ap-southeast-3": "276181064229", + "af-south-1": "559312083959", + "eu-west-1": "470317259841", + "eu-central-1": "936697816551", + "sa-east-1": "782484402741", + "ap-northeast-3": "792733760839", + "eu-south-1": "592751261982", + "ap-northeast-1": "102112518831", + "us-west-1": "742091327244", + "ap-southeast-1": "492261229750", + "me-central-1": "103105715889", + "us-gov-east-1": "107072934176", + "us-gov-west-1": "107173498710" + }, + "repository": "sagemaker-base-python" + } + } +} \ No newline at end of file diff --git a/src/sagemaker/image_uris.py b/src/sagemaker/image_uris.py index b9101acd96..e39225e60c 100644 --- a/src/sagemaker/image_uris.py +++ b/src/sagemaker/image_uris.py @@ -663,3 +663,29 @@ def get_training_image_uri( container_version=container_version, training_compiler_config=compiler_config, ) + + +def get_base_python_image_uri(region, py_version="310") -> str: + """Retrieves the image URI for base python image. + + Args: + region (str): The AWS region to use for image URI. + py_version (str): The python version to use for the image. Can be 310 or 38 + Default to 310 + + Returns: + str: The image URI string. + """ + + framework = "sagemaker-base-python" + version = "1.0" + hostname = utils._botocore_resolver().construct_endpoint("ecr", region)["hostname"] + config = config_for_framework(framework) + version_config = config["versions"][_version_for_config(version, config)] + + registry = _registry_from_region(region, version_config["registries"]) + + repo = version_config["repository"] + "-" + py_version + repo_and_tag = repo + ":" + version + + return ECR_URI_TEMPLATE.format(registry=registry, hostname=hostname, repository=repo_and_tag) diff --git a/src/sagemaker/local/local_session.py b/src/sagemaker/local/local_session.py index a9d8c73223..c94f695e0d 100644 --- a/src/sagemaker/local/local_session.py +++ b/src/sagemaker/local/local_session.py @@ -674,10 +674,22 @@ def _initialize( self.sagemaker_client = LocalSagemakerClient(self) self.sagemaker_runtime_client = LocalSagemakerRuntimeClient(self.config) self.local_mode = True + sagemaker_config = kwargs.get("sagemaker_config", None) + if sagemaker_config: + validate_sagemaker_config(sagemaker_config) if self.s3_endpoint_url is not None: self.s3_resource = boto_session.resource("s3", endpoint_url=self.s3_endpoint_url) self.s3_client = boto_session.client("s3", endpoint_url=self.s3_endpoint_url) + self.sagemaker_config = ( + sagemaker_config + if sagemaker_config + else load_sagemaker_config(s3_resource=self.s3_resource) + ) + else: + self.sagemaker_config = ( + sagemaker_config if sagemaker_config else load_sagemaker_config() + ) sagemaker_config = kwargs.get("sagemaker_config", None) if sagemaker_config: diff --git a/src/sagemaker/remote_function/__init__.py b/src/sagemaker/remote_function/__init__.py new file mode 100644 index 0000000000..5e7f94b724 --- /dev/null +++ b/src/sagemaker/remote_function/__init__.py @@ -0,0 +1,16 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Defines classes and helper methods used in remote function executions.""" +from __future__ import absolute_import + +from sagemaker.remote_function.client import remote, RemoteExecutor # noqa: F401 diff --git a/src/sagemaker/remote_function/client.py b/src/sagemaker/remote_function/client.py new file mode 100644 index 0000000000..a07f7baeb0 --- /dev/null +++ b/src/sagemaker/remote_function/client.py @@ -0,0 +1,881 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""SageMaker remote function client.""" +from __future__ import absolute_import + +from concurrent.futures import ThreadPoolExecutor +from collections import deque +import time +import threading +from typing import Dict, List, Tuple, Any +import functools +import itertools +import inspect + +from botocore.exceptions import ClientError +from sagemaker.exceptions import UnexpectedStatusException +from sagemaker.experiments._run_context import _RunContext + +import sagemaker.remote_function.core.serialization as serialization +from sagemaker.remote_function.errors import RemoteFunctionError, ServiceError, DeserializationError +from sagemaker.remote_function.core.stored_function import RESULTS_FOLDER, EXCEPTION_FOLDER +from sagemaker.remote_function.runtime_environment.runtime_environment_manager import ( + RuntimeEnvironmentError, +) + +from sagemaker.session import Session +from sagemaker.s3 import s3_path_join +from sagemaker.remote_function.job import _JobSettings, _Job, _RunInfo +from sagemaker.remote_function import logging_config +from sagemaker.utils import name_from_base, base_from_name + +_API_CALL_LIMIT = { + "SubmittingIntervalInSecs": 1, + "MinBatchPollingIntervalInSecs": 10, + "PollingIntervalInSecs": 0.5, +} + +# Possible future states. +_PENDING = "PENDING" +_RUNNING = "RUNNING" +# The future was cancelled by the user... +_CANCELLED = "CANCELLED" +_FINISHED = "FINISHED" + +logger = logging_config.get_logger() + + +def remote( + _func=None, + *, + dependencies: str = None, + pre_execution_commands: List[str] = None, + pre_execution_script: str = None, + environment_variables: Dict[str, str] = None, + image_uri: str = None, + include_local_workdir: bool = False, + instance_count: int = 1, + instance_type: str = None, + job_conda_env: str = None, + job_name_prefix: str = None, + keep_alive_period_in_seconds: int = 0, + max_retry_attempts: int = 1, + max_runtime_in_seconds: int = 24 * 60 * 60, + role: str = None, + s3_kms_key: str = None, + s3_root_uri: str = None, + sagemaker_session: Session = None, + security_group_ids: List[str] = None, + subnets: List[str] = None, + tags: List[Tuple[str, str]] = None, + volume_kms_key: str = None, + volume_size: int = 30, + encrypt_inter_container_traffic: bool = None, +): + """Function that starts a new SageMaker job synchronously with overridden runtime settings. + + Args: + _func (Optional): Python function to be executed on the SageMaker job runtime environment. + dependencies (str): Path to dependencies file or a reserved keyword + ``auto_capture``. Defaults to None. + pre_execution_commands (List[str]): List of commands to be executed prior to executing + remote function. Only one of ``pre_execution_commands`` or ``pre_execution_script`` + can be specified at the same time. Defaults to None. + pre_execution_script (str): Path to script file to be executed prior to executing + remote function. Only one of ``pre_execution_commands`` or ``pre_execution_script`` + can be specified at the same time. Defaults to None. + environment_variables (Dict): environment variables + image_uri (str): Docker image URI on ECR. + include_local_workdir (bool): Set to ``True`` if the remote function code imports local + modules and methods that are not available via PyPI or conda. Default value is ``False``. + instance_count (int): Number of instance to use. Default is 1. + instance_type (str): EC2 instance type. + job_conda_env (str): Name of the conda environment to activate during execution of the job. + Default is None. + job_name_prefix (str): Prefix used to identify the underlying sagemaker job. + keep_alive_period_in_seconds (int): The duration of time in seconds to retain configured + resources in a warm pool for subsequent training jobs. Default is 0. + max_retry_attempts (int): Max number of times the job is retried on InternalServerFailure. + Default is 1. + max_runtime_in_seconds (int): Timeout in seconds for training. After this amount of time + Amazon SageMaker terminates the job regardless of its current status. + Default is 86400 seconds (1 day). + role (str): IAM role used for SageMaker execution. + s3_kms_key (str): The encryption key used for storing serialized data. + s3_root_uri (str): The root S3 folder where the code archives and data are uploaded to. + sagemaker_session (sagemaker.session.Session): The underlying SageMaker session which + AWS service calls are delegated to (default: None). If not provided, one is created + with default AWS configuration chain. + security_group_ids (List[str]): List of security group IDs. + subnets (List[str]): List of subnet IDs. + tags (List[Tuple[str, str]]): List of tags attached to the job. + volume_kms_key (str): KMS key used for encrypting EBS volume attached to the training + instance. + volume_size (int): Size in GB of the storage volume to use for storing input and output + data. Default is 30. + encrypt_inter_container_traffic (bool): Specifies whether traffic between training + containers is encrypted for the training job. (default: ``False``). + """ + + def _remote(func): + @functools.wraps(func) + def wrapper(*args, **kwargs): + + RemoteExecutor._validate_submit_args(func, *args, **kwargs) + + job_settings = _JobSettings( + dependencies=dependencies, + pre_execution_commands=pre_execution_commands, + pre_execution_script=pre_execution_script, + environment_variables=environment_variables, + image_uri=image_uri, + include_local_workdir=include_local_workdir, + instance_count=instance_count, + instance_type=instance_type, + job_conda_env=job_conda_env, + job_name_prefix=job_name_prefix, + keep_alive_period_in_seconds=keep_alive_period_in_seconds, + max_retry_attempts=max_retry_attempts, + max_runtime_in_seconds=max_runtime_in_seconds, + role=role, + s3_kms_key=s3_kms_key, + s3_root_uri=s3_root_uri, + sagemaker_session=sagemaker_session, + security_group_ids=security_group_ids, + subnets=subnets, + tags=tags, + volume_kms_key=volume_kms_key, + volume_size=volume_size, + encrypt_inter_container_traffic=encrypt_inter_container_traffic, + ) + job = _Job.start(job_settings, func, args, kwargs) + + try: + job.wait() + except UnexpectedStatusException as usex: + if usex.actual_status == "Failed": + try: + exception = serialization.deserialize_exception_from_s3( + sagemaker_session=job_settings.sagemaker_session, + s3_uri=s3_path_join( + job_settings.s3_root_uri, job.job_name, EXCEPTION_FOLDER + ), + ) + except ServiceError as serr: + chained_e = serr.__cause__ + if ( + isinstance(chained_e, ClientError) + and chained_e.response["Error"]["Code"] # pylint: disable=no-member + == "404" + and chained_e.response["Error"]["Message"] # pylint: disable=no-member + == "Not Found" + ): + describe_result = job.describe() + if ( + "FailureReason" in describe_result + and describe_result["FailureReason"] + and "RuntimeEnvironmentError: " in describe_result["FailureReason"] + ): + failure_msg = describe_result["FailureReason"].replace( + "RuntimeEnvironmentError: ", "" + ) + raise RuntimeEnvironmentError(failure_msg) + raise RemoteFunctionError( + "Failed to execute remote function. " + + "Check corresponding job for details." + ) + raise serr + + raise exception + + raise TimeoutError( + "Job for remote function timed out before reaching a termination status." + ) + + if job.describe()["TrainingJobStatus"] == "Completed": + return serialization.deserialize_obj_from_s3( + sagemaker_session=job_settings.sagemaker_session, + s3_uri=s3_path_join(job_settings.s3_root_uri, job.job_name, RESULTS_FOLDER), + ) + + if job.describe()["TrainingJobStatus"] == "Stopped": + raise RemoteFunctionError("Job for remote function has been aborted.") + + return None + + return wrapper + + if _func is None: + return _remote + return _remote(_func) + + +class _SubmitRequest: + """Class that holds parameters and data for creating a new job.""" + + def __init__( + self, future, job_settings: _JobSettings, func, func_args, func_kwargs, run_info=None + ): + self.future = future + self.job_settings = job_settings + self.func = func + self.args = func_args + self.kwargs = func_kwargs + self.run_info = run_info + + +def _submit_worker(executor): + """Background worker that submits job requests.""" + + def has_work_to_do(): + return ( + len(executor._pending_request_queue) > 0 + and len(executor._running_jobs) < executor.max_parallel_jobs + ) + + try: + while True: + with executor._state_condition: + executor._state_condition.wait_for(has_work_to_do) + request = executor._pending_request_queue[0] + + if request is None: + with executor._state_condition: + # remove the anchor from the pending queue + executor._pending_request_queue.popleft() + return + + time.sleep(_API_CALL_LIMIT["SubmittingIntervalInSecs"]) + # submit a new job + job = request.future._start_and_notify( + request.job_settings, request.func, request.args, request.kwargs, request.run_info + ) + + with executor._state_condition: + if job: + executor._running_jobs[job.job_name] = job + # remove the request from the pending queue + executor._pending_request_queue.popleft() + except Exception: # pylint: disable=broad-except + logger.exception("Error occurred while submitting CreateTrainingJob requests.") + + +def _polling_worker(executor): + """Background worker that polls the status of the running jobs.""" + try: + while True: + with executor._state_condition: + if ( + executor._shutdown + and len(executor._running_jobs) + len(executor._pending_request_queue) == 0 + ): + return + + time.sleep( + max( + _API_CALL_LIMIT["MinBatchPollingIntervalInSecs"] + - len(executor._running_jobs) * _API_CALL_LIMIT["PollingIntervalInSecs"], + 0, + ) + ) + + # check if running jobs are terminated + for job_name in list(executor._running_jobs.keys()): + try: + time.sleep(_API_CALL_LIMIT["PollingIntervalInSecs"]) + if executor._running_jobs[job_name].describe()["TrainingJobStatus"] in [ + "Completed", + "Failed", + "Stopped", + ]: + with executor._state_condition: + del executor._running_jobs[job_name] + executor._state_condition.notify_all() + except Exception as e: # pylint: disable=broad-except + if ( + not isinstance(e, ClientError) + or e.response["Error"]["Code"] # pylint: disable=no-member + != "LimitExceededException" + ): + # Couldn't check the job status, move on + logger.exception( + "Error occurred while checking the status of job %s", job_name + ) + with executor._state_condition: + del executor._running_jobs[job_name] + executor._state_condition.notify_all() + except Exception: # pylint: disable=broad-except + logger.exception("Error occurred while monitoring the job statuses.") + + +class RemoteExecutor(object): + """Run Python functions asynchronously as SageMaker jobs""" + + def __init__( + self, + *, + dependencies: str = None, + pre_execution_commands: List[str] = None, + pre_execution_script: str = None, + environment_variables: Dict[str, str] = None, + image_uri: str = None, + include_local_workdir: bool = False, + instance_count: int = 1, + instance_type: str = None, + job_conda_env: str = None, + job_name_prefix: str = None, + keep_alive_period_in_seconds: int = 0, + max_parallel_jobs: int = 1, + max_retry_attempts: int = 1, + max_runtime_in_seconds: int = 24 * 60 * 60, + role: str = None, + s3_kms_key: str = None, + s3_root_uri: str = None, + sagemaker_session: Session = None, + security_group_ids: List[str] = None, + subnets: List[str] = None, + tags: List[Tuple[str, str]] = None, + volume_kms_key: str = None, + volume_size: int = 30, + encrypt_inter_container_traffic: bool = None, + ): + """Initiates a ``RemoteExecutor`` instance. + + Args: + dependencies (str): Path to dependencies file or a reserved keyword + ``auto_capture``. Defaults to None. + pre_execution_commands (List[str]): List of commands to be executed prior to executing + remote function. Only one of ``pre_execution_commands`` or ``pre_execution_script`` + can be specified at the same time. Defaults to None. + pre_execution_script (str): Path to script file to be executed prior to executing + remote function. Only one of ``pre_execution_commands`` or ``pre_execution_script`` + can be specified at the same time. Defaults to None. + environment_variables (Dict): Environment variables passed to the underlying sagemaker + job. Defaults to None + image_uri (str): Docker image URI on ECR. Defaults to base Python image. + include_local_workdir (bool): Set to ``True`` if the remote function code imports local + modules and methods that are not available via PyPI or conda. Default value is + ``False``. + instance_count (int): Number of instance to use. Defaults to 1. + instance_type (str): EC2 instance type. + job_conda_env (str): Name of the conda environment to activate during execution + of the job. Default is None. + job_name_prefix (str): Prefix used to identify the underlying sagemaker job. + keep_alive_period_in_seconds (int): The duration of time in seconds to retain configured + resources in a warm pool for subsequent training jobs. Defaults to 0. + max_parallel_jobs (int): Maximal number of jobs that run in parallel. Default to 1. + max_retry_attempts (int): Max number of times the job is retried on + InternalServerFailure.Defaults to 1. + max_runtime_in_seconds (int): Timeout in seconds for training. After this amount of + time Amazon SageMaker terminates the job regardless of its current status. + Defaults to 86400 seconds (1 day). + role (str): IAM role used for SageMaker execution. Defaults to SageMaker default + execution role. + s3_kms_key (str): The encryption key used for storing serialized data. Defaults to S3 + managed key. + s3_root_uri (str): The root S3 folder where the code archives and data are uploaded to. + This parameter is autogenerated using information regarding the image uri if not + provided. + sagemaker_session (sagemaker.session.Session): The underlying SageMaker session which + AWS service calls are delegated to (default: None). If not provided, one is created + with default AWS configuration chain. + security_group_ids (List[str]): List of security group IDs. Defaults to None. + subnets (List[str]): List of subnet IDs. Defaults to None. + tags (List[Tuple[str, str]]): List of tags attached to the job. Defaults to None. + volume_kms_key (str): KMS key used for encrypting EBS volume attached to the training + instance. + volume_size (int): Size in GB of the storage volume to use for storing input and output + data. Defaults to 30. + encrypt_inter_container_traffic (bool): Specifies whether traffic between training + containers is encrypted for the training job. (default: ``False``). + """ + self.max_parallel_jobs = max_parallel_jobs + + if self.max_parallel_jobs <= 0: + raise ValueError("max_parallel_jobs must be greater than 0.") + + self.job_settings = _JobSettings( + dependencies=dependencies, + pre_execution_commands=pre_execution_commands, + pre_execution_script=pre_execution_script, + environment_variables=environment_variables, + image_uri=image_uri, + include_local_workdir=include_local_workdir, + instance_count=instance_count, + instance_type=instance_type, + job_conda_env=job_conda_env, + job_name_prefix=job_name_prefix, + keep_alive_period_in_seconds=keep_alive_period_in_seconds, + max_retry_attempts=max_retry_attempts, + max_runtime_in_seconds=max_runtime_in_seconds, + role=role, + s3_kms_key=s3_kms_key, + s3_root_uri=s3_root_uri, + sagemaker_session=sagemaker_session, + security_group_ids=security_group_ids, + subnets=subnets, + tags=tags, + volume_kms_key=volume_kms_key, + volume_size=volume_size, + encrypt_inter_container_traffic=encrypt_inter_container_traffic, + ) + + self._state_condition = threading.Condition() + self._pending_request_queue = deque() + # For thread safety, see + # https://web.archive.org/web/20201108091210/http://effbot.org/pyfaq/what-kinds-of-global-value-mutation-are-thread-safe.htm + self._running_jobs = dict() + self._shutdown = False + + self._workers: ThreadPoolExecutor = None + + def submit(self, func, *args, **kwargs): + """Execute the input function as a SageMaker job asynchronously. + + Args: + func: Python function to run as a SageMaker job. + *args: Positional arguments to the input function. + **kwargs: keyword arguments to the input function + """ + if self._shutdown: + raise RuntimeError("Cannot schedule new remote function executions after shutdown") + + self._validate_submit_args(func, *args, **kwargs) + + with self._state_condition: + future = Future() + + run_info = None + if _RunContext.get_current_run() is not None: + run = _RunContext.get_current_run() + run_info = _RunInfo(run.experiment_name, run.run_name) + + self._pending_request_queue.append( + _SubmitRequest(future, self.job_settings, func, args, kwargs, run_info) + ) + + if self._workers is None: + self._workers = ThreadPoolExecutor(2) + self._workers.submit(_submit_worker, self) + self._workers.submit(_polling_worker, self) + + self._state_condition.notify_all() + + return future + + def map(self, func, *iterables): + """Return an iterator that applies function to every item of iterable, yielding the results. + + If additional iterables arguments are passed, function must take that many arguments and + is applied to the items from all iterables in parallel. With multiple iterables, the + iterator stops when the shortest iterable is exhausted. + + Args: + func: Python function to run as a SageMaker job. + iterables: Arguments of the input python function. + """ + + futures = map(self.submit, itertools.repeat(func), *iterables) + return [future.result() for future in futures] + + def shutdown(self): + """Prevent more function executions to be submitted to this executor.""" + with self._state_condition: + self._shutdown = True + + # give a signal to the submitting worker so that it doesn't block on empty queue forever + self._pending_request_queue.append(None) + + self._state_condition.notify_all() + + if self._workers is not None: + self._workers.shutdown(wait=True) + + def __enter__(self): + """Create an executor instance and return it""" + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + """Make sure the executor instance is shutdown.""" + self.shutdown() + return False + + @staticmethod + def _validate_submit_args(func, *args, **kwargs): + """Validates input args passed to submit method.""" + + full_arg_spec = inspect.getfullargspec(func) + + # args related validations + + is_accepting_variable_positional_args = full_arg_spec.varargs is not None + num_default_positional_args = len(full_arg_spec.defaults) if full_arg_spec.defaults else 0 + minimum_num_expected_positional_args = len(full_arg_spec.args) - num_default_positional_args + + if not is_accepting_variable_positional_args and len(args) > len(full_arg_spec.args): + raise TypeError( + f"{func.__name__}() takes {len(full_arg_spec.args)} positional " + + f"{'arguments' if len(full_arg_spec.args) > 1 else 'argument'} but {len(args)} " + + f"{'were' if len(args) > 1 else 'was'} given." + ) + + if len(args) < minimum_num_expected_positional_args: + missing_positional_args = full_arg_spec.args[ + len(args) : minimum_num_expected_positional_args + ] + missing_args = list(filter(lambda arg: arg not in kwargs, missing_positional_args)) + if missing_args: + missing_args_str = ( + ", ".join(map(lambda x: f"'{x}'", missing_args[:-1])) + + f", and '{missing_args[-1]}'" + if len(missing_args) > 1 + else f"'{missing_args[0]}'" + ) + raise TypeError( + f"{func.__name__}() missing {len(missing_args)} required positional " + + f"{'arguments' if len(missing_args) > 1 else 'argument'}: {missing_args_str}" + ) + + # kwargs related validations + + for k in kwargs: + if k in full_arg_spec.args and len(args) > full_arg_spec.args.index(k): + raise TypeError(f"{func.__name__}() got multiple values for argument '{k}'") + if k not in full_arg_spec.kwonlyargs and k not in full_arg_spec.args: + raise TypeError(f"{func.__name__}() got an unexpected keyword argument '{k}'") + + missing_kwargs = [ + k + for k in full_arg_spec.kwonlyargs + if k not in full_arg_spec.kwonlydefaults and k not in kwargs + ] + if missing_kwargs: + missing_kwargs_string = ( + ", ".join(map(lambda x: f"'{x}'", missing_kwargs[:-1])) + + f", and '{missing_kwargs[-1]}'" + if len(missing_kwargs) > 1 + else f"'{missing_kwargs[0]}'" + ) + + raise TypeError( + f"{func.__name__}() missing {len(missing_kwargs)} required keyword-only " + + f"{'arguments' if len(missing_kwargs) > 1 else 'argument'}: " + + f"{missing_kwargs_string}" + ) + + +class Future(object): + """Class representing a reference to a sagemaker job result. + + The sagemaker job represented may or may not have finished running. + """ + + def __init__(self): + self._condition = threading.Condition() + self._state = _PENDING + self._job = None + self._exception = None + self._return = None + + @staticmethod + def from_describe_response(describe_training_job_response, sagemaker_session): + """Construct a Future from a describe_training_job_response object.""" + future = Future() + job_exception = None + client_exception = None + job_return = None + job = _Job.from_describe_response(describe_training_job_response, sagemaker_session) + if describe_training_job_response["TrainingJobStatus"] in ["Stopping", "Stopped"]: + state = _CANCELLED + elif describe_training_job_response["TrainingJobStatus"] == "Completed": + state = _FINISHED + try: + job_return = serialization.deserialize_obj_from_s3( + sagemaker_session=sagemaker_session, + s3_uri=s3_path_join(job.s3_uri, RESULTS_FOLDER), + ) + except DeserializationError as e: + client_exception = e + except ServiceError as e: + client_exception = e + elif describe_training_job_response["TrainingJobStatus"] == "Failed": + state = _FINISHED + try: + job_exception = serialization.deserialize_exception_from_s3( + sagemaker_session=sagemaker_session, + s3_uri=s3_path_join(job.s3_uri, EXCEPTION_FOLDER), + ) + except ServiceError as serr: + chained_e = serr.__cause__ + if ( + isinstance(chained_e, ClientError) + and chained_e.response["Error"]["Code"] == "404" # pylint: disable=no-member + and chained_e.response["Error"]["Message"] # pylint: disable=no-member + == "Not Found" + ): + if ( + "FailureReason" in describe_training_job_response + and describe_training_job_response["FailureReason"] + and "RuntimeEnvironmentError: " + in describe_training_job_response["FailureReason"] + ): + failure_msg = describe_training_job_response["FailureReason"].replace( + "RuntimeEnvironmentError: ", "" + ) + job_exception = RuntimeEnvironmentError(failure_msg) + else: + job_exception = RemoteFunctionError( + "Failed to execute remote function. " + + "Check corresponding job for details." + ) + else: + job_exception = serr + except DeserializationError as e: + client_exception = e + else: + state = _RUNNING + + future._job = job + future._state = state + future._exception = job_exception or client_exception + future._return = job_return + return future + + def _start_and_notify( + self, job_settings: _JobSettings, func, func_args, func_kwargs, run_info=None + ): + """Start and record the newly created job in the future object. + + The job is recorded if one is successfully started. Otherwise, the exception is + recorded. The state update will be broadcast to other waiting threads. + """ + with self._condition: + if self._state in [_PENDING]: + + try: + self._job = _Job.start(job_settings, func, func_args, func_kwargs, run_info) + except (Exception,) as e: # pylint: disable=broad-except + self._exception = e + self._state = _FINISHED + self._condition.notify_all() + return None + + self._state = _RUNNING + self._condition.notify_all() + return self._job + return None + + def result(self, timeout: float = None) -> Any: + """Returns the function result. + + This method blocks on the sagemaker job completing for up to the timeout value (if + specified). If timeout is ``None``, this method will block until the job is completed. + Args: + timeout (float): Timeout in seconds to wait until the job is completed. ``None`` by + default. + + Returns: + The Python object returned by the function + """ + try: + self.wait(timeout) + except UnexpectedStatusException: + pass + + with self._condition: + if self._state == _PENDING: + raise RuntimeError() + + if self._state == _RUNNING: + if self._job.describe()["TrainingJobStatus"] == "Completed": + self._return = serialization.deserialize_obj_from_s3( + sagemaker_session=self._job.sagemaker_session, + s3_uri=s3_path_join(self._job.s3_uri, RESULTS_FOLDER), + ) + self._state = _FINISHED + return self._return + if self._job.describe()["TrainingJobStatus"] == "Failed": + try: + self._exception = serialization.deserialize_exception_from_s3( + sagemaker_session=self._job.sagemaker_session, + s3_uri=s3_path_join(self._job.s3_uri, EXCEPTION_FOLDER), + ) + except ServiceError as serr: + chained_e = serr.__cause__ + if ( + isinstance(chained_e, ClientError) + and chained_e.response["Error"]["Code"] # pylint: disable=no-member + == "404" + and chained_e.response["Error"]["Message"] # pylint: disable=no-member + == "Not Found" + ): + if ( + "FailureReason" in self._job.describe() + and self._job.describe()["FailureReason"] + and "RuntimeEnvironmentError: " + in self._job.describe()["FailureReason"] + ): + failure_msg = self._job.describe()["FailureReason"].replace( + "RuntimeEnvironmentError: ", "" + ) + self._exception = RuntimeEnvironmentError(failure_msg) + else: + self._exception = RemoteFunctionError( + "Failed to execute remote function. " + + "Check corresponding job for details." + ) + else: + self._exception = serr + self._state = _FINISHED + elif self._job.describe()["TrainingJobStatus"] == "Stopped": + self._state = _CANCELLED + raise RemoteFunctionError("Job for remote function has been aborted.") + else: + raise TimeoutError( + "Job for remote function timed out before reaching a termination status." + ) + + if self._state == _FINISHED: + if self._exception: + raise self._exception + return self._return + + return None + + def wait( + self, + timeout: int = None, + ) -> None: + """Wait for the underlying sagemaker job to complete. + + This method blocks on the sagemaker job completing for up to the timeout value (if + specified). If timeout is ``None``, this method will block until the job is completed. + Args: + timeout (int): Timeout in seconds to wait until the job is completed. ``None`` by + default. + + Returns: None + """ + + with self._condition: + if self._state == _PENDING: + self._condition.wait(timeout=timeout) + + if self._state == _RUNNING: + self._job.wait(timeout=timeout) + + def cancel(self): + """Cancel the function execution. + + It prevents the SageMaker job being created or stops the underlying sagemaker job early + if it is already in progress. + + Returns: ``True`` if the underlying sagemaker job is cancelled. + """ + with self._condition: + if self._state == _FINISHED: + return False + if self._state == _CANCELLED: + return True + + if self._job: + self._job.stop() + self._state = _CANCELLED + return True + + def running(self): + """Returns ``True`` if the underlying sagemaker job is still running.""" + with self._condition: + return self._state == _RUNNING + + def cancelled(self): + """Returns ``True`` if the underlying sagemaker job was cancelled. ``False``, otherwise.""" + with self._condition: + return self._state == _CANCELLED + + def done(self): + """Returns ``True`` if the underlying sagemaker job finished running.""" + with self._condition: + if self._state == _RUNNING and self._job.describe()["TrainingJobStatus"] in [ + "Completed", + "Failed", + ]: + self._state = _FINISHED + return True + + if self._state == _FINISHED: + return True + + return False + + +def get_future(job_name, sagemaker_session=None): + """Get a future object with information about a job with the given job_name. + + Args: + job_name (str): name of the underlying SageMaker job. + sagemaker_session (sagemaker.session.Session): Session object which + manages interactions with Amazon SageMaker APIs and any other + AWS services needed. + + Returns: + A `sagemaker.remote_function.client.Future` instance. + """ + if not sagemaker_session: + sagemaker_session = Session() + describe_training_job_response = sagemaker_session.sagemaker_client.describe_training_job( + TrainingJobName=job_name + ) + return Future.from_describe_response(describe_training_job_response, sagemaker_session) + + +def list_futures(job_name_prefix, sagemaker_session=None): + """Generates Future objects with information about jobs with given job_name_prefix. + + Args: + job_name_prefix (str): prefix used to identify relevant SageMaker jobs. + sagemaker_session (sagemaker.session.Session): Session object which + manages interactions with Amazon SageMaker APIs and any other + AWS services needed. + + Yields: + A `sagemaker.remote_function.client.Future` instance. + """ + if not sagemaker_session: + sagemaker_session = Session() + job_name = name_from_base(job_name_prefix) + # perform the following transformation because we might have trimmed the job_name_prefix while + # creating the job. + transformed_job_name_prefix = base_from_name(job_name) + next_token = None + list_training_job_kwargs = {"NameContains": transformed_job_name_prefix} + while True: + if next_token: + list_training_job_kwargs["NextToken"] = next_token + list_training_job_response = sagemaker_session.sagemaker_client.list_training_jobs( + **list_training_job_kwargs + ) + training_job_names = [ + job["TrainingJobName"] for job in list_training_job_response["TrainingJobSummaries"] + ] + for training_job_name in training_job_names: + describe_training_job_response = ( + sagemaker_session.sagemaker_client.describe_training_job( + TrainingJobName=training_job_name + ) + ) + yield Future.from_describe_response(describe_training_job_response, sagemaker_session) + if "NextToken" in list_training_job_response: + next_token = list_training_job_response["NextToken"] + else: + break diff --git a/src/sagemaker/remote_function/core/serialization.py b/src/sagemaker/remote_function/core/serialization.py new file mode 100644 index 0000000000..29b7f18bb1 --- /dev/null +++ b/src/sagemaker/remote_function/core/serialization.py @@ -0,0 +1,271 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""SageMaker remote function data serializer/deserializer.""" +from __future__ import absolute_import + +import dataclasses +import json +import os +import sys + +import cloudpickle + +from typing import Any, Callable +from sagemaker.remote_function.errors import ServiceError, SerializationError, DeserializationError +from sagemaker.s3 import S3Downloader, S3Uploader +from tblib import pickling_support + + +def _get_python_version(): + return f"{sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}" + + +@dataclasses.dataclass +class _MetaData: + """Metadata about the serialized data or functions.""" + + version: str = "2023-04-24" + python_version: str = _get_python_version() + serialization_module: str = "cloudpickle" + + def to_json(self): + return json.dumps(dataclasses.asdict(self)).encode() + + @staticmethod + def from_json(s): + try: + obj = json.loads(s) + except json.decoder.JSONDecodeError: + raise DeserializationError("Corrupt metadata file. It is not a valid json file.") + + metadata = _MetaData() + metadata.version = obj.get("version") + metadata.python_version = obj.get("python_version") + metadata.serialization_module = obj.get("serialization_module") + + if not ( + metadata.version == "2023-04-24" and metadata.serialization_module == "cloudpickle" + ): + raise DeserializationError( + f"Corrupt metadata file. Serialization approach {s} is not supported." + ) + + return metadata + + +class CloudpickleSerializer: + """Serializer using cloudpickle.""" + + @staticmethod + def serialize(obj: Any, sagemaker_session, s3_uri: str, s3_kms_key: str = None): + """Serializes data object and uploads it to S3. + + Args: + sagemaker_session (sagemaker.session.Session): The underlying Boto3 session which AWS service + calls are delegated to. + s3_uri (str): S3 root uri to which resulting serialized artifacts will be uploaded. + s3_kms_key (str): KMS key used to encrypt artifacts uploaded to S3. + obj: object to be serialized and persisted + Raises: + SerializationError: when fail to serialize object to bytes. + """ + try: + bytes_to_upload = cloudpickle.dumps(obj) + except Exception as e: + if isinstance( + e, NotImplementedError + ) and "Instance of Run type is not allowed to be pickled." in str(e): + raise SerializationError( + """You are trying to pass a sagemaker.experiments.run.Run object to a remote function + or are trying to access a global sagemaker.experiments.run.Run object from within the function. + This is not supported. You must use `load_run` to load an existing Run in the remote function + or instantiate a new Run in the function.""" + ) from e + + raise SerializationError( + "Error when serializing object of type [{}]: {}".format(type(obj).__name__, repr(e)) + ) from e + + _upload_bytes_to_s3(bytes_to_upload, s3_uri, s3_kms_key, sagemaker_session) + + @staticmethod + def deserialize(sagemaker_session, s3_uri) -> Any: + """Downloads from S3 and then deserializes data objects. + + Args: + sagemaker_session (sagemaker.session.Session): The underlying sagemaker session which AWS service + calls are delegated to. + s3_uri (str): S3 root uri to which resulting serialized artifacts will be uploaded. + Returns : + List of deserialized python objects. + Raises: + DeserializationError: when fail to serialize object to bytes. + """ + bytes_to_deserialize = _read_bytes_from_s3(s3_uri, sagemaker_session) + + try: + return cloudpickle.loads(bytes_to_deserialize) + except Exception as e: + raise DeserializationError( + "Error when deserializing bytes downloaded from {}: {}".format(s3_uri, repr(e)) + ) from e + + +# TODO: use dask serializer in case dask distributed is installed in users' environment. +def serialize_func_to_s3(func: Callable, sagemaker_session, s3_uri, s3_kms_key=None): + """Serializes function and uploads it to S3. + + Args: + sagemaker_session (sagemaker.session.Session): The underlying Boto3 session which AWS service + calls are delegated to. + s3_uri (str): S3 root uri to which resulting serialized artifacts will be uploaded. + s3_kms_key (str): KMS key used to encrypt artifacts uploaded to S3. + func: function to be serialized and persisted + Raises: + SerializationError: when fail to serialize function to bytes. + """ + + _upload_bytes_to_s3( + _MetaData().to_json(), os.path.join(s3_uri, "metadata.json"), s3_kms_key, sagemaker_session + ) + CloudpickleSerializer.serialize( + func, sagemaker_session, os.path.join(s3_uri, "payload.pkl"), s3_kms_key + ) + + +def deserialize_func_from_s3(sagemaker_session, s3_uri) -> Callable: + """Downloads from S3 and then deserializes data objects. + + This method downloads the serialized training job outputs to a temporary directory and + then deserializes them using dask. + + Args: + sagemaker_session (sagemaker.session.Session): The underlying sagemaker session which AWS service + calls are delegated to. + s3_uri (str): S3 root uri to which resulting serialized artifacts will be uploaded. + Returns : + The deserialized function. + Raises: + DeserializationError: when fail to serialize function to bytes. + """ + _MetaData.from_json( + _read_bytes_from_s3(os.path.join(s3_uri, "metadata.json"), sagemaker_session) + ) + + return CloudpickleSerializer.deserialize(sagemaker_session, os.path.join(s3_uri, "payload.pkl")) + + +def serialize_obj_to_s3(obj: Any, sagemaker_session, s3_uri: str, s3_kms_key: str = None): + """Serializes data object and uploads it to S3. + + Args: + sagemaker_session (sagemaker.session.Session): The underlying Boto3 session which AWS service + calls are delegated to. + s3_uri (str): S3 root uri to which resulting serialized artifacts will be uploaded. + s3_kms_key (str): KMS key used to encrypt artifacts uploaded to S3. + obj: object to be serialized and persisted + Raises: + SerializationError: when fail to serialize object to bytes. + """ + + _upload_bytes_to_s3( + _MetaData().to_json(), os.path.join(s3_uri, "metadata.json"), s3_kms_key, sagemaker_session + ) + CloudpickleSerializer.serialize( + obj, sagemaker_session, os.path.join(s3_uri, "payload.pkl"), s3_kms_key + ) + + +def deserialize_obj_from_s3(sagemaker_session, s3_uri) -> Any: + """Downloads from S3 and then deserializes data objects. + + Args: + sagemaker_session (sagemaker.session.Session): The underlying sagemaker session which AWS service + calls are delegated to. + s3_uri (str): S3 root uri to which resulting serialized artifacts will be uploaded. + Returns : + Deserialized python objects. + Raises: + DeserializationError: when fail to serialize object to bytes. + """ + + _MetaData.from_json( + _read_bytes_from_s3(os.path.join(s3_uri, "metadata.json"), sagemaker_session) + ) + + return CloudpickleSerializer.deserialize(sagemaker_session, os.path.join(s3_uri, "payload.pkl")) + + +def serialize_exception_to_s3( + exc: Exception, sagemaker_session, s3_uri: str, s3_kms_key: str = None +): + """Serializes exception with traceback and uploads it to S3. + + Args: + sagemaker_session (sagemaker.session.Session): The underlying Boto3 session which AWS service + calls are delegated to. + s3_uri (str): S3 root uri to which resulting serialized artifacts will be uploaded. + s3_kms_key (str): KMS key used to encrypt artifacts uploaded to S3. + exc: Exception to be serialized and persisted + Raises: + SerializationError: when fail to serialize object to bytes. + """ + pickling_support.install() + _upload_bytes_to_s3( + _MetaData().to_json(), os.path.join(s3_uri, "metadata.json"), s3_kms_key, sagemaker_session + ) + CloudpickleSerializer.serialize( + exc, sagemaker_session, os.path.join(s3_uri, "payload.pkl"), s3_kms_key + ) + + +def deserialize_exception_from_s3(sagemaker_session, s3_uri) -> Any: + """Downloads from S3 and then deserializes exception. + + Args: + sagemaker_session (sagemaker.session.Session): The underlying sagemaker session which AWS service + calls are delegated to. + s3_uri (str): S3 root uri to which resulting serialized artifacts will be uploaded. + Returns : + Deserialized exception with traceback. + Raises: + DeserializationError: when fail to serialize object to bytes. + """ + + _MetaData.from_json( + _read_bytes_from_s3(os.path.join(s3_uri, "metadata.json"), sagemaker_session) + ) + + return CloudpickleSerializer.deserialize(sagemaker_session, os.path.join(s3_uri, "payload.pkl")) + + +def _upload_bytes_to_s3(bytes, s3_uri, s3_kms_key, sagemaker_session): + """Wrapping s3 uploading with exception translation for remote function.""" + try: + S3Uploader.upload_bytes( + bytes, s3_uri, kms_key=s3_kms_key, sagemaker_session=sagemaker_session + ) + except Exception as e: + raise ServiceError( + "Failed to upload serialized bytes to {}: {}".format(s3_uri, repr(e)) + ) from e + + +def _read_bytes_from_s3(s3_uri, sagemaker_session): + """Wrapping s3 downloading with exception translation for remote function.""" + try: + return S3Downloader.read_bytes(s3_uri, sagemaker_session=sagemaker_session) + except Exception as e: + raise ServiceError( + "Failed to read serialized bytes from {}: {}".format(s3_uri, repr(e)) + ) from e diff --git a/src/sagemaker/remote_function/core/stored_function.py b/src/sagemaker/remote_function/core/stored_function.py new file mode 100644 index 0000000000..0204cf3e51 --- /dev/null +++ b/src/sagemaker/remote_function/core/stored_function.py @@ -0,0 +1,105 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""SageMaker job function serializer/deserializer.""" +from __future__ import absolute_import + +from sagemaker.s3 import s3_path_join +from sagemaker.remote_function import logging_config + +import sagemaker.remote_function.core.serialization as serialization + + +logger = logging_config.get_logger() + + +FUNCTION_FOLDER = "function" +ARGUMENTS_FOLDER = "arguments" +RESULTS_FOLDER = "results" +EXCEPTION_FOLDER = "exception" + + +class StoredFunction: + """Class representing a remote function stored in S3.""" + + def __init__(self, sagemaker_session, s3_base_uri, s3_kms_key=None): + """Construct a StoredFunction object. + + Args: + sagemaker_session: (sagemaker.session.Session): The underlying sagemaker session which + AWS service calls are delegated to. + s3_base_uri: the base uri to which serialized artifacts will be uploaded. + s3_kms_key: KMS key used to encrypt artifacts uploaded to S3. + """ + self.sagemaker_session = sagemaker_session + self.s3_base_uri = s3_base_uri + self.s3_kms_key = s3_kms_key + + def save(self, func, *args, **kwargs): + """Serialize and persist the function and arguments. + + Args: + func: the python function. + args: the positional arguments to func. + kwargs: the keyword arguments to func. + Returns: + None + """ + logger.info( + f"Serializing function code to {s3_path_join(self.s3_base_uri, FUNCTION_FOLDER)}" + ) + serialization.serialize_func_to_s3( + func, + self.sagemaker_session, + s3_path_join(self.s3_base_uri, FUNCTION_FOLDER), + self.s3_kms_key, + ) + + logger.info( + f"Serializing function arguments to {s3_path_join(self.s3_base_uri, ARGUMENTS_FOLDER)}" + ) + serialization.serialize_obj_to_s3( + (args, kwargs), + self.sagemaker_session, + s3_path_join(self.s3_base_uri, ARGUMENTS_FOLDER), + self.s3_kms_key, + ) + + def load_and_invoke(self) -> None: + """Load and deserialize the function and the arguments and then execute it.""" + + logger.info( + f"Deserializing function code from {s3_path_join(self.s3_base_uri, FUNCTION_FOLDER)}" + ) + func = serialization.deserialize_func_from_s3( + self.sagemaker_session, s3_path_join(self.s3_base_uri, FUNCTION_FOLDER) + ) + + logger.info( + f"Deserializing function arguments from {s3_path_join(self.s3_base_uri, ARGUMENTS_FOLDER)}" + ) + args, kwargs = serialization.deserialize_obj_from_s3( + self.sagemaker_session, s3_path_join(self.s3_base_uri, ARGUMENTS_FOLDER) + ) + + logger.info("Invoking the function") + result = func(*args, **kwargs) + + logger.info( + f"Serializing the function return and uploading to {s3_path_join(self.s3_base_uri, RESULTS_FOLDER)}" + ) + serialization.serialize_obj_to_s3( + result, + self.sagemaker_session, + s3_path_join(self.s3_base_uri, RESULTS_FOLDER), + self.s3_kms_key, + ) diff --git a/src/sagemaker/remote_function/errors.py b/src/sagemaker/remote_function/errors.py new file mode 100644 index 0000000000..b0f1f7031c --- /dev/null +++ b/src/sagemaker/remote_function/errors.py @@ -0,0 +1,99 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Definitions for reomote job errors and error handling""" +from __future__ import absolute_import + +import os + +from tblib import pickling_support +from sagemaker.s3 import s3_path_join +import sagemaker.remote_function.core.serialization as serialization + + +DEFAULT_FAILURE_CODE = 1 +FAILURE_REASON_PATH = "/opt/ml/output/failure" + + +@pickling_support.install +class RemoteFunctionError(Exception): + """The base exception class for remote function exceptions""" + + def __init__(self, message): + self.message = message + super().__init__(self.message) + + +@pickling_support.install +class ServiceError(RemoteFunctionError): + """Raised when errors encountered during interaction with SageMaker, S3 service APIs""" + + +@pickling_support.install +class SerializationError(RemoteFunctionError): + """Raised when errors encountered during serialization of remote function objects""" + + +@pickling_support.install +class DeserializationError(RemoteFunctionError): + """Raised when errors encountered during deserialization of remote function objects""" + + +def _get_valid_failure_exit_code(exit_code) -> int: + """Normalize exit code for terminating the process""" + try: + valid_exit_code = int(exit_code) + except (TypeError, ValueError): + valid_exit_code = DEFAULT_FAILURE_CODE + + return valid_exit_code + + +def _write_failure_reason_file(failure_msg): + """Create a file 'failure' with failure reason written if remote function execution failed. + + See: https://docs.aws.amazon.com/sagemaker/latest/dg/your-algorithms-training-algo.html + Args: + failure_msg: The content of file to be written. + """ + if not os.path.exists(FAILURE_REASON_PATH): + with open(FAILURE_REASON_PATH, "w") as f: + f.write(failure_msg) + + +def handle_error(error, sagemaker_session, s3_base_uri, s3_kms_key) -> int: + """Handle all exceptions raised during remote function execution. + + Args: + error (Exception): The error to be handled. + sagemaker_session (sagemaker.session.Session): The underlying Boto3 session which + AWS service calls are delegated to. + s3_base_uri (str): S3 root uri to which resulting serialized exception will be uploaded. + s3_kms_key (str): KMS key used to encrypt artifacts uploaded to S3. + Returns : + exit_code (int): Exit code to terminate current job. + """ + + failure_reason = repr(error) + if isinstance(error, RemoteFunctionError): + exit_code = DEFAULT_FAILURE_CODE + else: + error_number = getattr(error, "errno", DEFAULT_FAILURE_CODE) + exit_code = _get_valid_failure_exit_code(error_number) + + _write_failure_reason_file(failure_reason) + + serialization.serialize_exception_to_s3( + error, sagemaker_session, s3_path_join(s3_base_uri, "exception"), s3_kms_key + ) + + return exit_code diff --git a/src/sagemaker/remote_function/invoke_function.py b/src/sagemaker/remote_function/invoke_function.py new file mode 100644 index 0000000000..66c866a1b0 --- /dev/null +++ b/src/sagemaker/remote_function/invoke_function.py @@ -0,0 +1,103 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""An entry point for invoking remote function inside a job.""" + +from __future__ import absolute_import + +import argparse +import sys +import json + +import boto3 +from sagemaker.experiments.run import Run +from sagemaker.remote_function.job import ( + KEY_EXPERIMENT_NAME, + KEY_RUN_NAME, +) + +from sagemaker.session import Session +from sagemaker.remote_function.errors import handle_error +from sagemaker.remote_function import logging_config + + +SUCCESS_EXIT_CODE = 0 + + +def _parse_agrs(): + """Parses CLI arguments.""" + parser = argparse.ArgumentParser() + parser.add_argument("--region", type=str, required=True) + parser.add_argument("--s3_base_uri", type=str, required=True) + parser.add_argument("--s3_kms_key", type=str) + parser.add_argument("--run_in_context", type=str) + + args, _ = parser.parse_known_args() + return args + + +def _get_sagemaker_session(region): + """Get sagemaker session for interacting with AWS or Sagemaker services""" + boto_session = boto3.session.Session(region_name=region) + return Session(boto_session=boto_session) + + +def _load_run_object(run_in_context: str, sagemaker_session: Session) -> Run: + """Load current run in json string into run object""" + run_dict = json.loads(run_in_context) + return Run( + experiment_name=run_dict.get(KEY_EXPERIMENT_NAME), + run_name=run_dict.get(KEY_RUN_NAME), + sagemaker_session=sagemaker_session, + ) + + +def _execute_remote_function(sagemaker_session, s3_base_uri, s3_kms_key, run_in_context): + """Execute stored remote function""" + from sagemaker.remote_function.core.stored_function import StoredFunction + + stored_function = StoredFunction(sagemaker_session, s3_base_uri, s3_kms_key) + + if run_in_context: + run_obj = _load_run_object(run_in_context, sagemaker_session) + with run_obj: + stored_function.load_and_invoke() + else: + stored_function.load_and_invoke() + + +def main(): + """Entry point for invoke function script""" + + logger = logging_config.get_logger() + + exit_code = SUCCESS_EXIT_CODE + + try: + args = _parse_agrs() + region = args.region + s3_base_uri = args.s3_base_uri + s3_kms_key = args.s3_kms_key + run_in_context = args.run_in_context + + sagemaker_session = _get_sagemaker_session(region) + _execute_remote_function(sagemaker_session, s3_base_uri, s3_kms_key, run_in_context) + + except Exception as e: # pylint: disable=broad-except + logger.exception("Error encountered while invoking the remote function.") + exit_code = handle_error(e, sagemaker_session, s3_base_uri, s3_kms_key) + finally: + sys.exit(exit_code) + + +if __name__ == "__main__": + main() diff --git a/src/sagemaker/remote_function/job.py b/src/sagemaker/remote_function/job.py new file mode 100644 index 0000000000..04ebfada13 --- /dev/null +++ b/src/sagemaker/remote_function/job.py @@ -0,0 +1,653 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Helper classes that interact with SageMaker Training service.""" +from __future__ import absolute_import +import dataclasses + +import os +import re +import shutil +import sys +import json +from typing import Dict, List, Tuple + +from sagemaker.config.config_schema import ( + REMOTE_FUNCTION_ENVIRONMENT_VARIABLES, + REMOTE_FUNCTION_IMAGE_URI, + REMOTE_FUNCTION_DEPENDENCIES, + REMOTE_FUNCTION_PRE_EXECUTION_COMMANDS, + REMOTE_FUNCTION_PRE_EXECUTION_SCRIPT, + REMOTE_FUNCTION_INCLUDE_LOCAL_WORKDIR, + REMOTE_FUNCTION_INSTANCE_TYPE, + REMOTE_FUNCTION_JOB_CONDA_ENV, + REMOTE_FUNCTION_ROLE_ARN, + REMOTE_FUNCTION_S3_ROOT_URI, + REMOTE_FUNCTION_S3_KMS_KEY_ID, + REMOTE_FUNCTION_VOLUME_KMS_KEY_ID, + REMOTE_FUNCTION_TAGS, + REMOTE_FUNCTION_VPC_CONFIG_SUBNETS, + REMOTE_FUNCTION_VPC_CONFIG_SECURITY_GROUP_IDS, + REMOTE_FUNCTION_ENABLE_INTER_CONTAINER_TRAFFIC_ENCRYPTION, +) +from sagemaker.experiments._run_context import _RunContext +from sagemaker.experiments.run import Run +from sagemaker.image_uris import get_base_python_image_uri +from sagemaker.session import get_execution_role, _logs_for_job, Session +from sagemaker.utils import name_from_base, _tmpdir, resolve_value_from_config +from sagemaker.s3 import s3_path_join, S3Uploader +from sagemaker import vpc_utils +from sagemaker.remote_function.core.stored_function import StoredFunction +from sagemaker.remote_function.runtime_environment.runtime_environment_manager import ( + RuntimeEnvironmentManager, +) +from sagemaker.remote_function import logging_config + +# runtime script names +BOOTSTRAP_SCRIPT_NAME = "bootstrap_runtime_environment.py" +ENTRYPOINT_SCRIPT_NAME = "job_driver.sh" +PRE_EXECUTION_SCRIPT_NAME = "pre_exec.sh" +RUNTIME_MANAGER_SCRIPT_NAME = "runtime_environment_manager.py" + +# training channel names +RUNTIME_SCRIPTS_CHANNEL_NAME = "sagemaker_remote_function_bootstrap" +REMOTE_FUNCTION_WORKSPACE = "sm_rf_user_ws" +JOB_REMOTE_FUNCTION_WORKSPACE = "sagemaker_remote_function_workspace" + +# run context dictionary keys +KEY_EXPERIMENT_NAME = "experiment_name" +KEY_RUN_NAME = "run_name" + +JOBS_CONTAINER_ENTRYPOINT = [ + "/bin/bash", + f"/opt/ml/input/data/{RUNTIME_SCRIPTS_CHANNEL_NAME}/{ENTRYPOINT_SCRIPT_NAME}", +] + +ENTRYPOINT_SCRIPT = f""" +#!/bin/bash + +# Entry point for bootstrapping runtime environment and invoking remote function + +set -eu + +PERSISTENT_CACHE_DIR=${{SAGEMAKER_MANAGED_WARMPOOL_CACHE_DIRECTORY:-/opt/ml/cache}} +export CONDA_PKGS_DIRS=${{PERSISTENT_CACHE_DIR}}/sm_remotefunction_user_dependencies_cache/conda/pkgs +printf "INFO: CONDA_PKGS_DIRS is set to '$CONDA_PKGS_DIRS'\\n" +export PIP_CACHE_DIR=${{PERSISTENT_CACHE_DIR}}/sm_remotefunction_user_dependencies_cache/pip +printf "INFO: PIP_CACHE_DIR is set to '$PIP_CACHE_DIR'\\n" + + +printf "INFO: Bootstraping runtime environment.\\n" +python /opt/ml/input/data/{RUNTIME_SCRIPTS_CHANNEL_NAME}/{BOOTSTRAP_SCRIPT_NAME} "$@" + +if [ -d {JOB_REMOTE_FUNCTION_WORKSPACE} ] +then + if [ -f "remote_function_conda_env.txt" ] + then + cp remote_function_conda_env.txt {JOB_REMOTE_FUNCTION_WORKSPACE}/remote_function_conda_env.txt + fi + printf "INFO: Changing workspace to {JOB_REMOTE_FUNCTION_WORKSPACE}.\\n" + cd {JOB_REMOTE_FUNCTION_WORKSPACE} +fi + +if [ -f "remote_function_conda_env.txt" ] +then + conda_env=$(cat remote_function_conda_env.txt) + + if which mamba >/dev/null; then + conda_exe="mamba" + else + conda_exe="conda" + fi + + printf "INFO: Invoking remote function inside conda environment: $conda_env.\\n" + $conda_exe run -n $conda_env python -m sagemaker.remote_function.invoke_function "$@" +else + printf "INFO: No conda env provided. Invoking remote function\\n" + python -m sagemaker.remote_function.invoke_function "$@" +fi +""" + +logger = logging_config.get_logger() + + +class _JobSettings: + """Helper class that processes the job settings. + + It validates the job settings and provides default values if necessary. + """ + + def __init__( + self, + *, + dependencies: str = None, + pre_execution_commands: List[str] = None, + pre_execution_script: str = None, + environment_variables: Dict[str, str] = None, + image_uri: str = None, + include_local_workdir: bool = None, + instance_count: int = 1, + instance_type: str = None, + job_conda_env: str = None, + job_name_prefix: str = None, + keep_alive_period_in_seconds: int = 0, + max_retry_attempts: int = 1, + max_runtime_in_seconds: int = 24 * 60 * 60, + role: str = None, + s3_kms_key: str = None, + s3_root_uri: str = None, + sagemaker_session: Session = None, + security_group_ids: List[str] = None, + subnets: List[str] = None, + tags: List[Tuple[str, str]] = None, + volume_kms_key: str = None, + volume_size: int = 30, + encrypt_inter_container_traffic: bool = None, + ): + + self.sagemaker_session = sagemaker_session or Session() + + self.environment_variables = resolve_value_from_config( + direct_input=environment_variables, + config_path=REMOTE_FUNCTION_ENVIRONMENT_VARIABLES, + default_value={}, + sagemaker_session=self.sagemaker_session, + ) + self.environment_variables.update( + {"AWS_DEFAULT_REGION": self.sagemaker_session.boto_region_name} + ) + + _image_uri = resolve_value_from_config( + direct_input=image_uri, + config_path=REMOTE_FUNCTION_IMAGE_URI, + sagemaker_session=self.sagemaker_session, + ) + if _image_uri: + self.image_uri = _image_uri + else: + self.image_uri = self._get_default_image(self.sagemaker_session) + + self.dependencies = resolve_value_from_config( + direct_input=dependencies, + config_path=REMOTE_FUNCTION_DEPENDENCIES, + sagemaker_session=self.sagemaker_session, + ) + + self.pre_execution_commands = resolve_value_from_config( + direct_input=pre_execution_commands, + config_path=REMOTE_FUNCTION_PRE_EXECUTION_COMMANDS, + sagemaker_session=self.sagemaker_session, + ) + + self.pre_execution_script = resolve_value_from_config( + direct_input=pre_execution_script, + config_path=REMOTE_FUNCTION_PRE_EXECUTION_SCRIPT, + sagemaker_session=self.sagemaker_session, + ) + + if self.pre_execution_commands is not None and self.pre_execution_script is not None: + raise ValueError( + "Only one of pre_execution_commands or pre_execution_script can be specified!" + ) + + self.include_local_workdir = resolve_value_from_config( + direct_input=include_local_workdir, + config_path=REMOTE_FUNCTION_INCLUDE_LOCAL_WORKDIR, + default_value=False, + sagemaker_session=self.sagemaker_session, + ) + self.instance_type = resolve_value_from_config( + direct_input=instance_type, + config_path=REMOTE_FUNCTION_INSTANCE_TYPE, + sagemaker_session=self.sagemaker_session, + ) + if not self.instance_type: + raise ValueError("instance_type is a required parameter!") + + self.instance_count = instance_count + self.volume_size = volume_size + self.max_runtime_in_seconds = max_runtime_in_seconds + self.max_retry_attempts = max_retry_attempts + self.keep_alive_period_in_seconds = keep_alive_period_in_seconds + self.job_conda_env = resolve_value_from_config( + direct_input=job_conda_env, + config_path=REMOTE_FUNCTION_JOB_CONDA_ENV, + sagemaker_session=self.sagemaker_session, + ) + self.job_name_prefix = job_name_prefix + self.encrypt_inter_container_traffic = resolve_value_from_config( + direct_input=encrypt_inter_container_traffic, + config_path=REMOTE_FUNCTION_ENABLE_INTER_CONTAINER_TRAFFIC_ENCRYPTION, + default_value=False, + sagemaker_session=self.sagemaker_session, + ) + self.enable_network_isolation = False + + _role = resolve_value_from_config( + direct_input=role, + config_path=REMOTE_FUNCTION_ROLE_ARN, + sagemaker_session=self.sagemaker_session, + ) + if _role: + self.role = self.sagemaker_session.expand_role(_role) + else: + self.role = get_execution_role(self.sagemaker_session) + + self.s3_root_uri = resolve_value_from_config( + direct_input=s3_root_uri, + config_path=REMOTE_FUNCTION_S3_ROOT_URI, + default_value=os.path.join("s3://", self.sagemaker_session.default_bucket()), + sagemaker_session=self.sagemaker_session, + ) + + self.s3_kms_key = resolve_value_from_config( + direct_input=s3_kms_key, + config_path=REMOTE_FUNCTION_S3_KMS_KEY_ID, + sagemaker_session=self.sagemaker_session, + ) + self.volume_kms_key = resolve_value_from_config( + direct_input=volume_kms_key, + config_path=REMOTE_FUNCTION_VOLUME_KMS_KEY_ID, + sagemaker_session=self.sagemaker_session, + ) + + _subnets = resolve_value_from_config( + direct_input=subnets, + config_path=REMOTE_FUNCTION_VPC_CONFIG_SUBNETS, + sagemaker_session=self.sagemaker_session, + ) + _security_group_ids = resolve_value_from_config( + direct_input=security_group_ids, + config_path=REMOTE_FUNCTION_VPC_CONFIG_SECURITY_GROUP_IDS, + sagemaker_session=self.sagemaker_session, + ) + vpc_config = vpc_utils.to_dict(subnets=_subnets, security_group_ids=_security_group_ids) + self.vpc_config = vpc_utils.sanitize(vpc_config) + + self.tags = self.sagemaker_session._append_sagemaker_config_tags( + [{"Key": k, "Value": v} for k, v in tags] if tags else None, REMOTE_FUNCTION_TAGS + ) + + @staticmethod + def _get_default_image(session): + """Return Studio notebook image, if in Studio env. Else, base python""" + + if ( + "SAGEMAKER_INTERNAL_IMAGE_URI" in os.environ + and os.environ["SAGEMAKER_INTERNAL_IMAGE_URI"] + ): + return os.environ["SAGEMAKER_INTERNAL_IMAGE_URI"] + + py_version = str(sys.version_info[0]) + str(sys.version_info[1]) + + if py_version not in ["310", "38"]: + raise ValueError( + "Default image is supported only for Python versions 3.8 and 3.10. If you " + "are using any other python version, you must provide a compatible image_uri." + ) + + region = session.boto_region_name + image_uri = get_base_python_image_uri(region=region, py_version=py_version) + + return image_uri + + +class _Job: + """Helper class that interacts with the SageMaker training service.""" + + def __init__(self, job_name: str, s3_uri: str, sagemaker_session: Session): + """Initialize a _Job object.""" + self.job_name = job_name + self.s3_uri = s3_uri + self.sagemaker_session = sagemaker_session + self._last_describe_response = None + + @staticmethod + def from_describe_response(describe_training_job_response, sagemaker_session): + """Construct a _Job from a describe_training_job_response object.""" + job_name = describe_training_job_response["TrainingJobName"] + s3_uri = describe_training_job_response["OutputDataConfig"]["S3OutputPath"] + job = _Job(job_name, s3_uri, sagemaker_session) + job._last_describe_response = describe_training_job_response + return job + + @staticmethod + def start(job_settings: _JobSettings, func, func_args, func_kwargs, run_info=None): + """Start a training job. + + Args: + job_settings (_JobSettings): the job settings. + func: the function to be executed. + func_args: the positional arguments to the function. + func_kwargs: the keyword arguments to the function + + Returns: the _Job object. + """ + job_name = _Job._get_job_name(job_settings, func) + s3_base_uri = s3_path_join(job_settings.s3_root_uri, job_name) + + bootstrap_scripts_s3uri = _prepare_and_upload_runtime_scripts( + s3_base_uri=s3_base_uri, + s3_kms_key=job_settings.s3_kms_key, + sagemaker_session=job_settings.sagemaker_session, + ) + + dependencies_list_path = RuntimeEnvironmentManager().snapshot(job_settings.dependencies) + user_dependencies_s3uri = _prepare_and_upload_dependencies( + local_dependencies_path=dependencies_list_path, + include_local_workdir=job_settings.include_local_workdir, + pre_execution_commands=job_settings.pre_execution_commands, + pre_execution_script_local_path=job_settings.pre_execution_script, + s3_base_uri=s3_base_uri, + s3_kms_key=job_settings.s3_kms_key, + sagemaker_session=job_settings.sagemaker_session, + ) + + stored_function = StoredFunction( + sagemaker_session=job_settings.sagemaker_session, + s3_base_uri=s3_base_uri, + s3_kms_key=job_settings.s3_kms_key, + ) + + stored_function.save(func, *func_args, **func_kwargs) + + request_dict = dict( + TrainingJobName=job_name, + RoleArn=job_settings.role, + StoppingCondition={ + "MaxRuntimeInSeconds": job_settings.max_runtime_in_seconds, + }, + RetryStrategy={"MaximumRetryAttempts": job_settings.max_retry_attempts}, + ) + + if job_settings.tags: + request_dict["Tags"] = job_settings.tags + + input_data_config = [ + dict( + ChannelName=RUNTIME_SCRIPTS_CHANNEL_NAME, + DataSource={ + "S3DataSource": { + "S3Uri": bootstrap_scripts_s3uri, + "S3DataType": "S3Prefix", + } + }, + ) + ] + + if user_dependencies_s3uri: + input_data_config.append( + dict( + ChannelName=REMOTE_FUNCTION_WORKSPACE, + DataSource={ + "S3DataSource": { + "S3Uri": s3_path_join(s3_base_uri, REMOTE_FUNCTION_WORKSPACE), + "S3DataType": "S3Prefix", + } + }, + ) + ) + + request_dict["InputDataConfig"] = input_data_config + + output_config = {"S3OutputPath": s3_base_uri} + if job_settings.s3_kms_key is not None: + output_config["KmsKeyId"] = job_settings.s3_kms_key + request_dict["OutputDataConfig"] = output_config + + container_args = ["--s3_base_uri", s3_base_uri] + container_args.extend(["--region", job_settings.sagemaker_session.boto_region_name]) + container_args.extend( + ["--client_python_version", RuntimeEnvironmentManager()._current_python_version()] + ) + if job_settings.s3_kms_key: + container_args.extend(["--s3_kms_key", job_settings.s3_kms_key]) + + if job_settings.job_conda_env: + container_args.extend(["--job_conda_env", job_settings.job_conda_env]) + + if run_info is not None: + container_args.extend(["--run_in_context", json.dumps(dataclasses.asdict(run_info))]) + elif _RunContext.get_current_run() is not None: + container_args.extend( + ["--run_in_context", _convert_run_to_json(_RunContext.get_current_run())] + ) + + algorithm_spec = dict( + TrainingImage=job_settings.image_uri, + TrainingInputMode="File", + ContainerEntrypoint=JOBS_CONTAINER_ENTRYPOINT, + ContainerArguments=container_args, + ) + + request_dict["AlgorithmSpecification"] = algorithm_spec + + resource_config = dict( + VolumeSizeInGB=job_settings.volume_size, + InstanceCount=job_settings.instance_count, + InstanceType=job_settings.instance_type, + ) + if job_settings.volume_kms_key is not None: + resource_config["VolumeKmsKeyId"] = job_settings.volume_kms_key + if job_settings.keep_alive_period_in_seconds is not None: + resource_config["KeepAlivePeriodInSeconds"] = job_settings.keep_alive_period_in_seconds + + request_dict["ResourceConfig"] = resource_config + + if job_settings.enable_network_isolation is not None: + request_dict["EnableNetworkIsolation"] = job_settings.enable_network_isolation + + if job_settings.encrypt_inter_container_traffic is not None: + request_dict[ + "EnableInterContainerTrafficEncryption" + ] = job_settings.encrypt_inter_container_traffic + + if job_settings.vpc_config: + request_dict["VpcConfig"] = job_settings.vpc_config + + if job_settings.environment_variables: + request_dict["Environment"] = job_settings.environment_variables + + logger.info("Creating job: %s", job_name) + job_settings.sagemaker_session.sagemaker_client.create_training_job(**request_dict) + + return _Job(job_name, s3_base_uri, job_settings.sagemaker_session) + + def describe(self): + """Describe the underlying sagemaker training job.""" + if self._last_describe_response is not None and self._last_describe_response[ + "TrainingJobStatus" + ] in ["Completed", "Failed", "Stopped"]: + return self._last_describe_response + + self._last_describe_response = ( + self.sagemaker_session.sagemaker_client.describe_training_job( + TrainingJobName=self.job_name + ) + ) + + return self._last_describe_response + + def stop(self): + """Stop the underlying sagemaker training job.""" + self.sagemaker_session.sagemaker_client.stop_training_job(TrainingJobName=self.job_name) + + def wait(self, timeout: int = None): + """Wait for the underlying sagemaker job to finish and displays its logs . + + This method blocks on the sagemaker job completing for up to the timeout value (if + specified). If timeout is ``None``, this method will block until the job is completed. + + Args: + timeout (int): Timeout in seconds to wait until the job is completed. ``None`` by + default. + + Returns: None + """ + + self._last_describe_response = _logs_for_job( + boto_session=self.sagemaker_session.boto_session, + job_name=self.job_name, + wait=True, + timeout=timeout, + ) + + @staticmethod + def _get_job_name(job_settings, func): + """Get the underlying SageMaker job name from job_name_prefix or func.""" + job_name_prefix = job_settings.job_name_prefix + if not job_name_prefix: + job_name_prefix = func.__name__ + # remove all special characters in the beginning of function name + job_name_prefix = re.sub(r"^[^a-zA-Z0-9]+", "", job_name_prefix) + # convert all remaining special characters to '-' + job_name_prefix = re.sub(r"[^a-zA-Z0-9-]", "-", job_name_prefix) + return name_from_base(job_name_prefix) + + +def _prepare_and_upload_runtime_scripts( + s3_base_uri: str, s3_kms_key: str, sagemaker_session: Session +): + """Copy runtime scripts to a folder and upload to S3""" + + with _tmpdir() as bootstrap_scripts: + + # write entrypoint script to tmpdir + entrypoint_script_path = os.path.join(bootstrap_scripts, ENTRYPOINT_SCRIPT_NAME) + with open(entrypoint_script_path, "w") as file: + file.writelines(ENTRYPOINT_SCRIPT) + + bootstrap_script_path = os.path.join( + os.path.dirname(__file__), "runtime_environment", BOOTSTRAP_SCRIPT_NAME + ) + runtime_manager_script_path = os.path.join( + os.path.dirname(__file__), "runtime_environment", RUNTIME_MANAGER_SCRIPT_NAME + ) + + # copy runtime scripts to tmpdir + shutil.copy2(bootstrap_script_path, bootstrap_scripts) + shutil.copy2(runtime_manager_script_path, bootstrap_scripts) + + return S3Uploader.upload( + bootstrap_scripts, + s3_path_join(s3_base_uri, RUNTIME_SCRIPTS_CHANNEL_NAME), + s3_kms_key, + sagemaker_session, + ) + + +def _prepare_and_upload_dependencies( + local_dependencies_path: str, + include_local_workdir: bool, + pre_execution_commands: List[str], + pre_execution_script_local_path: str, + s3_base_uri: str, + s3_kms_key: str, + sagemaker_session: Session, +) -> str: + """Upload the job dependencies to S3 if present""" + + if not ( + local_dependencies_path + or include_local_workdir + or pre_execution_commands + or pre_execution_script_local_path + ): + return None + + with _tmpdir() as tmp_dir: + tmp_workspace_dir = os.path.join(tmp_dir, "temp_workspace/") + os.mkdir(tmp_workspace_dir) + # TODO Remove the following hack to avoid dir_exists error in the copy_tree call below. + tmp_workspace = os.path.join(tmp_workspace_dir, JOB_REMOTE_FUNCTION_WORKSPACE) + + if include_local_workdir: + shutil.copytree( + os.getcwd(), + tmp_workspace, + ignore=_filter_non_python_files, + ) + logger.info("Copied user workspace python scripts to '%s'", tmp_workspace) + + if local_dependencies_path: + if not os.path.isdir(tmp_workspace): + # create the directory if no workdir_path was provided in the input. + os.mkdir(tmp_workspace) + dst_path = shutil.copy2(local_dependencies_path, tmp_workspace) + logger.info( + "Copied dependencies file at '%s' to '%s'", local_dependencies_path, dst_path + ) + + if pre_execution_commands or pre_execution_script_local_path: + if not os.path.isdir(tmp_workspace): + os.mkdir(tmp_workspace) + pre_execution_script = os.path.join(tmp_workspace, PRE_EXECUTION_SCRIPT_NAME) + if pre_execution_commands: + with open(pre_execution_script, "w") as target_script: + commands = [cmd + "\n" for cmd in pre_execution_commands] + target_script.writelines(commands) + logger.info( + "Generated pre-execution script from commands to '%s'", pre_execution_script + ) + else: + shutil.copy(pre_execution_script_local_path, pre_execution_script) + logger.info( + "Copied pre-execution commands from script at '%s' to '%s'", + pre_execution_script_local_path, + pre_execution_script, + ) + + workspace_archive_path = os.path.join(tmp_dir, "workspace") + workspace_archive_path = shutil.make_archive( + workspace_archive_path, "zip", tmp_workspace_dir + ) + logger.info("Successfully created workdir archive at '%s'", workspace_archive_path) + + upload_path = S3Uploader.upload( + workspace_archive_path, + s3_path_join(s3_base_uri, REMOTE_FUNCTION_WORKSPACE), + s3_kms_key, + sagemaker_session, + ) + logger.info("Successfully uploaded workdir to '%s'", upload_path) + return upload_path + + +def _convert_run_to_json(run: Run) -> str: + """Convert current run into json string""" + run_info = _RunInfo(run.experiment_name, run.run_name) + return json.dumps(dataclasses.asdict(run_info)) + + +def _filter_non_python_files(path: str, names: List) -> List: + """Ignore function for filtering out non python files.""" + to_ignore = [] + for name in names: + full_path = os.path.join(path, name) + if os.path.isfile(full_path): + if not name.endswith(".py"): + to_ignore.append(name) + elif os.path.isdir(full_path): + if name == "__pycache__": + to_ignore.append(name) + else: + to_ignore.append(name) + + return to_ignore + + +@dataclasses.dataclass +class _RunInfo: + """Data class to hold information of the run object from context.""" + + experiment_name: str + run_name: str diff --git a/src/sagemaker/remote_function/logging_config.py b/src/sagemaker/remote_function/logging_config.py new file mode 100644 index 0000000000..875fabf6e0 --- /dev/null +++ b/src/sagemaker/remote_function/logging_config.py @@ -0,0 +1,38 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Utilities related to logging.""" +from __future__ import absolute_import + +import logging +import time + + +class _UTCFormatter(logging.Formatter): + """Class that overrides the default local time provider in log formatter.""" + + converter = time.gmtime + + +def get_logger(): + """Return a logger with the name 'sagemaker'""" + sagemaker_logger = logging.getLogger("sagemaker.remote_function") + if len(sagemaker_logger.handlers) == 0: + sagemaker_logger.setLevel(logging.INFO) + handler = logging.StreamHandler() + formatter = _UTCFormatter("%(asctime)s %(name)s %(levelname)-8s %(message)s") + handler.setFormatter(formatter) + sagemaker_logger.addHandler(handler) + # don't stream logs with the root logger handler + sagemaker_logger.propagate = 0 + + return sagemaker_logger diff --git a/src/sagemaker/remote_function/runtime_environment/__init__.py b/src/sagemaker/remote_function/runtime_environment/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/sagemaker/remote_function/runtime_environment/bootstrap_runtime_environment.py b/src/sagemaker/remote_function/runtime_environment/bootstrap_runtime_environment.py new file mode 100644 index 0000000000..6655e1febf --- /dev/null +++ b/src/sagemaker/remote_function/runtime_environment/bootstrap_runtime_environment.py @@ -0,0 +1,147 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""An entry point for runtime environment. This must be kept independent of SageMaker PySDK""" +from __future__ import absolute_import + +import argparse +import sys +import os +import shutil +import pathlib + +if __package__ is None or __package__ == "": + from runtime_environment_manager import RuntimeEnvironmentManager, get_logger +else: + from sagemaker.remote_function.runtime_environment.runtime_environment_manager import ( + RuntimeEnvironmentManager, + get_logger, + ) + +SUCCESS_EXIT_CODE = 0 +DEFAULT_FAILURE_CODE = 1 + +REMOTE_FUNCTION_WORKSPACE = "sm_rf_user_ws" +BASE_CHANNEL_PATH = "/opt/ml/input/data" +FAILURE_REASON_PATH = "/opt/ml/output/failure" +PRE_EXECUTION_SCRIPT_NAME = "pre_exec.sh" +JOB_REMOTE_FUNCTION_WORKSPACE = "sagemaker_remote_function_workspace" + + +logger = get_logger() + + +def main(): + """Entry point for bootstrap script""" + + exit_code = DEFAULT_FAILURE_CODE + + try: + args = _parse_agrs() + client_python_version = args.client_python_version + job_conda_env = args.job_conda_env + + conda_env = job_conda_env or os.getenv("SAGEMAKER_JOB_CONDA_ENV") + + RuntimeEnvironmentManager()._validate_python_version(client_python_version, conda_env) + + _bootstrap_runtime_environment(client_python_version, conda_env) + + exit_code = SUCCESS_EXIT_CODE + except Exception as e: # pylint: disable=broad-except + logger.exception("Error encountered while bootstrapping runtime environment: %s", e) + + _write_failure_reason_file(str(e)) + finally: + sys.exit(exit_code) + + +def _bootstrap_runtime_environment( + client_python_version: str, + conda_env: str = None, +): + """Bootstrap runtime environment for remote function invocation + + Args: + conda_env (str): conda environment to be activated. Default is None. + """ + workspace_archive_dir_path = os.path.join(BASE_CHANNEL_PATH, REMOTE_FUNCTION_WORKSPACE) + + if not os.path.exists(workspace_archive_dir_path): + logger.info( + "Directory '%s' does not exist. Assuming no dependencies to bootstrap.", + workspace_archive_dir_path, + ) + return + + # Unpack user workspace archive first. + workspace_archive_path = os.path.join(workspace_archive_dir_path, "workspace.zip") + if not os.path.isfile(workspace_archive_path): + logger.info( + "Workspace archive '%s' does not exist. Assuming no dependencies to bootstrap.", + workspace_archive_dir_path, + ) + return + + workspace_unpack_dir = pathlib.Path(os.getcwd()).absolute() + shutil.unpack_archive(filename=workspace_archive_path, extract_dir=workspace_unpack_dir) + logger.info("Successfully unpacked workspace archive at '%s'.", workspace_unpack_dir) + workspace_unpack_dir = pathlib.Path(workspace_unpack_dir, JOB_REMOTE_FUNCTION_WORKSPACE) + + # Handle pre-execution commands + path_to_pre_exec_script = os.path.join(workspace_unpack_dir, PRE_EXECUTION_SCRIPT_NAME) + RuntimeEnvironmentManager().run_pre_exec_script(pre_exec_script_path=path_to_pre_exec_script) + + # Handle dependencies file. + dependencies_file = None + for file in os.listdir(workspace_unpack_dir): + if file.endswith(".txt") or file.endswith(".yml") or file.endswith(".yaml"): + dependencies_file = os.path.join(workspace_unpack_dir, file) + break + + if dependencies_file: + RuntimeEnvironmentManager().bootstrap( + local_dependencies_file=dependencies_file, + conda_env=conda_env, + client_python_version=client_python_version, + ) + else: + logger.info( + "Did not find any dependency file in workspace directory at '%s'." + " Assuming no additional dependencies to install.", + workspace_archive_dir_path, + ) + + +def _write_failure_reason_file(failure_msg): + """Create a file 'failure' with failure reason written if bootstrap runtime env failed. + + See: https://docs.aws.amazon.com/sagemaker/latest/dg/your-algorithms-training-algo.html + Args: + failure_msg: The content of file to be written. + """ + if not os.path.exists(FAILURE_REASON_PATH): + with open(FAILURE_REASON_PATH, "w") as f: + f.write("RuntimeEnvironmentError: " + failure_msg) + + +def _parse_agrs(): + """Parses CLI arguments.""" + parser = argparse.ArgumentParser() + parser.add_argument("--job_conda_env", type=str) + parser.add_argument("--client_python_version") + args, _ = parser.parse_known_args() + return args + + +if __name__ == "__main__": + main() diff --git a/src/sagemaker/remote_function/runtime_environment/runtime_environment_manager.py b/src/sagemaker/remote_function/runtime_environment/runtime_environment_manager.py new file mode 100644 index 0000000000..8a7ee8f686 --- /dev/null +++ b/src/sagemaker/remote_function/runtime_environment/runtime_environment_manager.py @@ -0,0 +1,366 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""SageMaker runtime environment module. This must be kept independent of SageMaker PySDK""" + +from __future__ import absolute_import + + +import logging +import sys +import shlex +import os +import subprocess +import time + + +class _UTCFormatter(logging.Formatter): + """Class that overrides the default local time provider in log formatter.""" + + converter = time.gmtime + + +def get_logger(): + """Return a logger with the name 'sagemaker'""" + sagemaker_logger = logging.getLogger("sagemaker.remote_function") + if len(sagemaker_logger.handlers) == 0: + sagemaker_logger.setLevel(logging.INFO) + handler = logging.StreamHandler() + formatter = _UTCFormatter("%(asctime)s %(name)s %(levelname)-8s %(message)s") + handler.setFormatter(formatter) + sagemaker_logger.addHandler(handler) + # don't stream logs with the root logger handler + sagemaker_logger.propagate = 0 + + return sagemaker_logger + + +logger = get_logger() + + +class RuntimeEnvironmentManager: + """Runtime Environment Manager class to manage runtime environment.""" + + def snapshot(self, dependencies: str = None) -> str: + """Creates snapshot of the user's environment + + If a req.txt or conda.yml file is provided, it verifies their existence and + returns the local file path + If ``auto_capture`` is set, this method will take the snapshot of + user's dependencies installed in the local runtime. + Current support for ``auto_capture``: + * conda env, generate a yml file and return it's local path + + Args: + dependencies (str): Local path where dependencies file exists. + + Returns: + file path of the existing or generated dependencies file + """ + + # No additional dependencies specified + if dependencies is None: + return None + + if dependencies == "auto_capture": + return self._capture_from_local_runtime() + + # Dependencies specified as either req.txt or conda_env.yml + if ( + dependencies.endswith(".txt") + or dependencies.endswith(".yml") + or dependencies.endswith(".yaml") + ): + self._is_file_exists(dependencies) + return dependencies + + raise ValueError(f'Invalid dependencies provided: "{dependencies}"') + + def _capture_from_local_runtime(self) -> str: + """Generates dependencies list from the user's local runtime. + + Raises RuntimeEnvironmentError if not able to. + + Currently supports: conda environments + """ + + # Try to capture dependencies from the conda environment, if any. + conda_env_name = self._get_active_conda_env_name() + conda_env_prefix = self._get_active_conda_env_prefix() + if conda_env_name: + logger.info("Found conda_env_name: '%s'", conda_env_name) + elif conda_env_prefix: + logger.info("Found conda_env_prefix: '%s'", conda_env_prefix) + else: + raise ValueError("No conda environment seems to be active.") + + if conda_env_name == "base": + logger.warning( + "We recommend using an environment other than base to " + "isolate your project dependencies from conda dependencies" + ) + + local_dependencies_path = os.path.join(os.getcwd(), "env_snapshot.yml") + self._export_conda_env_from_prefix(conda_env_prefix, local_dependencies_path) + + return local_dependencies_path + + def _get_active_conda_env_prefix(self) -> str: + """Returns the conda prefix from the set environment variable. None otherwise.""" + return os.getenv("CONDA_PREFIX") + + def _get_active_conda_env_name(self) -> str: + """Returns the conda environment name from the set environment variable. None otherwise.""" + return os.getenv("CONDA_DEFAULT_ENV") + + def bootstrap( + self, local_dependencies_file: str, client_python_version: str, conda_env: str = None + ): + """Bootstraps the runtime environment by installing the additional dependencies if any. + + Args: + local_dependencies_file (str): path where dependencies file exists. + conda_env (str): conda environment to be activated. Default is None. + + Returns: None + """ + + if local_dependencies_file.endswith(".txt"): + if conda_env: + self._install_req_txt_in_conda_env(conda_env, local_dependencies_file) + self._write_conda_env_to_file(conda_env) + + else: + self._install_requirements_txt(local_dependencies_file, _python_executable()) + + elif local_dependencies_file.endswith(".yml") or local_dependencies_file.endswith(".yaml"): + if conda_env: + self._update_conda_env(conda_env, local_dependencies_file) + else: + conda_env = "sagemaker-runtime-env" + self._create_conda_env(conda_env, local_dependencies_file) + self._validate_python_version(client_python_version, conda_env) + self._write_conda_env_to_file(conda_env) + + def run_pre_exec_script(self, pre_exec_script_path: str): + """Runs script of pre-execution commands if existing. + + Args: + pre_exec_script_path (str): Path to pre-execution command script file. + """ + if os.path.isfile(pre_exec_script_path): + logger.info("Running pre-execution commands in '%s'", pre_exec_script_path) + return_code, error_logs = _run_pre_execution_command_script(pre_exec_script_path) + + if return_code: + error_message = ( + f"Encountered error while running pre-execution commands. Reason: {error_logs}" + ) + raise RuntimeEnvironmentError(error_message) + else: + logger.info( + "'%s' does not exist. Assuming no pre-execution commands to run", + pre_exec_script_path, + ) + + def _is_file_exists(self, dependencies): + """Check whether the dependencies file exists at the given location. + + Raises error if not + """ + if not os.path.isfile(dependencies): + raise ValueError(f'No dependencies file named "{dependencies}" was found.') + + def _install_requirements_txt(self, local_path, python_executable): + """Install requirements.txt file""" + cmd = f"{python_executable} -m pip install -r {local_path}" + logger.info("Running command: '%s' in the dir: '%s' ", cmd, os.getcwd()) + _run_shell_cmd(cmd) + logger.info("Command %s ran successfully", cmd) + + def _create_conda_env(self, env_name, local_path): + """Create conda env using conda yml file""" + + cmd = f"{self._get_conda_exe()} env create -n {env_name} --file {local_path}" + logger.info("Creating conda environment %s using: %s.", env_name, cmd) + _run_shell_cmd(cmd) + logger.info("Conda environment %s created successfully.", env_name) + + def _install_req_txt_in_conda_env(self, env_name, local_path): + """Install requirements.txt in the given conda environment""" + + cmd = f"{self._get_conda_exe()} run -n {env_name} pip install -r {local_path}" + logger.info("Activating conda env and installing requirements: %s", cmd) + _run_shell_cmd(cmd) + logger.info("Requirements installed successfully in conda env %s", env_name) + + def _update_conda_env(self, env_name, local_path): + """Update conda env using conda yml file""" + + cmd = f"{self._get_conda_exe()} env update -n {env_name} --file {local_path}" + logger.info("Updating conda env: %s", cmd) + _run_shell_cmd(cmd) + logger.info("Conda env %s updated succesfully", env_name) + + def _export_conda_env_from_prefix(self, prefix, local_path): + """Export the conda env to a conda yml file""" + + cmd = f"{self._get_conda_exe()} env export -p {prefix} --no-builds > {local_path}" + logger.info("Exporting conda environment: %s", cmd) + _run_shell_cmd(cmd) + logger.info("Conda environment %s exported successfully", prefix) + + def _write_conda_env_to_file(self, env_name): + """Writes conda env to the text file""" + + file_name = "remote_function_conda_env.txt" + file_path = os.path.join(os.getcwd(), file_name) + with open(file_path, "w") as output_file: + output_file.write(env_name) + + def _get_conda_exe(self): + """Checks whether conda or mamba is available to use""" + + if not subprocess.Popen(["which", "mamba"]).wait(): + return "mamba" + if not subprocess.Popen(["which", "conda"]).wait(): + return "conda" + raise ValueError("Neither conda nor mamba is installed on the image") + + def _python_version_in_conda_env(self, env_name): + """Returns python version inside a conda environment""" + cmd = f"{self._get_conda_exe()} run -n {env_name} python --version" + try: + output = ( + subprocess.check_output(shlex.split(cmd), stderr=subprocess.STDOUT) + .decode("utf-8") + .strip() + ) + # convert 'Python 3.7.16' to [3, 7, 16] + version = output.split("Python ")[1].split(".") + return version[0] + "." + version[1] + except subprocess.CalledProcessError as e: + raise RuntimeEnvironmentError(e.output) + + def _current_python_version(self): + """Returns the current python version where program is running""" + + return f"{sys.version_info.major}.{sys.version_info.minor}" + + def _validate_python_version(self, client_python_version: str, conda_env: str = None): + """Validate the python version + + Validates if the python version where remote function runs + matches the one used on client side. + """ + if conda_env: + job_python_version = self._python_version_in_conda_env(conda_env) + else: + job_python_version = self._current_python_version() + if client_python_version != job_python_version: + raise RuntimeEnvironmentError( + f"Python version found in the container is {job_python_version} which " + f"does not match python version {client_python_version} on the local client . " + f"Please make sure that the python version used in the training container " + f"is same as the local python version." + ) + + +def _run_and_get_output_shell_cmd(cmd: str) -> str: + """Run and return the output of the given shell command""" + return subprocess.check_output(shlex.split(cmd), stderr=subprocess.STDOUT).decode("utf-8") + + +def _run_pre_execution_command_script(script_path: str): + """This method runs a given shell script using subprocess + + Raises RuntimeEnvironmentError if the shell script fails + """ + current_dir = os.path.dirname(script_path) + + process = subprocess.Popen( + ["/bin/bash", "-eu", script_path], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + cwd=current_dir, + ) + + _log_output(process) + error_logs = _log_error(process) + return_code = process.wait() + + return return_code, error_logs + + +def _run_shell_cmd(cmd: str): + """This method runs a given shell command using subprocess + + Raises RuntimeEnvironmentError if the command fails + """ + + process = subprocess.Popen((cmd), stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True) + + _log_output(process) + error_logs = _log_error(process) + return_code = process.wait() + if return_code: + error_message = f"Encountered error while running command '{cmd}'. Reason: {error_logs}" + raise RuntimeEnvironmentError(error_message) + + +def _log_output(process): + """This method takes in Popen process and logs the output of that process""" + with process.stdout as pipe: + for line in iter(pipe.readline, b""): + logger.info(str(line, "UTF-8")) + + +def _log_error(process): + """This method takes in Popen process and logs the error of that process. + + Returns those logs as a string + """ + + error_logs = "" + with process.stderr as pipe: + for line in iter(pipe.readline, b""): + error_str = str(line, "UTF-8") + if "ERROR:" in error_str: + logger.error(error_str) + else: + logger.warning(error_str) + error_logs = error_logs + error_str + + return error_logs + + +def _python_executable(): + """Return the real path for the Python executable, if it exists. + + Return RuntimeEnvironmentError otherwise. + + Returns: + (str): The real path of the current Python executable. + """ + if not sys.executable: + raise RuntimeEnvironmentError( + "Failed to retrieve the path for the Python executable binary" + ) + return sys.executable + + +class RuntimeEnvironmentError(Exception): + """The base exception class for bootstrap env excepitons""" + + def __init__(self, message): + self.message = message + super().__init__(self.message) diff --git a/src/sagemaker/s3.py b/src/sagemaker/s3.py index e365a4eab5..9817f83d37 100644 --- a/src/sagemaker/s3.py +++ b/src/sagemaker/s3.py @@ -15,7 +15,9 @@ import pathlib import logging +import io +from typing import Union from six.moves.urllib.parse import urlparse from sagemaker.session import Session @@ -72,7 +74,7 @@ def upload(local_path, desired_s3_uri, kms_key=None, sagemaker_session=None): kms_key (str): The KMS key to use to encrypt the files. sagemaker_session (sagemaker.session.Session): Session object which manages interactions with Amazon SageMaker APIs and any other - AWS services needed. If not specified, the estimator creates one + AWS services needed. If not specified, one is created using the default AWS configuration chain. Returns: @@ -92,7 +94,9 @@ def upload(local_path, desired_s3_uri, kms_key=None, sagemaker_session=None): ) @staticmethod - def upload_string_as_file_body(body, desired_s3_uri=None, kms_key=None, sagemaker_session=None): + def upload_string_as_file_body( + body: str, desired_s3_uri=None, kms_key=None, sagemaker_session=None + ): """Static method that uploads a given file or directory to S3. Args: @@ -105,10 +109,11 @@ def upload_string_as_file_body(body, desired_s3_uri=None, kms_key=None, sagemake using the default AWS configuration chain. Returns: - str: The S3 uri of the uploaded file(s). + str: The S3 uri of the uploaded file. """ sagemaker_session = sagemaker_session or Session() + bucket, key = parse_s3_url(desired_s3_uri) sagemaker_session.upload_string_as_file_body( @@ -117,6 +122,39 @@ def upload_string_as_file_body(body, desired_s3_uri=None, kms_key=None, sagemake return desired_s3_uri + @staticmethod + def upload_bytes(b: Union[bytes, io.BytesIO], s3_uri, kms_key=None, sagemaker_session=None): + """Static method that uploads a given file or directory to S3. + + Args: + b (bytes or io.BytesIO): bytes. + s3_uri (str): The S3 uri to upload to. + kms_key (str): The KMS key to use to encrypt the files. + sagemaker_session (sagemaker.session.Session): Session object which + manages interactions with Amazon SageMaker APIs and any other + AWS services needed. If not specified, one is created + using the default AWS configuration chain. + + Returns: + str: The S3 uri of the uploaded file. + + """ + sagemaker_session = sagemaker_session or Session() + + bucket, object_key = parse_s3_url(s3_uri) + + if kms_key is not None: + extra_args = {"SSEKMSKeyId": kms_key, "ServerSideEncryption": "aws:kms"} + else: + extra_args = None + + b = b if isinstance(b, io.BytesIO) else io.BytesIO(b) + sagemaker_session.s3_resource.Bucket(bucket).upload_fileobj( + b, object_key, ExtraArgs=extra_args + ) + + return s3_uri + class S3Downloader(object): """Contains static methods for downloading directories or files from S3.""" @@ -131,8 +169,11 @@ def download(s3_uri, local_path, kms_key=None, sagemaker_session=None): kms_key (str): The KMS key to use to decrypt the files. sagemaker_session (sagemaker.session.Session): Session object which manages interactions with Amazon SageMaker APIs and any other - AWS services needed. If not specified, the estimator creates one + AWS services needed. If not specified, one is created using the default AWS configuration chain. + + Returns: + list[str]: List of local paths of downloaded files """ sagemaker_session = sagemaker_session or Session() bucket, key_prefix = parse_s3_url(url=s3_uri) @@ -141,28 +182,52 @@ def download(s3_uri, local_path, kms_key=None, sagemaker_session=None): else: extra_args = None - sagemaker_session.download_data( + return sagemaker_session.download_data( path=local_path, bucket=bucket, key_prefix=key_prefix, extra_args=extra_args ) @staticmethod - def read_file(s3_uri, sagemaker_session=None): - """Static method that returns the contents of an s3 uri file body as a string. + def read_file(s3_uri, sagemaker_session=None) -> str: + """Static method that returns the contents of a s3 uri file body as a string. Args: s3_uri (str): An S3 uri that refers to a single file. sagemaker_session (sagemaker.session.Session): Session object which manages interactions with Amazon SageMaker APIs and any other - AWS services needed. If not specified, the estimator creates one + AWS services needed. If not specified, one is created using the default AWS configuration chain. Returns: str: The body of the file. """ sagemaker_session = sagemaker_session or Session() - bucket, key_prefix = parse_s3_url(url=s3_uri) - return sagemaker_session.read_s3_file(bucket=bucket, key_prefix=key_prefix) + bucket, object_key = parse_s3_url(url=s3_uri) + + return sagemaker_session.read_s3_file(bucket=bucket, key_prefix=object_key) + + @staticmethod + def read_bytes(s3_uri, sagemaker_session=None) -> bytes: + """Static method that returns the contents of a s3 object as bytes. + + Args: + s3_uri (str): An S3 uri that refers to a s3 object. + sagemaker_session (sagemaker.session.Session): Session object which + manages interactions with Amazon SageMaker APIs and any other + AWS services needed. If not specified, one is created + using the default AWS configuration chain. + + Returns: + bytes: The body of the file. + """ + sagemaker_session = sagemaker_session or Session() + + bucket, object_key = parse_s3_url(s3_uri) + + bytes_io = io.BytesIO() + sagemaker_session.s3_resource.Bucket(bucket).download_fileobj(object_key, bytes_io) + bytes_io.seek(0) + return bytes_io.read() @staticmethod def list(s3_uri, sagemaker_session=None): @@ -172,7 +237,7 @@ def list(s3_uri, sagemaker_session=None): s3_uri (str): The S3 base uri to list objects in. sagemaker_session (sagemaker.session.Session): Session object which manages interactions with Amazon SageMaker APIs and any other - AWS services needed. If not specified, the estimator creates one + AWS services needed. If not specified, one is created using the default AWS configuration chain. Returns: diff --git a/src/sagemaker/session.py b/src/sagemaker/session.py index d99d8826fd..62833000f5 100644 --- a/src/sagemaker/session.py +++ b/src/sagemaker/session.py @@ -267,6 +267,9 @@ def _initialize( self.sagemaker_metrics_client = self.boto_session.client("sagemaker-metrics") prepend_user_agent(self.sagemaker_metrics_client) + self.s3_client = self.boto_session.client("s3", region_name=self.boto_region_name) + self.s3_resource = self.boto_session.resource("s3", region_name=self.boto_region_name) + self.local_mode = False if sagemaker_config: @@ -386,6 +389,9 @@ def download_data(self, path, bucket, key_prefix="", extra_args=None): download operation. Please refer to the ExtraArgs parameter in the boto3 documentation here: https://boto3.amazonaws.com/v1/documentation/api/latest/guide/s3-example-download-file.html + + Returns: + list[str]: List of local paths of downloaded files """ # Initialize the S3 client. if self.s3_client is None: @@ -405,7 +411,12 @@ def download_data(self, path, bucket, key_prefix="", extra_args=None): if next_token != "": request_parameters.update({"ContinuationToken": next_token}) response = s3.list_objects_v2(**request_parameters) - contents = response.get("Contents") + contents = response.get("Contents", None) + if not contents: + LOGGER.info( + "Nothing to download from bucket: %s, key_prefix: %s.", bucket, key_prefix + ) + return [] # For each object, save its key or directory. for s3_object in contents: key = s3_object.get("Key") @@ -414,6 +425,7 @@ def download_data(self, path, bucket, key_prefix="", extra_args=None): # For each object key, create the directory on the local machine if needed, and then # download the file. + downloaded_paths = [] for key in keys: tail_s3_uri_path = os.path.basename(key) if not os.path.splitext(key_prefix)[1]: @@ -424,6 +436,8 @@ def download_data(self, path, bucket, key_prefix="", extra_args=None): s3.download_file( Bucket=bucket, Key=key, Filename=destination_path, ExtraArgs=extra_args ) + downloaded_paths.append(destination_path) + return downloaded_paths def read_s3_file(self, bucket, key_prefix): """Read a single file from S3. @@ -481,10 +495,7 @@ def default_bucket(self): default_bucket = self._default_bucket_name_override if not default_bucket: - account = self.boto_session.client( - "sts", region_name=region, endpoint_url=sts_regional_endpoint(region) - ).get_caller_identity()["Account"] - default_bucket = "sagemaker-{}-{}".format(region, account) + default_bucket = generate_default_sagemaker_bucket_name(self.boto_session) self._create_s3_bucket_if_it_does_not_exist(bucket_name=default_bucket, region=region) @@ -2280,7 +2291,7 @@ def wait_for_auto_ml_job(self, job, poll=5): exceptions.UnexpectedStatusException: If the auto ml job fails. """ desc = _wait_until(lambda: _auto_ml_job_status(self.sagemaker_client, job), poll) - self._check_job_status(job, desc, "AutoMLJobStatus") + _check_job_status(job, desc, "AutoMLJobStatus") return desc def logs_for_auto_ml_job( # noqa: C901 - suppress complexity warning for this method @@ -2306,7 +2317,7 @@ def logs_for_auto_ml_job( # noqa: C901 - suppress complexity warning for this m description = _wait_until(lambda: self.describe_auto_ml_job(job_name), poll) instance_count, stream_names, positions, client, log_group, dot, color_wrap = _logs_init( - self, description, job="AutoML" + self.boto_session, description, job="AutoML" ) state = _get_initial_job_state(description, "AutoMLJobStatus", wait) @@ -2361,7 +2372,7 @@ def logs_for_auto_ml_job( # noqa: C901 - suppress complexity warning for this m state = LogState.JOB_COMPLETE if wait: - self._check_job_status(job_name, description, "AutoMLJobStatus") + _check_job_status(job_name, description, "AutoMLJobStatus") if dot: print() @@ -4099,7 +4110,7 @@ def wait_for_job(self, job, poll=5): desc = _wait_until_training_done( lambda last_desc: _train_done(self.sagemaker_client, job, last_desc), None, poll ) - self._check_job_status(job, desc, "TrainingJobStatus") + _check_job_status(job, desc, "TrainingJobStatus") return desc def wait_for_processing_job(self, job, poll=5): @@ -4117,7 +4128,7 @@ def wait_for_processing_job(self, job, poll=5): exceptions.UnexpectedStatusException: If the processing job fails. """ desc = _wait_until(lambda: _processing_job_status(self.sagemaker_client, job), poll) - self._check_job_status(job, desc, "ProcessingJobStatus") + _check_job_status(job, desc, "ProcessingJobStatus") return desc def wait_for_compilation_job(self, job, poll=5): @@ -4135,7 +4146,7 @@ def wait_for_compilation_job(self, job, poll=5): exceptions.UnexpectedStatusException: If the compilation job fails. """ desc = _wait_until(lambda: _compilation_job_status(self.sagemaker_client, job), poll) - self._check_job_status(job, desc, "CompilationJobStatus") + _check_job_status(job, desc, "CompilationJobStatus") return desc def wait_for_edge_packaging_job(self, job, poll=5): @@ -4153,7 +4164,7 @@ def wait_for_edge_packaging_job(self, job, poll=5): exceptions.UnexpectedStatusException: If the edge packaging job fails. """ desc = _wait_until(lambda: _edge_packaging_job_status(self.sagemaker_client, job), poll) - self._check_job_status(job, desc, "EdgePackagingJobStatus") + _check_job_status(job, desc, "EdgePackagingJobStatus") return desc def wait_for_tuning_job(self, job, poll=5): @@ -4171,7 +4182,7 @@ def wait_for_tuning_job(self, job, poll=5): exceptions.UnexpectedStatusException: If the hyperparameter tuning job fails. """ desc = _wait_until(lambda: _tuning_job_status(self.sagemaker_client, job), poll) - self._check_job_status(job, desc, "HyperParameterTuningJobStatus") + _check_job_status(job, desc, "HyperParameterTuningJobStatus") return desc def describe_transform_job(self, job_name): @@ -4200,7 +4211,7 @@ def wait_for_transform_job(self, job, poll=5): exceptions.UnexpectedStatusException: If the transform job fails. """ desc = _wait_until(lambda: _transform_job_status(self.sagemaker_client, job), poll) - self._check_job_status(job, desc, "TransformJobStatus") + _check_job_status(job, desc, "TransformJobStatus") return desc def stop_transform_job(self, name): @@ -4224,48 +4235,6 @@ def stop_transform_job(self, name): LOGGER.error("Error occurred while attempting to stop transform job: %s.", name) raise - def _check_job_status(self, job, desc, status_key_name): - """Check to see if the job completed successfully. - - If not, construct and raise a exceptions. (UnexpectedStatusException). - - Args: - job (str): The name of the job to check. - desc (dict[str, str]): The result of ``describe_training_job()``. - status_key_name (str): Status key name to check for. - - Raises: - exceptions.CapacityError: If the training job fails with CapacityError. - exceptions.UnexpectedStatusException: If the training job fails. - """ - status = desc[status_key_name] - # If the status is capital case, then convert it to Camel case - status = _STATUS_CODE_TABLE.get(status, status) - - if status == "Stopped": - LOGGER.warning( - "Job ended with status 'Stopped' rather than 'Completed'. " - "This could mean the job timed out or stopped early for some other reason: " - "Consider checking whether it completed as you expect." - ) - elif status != "Completed": - reason = desc.get("FailureReason", "(No reason provided)") - job_type = status_key_name.replace("JobStatus", " job") - message = "Error for {job_type} {job_name}: {status}. Reason: {reason}".format( - job_type=job_type, job_name=job, status=status, reason=reason - ) - if "CapacityError" in str(reason): - raise exceptions.CapacityError( - message=message, - allowed_statuses=["Completed", "Stopped"], - actual_status=status, - ) - raise exceptions.UnexpectedStatusException( - message=message, - allowed_statuses=["Completed", "Stopped"], - actual_status=status, - ) - def wait_for_endpoint(self, endpoint, poll=30): """Wait for an Amazon SageMaker endpoint deployment to complete. @@ -4635,9 +4604,7 @@ def get_caller_identity_arn(self): return role - def logs_for_job( # noqa: C901 - suppress complexity warning for this method - self, job_name, wait=False, poll=10, log_type="All" - ): + def logs_for_job(self, job_name, wait=False, poll=10, log_type="All", timeout=None): """Display logs for a given training job, optionally tailing them until job is complete. If the output is a tty or a Jupyter cell, it will be color-coded @@ -4649,124 +4616,16 @@ def logs_for_job( # noqa: C901 - suppress complexity warning for this method (default: False). poll (int): The interval in seconds between polling for new log entries and job completion (default: 5). - + log_type ([str]): A list of strings specifying which logs to print. Acceptable + strings are "All", "None", "Training", or "Rules". To maintain backwards + compatibility, boolean values are also accepted and converted to strings. + timeout (int): Timeout in seconds to wait until the job is completed. ``None`` by + default. Raises: exceptions.CapacityError: If the training job fails with CapacityError. exceptions.UnexpectedStatusException: If waiting and the training job fails. """ - - description = _wait_until(lambda: self.describe_training_job(job_name), poll) - print(secondary_training_status_message(description, None), end="") - - instance_count, stream_names, positions, client, log_group, dot, color_wrap = _logs_init( - self, description, job="Training" - ) - - state = _get_initial_job_state(description, "TrainingJobStatus", wait) - - # The loop below implements a state machine that alternates between checking the job status - # and reading whatever is available in the logs at this point. Note, that if we were - # called with wait == False, we never check the job status. - # - # If wait == TRUE and job is not completed, the initial state is TAILING - # If wait == FALSE, the initial state is COMPLETE (doesn't matter if the job really is - # complete). - # - # The state table: - # - # STATE ACTIONS CONDITION NEW STATE - # ---------------- ---------------- ----------------- ---------------- - # TAILING Read logs, Pause, Get status Job complete JOB_COMPLETE - # Else TAILING - # JOB_COMPLETE Read logs, Pause Any COMPLETE - # COMPLETE Read logs, Exit N/A - # - # Notes: - # - The JOB_COMPLETE state forces us to do an extra pause and read any items that got to - # Cloudwatch after the job was marked complete. - last_describe_job_call = time.time() - last_description = description - last_debug_rule_statuses = None - last_profiler_rule_statuses = None - - while True: - _flush_log_streams( - stream_names, - instance_count, - client, - log_group, - job_name, - positions, - dot, - color_wrap, - ) - if state == LogState.COMPLETE: - break - - time.sleep(poll) - - if state == LogState.JOB_COMPLETE: - state = LogState.COMPLETE - elif time.time() - last_describe_job_call >= 30: - description = self.sagemaker_client.describe_training_job(TrainingJobName=job_name) - last_describe_job_call = time.time() - - if secondary_training_status_changed(description, last_description): - print() - print(secondary_training_status_message(description, last_description), end="") - last_description = description - - status = description["TrainingJobStatus"] - - if status in ("Completed", "Failed", "Stopped"): - print() - state = LogState.JOB_COMPLETE - - # Print prettified logs related to the status of SageMaker Debugger rules. - debug_rule_statuses = description.get("DebugRuleEvaluationStatuses", {}) - if ( - debug_rule_statuses - and _rule_statuses_changed(debug_rule_statuses, last_debug_rule_statuses) - and (log_type in {"All", "Rules"}) - ): - for status in debug_rule_statuses: - rule_log = ( - f"{status['RuleConfigurationName']}: {status['RuleEvaluationStatus']}" - ) - print(rule_log) - - last_debug_rule_statuses = debug_rule_statuses - - # Print prettified logs related to the status of SageMaker Profiler rules. - profiler_rule_statuses = description.get("ProfilerRuleEvaluationStatuses", {}) - if ( - profiler_rule_statuses - and _rule_statuses_changed(profiler_rule_statuses, last_profiler_rule_statuses) - and (log_type in {"All", "Rules"}) - ): - for status in profiler_rule_statuses: - rule_log = ( - f"{status['RuleConfigurationName']}: {status['RuleEvaluationStatus']}" - ) - print(rule_log) - - last_profiler_rule_statuses = profiler_rule_statuses - - if wait: - self._check_job_status(job_name, description, "TrainingJobStatus") - if dot: - print() - # Customers are not billed for hardware provisioning, so billable time is less than - # total time - training_time = description.get("TrainingTimeInSeconds") - billable_time = description.get("BillableTimeInSeconds") - if training_time is not None: - print("Training seconds:", training_time * instance_count) - if billable_time is not None: - print("Billable seconds:", billable_time * instance_count) - if description.get("EnableManagedSpotTraining"): - saving = (1 - float(billable_time) / training_time) * 100 - print("Managed Spot Training savings: {:.1f}%".format(saving)) + _logs_for_job(self.boto_session, job_name, wait, poll, log_type, timeout) def logs_for_processing_job(self, job_name, wait=False, poll=10): """Display logs for a given processing job, optionally tailing them until the is complete. @@ -4785,7 +4644,7 @@ def logs_for_processing_job(self, job_name, wait=False, poll=10): description = _wait_until(lambda: self.describe_processing_job(job_name), poll) instance_count, stream_names, positions, client, log_group, dot, color_wrap = _logs_init( - self, description, job="Processing" + self.boto_session, description, job="Processing" ) state = _get_initial_job_state(description, "ProcessingJobStatus", wait) @@ -4842,7 +4701,7 @@ def logs_for_processing_job(self, job_name, wait=False, poll=10): state = LogState.JOB_COMPLETE if wait: - self._check_job_status(job_name, description, "ProcessingJobStatus") + _check_job_status(job_name, description, "ProcessingJobStatus") if dot: print() @@ -4866,7 +4725,7 @@ def logs_for_transform_job(self, job_name, wait=False, poll=10): description = _wait_until(lambda: self.describe_transform_job(job_name), poll) instance_count, stream_names, positions, client, log_group, dot, color_wrap = _logs_init( - self, description, job="Transform" + self.boto_session, description, job="Transform" ) state = _get_initial_job_state(description, "TransformJobStatus", wait) @@ -4923,7 +4782,7 @@ def logs_for_transform_job(self, job_name, wait=False, poll=10): state = LogState.JOB_COMPLETE if wait: - self._check_job_status(job_name, description, "TransformJobStatus") + _check_job_status(job_name, description, "TransformJobStatus") if dot: print() @@ -5615,7 +5474,7 @@ def wait_for_inference_recommendations_job( else: raise ValueError("log_level must be either Quiet or Verbose") desc = _describe_inference_recommendations_job_status(self.sagemaker_client, job_name) - self._check_job_status(job_name, desc, "Status") + _check_job_status(job_name, desc, "Status") return desc @@ -6013,6 +5872,19 @@ def get_execution_role(sagemaker_session=None): raise ValueError(message.format(arn)) +def generate_default_sagemaker_bucket_name(boto_session): + """Generates a name for the default sagemaker S3 bucket. + + Args: + boto_session (boto3.session.Session): The underlying Boto3 session which AWS service + """ + region = boto_session.region_name + account = boto_session.client( + "sts", region_name=region, endpoint_url=sts_regional_endpoint(region) + ).get_caller_identity()["Account"] + return "sagemaker-{}-{}".format(region, account) + + def _deployment_entity_exists(describe_fn): """Placeholder docstring""" try: @@ -6483,7 +6355,199 @@ def _rule_statuses_changed(current_statuses, last_statuses): return False -def _logs_init(sagemaker_session, description, job): +def _logs_for_job( # noqa: C901 - suppress complexity warning for this method + boto_session, job_name, wait=False, poll=10, log_type="All", timeout=None +): + """Display logs for a given training job, optionally tailing them until job is complete. + + If the output is a tty or a Jupyter cell, it will be color-coded + based on which instance the log entry is from. + + Args: + boto_session (boto3.session.Session): The underlying Boto3 session which AWS service + calls are delegated to (default: None). If not provided, one is created with + default AWS configuration chain. + job_name (str): Name of the training job to display the logs for. + wait (bool): Whether to keep looking for new log entries until the job completes + (default: False). + poll (int): The interval in seconds between polling for new log entries and job + completion (default: 5). + log_type ([str]): A list of strings specifying which logs to print. Acceptable + strings are "All", "None", "Training", or "Rules". To maintain backwards + compatibility, boolean values are also accepted and converted to strings. + timeout (int): Timeout in seconds to wait until the job is completed. ``None`` by + default. + Returns: + Last call to sagemaker DescribeTrainingJob + Raises: + exceptions.CapacityError: If the training job fails with CapacityError. + exceptions.UnexpectedStatusException: If waiting and the training job fails. + """ + sagemaker_client = boto_session.client("sagemaker") + request_end_time = time.time() + timeout if timeout else None + description = sagemaker_client.describe_training_job(TrainingJobName=job_name) + print(secondary_training_status_message(description, None), end="") + + instance_count, stream_names, positions, client, log_group, dot, color_wrap = _logs_init( + boto_session, description, job="Training" + ) + + state = _get_initial_job_state(description, "TrainingJobStatus", wait) + + # The loop below implements a state machine that alternates between checking the job status + # and reading whatever is available in the logs at this point. Note, that if we were + # called with wait == False, we never check the job status. + # + # If wait == TRUE and job is not completed, the initial state is TAILING + # If wait == FALSE, the initial state is COMPLETE (doesn't matter if the job really is + # complete). + # + # The state table: + # + # STATE ACTIONS CONDITION NEW STATE + # ---------------- ---------------- ----------------- ---------------- + # TAILING Read logs, Pause, Get status Job complete JOB_COMPLETE + # Else TAILING + # JOB_COMPLETE Read logs, Pause Any COMPLETE + # COMPLETE Read logs, Exit N/A + # + # Notes: + # - The JOB_COMPLETE state forces us to do an extra pause and read any items that got to + # Cloudwatch after the job was marked complete. + last_describe_job_call = time.time() + last_description = description + last_debug_rule_statuses = None + last_profiler_rule_statuses = None + + while True: + _flush_log_streams( + stream_names, + instance_count, + client, + log_group, + job_name, + positions, + dot, + color_wrap, + ) + if timeout and time.time() > request_end_time: + print("Timeout Exceeded. {} seconds elapsed.".format(timeout)) + break + + if state == LogState.COMPLETE: + break + + time.sleep(poll) + + if state == LogState.JOB_COMPLETE: + state = LogState.COMPLETE + elif time.time() - last_describe_job_call >= 30: + description = sagemaker_client.describe_training_job(TrainingJobName=job_name) + last_describe_job_call = time.time() + + if secondary_training_status_changed(description, last_description): + print() + print(secondary_training_status_message(description, last_description), end="") + last_description = description + + status = description["TrainingJobStatus"] + + if status in ("Completed", "Failed", "Stopped"): + print() + state = LogState.JOB_COMPLETE + + # Print prettified logs related to the status of SageMaker Debugger rules. + debug_rule_statuses = description.get("DebugRuleEvaluationStatuses", {}) + if ( + debug_rule_statuses + and _rule_statuses_changed(debug_rule_statuses, last_debug_rule_statuses) + and (log_type in {"All", "Rules"}) + ): + for status in debug_rule_statuses: + rule_log = ( + f"{status['RuleConfigurationName']}: {status['RuleEvaluationStatus']}" + ) + print(rule_log) + + last_debug_rule_statuses = debug_rule_statuses + + # Print prettified logs related to the status of SageMaker Profiler rules. + profiler_rule_statuses = description.get("ProfilerRuleEvaluationStatuses", {}) + if ( + profiler_rule_statuses + and _rule_statuses_changed(profiler_rule_statuses, last_profiler_rule_statuses) + and (log_type in {"All", "Rules"}) + ): + for status in profiler_rule_statuses: + rule_log = ( + f"{status['RuleConfigurationName']}: {status['RuleEvaluationStatus']}" + ) + print(rule_log) + + last_profiler_rule_statuses = profiler_rule_statuses + + if wait: + _check_job_status(job_name, description, "TrainingJobStatus") + if dot: + print() + # Customers are not billed for hardware provisioning, so billable time is less than + # total time + training_time = description.get("TrainingTimeInSeconds") + billable_time = description.get("BillableTimeInSeconds") + if training_time is not None: + print("Training seconds:", training_time * instance_count) + if billable_time is not None: + print("Billable seconds:", billable_time * instance_count) + if description.get("EnableManagedSpotTraining"): + saving = (1 - float(billable_time) / training_time) * 100 + print("Managed Spot Training savings: {:.1f}%".format(saving)) + return last_description + + +def _check_job_status(job, desc, status_key_name): + """Check to see if the job completed successfully. + + If not, construct and raise a exceptions. (UnexpectedStatusException). + + Args: + job (str): The name of the job to check. + desc (dict[str, str]): The result of ``describe_training_job()``. + status_key_name (str): Status key name to check for. + + Raises: + exceptions.CapacityError: If the training job fails with CapacityError. + exceptions.UnexpectedStatusException: If the training job fails. + """ + status = desc[status_key_name] + # If the status is capital case, then convert it to Camel case + status = _STATUS_CODE_TABLE.get(status, status) + + if status == "Stopped": + LOGGER.warning( + "Job ended with status 'Stopped' rather than 'Completed'. " + "This could mean the job timed out or stopped early for some other reason: " + "Consider checking whether it completed as you expect." + ) + elif status != "Completed": + reason = desc.get("FailureReason", "(No reason provided)") + job_type = status_key_name.replace("JobStatus", " job") + message = "Error for {job_type} {job_name}: {status}. Reason: {reason}".format( + job_type=job_type, job_name=job, status=status, reason=reason + ) + if "CapacityError" in str(reason): + raise exceptions.CapacityError( + message=message, + allowed_statuses=["Completed", "Stopped"], + actual_status=status, + ) + raise exceptions.UnexpectedStatusException( + message=message, + allowed_statuses=["Completed", "Stopped"], + actual_status=status, + ) + + +def _logs_init(boto_session, description, job): """Placeholder docstring""" if job == "Training": if "InstanceGroups" in description["ResourceConfig"]: @@ -6505,7 +6569,7 @@ def _logs_init(sagemaker_session, description, job): # Increase retries allowed (from default of 4), as we don't want waiting for a training job # to be interrupted by a transient exception. config = botocore.config.Config(retries={"max_attempts": 15}) - client = sagemaker_session.boto_session.client("logs", config=config) + client = boto_session.client("logs", config=config) log_group = "/aws/sagemaker/" + job + "Jobs" dot = False diff --git a/tests/data/config/config.yaml b/tests/data/config/config.yaml index 0abb47e70e..fc052f2ddd 100644 --- a/tests/data/config/config.yaml +++ b/tests/data/config/config.yaml @@ -118,3 +118,26 @@ SageMaker: OutputConfig: KmsKeyId: 'kmskeyid1' RoleArn: 'arn:aws:iam::555555555555:role/IMRole' + PythonSDK: + Modules: + RemoteFunction: + Dependencies: "./requirements.txt" + EnvironmentVariables: + "var1": "value1" + "var2": "value2" + ImageUri: "123456789012.dkr.ecr.us-west-2.amazonaws.com/myimage:latest" + IncludeLocalWorkDir: true + InstanceType: "ml.m5.xlarge" + JobCondaEnvironment: "some_conda_env" + RoleArn: "arn:aws:iam::555555555555:role/IMRole" + S3KmsKeyId: "kmskeyid1" + S3RootUri: "s3://my-bucket/key" + Tags: + - Key: "tag1" + Value: "tagValue1" + VolumeKmsKeyId: "kmskeyid2" + VpcConfig: + SecurityGroupIds: + - 'sg123' + Subnets: + - 'subnet-1234' diff --git a/tests/data/remote_function/config.yaml b/tests/data/remote_function/config.yaml new file mode 100644 index 0000000000..6239de48cc --- /dev/null +++ b/tests/data/remote_function/config.yaml @@ -0,0 +1,18 @@ +SchemaVersion: '1.0' +SageMaker: + PythonSDK: + Modules: + RemoteFunction: + Dependencies: "path/to/requirements.txt" + PreExecutionCommands: ["command_1", "command_2"] + EnableInterContainerTrafficEncryption: true + EnvironmentVariables: {"EnvVarKey": "EnvVarValue"} + IncludeLocalWorkDir: true + InstanceType: "ml.m5.large" + JobCondaEnvironment: "my_conda_env" + S3KmsKeyId: "someS3KmsKey" + VpcConfig: + SecurityGroupIds: ["sg123"] + Subnets: ["subnet-1234"] + Tags: [{"Key": "someTagKey", "Value":"someTagValue"}, {"Key":"someTagKey2", "Value":"someTagValue2"}] + VolumeKmsKeyId: "someVolumeKmsKey" diff --git a/tests/data/remote_function/non_existent_requirements.txt b/tests/data/remote_function/non_existent_requirements.txt new file mode 100644 index 0000000000..12e3093c4a --- /dev/null +++ b/tests/data/remote_function/non_existent_requirements.txt @@ -0,0 +1 @@ +does_not_exist diff --git a/tests/data/remote_function/old_deps_requirements.txt b/tests/data/remote_function/old_deps_requirements.txt new file mode 100644 index 0000000000..684864f2bc --- /dev/null +++ b/tests/data/remote_function/old_deps_requirements.txt @@ -0,0 +1 @@ +pandas==1.1.0 diff --git a/tests/data/remote_function/pre_exec_commands b/tests/data/remote_function/pre_exec_commands new file mode 100644 index 0000000000..6f9ace9f18 --- /dev/null +++ b/tests/data/remote_function/pre_exec_commands @@ -0,0 +1,4 @@ +echo "test-content-1" > test_file_1 +echo "test-content-2" > test_file_2 +echo "test-content-3" > test_file_3 +rm ./test_file_2 \ No newline at end of file diff --git a/tests/data/remote_function/pre_exec_commands_bad_cmd b/tests/data/remote_function/pre_exec_commands_bad_cmd new file mode 100644 index 0000000000..87cf9b8185 --- /dev/null +++ b/tests/data/remote_function/pre_exec_commands_bad_cmd @@ -0,0 +1,3 @@ +echo "test-content-1" > test_file_1 +bws sagemaker describe-training-job +echo "test-content-3" > test_file_3 \ No newline at end of file diff --git a/tests/data/remote_function/requirements.txt b/tests/data/remote_function/requirements.txt new file mode 100644 index 0000000000..886acca403 --- /dev/null +++ b/tests/data/remote_function/requirements.txt @@ -0,0 +1 @@ +scipy==1.7.3 diff --git a/tests/integ/sagemaker/remote_function/__init__.py b/tests/integ/sagemaker/remote_function/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/integ/sagemaker/remote_function/conftest.py b/tests/integ/sagemaker/remote_function/conftest.py new file mode 100644 index 0000000000..9f62315d77 --- /dev/null +++ b/tests/integ/sagemaker/remote_function/conftest.py @@ -0,0 +1,214 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +import base64 +import os +import subprocess +import shutil +import pytest +import docker + +from sagemaker.utils import sagemaker_timestamp, _tmpdir, sts_regional_endpoint + +REPO_ACCOUNT_ID = "033110030271" + +REPO_NAME = "remote-function-dummy-container" + +DOCKERFILE_TEMPLATE = ( + "FROM public.ecr.aws/docker/library/python:{py_version}-slim\n\n" + "RUN apt-get update -y \ + && apt-get install -y unzip curl\n\n" + "RUN curl 'https://awscli.amazonaws.com/awscli-exe-linux-x86_64.zip' -o 'awscliv2.zip' \ + && unzip awscliv2.zip \ + && ./aws/install\n\n" + "COPY {source_archive} ./\n" + "RUN pip3 install '{source_archive}'\n" + "RUN rm {source_archive}\n" +) + +DOCKERFILE_TEMPLATE_WITH_CONDA = ( + "FROM public.ecr.aws/docker/library/python:{py_version}-slim\n\n" + 'SHELL ["/bin/bash", "-c"]\n' + "RUN apt-get update -y \ + && apt-get install -y unzip curl\n\n" + "RUN curl -L -O 'https://github.com/conda-forge/miniforge/releases/latest/download/Mambaforge-Linux-x86_64.sh' \ + && bash Mambaforge-Linux-x86_64.sh -b -p '/opt/conda' \ + && /opt/conda/bin/conda init bash\n\n" + "ENV PATH $PATH:/opt/conda/bin\n" + "RUN mamba create -n integ_test_env python={py_version} -y \ + && mamba create -n default_env python={py_version} -y\n" + "COPY {source_archive} ./\n" + "RUN pip install '{source_archive}' \ + && mamba run -n base pip install '{source_archive}' \ + && mamba run -n default_env pip install '{source_archive}' \ + && mamba run -n integ_test_env pip install '{source_archive}'\n" + "ENV SHELL=/bin/bash\n" + "ENV SAGEMAKER_JOB_CONDA_ENV=default_env\n" +) + +CONDA_YML_FILE_TEMPLATE = ( + "name: integ_test_env\n" + "channels:\n" + " - defaults\n" + "dependencies:\n" + " - scipy=1.7.3\n" + " - pip:\n" + " - /sagemaker-{sagemaker_version}.tar.gz\n" + "prefix: /opt/conda/bin/conda\n" +) + +CONDA_YML_FILE_WITH_SM_FROM_INPUT_CHANNEL = ( + "name: integ_test_env\n" + "channels:\n" + " - defaults\n" + "dependencies:\n" + " - scipy=1.7.3\n" + " - pip:\n" + " - sagemaker-2.132.1.dev0-py2.py3-none-any.whl\n" + "prefix: /opt/conda/bin/conda\n" +) + + +@pytest.fixture(scope="package") +def dummy_container_without_error(sagemaker_session): + # TODO: the python version should be dynamically specified instead of hardcoding + ecr_uri = _build_container(sagemaker_session, "3.7", DOCKERFILE_TEMPLATE) + return ecr_uri + + +@pytest.fixture(scope="package") +def dummy_container_incompatible_python_runtime(sagemaker_session): + ecr_uri = _build_container(sagemaker_session, "3.10", DOCKERFILE_TEMPLATE) + return ecr_uri + + +@pytest.fixture(scope="package") +def dummy_container_with_conda(sagemaker_session): + ecr_uri = _build_container(sagemaker_session, "3.7", DOCKERFILE_TEMPLATE_WITH_CONDA) + return ecr_uri + + +@pytest.fixture(scope="package") +def conda_env_yml(): + """Write conda yml file needed for tests""" + + conda_yml_file_name = "conda_env.yml" + with open(os.path.join(os.getcwd(), "VERSION"), "r") as version_file: + sagemaker_version = version_file.readline().strip() + conda_file_path = os.path.join(os.getcwd(), conda_yml_file_name) + with open(conda_file_path, "w") as yml_file: + yml_file.writelines(CONDA_YML_FILE_TEMPLATE.format(sagemaker_version=sagemaker_version)) + yield conda_file_path + + # cleanup + if os.path.isfile(conda_yml_file_name): + os.remove(conda_yml_file_name) + + +@pytest.fixture(scope="package") +def conda_yml_file_sm_from_input_channel(): + """Write conda yml file needed for tests""" + + conda_yml_file_name = "conda_env_sm_from_input_channel.yml" + conda_file_path = os.path.join(os.getcwd(), conda_yml_file_name) + + with open(conda_file_path, "w") as yml_file: + yml_file.writelines(CONDA_YML_FILE_WITH_SM_FROM_INPUT_CHANNEL) + yield conda_file_path + + # cleanup + if os.path.isfile(conda_yml_file_name): + os.remove(conda_yml_file_name) + + +def _build_container(sagemaker_session, py_version, docker_templete): + """Build a dummy test container locally and push a container to an ecr repo""" + + region = sagemaker_session.boto_region_name + image_tag = f"{py_version.replace('.', '-')}-{sagemaker_timestamp()}" + ecr_client = sagemaker_session.boto_session.client("ecr") + username, password = _ecr_login(ecr_client) + + with _tmpdir() as tmpdir: + print("building docker image locally in ", tmpdir) + print("building source archive...") + source_archive = _generate_and_move_sagemaker_sdk_tar(tmpdir) + with open(os.path.join(tmpdir, "Dockerfile"), "w") as file: + file.writelines( + docker_templete.format(py_version=py_version, source_archive=source_archive) + ) + + docker_client = docker.from_env() + + print("building docker image...") + image, build_logs = docker_client.images.build(path=tmpdir, tag=REPO_NAME, rm=True) + + if _is_repository_exists(ecr_client, REPO_NAME): + sts_client = sagemaker_session.boto_session.client( + "sts", region_name=region, endpoint_url=sts_regional_endpoint(region) + ) + account_id = sts_client.get_caller_identity()["Account"] + # When the test is run locally, repo will exist in same account whose credentials are used to run the test + ecr_image = _ecr_image_uri( + account_id, sagemaker_session.boto_region_name, REPO_NAME, image_tag + ) + else: + ecr_image = _ecr_image_uri( + REPO_ACCOUNT_ID, + sagemaker_session.boto_region_name, + REPO_NAME, + image_tag, + ) + + print("pushing image...") + image.tag(ecr_image, tag=image_tag) + docker_client.images.push(ecr_image, auth_config={"username": username, "password": password}) + + return ecr_image + + +def _is_repository_exists(ecr_client, repo_name): + try: + ecr_client.describe_repositories(repositoryNames=[repo_name]) + return True + except ecr_client.exceptions.RepositoryNotFoundException: + return False + + +def _ecr_login(ecr_client): + """Get a login credentials for an ecr client.""" + login = ecr_client.get_authorization_token() + b64token = login["authorizationData"][0]["authorizationToken"].encode("utf-8") + username, password = base64.b64decode(b64token).decode("utf-8").split(":") + return username, password + + +def _ecr_image_uri(account, region, image_name, tag): + """Build an ECR image URI based in account, region and container name""" + return "{}.dkr.ecr.{}.amazonaws.com/{}:{}".format(account, region, image_name, tag) + + +def _generate_and_move_sagemaker_sdk_tar(destination_folder): + """ + Run setup.py sdist to generate the PySDK tar file and + copy it to appropriate test data folder + """ + subprocess.run("python3 setup.py sdist", shell=True) + dist_dir = "dist" + source_archive = os.listdir(dist_dir)[0] + source_path = os.path.join(dist_dir, source_archive) + destination_path = os.path.join(destination_folder, source_archive) + shutil.copy2(source_path, destination_path) + + return source_archive diff --git a/tests/integ/sagemaker/remote_function/helpers/__init__.py b/tests/integ/sagemaker/remote_function/helpers/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/integ/sagemaker/remote_function/helpers/local_module.py b/tests/integ/sagemaker/remote_function/helpers/local_module.py new file mode 100644 index 0000000000..089443fd35 --- /dev/null +++ b/tests/integ/sagemaker/remote_function/helpers/local_module.py @@ -0,0 +1,5 @@ +from __future__ import absolute_import + + +def square(x): + return x * x diff --git a/tests/integ/sagemaker/remote_function/helpers/nested_helper/local_module2.py b/tests/integ/sagemaker/remote_function/helpers/nested_helper/local_module2.py new file mode 100644 index 0000000000..3edffd61ab --- /dev/null +++ b/tests/integ/sagemaker/remote_function/helpers/nested_helper/local_module2.py @@ -0,0 +1,5 @@ +from __future__ import absolute_import + + +def cube(x): + return x**3 diff --git a/tests/integ/sagemaker/remote_function/test_decorator.py b/tests/integ/sagemaker/remote_function/test_decorator.py new file mode 100644 index 0000000000..6811a7a06d --- /dev/null +++ b/tests/integ/sagemaker/remote_function/test_decorator.py @@ -0,0 +1,598 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import +import time + + +import pytest +import os +import logging +import random +import string +import pandas as pd +from sagemaker.experiments.run import Run, load_run +from tests.integ.sagemaker.experiments.helpers import cleanup_exp_resources +from sagemaker.experiments.trial_component import _TrialComponent +from sagemaker.experiments._api_types import _TrialComponentStatusType + +from sagemaker.remote_function import remote +from sagemaker.remote_function.runtime_environment.runtime_environment_manager import ( + RuntimeEnvironmentError, +) +from sagemaker.remote_function.errors import ( + DeserializationError, + SerializationError, +) +from sagemaker.utils import unique_name_from_base + +from tests.integ.kms_utils import get_or_create_kms_key +from tests.integ import DATA_DIR + +ROLE = "SageMakerRole" + + +@pytest.fixture(scope="module") +def s3_kms_key(sagemaker_session): + return get_or_create_kms_key(sagemaker_session=sagemaker_session) + + +def test_decorator(sagemaker_session, dummy_container_without_error, cpu_instance_type): + @remote( + role=ROLE, + image_uri=dummy_container_without_error, + instance_type=cpu_instance_type, + sagemaker_session=sagemaker_session, + keep_alive_period_in_seconds=60, + ) + def divide(x, y): + return x / y + + assert divide(10, 2) == 5 + assert divide(20, 2) == 10 + + +def test_decorated_function_raises_exception( + sagemaker_session, dummy_container_without_error, cpu_instance_type +): + @remote( + role=ROLE, + image_uri=dummy_container_without_error, + instance_type=cpu_instance_type, + sagemaker_session=sagemaker_session, + ) + def divide(x, y): + logging.warning(f"{x}/{y}") + return x / y + + with pytest.raises(ZeroDivisionError): + divide(10, 0) + + +def test_remote_python_runtime_is_incompatible( + sagemaker_session, dummy_container_incompatible_python_runtime, cpu_instance_type +): + @remote( + role=ROLE, + image_uri=dummy_container_incompatible_python_runtime, + instance_type=cpu_instance_type, + sagemaker_session=sagemaker_session, + ) + def divide(x, y): + return x / y + + with pytest.raises( + RuntimeEnvironmentError, + match=( + "Please make sure that the python version used in the training container " + "is same as the local python version." + ), + ): + divide(10, 2) + + +# TODO: add VPC settings, update SageMakerRole with KMS permissions +@pytest.mark.skip +def test_advanced_job_setting( + sagemaker_session, dummy_container_without_error, cpu_instance_type, s3_kms_key +): + @remote( + role=ROLE, + image_uri=dummy_container_without_error, + instance_type=cpu_instance_type, + s3_kms_key=s3_kms_key, + sagemaker_session=sagemaker_session, + ) + def divide(x, y): + return x / y + + assert divide(10, 2) == 5 + + +def test_with_local_dependencies( + sagemaker_session, dummy_container_without_error, cpu_instance_type, monkeypatch +): + source_dir_path = os.path.join(os.path.dirname(__file__)) + with monkeypatch.context() as m: + m.chdir(source_dir_path) + dependencies_path = os.path.join(DATA_DIR, "remote_function", "requirements.txt") + + @remote( + role=ROLE, + image_uri=dummy_container_without_error, + dependencies=dependencies_path, + instance_type=cpu_instance_type, + sagemaker_session=sagemaker_session, + include_local_workdir=True, + ) + def train(x): + from helpers import local_module + from helpers.nested_helper import local_module2 + + return local_module.square(x) + local_module2.cube(x) + + assert train(2) == 12 + + +def test_with_additional_dependencies( + sagemaker_session, dummy_container_without_error, cpu_instance_type +): + dependencies_path = os.path.join(DATA_DIR, "remote_function", "requirements.txt") + + @remote( + role=ROLE, + image_uri=dummy_container_without_error, + dependencies=dependencies_path, + instance_type=cpu_instance_type, + sagemaker_session=sagemaker_session, + ) + def cuberoot(x): + from scipy.special import cbrt + + return cbrt(27) + + assert cuberoot(27) == 3 + + +def test_additional_dependencies_with_job_conda_env( + sagemaker_session, dummy_container_with_conda, cpu_instance_type +): + dependencies_path = os.path.join(DATA_DIR, "remote_function", "requirements.txt") + + @remote( + role=ROLE, + image_uri=dummy_container_with_conda, + dependencies=dependencies_path, + instance_type=cpu_instance_type, + sagemaker_session=sagemaker_session, + job_conda_env="integ_test_env", + ) + def cuberoot(x): + from scipy.special import cbrt + + return cbrt(x) + + assert cuberoot(27) == 3 + + +def test_additional_dependencies_with_default_conda_env( + sagemaker_session, dummy_container_with_conda, cpu_instance_type +): + dependencies_path = os.path.join(DATA_DIR, "remote_function", "requirements.txt") + + @remote( + role=ROLE, + image_uri=dummy_container_with_conda, + dependencies=dependencies_path, + instance_type=cpu_instance_type, + sagemaker_session=sagemaker_session, + ) + def cuberoot(x): + from scipy.special import cbrt + + return cbrt(x) + + assert cuberoot(27) == 3 + + +def test_additional_dependencies_with_non_existent_conda_env( + sagemaker_session, dummy_container_with_conda, cpu_instance_type +): + dependencies_path = os.path.join(DATA_DIR, "remote_function", "requirements.txt") + + @remote( + role=ROLE, + image_uri=dummy_container_with_conda, + dependencies=dependencies_path, + instance_type=cpu_instance_type, + sagemaker_session=sagemaker_session, + job_conda_env="non_existent_env", + ) + def cuberoot(x): + from scipy.special import cbrt + + return cbrt(x) + + with pytest.raises(RuntimeEnvironmentError): + cuberoot(27) == 3 + + +def test_additional_dependencies_with_conda_yml_file( + sagemaker_session, dummy_container_with_conda, cpu_instance_type, conda_env_yml +): + @remote( + role=ROLE, + image_uri=dummy_container_with_conda, + dependencies=conda_env_yml, + instance_type=cpu_instance_type, + sagemaker_session=sagemaker_session, + job_conda_env="integ_test_env", + keep_alive_period_in_seconds=120, + ) + def cuberoot(x): + from scipy.special import cbrt + + return cbrt(x) + + assert cuberoot(27) == 3 + + +def test_with_non_existent_dependencies( + sagemaker_session, dummy_container_without_error, cpu_instance_type +): + dependencies_path = os.path.join(DATA_DIR, "remote_function", "non_existent_requirements.txt") + + @remote( + role=ROLE, + image_uri=dummy_container_without_error, + dependencies=dependencies_path, + instance_type=cpu_instance_type, + sagemaker_session=sagemaker_session, + keep_alive_period_in_seconds=30, + ) + def divide(x, y): + return x / y + + with pytest.raises(RuntimeEnvironmentError): + divide(10, 2) + + +def test_with_incompatible_dependencies( + sagemaker_session, dummy_container_without_error, cpu_instance_type +): + + dependencies_path = os.path.join(DATA_DIR, "remote_function", "old_deps_requirements.txt") + + @remote( + role=ROLE, + image_uri=dummy_container_without_error, + dependencies=dependencies_path, + instance_type=cpu_instance_type, + sagemaker_session=sagemaker_session, + keep_alive_period_in_seconds=30, + ) + def mul_ten(df: pd.DataFrame): + return df.mul(10) + + df1 = pd.DataFrame( + { + "A": [14, 4, 5, 4, 1], + "B": [5, 2, 54, 3, 2], + "C": [20, 20, 7, 3, 8], + "D": [14, 3, 6, 2, 6], + } + ) + + with pytest.raises(DeserializationError): + mul_ten(df1) + + +def test_decorator_with_exp_and_run_names_passed_to_remote_function( + sagemaker_session, dummy_container_without_error, cpu_instance_type +): + @remote( + role=ROLE, + image_uri=dummy_container_without_error, + instance_type=cpu_instance_type, + sagemaker_session=sagemaker_session, + keep_alive_period_in_seconds=30, + ) + def train(exp_name, run_name): + + with Run(experiment_name=exp_name, run_name=run_name) as run: + print(f"Experiment name: {run.experiment_name}") + print(f"Run name: {run.run_name}") + print(f"Trial component name: {run._trial_component.trial_component_name}") + + run.log_parameter("p1", 1.0) + run.log_parameter("p2", 2) + + for i in range(2): + run.log_metric("A", i) + for i in range(2): + run.log_metric("B", i) + for i in range(2): + run.log_metric("C", i) + for i in range(2): + time.sleep(0.003) + run.log_metric("D", i) + for i in range(2): + time.sleep(0.003) + run.log_metric("E", i) + time.sleep(15) + + exp_name = unique_name_from_base("my-test-exp") + run_name = "my-test-run" + tc_name = Run._generate_trial_component_name(experiment_name=exp_name, run_name=run_name) + + with cleanup_exp_resources(exp_names=[exp_name], sagemaker_session=sagemaker_session): + train(exp_name, run_name) + + tc = _TrialComponent.load(trial_component_name=tc_name, sagemaker_session=sagemaker_session) + + assert tc.start_time + assert tc.end_time + assert tc.status.primary_status == _TrialComponentStatusType.Completed.value + assert tc.parameters["p1"] == 1.0 + assert tc.parameters["p2"] == 2.0 + assert len(tc.metrics) == 5 + for metric_summary in tc.metrics: + # metrics deletion is not supported at this point + # so its count would accumulate + assert metric_summary.count > 0 + assert metric_summary.min == 0.0 + assert metric_summary.max == 1.0 + + +def test_decorator_load_run_inside_remote_function( + sagemaker_session, dummy_container_without_error, cpu_instance_type +): + @remote( + role=ROLE, + image_uri=dummy_container_without_error, + instance_type=cpu_instance_type, + sagemaker_session=sagemaker_session, + keep_alive_period_in_seconds=30, + ) + def train(): + with load_run() as run: + run.log_parameters({"p3": 3.0, "p4": 4}) + run.log_metric("test-job-load-log-metric", 0.1) + + exp_name = unique_name_from_base("my-test-exp") + run_name = "my-test-run" + tc_name = Run._generate_trial_component_name(experiment_name=exp_name, run_name=run_name) + + with cleanup_exp_resources(exp_names=[exp_name], sagemaker_session=sagemaker_session): + with Run( + experiment_name=exp_name, + run_name=run_name, + sagemaker_session=sagemaker_session, + ): + train() + + tc = _TrialComponent.load(trial_component_name=tc_name, sagemaker_session=sagemaker_session) + + assert tc.parameters["p3"] == 3.0 + assert tc.parameters["p4"] == 4.0 + assert len(tc.metrics) > 0 + for metric_summary in tc.metrics: + if metric_summary.metric_name != "test-job-load-log-metric": + continue + assert metric_summary.last == 0.1 + assert metric_summary.max == 0.1 + assert metric_summary.min == 0.1 + + +def test_decorator_with_nested_exp_run( + sagemaker_session, dummy_container_without_error, cpu_instance_type +): + @remote( + role=ROLE, + image_uri=dummy_container_without_error, + instance_type=cpu_instance_type, + sagemaker_session=sagemaker_session, + keep_alive_period_in_seconds=30, + ) + def train(exp_name, run_name): + with Run( + experiment_name=exp_name, + run_name=run_name, + ) as run: + print(f"Experiment name: {run.experiment_name}") + print(f"Run name: {run.run_name}") + print(f"Trial component name: {run._trial_component.trial_component_name}") + + run.log_parameter("p1", 1.0) + run.log_parameter("p2", 2) + + for i in range(2): + run.log_metric("A", i) + for i in range(2): + run.log_metric("B", i) + for i in range(2): + run.log_metric("C", i) + for i in range(2): + time.sleep(0.003) + run.log_metric("D", i) + for i in range(2): + time.sleep(0.003) + run.log_metric("E", i) + time.sleep(15) + + exp_name = unique_name_from_base("my-test-exp") + run_name = "my-test-run" + + with cleanup_exp_resources(exp_names=[exp_name], sagemaker_session=sagemaker_session): + with pytest.raises( + RuntimeError, match="It is not allowed to use nested 'with' statements on the Run." + ): + with Run( + experiment_name=exp_name, + run_name=run_name, + sagemaker_session=sagemaker_session, + ): + train( + exp_name=exp_name, + run_name=run_name, + ) + + +def test_decorator_function_defined_in_with_run( + sagemaker_session, dummy_container_without_error, cpu_instance_type +): + exp_name = unique_name_from_base("my-test-exp") + run_name = "my-test-run" + with Run( + experiment_name=exp_name, + run_name=run_name, + sagemaker_session=sagemaker_session, + ) as run: + + @remote( + role=ROLE, + image_uri=dummy_container_without_error, + instance_type=cpu_instance_type, + sagemaker_session=sagemaker_session, + ) + def train(metric_1, value_1, metric_2, value_2): + run.log_parameter(metric_1, value_1) + run.log_parameter(metric_2, value_2) + + with pytest.raises(SerializationError) as e: + train("p1", 1.0, "p2", 0.5) + assert isinstance(e.__cause__, NotImplementedError) + + +def test_decorator_pre_execution_command( + sagemaker_session, dummy_container_without_error, cpu_instance_type +): + + random_str_1 = "".join(random.choices(string.ascii_uppercase + string.digits, k=10)) + random_str_2 = "".join(random.choices(string.ascii_uppercase + string.digits, k=10)) + random_str_3 = "".join(random.choices(string.ascii_uppercase + string.digits, k=10)) + + @remote( + role=ROLE, + image_uri=dummy_container_without_error, + pre_execution_commands=[ + f"echo {random_str_1} > {random_str_1}", + f"echo {random_str_2} > {random_str_2}", + f"echo {random_str_3} > {random_str_3}", + f"rm ./{random_str_2}", + ], + instance_type=cpu_instance_type, + sagemaker_session=sagemaker_session, + keep_alive_period_in_seconds=60, + ) + def get_file_content(file_names): + joined_content = "" + for file_name in file_names: + if os.path.exists(file_name): + with open(f"{file_name}", "r") as f: + joined_content += f.read() + return joined_content + + assert ( + get_file_content([random_str_1, random_str_2, random_str_3]) + == random_str_1 + "\n" + random_str_3 + "\n" + ) + + +def test_decorator_pre_execution_script( + sagemaker_session, dummy_container_without_error, cpu_instance_type +): + + pre_execution_script_path = os.path.join(DATA_DIR, "remote_function", "pre_exec_commands") + + @remote( + role=ROLE, + image_uri=dummy_container_without_error, + pre_execution_script=pre_execution_script_path, + instance_type=cpu_instance_type, + sagemaker_session=sagemaker_session, + keep_alive_period_in_seconds=60, + ) + def get_file_content(file_names): + joined_content = "" + for file_name in file_names: + if os.path.exists(file_name): + with open(f"{file_name}", "r") as f: + joined_content += f.read() + return joined_content + + assert ( + get_file_content(["test_file_1", "test_file_2", "test_file_3"]) + == "test-content-1" + "\n" + "test-content-3" + "\n" + ) + + +def test_decorator_pre_execution_command_error( + sagemaker_session, dummy_container_without_error, cpu_instance_type +): + + random_str_1 = "".join(random.choices(string.ascii_uppercase + string.digits, k=10)) + random_str_2 = "".join(random.choices(string.ascii_uppercase + string.digits, k=10)) + random_str_3 = "".join(random.choices(string.ascii_uppercase + string.digits, k=10)) + + @remote( + role=ROLE, + image_uri=dummy_container_without_error, + pre_execution_commands=[ + f"echo {random_str_1} > {random_str_1}", + "aws sagemaker describe-training-job", + f"echo {random_str_3} > {random_str_3}", + ], + instance_type=cpu_instance_type, + sagemaker_session=sagemaker_session, + keep_alive_period_in_seconds=60, + ) + def get_file_content(file_names): + joined_content = "" + for file_name in file_names: + if os.path.exists(file_name): + with open(f"{file_name}", "r") as f: + joined_content += f.read() + return joined_content + + with pytest.raises(RuntimeEnvironmentError) as e: + get_file_content([random_str_1, random_str_2, random_str_3]) + assert "aws: error: the following arguments are required: --training-job-name" in str(e) + + +def test_decorator_pre_execution_script_error( + sagemaker_session, dummy_container_without_error, cpu_instance_type +): + + pre_execution_script_path = os.path.join( + DATA_DIR, "remote_function", "pre_exec_commands_bad_cmd" + ) + + @remote( + role=ROLE, + image_uri=dummy_container_without_error, + pre_execution_script=pre_execution_script_path, + instance_type=cpu_instance_type, + sagemaker_session=sagemaker_session, + keep_alive_period_in_seconds=60, + ) + def get_file_content(file_names): + joined_content = "" + for file_name in file_names: + if os.path.exists(file_name): + with open(f"{file_name}", "r") as f: + joined_content += f.read() + return joined_content + + with pytest.raises(RuntimeEnvironmentError) as e: + get_file_content(["test_file_1", "test_file_2", "test_file_3"]) + assert "line 2: bws: command not found" in str(e) diff --git a/tests/integ/sagemaker/remote_function/test_executor.py b/tests/integ/sagemaker/remote_function/test_executor.py new file mode 100644 index 0000000000..576ad5bd14 --- /dev/null +++ b/tests/integ/sagemaker/remote_function/test_executor.py @@ -0,0 +1,255 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +from sagemaker.experiments.trial_component import _TrialComponent +from sagemaker.remote_function import RemoteExecutor +from sagemaker.remote_function.client import get_future, list_futures +from sagemaker.experiments.run import Run, load_run +from tests.integ.sagemaker.experiments.helpers import cleanup_exp_resources +from sagemaker.utils import unique_name_from_base + +ROLE = "SageMakerRole" + + +def test_executor_submit(sagemaker_session, dummy_container_without_error, cpu_instance_type): + def square(x): + return x * x + + def cube(x): + return x * x * x + + with RemoteExecutor( + max_parallel_jobs=1, + role=ROLE, + image_uri=dummy_container_without_error, + instance_type=cpu_instance_type, + sagemaker_session=sagemaker_session, + keep_alive_period_in_seconds=30, + ) as e: + future_1 = e.submit(square, 10) + future_2 = e.submit(cube, 10) + + assert future_1.result() == 100 + assert future_2.result() == 1000 + + assert get_future(future_1._job.job_name, sagemaker_session).result() == 100 + assert get_future(future_2._job.job_name, sagemaker_session).result() == 1000 + + assert next( + list_futures(job_name_prefix="square", sagemaker_session=sagemaker_session) + )._job.job_name.startswith("square") + assert next( + list_futures(job_name_prefix="cube", sagemaker_session=sagemaker_session) + )._job.job_name.startswith("cube") + + +def test_executor_map(sagemaker_session, dummy_container_without_error, cpu_instance_type): + def power(a, b): + return a**b + + with RemoteExecutor( + max_parallel_jobs=1, + role=ROLE, + image_uri=dummy_container_without_error, + instance_type=cpu_instance_type, + sagemaker_session=sagemaker_session, + keep_alive_period_in_seconds=30, + ) as e: + results = e.map(power, [5, 6], [2, 3]) + + assert len(results) == 2 + assert results[0] == 25 + assert results[1] == 216 + + assert next( + list_futures(job_name_prefix="power", sagemaker_session=sagemaker_session) + )._job.job_name.startswith("power") + + +def test_executor_submit_with_run_inside( + sagemaker_session, dummy_container_without_error, cpu_instance_type +): + def square(x): + with load_run() as run: + result = x * x + run.log_metric("x", result) + return result + + def cube(x): + with load_run() as run: + result = x * x * x + run.log_metric("x", result) + return result + + exp_name = unique_name_from_base("my-test-exp") + run_name = "my-test-run" + tc_name = Run._generate_trial_component_name(experiment_name=exp_name, run_name=run_name) + + with cleanup_exp_resources(exp_names=[exp_name], sagemaker_session=sagemaker_session): + with RemoteExecutor( + max_parallel_jobs=1, + role=ROLE, + image_uri=dummy_container_without_error, + instance_type=cpu_instance_type, + sagemaker_session=sagemaker_session, + keep_alive_period_in_seconds=30, + ) as e: + with Run( + experiment_name=exp_name, + run_name=run_name, + sagemaker_session=sagemaker_session, + ): + future_1 = e.submit(square, 10) + future_2 = e.submit(cube, 10) + + assert future_1.result() == 100 + assert future_2.result() == 1000 + + assert get_future(future_1._job.job_name, sagemaker_session).result() == 100 + assert get_future(future_2._job.job_name, sagemaker_session).result() == 1000 + + assert next( + list_futures(job_name_prefix="square", sagemaker_session=sagemaker_session) + )._job.job_name.startswith("square") + assert next( + list_futures(job_name_prefix="cube", sagemaker_session=sagemaker_session) + )._job.job_name.startswith("cube") + + tc = _TrialComponent.load(trial_component_name=tc_name, sagemaker_session=sagemaker_session) + assert len(tc.metrics) > 0 + for metric_summary in tc.metrics: + if metric_summary.metric_name != "x": + continue + assert metric_summary.max == 1000 + assert metric_summary.min == 100 + assert metric_summary.avg == 550 + + +def test_executor_submit_with_run_outside( + sagemaker_session, dummy_container_without_error, cpu_instance_type +): + def square(x): + with load_run() as run: + result = x * x + run.log_metric("x", result) + return result + + def cube(x): + with load_run() as run: + result = x * x * x + run.log_metric("x", result) + return result + + exp_name = unique_name_from_base("my-test-exp") + run_name = "my-test-run" + tc_name = Run._generate_trial_component_name(experiment_name=exp_name, run_name=run_name) + + with cleanup_exp_resources(exp_names=[exp_name], sagemaker_session=sagemaker_session): + with Run( + experiment_name=exp_name, + run_name=run_name, + sagemaker_session=sagemaker_session, + ): + with RemoteExecutor( + max_parallel_jobs=1, + role=ROLE, + image_uri=dummy_container_without_error, + instance_type=cpu_instance_type, + sagemaker_session=sagemaker_session, + keep_alive_period_in_seconds=30, + ) as e: + future_1 = e.submit(square, 10) + future_2 = e.submit(cube, 10) + + assert future_1.result() == 100 + assert future_2.result() == 1000 + + assert get_future(future_1._job.job_name, sagemaker_session).result() == 100 + assert get_future(future_2._job.job_name, sagemaker_session).result() == 1000 + + assert next( + list_futures(job_name_prefix="square", sagemaker_session=sagemaker_session) + )._job.job_name.startswith("square") + assert next( + list_futures(job_name_prefix="cube", sagemaker_session=sagemaker_session) + )._job.job_name.startswith("cube") + + tc = _TrialComponent.load(trial_component_name=tc_name, sagemaker_session=sagemaker_session) + assert len(tc.metrics) > 0 + for metric_summary in tc.metrics: + if metric_summary.metric_name != "x": + continue + assert metric_summary.max == 1000 + assert metric_summary.min == 100 + assert metric_summary.avg == 550 + + +def test_executor_map_with_run(sagemaker_session, dummy_container_without_error, cpu_instance_type): + def square(x): + with load_run() as run: + result = x * x + run.log_metric("x", result) + return result + + exp_name = unique_name_from_base("my-test-exp") + run_name = "my-test-run" + tc_name = Run._generate_trial_component_name(experiment_name=exp_name, run_name=run_name) + + with cleanup_exp_resources(exp_names=[exp_name], sagemaker_session=sagemaker_session): + with Run( + experiment_name=exp_name, + run_name=run_name, + sagemaker_session=sagemaker_session, + ): + with RemoteExecutor( + max_parallel_jobs=2, + role=ROLE, + image_uri=dummy_container_without_error, + instance_type=cpu_instance_type, + sagemaker_session=sagemaker_session, + keep_alive_period_in_seconds=30, + ) as e: + results = e.map(square, [2, 4]) + + assert len(results) == 2 + assert results[0] == 4 + assert results[1] == 16 + + with RemoteExecutor( + max_parallel_jobs=2, + role=ROLE, + image_uri=dummy_container_without_error, + instance_type=cpu_instance_type, + sagemaker_session=sagemaker_session, + keep_alive_period_in_seconds=30, + ) as e: + with Run( + experiment_name=exp_name, + run_name=run_name, + sagemaker_session=sagemaker_session, + ): + results = e.map(square, [6, 8]) + + assert len(results) == 2 + assert results[0] == 36 + assert results[1] == 64 + + tc = _TrialComponent.load(trial_component_name=tc_name, sagemaker_session=sagemaker_session) + assert len(tc.metrics) > 0 + for metric_summary in tc.metrics: + if metric_summary.metric_name != "x": + continue + assert metric_summary.max == 64 + assert metric_summary.min == 4 + assert metric_summary.avg == 30 diff --git a/tests/integ/test_s3.py b/tests/integ/test_s3.py index ef7e3bf85b..9e81336425 100644 --- a/tests/integ/test_s3.py +++ b/tests/integ/test_s3.py @@ -12,6 +12,7 @@ # language governing permissions and limitations under the License. from __future__ import absolute_import +import io import os import uuid @@ -238,3 +239,34 @@ def test_s3_uploader_and_downloader_downloads_files_when_given_directory_uris_wi with open(os.path.join(TMP_BASE_PATH, my_inner_directory_uuid, file_2_name), "r") as f: assert file_2_body == f.read() + + +def test_upload_and_read_bytes(sagemaker_session, s3_files_kms_key): + my_uuid = str(uuid.uuid4()) + base_s3_uri = os.path.join( + "s3://", sagemaker_session.default_bucket(), "integ-test-test-upload-read-bytes", my_uuid + ) + + body = bytes(my_uuid, "utf-8") + + S3Uploader.upload_bytes( + body, + s3_uri=os.path.join(base_s3_uri, "from_bytes"), + kms_key=s3_files_kms_key, + sagemaker_session=sagemaker_session, + ) + + S3Uploader.upload_bytes( + io.BytesIO(body), + s3_uri=os.path.join(base_s3_uri, "from_bytes_io"), + kms_key=s3_files_kms_key, + sagemaker_session=sagemaker_session, + ) + + assert body == S3Downloader.read_bytes( + s3_uri=os.path.join(base_s3_uri, "from_bytes"), sagemaker_session=sagemaker_session + ) + + assert body == S3Downloader.read_bytes( + s3_uri=os.path.join(base_s3_uri, "from_bytes_io"), sagemaker_session=sagemaker_session + ) diff --git a/tests/unit/sagemaker/config/conftest.py b/tests/unit/sagemaker/config/conftest.py index ef51538dc9..3e1feb4adc 100644 --- a/tests/unit/sagemaker/config/conftest.py +++ b/tests/unit/sagemaker/config/conftest.py @@ -32,6 +32,11 @@ def valid_iam_role_arn(): return "arn:aws:iam::555555555555:role/IMRole" +@pytest.fixture() +def valid_tags(): + return [{"Key": "tag1", "Value": "tagValue1"}] + + @pytest.fixture() def valid_feature_group_config(valid_iam_role_arn): security_storage_config = {"KmsKeyId": "kmskeyid1"} @@ -164,6 +169,26 @@ def valid_monitoring_schedule_config(valid_iam_role_arn, valid_vpc_config): } +@pytest.fixture() +def valid_remote_function_config(valid_iam_role_arn, valid_tags, valid_vpc_config): + return { + "RemoteFunction": { + "Dependencies": "./requirements.txt", + "EnvironmentVariables": {"var1": "value1", "var2": "value2"}, + "ImageUri": "123456789012.dkr.ecr.us-west-2.amazonaws.com/myimage:latest", + "IncludeLocalWorkDir": True, + "InstanceType": "ml.m5.xlarge", + "JobCondaEnvironment": "some_conda_env", + "RoleArn": valid_iam_role_arn, + "S3KmsKeyId": "kmskeyid1", + "S3RootUri": "s3://my-bucket/key", + "Tags": valid_tags, + "VolumeKmsKeyId": "kmskeyid2", + "VpcConfig": valid_vpc_config, + } + } + + @pytest.fixture() def valid_config_with_all_the_scopes( valid_feature_group_config, @@ -178,6 +203,7 @@ def valid_config_with_all_the_scopes( valid_processing_job_config, valid_training_job_config, valid_edge_packaging_config, + valid_remote_function_config, ): return { "FeatureGroup": valid_feature_group_config, @@ -192,6 +218,7 @@ def valid_config_with_all_the_scopes( "ProcessingJob": valid_processing_job_config, "TrainingJob": valid_training_job_config, "EdgePackagingJob": valid_edge_packaging_config, + "PythonSDK": {"Modules": valid_remote_function_config}, } diff --git a/tests/unit/sagemaker/config/test_config_schema.py b/tests/unit/sagemaker/config/test_config_schema.py index f7170c3d3e..2efe47b8a2 100644 --- a/tests/unit/sagemaker/config/test_config_schema.py +++ b/tests/unit/sagemaker/config/test_config_schema.py @@ -95,6 +95,12 @@ def test_valid_monitoring_schedule_schema( ) +def test_valid_remote_function_schema(base_config_with_schema, valid_remote_function_config): + _validate_config( + base_config_with_schema, {"PythonSDK": {"Modules": valid_remote_function_config}} + ) + + def test_tags_with_invalid_schema(base_config_with_schema, valid_edge_packaging_config): edge_packaging_config = valid_edge_packaging_config.copy() edge_packaging_config["Tags"] = [{"Key": "somekey"}] @@ -185,3 +191,11 @@ def test_invalid_custom_parameters_schema(base_config_with_schema): config["CustomParameters"] = {"custom_key": {"custom_key": "custom_value"}} with pytest.raises(exceptions.ValidationError): validate(config, SAGEMAKER_PYTHON_SDK_CONFIG_SCHEMA) + + +def test_invalid_s3uri_schema(base_config_with_schema): + config = base_config_with_schema + + config["SageMaker"] = {"PythonSDK": {"Modules": {"RemoteFunction": {"S3RootUri": "bad_regex"}}}} + with pytest.raises(exceptions.ValidationError): + validate(config, SAGEMAKER_PYTHON_SDK_CONFIG_SCHEMA) diff --git a/tests/unit/sagemaker/experiments/helpers.py b/tests/unit/sagemaker/experiments/helpers.py index b7914010e5..0fec9f7fc3 100644 --- a/tests/unit/sagemaker/experiments/helpers.py +++ b/tests/unit/sagemaker/experiments/helpers.py @@ -18,6 +18,9 @@ TEST_EXP_NAME = "my-experiment" TEST_RUN_NAME = "my-run" +TEST_EXP_DISPLAY_NAME = "my-experiment-display-name" +TEST_RUN_DISPLAY_NAME = "my-run-display-name" +TEST_TAGS = [{"Key": "some-key", "Value": "some-value"}] def mock_tc_load_or_create_func( diff --git a/tests/unit/sagemaker/experiments/test_run.py b/tests/unit/sagemaker/experiments/test_run.py index c936ae4ddb..a6495fc914 100644 --- a/tests/unit/sagemaker/experiments/test_run.py +++ b/tests/unit/sagemaker/experiments/test_run.py @@ -14,6 +14,7 @@ import datetime import unittest +import cloudpickle from math import inf, nan from unittest.mock import patch, Mock, MagicMock @@ -48,6 +49,8 @@ mock_tc_load_or_create_func, TEST_EXP_NAME, TEST_RUN_NAME, + TEST_EXP_DISPLAY_NAME, + TEST_RUN_DISPLAY_NAME, ) @@ -340,6 +343,34 @@ def test_run_load_in_sm_transform_job(mock_run_env, sagemaker_session): client.describe_transform_job.assert_called_once_with(TransformJobName=job_name) +@patch( + "sagemaker.experiments.run._Experiment._load_or_create", + MagicMock(return_value=_Experiment(experiment_name=TEST_EXP_NAME)), +) +@patch( + "sagemaker.experiments.run._Trial._load_or_create", + MagicMock(side_effect=mock_trial_load_or_create_func), +) +@patch.object(_Trial, "add_trial_component", MagicMock(return_value=None)) +@patch( + "sagemaker.experiments.run._TrialComponent._load_or_create", + MagicMock(side_effect=mock_tc_load_or_create_func), +) +@patch.object(_TrialComponent, "save") +def test_run_object_serialize_deserialize(mock_tc_save, sagemaker_session): + run_obj = Run( + experiment_name=TEST_EXP_NAME, + run_name=TEST_RUN_NAME, + experiment_display_name=TEST_EXP_DISPLAY_NAME, + run_display_name=TEST_RUN_DISPLAY_NAME, + sagemaker_session=sagemaker_session, + ) + with pytest.raises( + NotImplementedError, match="Instance of Run type is not allowed to be pickled." + ): + cloudpickle.dumps(run_obj) + + def test_log_parameter_outside_run_context(run_obj): with pytest.raises(RuntimeError) as err: run_obj.log_parameter("foo", "bar") diff --git a/tests/unit/sagemaker/image_uris/expected_uris.py b/tests/unit/sagemaker/image_uris/expected_uris.py index 2729d7db51..ce9f07ff69 100644 --- a/tests/unit/sagemaker/image_uris/expected_uris.py +++ b/tests/unit/sagemaker/image_uris/expected_uris.py @@ -78,3 +78,9 @@ def djl_framework_uri(repo, account, djl_version, primary_framework, region=REGI domain = ALTERNATE_DOMAINS.get(region, DOMAIN) tag = f"{djl_version}-{primary_framework}" return IMAGE_URI_FORMAT.format(account, region, domain, repo, tag) + + +def base_python_uri(repo, account, region=REGION): + domain = ALTERNATE_DOMAINS.get(region, DOMAIN) + tag = "1.0" + return IMAGE_URI_FORMAT.format(account, region, domain, repo, tag) diff --git a/tests/unit/sagemaker/image_uris/test_base_python.py b/tests/unit/sagemaker/image_uris/test_base_python.py new file mode 100644 index 0000000000..52ea9743bd --- /dev/null +++ b/tests/unit/sagemaker/image_uris/test_base_python.py @@ -0,0 +1,61 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import +import pytest +from sagemaker import image_uris +from tests.unit.sagemaker.image_uris import expected_uris + +REGISTRIES = { + "us-east-2": "429704687514", + "me-south-1": "117516905037", + "us-west-2": "236514542706", + "ca-central-1": "310906938811", + "ap-east-1": "493642496378", + "us-east-1": "081325390199", + "ap-northeast-2": "806072073708", + "eu-west-2": "712779665605", + "ap-southeast-2": "52832661640", + "cn-northwest-1": "390780980154", + "eu-north-1": "243637512696", + "cn-north-1": "390048526115", + "ap-south-1": "394103062818", + "eu-west-3": "615547856133", + "ap-southeast-3": "276181064229", + "af-south-1": "559312083959", + "eu-west-1": "470317259841", + "eu-central-1": "936697816551", + "sa-east-1": "782484402741", + "ap-northeast-3": "792733760839", + "eu-south-1": "592751261982", + "ap-northeast-1": "102112518831", + "us-west-1": "742091327244", + "ap-southeast-1": "492261229750", + "me-central-1": "103105715889", + "us-gov-east-1": "107072934176", + "us-gov-west-1": "107173498710", +} + + +@pytest.mark.parametrize("py_version", ["310", "38"]) +def test_get_base_python_image_uri(py_version): + for region in REGISTRIES.keys(): + uri = image_uris.get_base_python_image_uri( + region=region, + py_version=py_version, + ) + + repo = "sagemaker-base-python-" + py_version + expected = expected_uris.base_python_uri( + repo=repo, account=REGISTRIES[region], region=region + ) + assert expected == uri diff --git a/tests/unit/sagemaker/local/test_local_image.py b/tests/unit/sagemaker/local/test_local_image.py index da3cec026c..f7632a748d 100644 --- a/tests/unit/sagemaker/local/test_local_image.py +++ b/tests/unit/sagemaker/local/test_local_image.py @@ -300,6 +300,7 @@ def test_retrieve_artifacts(LocalSession, tmpdir): output_tar_files = [m.name for m in tar.getmembers()] for f in expected_output: assert f in output_tar_files + tar.close() def test_stream_output(): diff --git a/tests/unit/sagemaker/remote_function/core/test_serialization.py b/tests/unit/sagemaker/remote_function/core/test_serialization.py new file mode 100644 index 0000000000..e48137e30a --- /dev/null +++ b/tests/unit/sagemaker/remote_function/core/test_serialization.py @@ -0,0 +1,394 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +import os.path +import random +import string +import pytest + +from mock import patch, Mock +from sagemaker.experiments.run import Run +from sagemaker.remote_function.core.serialization import ( + serialize_func_to_s3, + deserialize_func_from_s3, + serialize_obj_to_s3, + deserialize_obj_from_s3, + serialize_exception_to_s3, + deserialize_exception_from_s3, +) +from sagemaker.remote_function.errors import ServiceError, SerializationError, DeserializationError +from tblib import pickling_support + +KMS_KEY = "kms-key" + + +mock_s3 = {} + + +def random_s3_uri(): + return "".join(random.choices(string.ascii_uppercase + string.digits, k=10)) + + +def upload(b, s3_uri, kms_key=None, sagemaker_session=None): + assert kms_key == KMS_KEY + mock_s3[s3_uri] = b + + +def read(s3_uri, sagemaker_session=None): + return mock_s3[s3_uri] + + +def upload_error(b, s3_uri, kms_key=None, sagemaker_session=None): + raise RuntimeError("some failure when upload_bytes") + + +def read_error(s3_uri, sagemaker_session=None): + raise RuntimeError("some failure when read_bytes") + + +@patch("sagemaker.s3.S3Uploader.upload_bytes", new=upload) +@patch("sagemaker.s3.S3Downloader.read_bytes", new=read) +def test_serialize_deserialize_func(): + def square(x): + return x * x + + s3_uri = random_s3_uri() + serialize_func_to_s3(func=square, sagemaker_session=Mock(), s3_uri=s3_uri, s3_kms_key=KMS_KEY) + + del square + + deserialized = deserialize_func_from_s3(sagemaker_session=Mock(), s3_uri=s3_uri) + + assert deserialized(3) == 9 + + +@patch("sagemaker.s3.S3Uploader.upload_bytes", new=upload) +@patch("sagemaker.s3.S3Downloader.read_bytes", new=read) +def test_serialize_deserialize_lambda(): + + s3_uri = random_s3_uri() + serialize_func_to_s3( + func=lambda x: x * x, sagemaker_session=Mock(), s3_uri=s3_uri, s3_kms_key=KMS_KEY + ) + + deserialized = deserialize_func_from_s3(sagemaker_session=Mock(), s3_uri=s3_uri) + + assert deserialized(3) == 9 + + +@patch("sagemaker.s3.S3Uploader.upload_bytes", new=upload) +@patch("sagemaker.s3.S3Downloader.read_bytes", new=read) +@patch("sagemaker.experiments.run._Experiment") +@patch("sagemaker.experiments.run._Trial") +@patch("sagemaker.experiments.run._TrialComponent._load_or_create", return_value=(Mock(), False)) +@patch("sagemaker.experiments.run._MetricsManager") +@patch("sagemaker.remote_function.job.Session") +def test_serialize_func_referencing_to_run(*args, **kwargs): + + with Run(experiment_name="exp_name", run_name="run_name") as run: + + def train(x): + return run.log_metric() + + s3_uri = random_s3_uri() + with pytest.raises( + SerializationError, + match="or instantiate a new Run in the function.", + ): + serialize_func_to_s3( + func=train, sagemaker_session=Mock(), s3_uri=s3_uri, s3_kms_key=KMS_KEY + ) + + +@patch("sagemaker.s3.S3Uploader.upload_bytes", new=upload) +@patch("sagemaker.s3.S3Downloader.read_bytes", new=read) +@patch("cloudpickle.dumps") +def test_serialize_func_serialization_error(mock_cloudpickle_dumps): + mock_cloudpickle_dumps.side_effect = RuntimeError("some failure when dumps") + + def square(x): + return x * x + + s3_uri = random_s3_uri() + + with pytest.raises( + SerializationError, + match=r"Error when serializing object of type \[function\]: RuntimeError\('some failure when dumps'\)", + ): + serialize_func_to_s3( + func=square, sagemaker_session=Mock(), s3_uri=s3_uri, s3_kms_key=KMS_KEY + ) + + +@patch("sagemaker.s3.S3Uploader.upload_bytes", new=upload) +@patch("sagemaker.s3.S3Downloader.read_bytes", new=read) +@patch("cloudpickle.loads") +def test_deserialize_func_deserialization_error(mock_cloudpickle_loads): + mock_cloudpickle_loads.side_effect = RuntimeError("some failure when loads") + + def square(x): + return x * x + + s3_uri = random_s3_uri() + + serialize_func_to_s3(func=square, sagemaker_session=Mock(), s3_uri=s3_uri, s3_kms_key=KMS_KEY) + + del square + + with pytest.raises( + DeserializationError, + match=rf"Error when deserializing bytes downloaded from {s3_uri}/payload.pkl: " + + r"RuntimeError\('some failure when loads'\)", + ): + deserialize_func_from_s3(sagemaker_session=Mock(), s3_uri=s3_uri) + + +@patch("sagemaker.s3.S3Uploader.upload_bytes", new=upload) +@patch("sagemaker.s3.S3Downloader.read_bytes", new=read) +def test_deserialize_func_corrupt_metadata(): + def square(x): + return x * x + + s3_uri = random_s3_uri() + + serialize_func_to_s3(func=square, sagemaker_session=Mock(), s3_uri=s3_uri, s3_kms_key=KMS_KEY) + mock_s3[os.path.join(s3_uri, "metadata.json")] = b"not json serializable" + + del square + + with pytest.raises(DeserializationError, match=r"Corrupt metadata file."): + deserialize_func_from_s3(sagemaker_session=Mock(), s3_uri=s3_uri) + + +@patch("sagemaker.s3.S3Uploader.upload_bytes", new=upload) +@patch("sagemaker.s3.S3Downloader.read_bytes", new=read) +def test_serialize_deserialize_custom_class_data(): + class MyData: + def __init__(self, x): + self.x = x + + my_data = MyData(10) + + s3_uri = random_s3_uri() + serialize_obj_to_s3(my_data, sagemaker_session=Mock(), s3_uri=s3_uri, s3_kms_key=KMS_KEY) + + del my_data + del MyData + + deserialized = deserialize_obj_from_s3(sagemaker_session=Mock(), s3_uri=s3_uri) + + assert deserialized.x == 10 + + +@patch("sagemaker.s3.S3Uploader.upload_bytes", new=upload) +@patch("sagemaker.s3.S3Downloader.read_bytes", new=read) +def test_serialize_deserialize_data_built_in_types(): + + my_data = {"a": [10]} + + s3_uri = random_s3_uri() + serialize_obj_to_s3(my_data, sagemaker_session=Mock(), s3_uri=s3_uri, s3_kms_key=KMS_KEY) + + del my_data + + deserialized = deserialize_obj_from_s3(sagemaker_session=Mock(), s3_uri=s3_uri) + + assert deserialized == {"a": [10]} + + +@patch("sagemaker.s3.S3Uploader.upload_bytes", new=upload) +@patch("sagemaker.s3.S3Downloader.read_bytes", new=read) +def test_serialize_deserialize_none(): + + s3_uri = random_s3_uri() + serialize_obj_to_s3(None, sagemaker_session=Mock(), s3_uri=s3_uri, s3_kms_key=KMS_KEY) + + deserialized = deserialize_obj_from_s3(sagemaker_session=Mock(), s3_uri=s3_uri) + + assert deserialized is None + + +@patch("sagemaker.s3.S3Uploader.upload_bytes", new=upload) +@patch("sagemaker.s3.S3Downloader.read_bytes", new=read) +@patch("sagemaker.experiments.run._Experiment") +@patch("sagemaker.experiments.run._Trial") +@patch("sagemaker.experiments.run._TrialComponent._load_or_create", return_value=(Mock(), False)) +@patch("sagemaker.experiments.run._MetricsManager") +@patch("sagemaker.remote_function.job.Session") +def test_serialize_run(*args, **kwargs): + with Run(experiment_name="exp_name", run_name="run_name") as run: + s3_uri = random_s3_uri() + with pytest.raises( + SerializationError, + match="or instantiate a new Run in the function.", + ): + serialize_obj_to_s3( + obj=run, sagemaker_session=Mock(), s3_uri=s3_uri, s3_kms_key=KMS_KEY + ) + + +@patch("sagemaker.s3.S3Uploader.upload_bytes", new=upload) +@patch("sagemaker.s3.S3Downloader.read_bytes", new=read) +@patch("cloudpickle.dumps") +def test_serialize_obj_serialization_error(mock_cloudpickle_dumps): + mock_cloudpickle_dumps.side_effect = RuntimeError("some failure when dumps") + + class MyData: + def __init__(self, x): + self.x = x + + my_data = MyData(10) + s3_uri = random_s3_uri() + + with pytest.raises( + SerializationError, + match=r"Error when serializing object of type \[MyData\]: RuntimeError\('some failure when dumps'\)", + ): + serialize_obj_to_s3( + obj=my_data, sagemaker_session=Mock(), s3_uri=s3_uri, s3_kms_key=KMS_KEY + ) + + +@patch("sagemaker.s3.S3Uploader.upload_bytes", new=upload) +@patch("sagemaker.s3.S3Downloader.read_bytes", new=read) +@patch("cloudpickle.loads") +def test_deserialize_obj_deserialization_error(mock_cloudpickle_loads): + mock_cloudpickle_loads.side_effect = RuntimeError("some failure when loads") + + class MyData: + def __init__(self, x): + self.x = x + + my_data = MyData(10) + s3_uri = random_s3_uri() + + serialize_obj_to_s3(obj=my_data, sagemaker_session=Mock(), s3_uri=s3_uri, s3_kms_key=KMS_KEY) + + del my_data + del MyData + + with pytest.raises( + DeserializationError, + match=rf"Error when deserializing bytes downloaded from {s3_uri}/payload.pkl: " + + r"RuntimeError\('some failure when loads'\)", + ): + deserialize_obj_from_s3(sagemaker_session=Mock(), s3_uri=s3_uri) + + +@patch("sagemaker.s3.S3Uploader.upload_bytes", new=upload_error) +@patch("sagemaker.s3.S3Downloader.read_bytes", new=read_error) +def test_serialize_deserialize_service_error(): + + my_func = lambda a: a + 10 # noqa: E731 + + s3_uri = random_s3_uri() + with pytest.raises( + ServiceError, + match=rf"Failed to upload serialized bytes to {s3_uri}/metadata.json: " + + r"RuntimeError\('some failure when upload_bytes'\)", + ): + serialize_func_to_s3( + func=my_func, sagemaker_session=Mock(), s3_uri=s3_uri, s3_kms_key=KMS_KEY + ) + + del my_func + + with pytest.raises( + ServiceError, + match=rf"Failed to read serialized bytes from {s3_uri}/metadata.json: " + + r"RuntimeError\('some failure when read_bytes'\)", + ): + deserialize_func_from_s3(sagemaker_session=Mock(), s3_uri=s3_uri) + + +@patch("sagemaker.s3.S3Uploader.upload_bytes", new=upload) +@patch("sagemaker.s3.S3Downloader.read_bytes", new=read) +def test_serialize_deserialize_exception_with_traceback(): + s3_uri = random_s3_uri() + + class CustomError(Exception): + ... + + def func_a(): + raise TypeError + + def func_b(): + try: + func_a() + except TypeError as first_exception: + raise CustomError("Some error") from first_exception + + try: + func_b() + except Exception as e: + pickling_support.install() + serialize_obj_to_s3(e, sagemaker_session=Mock(), s3_uri=s3_uri, s3_kms_key=KMS_KEY) + + with pytest.raises(CustomError, match="Some error") as exc_info: + raise deserialize_obj_from_s3(sagemaker_session=Mock(), s3_uri=s3_uri) + assert type(exc_info.value.__cause__) is TypeError + + +@patch("sagemaker.s3.S3Uploader.upload_bytes", new=upload) +@patch("sagemaker.s3.S3Downloader.read_bytes", new=read) +def test_serialize_deserialize_custom_exception_with_traceback(): + s3_uri = random_s3_uri() + + class CustomError(Exception): + ... + + def func_a(): + raise TypeError + + def func_b(): + try: + func_a() + except TypeError as first_exception: + raise CustomError("Some error") from first_exception + + try: + func_b() + except Exception as e: + serialize_exception_to_s3(e, sagemaker_session=Mock(), s3_uri=s3_uri, s3_kms_key=KMS_KEY) + + with pytest.raises(CustomError, match="Some error") as exc_info: + raise deserialize_exception_from_s3(sagemaker_session=Mock(), s3_uri=s3_uri) + assert type(exc_info.value.__cause__) is TypeError + + +@patch("sagemaker.s3.S3Uploader.upload_bytes", new=upload) +@patch("sagemaker.s3.S3Downloader.read_bytes", new=read) +def test_serialize_deserialize_remote_function_error_with_traceback(): + s3_uri = random_s3_uri() + + class CustomError(Exception): + ... + + def func_a(): + raise TypeError + + def func_b(): + try: + func_a() + except TypeError as first_exception: + raise ServiceError("Some error") from first_exception + + try: + func_b() + except Exception as e: + serialize_exception_to_s3(e, sagemaker_session=Mock(), s3_uri=s3_uri, s3_kms_key=KMS_KEY) + + with pytest.raises(ServiceError, match="Some error") as exc_info: + raise deserialize_exception_from_s3(sagemaker_session=Mock(), s3_uri=s3_uri) + assert type(exc_info.value.__cause__) is TypeError diff --git a/tests/unit/sagemaker/remote_function/core/test_stored_function.py b/tests/unit/sagemaker/remote_function/core/test_stored_function.py new file mode 100644 index 0000000000..759f06f7cf --- /dev/null +++ b/tests/unit/sagemaker/remote_function/core/test_stored_function.py @@ -0,0 +1,124 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +import pytest +import random +import string +from mock import MagicMock, Mock, patch +from sagemaker.experiments.experiment import _Experiment +from sagemaker.experiments.run import Run +from sagemaker.experiments.trial import _Trial +from sagemaker.experiments.trial_component import _TrialComponent + +from sagemaker.remote_function.core.stored_function import StoredFunction +from sagemaker.remote_function.core.serialization import deserialize_obj_from_s3 +from sagemaker.remote_function.errors import SerializationError +from tests.unit.sagemaker.experiments.helpers import ( + TEST_EXP_DISPLAY_NAME, + TEST_EXP_NAME, + TEST_RUN_DISPLAY_NAME, + TEST_RUN_NAME, + mock_tc_load_or_create_func, + mock_trial_load_or_create_func, +) + +KMS_KEY = "kms-key" + +mock_s3 = {} + + +def random_s3_uri(): + return "".join(random.choices(string.ascii_uppercase + string.digits, k=10)) + + +def upload_bytes(b, s3_uri, kms_key=None, sagemaker_session=None): + assert kms_key == KMS_KEY + mock_s3[s3_uri] = b + + +def read_bytes(s3_uri, sagemaker_session=None): + return mock_s3[s3_uri] + + +def quadratic(x=2, *, a=1, b=0, c=0): + return a * x * x + b * x + c + + +def log_bigger(a, b, run: Run): + if a >= b: + run.log_metric("bigger", a) + else: + run.log_metric("bigger", b) + + +@pytest.mark.parametrize( + "args, kwargs", + [([], {}), ([3], {}), ([], {"a": 2, "b": 1, "c": 1})], +) +@patch("sagemaker.s3.S3Uploader.upload_bytes", new=upload_bytes) +@patch("sagemaker.s3.S3Downloader.read_bytes", new=read_bytes) +@patch("sagemaker.s3.S3Uploader.upload") +@patch("sagemaker.s3.S3Downloader.download") +def test_save_and_load(s3_source_dir_download, s3_source_dir_upload, args, kwargs): + session = Mock() + s3_base_uri = random_s3_uri() + + stored_function = StoredFunction( + sagemaker_session=session, s3_base_uri=s3_base_uri, s3_kms_key=KMS_KEY + ) + stored_function.save(quadratic, *args, **kwargs) + stored_function.load_and_invoke() + + assert deserialize_obj_from_s3(session, s3_uri=f"{s3_base_uri}/results") == quadratic( + *args, **kwargs + ) + + +@patch( + "sagemaker.experiments.run._Experiment._load_or_create", + MagicMock(return_value=_Experiment(experiment_name=TEST_EXP_NAME)), +) +@patch( + "sagemaker.experiments.run._Trial._load_or_create", + MagicMock(side_effect=mock_trial_load_or_create_func), +) +@patch.object(_Trial, "add_trial_component", MagicMock(return_value=None)) +@patch( + "sagemaker.experiments.run._TrialComponent._load_or_create", + MagicMock(side_effect=mock_tc_load_or_create_func), +) +@patch("sagemaker.s3.S3Uploader.upload_bytes", new=upload_bytes) +@patch("sagemaker.s3.S3Downloader.read_bytes", new=read_bytes) +@patch.object(_TrialComponent, "save") +@patch("sagemaker.s3.S3Uploader.upload") +@patch("sagemaker.s3.S3Downloader.download") +def test_save_with_parameter_of_run_type( + s3_source_dir_download, s3_source_dir_upload, mock_tc_save +): + session = Mock() + s3_base_uri = random_s3_uri() + + run = Run( + experiment_name=TEST_EXP_NAME, + run_name=TEST_RUN_NAME, + experiment_display_name=TEST_EXP_DISPLAY_NAME, + run_display_name=TEST_RUN_DISPLAY_NAME, + sagemaker_session=session, + ) + stored_function = StoredFunction( + sagemaker_session=session, s3_base_uri=s3_base_uri, s3_kms_key=KMS_KEY + ) + with pytest.raises(SerializationError) as e: + stored_function.save(log_bigger, 1, 2, run) + assert isinstance(e.__cause__, NotImplementedError) diff --git a/tests/unit/sagemaker/remote_function/runtime_environment/test_bootstrap_runtime_environment.py b/tests/unit/sagemaker/remote_function/runtime_environment/test_bootstrap_runtime_environment.py new file mode 100644 index 0000000000..cf88775e49 --- /dev/null +++ b/tests/unit/sagemaker/remote_function/runtime_environment/test_bootstrap_runtime_environment.py @@ -0,0 +1,206 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +from mock import patch, Mock +from sagemaker.remote_function.runtime_environment.runtime_environment_manager import ( + RuntimeEnvironmentError, +) + +import sagemaker.remote_function.runtime_environment.bootstrap_runtime_environment as bootstrap +import pathlib + +TEST_JOB_CONDA_ENV = "conda_env" +CURR_WORKING_DIR = "/user/set/workdir" +TEST_DEPENDENCIES_PATH = "/user/set/workdir/sagemaker_remote_function_workspace" +TEST_PYTHON_VERSION = "3.10" +TEST_WORKSPACE_ARCHIVE_DIR_PATH = "/opt/ml/input/data/sm_rf_user_ws" +TEST_WORKSPACE_ARCHIVE_PATH = "/opt/ml/input/data/sm_rf_user_ws/workspace.zip" + + +def mock_args(): + args = Mock() + args.job_conda_env = TEST_JOB_CONDA_ENV + args.client_python_version = TEST_PYTHON_VERSION + return args + + +@patch( + "sagemaker.remote_function.runtime_environment.bootstrap_runtime_environment._parse_agrs", + new=mock_args, +) +@patch( + "sagemaker.remote_function.runtime_environment.runtime_environment_manager." + "RuntimeEnvironmentManager._validate_python_version" +) +@patch("sys.exit") +@patch("shutil.unpack_archive", Mock()) +@patch("os.getcwd", return_value=CURR_WORKING_DIR) +@patch("os.path.exists", return_value=True) +@patch("os.path.isfile", return_value=True) +@patch("os.listdir", return_value=["fileA.py", "fileB.sh", "requirements.txt"]) +@patch( + "sagemaker.remote_function.runtime_environment.runtime_environment_manager." + "RuntimeEnvironmentManager.run_pre_exec_script" +) +@patch( + "sagemaker.remote_function.runtime_environment.runtime_environment_manager.RuntimeEnvironmentManager.bootstrap" +) +def test_main_success( + bootstrap_runtime, + run_pre_exec_script, + list_dir, + file_exists, + path_exists, + getcwd, + _exit_process, + validate_python, +): + bootstrap.main() + validate_python.assert_called_once_with(TEST_PYTHON_VERSION, TEST_JOB_CONDA_ENV) + path_exists.assert_called_once_with(TEST_WORKSPACE_ARCHIVE_DIR_PATH) + file_exists.assert_called_once_with(TEST_WORKSPACE_ARCHIVE_PATH) + getcwd.assert_called() + list_dir.assert_called_once_with(pathlib.Path(TEST_DEPENDENCIES_PATH)) + run_pre_exec_script.assert_called(), + bootstrap_runtime.assert_called() + _exit_process.assert_called_with(0) + + +@patch( + "sagemaker.remote_function.runtime_environment.bootstrap_runtime_environment._parse_agrs", + new=mock_args, +) +@patch( + "sagemaker.remote_function.runtime_environment.runtime_environment_manager." + "RuntimeEnvironmentManager._validate_python_version" +) +@patch("sys.exit") +@patch( + "sagemaker.remote_function.runtime_environment.bootstrap_runtime_environment._write_failure_reason_file" +) +@patch( + "sagemaker.remote_function.runtime_environment.runtime_environment_manager." + "RuntimeEnvironmentManager.run_pre_exec_script" +) +@patch( + "sagemaker.remote_function.runtime_environment.bootstrap_runtime_environment._bootstrap_runtime_environment" +) +def test_main_failure( + bootstrap_runtime, run_pre_exec_script, write_failure, _exit_process, validate_python +): + runtime_err = RuntimeEnvironmentError("some failure reason") + bootstrap_runtime.side_effect = runtime_err + + bootstrap.main() + + validate_python.assert_called_once_with(TEST_PYTHON_VERSION, TEST_JOB_CONDA_ENV) + run_pre_exec_script.assert_not_called() + bootstrap_runtime.assert_called() + write_failure.assert_called_with(str(runtime_err)) + _exit_process.assert_called_with(1) + + +@patch( + "sagemaker.remote_function.runtime_environment.bootstrap_runtime_environment._parse_agrs", + new=mock_args, +) +@patch("sys.exit") +@patch( + "sagemaker.remote_function.runtime_environment.runtime_environment_manager." + "RuntimeEnvironmentManager._validate_python_version" +) +@patch( + "sagemaker.remote_function.runtime_environment.bootstrap_runtime_environment._write_failure_reason_file" +) +@patch("os.path.exists", return_value=False) +def test_main_channel_folder_does_not_exist( + path_exists, write_failure, validate_python, _exit_process +): + bootstrap.main() + path_exists.assert_called_once_with(TEST_WORKSPACE_ARCHIVE_DIR_PATH) + validate_python.assert_called_once_with(TEST_PYTHON_VERSION, TEST_JOB_CONDA_ENV) + write_failure.assert_not_called() + _exit_process.assert_called_with(0) + + +@patch( + "sagemaker.remote_function.runtime_environment.bootstrap_runtime_environment._parse_agrs", + new=mock_args, +) +@patch( + "sagemaker.remote_function.runtime_environment.runtime_environment_manager" + ".RuntimeEnvironmentManager._validate_python_version" +) +@patch("sys.exit") +@patch("os.path.exists", return_value=True) +@patch("os.path.isfile", return_value=False) +@patch( + "sagemaker.remote_function.runtime_environment.runtime_environment_manager." + "RuntimeEnvironmentManager.run_pre_exec_script" +) +@patch( + "sagemaker.remote_function.runtime_environment.runtime_environment_manager.RuntimeEnvironmentManager.bootstrap" +) +def test_main_no_workspace_archive( + bootstrap_runtime, run_pre_exec_script, file_exists, path_exists, _exit_process, validate_python +): + bootstrap.main() + validate_python.assert_called_once_with(TEST_PYTHON_VERSION, TEST_JOB_CONDA_ENV) + path_exists.assert_called_once_with(TEST_WORKSPACE_ARCHIVE_DIR_PATH) + file_exists.assert_called_once_with(TEST_WORKSPACE_ARCHIVE_PATH) + run_pre_exec_script.assert_not_called() + bootstrap_runtime.assert_not_called() + _exit_process.assert_called_with(0) + + +@patch( + "sagemaker.remote_function.runtime_environment.bootstrap_runtime_environment._parse_agrs", + new=mock_args, +) +@patch( + "sagemaker.remote_function.runtime_environment.runtime_environment_manager." + "RuntimeEnvironmentManager._validate_python_version" +) +@patch("sys.exit") +@patch("shutil.unpack_archive", Mock()) +@patch("os.path.exists", return_value=True) +@patch("os.path.isfile", return_value=True) +@patch("os.getcwd", return_value=CURR_WORKING_DIR) +@patch("os.listdir", return_value=["fileA.py", "fileB.sh"]) +@patch( + "sagemaker.remote_function.runtime_environment.runtime_environment_manager." + "RuntimeEnvironmentManager.run_pre_exec_script" +) +@patch( + "sagemaker.remote_function.runtime_environment.runtime_environment_manager.RuntimeEnvironmentManager.bootstrap" +) +def test_main_no_dependency_file( + bootstrap_runtime, + run_pre_exec_script, + list_dir, + get_cwd, + file_exists, + path_exists, + _exit_process, + validate_python, +): + bootstrap.main() + validate_python.assert_called_once_with(TEST_PYTHON_VERSION, TEST_JOB_CONDA_ENV) + path_exists.assert_called_once_with(TEST_WORKSPACE_ARCHIVE_DIR_PATH) + file_exists.assert_called_once_with(TEST_WORKSPACE_ARCHIVE_PATH) + get_cwd.assert_called_once() + list_dir.assert_called_once_with(pathlib.Path(TEST_DEPENDENCIES_PATH)) + run_pre_exec_script.assert_called() + bootstrap_runtime.assert_not_called() + _exit_process.assert_called_with(0) diff --git a/tests/unit/sagemaker/remote_function/runtime_environment/test_runtime_environment_manager.py b/tests/unit/sagemaker/remote_function/runtime_environment/test_runtime_environment_manager.py new file mode 100644 index 0000000000..516140c4da --- /dev/null +++ b/tests/unit/sagemaker/remote_function/runtime_environment/test_runtime_environment_manager.py @@ -0,0 +1,415 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +import pytest +from mock import patch, Mock +import sys +import shlex +import os + +from sagemaker.remote_function.runtime_environment.runtime_environment_manager import ( + RuntimeEnvironmentManager, + RuntimeEnvironmentError, +) + +TEST_REQUIREMENTS_TXT = "usr/local/requirements.txt" +TEST_CONDA_YML = "usr/local/conda_env.yml" +CLIENT_PYTHON_VERSION = "3.10" + + +def test_snapshot_no_dependencies(): + response = RuntimeEnvironmentManager().snapshot(dependencies=None) + assert response is None + + +@patch("os.path.isfile", return_value=True) +def test_snapshot_with_requirements_txt(isfile): + response = RuntimeEnvironmentManager().snapshot(TEST_REQUIREMENTS_TXT) + isfile.assert_called_once_with(TEST_REQUIREMENTS_TXT) + assert response == TEST_REQUIREMENTS_TXT + + +@patch("os.path.isfile", return_value=True) +def test_snapshot_with_conda_yml(isfile): + response = RuntimeEnvironmentManager().snapshot(TEST_CONDA_YML) + isfile.assert_called_once_with(TEST_CONDA_YML) + assert response == TEST_CONDA_YML + + +@patch("os.path.isfile", return_value=False) +def test_snapshot_file_not_exists(isfile): + with pytest.raises(ValueError): + RuntimeEnvironmentManager().snapshot(TEST_REQUIREMENTS_TXT) + + isfile.assert_called_once_with(TEST_REQUIREMENTS_TXT) + + +def test_snapshot_invalid_depedencies(): + + # scenario 1: invalid file format + invalid_dependencies_file = "usr/local/requirements.py" + with pytest.raises(ValueError): + RuntimeEnvironmentManager().snapshot(invalid_dependencies_file) + + # scenario 2: invalid keyword + invalid_dependencies = "from_some_invalid_keyword" + with pytest.raises(ValueError): + RuntimeEnvironmentManager().snapshot(invalid_dependencies) + + +def test__get_conda_env_name(): + with patch("os.getenv") as getenv_patch: + getenv_patch.return_value = "some-conda-env-name" + + result = RuntimeEnvironmentManager()._get_active_conda_env_name() + + assert result == "some-conda-env-name" + call_arg = getenv_patch.call_args[0][0] + assert call_arg == "CONDA_DEFAULT_ENV" + + +def test__get_active_conda_env_prefix(): + with patch("os.getenv") as getenv_patch: + getenv_patch.return_value = "some-conda-prefix" + + result = RuntimeEnvironmentManager()._get_active_conda_env_prefix() + + assert result == "some-conda-prefix" + call_arg = getenv_patch.call_args[0][0] + assert call_arg == "CONDA_PREFIX" + + +@patch( + "sagemaker.remote_function.runtime_environment.runtime_environment_manager._log_error", Mock() +) +@patch( + "sagemaker.remote_function.runtime_environment.runtime_environment_manager._log_output", Mock() +) +@patch( + "sagemaker.remote_function.runtime_environment.runtime_environment_manager" + ".RuntimeEnvironmentManager._get_conda_exe", + return_value="some_exe", +) +@patch( + "sagemaker.remote_function.runtime_environment.runtime_environment_manager" + ".RuntimeEnvironmentManager._get_active_conda_env_name", + return_value="test_env", +) +@patch( + "sagemaker.remote_function.runtime_environment.runtime_environment_manager" + ".RuntimeEnvironmentManager._get_active_conda_env_prefix", + return_value="/some/conda/env/prefix", +) +def test_snapshot_from_active_conda_env_when_name_available( + conda_env_prefix, conda_default_env, stub_conda_exe +): + expected_result = os.path.join(os.getcwd(), "env_snapshot.yml") + with patch("subprocess.Popen") as popen: + popen.return_value.wait.return_value = 0 + + result = RuntimeEnvironmentManager().snapshot("auto_capture") + assert result == expected_result + + call_args = popen.call_args[0][0] + assert call_args is not None + expected_cmd = ( + f"{stub_conda_exe.return_value} env export -p {conda_env_prefix.return_value} " + f"--no-builds > {expected_result}" + ) + assert call_args == expected_cmd + + +@patch( + "sagemaker.remote_function.runtime_environment.runtime_environment_manager._log_error", Mock() +) +@patch( + "sagemaker.remote_function.runtime_environment.runtime_environment_manager._log_output", Mock() +) +@patch( + "sagemaker.remote_function.runtime_environment.runtime_environment_manager" + ".RuntimeEnvironmentManager._get_conda_exe", + return_value="conda", +) +@patch( + "sagemaker.remote_function.runtime_environment.runtime_environment_manager" + ".RuntimeEnvironmentManager._get_active_conda_env_name", + return_value=None, +) +@patch( + "sagemaker.remote_function.runtime_environment.runtime_environment_manager" + ".RuntimeEnvironmentManager._get_active_conda_env_prefix", + return_value="/some/conda/env/prefix", +) +def test_snapshot_from_active_conda_env_when_prefix_available( + conda_env_prefix, no_conda_env_name, conda_env +): + expected_result = os.path.join(os.getcwd(), "env_snapshot.yml") + with patch("subprocess.Popen") as popen: + popen.return_value.wait.return_value = 0 + + result = RuntimeEnvironmentManager().snapshot("auto_capture") + assert result == expected_result + + call_args = popen.call_args[0][0] + assert call_args is not None + expected_cmd = "{} env export -p {} --no-builds > {}".format( + conda_env.return_value, conda_env_prefix.return_value, expected_result + ) + assert call_args == expected_cmd + + +@patch( + "sagemaker.remote_function.runtime_environment.runtime_environment_manager._log_error", Mock() +) +@patch( + "sagemaker.remote_function.runtime_environment.runtime_environment_manager._log_output", Mock() +) +@patch( + "sagemaker.remote_function.runtime_environment.runtime_environment_manager" + ".RuntimeEnvironmentManager._get_active_conda_env_name", + return_value=None, +) +@patch( + "sagemaker.remote_function.runtime_environment.runtime_environment_manager" + ".RuntimeEnvironmentManager._get_active_conda_env_prefix", + return_value=None, +) +def test_snapshot_auto_capture_no_active_conda_env(no_conda_env_prefix, no_conda_env_name): + with pytest.raises(ValueError): + RuntimeEnvironmentManager().snapshot("auto_capture") + + +@patch( + "sagemaker.remote_function.runtime_environment.runtime_environment_manager._log_error", Mock() +) +@patch( + "sagemaker.remote_function.runtime_environment.runtime_environment_manager._log_output", Mock() +) +def test_bootstrap_req_txt(): + with patch("subprocess.Popen") as popen: + popen.return_value.wait.return_value = 0 + RuntimeEnvironmentManager().bootstrap(TEST_REQUIREMENTS_TXT, CLIENT_PYTHON_VERSION) + python_exe = sys.executable + call_args = popen.call_args[0][0] + assert call_args is not None + + expected_cmd = "{} -m pip install -r {}".format(python_exe, TEST_REQUIREMENTS_TXT) + assert call_args == expected_cmd + + +@patch( + "sagemaker.remote_function.runtime_environment.runtime_environment_manager._log_error", Mock() +) +@patch( + "sagemaker.remote_function.runtime_environment.runtime_environment_manager._log_output", Mock() +) +def test_bootstrap_req_txt_error(): + with patch("subprocess.Popen") as popen: + popen.return_value.wait.return_value = 1 + + with pytest.raises(RuntimeEnvironmentError): + RuntimeEnvironmentManager().bootstrap(TEST_REQUIREMENTS_TXT, CLIENT_PYTHON_VERSION) + + python_exe = sys.executable + call_args = popen.call_args[0][0] + assert call_args is not None + + expected_cmd = "{} -m pip install -r {}".format(python_exe, TEST_REQUIREMENTS_TXT) + assert call_args == expected_cmd + + +@patch( + "sagemaker.remote_function.runtime_environment.runtime_environment_manager._log_error", Mock() +) +@patch( + "sagemaker.remote_function.runtime_environment.runtime_environment_manager._log_output", Mock() +) +@patch( + "sagemaker.remote_function.runtime_environment.runtime_environment_manager" + ".RuntimeEnvironmentManager._write_conda_env_to_file", + Mock(), +) +@patch( + "sagemaker.remote_function.runtime_environment.runtime_environment_manager" + ".RuntimeEnvironmentManager._get_conda_exe", + return_value="some_exe", +) +def test_bootstrap_req_txt_with_conda_env(mock_conda_exe): + with patch("subprocess.Popen") as popen: + popen.return_value.wait.return_value = 0 + job_conda_env = "conda_env" + RuntimeEnvironmentManager().bootstrap( + TEST_REQUIREMENTS_TXT, CLIENT_PYTHON_VERSION, job_conda_env + ) + + call_args = popen.call_args[0][0] + assert call_args is not None + + expected_cmd = f"{mock_conda_exe.return_value} run -n conda_env pip install -r usr/local/requirements.txt" + assert call_args == expected_cmd + + +@patch( + "sagemaker.remote_function.runtime_environment.runtime_environment_manager._log_error", Mock() +) +@patch( + "sagemaker.remote_function.runtime_environment.runtime_environment_manager._log_output", Mock() +) +@patch( + "sagemaker.remote_function.runtime_environment.runtime_environment_manager" + ".RuntimeEnvironmentManager._write_conda_env_to_file", + Mock(), +) +@patch( + "sagemaker.remote_function.runtime_environment.runtime_environment_manager." + "RuntimeEnvironmentManager._validate_python_version" +) +@patch( + "sagemaker.remote_function.runtime_environment.runtime_environment_manager" + ".RuntimeEnvironmentManager._get_conda_exe", + return_value="some_exe", +) +def test_bootstrap_conda_yml_create_env(mock_conda_exe, mock_validate_python): + with patch("subprocess.Popen") as popen: + popen.return_value.wait.return_value = 0 + + RuntimeEnvironmentManager().bootstrap(TEST_CONDA_YML, CLIENT_PYTHON_VERSION) + + call_args = popen.call_args[0][0] + assert call_args is not None + + expected_cmd = f"{mock_conda_exe.return_value} env create -n sagemaker-runtime-env --file {TEST_CONDA_YML}" + assert call_args == expected_cmd + mock_validate_python.assert_called_once_with(CLIENT_PYTHON_VERSION, "sagemaker-runtime-env") + + +@patch( + "sagemaker.remote_function.runtime_environment.runtime_environment_manager._log_error", Mock() +) +@patch( + "sagemaker.remote_function.runtime_environment.runtime_environment_manager._log_output", Mock() +) +@patch( + "sagemaker.remote_function.runtime_environment.runtime_environment_manager" + ".RuntimeEnvironmentManager._write_conda_env_to_file", + Mock(), +) +@patch( + "sagemaker.remote_function.runtime_environment.runtime_environment_manager" + ".RuntimeEnvironmentManager._get_conda_exe", + return_value="conda", +) +def test_bootstrap_conda_yml_update_env(mock_conda_exe): + with patch("subprocess.Popen") as popen: + popen.return_value.wait.return_value = 0 + job_conda_env = "conda_env" + + RuntimeEnvironmentManager().bootstrap(TEST_CONDA_YML, CLIENT_PYTHON_VERSION, job_conda_env) + + call_args = popen.call_args[0][0] + assert call_args is not None + + expected_cmd = "{} env update -n {} --file {}".format( + mock_conda_exe.return_value, job_conda_env, TEST_CONDA_YML + ) + assert call_args == expected_cmd + + +@patch( + "sagemaker.remote_function.runtime_environment.runtime_environment_manager" + ".RuntimeEnvironmentManager._get_conda_exe", + return_value="conda", +) +def test_python_version_in_conda(mock_conda_exe): + with patch("subprocess.check_output") as check_output: + check_output.return_value = b"Python 3.10.7" + + job_conda_env = "conda_env" + version = RuntimeEnvironmentManager()._python_version_in_conda_env(job_conda_env) + call_args = check_output.call_args[0][0] + assert call_args is not None + + expected_cmd = f"{mock_conda_exe.return_value} run -n {job_conda_env} python --version" + assert call_args == shlex.split(expected_cmd) + assert version == "3.10" + + +@patch( + "sagemaker.remote_function.runtime_environment.runtime_environment_manager." + "RuntimeEnvironmentManager._python_version_in_conda_env", + return_value="3.10", +) +def test_validate_python_version(python_version_in_conda_env): + try: + RuntimeEnvironmentManager()._validate_python_version(CLIENT_PYTHON_VERSION, "conda_env") + except Exception: + pytest.raises("Unexpected error") + + +@patch( + "sagemaker.remote_function.runtime_environment.runtime_environment_manager." + "RuntimeEnvironmentManager._python_version_in_conda_env", + return_value="3.9", +) +def test_validate_python_version_error(python_version_in_conda_env): + with pytest.raises(RuntimeEnvironmentError): + RuntimeEnvironmentManager()._validate_python_version(CLIENT_PYTHON_VERSION, "conda_env") + + +@patch("os.path.isfile", return_value=True) +@patch( + "sagemaker.remote_function.runtime_environment.runtime_environment_manager._log_error", Mock() +) +@patch( + "sagemaker.remote_function.runtime_environment.runtime_environment_manager._log_output", Mock() +) +def test_run_pre_exec_script(isfile): + with patch("subprocess.Popen") as popen: + popen.return_value.wait.return_value = 0 + RuntimeEnvironmentManager().run_pre_exec_script(pre_exec_script_path="path/to/pre_exec.sh") + call_args = popen.call_args[0][0] + expected_cmd = ["/bin/bash", "-eu", "path/to/pre_exec.sh"] + assert call_args == expected_cmd + + +@patch("os.path.isfile", return_value=False) +@patch( + "sagemaker.remote_function.runtime_environment.runtime_environment_manager._log_error", Mock() +) +@patch( + "sagemaker.remote_function.runtime_environment.runtime_environment_manager._log_output", Mock() +) +def test_run_pre_exec_script_no_script(isfile): + with patch("subprocess.Popen") as popen: + RuntimeEnvironmentManager().run_pre_exec_script(pre_exec_script_path="path/to/pre_exec.sh") + popen.assert_not_called() + + +@patch("os.path.isfile", return_value=True) +@patch( + "sagemaker.remote_function.runtime_environment.runtime_environment_manager._log_error", Mock() +) +@patch( + "sagemaker.remote_function.runtime_environment.runtime_environment_manager._log_output", Mock() +) +def test_run_pre_exec_script_cmd_error(isfile): + with patch("subprocess.Popen") as popen: + popen.return_value.wait.return_value = 1 + with pytest.raises(RuntimeEnvironmentError): + RuntimeEnvironmentManager().run_pre_exec_script( + pre_exec_script_path="path/to/pre_exec.sh" + ) + call_args = popen.call_args[0][0] + expected_cmd = ["/bin/bash", "-eu", "path/to/pre_exec.sh"] + assert call_args == expected_cmd diff --git a/tests/unit/sagemaker/remote_function/test_client.py b/tests/unit/sagemaker/remote_function/test_client.py new file mode 100644 index 0000000000..6d4541a0df --- /dev/null +++ b/tests/unit/sagemaker/remote_function/test_client.py @@ -0,0 +1,1334 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +import threading +import time + +import pytest +from mock import MagicMock, patch, Mock, ANY, call +from sagemaker.exceptions import UnexpectedStatusException + +from botocore.exceptions import ClientError +from sagemaker import Session +from sagemaker.experiments.experiment import _Experiment +from sagemaker.experiments.run import Run +from sagemaker.remote_function.client import ( + remote, + RemoteExecutor, + Future, + get_future, + list_futures, +) +from sagemaker.remote_function.errors import DeserializationError, RemoteFunctionError, ServiceError +from sagemaker.remote_function.runtime_environment.runtime_environment_manager import ( + RuntimeEnvironmentError, +) +from sagemaker.remote_function.job import _RunInfo + +from tests.unit.sagemaker.experiments.helpers import ( + mock_tc_load_or_create_func, + mock_trial_load_or_create_func, +) + +TRAINING_JOB_ARN = "training-job-arn" +TRAINING_JOB_NAME = "job-name" +IMAGE = "image_uri" +BUCKET = "my-s3-bucket" +S3_URI = f"s3://{BUCKET}/keyprefix" +EXPECTED_JOB_RESULT = [1, 2, 3] +PATH_TO_SRC_DIR = "path/to/src/dir" + + +def describe_training_job_response(job_status): + return { + "TrainingJobName": TRAINING_JOB_NAME, + "TrainingJobArn": TRAINING_JOB_ARN, + "TrainingJobStatus": job_status, + "ResourceConfig": { + "InstanceCount": 1, + "InstanceType": "ml.c4.xlarge", + "VolumeSizeInGB": 30, + }, + "OutputDataConfig": {"S3OutputPath": "s3://sagemaker-123/image_uri/output"}, + } + + +COMPLETED_TRAINING_JOB = describe_training_job_response("Completed") +INPROGRESS_TRAINING_JOB = describe_training_job_response("InProgress") +CANCELLED_TRAINING_JOB = describe_training_job_response("Stopped") +FAILED_TRAINING_JOB = describe_training_job_response("Failed") + +API_CALL_LIMIT = { + "SubmittingIntervalInSecs": 0.005, + "MinBatchPollingIntervalInSecs": 0.01, + "PollingIntervalInSecs": 0.01, +} + + +@pytest.fixture +def client(): + client_mock = Mock() + client_mock._client_config.user_agent = ( + "Boto3/1.14.24 Python/3.8.5 Linux/5.4.0-42-generic Botocore/1.17.24 Resource" + ) + return client_mock + + +@pytest.fixture +def sagemaker_session(client): + return Session( + sagemaker_client=client, + ) + + +@pytest.fixture +def run_obj(sagemaker_session): + client = sagemaker_session.sagemaker_client + client.update_trial_component.return_value = {} + client.associate_trial_component.return_value = {} + with patch( + "sagemaker.experiments.run._Experiment._load_or_create", + MagicMock( + return_value=_Experiment( + experiment_name="test-exp", sagemaker_session=sagemaker_session + ) + ), + ): + with patch( + "sagemaker.experiments.run._TrialComponent._load_or_create", + MagicMock(side_effect=mock_tc_load_or_create_func), + ): + with patch( + "sagemaker.experiments.run._Trial._load_or_create", + MagicMock(side_effect=mock_trial_load_or_create_func), + ): + run = Run( + experiment_name="test-exp", + sagemaker_session=sagemaker_session, + ) + run._artifact_uploader = Mock() + run._lineage_artifact_tracker = Mock() + run._metrics_manager = Mock() + + return run + + +def create_mock_job(job_name, describe_return): + mock_job = Mock(job_name=job_name, s3_uri=S3_URI) + mock_job.describe.return_value = describe_return + + return mock_job + + +def job_function(a, b=1, *, c, d=3): + return a * b * c * d + + +def job_function2(a, b): + # uses positional-only args + return a**b + + +def inner_func_0(): + return 1 / 0 + + +def inner_func_1(): + return inner_func_0() + + +def inner_func_2(): + raise ValueError("some value error") + + +@patch( + "sagemaker.remote_function.client.serialization.deserialize_obj_from_s3", + return_value=EXPECTED_JOB_RESULT, +) +@patch("sagemaker.remote_function.client._JobSettings") +@patch("sagemaker.remote_function.client._Job.start") +def test_decorator(mock_start, mock_job_settings, mock_deserialize_obj_from_s3): + mock_job = Mock(job_name=TRAINING_JOB_NAME) + mock_job.describe.return_value = COMPLETED_TRAINING_JOB + + mock_start.return_value = mock_job + + @remote(image_uri=IMAGE, s3_root_uri=S3_URI) + def square(x): + return x * x + + result = square(5) + assert result == EXPECTED_JOB_RESULT + assert mock_job_settings.call_args.kwargs["image_uri"] == IMAGE + + +@patch( + "sagemaker.remote_function.client.serialization.deserialize_exception_from_s3", + return_value=ZeroDivisionError("division by zero"), +) +@patch("sagemaker.remote_function.client._JobSettings") +@patch("sagemaker.remote_function.client._Job.start") +def test_decorator_underlying_job_failed_remote_error_client_function( + mock_start, mock_job_settings, mock_deserialize_exception_from_s3 +): + mock_job = Mock(job_name=TRAINING_JOB_NAME) + mock_job.describe.return_value = FAILED_TRAINING_JOB + + mock_start.return_value = mock_job + mock_job.wait.side_effect = UnexpectedStatusException( + message="some message", + allowed_statuses=["Completed", "Stopped"], + actual_status="Failed", + ) + + @remote(image_uri=IMAGE, s3_root_uri=S3_URI) + def square(x): + return x * x + + with pytest.raises(ZeroDivisionError, match=r"division by zero"): + square(5) + + +@patch("sagemaker.s3.S3Downloader.read_bytes") +@patch("sagemaker.remote_function.client._JobSettings") +@patch("sagemaker.remote_function.client._Job.start") +def test_decorator_underlying_job_failed_no_exception_in_s3( + mock_start, mock_job_settings, read_bytes +): + mock_job = Mock(job_name=TRAINING_JOB_NAME) + mock_job.describe.return_value = FAILED_TRAINING_JOB + read_bytes.side_effect = ClientError( + error_response={"Error": {"Code": "404", "Message": "Not Found"}}, + operation_name="HeadObject", + ) + + mock_start.return_value = mock_job + mock_job.wait.side_effect = UnexpectedStatusException( + message="some message", + allowed_statuses=["Completed", "Stopped"], + actual_status="Failed", + ) + + @remote(image_uri=IMAGE, s3_root_uri=S3_URI) + def square(x): + return x * x + + with pytest.raises( + RemoteFunctionError, + match=r"Failed to execute remote function. Check corresponding job for details.", + ): + square(5) + + +@patch("sagemaker.s3.S3Downloader.read_bytes") +@patch("sagemaker.remote_function.client._JobSettings") +@patch("sagemaker.remote_function.client._Job.start") +def test_decorator_underlying_job_failed_runtime_environment_error( + mock_start, mock_job_settings, read_bytes +): + failed_training_job = FAILED_TRAINING_JOB.copy() + failed_training_job.update( + {"FailureReason": "RuntimeEnvironmentError: failure while installing dependencies."} + ) + mock_job = Mock(job_name=TRAINING_JOB_NAME) + mock_job.describe.return_value = failed_training_job + read_bytes.side_effect = ClientError( + error_response={"Error": {"Code": "404", "Message": "Not Found"}}, + operation_name="HeadObject", + ) + + mock_start.return_value = mock_job + mock_job.wait.side_effect = UnexpectedStatusException( + message="some message", + allowed_statuses=["Completed", "Stopped"], + actual_status="Failed", + ) + + @remote(image_uri=IMAGE, s3_root_uri=S3_URI) + def square(x): + return x * x + + with pytest.raises( + RuntimeEnvironmentError, + match=r"failure while installing dependencies.", + ): + square(5) + + +@patch("sagemaker.s3.S3Downloader.read_bytes") +@patch("sagemaker.remote_function.client._JobSettings") +@patch("sagemaker.remote_function.client._Job.start") +def test_decorator_job_failed_failure_reason_without_runtime_environment_error( + mock_start, mock_job_settings, read_bytes +): + failed_training_job = FAILED_TRAINING_JOB.copy() + failed_training_job.update({"FailureReason": "failure while installing dependencies."}) + mock_job = Mock(job_name=TRAINING_JOB_NAME) + mock_job.describe.return_value = failed_training_job + read_bytes.side_effect = ClientError( + error_response={"Error": {"Code": "404", "Message": "Not Found"}}, + operation_name="HeadObject", + ) + + mock_start.return_value = mock_job + mock_job.wait.side_effect = UnexpectedStatusException( + message="some message", + allowed_statuses=["Completed", "Stopped"], + actual_status="Failed", + ) + + @remote(image_uri=IMAGE, s3_root_uri=S3_URI) + def square(x): + return x * x + + with pytest.raises( + RemoteFunctionError, + match=r"Failed to execute remote function. Check corresponding job for details.", + ): + square(5) + + +@patch("sagemaker.s3.S3Downloader.read_bytes") +@patch("sagemaker.remote_function.client._JobSettings") +@patch("sagemaker.remote_function.client._Job.start") +def test_decorator_underlying_job_failed_local_error_service_error( + mock_start, mock_job_settings, read_bytes +): + mock_job = Mock(job_name=TRAINING_JOB_NAME) + mock_job.describe.return_value = FAILED_TRAINING_JOB + re = RuntimeError("some error when reading from s3") + read_bytes.side_effect = re + + mock_start.return_value = mock_job + mock_job.wait.side_effect = UnexpectedStatusException( + message="some message", + allowed_statuses=["Completed", "Stopped"], + actual_status="Failed", + ) + + @remote(image_uri=IMAGE, s3_root_uri=S3_URI) + def square(x): + return x * x + + with pytest.raises( + ServiceError, + match=r"Failed to read serialized bytes from .+: RuntimeError\('some error when reading from s3'\)", + ): + square(5) + + +@patch( + "sagemaker.remote_function.client.serialization.deserialize_exception_from_s3", + side_effect=DeserializationError("Failed to deserialize the exception."), +) +@patch("sagemaker.remote_function.client._JobSettings") +@patch("sagemaker.remote_function.client._Job.start") +def test_decorator_underlying_job_failed_local_error_remote_function_error( + mock_start, mock_job_settings, mock_deserialize_exception_from_s3 +): + mock_job = Mock(job_name=TRAINING_JOB_NAME) + mock_job.describe.return_value = FAILED_TRAINING_JOB + + mock_start.return_value = mock_job + mock_job.wait.side_effect = UnexpectedStatusException( + message="some message", + allowed_statuses=["Completed", "Stopped"], + actual_status="Failed", + ) + + @remote(image_uri=IMAGE, s3_root_uri=S3_URI) + def square(x): + return x * x + + with pytest.raises( + DeserializationError, + match=r"Failed to deserialize the exception.", + ): + square(5) + assert mock_job_settings.call_args.kwargs["image_uri"] == IMAGE + + +@patch("sagemaker.remote_function.client._JobSettings") +@patch("sagemaker.remote_function.client._Job.start") +def test_decorator_underlying_job_stopped_somehow(mock_start, mock_job_settings): + mock_job = Mock(job_name=TRAINING_JOB_NAME) + mock_job.describe.return_value = CANCELLED_TRAINING_JOB + + mock_start.return_value = mock_job + + @remote(image_uri=IMAGE, s3_root_uri=S3_URI) + def square(x): + return x * x + + with pytest.raises(RemoteFunctionError, match=r"Job for remote function has been aborted."): + square(5) + + +@patch("sagemaker.remote_function.client._JobSettings") +@patch("sagemaker.remote_function.client._Job.start") +def test_decorator_underlying_job_timed_out(mock_start, mock_job_settings): + mock_job = Mock(job_name=TRAINING_JOB_NAME) + mock_job.describe.return_value = INPROGRESS_TRAINING_JOB + + mock_start.return_value = mock_job + mock_job.wait.side_effect = UnexpectedStatusException( + message="some message", + allowed_statuses=["Completed", "Stopped"], + actual_status="InProgress", + ) + + @remote(image_uri=IMAGE, s3_root_uri=S3_URI) + def square(x): + return x * x + + with pytest.raises( + TimeoutError, + match=r"Job for remote function timed out before reaching a termination status.", + ): + square(5) + + +@patch( + "sagemaker.remote_function.core.serialization.deserialize_obj_from_s3", + return_value=EXPECTED_JOB_RESULT, +) +@patch("sagemaker.remote_function.client._JobSettings") +@patch("sagemaker.remote_function.client._Job.start") +def test_decorator_no_arguments(mock_start, mock_job_settings, mock_deserialize): + mock_job = Mock(job_name=TRAINING_JOB_NAME) + mock_job.describe.return_value = COMPLETED_TRAINING_JOB + + mock_start.return_value = mock_job + + @remote + def square(x): + return x * x + + result = square(5) + assert result == EXPECTED_JOB_RESULT + assert mock_job_settings.call_args.kwargs["image_uri"] is None + + +@pytest.mark.parametrize( + "args, kwargs, error_message", + [ + ( + [1, 2, 3], + {}, + "decorated_function() missing 2 required keyword-only arguments: 'd', and 'e'", + ), + ([1, 2, 3], {"d": 4}, "decorated_function() missing 1 required keyword-only argument: 'e'"), + ( + [1, 2, 3], + {"d": 3, "e": 4, "g": "extra_arg"}, + "decorated_function() got an unexpected keyword argument 'g'", + ), + ( + [], + {"c": 3, "d": 4}, + "decorated_function() missing 2 required positional arguments: 'a', and 'b'", + ), + ([1], {"c": 3, "d": 4}, "decorated_function() missing 1 required positional argument: 'b'"), + ( + [1, 2, 3, "extra_arg"], + {"d": 3, "e": 4}, + "decorated_function() takes 3 positional arguments but 4 were given.", + ), + ([], {"a": 1, "b": 2, "d": 3, "e": 2}, None), + ( + (1, 2), + {"a": 1, "c": 3, "d": 2}, + "decorated_function() got multiple values for argument 'a'", + ), + ( + (1, 2), + {"b": 1, "c": 3, "d": 2}, + "decorated_function() got multiple values for argument 'b'", + ), + ], +) +@patch("sagemaker.remote_function.client._JobSettings") +@patch("sagemaker.remote_function.client._Job.start") +def test_decorator_invalid_function_args( + mock_job_start, mock_job_settings, args, kwargs, error_message +): + @remote(image_uri=IMAGE, s3_root_uri=S3_URI) + def decorated_function(a, b, c=1, *, d, e, f=3): + return a * b * c * d * e * f + + if error_message: + with pytest.raises(TypeError) as e: + decorated_function(*args, **kwargs) + assert error_message in str(e.value) + else: + try: + decorated_function(*args, **kwargs) + except Exception as ex: + pytest.fail("Unexpected Exception: " + str(ex)) + + +def test_executor_invalid_arguments(): + with pytest.raises(ValueError): + with RemoteExecutor(max_parallel_jobs=0, s3_root_uri="s3://bucket/") as e: + e.submit(job_function, 1, 2, c=3, d=4) + + +@patch("sagemaker.remote_function.client._JobSettings") +def test_executor_submit_after_shutdown(*args): + with pytest.raises(RuntimeError): + with RemoteExecutor(max_parallel_jobs=1, s3_root_uri="s3://bucket/") as e: + pass + e.submit(job_function, 1, 2, c=3, d=4) + + +@pytest.mark.parametrize("parallelism", [1, 2, 3, 4]) +@patch("sagemaker.remote_function.client._API_CALL_LIMIT", new=API_CALL_LIMIT) +@patch("sagemaker.remote_function.client._JobSettings") +@patch("sagemaker.remote_function.client._Job.start") +def test_executor_submit_happy_case(mock_start, mock_job_settings, parallelism): + mock_job_1 = create_mock_job("job_1", COMPLETED_TRAINING_JOB) + mock_job_2 = create_mock_job("job_2", COMPLETED_TRAINING_JOB) + mock_job_3 = create_mock_job("job_3", COMPLETED_TRAINING_JOB) + mock_job_4 = create_mock_job("job_4", COMPLETED_TRAINING_JOB) + mock_start.side_effect = [mock_job_1, mock_job_2, mock_job_3, mock_job_4] + + with RemoteExecutor(max_parallel_jobs=parallelism, s3_root_uri="s3://bucket/") as e: + future_1 = e.submit(job_function, 1, 2, c=3, d=4) + future_2 = e.submit(job_function, 5, 6, c=7, d=8) + future_3 = e.submit(job_function, 9, 10, c=11, d=12) + future_4 = e.submit(job_function, 13, 14, c=15, d=16) + + mock_start.assert_has_calls( + [ + call(ANY, job_function, (1, 2), {"c": 3, "d": 4}, None), + call(ANY, job_function, (5, 6), {"c": 7, "d": 8}, None), + call(ANY, job_function, (9, 10), {"c": 11, "d": 12}, None), + call(ANY, job_function, (13, 14), {"c": 15, "d": 16}, None), + ] + ) + mock_job_1.describe.assert_called() + mock_job_2.describe.assert_called() + mock_job_3.describe.assert_called() + mock_job_4.describe.assert_called() + + assert future_1.done() + assert future_2.done() + assert future_3.done() + assert future_4.done() + + +@patch("sagemaker.remote_function.client._API_CALL_LIMIT", new=API_CALL_LIMIT) +@patch("sagemaker.remote_function.client._JobSettings") +@patch("sagemaker.remote_function.client._Job.start") +def test_executor_submit_with_run(mock_start, mock_job_settings, run_obj): + mock_job_1 = create_mock_job("job_1", COMPLETED_TRAINING_JOB) + mock_job_2 = create_mock_job("job_2", COMPLETED_TRAINING_JOB) + mock_job_3 = create_mock_job("job_3", COMPLETED_TRAINING_JOB) + mock_job_4 = create_mock_job("job_4", COMPLETED_TRAINING_JOB) + mock_start.side_effect = [mock_job_1, mock_job_2, mock_job_3, mock_job_4] + + run_info = _RunInfo(run_obj.experiment_name, run_obj.run_name) + + with run_obj: + with RemoteExecutor(max_parallel_jobs=2, s3_root_uri="s3://bucket/") as e: + future_1 = e.submit(job_function, 1, 2, c=3, d=4) + future_2 = e.submit(job_function, 5, 6, c=7, d=8) + + mock_start.assert_has_calls( + [ + call(ANY, job_function, (1, 2), {"c": 3, "d": 4}, run_info), + call(ANY, job_function, (5, 6), {"c": 7, "d": 8}, run_info), + ] + ) + mock_job_1.describe.assert_called() + mock_job_2.describe.assert_called() + + assert future_1.done() + assert future_2.done() + + with RemoteExecutor(max_parallel_jobs=2, s3_root_uri="s3://bucket/") as e: + with run_obj: + future_3 = e.submit(job_function, 9, 10, c=11, d=12) + future_4 = e.submit(job_function, 13, 14, c=15, d=16) + + mock_start.assert_has_calls( + [ + call(ANY, job_function, (9, 10), {"c": 11, "d": 12}, run_info), + call(ANY, job_function, (13, 14), {"c": 15, "d": 16}, run_info), + ] + ) + mock_job_3.describe.assert_called() + mock_job_4.describe.assert_called() + + assert future_3.done() + assert future_4.done() + + +@patch("sagemaker.remote_function.client._API_CALL_LIMIT", new=API_CALL_LIMIT) +@patch("sagemaker.remote_function.client._JobSettings") +@patch("sagemaker.remote_function.client._Job.start") +def test_executor_submit_enforcing_max_parallel_jobs(mock_start, *args): + mock_job_1 = create_mock_job("job_1", INPROGRESS_TRAINING_JOB) + mock_job_2 = create_mock_job("job_2", INPROGRESS_TRAINING_JOB) + mock_start.side_effect = [mock_job_1, mock_job_2] + + e = RemoteExecutor(max_parallel_jobs=1, s3_root_uri="s3://bucket/") + future_1 = e.submit(job_function, 1, 2, c=3, d=4) + future_2 = e.submit(job_function, 5, 6, c=7, d=8) + + time.sleep(0.02) + + assert future_1.running() + assert not future_2.running() + mock_start.assert_called_with(ANY, job_function, (1, 2), {"c": 3, "d": 4}, None) + + mock_job_1.describe.return_value = COMPLETED_TRAINING_JOB + mock_job_2.describe.return_value = COMPLETED_TRAINING_JOB + + e.shutdown() + + mock_start.assert_called_with(ANY, job_function, (5, 6), {"c": 7, "d": 8}, None) + mock_job_1.describe.assert_called() + mock_job_2.describe.assert_called() + + assert future_1.done() + assert future_2.done() + + +@patch("sagemaker.remote_function.client._API_CALL_LIMIT", new=API_CALL_LIMIT) +@patch("sagemaker.remote_function.client._JobSettings") +@patch("sagemaker.remote_function.client._Job.start") +def test_executor_fails_to_start_job(mock_start, *args): + mock_job = Mock() + mock_job.describe.return_value = COMPLETED_TRAINING_JOB + + mock_start.side_effect = [TypeError(), mock_job] + + with RemoteExecutor(max_parallel_jobs=1, s3_root_uri="s3://bucket/") as e: + future_1 = e.submit(job_function, 1, 2, c=3, d=4) + future_2 = e.submit(job_function, 5, 6, c=7, d=8) + + with pytest.raises(TypeError): + future_1.result() + print(future_2._state) + assert future_2.done() + + +@patch("sagemaker.remote_function.client._API_CALL_LIMIT", new=API_CALL_LIMIT) +@patch("sagemaker.remote_function.client._JobSettings") +@patch("sagemaker.remote_function.client._Job.start") +def test_executor_submit_and_cancel(mock_start, *args): + mock_job_1 = create_mock_job("job_1", INPROGRESS_TRAINING_JOB) + mock_job_2 = create_mock_job("job_2", INPROGRESS_TRAINING_JOB) + mock_start.side_effect = [mock_job_1, mock_job_2] + + e = RemoteExecutor(max_parallel_jobs=1, s3_root_uri="s3://bucket/") + + # submit first job and stay in progress + future_1 = e.submit(job_function, 1, 2, c=3, d=4) + + # submit second job and cancel + future_2 = e.submit(job_function, 5, 6, c=7, d=8) + future_2.cancel() + + # let the first job complete + mock_job_1.describe.return_value = COMPLETED_TRAINING_JOB + e.shutdown() + + mock_start.assert_called_once_with(ANY, job_function, (1, 2), {"c": 3, "d": 4}, None) + mock_job_1.describe.assert_called() + + assert future_1.done() + assert future_2.cancelled() + + +@patch("sagemaker.remote_function.client._API_CALL_LIMIT", new=API_CALL_LIMIT) +@patch("sagemaker.remote_function.client._JobSettings") +@patch("sagemaker.remote_function.client._Job.start") +def test_executor_describe_job_throttled_temporarily(mock_start, *args): + throttling_error = ClientError( + error_response={"Error": {"Code": "LimitExceededException"}}, + operation_name="SomeOperation", + ) + mock_job = Mock() + mock_job.describe.side_effect = [ + throttling_error, + throttling_error, + COMPLETED_TRAINING_JOB, + COMPLETED_TRAINING_JOB, + COMPLETED_TRAINING_JOB, + COMPLETED_TRAINING_JOB, + ] + mock_start.return_value = mock_job + + with RemoteExecutor(max_parallel_jobs=1, s3_root_uri="s3://bucket/") as e: + # submit first job + future_1 = e.submit(job_function, 1, 2, c=3, d=4) + # submit second job + future_2 = e.submit(job_function, 5, 6, c=7, d=8) + + assert future_1.done() + assert future_2.done() + + +@patch("sagemaker.remote_function.client._API_CALL_LIMIT", new=API_CALL_LIMIT) +@patch("sagemaker.remote_function.client._JobSettings") +@patch("sagemaker.remote_function.client._Job.start") +def test_executor_describe_job_failed_permanently(mock_start, *args): + mock_job = Mock() + mock_job.describe.side_effect = RuntimeError() + mock_start.return_value = mock_job + + with RemoteExecutor(max_parallel_jobs=1, s3_root_uri="s3://bucket/") as e: + # submit first job + future_1 = e.submit(job_function, 1, 2, c=3, d=4) + # submit second job + future_2 = e.submit(job_function, 5, 6, c=7, d=8) + + with pytest.raises(RuntimeError): + future_1.done() + with pytest.raises(RuntimeError): + future_2.done() + + +@pytest.mark.parametrize( + "args, kwargs, error_message", + [ + ((1, 2), {}, "job_function() missing 1 required keyword-only argument: 'c'"), + ( + (1, 2), + {"c": 3, "d": 4, "e": "extra_arg"}, + "job_function() got an unexpected keyword argument 'e'", + ), + ((), {"c": 3, "d": 4}, "job_function() missing 1 required positional argument: 'a'"), + ( + (1, 2, "extra_Arg"), + {"c": 3, "d": 4}, + "job_function() takes 2 positional arguments but 3 were given.", + ), + ], +) +@patch("sagemaker.remote_function.client._JobSettings") +def test_executor_submit_invalid_function_args(mock_job_settings, args, kwargs, error_message): + with pytest.raises(TypeError) as e: + with RemoteExecutor(max_parallel_jobs=1, s3_root_uri="s3://bucket/") as executor: + executor.submit(job_function, *args, **kwargs) + assert error_message in str(e.value) + + +@patch("sagemaker.remote_function.client._Job.start") +def test_future_cancel_before_job_starts(mock_start): + mock_job = Mock() + mock_start.return_value = mock_job + + future = Future() + + # cancel + assert future.cancel() + + # try to start the job + future._start_and_notify(Mock(), job_function, None, None) + + assert future.cancelled() + assert not future.done() + assert future.result() is None + mock_job.stop.assert_not_called() + + +@patch("sagemaker.remote_function.client._Job.start") +def test_future_cancel_after_job_starts(mock_start): + mock_job = Mock() + mock_start.return_value = mock_job + + future = Future() + + # try to start the job + future._start_and_notify(Mock(), job_function, None, None) + assert future.running() + + # cancel + assert future.cancel() + + assert future.cancelled() + assert not future.done() + assert future.result() is None + mock_job.stop.assert_called_once() + + +@patch("sagemaker.remote_function.client._Job.start") +def test_future_cancel_when_job_starting(mock_start): + mock_job = Mock() + mock_start.return_value = mock_job + + future = Future() + + t = threading.Thread( + target=lambda f: f._start_and_notify(Mock(), job_function, None, None), + args=[future], + ) + t.start() + + future.cancel() + + t.join() + + assert future.cancelled() + + +@patch("sagemaker.remote_function.client._Job.start") +def test_future_cancel_after_job_fails_to_start(mock_start): + mock_start.side_effect = TypeError() + + future = Future() + + # try to start the job + future._start_and_notify(Mock(), job_function, None, None) + assert future.done() + + # cancel + assert not future.cancel() + + assert not future.cancelled() + assert future.done() + + +@patch("sagemaker.remote_function.client._Job.start") +def test_future_wait_after_job_start(mock_start): + mock_job = Mock() + mock_start.return_value = mock_job + + future = Future() + + # try to start the job + future._start_and_notify(Mock(), job_function, None, None) + assert future.running() + + future.wait() + + mock_job.wait.assert_called_once() + + +@patch("sagemaker.remote_function.client._Job.start") +def test_future_wait_before_job_start(mock_start): + mock_job = Mock() + mock_start.return_value = mock_job + + future = Future() + + # wait for the future to resolve until timeout + future.wait(timeout=0.01) + mock_job.wait.assert_not_called() + + # start the job + future._start_and_notify(Mock(), job_function, None, None) + assert future.running() + + future.wait() + mock_job.wait.assert_called_once() + + +@patch( + "sagemaker.remote_function.client.serialization.deserialize_obj_from_s3", + return_value=EXPECTED_JOB_RESULT, +) +@patch("sagemaker.remote_function.client._Job.start") +def test_future_get_result_from_completed_job(mock_start, mock_deserialize): + mock_job = Mock(job_name=TRAINING_JOB_NAME, s3_uri=S3_URI) + mock_job.describe.return_value = COMPLETED_TRAINING_JOB + + mock_start.return_value = mock_job + + future = Future() + + # try to start the job + future._start_and_notify(Mock(), job_function, None, None) + assert future.running() + + result = future.result() + + assert result is EXPECTED_JOB_RESULT + assert future.done() + mock_job.wait.assert_called_once() + + +@patch( + "sagemaker.remote_function.client.serialization.deserialize_exception_from_s3", + return_value=ZeroDivisionError("division by zero"), +) +@patch("sagemaker.remote_function.client._Job.start") +def test_future_get_result_from_failed_job_remote_error_client_function( + mock_start, mock_deserialize +): + mock_job = Mock(job_name=TRAINING_JOB_NAME, s3_uri=S3_URI) + mock_start.return_value = mock_job + mock_job.describe.return_value = FAILED_TRAINING_JOB + + future = Future() + + # try to start the job + future._start_and_notify(Mock(), job_function, None, None) + assert future.running() + + with pytest.raises(ZeroDivisionError, match=r"division by zero"): + future.result() + + assert future.done() + mock_job.wait.assert_called_once() + mock_deserialize.assert_called_with(sagemaker_session=ANY, s3_uri=f"{S3_URI}/exception") + + +@patch("sagemaker.s3.S3Downloader.read_bytes") +@patch("sagemaker.remote_function.client._Job.start") +def test_future_get_result_from_failed_job_no_exception_in_s3(mock_start, read_bytes): + mock_job = Mock(job_name=TRAINING_JOB_NAME, s3_uri=S3_URI) + mock_start.return_value = mock_job + mock_job.describe.return_value = FAILED_TRAINING_JOB + + read_bytes.side_effect = ClientError( + error_response={"Error": {"Code": "404", "Message": "Not Found"}}, + operation_name="HeadObject", + ) + + future = Future() + + # try to start the job + future._start_and_notify(Mock(), job_function, None, None) + assert future.running() + + with pytest.raises( + RemoteFunctionError, + match=r"Failed to execute remote function. Check corresponding job for details.", + ): + future.result() + + assert future.done() + mock_job.wait.assert_called_once() + + +@patch("sagemaker.s3.S3Downloader.read_bytes") +@patch("sagemaker.remote_function.client._Job.start") +def test_future_get_result_from_failed_job_runtime_environment_error(mock_start, read_bytes): + mock_job = Mock(job_name=TRAINING_JOB_NAME, s3_uri=S3_URI) + mock_start.return_value = mock_job + failed_training_job = FAILED_TRAINING_JOB.copy() + failed_training_job.update( + {"FailureReason": "RuntimeEnvironmentError: failure while installing dependencies."} + ) + mock_job.describe.return_value = failed_training_job + + read_bytes.side_effect = ClientError( + error_response={"Error": {"Code": "404", "Message": "Not Found"}}, + operation_name="HeadObject", + ) + + future = Future() + + # try to start the job + future._start_and_notify(Mock(), job_function, None, None) + assert future.running() + + with pytest.raises( + RuntimeEnvironmentError, + match=r"failure while installing dependencies.", + ): + future.result() + + assert future.done() + mock_job.wait.assert_called_once() + + +@patch("sagemaker.s3.S3Downloader.read_bytes") +@patch("sagemaker.remote_function.client._Job.start") +def test_future_get_result_from_failed_job_local_error_service_error(mock_start, read_bytes): + mock_job = Mock(job_name=TRAINING_JOB_NAME, s3_uri=S3_URI) + mock_start.return_value = mock_job + mock_job.describe.return_value = FAILED_TRAINING_JOB + + read_bytes.side_effect = RuntimeError("some error when reading from s3") + + future = Future() + + # try to start the job + future._start_and_notify(Mock(), job_function, None, None) + assert future.running() + + with pytest.raises( + ServiceError, + match=r"Failed to read serialized bytes from .+: RuntimeError\('some error when reading from s3'\)", + ): + future.result() + + assert future.done() + mock_job.wait.assert_called_once() + + +@patch( + "sagemaker.remote_function.client.serialization.deserialize_exception_from_s3", + side_effect=DeserializationError("Failed to deserialize the exception."), +) +@patch("sagemaker.remote_function.client._Job.start") +def test_future_get_result_from_failed_job_local_error_remote_function_error( + mock_start, mock_deserialize +): + mock_job = Mock(job_name=TRAINING_JOB_NAME, s3_uri=S3_URI) + mock_start.return_value = mock_job + mock_job.describe.return_value = FAILED_TRAINING_JOB + + future = Future() + + # try to start the job + future._start_and_notify(Mock(), job_function, None, None) + assert future.running() + + with pytest.raises( + DeserializationError, + match=r"Failed to deserialize the exception.", + ): + future.result() + + assert future.done() + mock_job.wait.assert_called_once() + + +@patch( + "sagemaker.remote_function.client.serialization.deserialize_obj_from_s3", + return_value=EXPECTED_JOB_RESULT, +) +@patch("sagemaker.remote_function.client._Job.start") +def test_future_get_result_from_in_progress_job(mock_start, mock_deserialize): + mock_job = Mock() + mock_start.return_value = mock_job + mock_job.describe.return_value = INPROGRESS_TRAINING_JOB + mock_job.wait.side_effect = UnexpectedStatusException( + message="some message", + allowed_statuses=["Completed", "Stopped"], + actual_status="InProgress", + ) + + future = Future() + + # try to start the job + future._start_and_notify(Mock(), job_function, None, None) + assert future.running() + + with pytest.raises( + TimeoutError, + match=r"Job for remote function timed out before reaching a termination status.", + ): + future.result() + + assert future.running() + mock_job.wait.assert_called_once() + + +@patch( + "sagemaker.remote_function.client.serialization.deserialize_obj_from_s3", + return_value=EXPECTED_JOB_RESULT, +) +@patch("sagemaker.remote_function.client._Job.start") +def test_future_get_result_from_stopped_job(mock_start, mock_deserialize): + mock_job = Mock() + mock_start.return_value = mock_job + mock_job.describe.return_value = CANCELLED_TRAINING_JOB + mock_job.wait.side_effect = UnexpectedStatusException( + message="some message", + allowed_statuses=["Completed", "Stopped"], + actual_status="Stopped", + ) + + future = Future() + + # try to start the job + future._start_and_notify(Mock(), job_function, None, None) + assert future.running() + + with pytest.raises(RemoteFunctionError, match=r"Job for remote function has been aborted."): + future.result() + + assert future.cancelled() + mock_job.wait.assert_called_once() + + +@patch( + "sagemaker.remote_function.client.serialization.deserialize_obj_from_s3", + return_value=EXPECTED_JOB_RESULT, +) +@patch("sagemaker.remote_function.client._Job.start") +def test_future_get_result_from_job_failed_to_start(mock_start, mock_deserialize): + mock_start.side_effect = TypeError() + + future = Future() + + # try to start the job + future._start_and_notify(Mock(), job_function, None, None) + assert future.done() + + with pytest.raises(TypeError): + future.result() + + +def test_future_get_result_from_not_yet_started_job(): + future = Future() + + # wait for the future to resolve until timeout + with pytest.raises(RuntimeError): + future.result(timeout=0.01) + + +@patch("sagemaker.remote_function.client._API_CALL_LIMIT", new=API_CALL_LIMIT) +@patch("sagemaker.remote_function.client._JobSettings") +@patch("sagemaker.remote_function.client._Job.start") +@patch("sagemaker.remote_function.client.serialization.deserialize_obj_from_s3") +def test_executor_map_happy_case(mock_deserialized, mock_start, mock_job_settings): + mock_job_1 = create_mock_job("job_1", COMPLETED_TRAINING_JOB) + mock_job_2 = create_mock_job("job_2", COMPLETED_TRAINING_JOB) + mock_start.side_effect = [mock_job_1, mock_job_2] + + mock_deserialized.side_effect = [1, 16] + + with RemoteExecutor(max_parallel_jobs=1, s3_root_uri="s3://bucket/") as executor: + results = executor.map(job_function2, [1, 2], [3, 4]) + + mock_start.assert_has_calls( + [ + call(ANY, job_function2, (1, 3), {}, None), + call(ANY, job_function2, (2, 4), {}, None), + ] + ) + mock_job_1.describe.assert_called() + mock_job_2.describe.assert_called() + + assert results[0] == 1 + assert results[1] == 16 + + +@patch("sagemaker.remote_function.client._API_CALL_LIMIT", new=API_CALL_LIMIT) +@patch("sagemaker.remote_function.client._JobSettings") +@patch("sagemaker.remote_function.client._Job.start") +@patch("sagemaker.remote_function.client.serialization.deserialize_obj_from_s3") +def test_executor_map_with_run(mock_deserialized, mock_start, mock_job_settings, run_obj): + mock_job_1 = create_mock_job("job_1", COMPLETED_TRAINING_JOB) + mock_job_2 = create_mock_job("job_2", COMPLETED_TRAINING_JOB) + mock_job_3 = create_mock_job("job_3", COMPLETED_TRAINING_JOB) + mock_job_4 = create_mock_job("job_4", COMPLETED_TRAINING_JOB) + mock_start.side_effect = [mock_job_1, mock_job_2, mock_job_3, mock_job_4] + + mock_deserialized.side_effect = [1, 16] + + run_info = _RunInfo(run_obj.experiment_name, run_obj.run_name) + + with run_obj: + with RemoteExecutor(max_parallel_jobs=2, s3_root_uri="s3://bucket/") as executor: + results_12 = executor.map(job_function2, [1, 2], [3, 4]) + + mock_start.assert_has_calls( + [ + call(ANY, job_function2, (1, 3), {}, run_info), + call(ANY, job_function2, (2, 4), {}, run_info), + ] + ) + mock_job_1.describe.assert_called() + mock_job_2.describe.assert_called() + + assert results_12[0] == 1 + assert results_12[1] == 16 + + mock_deserialized.side_effect = [1, 16] + + with RemoteExecutor(max_parallel_jobs=2, s3_root_uri="s3://bucket/") as executor: + with run_obj: + results_34 = executor.map(job_function2, [1, 2], [3, 4]) + + mock_start.assert_has_calls( + [ + call(ANY, job_function2, (1, 3), {}, run_info), + call(ANY, job_function2, (2, 4), {}, run_info), + ] + ) + mock_job_3.describe.assert_called() + mock_job_4.describe.assert_called() + + assert results_34[0] == 1 + assert results_34[1] == 16 + + +@patch("sagemaker.remote_function.client.Session") +@patch("sagemaker.remote_function.client.serialization.deserialize_obj_from_s3") +def test_get_future_completed_job(mock_deserialized, mock_session): + job_return_val = "4.666" + + mock_session.return_value.sagemaker_client.describe_training_job.return_value = ( + COMPLETED_TRAINING_JOB + ) + mock_deserialized.return_value = job_return_val + + future = get_future(TRAINING_JOB_NAME) + + assert future.done() + assert future.result() == job_return_val + + +@patch( + "sagemaker.remote_function.client.serialization.deserialize_exception_from_s3", + return_value=ZeroDivisionError("division by zero"), +) +@patch("sagemaker.remote_function.client.Session") +def test_get_future_failed_job(mock_session, *args): + mock_session.return_value.sagemaker_client.describe_training_job.return_value = ( + FAILED_TRAINING_JOB + ) + + future = get_future(TRAINING_JOB_NAME) + + assert future.done() + with pytest.raises(ZeroDivisionError, match=r"division by zero"): + future.result() + + +@patch( + "sagemaker.remote_function.client.serialization.deserialize_obj_from_s3", + side_effect=DeserializationError("Failed to deserialize the results."), +) +@patch("sagemaker.remote_function.client.Session") +def test_get_future_completed_job_deserialization_error(mock_session, mock_deserialize): + mock_session.return_value.sagemaker_client.describe_training_job.return_value = ( + COMPLETED_TRAINING_JOB + ) + + future = get_future(TRAINING_JOB_NAME) + + assert future.done() + with pytest.raises(DeserializationError, match=r"Failed to deserialize the results."): + future.result() + + mock_deserialize.assert_called_with( + sagemaker_session=ANY, s3_uri="s3://sagemaker-123/image_uri/output/results" + ) + + +@patch("sagemaker.s3.S3Downloader.read_bytes") +@patch("sagemaker.remote_function.client.Session") +def test_get_future_completed_job_s3_read_error(mock_session, read_bytes): + mock_session.return_value.sagemaker_client.describe_training_job.return_value = ( + COMPLETED_TRAINING_JOB + ) + + read_bytes.side_effect = RuntimeError("some error when reading from s3") + + future = get_future(TRAINING_JOB_NAME) + + assert future.done() + with pytest.raises( + ServiceError, + match=r"Failed to read serialized bytes from .+: RuntimeError\('some error when reading from s3'\)", + ): + future.result() + + +@patch("sagemaker.s3.S3Downloader.read_bytes") +@patch("sagemaker.remote_function.client.Session") +def test_get_future_failed_job_S3_404_service_error(mock_session, read_bytes): + mock_session.return_value.sagemaker_client.describe_training_job.return_value = ( + FAILED_TRAINING_JOB + ) + + read_bytes.side_effect = ClientError( + error_response={"Error": {"Code": "404", "Message": "Not Found"}}, + operation_name="HeadObject", + ) + + future = get_future(TRAINING_JOB_NAME) + + assert future.done() + with pytest.raises( + RemoteFunctionError, + match=r"Failed to execute remote function. Check corresponding job for details.", + ): + future.result() + + +@patch("sagemaker.s3.S3Downloader.read_bytes") +@patch("sagemaker.remote_function.client.Session") +def test_get_future_failed_job_S3_404_runtime_environment_error(mock_session, read_bytes): + failed_training_job = FAILED_TRAINING_JOB.copy() + failed_training_job.update( + {"FailureReason": "RuntimeEnvironmentError: failure while installing dependencies."} + ) + mock_session.return_value.sagemaker_client.describe_training_job.return_value = ( + failed_training_job + ) + + read_bytes.side_effect = ClientError( + error_response={"Error": {"Code": "404", "Message": "Not Found"}}, + operation_name="HeadObject", + ) + + future = get_future(TRAINING_JOB_NAME) + + assert future.done() + with pytest.raises( + RuntimeEnvironmentError, + match=r"failure while installing dependencies.", + ): + future.result() + + +@patch("sagemaker.remote_function.client.Session") +def test_get_future_incomplete_job(mock_session): + mock_session.return_value.sagemaker_client.describe_training_job.return_value = ( + INPROGRESS_TRAINING_JOB + ) + + future = get_future(TRAINING_JOB_NAME) + + assert future.running() + + +@patch("sagemaker.remote_function.client.Session") +def test_list_future(mock_session): + job_name_prefix = "foobarbaz" + next_token = "next-token-1" + mock_session.return_value.sagemaker_client.list_training_jobs.side_effect = [ + { + "TrainingJobSummaries": [{"TrainingJobName": "job-1"}, {"TrainingJobName": "job-2"}], + "NextToken": next_token, + }, + {"TrainingJobSummaries": [{"TrainingJobName": "job-3"}]}, + ] + mock_session.return_value.sagemaker_client.describe_training_job.side_effect = [ + INPROGRESS_TRAINING_JOB, + COMPLETED_TRAINING_JOB, + FAILED_TRAINING_JOB, + ] + + futures = list(list_futures(job_name_prefix)) + + assert futures[0].running() + assert futures[1].done() + assert futures[2].done() + + mock_session.return_value.sagemaker_client.list_training_jobs.assert_has_calls( + [ + call(NameContains=job_name_prefix), + call(NameContains=job_name_prefix, NextToken=next_token), + ] + ) + + mock_session.return_value.sagemaker_client.describe_training_job.assert_has_calls( + [ + call(TrainingJobName="job-1"), + call(TrainingJobName="job-2"), + call(TrainingJobName="job-3"), + ] + ) diff --git a/tests/unit/sagemaker/remote_function/test_errors.py b/tests/unit/sagemaker/remote_function/test_errors.py new file mode 100644 index 0000000000..78b864e784 --- /dev/null +++ b/tests/unit/sagemaker/remote_function/test_errors.py @@ -0,0 +1,81 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +import pytest +import errno +from mock import patch, Mock, mock_open + +from sagemaker.remote_function.errors import SerializationError, handle_error + +TEST_S3_BASE_URI = "s3://my-bucket/" +TEST_S3_KMS_KEY = "my-kms-key" + + +class _InvalidErrorNumberException(Exception): + def __init__(self, *args, **kwargs): # real signature unknown + self.errno = "invalid" + + +@pytest.fixture() +def sagemaker_session(): + return Mock() + + +@pytest.mark.parametrize( + "error, expected_exit_code, error_string", + [ + ( + SerializationError("some failure reason"), + 1, + "SerializationError('some failure reason')", + ), + ( + FileNotFoundError(errno.ENOENT, "No such file or directory"), + errno.ENOENT, + "FileNotFoundError(2, 'No such file or directory')", + ), + ( + Exception("No such file or directory"), + 1, + "Exception('No such file or directory')", + ), + ( + _InvalidErrorNumberException("No such file or directory"), + 1, + "_InvalidErrorNumberException('No such file or directory')", + ), + ], +) +@patch("sagemaker.remote_function.client.serialization.serialize_exception_to_s3") +@patch("builtins.open", new_callable=mock_open()) +@patch("os.path.exists", return_value=False) +def test_handle_error( + exists, + mock_open_file, + serialize_exception_to_s3, + sagemaker_session, + error, + expected_exit_code, + error_string, +): + err = error + exit_code = handle_error(err, sagemaker_session, TEST_S3_BASE_URI, TEST_S3_KMS_KEY) + + assert exit_code == expected_exit_code + exists.assert_called_once_with("/opt/ml/output/failure") + mock_open_file.assert_called_with("/opt/ml/output/failure", "w") + mock_open_file.return_value.__enter__().write.assert_called_with(error_string) + serialize_exception_to_s3.assert_called_with( + err, sagemaker_session, TEST_S3_BASE_URI + "exception", TEST_S3_KMS_KEY + ) diff --git a/tests/unit/sagemaker/remote_function/test_invoke_function.py b/tests/unit/sagemaker/remote_function/test_invoke_function.py new file mode 100644 index 0000000000..661e2138e3 --- /dev/null +++ b/tests/unit/sagemaker/remote_function/test_invoke_function.py @@ -0,0 +1,109 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +from mock import patch, Mock +from sagemaker.remote_function import invoke_function +from sagemaker.remote_function.errors import SerializationError + +TEST_REGION = "us-west-2" +TEST_S3_BASE_URI = "s3://my-bucket/" +TEST_S3_KMS_KEY = "my-kms-key" +TEST_RUN_IN_CONTEXT = '{"experiment_name": "my-exp-name", "run_name": "my-run-name"}' + + +def mock_args(): + args = Mock() + args.region = TEST_REGION + args.s3_base_uri = TEST_S3_BASE_URI + args.s3_kms_key = TEST_S3_KMS_KEY + args.run_in_context = None + + return args + + +def mock_args_with_run_in_context(): + args = Mock() + args.region = TEST_REGION + args.s3_base_uri = TEST_S3_BASE_URI + args.s3_kms_key = TEST_S3_KMS_KEY + args.run_in_context = TEST_RUN_IN_CONTEXT + + return args + + +def mock_session(): + return Mock() + + +@patch("sagemaker.remote_function.invoke_function._parse_agrs", new=mock_args) +@patch("sagemaker.remote_function.invoke_function._load_run_object") +@patch("sys.exit") +@patch("sagemaker.remote_function.core.stored_function.StoredFunction.load_and_invoke") +@patch( + "sagemaker.remote_function.invoke_function._get_sagemaker_session", + return_value=mock_session(), +) +def test_main_success(_get_sagemaker_session, load_and_invoke, _exit_process, _load_run_object): + invoke_function.main() + + _get_sagemaker_session.assert_called_with(TEST_REGION) + load_and_invoke.assert_called() + _load_run_object.assert_not_called() + _exit_process.assert_called_with(0) + + +@patch("sagemaker.remote_function.invoke_function._parse_agrs", new=mock_args_with_run_in_context) +@patch("sagemaker.remote_function.invoke_function._load_run_object") +@patch("sys.exit") +@patch("sagemaker.remote_function.core.stored_function.StoredFunction.load_and_invoke") +@patch( + "sagemaker.remote_function.invoke_function._get_sagemaker_session", + return_value=mock_session(), +) +def test_main_success_with_run( + _get_sagemaker_session, load_and_invoke, _exit_process, _load_run_object +): + invoke_function.main() + + _get_sagemaker_session.assert_called_with(TEST_REGION) + load_and_invoke.assert_called() + _load_run_object.assert_called_once_with(TEST_RUN_IN_CONTEXT, _get_sagemaker_session()) + _exit_process.assert_called_with(0) + + +@patch("sagemaker.remote_function.invoke_function._parse_agrs", new=mock_args) +@patch("sagemaker.remote_function.invoke_function._load_run_object") +@patch("sagemaker.remote_function.invoke_function.handle_error") +@patch("sys.exit") +@patch("sagemaker.remote_function.core.stored_function.StoredFunction.load_and_invoke") +@patch( + "sagemaker.remote_function.invoke_function._get_sagemaker_session", + return_value=mock_session(), +) +def test_main_failure( + _get_sagemaker_session, load_and_invoke, _exit_process, handle_error, _load_run_object +): + ser_err = SerializationError("some failure reason") + load_and_invoke.side_effect = ser_err + handle_error.return_value = 1 + + invoke_function.main() + + _get_sagemaker_session.assert_called_with(TEST_REGION) + load_and_invoke.assert_called() + _load_run_object.assert_not_called() + handle_error.assert_called_with( + ser_err, _get_sagemaker_session(), TEST_S3_BASE_URI, TEST_S3_KMS_KEY + ) + _exit_process.assert_called_with(1) diff --git a/tests/unit/sagemaker/remote_function/test_job.py b/tests/unit/sagemaker/remote_function/test_job.py new file mode 100644 index 0000000000..fb019875ad --- /dev/null +++ b/tests/unit/sagemaker/remote_function/test_job.py @@ -0,0 +1,558 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +import os +import sys + +import pytest +from mock import patch, Mock, ANY + +from sagemaker.config import load_sagemaker_config +from tests.unit import DATA_DIR +from sagemaker.remote_function.job import ( + _JobSettings, + _Job, + _convert_run_to_json, + _prepare_and_upload_runtime_scripts, + _prepare_and_upload_dependencies, + _filter_non_python_files, +) + + +REGION = "us-west-2" +TRAINING_JOB_ARN = "training-job-arn" +IMAGE = "image_uri" +BUCKET = "my-s3-bucket" +S3_URI = f"s3://{BUCKET}/keyprefix" +ROLE_ARN = "my_execution_role_arn" +KMS_KEY_ARN = "kms-key-arn" +DEFAULT_ROLE_ARN = "default_execution_role_arn" +TEST_REGION = "us-west-2" +RUNTIME_SCRIPTS_CHANNEL_NAME = "sagemaker_remote_function_bootstrap" +REMOTE_FUNCTION_WORKSPACE = "sm_rf_user_ws" + +EXPECTED_FUNCTION_URI = S3_URI + "/function.pkl" +EXPECTED_OUTPUT_URI = S3_URI + "/output" +EXPECTED_DEPENDENCIES_URI = S3_URI + "/additional_dependencies/requirements.txt" + +DESCRIBE_TRAINING_JOB_RESPONSE = { + "TrainingJobArn": TRAINING_JOB_ARN, + "TrainingJobStatus": "{}", + "ResourceConfig": { + "InstanceCount": 1, + "InstanceType": "ml.c4.xlarge", + "VolumeSizeInGB": 30, + }, + "OutputDataConfig": {"S3OutputPath": "s3://sagemaker-123/image_uri/output"}, +} + +TEST_EXP_NAME = "my-exp-name" +TEST_RUN_NAME = "my-run-name" +TEST_EXP_DISPLAY_NAME = "my-exp-display-name" +TEST_RUN_DISPLAY_NAME = "my-run-display-name" +TEST_TAGS = [{"Key": "some-key", "Value": "some-value"}] + + +def mock_get_current_run(): + current_run = Mock() + current_run.experiment_name = TEST_EXP_NAME + current_run.run_name = TEST_RUN_NAME + current_run.experiment_display_name = TEST_EXP_DISPLAY_NAME + current_run.run_display_name = TEST_RUN_DISPLAY_NAME + current_run.tags = TEST_TAGS + return current_run + + +def describe_training_job_response(job_status): + return { + "TrainingJobArn": TRAINING_JOB_ARN, + "TrainingJobStatus": job_status, + "ResourceConfig": { + "InstanceCount": 1, + "InstanceType": "ml.c4.xlarge", + "VolumeSizeInGB": 30, + }, + "OutputDataConfig": {"S3OutputPath": "s3://sagemaker-123/image_uri/output"}, + } + + +COMPLETED_TRAINING_JOB = describe_training_job_response("Completed") +INPROGRESS_TRAINING_JOB = describe_training_job_response("InProgress") +CANCELLED_TRAINING_JOB = describe_training_job_response("Stopped") +FAILED_TRAINING_JOB = describe_training_job_response("Failed") + + +def mock_session(): + session = Mock() + session.sagemaker_client.create_training_job.return_value = {"TrainingJobArn": TRAINING_JOB_ARN} + session.sagemaker_client.describe_training_job.return_value = COMPLETED_TRAINING_JOB + + session.default_bucket.return_value = BUCKET + session.expand_role.return_value = ROLE_ARN + session.boto_region_name = TEST_REGION + session.sagemaker_config = None + session._append_sagemaker_config_tags.return_value = [] + + return session + + +def job_function(a, b=1, *, c, d=3): + return a * b * c * d + + +@patch("sagemaker.remote_function.job.Session", return_value=mock_session()) +@patch("sagemaker.remote_function.job.get_execution_role", return_value=DEFAULT_ROLE_ARN) +def test_sagemaker_config_job_settings(get_execution_role, session): + + job_settings = _JobSettings(image_uri="image_uri", instance_type="ml.m5.xlarge") + assert job_settings.image_uri == "image_uri" + assert job_settings.s3_root_uri == f"s3://{BUCKET}" + assert job_settings.role == DEFAULT_ROLE_ARN + assert job_settings.environment_variables == {"AWS_DEFAULT_REGION": "us-west-2"} + assert job_settings.include_local_workdir is False + assert job_settings.instance_type == "ml.m5.xlarge" + + +@patch("sagemaker.remote_function.job.Session", return_value=mock_session()) +@patch("sagemaker.remote_function.job.get_execution_role", return_value=DEFAULT_ROLE_ARN) +def test_sagemaker_config_job_settings_with_configuration_file(get_execution_role, session): + config_tags = [ + {"Key": "someTagKey", "Value": "someTagValue"}, + {"Key": "someTagKey2", "Value": "someTagValue2"}, + ] + session().sagemaker_config = load_sagemaker_config( + additional_config_paths=[os.path.join(DATA_DIR, "remote_function")] + ) + session()._append_sagemaker_config_tags.return_value = config_tags + + job_settings = _JobSettings(image_uri="image_uri") + assert job_settings.image_uri == "image_uri" + assert job_settings.s3_root_uri == f"s3://{BUCKET}" + assert job_settings.role == DEFAULT_ROLE_ARN + assert job_settings.tags == config_tags + assert job_settings.vpc_config == {"Subnets": ["subnet-1234"], "SecurityGroupIds": ["sg123"]} + assert job_settings.pre_execution_commands == ["command_1", "command_2"] + assert job_settings.environment_variables == { + "AWS_DEFAULT_REGION": "us-west-2", + "EnvVarKey": "EnvVarValue", + } + assert job_settings.job_conda_env == "my_conda_env" + assert job_settings.include_local_workdir is True + assert job_settings.volume_kms_key == "someVolumeKmsKey" + assert job_settings.s3_kms_key == "someS3KmsKey" + assert job_settings.instance_type == "ml.m5.large" + assert job_settings.enable_network_isolation is False + assert job_settings.encrypt_inter_container_traffic is True + + +@patch("sagemaker.remote_function.job.Session", return_value=mock_session()) +@patch("sagemaker.remote_function.job.get_execution_role", return_value=DEFAULT_ROLE_ARN) +def test_sagemaker_config_job_settings_exclusive_pre_exec_cmd_or_script( + get_execution_role, session +): + + with pytest.raises( + ValueError, + match="Only one of pre_execution_commands or pre_execution_script can be specified!", + ): + _JobSettings( + image_uri="image_uri", + instance_type="ml.m5.xlarge", + pre_execution_commands=["command_1", "command_2"], + pre_execution_script="path/to/local/script", + ) + + session().sagemaker_config = load_sagemaker_config( + additional_config_paths=[os.path.join(DATA_DIR, "remote_function")] + ) + + with pytest.raises( + ValueError, + match="Only one of pre_execution_commands or pre_execution_script can be specified!", + ): + _JobSettings( + image_uri="image_uri", + instance_type="ml.m5.xlarge", + pre_execution_script="path/to/local/script", + ) + + +@patch("sagemaker.remote_function.job.Session", return_value=mock_session()) +@patch("sagemaker.remote_function.job.get_execution_role", return_value=DEFAULT_ROLE_ARN) +def test_sagemaker_config_job_settings_missing_image_uri(get_execution_role, session): + session().sagemaker_config = load_sagemaker_config( + additional_config_paths=[os.path.join(DATA_DIR, "remote_function")] + ) + + py_version = str(sys.version_info[0]) + str(sys.version_info[1]) + if py_version not in ["310", "38"]: + with pytest.raises( + ValueError, + match="Default image is supported only for Python versions 3.8 and 3.10. " + "If you are using any other python version, you must provide a compatible image_uri.", + ): + _JobSettings() + else: + job_settings = _JobSettings() + assert ( + job_settings.image_uri + == f"236514542706.dkr.ecr.{TEST_REGION}.amazonaws.com/sagemaker-base-python-{py_version}:1.0" + ) + + +@patch("sagemaker.remote_function.job.Session", return_value=mock_session()) +@patch("sagemaker.remote_function.job.get_execution_role", return_value=DEFAULT_ROLE_ARN) +def test_sagemaker_config_job_settings_studio_image_uri(get_execution_role, session, monkeypatch): + monkeypatch.setenv("SAGEMAKER_INTERNAL_IMAGE_URI", "studio_image_uri") + + session().sagemaker_config = load_sagemaker_config( + additional_config_paths=[os.path.join(DATA_DIR, "remote_function")] + ) + + job_settings = _JobSettings() + assert job_settings.image_uri == "studio_image_uri" + + monkeypatch.delenv("SAGEMAKER_INTERNAL_IMAGE_URI") + + +@patch("sagemaker.experiments._run_context._RunContext.get_current_run", new=mock_get_current_run) +@patch("sagemaker.remote_function.job._prepare_and_upload_dependencies", return_value="some_s3_uri") +@patch( + "sagemaker.remote_function.job._prepare_and_upload_runtime_scripts", return_value="some_s3_uri" +) +@patch("sagemaker.remote_function.job.RuntimeEnvironmentManager") +@patch("sagemaker.remote_function.job.StoredFunction") +@patch("sagemaker.remote_function.job.Session", return_value=mock_session()) +def test_start( + session, mock_stored_function, mock_runtime_manager, mock_script_upload, mock_dependency_upload +): + + job_settings = _JobSettings( + image_uri=IMAGE, + s3_root_uri=S3_URI, + role=ROLE_ARN, + include_local_workdir=True, + instance_type="ml.m5.large", + encrypt_inter_container_traffic=True, + ) + + job = _Job.start(job_settings, job_function, func_args=(1, 2), func_kwargs={"c": 3, "d": 4}) + + assert job.job_name.startswith("job-function") + + assert mock_stored_function.called_once_with( + sagemaker_session=session(), s3_base_uri=f"{S3_URI}/{job.job_name}", s3_kms_key=None + ) + + local_dependencies_path = mock_runtime_manager().snapshot() + mock_python_version = mock_runtime_manager()._current_python_version() + + mock_script_upload.assert_called_once_with( + s3_base_uri=f"{S3_URI}/{job.job_name}", + s3_kms_key=None, + sagemaker_session=session(), + ) + + mock_dependency_upload.assert_called_once_with( + local_dependencies_path=local_dependencies_path, + include_local_workdir=True, + pre_execution_commands=None, + pre_execution_script_local_path=None, + s3_base_uri=f"{S3_URI}/{job.job_name}", + s3_kms_key=None, + sagemaker_session=session(), + ) + + session().sagemaker_client.create_training_job.assert_called_once_with( + TrainingJobName=job.job_name, + RoleArn=ROLE_ARN, + StoppingCondition={"MaxRuntimeInSeconds": 86400}, + RetryStrategy={"MaximumRetryAttempts": 1}, + InputDataConfig=[ + dict( + ChannelName=RUNTIME_SCRIPTS_CHANNEL_NAME, + DataSource={ + "S3DataSource": { + "S3Uri": mock_script_upload.return_value, + "S3DataType": "S3Prefix", + } + }, + ), + dict( + ChannelName=REMOTE_FUNCTION_WORKSPACE, + DataSource={ + "S3DataSource": { + "S3Uri": f"{S3_URI}/{job.job_name}/sm_rf_user_ws", + "S3DataType": "S3Prefix", + } + }, + ), + ], + OutputDataConfig={"S3OutputPath": f"{S3_URI}/{job.job_name}"}, + AlgorithmSpecification=dict( + TrainingImage=IMAGE, + TrainingInputMode="File", + ContainerEntrypoint=[ + "/bin/bash", + "/opt/ml/input/data/sagemaker_remote_function_bootstrap/job_driver.sh", + ], + ContainerArguments=[ + "--s3_base_uri", + f"{S3_URI}/{job.job_name}", + "--region", + TEST_REGION, + "--client_python_version", + mock_python_version, + "--run_in_context", + '{"experiment_name": "my-exp-name", "run_name": "my-run-name"}', + ], + ), + ResourceConfig=dict( + VolumeSizeInGB=30, + InstanceCount=1, + InstanceType="ml.m5.large", + KeepAlivePeriodInSeconds=0, + ), + EnableNetworkIsolation=False, + EnableInterContainerTrafficEncryption=True, + Environment={"AWS_DEFAULT_REGION": "us-west-2"}, + ) + + +@patch("sagemaker.remote_function.job._prepare_and_upload_dependencies", return_value="some_s3_uri") +@patch( + "sagemaker.remote_function.job._prepare_and_upload_runtime_scripts", return_value="some_s3_uri" +) +@patch("sagemaker.remote_function.job.RuntimeEnvironmentManager") +@patch("sagemaker.remote_function.job.StoredFunction") +@patch("sagemaker.remote_function.job.Session", return_value=mock_session()) +def test_start_with_complete_job_settings( + session, mock_stored_function, mock_runtime_manager, mock_script_upload, mock_dependency_upload +): + + job_settings = _JobSettings( + dependencies="path/to/dependencies/req.txt", + pre_execution_script="path/to/script.sh", + environment_variables={"AWS_DEFAULT_REGION": "us-east-2"}, + image_uri=IMAGE, + s3_root_uri=S3_URI, + s3_kms_key=KMS_KEY_ARN, + role=ROLE_ARN, + instance_type="ml.m5.xlarge", + job_conda_env="conda_env", + keep_alive_period_in_seconds=120, + volume_size=120, + volume_kms_key=KMS_KEY_ARN, + subnets=["subnet"], + security_group_ids=["sg"], + ) + + job = _Job.start(job_settings, job_function, func_args=(1, 2), func_kwargs={"c": 3, "d": 4}) + + assert job.job_name.startswith("job-function") + + assert mock_stored_function.called_once_with( + sagemaker_session=session(), s3_base_uri=f"{S3_URI}/{job.job_name}", s3_kms_key=None + ) + + local_dependencies_path = mock_runtime_manager().snapshot() + mock_python_version = mock_runtime_manager()._current_python_version() + + mock_script_upload.assert_called_once_with( + s3_base_uri=f"{S3_URI}/{job.job_name}", + s3_kms_key=job_settings.s3_kms_key, + sagemaker_session=session(), + ) + + mock_dependency_upload.assert_called_once_with( + local_dependencies_path=local_dependencies_path, + include_local_workdir=False, + pre_execution_commands=None, + pre_execution_script_local_path="path/to/script.sh", + s3_base_uri=f"{S3_URI}/{job.job_name}", + s3_kms_key=job_settings.s3_kms_key, + sagemaker_session=session(), + ) + + session().sagemaker_client.create_training_job.assert_called_once_with( + TrainingJobName=job.job_name, + RoleArn=ROLE_ARN, + StoppingCondition={"MaxRuntimeInSeconds": 86400}, + RetryStrategy={"MaximumRetryAttempts": 1}, + InputDataConfig=[ + dict( + ChannelName=RUNTIME_SCRIPTS_CHANNEL_NAME, + DataSource={ + "S3DataSource": { + "S3Uri": mock_script_upload.return_value, + "S3DataType": "S3Prefix", + } + }, + ), + dict( + ChannelName=REMOTE_FUNCTION_WORKSPACE, + DataSource={ + "S3DataSource": { + "S3Uri": f"{S3_URI}/{job.job_name}/sm_rf_user_ws", + "S3DataType": "S3Prefix", + } + }, + ), + ], + OutputDataConfig={"S3OutputPath": f"{S3_URI}/{job.job_name}", "KmsKeyId": KMS_KEY_ARN}, + AlgorithmSpecification=dict( + TrainingImage=IMAGE, + TrainingInputMode="File", + ContainerEntrypoint=[ + "/bin/bash", + "/opt/ml/input/data/sagemaker_remote_function_bootstrap/job_driver.sh", + ], + ContainerArguments=[ + "--s3_base_uri", + f"{S3_URI}/{job.job_name}", + "--region", + TEST_REGION, + "--client_python_version", + mock_python_version, + "--s3_kms_key", + KMS_KEY_ARN, + "--job_conda_env", + job_settings.job_conda_env, + ], + ), + ResourceConfig=dict( + VolumeSizeInGB=120, + InstanceCount=1, + InstanceType="ml.m5.xlarge", + VolumeKmsKeyId=KMS_KEY_ARN, + KeepAlivePeriodInSeconds=120, + ), + EnableNetworkIsolation=False, + EnableInterContainerTrafficEncryption=False, + VpcConfig=dict(Subnets=["subnet"], SecurityGroupIds=["sg"]), + Environment={"AWS_DEFAULT_REGION": "us-west-2"}, + ) + + +@patch("sagemaker.remote_function.job._prepare_and_upload_runtime_scripts") +@patch("sagemaker.remote_function.job._prepare_and_upload_dependencies") +@patch("sagemaker.remote_function.job.StoredFunction") +@patch("sagemaker.remote_function.job.Session", return_value=mock_session()) +def test_describe(session, *args): + + job_settings = _JobSettings( + image_uri=IMAGE, + s3_root_uri=S3_URI, + role=ROLE_ARN, + instance_type="ml.m5.large", + ) + job = _Job.start(job_settings, job_function, func_args=(1, 2), func_kwargs={"c": 3, "d": 4}) + + job.describe() + assert job.describe() == COMPLETED_TRAINING_JOB + + session().sagemaker_client.describe_training_job.assert_called_once() + + +@patch("sagemaker.remote_function.job._prepare_and_upload_runtime_scripts") +@patch("sagemaker.remote_function.job._prepare_and_upload_dependencies") +@patch("sagemaker.remote_function.job.StoredFunction") +@patch("sagemaker.remote_function.job.Session", return_value=mock_session()) +def test_stop(session, *args): + + job_settings = _JobSettings( + image_uri=IMAGE, + s3_root_uri=S3_URI, + role=ROLE_ARN, + instance_type="ml.m5.large", + ) + job = _Job.start(job_settings, job_function, func_args=(1, 2), func_kwargs={"c": 3, "d": 4}) + + job.stop() + + session().sagemaker_client.stop_training_job.assert_called_once_with( + TrainingJobName=job.job_name + ) + + +@patch("sagemaker.remote_function.job._prepare_and_upload_runtime_scripts") +@patch("sagemaker.remote_function.job._prepare_and_upload_dependencies") +@patch("sagemaker.remote_function.job._logs_for_job") +@patch("sagemaker.remote_function.job.StoredFunction") +@patch("sagemaker.remote_function.job.Session", return_value=mock_session()) +def test_wait(session, mock_stored_function, mock_logs_for_job, *args): + + job_settings = _JobSettings( + image_uri=IMAGE, + s3_root_uri=S3_URI, + role=ROLE_ARN, + instance_type="ml.m5.large", + ) + job = _Job.start(job_settings, job_function, func_args=(1, 2), func_kwargs={"c": 3, "d": 4}) + + job.wait(timeout=10) + + mock_logs_for_job.assert_called_with( + boto_session=ANY, job_name=job.job_name, wait=True, timeout=10 + ) + + +@patch("sagemaker.s3.S3Uploader.upload", return_value="some_uri") +@patch("shutil.copy2") +@patch("sagemaker.remote_function.job.Session", return_value=mock_session()) +def test_prepare_and_upload_runtime_scripts(session, mock_copy, mock_s3_upload): + s3_path = _prepare_and_upload_runtime_scripts( + s3_base_uri=S3_URI, + s3_kms_key=KMS_KEY_ARN, + sagemaker_session=session(), + ) + + assert s3_path == mock_s3_upload.return_value + + assert mock_copy.call_count == 2 + mock_s3_upload.assert_called_once() + + +@patch("sagemaker.s3.S3Uploader.upload", return_value="some_uri") +@patch("shutil.copy2") +@patch("shutil.copytree") +@patch("sagemaker.remote_function.job.Session", return_value=mock_session()) +def test_prepare_and_upload_dependencies(session, mock_copytree, mock_copy, mock_s3_upload): + s3_path = _prepare_and_upload_dependencies( + local_dependencies_path="some/path/to/dependency", + include_local_workdir=True, + pre_execution_commands=["cmd_1", "cmd_2"], + pre_execution_script_local_path=None, + s3_base_uri=S3_URI, + s3_kms_key=KMS_KEY_ARN, + sagemaker_session=session, + ) + + assert s3_path == mock_s3_upload.return_value + + mock_copytree.assert_called_with(os.getcwd(), ANY, ignore=_filter_non_python_files) + mock_copy.assert_called_with("some/path/to/dependency", ANY) + mock_s3_upload.assert_called_once_with( + ANY, S3_URI + "/" + REMOTE_FUNCTION_WORKSPACE, KMS_KEY_ARN, session + ) + + +def test_convert_run_to_json(): + run = Mock() + run.experiment_name = TEST_EXP_NAME + run.run_name = TEST_RUN_NAME + + assert _convert_run_to_json(run) == ( + '{"experiment_name": "my-exp-name", "run_name": "my-run-name"}' + ) diff --git a/tests/unit/sagemaker/remote_function/test_logging_config.py b/tests/unit/sagemaker/remote_function/test_logging_config.py new file mode 100644 index 0000000000..f3b65e005a --- /dev/null +++ b/tests/unit/sagemaker/remote_function/test_logging_config.py @@ -0,0 +1,28 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +import logging + +from sagemaker.remote_function.logging_config import get_logger + + +def test_logger_config(): + logging.basicConfig(level=logging.INFO) + + logger_1 = get_logger() + assert len(logger_1.handlers) == 1 + + logger_2 = get_logger() + assert logger_2 is logger_1 + assert len(logger_2.handlers) == 1 diff --git a/tests/unit/test_exception_on_bad_status.py b/tests/unit/test_exception_on_bad_status.py index 471cb3b9b6..2ef017efd3 100644 --- a/tests/unit/test_exception_on_bad_status.py +++ b/tests/unit/test_exception_on_bad_status.py @@ -15,6 +15,7 @@ import pytest from mock import Mock, MagicMock import sagemaker +from sagemaker.session import _check_job_status EXPANDED_ROLE = "arn:aws:iam::111111111111:role/ExpandedRole" REGION = "us-west-2" @@ -59,10 +60,7 @@ def test_raise_when_failed_created_package(): def test_does_not_raise_when_correct_job_status(): try: job = Mock() - sagemaker_session = get_sagemaker_session(returns_status="Stopped") - sagemaker_session._check_job_status( - job, {"TransformationJobStatus": "Stopped"}, "TransformationJobStatus" - ) + _check_job_status(job, {"TransformationJobStatus": "Stopped"}, "TransformationJobStatus") except sagemaker.exceptions.UnexpectedStatusException: pytest.fail("UnexpectedStatusException was thrown while it should not") @@ -70,10 +68,7 @@ def test_does_not_raise_when_correct_job_status(): def test_does_raise_when_incorrect_job_status(): try: job = Mock() - sagemaker_session = get_sagemaker_session(returns_status="Failed") - sagemaker_session._check_job_status( - job, {"TransformationJobStatus": "Failed"}, "TransformationJobStatus" - ) + _check_job_status(job, {"TransformationJobStatus": "Failed"}, "TransformationJobStatus") assert ( False ), "sagemaker.exceptions.UnexpectedStatusException should have been raised but was not" @@ -87,8 +82,7 @@ def test_does_raise_when_incorrect_job_status(): def test_does_raise_capacity_error_when_incorrect_job_status(): try: job = Mock() - sagemaker_session = get_sagemaker_session(returns_status="Failed") - sagemaker_session._check_job_status( + _check_job_status( job, { "TransformationJobStatus": "Failed", diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index 9c3f38572f..20d8c675f2 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -1911,11 +1911,20 @@ def __init__(self, code): @pytest.fixture() -def sagemaker_session_complete(): +def boto_session_complete(): boto_mock = MagicMock(name="boto_session") boto_mock.client("logs").describe_log_streams.return_value = DEFAULT_LOG_STREAMS boto_mock.client("logs").get_log_events.side_effect = DEFAULT_LOG_EVENTS - ims = sagemaker.Session(boto_session=boto_mock, sagemaker_client=MagicMock()) + boto_mock.client("sagemaker").describe_training_job.return_value = COMPLETED_DESCRIBE_JOB_RESULT + boto_mock.client( + "sagemaker" + ).describe_transform_job.return_value = COMPLETED_DESCRIBE_TRANSFORM_JOB_RESULT + return boto_mock + + +@pytest.fixture() +def sagemaker_session_complete(boto_session_complete): + ims = sagemaker.Session(boto_session=boto_session_complete, sagemaker_client=MagicMock()) ims.sagemaker_client.describe_training_job.return_value = COMPLETED_DESCRIBE_JOB_RESULT ims.sagemaker_client.describe_transform_job.return_value = ( COMPLETED_DESCRIBE_TRANSFORM_JOB_RESULT @@ -1924,22 +1933,46 @@ def sagemaker_session_complete(): @pytest.fixture() -def sagemaker_session_stopped(): +def boto_session_stopped(): boto_mock = MagicMock(name="boto_session") boto_mock.client("logs").describe_log_streams.return_value = DEFAULT_LOG_STREAMS boto_mock.client("logs").get_log_events.side_effect = DEFAULT_LOG_EVENTS - ims = sagemaker.Session(boto_session=boto_mock, sagemaker_client=MagicMock()) + boto_mock.client("sagemaker").describe_training_job.return_value = STOPPED_DESCRIBE_JOB_RESULT + boto_mock.client( + "sagemaker" + ).describe_transform_job.return_value = STOPPED_DESCRIBE_TRANSFORM_JOB_RESULT + return boto_mock + + +@pytest.fixture() +def sagemaker_session_stopped(boto_session_stopped): + ims = sagemaker.Session(boto_session=boto_session_stopped, sagemaker_client=MagicMock()) ims.sagemaker_client.describe_training_job.return_value = STOPPED_DESCRIBE_JOB_RESULT ims.sagemaker_client.describe_transform_job.return_value = STOPPED_DESCRIBE_TRANSFORM_JOB_RESULT return ims @pytest.fixture() -def sagemaker_session_ready_lifecycle(): +def boto_session_ready_lifecycle(): boto_mock = MagicMock(name="boto_session") boto_mock.client("logs").describe_log_streams.return_value = DEFAULT_LOG_STREAMS boto_mock.client("logs").get_log_events.side_effect = STREAM_LOG_EVENTS - ims = sagemaker.Session(boto_session=boto_mock, sagemaker_client=MagicMock()) + boto_mock.client("sagemaker").describe_training_job.side_effect = [ + IN_PROGRESS_DESCRIBE_JOB_RESULT, + IN_PROGRESS_DESCRIBE_JOB_RESULT, + COMPLETED_DESCRIBE_JOB_RESULT, + ] + boto_mock.client("sagemaker").describe_transform_job.side_effect = [ + IN_PROGRESS_DESCRIBE_TRANSFORM_JOB_RESULT, + IN_PROGRESS_DESCRIBE_TRANSFORM_JOB_RESULT, + COMPLETED_DESCRIBE_TRANSFORM_JOB_RESULT, + ] + return boto_mock + + +@pytest.fixture() +def sagemaker_session_ready_lifecycle(boto_session_ready_lifecycle): + ims = sagemaker.Session(boto_session=boto_session_ready_lifecycle, sagemaker_client=MagicMock()) ims.sagemaker_client.describe_training_job.side_effect = [ IN_PROGRESS_DESCRIBE_JOB_RESULT, IN_PROGRESS_DESCRIBE_JOB_RESULT, @@ -1954,11 +1987,26 @@ def sagemaker_session_ready_lifecycle(): @pytest.fixture() -def sagemaker_session_full_lifecycle(): +def boto_session_full_lifecycle(): boto_mock = MagicMock(name="boto_session") boto_mock.client("logs").describe_log_streams.side_effect = LIFECYCLE_LOG_STREAMS boto_mock.client("logs").get_log_events.side_effect = STREAM_LOG_EVENTS - ims = sagemaker.Session(boto_session=boto_mock, sagemaker_client=MagicMock()) + boto_mock.client("sagemaker").describe_training_job.side_effect = [ + IN_PROGRESS_DESCRIBE_JOB_RESULT, + IN_PROGRESS_DESCRIBE_JOB_RESULT, + COMPLETED_DESCRIBE_JOB_RESULT, + ] + boto_mock.client("sagemaker").sagemaker_client.describe_transform_job.side_effect = [ + IN_PROGRESS_DESCRIBE_TRANSFORM_JOB_RESULT, + IN_PROGRESS_DESCRIBE_TRANSFORM_JOB_RESULT, + COMPLETED_DESCRIBE_TRANSFORM_JOB_RESULT, + ] + return boto_mock + + +@pytest.fixture() +def sagemaker_session_full_lifecycle(boto_session_full_lifecycle): + ims = sagemaker.Session(boto_session=boto_session_full_lifecycle, sagemaker_client=MagicMock()) ims.sagemaker_client.describe_training_job.side_effect = [ IN_PROGRESS_DESCRIBE_JOB_RESULT, IN_PROGRESS_DESCRIBE_JOB_RESULT, @@ -1976,7 +2024,9 @@ def sagemaker_session_full_lifecycle(): def test_logs_for_job_no_wait(cw, sagemaker_session_complete): ims = sagemaker_session_complete ims.logs_for_job(JOB_NAME) - ims.sagemaker_client.describe_training_job.assert_called_once_with(TrainingJobName=JOB_NAME) + ims.boto_session.client.return_value.describe_training_job.assert_called_once_with( + TrainingJobName=JOB_NAME + ) cw().assert_called_with(0, "hi there #1") @@ -1984,7 +2034,9 @@ def test_logs_for_job_no_wait(cw, sagemaker_session_complete): def test_logs_for_job_no_wait_stopped_job(cw, sagemaker_session_stopped): ims = sagemaker_session_stopped ims.logs_for_job(JOB_NAME) - ims.sagemaker_client.describe_training_job.assert_called_once_with(TrainingJobName=JOB_NAME) + ims.boto_session.client.return_value.describe_training_job.assert_called_once_with( + TrainingJobName=JOB_NAME + ) cw().assert_called_with(0, "hi there #1") @@ -1992,7 +2044,7 @@ def test_logs_for_job_no_wait_stopped_job(cw, sagemaker_session_stopped): def test_logs_for_job_wait_on_completed(cw, sagemaker_session_complete): ims = sagemaker_session_complete ims.logs_for_job(JOB_NAME, wait=True, poll=0) - assert ims.sagemaker_client.describe_training_job.call_args_list == [ + assert ims.boto_session.client.return_value.describe_training_job.call_args_list == [ call(TrainingJobName=JOB_NAME) ] cw().assert_called_with(0, "hi there #1") @@ -2002,7 +2054,7 @@ def test_logs_for_job_wait_on_completed(cw, sagemaker_session_complete): def test_logs_for_job_wait_on_stopped(cw, sagemaker_session_stopped): ims = sagemaker_session_stopped ims.logs_for_job(JOB_NAME, wait=True, poll=0) - assert ims.sagemaker_client.describe_training_job.call_args_list == [ + assert ims.boto_session.client.return_value.describe_training_job.call_args_list == [ call(TrainingJobName=JOB_NAME) ] cw().assert_called_with(0, "hi there #1") @@ -2012,7 +2064,7 @@ def test_logs_for_job_wait_on_stopped(cw, sagemaker_session_stopped): def test_logs_for_job_no_wait_on_running(cw, sagemaker_session_ready_lifecycle): ims = sagemaker_session_ready_lifecycle ims.logs_for_job(JOB_NAME) - assert ims.sagemaker_client.describe_training_job.call_args_list == [ + assert ims.boto_session.client.return_value.describe_training_job.call_args_list == [ call(TrainingJobName=JOB_NAME) ] cw().assert_called_with(0, "hi there #1") @@ -2024,7 +2076,7 @@ def test_logs_for_job_full_lifecycle(time, cw, sagemaker_session_full_lifecycle) ims = sagemaker_session_full_lifecycle ims.logs_for_job(JOB_NAME, wait=True, poll=0) assert ( - ims.sagemaker_client.describe_training_job.call_args_list + ims.boto_session.client.return_value.describe_training_job.call_args_list == [call(TrainingJobName=JOB_NAME)] * 3 ) assert cw().call_args_list == [