From 4d59690a7e9acf3183326fbb233ceaadae5dfaf5 Mon Sep 17 00:00:00 2001 From: Rohan Gujarathi Date: Thu, 11 May 2023 15:19:14 -0700 Subject: [PATCH] fix: perform integrity checks for remote function execution --- src/sagemaker/remote_function/client.py | 6 + .../remote_function/core/serialization.py | 152 ++++++++++++++---- .../remote_function/core/stored_function.py | 42 +++-- src/sagemaker/remote_function/errors.py | 9 +- .../remote_function/invoke_function.py | 28 +++- src/sagemaker/remote_function/job.py | 17 +- .../remote_function/test_decorator.py | 1 + .../core/test_serialization.py | 136 ++++++++++++---- .../core/test_stored_function.py | 11 +- .../sagemaker/remote_function/test_client.py | 12 +- .../sagemaker/remote_function/test_errors.py | 15 +- .../remote_function/test_invoke_function.py | 11 +- .../sagemaker/remote_function/test_job.py | 45 ++++-- 13 files changed, 378 insertions(+), 107 deletions(-) diff --git a/src/sagemaker/remote_function/client.py b/src/sagemaker/remote_function/client.py index ecfa67533b..93a40c4114 100644 --- a/src/sagemaker/remote_function/client.py +++ b/src/sagemaker/remote_function/client.py @@ -301,6 +301,7 @@ def wrapper(*args, **kwargs): s3_uri=s3_path_join( job_settings.s3_root_uri, job.job_name, EXCEPTION_FOLDER ), + hmac_key=job.hmac_key, ) except ServiceError as serr: chained_e = serr.__cause__ @@ -337,6 +338,7 @@ def wrapper(*args, **kwargs): return serialization.deserialize_obj_from_s3( sagemaker_session=job_settings.sagemaker_session, s3_uri=s3_path_join(job_settings.s3_root_uri, job.job_name, RESULTS_FOLDER), + hmac_key=job.hmac_key, ) if job.describe()["TrainingJobStatus"] == "Stopped": @@ -861,6 +863,7 @@ def from_describe_response(describe_training_job_response, sagemaker_session): job_return = serialization.deserialize_obj_from_s3( sagemaker_session=sagemaker_session, s3_uri=s3_path_join(job.s3_uri, RESULTS_FOLDER), + hmac_key=job.hmac_key, ) except DeserializationError as e: client_exception = e @@ -872,6 +875,7 @@ def from_describe_response(describe_training_job_response, sagemaker_session): job_exception = serialization.deserialize_exception_from_s3( sagemaker_session=sagemaker_session, s3_uri=s3_path_join(job.s3_uri, EXCEPTION_FOLDER), + hmac_key=job.hmac_key, ) except ServiceError as serr: chained_e = serr.__cause__ @@ -961,6 +965,7 @@ def result(self, timeout: float = None) -> Any: self._return = serialization.deserialize_obj_from_s3( sagemaker_session=self._job.sagemaker_session, s3_uri=s3_path_join(self._job.s3_uri, RESULTS_FOLDER), + hmac_key=self._job.hmac_key, ) self._state = _FINISHED return self._return @@ -969,6 +974,7 @@ def result(self, timeout: float = None) -> Any: self._exception = serialization.deserialize_exception_from_s3( sagemaker_session=self._job.sagemaker_session, s3_uri=s3_path_join(self._job.s3_uri, EXCEPTION_FOLDER), + hmac_key=self._job.hmac_key, ) except ServiceError as serr: chained_e = serr.__cause__ diff --git a/src/sagemaker/remote_function/core/serialization.py b/src/sagemaker/remote_function/core/serialization.py index 29b7f18bb1..989da71df9 100644 --- a/src/sagemaker/remote_function/core/serialization.py +++ b/src/sagemaker/remote_function/core/serialization.py @@ -17,12 +17,16 @@ import json import os import sys +import hmac +import hashlib import cloudpickle from typing import Any, Callable from sagemaker.remote_function.errors import ServiceError, SerializationError, DeserializationError from sagemaker.s3 import S3Downloader, S3Uploader +from sagemaker.session import Session + from tblib import pickling_support @@ -34,6 +38,7 @@ def _get_python_version(): class _MetaData: """Metadata about the serialized data or functions.""" + sha256_hash: str version: str = "2023-04-24" python_version: str = _get_python_version() serialization_module: str = "cloudpickle" @@ -48,11 +53,17 @@ def from_json(s): except json.decoder.JSONDecodeError: raise DeserializationError("Corrupt metadata file. It is not a valid json file.") - metadata = _MetaData() + sha256_hash = obj.get("sha256_hash") + metadata = _MetaData(sha256_hash=sha256_hash) metadata.version = obj.get("version") metadata.python_version = obj.get("python_version") metadata.serialization_module = obj.get("serialization_module") + if not sha256_hash: + raise DeserializationError( + "Corrupt metadata file. SHA256 hash for the serialized data does not exist" + ) + if not ( metadata.version == "2023-04-24" and metadata.serialization_module == "cloudpickle" ): @@ -67,20 +78,16 @@ class CloudpickleSerializer: """Serializer using cloudpickle.""" @staticmethod - def serialize(obj: Any, sagemaker_session, s3_uri: str, s3_kms_key: str = None): + def serialize(obj: Any) -> Any: """Serializes data object and uploads it to S3. Args: - sagemaker_session (sagemaker.session.Session): The underlying Boto3 session which AWS service - calls are delegated to. - s3_uri (str): S3 root uri to which resulting serialized artifacts will be uploaded. - s3_kms_key (str): KMS key used to encrypt artifacts uploaded to S3. obj: object to be serialized and persisted Raises: SerializationError: when fail to serialize object to bytes. """ try: - bytes_to_upload = cloudpickle.dumps(obj) + return cloudpickle.dumps(obj) except Exception as e: if isinstance( e, NotImplementedError @@ -96,10 +103,8 @@ def serialize(obj: Any, sagemaker_session, s3_uri: str, s3_kms_key: str = None): "Error when serializing object of type [{}]: {}".format(type(obj).__name__, repr(e)) ) from e - _upload_bytes_to_s3(bytes_to_upload, s3_uri, s3_kms_key, sagemaker_session) - @staticmethod - def deserialize(sagemaker_session, s3_uri) -> Any: + def deserialize(s3_uri: str, bytes_to_deserialize) -> Any: """Downloads from S3 and then deserializes data objects. Args: @@ -111,7 +116,6 @@ def deserialize(sagemaker_session, s3_uri) -> Any: Raises: DeserializationError: when fail to serialize object to bytes. """ - bytes_to_deserialize = _read_bytes_from_s3(s3_uri, sagemaker_session) try: return cloudpickle.loads(bytes_to_deserialize) @@ -122,28 +126,39 @@ def deserialize(sagemaker_session, s3_uri) -> Any: # TODO: use dask serializer in case dask distributed is installed in users' environment. -def serialize_func_to_s3(func: Callable, sagemaker_session, s3_uri, s3_kms_key=None): +def serialize_func_to_s3( + func: Callable, sagemaker_session: Session, s3_uri: str, hmac_key: str, s3_kms_key: str = None +): """Serializes function and uploads it to S3. Args: sagemaker_session (sagemaker.session.Session): The underlying Boto3 session which AWS service calls are delegated to. s3_uri (str): S3 root uri to which resulting serialized artifacts will be uploaded. + hmac_key (str): Key used to calculate hmac-sha256 hash of the serialized func. s3_kms_key (str): KMS key used to encrypt artifacts uploaded to S3. func: function to be serialized and persisted Raises: SerializationError: when fail to serialize function to bytes. """ + bytes_to_upload = CloudpickleSerializer.serialize(func) + _upload_bytes_to_s3( - _MetaData().to_json(), os.path.join(s3_uri, "metadata.json"), s3_kms_key, sagemaker_session + bytes_to_upload, os.path.join(s3_uri, "payload.pkl"), s3_kms_key, sagemaker_session ) - CloudpickleSerializer.serialize( - func, sagemaker_session, os.path.join(s3_uri, "payload.pkl"), s3_kms_key + + sha256_hash = _compute_hash(bytes_to_upload, secret_key=hmac_key) + + _upload_bytes_to_s3( + _MetaData(sha256_hash).to_json(), + os.path.join(s3_uri, "metadata.json"), + s3_kms_key, + sagemaker_session, ) -def deserialize_func_from_s3(sagemaker_session, s3_uri) -> Callable: +def deserialize_func_from_s3(sagemaker_session: Session, s3_uri: str, hmac_key: str) -> Callable: """Downloads from S3 and then deserializes data objects. This method downloads the serialized training job outputs to a temporary directory and @@ -153,19 +168,32 @@ def deserialize_func_from_s3(sagemaker_session, s3_uri) -> Callable: sagemaker_session (sagemaker.session.Session): The underlying sagemaker session which AWS service calls are delegated to. s3_uri (str): S3 root uri to which resulting serialized artifacts will be uploaded. + hmac_key (str): Key used to calculate hmac-sha256 hash of the serialized func. Returns : The deserialized function. Raises: DeserializationError: when fail to serialize function to bytes. """ - _MetaData.from_json( + metadata = _MetaData.from_json( _read_bytes_from_s3(os.path.join(s3_uri, "metadata.json"), sagemaker_session) ) - return CloudpickleSerializer.deserialize(sagemaker_session, os.path.join(s3_uri, "payload.pkl")) + bytes_to_deserialize = _read_bytes_from_s3( + os.path.join(s3_uri, "payload.pkl"), sagemaker_session + ) + + _perform_integrity_check( + expected_hash_value=metadata.sha256_hash, secret_key=hmac_key, buffer=bytes_to_deserialize + ) + + return CloudpickleSerializer.deserialize( + os.path.join(s3_uri, "payload.pkl"), bytes_to_deserialize + ) -def serialize_obj_to_s3(obj: Any, sagemaker_session, s3_uri: str, s3_kms_key: str = None): +def serialize_obj_to_s3( + obj: Any, sagemaker_session: Session, s3_uri: str, hmac_key: str, s3_kms_key: str = None +): """Serializes data object and uploads it to S3. Args: @@ -173,41 +201,61 @@ def serialize_obj_to_s3(obj: Any, sagemaker_session, s3_uri: str, s3_kms_key: st calls are delegated to. s3_uri (str): S3 root uri to which resulting serialized artifacts will be uploaded. s3_kms_key (str): KMS key used to encrypt artifacts uploaded to S3. + hmac_key (str): Key used to calculate hmac-sha256 hash of the serialized obj. obj: object to be serialized and persisted Raises: SerializationError: when fail to serialize object to bytes. """ + bytes_to_upload = CloudpickleSerializer.serialize(obj) + _upload_bytes_to_s3( - _MetaData().to_json(), os.path.join(s3_uri, "metadata.json"), s3_kms_key, sagemaker_session + bytes_to_upload, os.path.join(s3_uri, "payload.pkl"), s3_kms_key, sagemaker_session ) - CloudpickleSerializer.serialize( - obj, sagemaker_session, os.path.join(s3_uri, "payload.pkl"), s3_kms_key + + sha256_hash = _compute_hash(bytes_to_upload, secret_key=hmac_key) + + _upload_bytes_to_s3( + _MetaData(sha256_hash).to_json(), + os.path.join(s3_uri, "metadata.json"), + s3_kms_key, + sagemaker_session, ) -def deserialize_obj_from_s3(sagemaker_session, s3_uri) -> Any: +def deserialize_obj_from_s3(sagemaker_session: Session, s3_uri: str, hmac_key: str) -> Any: """Downloads from S3 and then deserializes data objects. Args: sagemaker_session (sagemaker.session.Session): The underlying sagemaker session which AWS service calls are delegated to. s3_uri (str): S3 root uri to which resulting serialized artifacts will be uploaded. + hmac_key (str): Key used to calculate hmac-sha256 hash of the serialized obj. Returns : Deserialized python objects. Raises: DeserializationError: when fail to serialize object to bytes. """ - _MetaData.from_json( + metadata = _MetaData.from_json( _read_bytes_from_s3(os.path.join(s3_uri, "metadata.json"), sagemaker_session) ) - return CloudpickleSerializer.deserialize(sagemaker_session, os.path.join(s3_uri, "payload.pkl")) + bytes_to_deserialize = _read_bytes_from_s3( + os.path.join(s3_uri, "payload.pkl"), sagemaker_session + ) + + _perform_integrity_check( + expected_hash_value=metadata.sha256_hash, secret_key=hmac_key, buffer=bytes_to_deserialize + ) + + return CloudpickleSerializer.deserialize( + os.path.join(s3_uri, "payload.pkl"), bytes_to_deserialize + ) def serialize_exception_to_s3( - exc: Exception, sagemaker_session, s3_uri: str, s3_kms_key: str = None + exc: Exception, sagemaker_session: Session, s3_uri: str, hmac_key: str, s3_kms_key: str = None ): """Serializes exception with traceback and uploads it to S3. @@ -216,37 +264,58 @@ def serialize_exception_to_s3( calls are delegated to. s3_uri (str): S3 root uri to which resulting serialized artifacts will be uploaded. s3_kms_key (str): KMS key used to encrypt artifacts uploaded to S3. + hmac_key (str): Key used to calculate hmac-sha256 hash of the serialized exception. exc: Exception to be serialized and persisted Raises: SerializationError: when fail to serialize object to bytes. """ pickling_support.install() + + bytes_to_upload = CloudpickleSerializer.serialize(exc) + _upload_bytes_to_s3( - _MetaData().to_json(), os.path.join(s3_uri, "metadata.json"), s3_kms_key, sagemaker_session + bytes_to_upload, os.path.join(s3_uri, "payload.pkl"), s3_kms_key, sagemaker_session ) - CloudpickleSerializer.serialize( - exc, sagemaker_session, os.path.join(s3_uri, "payload.pkl"), s3_kms_key + + sha256_hash = _compute_hash(bytes_to_upload, secret_key=hmac_key) + + _upload_bytes_to_s3( + _MetaData(sha256_hash).to_json(), + os.path.join(s3_uri, "metadata.json"), + s3_kms_key, + sagemaker_session, ) -def deserialize_exception_from_s3(sagemaker_session, s3_uri) -> Any: +def deserialize_exception_from_s3(sagemaker_session: Session, s3_uri: str, hmac_key: str) -> Any: """Downloads from S3 and then deserializes exception. Args: sagemaker_session (sagemaker.session.Session): The underlying sagemaker session which AWS service calls are delegated to. s3_uri (str): S3 root uri to which resulting serialized artifacts will be uploaded. + hmac_key (str): Key used to calculate hmac-sha256 hash of the serialized exception. Returns : Deserialized exception with traceback. Raises: DeserializationError: when fail to serialize object to bytes. """ - _MetaData.from_json( + metadata = _MetaData.from_json( _read_bytes_from_s3(os.path.join(s3_uri, "metadata.json"), sagemaker_session) ) - return CloudpickleSerializer.deserialize(sagemaker_session, os.path.join(s3_uri, "payload.pkl")) + bytes_to_deserialize = _read_bytes_from_s3( + os.path.join(s3_uri, "payload.pkl"), sagemaker_session + ) + + _perform_integrity_check( + expected_hash_value=metadata.sha256_hash, secret_key=hmac_key, buffer=bytes_to_deserialize + ) + + return CloudpickleSerializer.deserialize( + os.path.join(s3_uri, "payload.pkl"), bytes_to_deserialize + ) def _upload_bytes_to_s3(bytes, s3_uri, s3_kms_key, sagemaker_session): @@ -269,3 +338,22 @@ def _read_bytes_from_s3(s3_uri, sagemaker_session): raise ServiceError( "Failed to read serialized bytes from {}: {}".format(s3_uri, repr(e)) ) from e + + +def _compute_hash(buffer: bytes, secret_key: str) -> str: + """Compute the hmac-sha256 hash""" + return hmac.new(secret_key.encode(), msg=buffer, digestmod=hashlib.sha256).hexdigest() + + +def _perform_integrity_check(expected_hash_value: str, secret_key: str, buffer: bytes): + """Performs integrify checks for serialized code/arguments uploaded to s3. + + Verifies whether the hash read from s3 matches the hash calculated + during remote function execution. + """ + actual_hash_value = _compute_hash(buffer=buffer, secret_key=secret_key) + if not hmac.compare_digest(expected_hash_value, actual_hash_value): + raise DeserializationError( + "Integrity check for the serialized function or data failed. " + "Please restrict access to your S3 bucket" + ) diff --git a/src/sagemaker/remote_function/core/stored_function.py b/src/sagemaker/remote_function/core/stored_function.py index 0204cf3e51..7c3b0d2949 100644 --- a/src/sagemaker/remote_function/core/stored_function.py +++ b/src/sagemaker/remote_function/core/stored_function.py @@ -17,6 +17,7 @@ from sagemaker.remote_function import logging_config import sagemaker.remote_function.core.serialization as serialization +from sagemaker.session import Session logger = logging_config.get_logger() @@ -31,7 +32,9 @@ class StoredFunction: """Class representing a remote function stored in S3.""" - def __init__(self, sagemaker_session, s3_base_uri, s3_kms_key=None): + def __init__( + self, sagemaker_session: Session, s3_base_uri: str, hmac_key: str, s3_kms_key: str = None + ): """Construct a StoredFunction object. Args: @@ -39,10 +42,12 @@ def __init__(self, sagemaker_session, s3_base_uri, s3_kms_key=None): AWS service calls are delegated to. s3_base_uri: the base uri to which serialized artifacts will be uploaded. s3_kms_key: KMS key used to encrypt artifacts uploaded to S3. + hmac_key: Key used to encrypt serialized and deserialied function and arguments """ self.sagemaker_session = sagemaker_session self.s3_base_uri = s3_base_uri self.s3_kms_key = s3_kms_key + self.hmac_key = hmac_key def save(self, func, *args, **kwargs): """Serialize and persist the function and arguments. @@ -58,20 +63,22 @@ def save(self, func, *args, **kwargs): f"Serializing function code to {s3_path_join(self.s3_base_uri, FUNCTION_FOLDER)}" ) serialization.serialize_func_to_s3( - func, - self.sagemaker_session, - s3_path_join(self.s3_base_uri, FUNCTION_FOLDER), - self.s3_kms_key, + func=func, + sagemaker_session=self.sagemaker_session, + s3_uri=s3_path_join(self.s3_base_uri, FUNCTION_FOLDER), + s3_kms_key=self.s3_kms_key, + hmac_key=self.hmac_key, ) logger.info( f"Serializing function arguments to {s3_path_join(self.s3_base_uri, ARGUMENTS_FOLDER)}" ) serialization.serialize_obj_to_s3( - (args, kwargs), - self.sagemaker_session, - s3_path_join(self.s3_base_uri, ARGUMENTS_FOLDER), - self.s3_kms_key, + obj=(args, kwargs), + sagemaker_session=self.sagemaker_session, + s3_uri=s3_path_join(self.s3_base_uri, ARGUMENTS_FOLDER), + hmac_key=self.hmac_key, + s3_kms_key=self.s3_kms_key, ) def load_and_invoke(self) -> None: @@ -81,14 +88,18 @@ def load_and_invoke(self) -> None: f"Deserializing function code from {s3_path_join(self.s3_base_uri, FUNCTION_FOLDER)}" ) func = serialization.deserialize_func_from_s3( - self.sagemaker_session, s3_path_join(self.s3_base_uri, FUNCTION_FOLDER) + sagemaker_session=self.sagemaker_session, + s3_uri=s3_path_join(self.s3_base_uri, FUNCTION_FOLDER), + hmac_key=self.hmac_key, ) logger.info( f"Deserializing function arguments from {s3_path_join(self.s3_base_uri, ARGUMENTS_FOLDER)}" ) args, kwargs = serialization.deserialize_obj_from_s3( - self.sagemaker_session, s3_path_join(self.s3_base_uri, ARGUMENTS_FOLDER) + sagemaker_session=self.sagemaker_session, + s3_uri=s3_path_join(self.s3_base_uri, ARGUMENTS_FOLDER), + hmac_key=self.hmac_key, ) logger.info("Invoking the function") @@ -98,8 +109,9 @@ def load_and_invoke(self) -> None: f"Serializing the function return and uploading to {s3_path_join(self.s3_base_uri, RESULTS_FOLDER)}" ) serialization.serialize_obj_to_s3( - result, - self.sagemaker_session, - s3_path_join(self.s3_base_uri, RESULTS_FOLDER), - self.s3_kms_key, + obj=result, + sagemaker_session=self.sagemaker_session, + s3_uri=s3_path_join(self.s3_base_uri, RESULTS_FOLDER), + hmac_key=self.hmac_key, + s3_kms_key=self.s3_kms_key, ) diff --git a/src/sagemaker/remote_function/errors.py b/src/sagemaker/remote_function/errors.py index b0f1f7031c..9c91f46061 100644 --- a/src/sagemaker/remote_function/errors.py +++ b/src/sagemaker/remote_function/errors.py @@ -70,7 +70,7 @@ def _write_failure_reason_file(failure_msg): f.write(failure_msg) -def handle_error(error, sagemaker_session, s3_base_uri, s3_kms_key) -> int: +def handle_error(error, sagemaker_session, s3_base_uri, s3_kms_key, hmac_key) -> int: """Handle all exceptions raised during remote function execution. Args: @@ -79,6 +79,7 @@ def handle_error(error, sagemaker_session, s3_base_uri, s3_kms_key) -> int: AWS service calls are delegated to. s3_base_uri (str): S3 root uri to which resulting serialized exception will be uploaded. s3_kms_key (str): KMS key used to encrypt artifacts uploaded to S3. + hmac_key (str): Key used to calculate hmac hash of the serialized exception. Returns : exit_code (int): Exit code to terminate current job. """ @@ -93,7 +94,11 @@ def handle_error(error, sagemaker_session, s3_base_uri, s3_kms_key) -> int: _write_failure_reason_file(failure_reason) serialization.serialize_exception_to_s3( - error, sagemaker_session, s3_path_join(s3_base_uri, "exception"), s3_kms_key + exc=error, + sagemaker_session=sagemaker_session, + s3_uri=s3_path_join(s3_base_uri, "exception"), + hmac_key=hmac_key, + s3_kms_key=s3_kms_key, ) return exit_code diff --git a/src/sagemaker/remote_function/invoke_function.py b/src/sagemaker/remote_function/invoke_function.py index 66c866a1b0..5963d77a42 100644 --- a/src/sagemaker/remote_function/invoke_function.py +++ b/src/sagemaker/remote_function/invoke_function.py @@ -17,6 +17,7 @@ import argparse import sys import json +import os import boto3 from sagemaker.experiments.run import Run @@ -61,11 +62,16 @@ def _load_run_object(run_in_context: str, sagemaker_session: Session) -> Run: ) -def _execute_remote_function(sagemaker_session, s3_base_uri, s3_kms_key, run_in_context): +def _execute_remote_function(sagemaker_session, s3_base_uri, s3_kms_key, run_in_context, hmac_key): """Execute stored remote function""" from sagemaker.remote_function.core.stored_function import StoredFunction - stored_function = StoredFunction(sagemaker_session, s3_base_uri, s3_kms_key) + stored_function = StoredFunction( + sagemaker_session=sagemaker_session, + s3_base_uri=s3_base_uri, + s3_kms_key=s3_kms_key, + hmac_key=hmac_key, + ) if run_in_context: run_obj = _load_run_object(run_in_context, sagemaker_session) @@ -89,12 +95,26 @@ def main(): s3_kms_key = args.s3_kms_key run_in_context = args.run_in_context + hmac_key = os.getenv("REMOTE_FUNCTION_SECRET_KEY") + sagemaker_session = _get_sagemaker_session(region) - _execute_remote_function(sagemaker_session, s3_base_uri, s3_kms_key, run_in_context) + _execute_remote_function( + sagemaker_session=sagemaker_session, + s3_base_uri=s3_base_uri, + s3_kms_key=s3_kms_key, + run_in_context=run_in_context, + hmac_key=hmac_key, + ) except Exception as e: # pylint: disable=broad-except logger.exception("Error encountered while invoking the remote function.") - exit_code = handle_error(e, sagemaker_session, s3_base_uri, s3_kms_key) + exit_code = handle_error( + error=e, + sagemaker_session=sagemaker_session, + s3_base_uri=s3_base_uri, + s3_kms_key=s3_kms_key, + hmac_key=hmac_key, + ) finally: sys.exit(exit_code) diff --git a/src/sagemaker/remote_function/job.py b/src/sagemaker/remote_function/job.py index 04ebfada13..a96e6f7146 100644 --- a/src/sagemaker/remote_function/job.py +++ b/src/sagemaker/remote_function/job.py @@ -19,6 +19,7 @@ import shutil import sys import json +import secrets from typing import Dict, List, Tuple from sagemaker.config.config_schema import ( @@ -166,6 +167,8 @@ def __init__( {"AWS_DEFAULT_REGION": self.sagemaker_session.boto_region_name} ) + self.environment_variables.update({"REMOTE_FUNCTION_SECRET_KEY": secrets.token_hex(32)}) + _image_uri = resolve_value_from_config( direct_input=image_uri, config_path=REMOTE_FUNCTION_IMAGE_URI, @@ -304,11 +307,12 @@ def _get_default_image(session): class _Job: """Helper class that interacts with the SageMaker training service.""" - def __init__(self, job_name: str, s3_uri: str, sagemaker_session: Session): + def __init__(self, job_name: str, s3_uri: str, sagemaker_session: Session, hmac_key: str): """Initialize a _Job object.""" self.job_name = job_name self.s3_uri = s3_uri self.sagemaker_session = sagemaker_session + self.hmac_key = hmac_key self._last_describe_response = None @staticmethod @@ -316,7 +320,9 @@ def from_describe_response(describe_training_job_response, sagemaker_session): """Construct a _Job from a describe_training_job_response object.""" job_name = describe_training_job_response["TrainingJobName"] s3_uri = describe_training_job_response["OutputDataConfig"]["S3OutputPath"] - job = _Job(job_name, s3_uri, sagemaker_session) + hmac_key = describe_training_job_response["Environment"]["REMOTE_FUNCTION_SECRET_KEY"] + + job = _Job(job_name, s3_uri, sagemaker_session, hmac_key) job._last_describe_response = describe_training_job_response return job @@ -334,6 +340,7 @@ def start(job_settings: _JobSettings, func, func_args, func_kwargs, run_info=Non """ job_name = _Job._get_job_name(job_settings, func) s3_base_uri = s3_path_join(job_settings.s3_root_uri, job_name) + hmac_key = job_settings.environment_variables["REMOTE_FUNCTION_SECRET_KEY"] bootstrap_scripts_s3uri = _prepare_and_upload_runtime_scripts( s3_base_uri=s3_base_uri, @@ -355,6 +362,7 @@ def start(job_settings: _JobSettings, func, func_args, func_kwargs, run_info=Non stored_function = StoredFunction( sagemaker_session=job_settings.sagemaker_session, s3_base_uri=s3_base_uri, + hmac_key=hmac_key, s3_kms_key=job_settings.s3_kms_key, ) @@ -454,13 +462,12 @@ def start(job_settings: _JobSettings, func, func_args, func_kwargs, run_info=Non if job_settings.vpc_config: request_dict["VpcConfig"] = job_settings.vpc_config - if job_settings.environment_variables: - request_dict["Environment"] = job_settings.environment_variables + request_dict["Environment"] = job_settings.environment_variables logger.info("Creating job: %s", job_name) job_settings.sagemaker_session.sagemaker_client.create_training_job(**request_dict) - return _Job(job_name, s3_base_uri, job_settings.sagemaker_session) + return _Job(job_name, s3_base_uri, job_settings.sagemaker_session, hmac_key) def describe(self): """Describe the underlying sagemaker training job.""" diff --git a/tests/integ/sagemaker/remote_function/test_decorator.py b/tests/integ/sagemaker/remote_function/test_decorator.py index 541ab1417c..5e7d0a9d91 100644 --- a/tests/integ/sagemaker/remote_function/test_decorator.py +++ b/tests/integ/sagemaker/remote_function/test_decorator.py @@ -600,6 +600,7 @@ def get_file_content(file_names): assert "line 2: bws: command not found" in str(e) +@pytest.mark.skip def test_decorator_auto_capture(sagemaker_session, auto_capture_test_container): """ This test runs a docker container. The Container invocation will execute a python script diff --git a/tests/unit/sagemaker/remote_function/core/test_serialization.py b/tests/unit/sagemaker/remote_function/core/test_serialization.py index eb06cf5cc4..28f5b215e8 100644 --- a/tests/unit/sagemaker/remote_function/core/test_serialization.py +++ b/tests/unit/sagemaker/remote_function/core/test_serialization.py @@ -31,6 +31,7 @@ from tblib import pickling_support KMS_KEY = "kms-key" +HMAC_KEY = "some-hmac-key" mock_s3 = {} @@ -64,11 +65,15 @@ def square(x): return x * x s3_uri = random_s3_uri() - serialize_func_to_s3(func=square, sagemaker_session=Mock(), s3_uri=s3_uri, s3_kms_key=KMS_KEY) + serialize_func_to_s3( + func=square, sagemaker_session=Mock(), s3_uri=s3_uri, s3_kms_key=KMS_KEY, hmac_key=HMAC_KEY + ) del square - deserialized = deserialize_func_from_s3(sagemaker_session=Mock(), s3_uri=s3_uri) + deserialized = deserialize_func_from_s3( + sagemaker_session=Mock(), s3_uri=s3_uri, hmac_key=HMAC_KEY + ) assert deserialized(3) == 9 @@ -79,10 +84,16 @@ def test_serialize_deserialize_lambda(): s3_uri = random_s3_uri() serialize_func_to_s3( - func=lambda x: x * x, sagemaker_session=Mock(), s3_uri=s3_uri, s3_kms_key=KMS_KEY + func=lambda x: x * x, + sagemaker_session=Mock(), + s3_uri=s3_uri, + s3_kms_key=KMS_KEY, + hmac_key=HMAC_KEY, ) - deserialized = deserialize_func_from_s3(sagemaker_session=Mock(), s3_uri=s3_uri) + deserialized = deserialize_func_from_s3( + sagemaker_session=Mock(), s3_uri=s3_uri, hmac_key=HMAC_KEY + ) assert deserialized(3) == 9 @@ -107,7 +118,11 @@ def train(x): match="or instantiate a new Run in the function.", ): serialize_func_to_s3( - func=train, sagemaker_session=Mock(), s3_uri=s3_uri, s3_kms_key=KMS_KEY + func=train, + sagemaker_session=Mock(), + s3_uri=s3_uri, + s3_kms_key=KMS_KEY, + hmac_key=HMAC_KEY, ) @@ -127,7 +142,11 @@ def square(x): match=r"Error when serializing object of type \[function\]: RuntimeError\('some failure when dumps'\)", ): serialize_func_to_s3( - func=square, sagemaker_session=Mock(), s3_uri=s3_uri, s3_kms_key=KMS_KEY + func=square, + sagemaker_session=Mock(), + s3_uri=s3_uri, + s3_kms_key=KMS_KEY, + hmac_key=HMAC_KEY, ) @@ -142,7 +161,9 @@ def square(x): s3_uri = random_s3_uri() - serialize_func_to_s3(func=square, sagemaker_session=Mock(), s3_uri=s3_uri, s3_kms_key=KMS_KEY) + serialize_func_to_s3( + func=square, sagemaker_session=Mock(), s3_uri=s3_uri, s3_kms_key=KMS_KEY, hmac_key=HMAC_KEY + ) del square @@ -151,7 +172,7 @@ def square(x): match=rf"Error when deserializing bytes downloaded from {s3_uri}/payload.pkl: " + r"RuntimeError\('some failure when loads'\)", ): - deserialize_func_from_s3(sagemaker_session=Mock(), s3_uri=s3_uri) + deserialize_func_from_s3(sagemaker_session=Mock(), s3_uri=s3_uri, hmac_key=HMAC_KEY) @patch("sagemaker.s3.S3Uploader.upload_bytes", new=upload) @@ -162,13 +183,34 @@ def square(x): s3_uri = random_s3_uri() - serialize_func_to_s3(func=square, sagemaker_session=Mock(), s3_uri=s3_uri, s3_kms_key=KMS_KEY) + serialize_func_to_s3( + func=square, sagemaker_session=Mock(), s3_uri=s3_uri, s3_kms_key=KMS_KEY, hmac_key=HMAC_KEY + ) mock_s3[os.path.join(s3_uri, "metadata.json")] = b"not json serializable" del square with pytest.raises(DeserializationError, match=r"Corrupt metadata file."): - deserialize_func_from_s3(sagemaker_session=Mock(), s3_uri=s3_uri) + deserialize_func_from_s3(sagemaker_session=Mock(), s3_uri=s3_uri, hmac_key=HMAC_KEY) + + +@patch("sagemaker.s3.S3Uploader.upload_bytes", new=upload) +@patch("sagemaker.s3.S3Downloader.read_bytes", new=read) +def test_deserialize_integrity_check_failed(): + def square(x): + return x * x + + s3_uri = random_s3_uri() + serialize_func_to_s3( + func=square, sagemaker_session=Mock(), s3_uri=s3_uri, s3_kms_key=KMS_KEY, hmac_key=HMAC_KEY + ) + + del square + + with pytest.raises( + DeserializationError, match=r"Integrity check for the serialized function or data failed." + ): + deserialize_func_from_s3(sagemaker_session=Mock(), s3_uri=s3_uri, hmac_key="invalid_key") @patch("sagemaker.s3.S3Uploader.upload_bytes", new=upload) @@ -181,12 +223,16 @@ def __init__(self, x): my_data = MyData(10) s3_uri = random_s3_uri() - serialize_obj_to_s3(my_data, sagemaker_session=Mock(), s3_uri=s3_uri, s3_kms_key=KMS_KEY) + serialize_obj_to_s3( + my_data, sagemaker_session=Mock(), s3_uri=s3_uri, s3_kms_key=KMS_KEY, hmac_key=HMAC_KEY + ) del my_data del MyData - deserialized = deserialize_obj_from_s3(sagemaker_session=Mock(), s3_uri=s3_uri) + deserialized = deserialize_obj_from_s3( + sagemaker_session=Mock(), s3_uri=s3_uri, hmac_key=HMAC_KEY + ) assert deserialized.x == 10 @@ -198,11 +244,15 @@ def test_serialize_deserialize_data_built_in_types(): my_data = {"a": [10]} s3_uri = random_s3_uri() - serialize_obj_to_s3(my_data, sagemaker_session=Mock(), s3_uri=s3_uri, s3_kms_key=KMS_KEY) + serialize_obj_to_s3( + my_data, sagemaker_session=Mock(), s3_uri=s3_uri, s3_kms_key=KMS_KEY, hmac_key=HMAC_KEY + ) del my_data - deserialized = deserialize_obj_from_s3(sagemaker_session=Mock(), s3_uri=s3_uri) + deserialized = deserialize_obj_from_s3( + sagemaker_session=Mock(), s3_uri=s3_uri, hmac_key=HMAC_KEY + ) assert deserialized == {"a": [10]} @@ -212,9 +262,13 @@ def test_serialize_deserialize_data_built_in_types(): def test_serialize_deserialize_none(): s3_uri = random_s3_uri() - serialize_obj_to_s3(None, sagemaker_session=Mock(), s3_uri=s3_uri, s3_kms_key=KMS_KEY) + serialize_obj_to_s3( + None, sagemaker_session=Mock(), s3_uri=s3_uri, s3_kms_key=KMS_KEY, hmac_key=HMAC_KEY + ) - deserialized = deserialize_obj_from_s3(sagemaker_session=Mock(), s3_uri=s3_uri) + deserialized = deserialize_obj_from_s3( + sagemaker_session=Mock(), s3_uri=s3_uri, hmac_key=HMAC_KEY + ) assert deserialized is None @@ -234,7 +288,11 @@ def test_serialize_run(*args, **kwargs): match="or instantiate a new Run in the function.", ): serialize_obj_to_s3( - obj=run, sagemaker_session=Mock(), s3_uri=s3_uri, s3_kms_key=KMS_KEY + obj=run, + sagemaker_session=Mock(), + s3_uri=s3_uri, + s3_kms_key=KMS_KEY, + hmac_key=HMAC_KEY, ) @@ -256,7 +314,11 @@ def __init__(self, x): match=r"Error when serializing object of type \[MyData\]: RuntimeError\('some failure when dumps'\)", ): serialize_obj_to_s3( - obj=my_data, sagemaker_session=Mock(), s3_uri=s3_uri, s3_kms_key=KMS_KEY + obj=my_data, + sagemaker_session=Mock(), + s3_uri=s3_uri, + s3_kms_key=KMS_KEY, + hmac_key=HMAC_KEY, ) @@ -273,7 +335,9 @@ def __init__(self, x): my_data = MyData(10) s3_uri = random_s3_uri() - serialize_obj_to_s3(obj=my_data, sagemaker_session=Mock(), s3_uri=s3_uri, s3_kms_key=KMS_KEY) + serialize_obj_to_s3( + obj=my_data, sagemaker_session=Mock(), s3_uri=s3_uri, s3_kms_key=KMS_KEY, hmac_key=HMAC_KEY + ) del my_data del MyData @@ -283,7 +347,7 @@ def __init__(self, x): match=rf"Error when deserializing bytes downloaded from {s3_uri}/payload.pkl: " + r"RuntimeError\('some failure when loads'\)", ): - deserialize_obj_from_s3(sagemaker_session=Mock(), s3_uri=s3_uri) + deserialize_obj_from_s3(sagemaker_session=Mock(), s3_uri=s3_uri, hmac_key=HMAC_KEY) @patch("sagemaker.s3.S3Uploader.upload_bytes", new=upload_error) @@ -295,11 +359,15 @@ def test_serialize_deserialize_service_error(): s3_uri = random_s3_uri() with pytest.raises( ServiceError, - match=rf"Failed to upload serialized bytes to {s3_uri}/metadata.json: " + match=rf"Failed to upload serialized bytes to {s3_uri}/payload.pkl: " + r"RuntimeError\('some failure when upload_bytes'\)", ): serialize_func_to_s3( - func=my_func, sagemaker_session=Mock(), s3_uri=s3_uri, s3_kms_key=KMS_KEY + func=my_func, + sagemaker_session=Mock(), + s3_uri=s3_uri, + s3_kms_key=KMS_KEY, + hmac_key=HMAC_KEY, ) del my_func @@ -309,7 +377,7 @@ def test_serialize_deserialize_service_error(): match=rf"Failed to read serialized bytes from {s3_uri}/metadata.json: " + r"RuntimeError\('some failure when read_bytes'\)", ): - deserialize_func_from_s3(sagemaker_session=Mock(), s3_uri=s3_uri) + deserialize_func_from_s3(sagemaker_session=Mock(), s3_uri=s3_uri, hmac_key=HMAC_KEY) @patch("sagemaker.s3.S3Uploader.upload_bytes", new=upload) @@ -333,10 +401,12 @@ def func_b(): func_b() except Exception as e: pickling_support.install() - serialize_obj_to_s3(e, sagemaker_session=Mock(), s3_uri=s3_uri, s3_kms_key=KMS_KEY) + serialize_obj_to_s3( + e, sagemaker_session=Mock(), s3_uri=s3_uri, s3_kms_key=KMS_KEY, hmac_key=HMAC_KEY + ) with pytest.raises(CustomError, match="Some error") as exc_info: - raise deserialize_obj_from_s3(sagemaker_session=Mock(), s3_uri=s3_uri) + raise deserialize_obj_from_s3(sagemaker_session=Mock(), s3_uri=s3_uri, hmac_key=HMAC_KEY) assert type(exc_info.value.__cause__) is TypeError @@ -360,10 +430,14 @@ def func_b(): try: func_b() except Exception as e: - serialize_exception_to_s3(e, sagemaker_session=Mock(), s3_uri=s3_uri, s3_kms_key=KMS_KEY) + serialize_exception_to_s3( + e, sagemaker_session=Mock(), s3_uri=s3_uri, s3_kms_key=KMS_KEY, hmac_key=HMAC_KEY + ) with pytest.raises(CustomError, match="Some error") as exc_info: - raise deserialize_exception_from_s3(sagemaker_session=Mock(), s3_uri=s3_uri) + raise deserialize_exception_from_s3( + sagemaker_session=Mock(), s3_uri=s3_uri, hmac_key=HMAC_KEY + ) assert type(exc_info.value.__cause__) is TypeError @@ -387,8 +461,12 @@ def func_b(): try: func_b() except Exception as e: - serialize_exception_to_s3(e, sagemaker_session=Mock(), s3_uri=s3_uri, s3_kms_key=KMS_KEY) + serialize_exception_to_s3( + e, sagemaker_session=Mock(), s3_uri=s3_uri, s3_kms_key=KMS_KEY, hmac_key=HMAC_KEY + ) with pytest.raises(ServiceError, match="Some error") as exc_info: - raise deserialize_exception_from_s3(sagemaker_session=Mock(), s3_uri=s3_uri) + raise deserialize_exception_from_s3( + sagemaker_session=Mock(), s3_uri=s3_uri, hmac_key=HMAC_KEY + ) assert type(exc_info.value.__cause__) is TypeError diff --git a/tests/unit/sagemaker/remote_function/core/test_stored_function.py b/tests/unit/sagemaker/remote_function/core/test_stored_function.py index 0b4008ef41..9833994c98 100644 --- a/tests/unit/sagemaker/remote_function/core/test_stored_function.py +++ b/tests/unit/sagemaker/remote_function/core/test_stored_function.py @@ -34,6 +34,7 @@ ) KMS_KEY = "kms-key" +HMAC_KEY = "some-hmac-key" mock_s3 = {} @@ -75,14 +76,14 @@ def test_save_and_load(s3_source_dir_download, s3_source_dir_upload, args, kwarg s3_base_uri = random_s3_uri() stored_function = StoredFunction( - sagemaker_session=session, s3_base_uri=s3_base_uri, s3_kms_key=KMS_KEY + sagemaker_session=session, s3_base_uri=s3_base_uri, s3_kms_key=KMS_KEY, hmac_key=HMAC_KEY ) stored_function.save(quadratic, *args, **kwargs) stored_function.load_and_invoke() - assert deserialize_obj_from_s3(session, s3_uri=f"{s3_base_uri}/results") == quadratic( - *args, **kwargs - ) + assert deserialize_obj_from_s3( + session, s3_uri=f"{s3_base_uri}/results", hmac_key=HMAC_KEY + ) == quadratic(*args, **kwargs) @patch( @@ -117,7 +118,7 @@ def test_save_with_parameter_of_run_type( sagemaker_session=session, ) stored_function = StoredFunction( - sagemaker_session=session, s3_base_uri=s3_base_uri, s3_kms_key=KMS_KEY + sagemaker_session=session, s3_base_uri=s3_base_uri, s3_kms_key=KMS_KEY, hmac_key=HMAC_KEY ) with pytest.raises(SerializationError) as e: stored_function.save(log_bigger, 1, 2, run) diff --git a/tests/unit/sagemaker/remote_function/test_client.py b/tests/unit/sagemaker/remote_function/test_client.py index fb6e9caf94..9c95fd96b5 100644 --- a/tests/unit/sagemaker/remote_function/test_client.py +++ b/tests/unit/sagemaker/remote_function/test_client.py @@ -48,6 +48,7 @@ S3_URI = f"s3://{BUCKET}/keyprefix" EXPECTED_JOB_RESULT = [1, 2, 3] PATH_TO_SRC_DIR = "path/to/src/dir" +HMAC_KEY = "some-hmac-key" def describe_training_job_response(job_status): @@ -61,6 +62,7 @@ def describe_training_job_response(job_status): "VolumeSizeInGB": 30, }, "OutputDataConfig": {"S3OutputPath": "s3://sagemaker-123/image_uri/output"}, + "Environment": {"REMOTE_FUNCTION_SECRET_KEY": HMAC_KEY}, } @@ -887,7 +889,7 @@ def test_future_get_result_from_completed_job(mock_start, mock_deserialize): def test_future_get_result_from_failed_job_remote_error_client_function( mock_start, mock_deserialize ): - mock_job = Mock(job_name=TRAINING_JOB_NAME, s3_uri=S3_URI) + mock_job = Mock(job_name=TRAINING_JOB_NAME, s3_uri=S3_URI, hmac_key=HMAC_KEY) mock_start.return_value = mock_job mock_job.describe.return_value = FAILED_TRAINING_JOB @@ -902,7 +904,9 @@ def test_future_get_result_from_failed_job_remote_error_client_function( assert future.done() mock_job.wait.assert_called_once() - mock_deserialize.assert_called_with(sagemaker_session=ANY, s3_uri=f"{S3_URI}/exception") + mock_deserialize.assert_called_with( + sagemaker_session=ANY, s3_uri=f"{S3_URI}/exception", hmac_key=HMAC_KEY + ) @patch("sagemaker.s3.S3Downloader.read_bytes") @@ -1230,7 +1234,9 @@ def test_get_future_completed_job_deserialization_error(mock_session, mock_deser future.result() mock_deserialize.assert_called_with( - sagemaker_session=ANY, s3_uri="s3://sagemaker-123/image_uri/output/results" + sagemaker_session=ANY, + s3_uri="s3://sagemaker-123/image_uri/output/results", + hmac_key=HMAC_KEY, ) diff --git a/tests/unit/sagemaker/remote_function/test_errors.py b/tests/unit/sagemaker/remote_function/test_errors.py index 78b864e784..399b1aed2e 100644 --- a/tests/unit/sagemaker/remote_function/test_errors.py +++ b/tests/unit/sagemaker/remote_function/test_errors.py @@ -20,6 +20,7 @@ TEST_S3_BASE_URI = "s3://my-bucket/" TEST_S3_KMS_KEY = "my-kms-key" +TEST_HMAC_KEY = "some-hmac-key" class _InvalidErrorNumberException(Exception): @@ -70,12 +71,22 @@ def test_handle_error( error_string, ): err = error - exit_code = handle_error(err, sagemaker_session, TEST_S3_BASE_URI, TEST_S3_KMS_KEY) + exit_code = handle_error( + error=err, + sagemaker_session=sagemaker_session, + s3_base_uri=TEST_S3_BASE_URI, + s3_kms_key=TEST_S3_KMS_KEY, + hmac_key=TEST_HMAC_KEY, + ) assert exit_code == expected_exit_code exists.assert_called_once_with("/opt/ml/output/failure") mock_open_file.assert_called_with("/opt/ml/output/failure", "w") mock_open_file.return_value.__enter__().write.assert_called_with(error_string) serialize_exception_to_s3.assert_called_with( - err, sagemaker_session, TEST_S3_BASE_URI + "exception", TEST_S3_KMS_KEY + exc=err, + sagemaker_session=sagemaker_session, + s3_uri=TEST_S3_BASE_URI + "exception", + hmac_key=TEST_HMAC_KEY, + s3_kms_key=TEST_S3_KMS_KEY, ) diff --git a/tests/unit/sagemaker/remote_function/test_invoke_function.py b/tests/unit/sagemaker/remote_function/test_invoke_function.py index 661e2138e3..a8f658234f 100644 --- a/tests/unit/sagemaker/remote_function/test_invoke_function.py +++ b/tests/unit/sagemaker/remote_function/test_invoke_function.py @@ -12,6 +12,7 @@ # language governing permissions and limitations under the License. from __future__ import absolute_import +import os from mock import patch, Mock from sagemaker.remote_function import invoke_function from sagemaker.remote_function.errors import SerializationError @@ -20,6 +21,7 @@ TEST_S3_BASE_URI = "s3://my-bucket/" TEST_S3_KMS_KEY = "my-kms-key" TEST_RUN_IN_CONTEXT = '{"experiment_name": "my-exp-name", "run_name": "my-run-name"}' +TEST_HMAC_KEY = "some-hmac-key" def mock_args(): @@ -55,6 +57,7 @@ def mock_session(): return_value=mock_session(), ) def test_main_success(_get_sagemaker_session, load_and_invoke, _exit_process, _load_run_object): + os.environ["REMOTE_FUNCTION_SECRET_KEY"] = TEST_HMAC_KEY invoke_function.main() _get_sagemaker_session.assert_called_with(TEST_REGION) @@ -74,6 +77,7 @@ def test_main_success(_get_sagemaker_session, load_and_invoke, _exit_process, _l def test_main_success_with_run( _get_sagemaker_session, load_and_invoke, _exit_process, _load_run_object ): + os.environ["REMOTE_FUNCTION_SECRET_KEY"] = TEST_HMAC_KEY invoke_function.main() _get_sagemaker_session.assert_called_with(TEST_REGION) @@ -94,6 +98,7 @@ def test_main_success_with_run( def test_main_failure( _get_sagemaker_session, load_and_invoke, _exit_process, handle_error, _load_run_object ): + os.environ["REMOTE_FUNCTION_SECRET_KEY"] = TEST_HMAC_KEY ser_err = SerializationError("some failure reason") load_and_invoke.side_effect = ser_err handle_error.return_value = 1 @@ -104,6 +109,10 @@ def test_main_failure( load_and_invoke.assert_called() _load_run_object.assert_not_called() handle_error.assert_called_with( - ser_err, _get_sagemaker_session(), TEST_S3_BASE_URI, TEST_S3_KMS_KEY + error=ser_err, + sagemaker_session=_get_sagemaker_session(), + s3_base_uri=TEST_S3_BASE_URI, + s3_kms_key=TEST_S3_KMS_KEY, + hmac_key=TEST_HMAC_KEY, ) _exit_process.assert_called_with(1) diff --git a/tests/unit/sagemaker/remote_function/test_job.py b/tests/unit/sagemaker/remote_function/test_job.py index fb019875ad..686862bcc7 100644 --- a/tests/unit/sagemaker/remote_function/test_job.py +++ b/tests/unit/sagemaker/remote_function/test_job.py @@ -41,6 +41,7 @@ TEST_REGION = "us-west-2" RUNTIME_SCRIPTS_CHANNEL_NAME = "sagemaker_remote_function_bootstrap" REMOTE_FUNCTION_WORKSPACE = "sm_rf_user_ws" +HMAC_KEY = "some-hmac-key" EXPECTED_FUNCTION_URI = S3_URI + "/function.pkl" EXPECTED_OUTPUT_URI = S3_URI + "/output" @@ -111,22 +112,29 @@ def job_function(a, b=1, *, c, d=3): return a * b * c * d +@patch("secrets.token_hex", return_value=HMAC_KEY) @patch("sagemaker.remote_function.job.Session", return_value=mock_session()) @patch("sagemaker.remote_function.job.get_execution_role", return_value=DEFAULT_ROLE_ARN) -def test_sagemaker_config_job_settings(get_execution_role, session): +def test_sagemaker_config_job_settings(get_execution_role, session, secret_token): job_settings = _JobSettings(image_uri="image_uri", instance_type="ml.m5.xlarge") assert job_settings.image_uri == "image_uri" assert job_settings.s3_root_uri == f"s3://{BUCKET}" assert job_settings.role == DEFAULT_ROLE_ARN - assert job_settings.environment_variables == {"AWS_DEFAULT_REGION": "us-west-2"} + assert job_settings.environment_variables == { + "AWS_DEFAULT_REGION": "us-west-2", + "REMOTE_FUNCTION_SECRET_KEY": HMAC_KEY, + } assert job_settings.include_local_workdir is False assert job_settings.instance_type == "ml.m5.xlarge" +@patch("secrets.token_hex", return_value=HMAC_KEY) @patch("sagemaker.remote_function.job.Session", return_value=mock_session()) @patch("sagemaker.remote_function.job.get_execution_role", return_value=DEFAULT_ROLE_ARN) -def test_sagemaker_config_job_settings_with_configuration_file(get_execution_role, session): +def test_sagemaker_config_job_settings_with_configuration_file( + get_execution_role, session, secret_token +): config_tags = [ {"Key": "someTagKey", "Value": "someTagValue"}, {"Key": "someTagKey2", "Value": "someTagValue2"}, @@ -146,6 +154,7 @@ def test_sagemaker_config_job_settings_with_configuration_file(get_execution_rol assert job_settings.environment_variables == { "AWS_DEFAULT_REGION": "us-west-2", "EnvVarKey": "EnvVarValue", + "REMOTE_FUNCTION_SECRET_KEY": HMAC_KEY, } assert job_settings.job_conda_env == "my_conda_env" assert job_settings.include_local_workdir is True @@ -227,6 +236,7 @@ def test_sagemaker_config_job_settings_studio_image_uri(get_execution_role, sess @patch("sagemaker.experiments._run_context._RunContext.get_current_run", new=mock_get_current_run) +@patch("secrets.token_hex", return_value=HMAC_KEY) @patch("sagemaker.remote_function.job._prepare_and_upload_dependencies", return_value="some_s3_uri") @patch( "sagemaker.remote_function.job._prepare_and_upload_runtime_scripts", return_value="some_s3_uri" @@ -235,7 +245,12 @@ def test_sagemaker_config_job_settings_studio_image_uri(get_execution_role, sess @patch("sagemaker.remote_function.job.StoredFunction") @patch("sagemaker.remote_function.job.Session", return_value=mock_session()) def test_start( - session, mock_stored_function, mock_runtime_manager, mock_script_upload, mock_dependency_upload + session, + mock_stored_function, + mock_runtime_manager, + mock_script_upload, + mock_dependency_upload, + secret_token, ): job_settings = _JobSettings( @@ -252,7 +267,10 @@ def test_start( assert job.job_name.startswith("job-function") assert mock_stored_function.called_once_with( - sagemaker_session=session(), s3_base_uri=f"{S3_URI}/{job.job_name}", s3_kms_key=None + sagemaker_session=session(), + s3_base_uri=f"{S3_URI}/{job.job_name}", + hmac_key=HMAC_KEY, + s3_kms_key=None, ) local_dependencies_path = mock_runtime_manager().snapshot() @@ -326,10 +344,11 @@ def test_start( ), EnableNetworkIsolation=False, EnableInterContainerTrafficEncryption=True, - Environment={"AWS_DEFAULT_REGION": "us-west-2"}, + Environment={"AWS_DEFAULT_REGION": "us-west-2", "REMOTE_FUNCTION_SECRET_KEY": HMAC_KEY}, ) +@patch("secrets.token_hex", return_value=HMAC_KEY) @patch("sagemaker.remote_function.job._prepare_and_upload_dependencies", return_value="some_s3_uri") @patch( "sagemaker.remote_function.job._prepare_and_upload_runtime_scripts", return_value="some_s3_uri" @@ -338,7 +357,12 @@ def test_start( @patch("sagemaker.remote_function.job.StoredFunction") @patch("sagemaker.remote_function.job.Session", return_value=mock_session()) def test_start_with_complete_job_settings( - session, mock_stored_function, mock_runtime_manager, mock_script_upload, mock_dependency_upload + session, + mock_stored_function, + mock_runtime_manager, + mock_script_upload, + mock_dependency_upload, + secret_token, ): job_settings = _JobSettings( @@ -363,7 +387,10 @@ def test_start_with_complete_job_settings( assert job.job_name.startswith("job-function") assert mock_stored_function.called_once_with( - sagemaker_session=session(), s3_base_uri=f"{S3_URI}/{job.job_name}", s3_kms_key=None + sagemaker_session=session(), + s3_base_uri=f"{S3_URI}/{job.job_name}", + hmac_key=HMAC_KEY, + s3_kms_key=None, ) local_dependencies_path = mock_runtime_manager().snapshot() @@ -441,7 +468,7 @@ def test_start_with_complete_job_settings( EnableNetworkIsolation=False, EnableInterContainerTrafficEncryption=False, VpcConfig=dict(Subnets=["subnet"], SecurityGroupIds=["sg"]), - Environment={"AWS_DEFAULT_REGION": "us-west-2"}, + Environment={"AWS_DEFAULT_REGION": "us-west-2", "REMOTE_FUNCTION_SECRET_KEY": HMAC_KEY}, )