Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions python/pyspark/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,9 @@
from pyspark.conf import SparkConf
from pyspark.files import SparkFiles
from pyspark.java_gateway import launch_gateway, local_connect_and_auth
from pyspark.serializers import PickleSerializer, BatchedSerializer, UTF8Deserializer, \
PairDeserializer, AutoBatchedSerializer, NoOpSerializer, ChunkedStream
from pyspark.serializers import CloudPickleSerializer, BatchedSerializer, \
UTF8Deserializer, PairDeserializer, AutoBatchedSerializer, \
NoOpSerializer, ChunkedStream
from pyspark.storagelevel import StorageLevel
from pyspark.rdd import RDD, _load_from_socket, ignore_unicode_prefix
from pyspark.traceback_utils import CallSite, first_spark_call
Expand Down Expand Up @@ -75,8 +76,8 @@ class SparkContext(object):
PACKAGE_EXTENSIONS = ('.zip', '.egg', '.jar')

def __init__(self, master=None, appName=None, sparkHome=None, pyFiles=None,
environment=None, batchSize=0, serializer=PickleSerializer(), conf=None,
gateway=None, jsc=None, profiler_cls=BasicProfiler):
environment=None, batchSize=0, serializer=CloudPickleSerializer(),
conf=None, gateway=None, jsc=None, profiler_cls=BasicProfiler):
"""
Create a new SparkContext. At least the master and app name should be set,
either through the named parameters here or through C{conf}.
Expand Down
84 changes: 0 additions & 84 deletions python/pyspark/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -474,90 +474,6 @@ def dumps(self, obj):
return obj


# Hack namedtuple, make it picklable

__cls = {}


def _restore(name, fields, value):
""" Restore an object of namedtuple"""
k = (name, fields)
cls = __cls.get(k)
if cls is None:
cls = collections.namedtuple(name, fields)
__cls[k] = cls
return cls(*value)


def _hack_namedtuple(cls):
""" Make class generated by namedtuple picklable """
name = cls.__name__
fields = cls._fields

def __reduce__(self):
return (_restore, (name, fields, tuple(self)))
cls.__reduce__ = __reduce__
cls._is_namedtuple_ = True
return cls


def _hijack_namedtuple():
""" Hack namedtuple() to make it picklable """
# hijack only one time
if hasattr(collections.namedtuple, "__hijack"):
return

global _old_namedtuple # or it will put in closure
global _old_namedtuple_kwdefaults # or it will put in closure too

def _copy_func(f):
return types.FunctionType(f.__code__, f.__globals__, f.__name__,
f.__defaults__, f.__closure__)

def _kwdefaults(f):
# __kwdefaults__ contains the default values of keyword-only arguments which are
# introduced from Python 3. The possible cases for __kwdefaults__ in namedtuple
# are as below:
#
# - Does not exist in Python 2.
# - Returns None in <= Python 3.5.x.
# - Returns a dictionary containing the default values to the keys from Python 3.6.x
# (See https://bugs.python.org/issue25628).
kargs = getattr(f, "__kwdefaults__", None)
if kargs is None:
return {}
else:
return kargs

_old_namedtuple = _copy_func(collections.namedtuple)
_old_namedtuple_kwdefaults = _kwdefaults(collections.namedtuple)

def namedtuple(*args, **kwargs):
for k, v in _old_namedtuple_kwdefaults.items():
kwargs[k] = kwargs.get(k, v)
cls = _old_namedtuple(*args, **kwargs)
return _hack_namedtuple(cls)

# replace namedtuple with the new one
collections.namedtuple.__globals__["_old_namedtuple_kwdefaults"] = _old_namedtuple_kwdefaults
collections.namedtuple.__globals__["_old_namedtuple"] = _old_namedtuple
collections.namedtuple.__globals__["_hack_namedtuple"] = _hack_namedtuple
collections.namedtuple.__code__ = namedtuple.__code__
collections.namedtuple.__hijack = 1

# hack the cls already generated by namedtuple.
# Those created in other modules can be pickled as normal,
# so only hack those in __main__ module
for n, o in sys.modules["__main__"].__dict__.items():
if (type(o) is type and o.__base__ is tuple
and hasattr(o, "_fields")
and "__reduce__" not in o.__dict__):
_hack_namedtuple(o) # hack inplace


_hijack_namedtuple()


class PickleSerializer(FramedSerializer):

"""
Expand Down
3 changes: 1 addition & 2 deletions python/pyspark/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,13 +238,12 @@ class SerializationTestCase(unittest.TestCase):

def test_namedtuple(self):
from collections import namedtuple
from pickle import dumps, loads
from pyspark.cloudpickle import dumps, loads
P = namedtuple("P", "x y")
p1 = P(1, 3)
p2 = loads(dumps(p1, 2))
self.assertEqual(p1, p2)

from pyspark.cloudpickle import dumps
P2 = loads(dumps(P))
p3 = P2(1, 3)
self.assertEqual(p1, p3)
Expand Down