From 60e2abd9616016dce8e5dc2faf5c75be8e07335f Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Fri, 7 Oct 2016 04:59:37 +0000 Subject: [PATCH] Decrease the batch size for repartition. --- python/pyspark/rdd.py | 13 ++++++++++--- python/pyspark/tests.py | 10 ++++++++++ 2 files changed, 20 insertions(+), 3 deletions(-) diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index ed81eb16df3c..2de2c2fd1a60 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -2017,8 +2017,7 @@ def repartition(self, numPartitions): >>> len(rdd.repartition(10).glom().collect()) 10 """ - jrdd = self._jrdd.repartition(numPartitions) - return RDD(jrdd, self.ctx, self._jrdd_deserializer) + return self.coalesce(numPartitions, shuffle=True) def coalesce(self, numPartitions, shuffle=False): """ @@ -2029,7 +2028,15 @@ def coalesce(self, numPartitions, shuffle=False): >>> sc.parallelize([1, 2, 3, 4, 5], 3).coalesce(1).glom().collect() [[1, 2, 3, 4, 5]] """ - jrdd = self._jrdd.coalesce(numPartitions, shuffle) + if shuffle: + # Decrease the batch size in order to distribute evenly the elements across output + # partitions. Otherwise, repartition will possibly produce highly skewed partitions. + batchSize = min(10, self.ctx._batchSize or 1024) + ser = BatchedSerializer(PickleSerializer(), batchSize) + selfCopy = self._reserialize(ser) + jrdd = selfCopy._jrdd.coalesce(numPartitions, shuffle) + else: + jrdd = self._jrdd.coalesce(numPartitions, shuffle) return RDD(jrdd, self.ctx, self._jrdd_deserializer) def zip(self, other): diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index b0756911bfc1..3e0bd16d85ca 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -914,6 +914,16 @@ def test_repartitionAndSortWithinPartitions(self): self.assertEqual(partitions[0], [(0, 5), (0, 8), (2, 6)]) self.assertEqual(partitions[1], [(1, 3), (3, 8), (3, 8)]) + def test_repartition_no_skewed(self): + num_partitions = 20 + a = self.sc.parallelize(range(int(1000)), 2) + l = a.repartition(num_partitions).glom().map(len).collect() + zeros = len([x for x in l if x == 0]) + self.assertTrue(zeros == 0) + l = a.coalesce(num_partitions, True).glom().map(len).collect() + zeros = len([x for x in l if x == 0]) + self.assertTrue(zeros == 0) + def test_distinct(self): rdd = self.sc.parallelize((1, 2, 3)*10, 10) self.assertEqual(rdd.getNumPartitions(), 10)