Skip to content

Commit 9909efc

Browse files
aarondavrxin
authored andcommitted
SPARK-1839: PySpark RDD#take() shouldn't always read from driver
This patch simply ports over the Scala implementation of RDD#take(), which reads the first partition at the driver, then decides how many more partitions it needs to read and will possibly start a real job if it's more than 1. (Note that SparkContext#runJob(allowLocal=true) only runs the job locally if there's 1 partition selected and no parent stages.) Author: Aaron Davidson <[email protected]> Closes #922 from aarondav/take and squashes the following commits: fa06df9 [Aaron Davidson] SPARK-1839: PySpark RDD#take() shouldn't always read from driver
1 parent 7d52777 commit 9909efc

File tree

3 files changed

+84
-21
lines changed

3 files changed

+84
-21
lines changed

core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -269,6 +269,26 @@ private object SpecialLengths {
269269
private[spark] object PythonRDD {
270270
val UTF8 = Charset.forName("UTF-8")
271271

272+
/**
273+
* Adapter for calling SparkContext#runJob from Python.
274+
*
275+
* This method will return an iterator of an array that contains all elements in the RDD
276+
* (effectively a collect()), but allows you to run on a certain subset of partitions,
277+
* or to enable local execution.
278+
*/
279+
def runJob(
280+
sc: SparkContext,
281+
rdd: JavaRDD[Array[Byte]],
282+
partitions: JArrayList[Int],
283+
allowLocal: Boolean): Iterator[Array[Byte]] = {
284+
type ByteArray = Array[Byte]
285+
type UnrolledPartition = Array[ByteArray]
286+
val allPartitions: Array[UnrolledPartition] =
287+
sc.runJob(rdd, (x: Iterator[ByteArray]) => x.toArray, partitions, allowLocal)
288+
val flattenedPartition: UnrolledPartition = Array.concat(allPartitions: _*)
289+
flattenedPartition.iterator
290+
}
291+
272292
def readRDDFromFile(sc: JavaSparkContext, filename: String, parallelism: Int):
273293
JavaRDD[Array[Byte]] = {
274294
val file = new DataInputStream(new FileInputStream(filename))

python/pyspark/context.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -537,6 +537,32 @@ def cancelAllJobs(self):
537537
"""
538538
self._jsc.sc().cancelAllJobs()
539539

540+
def runJob(self, rdd, partitionFunc, partitions = None, allowLocal = False):
541+
"""
542+
Executes the given partitionFunc on the specified set of partitions,
543+
returning the result as an array of elements.
544+
545+
If 'partitions' is not specified, this will run over all partitions.
546+
547+
>>> myRDD = sc.parallelize(range(6), 3)
548+
>>> sc.runJob(myRDD, lambda part: [x * x for x in part])
549+
[0, 1, 4, 9, 16, 25]
550+
551+
>>> myRDD = sc.parallelize(range(6), 3)
552+
>>> sc.runJob(myRDD, lambda part: [x * x for x in part], [0, 2], True)
553+
[0, 1, 16, 25]
554+
"""
555+
if partitions == None:
556+
partitions = range(rdd._jrdd.splits().size())
557+
javaPartitions = ListConverter().convert(partitions, self._gateway._gateway_client)
558+
559+
# Implementation note: This is implemented as a mapPartitions followed
560+
# by runJob() in order to avoid having to pass a Python lambda into
561+
# SparkContext#runJob.
562+
mappedRDD = rdd.mapPartitions(partitionFunc)
563+
it = self._jvm.PythonRDD.runJob(self._jsc.sc(), mappedRDD._jrdd, javaPartitions, allowLocal)
564+
return list(mappedRDD._collect_iterator_through_file(it))
565+
540566
def _test():
541567
import atexit
542568
import doctest

python/pyspark/rdd.py

Lines changed: 38 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -841,34 +841,51 @@ def take(self, num):
841841
"""
842842
Take the first num elements of the RDD.
843843
844-
This currently scans the partitions *one by one*, so it will be slow if
845-
a lot of partitions are required. In that case, use L{collect} to get
846-
the whole RDD instead.
844+
It works by first scanning one partition, and use the results from
845+
that partition to estimate the number of additional partitions needed
846+
to satisfy the limit.
847+
848+
Translated from the Scala implementation in RDD#take().
847849
848850
>>> sc.parallelize([2, 3, 4, 5, 6]).cache().take(2)
849851
[2, 3]
850852
>>> sc.parallelize([2, 3, 4, 5, 6]).take(10)
851853
[2, 3, 4, 5, 6]
854+
>>> sc.parallelize(range(100), 100).filter(lambda x: x > 90).take(3)
855+
[91, 92, 93]
852856
"""
853-
def takeUpToNum(iterator):
854-
taken = 0
855-
while taken < num:
856-
yield next(iterator)
857-
taken += 1
858-
# Take only up to num elements from each partition we try
859-
mapped = self.mapPartitions(takeUpToNum)
860857
items = []
861-
# TODO(shivaram): Similar to the scala implementation, update the take
862-
# method to scan multiple splits based on an estimate of how many elements
863-
# we have per-split.
864-
with _JavaStackTrace(self.context) as st:
865-
for partition in range(mapped._jrdd.splits().size()):
866-
partitionsToTake = self.ctx._gateway.new_array(self.ctx._jvm.int, 1)
867-
partitionsToTake[0] = partition
868-
iterator = mapped._jrdd.collectPartitions(partitionsToTake)[0].iterator()
869-
items.extend(mapped._collect_iterator_through_file(iterator))
870-
if len(items) >= num:
871-
break
858+
totalParts = self._jrdd.splits().size()
859+
partsScanned = 0
860+
861+
while len(items) < num and partsScanned < totalParts:
862+
# The number of partitions to try in this iteration.
863+
# It is ok for this number to be greater than totalParts because
864+
# we actually cap it at totalParts in runJob.
865+
numPartsToTry = 1
866+
if partsScanned > 0:
867+
# If we didn't find any rows after the first iteration, just
868+
# try all partitions next. Otherwise, interpolate the number
869+
# of partitions we need to try, but overestimate it by 50%.
870+
if len(items) == 0:
871+
numPartsToTry = totalParts - 1
872+
else:
873+
numPartsToTry = int(1.5 * num * partsScanned / len(items))
874+
875+
left = num - len(items)
876+
877+
def takeUpToNumLeft(iterator):
878+
taken = 0
879+
while taken < left:
880+
yield next(iterator)
881+
taken += 1
882+
883+
p = range(partsScanned, min(partsScanned + numPartsToTry, totalParts))
884+
res = self.context.runJob(self, takeUpToNumLeft, p, True)
885+
886+
items += res
887+
partsScanned += numPartsToTry
888+
872889
return items[:num]
873890

874891
def first(self):

0 commit comments

Comments
 (0)