Skip to content

Commit 9a0ea4c

Browse files
author
Davies Liu
committed
fix big closure with shuffle
1 parent 1e340c3 commit 9a0ea4c

File tree

2 files changed

+7
-14
lines changed

2 files changed

+7
-14
lines changed

python/pyspark/rdd.py

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1197,7 +1197,7 @@ def take(self, num):
11971197
[91, 92, 93]
11981198
"""
11991199
items = []
1200-
totalParts = self._jrdd.partitions().size()
1200+
totalParts = self.getNumPartitions()
12011201
partsScanned = 0
12021202

12031203
while len(items) < num and partsScanned < totalParts:
@@ -1260,7 +1260,7 @@ def isEmpty(self):
12601260
>>> sc.parallelize([1]).isEmpty()
12611261
False
12621262
"""
1263-
return self._jrdd.partitions().size() == 0 or len(self.take(1)) == 0
1263+
return self.getNumPartitions() == 0 or len(self.take(1)) == 0
12641264

12651265
def saveAsNewAPIHadoopDataset(self, conf, keyConverter=None, valueConverter=None):
12661266
"""
@@ -2235,11 +2235,9 @@ def _prepare_for_python_RDD(sc, command, obj=None):
22352235
ser = CloudPickleSerializer()
22362236
pickled_command = ser.dumps((command, sys.version_info[:2]))
22372237
if len(pickled_command) > (1 << 20): # 1M
2238+
# The broadcast will have same life cycle as created PythonRDD
22382239
broadcast = sc.broadcast(pickled_command)
22392240
pickled_command = ser.dumps(broadcast)
2240-
# tracking the life cycle by obj
2241-
if obj is not None:
2242-
obj._broadcast = broadcast
22432241
broadcast_vars = ListConverter().convert(
22442242
[x._jbroadcast for x in sc._pickled_broadcast_vars],
22452243
sc._gateway._gateway_client)
@@ -2294,12 +2292,9 @@ def pipeline_func(split, iterator):
22942292
self._jrdd_deserializer = self.ctx.serializer
22952293
self._bypass_serializer = False
22962294
self.partitioner = prev.partitioner if self.preservesPartitioning else None
2297-
self._broadcast = None
22982295

2299-
def __del__(self):
2300-
if self._broadcast:
2301-
self._broadcast.unpersist()
2302-
self._broadcast = None
2296+
def getNumPartitions(self):
2297+
return self._prev_jrdd.partitions().size()
23032298

23042299
@property
23052300
def _jrdd(self):

python/pyspark/tests.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -550,10 +550,8 @@ def test_large_closure(self):
550550
data = [float(i) for i in xrange(N)]
551551
rdd = self.sc.parallelize(range(1), 1).map(lambda x: len(data))
552552
self.assertEquals(N, rdd.first())
553-
self.assertTrue(rdd._broadcast is not None)
554-
rdd = self.sc.parallelize(range(1), 1).map(lambda x: 1)
555-
self.assertEqual(1, rdd.first())
556-
self.assertTrue(rdd._broadcast is None)
553+
# regression test for SPARK-6886
554+
self.assertEqual(1, rdd.map(lambda x: (x, 1)).groupByKey().count())
557555

558556
def test_zip_with_different_serializers(self):
559557
a = self.sc.parallelize(range(5))

0 commit comments

Comments
 (0)