1717import json
1818import os
1919import sys
20- import pickle
21- from enum import Enum
2220
2321import cloudpickle
2422
2523from typing import Any , Callable
26-
27- from sagemaker .s3 import s3_path_join
28-
2924from sagemaker .remote_function .errors import ServiceError , SerializationError , DeserializationError
3025from sagemaker .s3 import S3Downloader , S3Uploader
3126from tblib import pickling_support
3227
33- METADATA_FILE = "metadata.json"
34- PAYLOAD_FILE = "payload.pkl"
35- HEADER_FILE = "headers.pkl"
36- FRAME_FILE = "frame-{}.dat"
37-
3828
3929def _get_python_version ():
4030 return f"{ sys .version_info .major } .{ sys .version_info .minor } .{ sys .version_info .micro } "
4131
4232
43- class SerializationModule (str , Enum ):
44- """Represents various serialization modules used."""
45-
46- CLOUDPICKLE = "cloudpickle"
47- DASK = "dask"
48-
49-
5033@dataclasses .dataclass
5134class _MetaData :
5235 """Metadata about the serialized data or functions."""
5336
54- serialization_module : SerializationModule
5537 version : str = "2023-04-24"
5638 python_version : str = _get_python_version ()
39+ serialization_module : str = "cloudpickle"
5740
5841 def to_json (self ):
5942 return json .dumps (dataclasses .asdict (self )).encode ()
@@ -62,13 +45,16 @@ def to_json(self):
6245 def from_json (s ):
6346 try :
6447 obj = json .loads (s )
65- metadata = _MetaData (** obj )
66- except (json .decoder .JSONDecodeError , TypeError ):
48+ except json .decoder .JSONDecodeError :
6749 raise DeserializationError ("Corrupt metadata file. It is not a valid json file." )
6850
69- if (
70- metadata .version != "2023-04-24"
71- or metadata .serialization_module not in SerializationModule .__members__ .values ()
51+ metadata = _MetaData ()
52+ metadata .version = obj .get ("version" )
53+ metadata .python_version = obj .get ("python_version" )
54+ metadata .serialization_module = obj .get ("serialization_module" )
55+
56+ if not (
57+ metadata .version == "2023-04-24" and metadata .serialization_module == "cloudpickle"
7258 ):
7359 raise DeserializationError (
7460 f"Corrupt metadata file. Serialization approach { s } is not supported."
@@ -93,12 +79,6 @@ def serialize(obj: Any, sagemaker_session, s3_uri: str, s3_kms_key: str = None):
9379 Raises:
9480 SerializationError: when fail to serialize object to bytes.
9581 """
96- _upload_bytes_to_s3 (
97- _MetaData (SerializationModule .CLOUDPICKLE ).to_json (),
98- os .path .join (s3_uri , METADATA_FILE ),
99- s3_kms_key ,
100- sagemaker_session ,
101- )
10282 try :
10383 bytes_to_upload = cloudpickle .dumps (obj )
10484 except Exception as e :
@@ -116,76 +96,7 @@ def serialize(obj: Any, sagemaker_session, s3_uri: str, s3_kms_key: str = None):
11696 "Error when serializing object of type [{}]: {}" .format (type (obj ).__name__ , repr (e ))
11797 ) from e
11898
119- _upload_bytes_to_s3 (
120- bytes_to_upload , os .path .join (s3_uri , PAYLOAD_FILE ), s3_kms_key , sagemaker_session
121- )
122-
123- @staticmethod
124- def deserialize (sagemaker_session , s3_uri ) -> Any :
125- """Downloads from S3 and then deserializes data objects.
126-
127- Args:
128- sagemaker_session (sagemaker.session.Session): The underlying sagemaker session which
129- AWS service calls are delegated to.
130- s3_uri (str): S3 root uri to which resulting serialized artifacts will be uploaded.
131- Returns :
132- List of deserialized python objects.
133- Raises:
134- DeserializationError: when fail to serialize object to bytes.
135- """
136- bytes_to_deserialize = _read_bytes_from_s3 (
137- os .path .join (s3_uri , PAYLOAD_FILE ), sagemaker_session
138- )
139-
140- try :
141- return cloudpickle .loads (bytes_to_deserialize )
142- except Exception as e :
143- raise DeserializationError (
144- "Error when deserializing bytes downloaded from {}: {}" .format (
145- os .path .join (s3_uri , PAYLOAD_FILE ), repr (e )
146- )
147- ) from e
148-
149-
150- class DaskSerializer :
151- """Serializer using Dask."""
152-
153- @staticmethod
154- def serialize (obj : Any , sagemaker_session , s3_uri : str , s3_kms_key : str = None ):
155- """Serializes data object and uploads it to S3.
156-
157- Args:
158- sagemaker_session (sagemaker.session.Session): The underlying Boto3 session which AWS
159- service calls are delegated to.
160- s3_uri (str): S3 root uri to which resulting serialized artifacts will be uploaded.
161- s3_kms_key (str): KMS key used to encrypt artifacts uploaded to S3.
162- obj: object to be serialized and persisted
163- Raises:
164- SerializationError: when fail to serialize object to bytes.
165- """
166- import distributed .protocol as dask
167-
168- _upload_bytes_to_s3 (
169- _MetaData (SerializationModule .DASK ).to_json (),
170- os .path .join (s3_uri , METADATA_FILE ),
171- s3_kms_key ,
172- sagemaker_session ,
173- )
174- try :
175- header , frames = dask .serialize (obj , on_error = "raise" )
176- except Exception as e :
177- raise SerializationError (
178- "Error when serializing object of type [{}]: {}" .format (type (obj ).__name__ , repr (e ))
179- ) from e
180-
181- _upload_bytes_to_s3 (
182- pickle .dumps (header ), s3_path_join (s3_uri , HEADER_FILE ), s3_kms_key , sagemaker_session
183- )
184- for idx , frame in enumerate (frames ):
185- frame = bytes (frame ) if isinstance (frame , memoryview ) else frame
186- _upload_bytes_to_s3 (
187- frame , s3_path_join (s3_uri , FRAME_FILE .format (idx )), s3_kms_key , sagemaker_session
188- )
99+ _upload_bytes_to_s3 (bytes_to_upload , s3_uri , s3_kms_key , sagemaker_session )
189100
190101 @staticmethod
191102 def deserialize (sagemaker_session , s3_uri ) -> Any :
@@ -200,29 +111,19 @@ def deserialize(sagemaker_session, s3_uri) -> Any:
200111 Raises:
201112 DeserializationError: when fail to serialize object to bytes.
202113 """
203- import distributed . protocol as dask
114+ bytes_to_deserialize = _read_bytes_from_s3 ( s3_uri , sagemaker_session )
204115
205- header_to_deserialize = _read_bytes_from_s3 (
206- s3_path_join (s3_uri , HEADER_FILE ), sagemaker_session
207- )
208- headers = pickle .loads (header_to_deserialize )
209- num_frames = len (headers ["frame-lengths" ]) if "frame-lengths" in headers else 1
210- frames = []
211- for idx in range (num_frames ):
212- frame = _read_bytes_from_s3 (
213- s3_path_join (s3_uri , FRAME_FILE .format (idx )), sagemaker_session
214- )
215- frames .append (frame )
216116 try :
217- return dask . deserialize ( headers , frames )
117+ return cloudpickle . loads ( bytes_to_deserialize )
218118 except Exception as e :
219119 raise DeserializationError (
220120 "Error when deserializing bytes downloaded from {}: {}" .format (s3_uri , repr (e ))
221121 ) from e
222122
223123
124+ # TODO: use dask serializer in case dask distributed is installed in users' environment.
224125def serialize_func_to_s3 (func : Callable , sagemaker_session , s3_uri , s3_kms_key = None ):
225- """Serializes function using cloudpickle and uploads it to S3.
126+ """Serializes function and uploads it to S3.
226127
227128 Args:
228129 sagemaker_session (sagemaker.session.Session): The underlying Boto3 session which AWS service
@@ -233,7 +134,13 @@ def serialize_func_to_s3(func: Callable, sagemaker_session, s3_uri, s3_kms_key=N
233134 Raises:
234135 SerializationError: when fail to serialize function to bytes.
235136 """
236- CloudpickleSerializer .serialize (func , sagemaker_session , s3_uri , s3_kms_key )
137+
138+ _upload_bytes_to_s3 (
139+ _MetaData ().to_json (), os .path .join (s3_uri , "metadata.json" ), s3_kms_key , sagemaker_session
140+ )
141+ CloudpickleSerializer .serialize (
142+ func , sagemaker_session , os .path .join (s3_uri , "payload.pkl" ), s3_kms_key
143+ )
237144
238145
239146def deserialize_func_from_s3 (sagemaker_session , s3_uri ) -> Callable :
@@ -251,16 +158,16 @@ def deserialize_func_from_s3(sagemaker_session, s3_uri) -> Callable:
251158 Raises:
252159 DeserializationError: when fail to serialize function to bytes.
253160 """
254- _MetaData .from_json (_read_bytes_from_s3 (os .path .join (s3_uri , METADATA_FILE ), sagemaker_session ))
255- return CloudpickleSerializer .deserialize (sagemaker_session , s3_uri )
161+ _MetaData .from_json (
162+ _read_bytes_from_s3 (os .path .join (s3_uri , "metadata.json" ), sagemaker_session )
163+ )
164+
165+ return CloudpickleSerializer .deserialize (sagemaker_session , os .path .join (s3_uri , "payload.pkl" ))
256166
257167
258168def serialize_obj_to_s3 (obj : Any , sagemaker_session , s3_uri : str , s3_kms_key : str = None ):
259169 """Serializes data object and uploads it to S3.
260170
261- This method uses the Dask library to perform serialization if its already installed, otherwise,
262- it uses cloudpickle.
263-
264171 Args:
265172 sagemaker_session (sagemaker.session.Session): The underlying Boto3 session which AWS service
266173 calls are delegated to.
@@ -271,12 +178,12 @@ def serialize_obj_to_s3(obj: Any, sagemaker_session, s3_uri: str, s3_kms_key: st
271178 SerializationError: when fail to serialize object to bytes.
272179 """
273180
274- try :
275- import distributed . protocol as dask # noqa: F401
276-
277- DaskSerializer .serialize (obj , sagemaker_session , s3_uri , s3_kms_key )
278- except ModuleNotFoundError :
279- CloudpickleSerializer . serialize ( obj , sagemaker_session , s3_uri , s3_kms_key )
181+ _upload_bytes_to_s3 (
182+ _MetaData (). to_json (), os . path . join ( s3_uri , "metadata.json" ), s3_kms_key , sagemaker_session
183+ )
184+ CloudpickleSerializer .serialize (
185+ obj , sagemaker_session , os . path . join ( s3_uri , "payload.pkl" ), s3_kms_key
186+ )
280187
281188
282189def deserialize_obj_from_s3 (sagemaker_session , s3_uri ) -> Any :
@@ -291,12 +198,12 @@ def deserialize_obj_from_s3(sagemaker_session, s3_uri) -> Any:
291198 Raises:
292199 DeserializationError: when fail to serialize object to bytes.
293200 """
294- metadata = _MetaData .from_json (
295- _read_bytes_from_s3 (os .path .join (s3_uri , METADATA_FILE ), sagemaker_session )
201+
202+ _MetaData .from_json (
203+ _read_bytes_from_s3 (os .path .join (s3_uri , "metadata.json" ), sagemaker_session )
296204 )
297- if metadata .serialization_module == SerializationModule .DASK :
298- return DaskSerializer .deserialize (sagemaker_session , s3_uri )
299- return CloudpickleSerializer .deserialize (sagemaker_session , s3_uri )
205+
206+ return CloudpickleSerializer .deserialize (sagemaker_session , os .path .join (s3_uri , "payload.pkl" ))
300207
301208
302209def serialize_exception_to_s3 (
@@ -314,7 +221,12 @@ def serialize_exception_to_s3(
314221 SerializationError: when fail to serialize object to bytes.
315222 """
316223 pickling_support .install ()
317- CloudpickleSerializer .serialize (exc , sagemaker_session , s3_uri , s3_kms_key )
224+ _upload_bytes_to_s3 (
225+ _MetaData ().to_json (), os .path .join (s3_uri , "metadata.json" ), s3_kms_key , sagemaker_session
226+ )
227+ CloudpickleSerializer .serialize (
228+ exc , sagemaker_session , os .path .join (s3_uri , "payload.pkl" ), s3_kms_key
229+ )
318230
319231
320232def deserialize_exception_from_s3 (sagemaker_session , s3_uri ) -> Any :
@@ -329,8 +241,12 @@ def deserialize_exception_from_s3(sagemaker_session, s3_uri) -> Any:
329241 Raises:
330242 DeserializationError: when fail to serialize object to bytes.
331243 """
332- _MetaData .from_json (_read_bytes_from_s3 (os .path .join (s3_uri , METADATA_FILE ), sagemaker_session ))
333- return CloudpickleSerializer .deserialize (sagemaker_session , s3_uri )
244+
245+ _MetaData .from_json (
246+ _read_bytes_from_s3 (os .path .join (s3_uri , "metadata.json" ), sagemaker_session )
247+ )
248+
249+ return CloudpickleSerializer .deserialize (sagemaker_session , os .path .join (s3_uri , "payload.pkl" ))
334250
335251
336252def _upload_bytes_to_s3 (bytes , s3_uri , s3_kms_key , sagemaker_session ):
0 commit comments