Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,26 @@ private object SpecialLengths {
private[spark] object PythonRDD {
val UTF8 = Charset.forName("UTF-8")

/**
* Adapter for calling SparkContext#runJob from Python.
*
* This method will return an iterator of an array that contains all elements in the RDD
* (effectively a collect()), but allows you to run on a certain subset of partitions,
* or to enable local execution.
*/
def runJob(
sc: SparkContext,
rdd: JavaRDD[Array[Byte]],
partitions: JArrayList[Int],
allowLocal: Boolean): Iterator[Array[Byte]] = {
type ByteArray = Array[Byte]
type UnrolledPartition = Array[ByteArray]
val allPartitions: Array[UnrolledPartition] =
sc.runJob(rdd, (x: Iterator[ByteArray]) => x.toArray, partitions, allowLocal)
val flattenedPartition: UnrolledPartition = Array.concat(allPartitions: _*)
flattenedPartition.iterator
}

def readRDDFromFile(sc: JavaSparkContext, filename: String, parallelism: Int):
JavaRDD[Array[Byte]] = {
val file = new DataInputStream(new FileInputStream(filename))
Expand Down
26 changes: 26 additions & 0 deletions python/pyspark/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -537,6 +537,32 @@ def cancelAllJobs(self):
"""
self._jsc.sc().cancelAllJobs()

def runJob(self, rdd, partitionFunc, partitions = None, allowLocal = False):
"""
Executes the given partitionFunc on the specified set of partitions,
returning the result as an array of elements.

If 'partitions' is not specified, this will run over all partitions.

>>> myRDD = sc.parallelize(range(6), 3)
>>> sc.runJob(myRDD, lambda part: [x * x for x in part])
[0, 1, 4, 9, 16, 25]

>>> myRDD = sc.parallelize(range(6), 3)
>>> sc.runJob(myRDD, lambda part: [x * x for x in part], [0, 2], True)
[0, 1, 16, 25]
"""
if partitions == None:
partitions = range(rdd._jrdd.splits().size())
javaPartitions = ListConverter().convert(partitions, self._gateway._gateway_client)

# Implementation note: This is implemented as a mapPartitions followed
# by runJob() in order to avoid having to pass a Python lambda into
# SparkContext#runJob.
mappedRDD = rdd.mapPartitions(partitionFunc)
it = self._jvm.PythonRDD.runJob(self._jsc.sc(), mappedRDD._jrdd, javaPartitions, allowLocal)
return list(mappedRDD._collect_iterator_through_file(it))

def _test():
import atexit
import doctest
Expand Down
59 changes: 38 additions & 21 deletions python/pyspark/rdd.py
Original file line number Diff line number Diff line change
Expand Up @@ -841,34 +841,51 @@ def take(self, num):
"""
Take the first num elements of the RDD.

This currently scans the partitions *one by one*, so it will be slow if
a lot of partitions are required. In that case, use L{collect} to get
the whole RDD instead.
It works by first scanning one partition, and use the results from
that partition to estimate the number of additional partitions needed
to satisfy the limit.

Translated from the Scala implementation in RDD#take().

>>> sc.parallelize([2, 3, 4, 5, 6]).cache().take(2)
[2, 3]
>>> sc.parallelize([2, 3, 4, 5, 6]).take(10)
[2, 3, 4, 5, 6]
>>> sc.parallelize(range(100), 100).filter(lambda x: x > 90).take(3)
[91, 92, 93]
"""
def takeUpToNum(iterator):
taken = 0
while taken < num:
yield next(iterator)
taken += 1
# Take only up to num elements from each partition we try
mapped = self.mapPartitions(takeUpToNum)
items = []
# TODO(shivaram): Similar to the scala implementation, update the take
# method to scan multiple splits based on an estimate of how many elements
# we have per-split.
with _JavaStackTrace(self.context) as st:
for partition in range(mapped._jrdd.splits().size()):
partitionsToTake = self.ctx._gateway.new_array(self.ctx._jvm.int, 1)
partitionsToTake[0] = partition
iterator = mapped._jrdd.collectPartitions(partitionsToTake)[0].iterator()
items.extend(mapped._collect_iterator_through_file(iterator))
if len(items) >= num:
break
totalParts = self._jrdd.splits().size()
partsScanned = 0

while len(items) < num and partsScanned < totalParts:
# The number of partitions to try in this iteration.
# It is ok for this number to be greater than totalParts because
# we actually cap it at totalParts in runJob.
numPartsToTry = 1
if partsScanned > 0:
# If we didn't find any rows after the first iteration, just
# try all partitions next. Otherwise, interpolate the number
# of partitions we need to try, but overestimate it by 50%.
if len(items) == 0:
numPartsToTry = totalParts - 1
else:
numPartsToTry = int(1.5 * num * partsScanned / len(items))

left = num - len(items)

def takeUpToNumLeft(iterator):
taken = 0
while taken < left:
yield next(iterator)
taken += 1

p = range(partsScanned, min(partsScanned + numPartsToTry, totalParts))
res = self.context.runJob(self, takeUpToNumLeft, p, True)

items += res
partsScanned += numPartsToTry

return items[:num]

def first(self):
Expand Down