Skip to content

Commit f8187e9

Browse files
committed
Some cleanups, fix memoryview support
1 parent 2da4c24 commit f8187e9

File tree

2 files changed

+24
-48
lines changed

2 files changed

+24
-48
lines changed

cloudpickle/cloudpickle.py

Lines changed: 20 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -264,15 +264,12 @@ def dump(self, obj):
264264
raise pickle.PicklingError(msg)
265265

266266
def save_memoryview(self, obj):
267-
"""Fallback to save_string"""
268-
Pickler.save_string(self, str(obj))
267+
self.save(obj.tobytes())
268+
dispatch[memoryview] = save_memoryview
269269

270-
def save_buffer(self, obj):
271-
"""Fallback to save_string"""
272-
Pickler.save_string(self,str(obj))
273-
if PY3:
274-
dispatch[memoryview] = save_memoryview
275-
else:
270+
if not PY3:
271+
def save_buffer(self, obj):
272+
self.save(str(obj))
276273
dispatch[buffer] = save_buffer
277274

278275
def save_unsupported(self, obj):
@@ -387,7 +384,7 @@ def save_function(self, obj, name=None):
387384
rv = (getattr, (obj.__self__, name))
388385
else:
389386
raise pickle.PicklingError("Can't pickle %r" % obj)
390-
return Pickler.save_reduce(self, obj=obj, *rv)
387+
return self.save_reduce(obj=obj, *rv)
391388

392389
# if func is lambda, def'ed at prompt, is in main, or is nested, then
393390
# we'll pickle the actual function object rather than simply saving a
@@ -477,18 +474,12 @@ def save_dynamic_class(self, obj):
477474
# Push the rehydration function.
478475
save(_rehydrate_skeleton_class)
479476

480-
# Mark the start of the args for the rehydration function.
477+
# Mark the start of the args tuple for the rehydration function.
481478
write(pickle.MARK)
482479

483-
# Create and memoize an empty class with obj's name and bases.
484-
save(type(obj))
485-
save((
486-
obj.__name__,
487-
obj.__bases__,
488-
type_kwargs,
489-
))
490-
write(pickle.REDUCE)
491-
self.memoize(obj)
480+
# Create and memoize an skeleton class with obj's name and bases.
481+
tp = type(obj)
482+
self.save_reduce(tp, (obj.__name__, obj.__bases__, type_kwargs), obj=obj)
492483

493484
# Now save the rest of obj's __dict__. Any references to obj
494485
# encountered while saving will point to the skeleton class.
@@ -627,37 +618,18 @@ def save_global(self, obj, name=None, pack=struct.pack):
627618
The name of this method is somewhat misleading: all types get
628619
dispatched here.
629620
"""
630-
if obj.__module__ == "__builtin__" or obj.__module__ == "builtins":
631-
if obj in _BUILTIN_TYPE_NAMES:
632-
return self.save_reduce(_builtin_type, (_BUILTIN_TYPE_NAMES[obj],), obj=obj)
633-
634-
if name is None:
635-
name = obj.__name__
636-
637-
modname = getattr(obj, "__module__", None)
638-
if modname is None:
639-
try:
640-
# whichmodule() could fail, see
641-
# https://bitbucket.org/gutworth/six/issues/63/importing-six-breaks-pickling
642-
modname = pickle.whichmodule(obj, name)
643-
except Exception:
644-
modname = '__main__'
645-
646-
if modname == '__main__':
647-
themodule = None
648-
else:
649-
__import__(modname)
650-
themodule = sys.modules[modname]
651-
self.modules.add(themodule)
621+
try:
622+
return Pickler.save_global(self, obj, name=name)
623+
except Exception:
624+
if obj.__module__ == "__builtin__" or obj.__module__ == "builtins":
625+
if obj in _BUILTIN_TYPE_NAMES:
626+
return self.save_reduce(_builtin_type, (_BUILTIN_TYPE_NAMES[obj],), obj=obj)
652627

653-
if hasattr(themodule, name) and getattr(themodule, name) is obj:
654-
return Pickler.save_global(self, obj, name)
628+
typ = type(obj)
629+
if typ is not obj and isinstance(obj, (type, types.ClassType)):
630+
return self.save_dynamic_class(obj)
655631

656-
typ = type(obj)
657-
if typ is not obj and isinstance(obj, (type, types.ClassType)):
658-
self.save_dynamic_class(obj)
659-
else:
660-
raise pickle.PicklingError("Can't pickle %r" % obj)
632+
raise
661633

662634
dispatch[type] = save_global
663635
dispatch[types.ClassType] = save_global

tests/cloudpickle_test.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,10 @@ def test_buffer(self):
130130
except NameError: # Python 3 does no longer support buffers
131131
pass
132132

133+
def test_memoryview(self):
134+
buffer_obj = memoryview(b"Hello")
135+
self.assertEqual(pickle_depickle(buffer_obj), buffer_obj.tobytes())
136+
133137
def test_lambda(self):
134138
self.assertEqual(pickle_depickle(lambda: 1)(), 1)
135139

0 commit comments

Comments
 (0)