Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -776,7 +776,7 @@ private[spark] object PythonRDD extends Logging {
}

/**
* Convert and RDD of Java objects to and RDD of serialized Python objects, that is usable by
* Convert an RDD of Java objects to an RDD of serialized Python objects, that is usable by
* PySpark.
*/
def javaToPython(jRDD: JavaRDD[Any]): JavaRDD[Array[Byte]] = {
Expand Down
47 changes: 42 additions & 5 deletions python/pyspark/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from pyspark.rdd import RDD, PipelinedRDD
from pyspark.serializers import BatchedSerializer, PickleSerializer, CloudPickleSerializer
from pyspark.storagelevel import StorageLevel
from pyspark.traceback_utils import SCCallSiteSync

from itertools import chain, ifilter, imap

Expand Down Expand Up @@ -1550,6 +1551,18 @@ def id(self):
self._id = self._jrdd.id()
return self._id

def limit(self, num):
"""Limit the result count to the number specified.

>>> srdd = sqlCtx.inferSchema(rdd)
>>> srdd.limit(2).collect()
[Row(field1=1, field2=u'row1'), Row(field1=2, field2=u'row2')]
>>> srdd.limit(0).collect()
[]
"""
rdd = self._jschema_rdd.baseSchemaRDD().limit(num).toJavaSchemaRDD()
return SchemaRDD(rdd, self.sql_ctx)

def saveAsParquetFile(self, path):
"""Save the contents as a Parquet file, preserving the schema.

Expand Down Expand Up @@ -1626,15 +1639,39 @@ def count(self):
return self._jschema_rdd.count()

def collect(self):
"""
Return a list that contains all of the rows in this RDD.
"""Return a list that contains all of the rows in this RDD.

Each object in the list is on Row, the fields can be accessed as
Each object in the list is a Row, the fields can be accessed as
attributes.

Unlike the base RDD implementation of collect, this implementation
leverages the query optimizer to perform a collect on the SchemaRDD,
which supports features such as filter pushdown.

>>> srdd = sqlCtx.inferSchema(rdd)
>>> srdd.collect()
[Row(field1=1, field2=u'row1'), ..., Row(field1=3, field2=u'row3')]
"""
rows = RDD.collect(self)
with SCCallSiteSync(self.context) as css:
bytesInJava = self._jschema_rdd.baseSchemaRDD().collectToPython().iterator()
cls = _create_cls(self.schema())
return map(cls, rows)
return map(cls, self._collect_iterator_through_file(bytesInJava))

def take(self, num):
"""Take the first num rows of the RDD.

Each object in the list is a Row, the fields can be accessed as
attributes.

Unlike the base RDD implementation of take, this implementation
leverages the query optimizer to perform a collect on a SchemaRDD,
which supports features such as filter pushdown.

>>> srdd = sqlCtx.inferSchema(rdd)
>>> srdd.take(2)
[Row(field1=1, field2=u'row1'), Row(field1=2, field2=u'row2')]
"""
return self.limit(num).collect()

# Convert each object in the RDD to a Row with the right class
# for this SchemaRDD, so that fields can be accessed as attributes.
Expand Down
37 changes: 26 additions & 11 deletions sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -377,15 +377,15 @@ class SchemaRDD(
def toJavaSchemaRDD: JavaSchemaRDD = new JavaSchemaRDD(sqlContext, logicalPlan)

/**
* Converts a JavaRDD to a PythonRDD. It is used by pyspark.
* Helper for converting a Row to a simple Array suitable for pyspark serialization.
*/
private[sql] def javaToPython: JavaRDD[Array[Byte]] = {
private def rowToJArray(row: Row, structType: StructType): Array[Any] = {
import scala.collection.Map

def toJava(obj: Any, dataType: DataType): Any = (obj, dataType) match {
case (null, _) => null

case (obj: Row, struct: StructType) => rowToArray(obj, struct)
case (obj: Row, struct: StructType) => rowToJArray(obj, struct)

case (seq: Seq[Any], array: ArrayType) =>
seq.map(x => toJava(x, array.elementType)).asJava
Expand All @@ -402,22 +402,37 @@ class SchemaRDD(
case (other, _) => other
}

def rowToArray(row: Row, structType: StructType): Array[Any] = {
val fields = structType.fields.map(field => field.dataType)
row.zip(fields).map {
case (obj, dataType) => toJava(obj, dataType)
}.toArray
}
val fields = structType.fields.map(field => field.dataType)
row.zip(fields).map {
case (obj, dataType) => toJava(obj, dataType)
}.toArray
}

/**
* Converts a JavaRDD to a PythonRDD. It is used by pyspark.
*/
private[sql] def javaToPython: JavaRDD[Array[Byte]] = {
val rowSchema = StructType.fromAttributes(this.queryExecution.analyzed.output)
this.mapPartitions { iter =>
val pickle = new Pickler
iter.map { row =>
rowToArray(row, rowSchema)
rowToJArray(row, rowSchema)
}.grouped(100).map(batched => pickle.dumps(batched.toArray))
}
}

/**
* Serializes the Array[Row] returned by SchemaRDD's optimized collect(), using the same
* format as javaToPython. It is used by pyspark.
*/
private[sql] def collectToPython: JList[Array[Byte]] = {
val rowSchema = StructType.fromAttributes(this.queryExecution.analyzed.output)
val pickle = new Pickler
new java.util.ArrayList(collect().map { row =>
rowToJArray(row, rowSchema)
}.grouped(100).map(batched => pickle.dumps(batched.toArray)).toIterable)
}

/**
* Creates SchemaRDD by applying own schema to derived RDD. Typically used to wrap return value
* of base RDD functions that do not change schema.
Expand All @@ -433,7 +448,7 @@ class SchemaRDD(
}

// =======================================================================
// Overriden RDD actions
// Overridden RDD actions
// =======================================================================

override def collect(): Array[Row] = queryExecution.executedPlan.executeCollect()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,8 @@ class JavaSchemaRDD(
new java.util.ArrayList(arr)
}

override def count(): Long = baseSchemaRDD.count

override def take(num: Int): JList[Row] = {
import scala.collection.JavaConversions._
val arr: java.util.Collection[Row] = baseSchemaRDD.take(num).toSeq.map(new Row(_))
Expand Down