Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
163 changes: 90 additions & 73 deletions python/pyspark/cloudpickle.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,58 +42,43 @@
"""
from __future__ import print_function

import io
import dis
import sys
import types
from functools import partial
import importlib
import io
import itertools
import logging
import opcode
import operator
import pickle
import struct
import logging
import weakref
import operator
import importlib
import itertools
import sys
import traceback
from functools import partial

import types
import weakref

# cloudpickle is meant for inter process communication: we expect all
# communicating processes to run the same Python version hence we favor
# communication speed over compatibility:
DEFAULT_PROTOCOL = pickle.HIGHEST_PROTOCOL


if sys.version < '3':
if sys.version_info[0] < 3: # pragma: no branch
from pickle import Pickler
try:
from cStringIO import StringIO
except ImportError:
from StringIO import StringIO
string_types = (basestring,) # noqa
PY3 = False
else:
types.ClassType = type
from pickle import _Pickler as Pickler
from io import BytesIO as StringIO
string_types = (str,)
PY3 = True


# Container for the global namespace to ensure consistent unpickling of
# functions defined in dynamic modules (modules not registed in sys.modules).
_dynamic_modules_globals = weakref.WeakValueDictionary()


class _DynamicModuleFuncGlobals(dict):
"""Global variables referenced by a function defined in a dynamic module

To avoid leaking references we store such context in a WeakValueDictionary
instance. However instances of python builtin types such as dict cannot
be used directly as values in such a construct, hence the need for a
derived class.
"""
pass


def _make_cell_set_template_code():
"""Get the Python compiler to emit LOAD_FAST(arg); STORE_DEREF

Expand All @@ -112,7 +97,7 @@ def _stub(value):

return _stub

_cell_set_template_code = f()
_cell_set_template_code = f().__code__

This function is _only_ a LOAD_FAST(arg); STORE_DEREF, but that is
invalid syntax on Python 2. If we use this function we also don't need
Expand All @@ -127,7 +112,7 @@ def inner(value):
# NOTE: we are marking the cell variable as a free variable intentionally
# so that we simulate an inner function instead of the outer function. This
# is what gives us the ``nonlocal`` behavior in a Python 2 compatible way.
if not PY3:
if not PY3: # pragma: no branch
return types.CodeType(
co.co_argcount,
co.co_nlocals,
Expand Down Expand Up @@ -228,14 +213,14 @@ def _factory():
}


if sys.version_info < (3, 4):
if sys.version_info < (3, 4): # pragma: no branch
def _walk_global_ops(code):
"""
Yield (opcode, argument number) tuples for all
global-referencing instructions in *code*.
"""
code = getattr(code, 'co_code', b'')
if not PY3:
if not PY3: # pragma: no branch
code = map(ord, code)

n = len(code)
Expand Down Expand Up @@ -273,8 +258,6 @@ def __init__(self, file, protocol=None):
if protocol is None:
protocol = DEFAULT_PROTOCOL
Pickler.__init__(self, file, protocol=protocol)
# set of modules to unpickle
self.modules = set()
# map ids to dictionary. used to ensure that functions can share global env
self.globals_ref = {}

Expand All @@ -294,7 +277,7 @@ def save_memoryview(self, obj):

dispatch[memoryview] = save_memoryview

if not PY3:
if not PY3: # pragma: no branch
def save_buffer(self, obj):
self.save(str(obj))

Expand All @@ -304,7 +287,6 @@ def save_module(self, obj):
"""
Save a module as an import
"""
self.modules.add(obj)
if _is_dynamic(obj):
self.save_reduce(dynamic_subimport, (obj.__name__, vars(obj)),
obj=obj)
Expand All @@ -317,7 +299,7 @@ def save_codeobject(self, obj):
"""
Save a code object
"""
if PY3:
if PY3: # pragma: no branch
args = (
obj.co_argcount, obj.co_kwonlyargcount, obj.co_nlocals, obj.co_stacksize,
obj.co_flags, obj.co_code, obj.co_consts, obj.co_names, obj.co_varnames,
Expand Down Expand Up @@ -384,7 +366,6 @@ def save_function(self, obj, name=None):
lookedup_by_name = None

if themodule:
self.modules.add(themodule)
if lookedup_by_name is obj:
return self.save_global(obj, name)

Expand All @@ -396,7 +377,7 @@ def save_function(self, obj, name=None):
# So we pickle them here using save_reduce; have to do it differently
# for different python versions.
if not hasattr(obj, '__code__'):
if PY3:
if PY3: # pragma: no branch
rv = obj.__reduce_ex__(self.proto)
else:
if hasattr(obj, '__self__'):
Expand Down Expand Up @@ -434,8 +415,31 @@ def save_function(self, obj, name=None):

def _save_subimports(self, code, top_level_dependencies):
"""
Ensure de-pickler imports any package child-modules that
are needed by the function
Save submodules used by a function but not listed in its globals.

In the example below:

```
import concurrent.futures
import cloudpickle


def func():
x = concurrent.futures.ThreadPoolExecutor


if __name__ == '__main__':
cloudpickle.dumps(func)
```

the globals extracted by cloudpickle in the function's state include
the concurrent module, but not its submodule (here,
concurrent.futures), which is the module used by func.

To ensure that calling the depickled function does not raise an
AttributeError, this function looks for any currently loaded submodule
that the function uses and whose parent is present in the function
globals, and saves it before saving the function.
"""

# check if any known dependency is an imported package
Expand Down Expand Up @@ -481,6 +485,17 @@ def save_dynamic_class(self, obj):
# doc can't participate in a cycle with the original class.
type_kwargs = {'__doc__': clsdict.pop('__doc__', None)}

if hasattr(obj, "__slots__"):
type_kwargs['__slots__'] = obj.__slots__
# pickle string length optimization: member descriptors of obj are
# created automatically from obj's __slots__ attribute, no need to
# save them in obj's state
if isinstance(obj.__slots__, string_types):
clsdict.pop(obj.__slots__)
else:
for k in obj.__slots__:
clsdict.pop(k, None)

# If type overrides __dict__ as a property, include it in the type kwargs.
# In Python 2, we can't set this attribute after construction.
__dict__ = clsdict.pop('__dict__', None)
Expand Down Expand Up @@ -639,17 +654,17 @@ def extract_func_data(self, func):
# save the dict
dct = func.__dict__

base_globals = self.globals_ref.get(id(func.__globals__), None)
if base_globals is None:
# For functions defined in a well behaved module use
# vars(func.__module__) for base_globals. This is necessary to
# share the global variables across multiple pickled functions from
# this module.
if hasattr(func, '__module__') and func.__module__ is not None:
base_globals = func.__module__
else:
base_globals = {}
self.globals_ref[id(func.__globals__)] = base_globals
# base_globals represents the future global namespace of func at
# unpickling time. Looking it up and storing it in globals_ref allow
# functions sharing the same globals at pickling time to also
# share them once unpickled, at one condition: since globals_ref is
# an attribute of a Cloudpickler instance, and that a new CloudPickler is
# created each time pickle.dump or pickle.dumps is called, functions
# also need to be saved within the same invokation of
# cloudpickle.dump/cloudpickle.dumps (for example: cloudpickle.dumps([f1, f2])). There
# is no such limitation when using Cloudpickler.dump, as long as the
# multiple invokations are bound to the same Cloudpickler.
base_globals = self.globals_ref.setdefault(id(func.__globals__), {})

return (code, f_globals, defaults, closure, dct, base_globals)

Expand Down Expand Up @@ -699,7 +714,7 @@ def save_instancemethod(self, obj):
if obj.__self__ is None:
self.save_reduce(getattr, (obj.im_class, obj.__name__))
else:
if PY3:
if PY3: # pragma: no branch
self.save_reduce(types.MethodType, (obj.__func__, obj.__self__), obj=obj)
else:
self.save_reduce(types.MethodType, (obj.__func__, obj.__self__, obj.__self__.__class__),
Expand Down Expand Up @@ -752,7 +767,7 @@ def save_inst(self, obj):
save(stuff)
write(pickle.BUILD)

if not PY3:
if not PY3: # pragma: no branch
dispatch[types.InstanceType] = save_inst

def save_property(self, obj):
Expand Down Expand Up @@ -852,7 +867,7 @@ def save_not_implemented(self, obj):

try: # Python 2
dispatch[file] = save_file
except NameError: # Python 3
except NameError: # Python 3 # pragma: no branch
dispatch[io.TextIOWrapper] = save_file

dispatch[type(Ellipsis)] = save_ellipsis
Expand All @@ -873,6 +888,12 @@ def save_root_logger(self, obj):

dispatch[logging.RootLogger] = save_root_logger

if hasattr(types, "MappingProxyType"): # pragma: no branch
def save_mappingproxy(self, obj):
self.save_reduce(types.MappingProxyType, (dict(obj),), obj=obj)

dispatch[types.MappingProxyType] = save_mappingproxy

"""Special functions for Add-on libraries"""
def inject_addons(self):
"""Plug in system. Register additional pickling functions if modules already loaded"""
Expand Down Expand Up @@ -1059,10 +1080,16 @@ def _fill_function(*args):
else:
raise ValueError('Unexpected _fill_value arguments: %r' % (args,))

# Only set global variables that do not exist.
for k, v in state['globals'].items():
if k not in func.__globals__:
func.__globals__[k] = v
# - At pickling time, any dynamic global variable used by func is
# serialized by value (in state['globals']).
# - At unpickling time, func's __globals__ attribute is initialized by
# first retrieving an empty isolated namespace that will be shared
# with other functions pickled from the same original module
# by the same CloudPickler instance and then updated with the
# content of state['globals'] to populate the shared isolated
# namespace with all the global variables that are specifically
# referenced for this function.
func.__globals__.update(state['globals'])

func.__defaults__ = state['defaults']
func.__dict__ = state['dict']
Expand Down Expand Up @@ -1100,21 +1127,11 @@ def _make_skel_func(code, cell_count, base_globals=None):
code and the correct number of cells in func_closure. All other
func attributes (e.g. func_globals) are empty.
"""
if base_globals is None:
# This is backward-compatibility code: for cloudpickle versions between
# 0.5.4 and 0.7, base_globals could be a string or None. base_globals
# should now always be a dictionary.
if base_globals is None or isinstance(base_globals, str):
base_globals = {}
elif isinstance(base_globals, str):
base_globals_name = base_globals
try:
# First try to reuse the globals from the module containing the
# function. If it is not possible to retrieve it, fallback to an
# empty dictionary.
base_globals = vars(importlib.import_module(base_globals))
except ImportError:
base_globals = _dynamic_modules_globals.get(
base_globals_name, None)
if base_globals is None:
base_globals = _DynamicModuleFuncGlobals()
_dynamic_modules_globals[base_globals_name] = base_globals

base_globals['__builtins__'] = __builtins__

Expand Down Expand Up @@ -1182,7 +1199,7 @@ def _getobject(modname, attribute):

""" Use copy_reg to extend global pickle definitions """

if sys.version_info < (3, 4):
if sys.version_info < (3, 4): # pragma: no branch
method_descriptor = type(str.upper)

def _reduce_method_descriptor(obj):
Expand Down
10 changes: 10 additions & 0 deletions python/pyspark/tests/test_rdd.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@
xrange = range


global_func = lambda: "Hi"


class RDDTests(ReusedPySparkTestCase):

def test_range(self):
Expand Down Expand Up @@ -726,6 +729,13 @@ def stopit(*x):
self.assertRaisesRegexp((Py4JJavaError, RuntimeError), msg,
seq_rdd.aggregate, 0, lambda *x: 1, stopit)

def test_overwritten_global_func(self):
# Regression test for SPARK-27000
global global_func
self.assertEqual(self.sc.parallelize([1]).map(lambda _: global_func()).first(), "Hi")
global_func = lambda: "Yeah"
self.assertEqual(self.sc.parallelize([1]).map(lambda _: global_func()).first(), "Yeah")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I ran a test locally:

>>> global_func = lambda: "Hi"
>>> sc.parallelize([1]).map(lambda _: global_func()).first()
'Hi'
>>> global_func = lambda: "Yeah"
>>> sc.parallelize([1]).map(lambda _: global_func()).first()
'Hi'
>>> sc.parallelize([1]).map(lambda _: global_func()).first()
'Yeah'
>>> sc.parallelize([1]).map(lambda _: global_func()).first()
'Hi'

Seems it outputs Hi or Yeah randomly. Is it caused by this cloudpickle issue?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it looks so. After the fix, the result becomes consistent:

>>> global_func = lambda: "Hi"
>>> sc.parallelize([1]).map(lambda _: global_func()).first()
'Hi'
>>> global_func = lambda: "Yeah"
>>> sc.parallelize([1]).map(lambda _: global_func()).first()
'Yeah'
>>> sc.parallelize([1]).map(lambda _: global_func()).first()
'Yeah'
>>> sc.parallelize([1]).map(lambda _: global_func()).first()
'Yeah'



if __name__ == "__main__":
import unittest
Expand Down