Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions src/sagemaker/remote_function/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,6 +301,7 @@ def wrapper(*args, **kwargs):
s3_uri=s3_path_join(
job_settings.s3_root_uri, job.job_name, EXCEPTION_FOLDER
),
hmac_key=job.hmac_key,
)
except ServiceError as serr:
chained_e = serr.__cause__
Expand Down Expand Up @@ -337,6 +338,7 @@ def wrapper(*args, **kwargs):
return serialization.deserialize_obj_from_s3(
sagemaker_session=job_settings.sagemaker_session,
s3_uri=s3_path_join(job_settings.s3_root_uri, job.job_name, RESULTS_FOLDER),
hmac_key=job.hmac_key,
)

if job.describe()["TrainingJobStatus"] == "Stopped":
Expand Down Expand Up @@ -861,6 +863,7 @@ def from_describe_response(describe_training_job_response, sagemaker_session):
job_return = serialization.deserialize_obj_from_s3(
sagemaker_session=sagemaker_session,
s3_uri=s3_path_join(job.s3_uri, RESULTS_FOLDER),
hmac_key=job.hmac_key,
)
except DeserializationError as e:
client_exception = e
Expand All @@ -872,6 +875,7 @@ def from_describe_response(describe_training_job_response, sagemaker_session):
job_exception = serialization.deserialize_exception_from_s3(
sagemaker_session=sagemaker_session,
s3_uri=s3_path_join(job.s3_uri, EXCEPTION_FOLDER),
hmac_key=job.hmac_key,
)
except ServiceError as serr:
chained_e = serr.__cause__
Expand Down Expand Up @@ -961,6 +965,7 @@ def result(self, timeout: float = None) -> Any:
self._return = serialization.deserialize_obj_from_s3(
sagemaker_session=self._job.sagemaker_session,
s3_uri=s3_path_join(self._job.s3_uri, RESULTS_FOLDER),
hmac_key=self._job.hmac_key,
)
self._state = _FINISHED
return self._return
Expand All @@ -969,6 +974,7 @@ def result(self, timeout: float = None) -> Any:
self._exception = serialization.deserialize_exception_from_s3(
sagemaker_session=self._job.sagemaker_session,
s3_uri=s3_path_join(self._job.s3_uri, EXCEPTION_FOLDER),
hmac_key=self._job.hmac_key,
)
except ServiceError as serr:
chained_e = serr.__cause__
Expand Down
152 changes: 120 additions & 32 deletions src/sagemaker/remote_function/core/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,16 @@
import json
import os
import sys
import hmac
import hashlib

import cloudpickle

from typing import Any, Callable
from sagemaker.remote_function.errors import ServiceError, SerializationError, DeserializationError
from sagemaker.s3 import S3Downloader, S3Uploader
from sagemaker.session import Session

from tblib import pickling_support


Expand All @@ -34,6 +38,7 @@ def _get_python_version():
class _MetaData:
"""Metadata about the serialized data or functions."""

sha256_hash: str
version: str = "2023-04-24"
python_version: str = _get_python_version()
serialization_module: str = "cloudpickle"
Expand All @@ -48,11 +53,17 @@ def from_json(s):
except json.decoder.JSONDecodeError:
raise DeserializationError("Corrupt metadata file. It is not a valid json file.")

metadata = _MetaData()
sha256_hash = obj.get("sha256_hash")
metadata = _MetaData(sha256_hash=sha256_hash)
metadata.version = obj.get("version")
metadata.python_version = obj.get("python_version")
metadata.serialization_module = obj.get("serialization_module")

if not sha256_hash:
raise DeserializationError(
"Corrupt metadata file. SHA256 hash for the serialized data does not exist"
)

if not (
metadata.version == "2023-04-24" and metadata.serialization_module == "cloudpickle"
):
Expand All @@ -67,20 +78,16 @@ class CloudpickleSerializer:
"""Serializer using cloudpickle."""

@staticmethod
def serialize(obj: Any, sagemaker_session, s3_uri: str, s3_kms_key: str = None):
def serialize(obj: Any) -> Any:
"""Serializes data object and uploads it to S3.

Args:
sagemaker_session (sagemaker.session.Session): The underlying Boto3 session which AWS service
calls are delegated to.
s3_uri (str): S3 root uri to which resulting serialized artifacts will be uploaded.
s3_kms_key (str): KMS key used to encrypt artifacts uploaded to S3.
obj: object to be serialized and persisted
Raises:
SerializationError: when fail to serialize object to bytes.
"""
try:
bytes_to_upload = cloudpickle.dumps(obj)
return cloudpickle.dumps(obj)
except Exception as e:
if isinstance(
e, NotImplementedError
Expand All @@ -96,10 +103,8 @@ def serialize(obj: Any, sagemaker_session, s3_uri: str, s3_kms_key: str = None):
"Error when serializing object of type [{}]: {}".format(type(obj).__name__, repr(e))
) from e

_upload_bytes_to_s3(bytes_to_upload, s3_uri, s3_kms_key, sagemaker_session)

@staticmethod
def deserialize(sagemaker_session, s3_uri) -> Any:
def deserialize(s3_uri: str, bytes_to_deserialize) -> Any:
"""Downloads from S3 and then deserializes data objects.

Args:
Expand All @@ -111,7 +116,6 @@ def deserialize(sagemaker_session, s3_uri) -> Any:
Raises:
DeserializationError: when fail to serialize object to bytes.
"""
bytes_to_deserialize = _read_bytes_from_s3(s3_uri, sagemaker_session)

try:
return cloudpickle.loads(bytes_to_deserialize)
Expand All @@ -122,28 +126,39 @@ def deserialize(sagemaker_session, s3_uri) -> Any:


# TODO: use dask serializer in case dask distributed is installed in users' environment.
def serialize_func_to_s3(func: Callable, sagemaker_session, s3_uri, s3_kms_key=None):
def serialize_func_to_s3(
func: Callable, sagemaker_session: Session, s3_uri: str, hmac_key: str, s3_kms_key: str = None
):
"""Serializes function and uploads it to S3.

Args:
sagemaker_session (sagemaker.session.Session): The underlying Boto3 session which AWS service
calls are delegated to.
s3_uri (str): S3 root uri to which resulting serialized artifacts will be uploaded.
hmac_key (str): Key used to calculate hmac-sha256 hash of the serialized func.
s3_kms_key (str): KMS key used to encrypt artifacts uploaded to S3.
func: function to be serialized and persisted
Raises:
SerializationError: when fail to serialize function to bytes.
"""

bytes_to_upload = CloudpickleSerializer.serialize(func)

_upload_bytes_to_s3(
_MetaData().to_json(), os.path.join(s3_uri, "metadata.json"), s3_kms_key, sagemaker_session
bytes_to_upload, os.path.join(s3_uri, "payload.pkl"), s3_kms_key, sagemaker_session
)
CloudpickleSerializer.serialize(
func, sagemaker_session, os.path.join(s3_uri, "payload.pkl"), s3_kms_key

sha256_hash = _compute_hash(bytes_to_upload, secret_key=hmac_key)

_upload_bytes_to_s3(
_MetaData(sha256_hash).to_json(),
os.path.join(s3_uri, "metadata.json"),
s3_kms_key,
sagemaker_session,
)


def deserialize_func_from_s3(sagemaker_session, s3_uri) -> Callable:
def deserialize_func_from_s3(sagemaker_session: Session, s3_uri: str, hmac_key: str) -> Callable:
"""Downloads from S3 and then deserializes data objects.

This method downloads the serialized training job outputs to a temporary directory and
Expand All @@ -153,61 +168,94 @@ def deserialize_func_from_s3(sagemaker_session, s3_uri) -> Callable:
sagemaker_session (sagemaker.session.Session): The underlying sagemaker session which AWS service
calls are delegated to.
s3_uri (str): S3 root uri to which resulting serialized artifacts will be uploaded.
hmac_key (str): Key used to calculate hmac-sha256 hash of the serialized func.
Returns :
The deserialized function.
Raises:
DeserializationError: when fail to serialize function to bytes.
"""
_MetaData.from_json(
metadata = _MetaData.from_json(
_read_bytes_from_s3(os.path.join(s3_uri, "metadata.json"), sagemaker_session)
)

return CloudpickleSerializer.deserialize(sagemaker_session, os.path.join(s3_uri, "payload.pkl"))
bytes_to_deserialize = _read_bytes_from_s3(
os.path.join(s3_uri, "payload.pkl"), sagemaker_session
)

_perform_integrity_check(
expected_hash_value=metadata.sha256_hash, secret_key=hmac_key, buffer=bytes_to_deserialize
)

return CloudpickleSerializer.deserialize(
os.path.join(s3_uri, "payload.pkl"), bytes_to_deserialize
)


def serialize_obj_to_s3(obj: Any, sagemaker_session, s3_uri: str, s3_kms_key: str = None):
def serialize_obj_to_s3(
obj: Any, sagemaker_session: Session, s3_uri: str, hmac_key: str, s3_kms_key: str = None
):
"""Serializes data object and uploads it to S3.

Args:
sagemaker_session (sagemaker.session.Session): The underlying Boto3 session which AWS service
calls are delegated to.
s3_uri (str): S3 root uri to which resulting serialized artifacts will be uploaded.
s3_kms_key (str): KMS key used to encrypt artifacts uploaded to S3.
hmac_key (str): Key used to calculate hmac-sha256 hash of the serialized obj.
obj: object to be serialized and persisted
Raises:
SerializationError: when fail to serialize object to bytes.
"""

bytes_to_upload = CloudpickleSerializer.serialize(obj)

_upload_bytes_to_s3(
_MetaData().to_json(), os.path.join(s3_uri, "metadata.json"), s3_kms_key, sagemaker_session
bytes_to_upload, os.path.join(s3_uri, "payload.pkl"), s3_kms_key, sagemaker_session
)
CloudpickleSerializer.serialize(
obj, sagemaker_session, os.path.join(s3_uri, "payload.pkl"), s3_kms_key

sha256_hash = _compute_hash(bytes_to_upload, secret_key=hmac_key)

_upload_bytes_to_s3(
_MetaData(sha256_hash).to_json(),
os.path.join(s3_uri, "metadata.json"),
s3_kms_key,
sagemaker_session,
)


def deserialize_obj_from_s3(sagemaker_session, s3_uri) -> Any:
def deserialize_obj_from_s3(sagemaker_session: Session, s3_uri: str, hmac_key: str) -> Any:
"""Downloads from S3 and then deserializes data objects.

Args:
sagemaker_session (sagemaker.session.Session): The underlying sagemaker session which AWS service
calls are delegated to.
s3_uri (str): S3 root uri to which resulting serialized artifacts will be uploaded.
hmac_key (str): Key used to calculate hmac-sha256 hash of the serialized obj.
Returns :
Deserialized python objects.
Raises:
DeserializationError: when fail to serialize object to bytes.
"""

_MetaData.from_json(
metadata = _MetaData.from_json(
_read_bytes_from_s3(os.path.join(s3_uri, "metadata.json"), sagemaker_session)
)

return CloudpickleSerializer.deserialize(sagemaker_session, os.path.join(s3_uri, "payload.pkl"))
bytes_to_deserialize = _read_bytes_from_s3(
os.path.join(s3_uri, "payload.pkl"), sagemaker_session
)

_perform_integrity_check(
expected_hash_value=metadata.sha256_hash, secret_key=hmac_key, buffer=bytes_to_deserialize
)

return CloudpickleSerializer.deserialize(
os.path.join(s3_uri, "payload.pkl"), bytes_to_deserialize
)


def serialize_exception_to_s3(
exc: Exception, sagemaker_session, s3_uri: str, s3_kms_key: str = None
exc: Exception, sagemaker_session: Session, s3_uri: str, hmac_key: str, s3_kms_key: str = None
):
"""Serializes exception with traceback and uploads it to S3.

Expand All @@ -216,37 +264,58 @@ def serialize_exception_to_s3(
calls are delegated to.
s3_uri (str): S3 root uri to which resulting serialized artifacts will be uploaded.
s3_kms_key (str): KMS key used to encrypt artifacts uploaded to S3.
hmac_key (str): Key used to calculate hmac-sha256 hash of the serialized exception.
exc: Exception to be serialized and persisted
Raises:
SerializationError: when fail to serialize object to bytes.
"""
pickling_support.install()

bytes_to_upload = CloudpickleSerializer.serialize(exc)

_upload_bytes_to_s3(
_MetaData().to_json(), os.path.join(s3_uri, "metadata.json"), s3_kms_key, sagemaker_session
bytes_to_upload, os.path.join(s3_uri, "payload.pkl"), s3_kms_key, sagemaker_session
)
CloudpickleSerializer.serialize(
exc, sagemaker_session, os.path.join(s3_uri, "payload.pkl"), s3_kms_key

sha256_hash = _compute_hash(bytes_to_upload, secret_key=hmac_key)

_upload_bytes_to_s3(
_MetaData(sha256_hash).to_json(),
os.path.join(s3_uri, "metadata.json"),
s3_kms_key,
sagemaker_session,
)


def deserialize_exception_from_s3(sagemaker_session, s3_uri) -> Any:
def deserialize_exception_from_s3(sagemaker_session: Session, s3_uri: str, hmac_key: str) -> Any:
"""Downloads from S3 and then deserializes exception.

Args:
sagemaker_session (sagemaker.session.Session): The underlying sagemaker session which AWS service
calls are delegated to.
s3_uri (str): S3 root uri to which resulting serialized artifacts will be uploaded.
hmac_key (str): Key used to calculate hmac-sha256 hash of the serialized exception.
Returns :
Deserialized exception with traceback.
Raises:
DeserializationError: when fail to serialize object to bytes.
"""

_MetaData.from_json(
metadata = _MetaData.from_json(
_read_bytes_from_s3(os.path.join(s3_uri, "metadata.json"), sagemaker_session)
)

return CloudpickleSerializer.deserialize(sagemaker_session, os.path.join(s3_uri, "payload.pkl"))
bytes_to_deserialize = _read_bytes_from_s3(
os.path.join(s3_uri, "payload.pkl"), sagemaker_session
)

_perform_integrity_check(
expected_hash_value=metadata.sha256_hash, secret_key=hmac_key, buffer=bytes_to_deserialize
)

return CloudpickleSerializer.deserialize(
os.path.join(s3_uri, "payload.pkl"), bytes_to_deserialize
)


def _upload_bytes_to_s3(bytes, s3_uri, s3_kms_key, sagemaker_session):
Expand All @@ -269,3 +338,22 @@ def _read_bytes_from_s3(s3_uri, sagemaker_session):
raise ServiceError(
"Failed to read serialized bytes from {}: {}".format(s3_uri, repr(e))
) from e


def _compute_hash(buffer: bytes, secret_key: str) -> str:
"""Compute the hmac-sha256 hash"""
return hmac.new(secret_key.encode(), msg=buffer, digestmod=hashlib.sha256).hexdigest()


def _perform_integrity_check(expected_hash_value: str, secret_key: str, buffer: bytes):
"""Performs integrify checks for serialized code/arguments uploaded to s3.

Verifies whether the hash read from s3 matches the hash calculated
during remote function execution.
"""
actual_hash_value = _compute_hash(buffer=buffer, secret_key=secret_key)
if not hmac.compare_digest(expected_hash_value, actual_hash_value):
raise DeserializationError(
"Integrity check for the serialized function or data failed. "
"Please restrict access to your S3 bucket"
)
Loading