1717import json
1818import os
1919import sys
20+ import pickle
21+ from enum import Enum
2022
2123import cloudpickle
2224
2325from typing import Any , Callable
26+
27+ from sagemaker .s3 import s3_path_join
28+
2429from sagemaker .remote_function .errors import ServiceError , SerializationError , DeserializationError
2530from sagemaker .s3 import S3Downloader , S3Uploader
2631from tblib import pickling_support
2732
33+ METADATA_FILE = "metadata.json"
34+ PAYLOAD_FILE = "payload.pkl"
35+ HEADER_FILE = "headers.pkl"
36+ FRAME_FILE = "frame-{}.dat"
37+
2838
2939def _get_python_version ():
3040 return f"{ sys .version_info .major } .{ sys .version_info .minor } .{ sys .version_info .micro } "
3141
3242
43+ class SerializationModule (str , Enum ):
44+ """Represents various serialization modules used."""
45+
46+ CLOUDPICKLE = "cloudpickle"
47+ DASK = "dask"
48+
49+
3350@dataclasses .dataclass
3451class _MetaData :
3552 """Metadata about the serialized data or functions."""
3653
54+ serialization_module : SerializationModule
3755 version : str = "2023-04-24"
3856 python_version : str = _get_python_version ()
39- serialization_module : str = "cloudpickle"
4057
4158 def to_json (self ):
4259 return json .dumps (dataclasses .asdict (self )).encode ()
@@ -45,16 +62,13 @@ def to_json(self):
4562 def from_json (s ):
4663 try :
4764 obj = json .loads (s )
48- except json .decoder .JSONDecodeError :
65+ metadata = _MetaData (** obj )
66+ except (json .decoder .JSONDecodeError , TypeError ):
4967 raise DeserializationError ("Corrupt metadata file. It is not a valid json file." )
5068
51- metadata = _MetaData ()
52- metadata .version = obj .get ("version" )
53- metadata .python_version = obj .get ("python_version" )
54- metadata .serialization_module = obj .get ("serialization_module" )
55-
56- if not (
57- metadata .version == "2023-04-24" and metadata .serialization_module == "cloudpickle"
69+ if (
70+ metadata .version != "2023-04-24"
71+ or metadata .serialization_module not in SerializationModule .__members__ .values ()
5872 ):
5973 raise DeserializationError (
6074 f"Corrupt metadata file. Serialization approach { s } is not supported."
@@ -79,6 +93,12 @@ def serialize(obj: Any, sagemaker_session, s3_uri: str, s3_kms_key: str = None):
7993 Raises:
8094 SerializationError: when fail to serialize object to bytes.
8195 """
96+ _upload_bytes_to_s3 (
97+ _MetaData (SerializationModule .CLOUDPICKLE ).to_json (),
98+ os .path .join (s3_uri , METADATA_FILE ),
99+ s3_kms_key ,
100+ sagemaker_session ,
101+ )
82102 try :
83103 bytes_to_upload = cloudpickle .dumps (obj )
84104 except Exception as e :
@@ -95,7 +115,76 @@ def serialize(obj: Any, sagemaker_session, s3_uri: str, s3_kms_key: str = None):
95115 "Error when serializing object of type [{}]: {}" .format (type (obj ).__name__ , repr (e ))
96116 ) from e
97117
98- _upload_bytes_to_s3 (bytes_to_upload , s3_uri , s3_kms_key , sagemaker_session )
118+ _upload_bytes_to_s3 (
119+ bytes_to_upload , os .path .join (s3_uri , PAYLOAD_FILE ), s3_kms_key , sagemaker_session
120+ )
121+
122+ @staticmethod
123+ def deserialize (sagemaker_session , s3_uri ) -> Any :
124+ """Downloads from S3 and then deserializes data objects.
125+
126+ Args:
127+ sagemaker_session (sagemaker.session.Session): The underlying sagemaker session which
128+ AWS service calls are delegated to.
129+ s3_uri (str): S3 root uri to which resulting serialized artifacts will be uploaded.
130+ Returns :
131+ List of deserialized python objects.
132+ Raises:
133+ DeserializationError: when fail to serialize object to bytes.
134+ """
135+ bytes_to_deserialize = _read_bytes_from_s3 (
136+ os .path .join (s3_uri , PAYLOAD_FILE ), sagemaker_session
137+ )
138+
139+ try :
140+ return cloudpickle .loads (bytes_to_deserialize )
141+ except Exception as e :
142+ raise DeserializationError (
143+ "Error when deserializing bytes downloaded from {}: {}" .format (
144+ os .path .join (s3_uri , PAYLOAD_FILE ), repr (e )
145+ )
146+ ) from e
147+
148+
149+ class DaskSerializer :
150+ """Serializer using Dask."""
151+
152+ @staticmethod
153+ def serialize (obj : Any , sagemaker_session , s3_uri : str , s3_kms_key : str = None ):
154+ """Serializes data object and uploads it to S3.
155+
156+ Args:
157+ sagemaker_session (sagemaker.session.Session): The underlying Boto3 session which AWS
158+ service calls are delegated to.
159+ s3_uri (str): S3 root uri to which resulting serialized artifacts will be uploaded.
160+ s3_kms_key (str): KMS key used to encrypt artifacts uploaded to S3.
161+ obj: object to be serialized and persisted
162+ Raises:
163+ SerializationError: when fail to serialize object to bytes.
164+ """
165+ import distributed .protocol as dask
166+
167+ _upload_bytes_to_s3 (
168+ _MetaData (SerializationModule .DASK ).to_json (),
169+ os .path .join (s3_uri , METADATA_FILE ),
170+ s3_kms_key ,
171+ sagemaker_session ,
172+ )
173+ try :
174+ header , frames = dask .serialize (obj , on_error = "raise" )
175+ except Exception as e :
176+ raise SerializationError (
177+ "Error when serializing object of type [{}]: {}" .format (type (obj ).__name__ , repr (e ))
178+ ) from e
179+
180+ _upload_bytes_to_s3 (
181+ pickle .dumps (header ), s3_path_join (s3_uri , HEADER_FILE ), s3_kms_key , sagemaker_session
182+ )
183+ for idx , frame in enumerate (frames ):
184+ frame = bytes (frame ) if isinstance (frame , memoryview ) else frame
185+ _upload_bytes_to_s3 (
186+ frame , s3_path_join (s3_uri , FRAME_FILE .format (idx )), s3_kms_key , sagemaker_session
187+ )
99188
100189 @staticmethod
101190 def deserialize (sagemaker_session , s3_uri ) -> Any :
@@ -110,19 +199,29 @@ def deserialize(sagemaker_session, s3_uri) -> Any:
110199 Raises:
111200 DeserializationError: when fail to serialize object to bytes.
112201 """
113- bytes_to_deserialize = _read_bytes_from_s3 ( s3_uri , sagemaker_session )
202+ import distributed . protocol as dask
114203
204+ header_to_deserialize = _read_bytes_from_s3 (
205+ s3_path_join (s3_uri , HEADER_FILE ), sagemaker_session
206+ )
207+ headers = pickle .loads (header_to_deserialize )
208+ num_frames = len (headers ["frame-lengths" ]) if "frame-lengths" in headers else 1
209+ frames = []
210+ for idx in range (num_frames ):
211+ frame = _read_bytes_from_s3 (
212+ s3_path_join (s3_uri , FRAME_FILE .format (idx )), sagemaker_session
213+ )
214+ frames .append (frame )
115215 try :
116- return cloudpickle . loads ( bytes_to_deserialize )
216+ return dask . deserialize ( headers , frames )
117217 except Exception as e :
118218 raise DeserializationError (
119219 "Error when deserializing bytes downloaded from {}: {}" .format (s3_uri , repr (e ))
120220 ) from e
121221
122222
123- # TODO: use dask serializer in case dask distributed is installed in users' environment.
124223def serialize_func_to_s3 (func : Callable , sagemaker_session , s3_uri , s3_kms_key = None ):
125- """Serializes function and uploads it to S3.
224+ """Serializes function using cloudpickle and uploads it to S3.
126225
127226 Args:
128227 sagemaker_session (sagemaker.session.Session): The underlying Boto3 session which AWS service
@@ -133,13 +232,7 @@ def serialize_func_to_s3(func: Callable, sagemaker_session, s3_uri, s3_kms_key=N
133232 Raises:
134233 SerializationError: when fail to serialize function to bytes.
135234 """
136-
137- _upload_bytes_to_s3 (
138- _MetaData ().to_json (), os .path .join (s3_uri , "metadata.json" ), s3_kms_key , sagemaker_session
139- )
140- CloudpickleSerializer .serialize (
141- func , sagemaker_session , os .path .join (s3_uri , "payload.pkl" ), s3_kms_key
142- )
235+ CloudpickleSerializer .serialize (func , sagemaker_session , s3_uri , s3_kms_key )
143236
144237
145238def deserialize_func_from_s3 (sagemaker_session , s3_uri ) -> Callable :
@@ -157,16 +250,16 @@ def deserialize_func_from_s3(sagemaker_session, s3_uri) -> Callable:
157250 Raises:
158251 DeserializationError: when fail to serialize function to bytes.
159252 """
160- _MetaData .from_json (
161- _read_bytes_from_s3 (os .path .join (s3_uri , "metadata.json" ), sagemaker_session )
162- )
163-
164- return CloudpickleSerializer .deserialize (sagemaker_session , os .path .join (s3_uri , "payload.pkl" ))
253+ _MetaData .from_json (_read_bytes_from_s3 (os .path .join (s3_uri , METADATA_FILE ), sagemaker_session ))
254+ return CloudpickleSerializer .deserialize (sagemaker_session , s3_uri )
165255
166256
167257def serialize_obj_to_s3 (obj : Any , sagemaker_session , s3_uri : str , s3_kms_key : str = None ):
168258 """Serializes data object and uploads it to S3.
169259
260+ This method uses the Dask library to perform serialization if its already installed, otherwise,
261+ it uses cloudpickle.
262+
170263 Args:
171264 sagemaker_session (sagemaker.session.Session): The underlying Boto3 session which AWS service
172265 calls are delegated to.
@@ -177,12 +270,12 @@ def serialize_obj_to_s3(obj: Any, sagemaker_session, s3_uri: str, s3_kms_key: st
177270 SerializationError: when fail to serialize object to bytes.
178271 """
179272
180- _upload_bytes_to_s3 (
181- _MetaData (). to_json (), os . path . join ( s3_uri , "metadata.json" ), s3_kms_key , sagemaker_session
182- )
183- CloudpickleSerializer .serialize (
184- obj , sagemaker_session , os . path . join ( s3_uri , "payload.pkl" ), s3_kms_key
185- )
273+ try :
274+ import distributed . protocol as dask # noqa: F401
275+
276+ DaskSerializer .serialize (obj , sagemaker_session , s3_uri , s3_kms_key )
277+ except ModuleNotFoundError :
278+ CloudpickleSerializer . serialize ( obj , sagemaker_session , s3_uri , s3_kms_key )
186279
187280
188281def deserialize_obj_from_s3 (sagemaker_session , s3_uri ) -> Any :
@@ -197,12 +290,12 @@ def deserialize_obj_from_s3(sagemaker_session, s3_uri) -> Any:
197290 Raises:
198291 DeserializationError: when fail to serialize object to bytes.
199292 """
200-
201- _MetaData .from_json (
202- _read_bytes_from_s3 (os .path .join (s3_uri , "metadata.json" ), sagemaker_session )
293+ metadata = _MetaData .from_json (
294+ _read_bytes_from_s3 (os .path .join (s3_uri , METADATA_FILE ), sagemaker_session )
203295 )
204-
205- return CloudpickleSerializer .deserialize (sagemaker_session , os .path .join (s3_uri , "payload.pkl" ))
296+ if metadata .serialization_module == SerializationModule .DASK :
297+ return DaskSerializer .deserialize (sagemaker_session , s3_uri )
298+ return CloudpickleSerializer .deserialize (sagemaker_session , s3_uri )
206299
207300
208301def serialize_exception_to_s3 (
@@ -220,12 +313,7 @@ def serialize_exception_to_s3(
220313 SerializationError: when fail to serialize object to bytes.
221314 """
222315 pickling_support .install ()
223- _upload_bytes_to_s3 (
224- _MetaData ().to_json (), os .path .join (s3_uri , "metadata.json" ), s3_kms_key , sagemaker_session
225- )
226- CloudpickleSerializer .serialize (
227- exc , sagemaker_session , os .path .join (s3_uri , "payload.pkl" ), s3_kms_key
228- )
316+ CloudpickleSerializer .serialize (exc , sagemaker_session , s3_uri , s3_kms_key )
229317
230318
231319def deserialize_exception_from_s3 (sagemaker_session , s3_uri ) -> Any :
@@ -240,12 +328,8 @@ def deserialize_exception_from_s3(sagemaker_session, s3_uri) -> Any:
240328 Raises:
241329 DeserializationError: when fail to serialize object to bytes.
242330 """
243-
244- _MetaData .from_json (
245- _read_bytes_from_s3 (os .path .join (s3_uri , "metadata.json" ), sagemaker_session )
246- )
247-
248- return CloudpickleSerializer .deserialize (sagemaker_session , os .path .join (s3_uri , "payload.pkl" ))
331+ _MetaData .from_json (_read_bytes_from_s3 (os .path .join (s3_uri , METADATA_FILE ), sagemaker_session ))
332+ return CloudpickleSerializer .deserialize (sagemaker_session , s3_uri )
249333
250334
251335def _upload_bytes_to_s3 (bytes , s3_uri , s3_kms_key , sagemaker_session ):
0 commit comments