1313"""SageMaker remote function data serializer/deserializer."""
1414from __future__ import absolute_import
1515
16+ import dataclasses
17+ import json
18+ import os
19+ import sys
20+
1621import cloudpickle
1722
1823from typing import Any , Callable
1924from sagemaker .remote_function .errors import ServiceError , SerializationError , DeserializationError
2025from sagemaker .s3 import S3Downloader , S3Uploader
26+ from tblib import pickling_support
27+
28+
29+ def _get_python_version ():
30+ return f"{ sys .version_info .major } .{ sys .version_info .minor } .{ sys .version_info .micro } "
31+
32+
33+ @dataclasses .dataclass
34+ class _MetaData :
35+ """Metadata about the serialized data or functions."""
2136
37+ version : str = "2023-04-24"
38+ python_version : str = _get_python_version ()
39+ serialization_module : str = "cloudpickle"
2240
23- # TODO: 1) use dask serializer instead of cloudpickle for data serialization.
24- # 2) set the pickle protocol properly
25- # 3) serialization/deserialization scheme needs to be explicitly versioned
26- # 4) handle exceptions
41+ def to_json (self ):
42+ return json .dumps (dataclasses .asdict (self )).encode ()
43+
44+ @staticmethod
45+ def from_json (s ):
46+ try :
47+ obj = json .loads (s )
48+ except json .decoder .JSONDecodeError :
49+ raise DeserializationError ("Corrupt metadata file. It is not a valid json file." )
50+
51+ metadata = _MetaData ()
52+ metadata .version = obj .get ("version" )
53+ metadata .python_version = obj .get ("python_version" )
54+ metadata .serialization_module = obj .get ("serialization_module" )
55+
56+ if not (
57+ metadata .version == "2023-04-24" and metadata .serialization_module == "cloudpickle"
58+ ):
59+ raise DeserializationError (
60+ f"Corrupt metadata file. Serialization approach { s } is not supported."
61+ )
62+
63+ return metadata
64+
65+
66+ class CloudpickleSerializer :
67+ """Serializer using cloudpickle."""
68+
69+ @staticmethod
70+ def serialize (obj : Any , sagemaker_session , s3_uri : str , s3_kms_key : str = None ):
71+ """Serializes data object and uploads it to S3.
72+
73+ Args:
74+ sagemaker_session (sagemaker.session.Session): The underlying Boto3 session which AWS service
75+ calls are delegated to.
76+ s3_uri (str): S3 root uri to which resulting serialized artifacts will be uploaded.
77+ s3_kms_key (str): KMS key used to encrypt artifacts uploaded to S3.
78+ obj: object to be serialized and persisted
79+ Raises:
80+ SerializationError: when fail to serialize object to bytes.
81+ """
82+ try :
83+ bytes_to_upload = cloudpickle .dumps (obj )
84+ except Exception as e :
85+ if isinstance (
86+ e , NotImplementedError
87+ ) and "Instance of Run type is not allowed to be pickled." in str (e ):
88+ raise SerializationError (
89+ """You are trying to reference to a sagemaker.experiments.run.Run instance from within the function
90+ or passing it as a function argument.
91+ Instantiate a Run in the function or use load_run instead."""
92+ ) from e
93+
94+ raise SerializationError (
95+ "Error when serializing object of type [{}]: {}" .format (type (obj ).__name__ , repr (e ))
96+ ) from e
97+
98+ _upload_bytes_to_s3 (bytes_to_upload , s3_uri , s3_kms_key , sagemaker_session )
99+
100+ @staticmethod
101+ def deserialize (sagemaker_session , s3_uri ) -> Any :
102+ """Downloads from S3 and then deserializes data objects.
103+
104+ Args:
105+ sagemaker_session (sagemaker.session.Session): The underlying sagemaker session which AWS service
106+ calls are delegated to.
107+ s3_uri (str): S3 root uri to which resulting serialized artifacts will be uploaded.
108+ Returns :
109+ List of deserialized python objects.
110+ Raises:
111+ DeserializationError: when fail to serialize object to bytes.
112+ """
113+ bytes_to_deserialize = _read_bytes_from_s3 (s3_uri , sagemaker_session )
114+
115+ try :
116+ return cloudpickle .loads (bytes_to_deserialize )
117+ except Exception as e :
118+ raise DeserializationError (
119+ "Error when deserializing bytes downloaded from {}: {}" .format (s3_uri , repr (e ))
120+ ) from e
121+
122+
123+ # TODO: use dask serializer in case dask distributed is installed in users' environment.
27124def serialize_func_to_s3 (func : Callable , sagemaker_session , s3_uri , s3_kms_key = None ):
28125 """Serializes function and uploads it to S3.
29126
@@ -36,16 +133,13 @@ def serialize_func_to_s3(func: Callable, sagemaker_session, s3_uri, s3_kms_key=N
36133 Raises:
37134 SerializationError: when fail to serialize function to bytes.
38135 """
39- try :
40- bytes_to_upload = cloudpickle .dumps (func )
41- except Exception as e :
42- raise SerializationError (
43- "Error when serializing function [{}]: {}" .format (
44- getattr (func , "__name__" , repr (func )), repr (e )
45- )
46- ) from e
47136
48- _upload_bytes_to_s3 (bytes_to_upload , s3_uri , s3_kms_key , sagemaker_session )
137+ _upload_bytes_to_s3 (
138+ _MetaData ().to_json (), os .path .join (s3_uri , "metadata.json" ), s3_kms_key , sagemaker_session
139+ )
140+ CloudpickleSerializer .serialize (
141+ func , sagemaker_session , os .path .join (s3_uri , "payload.pkl" ), s3_kms_key
142+ )
49143
50144
51145def deserialize_func_from_s3 (sagemaker_session , s3_uri ) -> Callable :
@@ -63,16 +157,11 @@ def deserialize_func_from_s3(sagemaker_session, s3_uri) -> Callable:
63157 Raises:
64158 DeserializationError: when fail to serialize function to bytes.
65159 """
66- bytes_to_deserialize = _read_bytes_from_s3 (s3_uri , sagemaker_session )
160+ _MetaData .from_json (
161+ _read_bytes_from_s3 (os .path .join (s3_uri , "metadata.json" ), sagemaker_session )
162+ )
67163
68- try :
69- return cloudpickle .loads (bytes_to_deserialize )
70- except Exception as e :
71- raise DeserializationError (
72- "Error when deserializing bytes downloaded from {} to function: {}" .format (
73- s3_uri , repr (e )
74- )
75- ) from e
164+ return CloudpickleSerializer .deserialize (sagemaker_session , os .path .join (s3_uri , "payload.pkl" ))
76165
77166
78167def serialize_obj_to_s3 (obj : Any , sagemaker_session , s3_uri : str , s3_kms_key : str = None ):
@@ -87,21 +176,13 @@ def serialize_obj_to_s3(obj: Any, sagemaker_session, s3_uri: str, s3_kms_key: st
87176 Raises:
88177 SerializationError: when fail to serialize object to bytes.
89178 """
90- try :
91- bytes_to_upload = cloudpickle .dumps (obj )
92- except Exception as e :
93- if isinstance (
94- e , NotImplementedError
95- ) and "Instance of Run type is not allowed to be pickled." in str (e ):
96- raise SerializationError (
97- "Remote function does not allow parameters of Run type."
98- ) from e
99-
100- raise SerializationError (
101- "Error when serializing object of type [{}]: {}" .format (type (obj ).__name__ , repr (e ))
102- ) from e
103179
104- _upload_bytes_to_s3 (bytes_to_upload , s3_uri , s3_kms_key , sagemaker_session )
180+ _upload_bytes_to_s3 (
181+ _MetaData ().to_json (), os .path .join (s3_uri , "metadata.json" ), s3_kms_key , sagemaker_session
182+ )
183+ CloudpickleSerializer .serialize (
184+ obj , sagemaker_session , os .path .join (s3_uri , "payload.pkl" ), s3_kms_key
185+ )
105186
106187
107188def deserialize_obj_from_s3 (sagemaker_session , s3_uri ) -> Any :
@@ -112,18 +193,59 @@ def deserialize_obj_from_s3(sagemaker_session, s3_uri) -> Any:
112193 calls are delegated to.
113194 s3_uri (str): S3 root uri to which resulting serialized artifacts will be uploaded.
114195 Returns :
115- List of deserialized python objects.
196+ Deserialized python objects.
116197 Raises:
117198 DeserializationError: when fail to serialize object to bytes.
118199 """
119- bytes_to_deserialize = _read_bytes_from_s3 (s3_uri , sagemaker_session )
120200
121- try :
122- return cloudpickle .loads (bytes_to_deserialize )
123- except Exception as e :
124- raise DeserializationError (
125- "Error when deserializing bytes downloaded from {}: {}" .format (s3_uri , repr (e ))
126- ) from e
201+ _MetaData .from_json (
202+ _read_bytes_from_s3 (os .path .join (s3_uri , "metadata.json" ), sagemaker_session )
203+ )
204+
205+ return CloudpickleSerializer .deserialize (sagemaker_session , os .path .join (s3_uri , "payload.pkl" ))
206+
207+
208+ def serialize_exception_to_s3 (
209+ exc : Exception , sagemaker_session , s3_uri : str , s3_kms_key : str = None
210+ ):
211+ """Serializes exception with traceback and uploads it to S3.
212+
213+ Args:
214+ sagemaker_session (sagemaker.session.Session): The underlying Boto3 session which AWS service
215+ calls are delegated to.
216+ s3_uri (str): S3 root uri to which resulting serialized artifacts will be uploaded.
217+ s3_kms_key (str): KMS key used to encrypt artifacts uploaded to S3.
218+ exc: Exception to be serialized and persisted
219+ Raises:
220+ SerializationError: when fail to serialize object to bytes.
221+ """
222+ pickling_support .install ()
223+ _upload_bytes_to_s3 (
224+ _MetaData ().to_json (), os .path .join (s3_uri , "metadata.json" ), s3_kms_key , sagemaker_session
225+ )
226+ CloudpickleSerializer .serialize (
227+ exc , sagemaker_session , os .path .join (s3_uri , "payload.pkl" ), s3_kms_key
228+ )
229+
230+
231+ def deserialize_exception_from_s3 (sagemaker_session , s3_uri ) -> Any :
232+ """Downloads from S3 and then deserializes exception.
233+
234+ Args:
235+ sagemaker_session (sagemaker.session.Session): The underlying sagemaker session which AWS service
236+ calls are delegated to.
237+ s3_uri (str): S3 root uri to which resulting serialized artifacts will be uploaded.
238+ Returns :
239+ Deserialized exception with traceback.
240+ Raises:
241+ DeserializationError: when fail to serialize object to bytes.
242+ """
243+
244+ _MetaData .from_json (
245+ _read_bytes_from_s3 (os .path .join (s3_uri , "metadata.json" ), sagemaker_session )
246+ )
247+
248+ return CloudpickleSerializer .deserialize (sagemaker_session , os .path .join (s3_uri , "payload.pkl" ))
127249
128250
129251def _upload_bytes_to_s3 (bytes , s3_uri , s3_kms_key , sagemaker_session ):
0 commit comments