Skip to content

Commit a67e842

Browse files
committed
[SPARK-27000][PYTHON] Upgrades cloudpickle to v0.8.0
## What changes were proposed in this pull request? After upgrading cloudpickle to 0.6.1 at #20691, one regression was found. Cloudpickle had a critical cloudpipe/cloudpickle#240 for that. Basically, it currently looks existing globals would override globals shipped in a function's, meaning: **Before:** ```python >>> def hey(): ... return "Hi" ... >>> spark.range(1).rdd.map(lambda _: hey()).collect() ['Hi'] >>> def hey(): ... return "Yeah" ... >>> spark.range(1).rdd.map(lambda _: hey()).collect() ['Hi'] ``` **After:** ```python >>> def hey(): ... return "Hi" ... >>> spark.range(1).rdd.map(lambda _: hey()).collect() ['Hi'] >>> >>> def hey(): ... return "Yeah" ... >>> spark.range(1).rdd.map(lambda _: hey()).collect() ['Yeah'] ``` Therefore, this PR upgrades cloudpickle to 0.8.0. Note that cloudpickle's release cycle is quite short. Between 0.6.1 and 0.7.0, it contains minor bug fixes. I don't see notable changes to double check and/or avoid. There is virtually only this fix between 0.7.0 and 0.8.1 - other fixes are about testing. ## How was this patch tested? Manually tested, tests were added. Verified unit tests were added in cloudpickle. Closes #23904 from HyukjinKwon/SPARK-27000. Authored-by: Hyukjin Kwon <[email protected]> Signed-off-by: Hyukjin Kwon <[email protected]>
1 parent 5fd28e8 commit a67e842

File tree

2 files changed

+100
-73
lines changed

2 files changed

+100
-73
lines changed

python/pyspark/cloudpickle.py

Lines changed: 90 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -42,58 +42,43 @@
4242
"""
4343
from __future__ import print_function
4444

45-
import io
4645
import dis
47-
import sys
48-
import types
46+
from functools import partial
47+
import importlib
48+
import io
49+
import itertools
50+
import logging
4951
import opcode
52+
import operator
5053
import pickle
5154
import struct
52-
import logging
53-
import weakref
54-
import operator
55-
import importlib
56-
import itertools
55+
import sys
5756
import traceback
58-
from functools import partial
59-
57+
import types
58+
import weakref
6059

6160
# cloudpickle is meant for inter process communication: we expect all
6261
# communicating processes to run the same Python version hence we favor
6362
# communication speed over compatibility:
6463
DEFAULT_PROTOCOL = pickle.HIGHEST_PROTOCOL
6564

6665

67-
if sys.version < '3':
66+
if sys.version_info[0] < 3: # pragma: no branch
6867
from pickle import Pickler
6968
try:
7069
from cStringIO import StringIO
7170
except ImportError:
7271
from StringIO import StringIO
72+
string_types = (basestring,) # noqa
7373
PY3 = False
7474
else:
7575
types.ClassType = type
7676
from pickle import _Pickler as Pickler
7777
from io import BytesIO as StringIO
78+
string_types = (str,)
7879
PY3 = True
7980

8081

81-
# Container for the global namespace to ensure consistent unpickling of
82-
# functions defined in dynamic modules (modules not registed in sys.modules).
83-
_dynamic_modules_globals = weakref.WeakValueDictionary()
84-
85-
86-
class _DynamicModuleFuncGlobals(dict):
87-
"""Global variables referenced by a function defined in a dynamic module
88-
89-
To avoid leaking references we store such context in a WeakValueDictionary
90-
instance. However instances of python builtin types such as dict cannot
91-
be used directly as values in such a construct, hence the need for a
92-
derived class.
93-
"""
94-
pass
95-
96-
9782
def _make_cell_set_template_code():
9883
"""Get the Python compiler to emit LOAD_FAST(arg); STORE_DEREF
9984
@@ -112,7 +97,7 @@ def _stub(value):
11297
11398
return _stub
11499
115-
_cell_set_template_code = f()
100+
_cell_set_template_code = f().__code__
116101
117102
This function is _only_ a LOAD_FAST(arg); STORE_DEREF, but that is
118103
invalid syntax on Python 2. If we use this function we also don't need
@@ -127,7 +112,7 @@ def inner(value):
127112
# NOTE: we are marking the cell variable as a free variable intentionally
128113
# so that we simulate an inner function instead of the outer function. This
129114
# is what gives us the ``nonlocal`` behavior in a Python 2 compatible way.
130-
if not PY3:
115+
if not PY3: # pragma: no branch
131116
return types.CodeType(
132117
co.co_argcount,
133118
co.co_nlocals,
@@ -228,14 +213,14 @@ def _factory():
228213
}
229214

230215

231-
if sys.version_info < (3, 4):
216+
if sys.version_info < (3, 4): # pragma: no branch
232217
def _walk_global_ops(code):
233218
"""
234219
Yield (opcode, argument number) tuples for all
235220
global-referencing instructions in *code*.
236221
"""
237222
code = getattr(code, 'co_code', b'')
238-
if not PY3:
223+
if not PY3: # pragma: no branch
239224
code = map(ord, code)
240225

241226
n = len(code)
@@ -273,8 +258,6 @@ def __init__(self, file, protocol=None):
273258
if protocol is None:
274259
protocol = DEFAULT_PROTOCOL
275260
Pickler.__init__(self, file, protocol=protocol)
276-
# set of modules to unpickle
277-
self.modules = set()
278261
# map ids to dictionary. used to ensure that functions can share global env
279262
self.globals_ref = {}
280263

@@ -294,7 +277,7 @@ def save_memoryview(self, obj):
294277

295278
dispatch[memoryview] = save_memoryview
296279

297-
if not PY3:
280+
if not PY3: # pragma: no branch
298281
def save_buffer(self, obj):
299282
self.save(str(obj))
300283

@@ -304,7 +287,6 @@ def save_module(self, obj):
304287
"""
305288
Save a module as an import
306289
"""
307-
self.modules.add(obj)
308290
if _is_dynamic(obj):
309291
self.save_reduce(dynamic_subimport, (obj.__name__, vars(obj)),
310292
obj=obj)
@@ -317,7 +299,7 @@ def save_codeobject(self, obj):
317299
"""
318300
Save a code object
319301
"""
320-
if PY3:
302+
if PY3: # pragma: no branch
321303
args = (
322304
obj.co_argcount, obj.co_kwonlyargcount, obj.co_nlocals, obj.co_stacksize,
323305
obj.co_flags, obj.co_code, obj.co_consts, obj.co_names, obj.co_varnames,
@@ -384,7 +366,6 @@ def save_function(self, obj, name=None):
384366
lookedup_by_name = None
385367

386368
if themodule:
387-
self.modules.add(themodule)
388369
if lookedup_by_name is obj:
389370
return self.save_global(obj, name)
390371

@@ -396,7 +377,7 @@ def save_function(self, obj, name=None):
396377
# So we pickle them here using save_reduce; have to do it differently
397378
# for different python versions.
398379
if not hasattr(obj, '__code__'):
399-
if PY3:
380+
if PY3: # pragma: no branch
400381
rv = obj.__reduce_ex__(self.proto)
401382
else:
402383
if hasattr(obj, '__self__'):
@@ -434,8 +415,31 @@ def save_function(self, obj, name=None):
434415

435416
def _save_subimports(self, code, top_level_dependencies):
436417
"""
437-
Ensure de-pickler imports any package child-modules that
438-
are needed by the function
418+
Save submodules used by a function but not listed in its globals.
419+
420+
In the example below:
421+
422+
```
423+
import concurrent.futures
424+
import cloudpickle
425+
426+
427+
def func():
428+
x = concurrent.futures.ThreadPoolExecutor
429+
430+
431+
if __name__ == '__main__':
432+
cloudpickle.dumps(func)
433+
```
434+
435+
the globals extracted by cloudpickle in the function's state include
436+
the concurrent module, but not its submodule (here,
437+
concurrent.futures), which is the module used by func.
438+
439+
To ensure that calling the depickled function does not raise an
440+
AttributeError, this function looks for any currently loaded submodule
441+
that the function uses and whose parent is present in the function
442+
globals, and saves it before saving the function.
439443
"""
440444

441445
# check if any known dependency is an imported package
@@ -481,6 +485,17 @@ def save_dynamic_class(self, obj):
481485
# doc can't participate in a cycle with the original class.
482486
type_kwargs = {'__doc__': clsdict.pop('__doc__', None)}
483487

488+
if hasattr(obj, "__slots__"):
489+
type_kwargs['__slots__'] = obj.__slots__
490+
# pickle string length optimization: member descriptors of obj are
491+
# created automatically from obj's __slots__ attribute, no need to
492+
# save them in obj's state
493+
if isinstance(obj.__slots__, string_types):
494+
clsdict.pop(obj.__slots__)
495+
else:
496+
for k in obj.__slots__:
497+
clsdict.pop(k, None)
498+
484499
# If type overrides __dict__ as a property, include it in the type kwargs.
485500
# In Python 2, we can't set this attribute after construction.
486501
__dict__ = clsdict.pop('__dict__', None)
@@ -639,17 +654,17 @@ def extract_func_data(self, func):
639654
# save the dict
640655
dct = func.__dict__
641656

642-
base_globals = self.globals_ref.get(id(func.__globals__), None)
643-
if base_globals is None:
644-
# For functions defined in a well behaved module use
645-
# vars(func.__module__) for base_globals. This is necessary to
646-
# share the global variables across multiple pickled functions from
647-
# this module.
648-
if hasattr(func, '__module__') and func.__module__ is not None:
649-
base_globals = func.__module__
650-
else:
651-
base_globals = {}
652-
self.globals_ref[id(func.__globals__)] = base_globals
657+
# base_globals represents the future global namespace of func at
658+
# unpickling time. Looking it up and storing it in globals_ref allow
659+
# functions sharing the same globals at pickling time to also
660+
# share them once unpickled, at one condition: since globals_ref is
661+
# an attribute of a Cloudpickler instance, and that a new CloudPickler is
662+
# created each time pickle.dump or pickle.dumps is called, functions
663+
# also need to be saved within the same invokation of
664+
# cloudpickle.dump/cloudpickle.dumps (for example: cloudpickle.dumps([f1, f2])). There
665+
# is no such limitation when using Cloudpickler.dump, as long as the
666+
# multiple invokations are bound to the same Cloudpickler.
667+
base_globals = self.globals_ref.setdefault(id(func.__globals__), {})
653668

654669
return (code, f_globals, defaults, closure, dct, base_globals)
655670

@@ -699,7 +714,7 @@ def save_instancemethod(self, obj):
699714
if obj.__self__ is None:
700715
self.save_reduce(getattr, (obj.im_class, obj.__name__))
701716
else:
702-
if PY3:
717+
if PY3: # pragma: no branch
703718
self.save_reduce(types.MethodType, (obj.__func__, obj.__self__), obj=obj)
704719
else:
705720
self.save_reduce(types.MethodType, (obj.__func__, obj.__self__, obj.__self__.__class__),
@@ -752,7 +767,7 @@ def save_inst(self, obj):
752767
save(stuff)
753768
write(pickle.BUILD)
754769

755-
if not PY3:
770+
if not PY3: # pragma: no branch
756771
dispatch[types.InstanceType] = save_inst
757772

758773
def save_property(self, obj):
@@ -852,7 +867,7 @@ def save_not_implemented(self, obj):
852867

853868
try: # Python 2
854869
dispatch[file] = save_file
855-
except NameError: # Python 3
870+
except NameError: # Python 3 # pragma: no branch
856871
dispatch[io.TextIOWrapper] = save_file
857872

858873
dispatch[type(Ellipsis)] = save_ellipsis
@@ -873,6 +888,12 @@ def save_root_logger(self, obj):
873888

874889
dispatch[logging.RootLogger] = save_root_logger
875890

891+
if hasattr(types, "MappingProxyType"): # pragma: no branch
892+
def save_mappingproxy(self, obj):
893+
self.save_reduce(types.MappingProxyType, (dict(obj),), obj=obj)
894+
895+
dispatch[types.MappingProxyType] = save_mappingproxy
896+
876897
"""Special functions for Add-on libraries"""
877898
def inject_addons(self):
878899
"""Plug in system. Register additional pickling functions if modules already loaded"""
@@ -1059,10 +1080,16 @@ def _fill_function(*args):
10591080
else:
10601081
raise ValueError('Unexpected _fill_value arguments: %r' % (args,))
10611082

1062-
# Only set global variables that do not exist.
1063-
for k, v in state['globals'].items():
1064-
if k not in func.__globals__:
1065-
func.__globals__[k] = v
1083+
# - At pickling time, any dynamic global variable used by func is
1084+
# serialized by value (in state['globals']).
1085+
# - At unpickling time, func's __globals__ attribute is initialized by
1086+
# first retrieving an empty isolated namespace that will be shared
1087+
# with other functions pickled from the same original module
1088+
# by the same CloudPickler instance and then updated with the
1089+
# content of state['globals'] to populate the shared isolated
1090+
# namespace with all the global variables that are specifically
1091+
# referenced for this function.
1092+
func.__globals__.update(state['globals'])
10661093

10671094
func.__defaults__ = state['defaults']
10681095
func.__dict__ = state['dict']
@@ -1100,21 +1127,11 @@ def _make_skel_func(code, cell_count, base_globals=None):
11001127
code and the correct number of cells in func_closure. All other
11011128
func attributes (e.g. func_globals) are empty.
11021129
"""
1103-
if base_globals is None:
1130+
# This is backward-compatibility code: for cloudpickle versions between
1131+
# 0.5.4 and 0.7, base_globals could be a string or None. base_globals
1132+
# should now always be a dictionary.
1133+
if base_globals is None or isinstance(base_globals, str):
11041134
base_globals = {}
1105-
elif isinstance(base_globals, str):
1106-
base_globals_name = base_globals
1107-
try:
1108-
# First try to reuse the globals from the module containing the
1109-
# function. If it is not possible to retrieve it, fallback to an
1110-
# empty dictionary.
1111-
base_globals = vars(importlib.import_module(base_globals))
1112-
except ImportError:
1113-
base_globals = _dynamic_modules_globals.get(
1114-
base_globals_name, None)
1115-
if base_globals is None:
1116-
base_globals = _DynamicModuleFuncGlobals()
1117-
_dynamic_modules_globals[base_globals_name] = base_globals
11181135

11191136
base_globals['__builtins__'] = __builtins__
11201137

@@ -1182,7 +1199,7 @@ def _getobject(modname, attribute):
11821199

11831200
""" Use copy_reg to extend global pickle definitions """
11841201

1185-
if sys.version_info < (3, 4):
1202+
if sys.version_info < (3, 4): # pragma: no branch
11861203
method_descriptor = type(str.upper)
11871204

11881205
def _reduce_method_descriptor(obj):

python/pyspark/tests/test_rdd.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,9 @@
3232
xrange = range
3333

3434

35+
global_func = lambda: "Hi"
36+
37+
3538
class RDDTests(ReusedPySparkTestCase):
3639

3740
def test_range(self):
@@ -726,6 +729,13 @@ def stopit(*x):
726729
self.assertRaisesRegexp((Py4JJavaError, RuntimeError), msg,
727730
seq_rdd.aggregate, 0, lambda *x: 1, stopit)
728731

732+
def test_overwritten_global_func(self):
733+
# Regression test for SPARK-27000
734+
global global_func
735+
self.assertEqual(self.sc.parallelize([1]).map(lambda _: global_func()).first(), "Hi")
736+
global_func = lambda: "Yeah"
737+
self.assertEqual(self.sc.parallelize([1]).map(lambda _: global_func()).first(), "Yeah")
738+
729739

730740
if __name__ == "__main__":
731741
import unittest

0 commit comments

Comments
 (0)