diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index d5002fa02992b..12b345a8fa7c3 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -776,7 +776,7 @@ private[spark] object PythonRDD extends Logging { } /** - * Convert and RDD of Java objects to and RDD of serialized Python objects, that is usable by + * Convert an RDD of Java objects to an RDD of serialized Python objects, that is usable by * PySpark. */ def javaToPython(jRDD: JavaRDD[Any]): JavaRDD[Array[Byte]] = { diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py index fc9310fef318c..aeed5f4d28918 100644 --- a/python/pyspark/sql.py +++ b/python/pyspark/sql.py @@ -30,6 +30,7 @@ from pyspark.rdd import RDD, PipelinedRDD from pyspark.serializers import BatchedSerializer, PickleSerializer, CloudPickleSerializer from pyspark.storagelevel import StorageLevel +from pyspark.traceback_utils import SCCallSiteSync from itertools import chain, ifilter, imap @@ -1550,6 +1551,18 @@ def id(self): self._id = self._jrdd.id() return self._id + def limit(self, num): + """Limit the result count to the number specified. + + >>> srdd = sqlCtx.inferSchema(rdd) + >>> srdd.limit(2).collect() + [Row(field1=1, field2=u'row1'), Row(field1=2, field2=u'row2')] + >>> srdd.limit(0).collect() + [] + """ + rdd = self._jschema_rdd.baseSchemaRDD().limit(num).toJavaSchemaRDD() + return SchemaRDD(rdd, self.sql_ctx) + def saveAsParquetFile(self, path): """Save the contents as a Parquet file, preserving the schema. @@ -1626,15 +1639,39 @@ def count(self): return self._jschema_rdd.count() def collect(self): - """ - Return a list that contains all of the rows in this RDD. + """Return a list that contains all of the rows in this RDD. - Each object in the list is on Row, the fields can be accessed as + Each object in the list is a Row, the fields can be accessed as attributes. + + Unlike the base RDD implementation of collect, this implementation + leverages the query optimizer to perform a collect on the SchemaRDD, + which supports features such as filter pushdown. + + >>> srdd = sqlCtx.inferSchema(rdd) + >>> srdd.collect() + [Row(field1=1, field2=u'row1'), ..., Row(field1=3, field2=u'row3')] """ - rows = RDD.collect(self) + with SCCallSiteSync(self.context) as css: + bytesInJava = self._jschema_rdd.baseSchemaRDD().collectToPython().iterator() cls = _create_cls(self.schema()) - return map(cls, rows) + return map(cls, self._collect_iterator_through_file(bytesInJava)) + + def take(self, num): + """Take the first num rows of the RDD. + + Each object in the list is a Row, the fields can be accessed as + attributes. + + Unlike the base RDD implementation of take, this implementation + leverages the query optimizer to perform a collect on a SchemaRDD, + which supports features such as filter pushdown. + + >>> srdd = sqlCtx.inferSchema(rdd) + >>> srdd.take(2) + [Row(field1=1, field2=u'row1'), Row(field1=2, field2=u'row2')] + """ + return self.limit(num).collect() # Convert each object in the RDD to a Row with the right class # for this SchemaRDD, so that fields can be accessed as attributes. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala index d2ceb4a2b0b25..3bc5dce095511 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala @@ -377,15 +377,15 @@ class SchemaRDD( def toJavaSchemaRDD: JavaSchemaRDD = new JavaSchemaRDD(sqlContext, logicalPlan) /** - * Converts a JavaRDD to a PythonRDD. It is used by pyspark. + * Helper for converting a Row to a simple Array suitable for pyspark serialization. */ - private[sql] def javaToPython: JavaRDD[Array[Byte]] = { + private def rowToJArray(row: Row, structType: StructType): Array[Any] = { import scala.collection.Map def toJava(obj: Any, dataType: DataType): Any = (obj, dataType) match { case (null, _) => null - case (obj: Row, struct: StructType) => rowToArray(obj, struct) + case (obj: Row, struct: StructType) => rowToJArray(obj, struct) case (seq: Seq[Any], array: ArrayType) => seq.map(x => toJava(x, array.elementType)).asJava @@ -402,22 +402,37 @@ class SchemaRDD( case (other, _) => other } - def rowToArray(row: Row, structType: StructType): Array[Any] = { - val fields = structType.fields.map(field => field.dataType) - row.zip(fields).map { - case (obj, dataType) => toJava(obj, dataType) - }.toArray - } + val fields = structType.fields.map(field => field.dataType) + row.zip(fields).map { + case (obj, dataType) => toJava(obj, dataType) + }.toArray + } + /** + * Converts a JavaRDD to a PythonRDD. It is used by pyspark. + */ + private[sql] def javaToPython: JavaRDD[Array[Byte]] = { val rowSchema = StructType.fromAttributes(this.queryExecution.analyzed.output) this.mapPartitions { iter => val pickle = new Pickler iter.map { row => - rowToArray(row, rowSchema) + rowToJArray(row, rowSchema) }.grouped(100).map(batched => pickle.dumps(batched.toArray)) } } + /** + * Serializes the Array[Row] returned by SchemaRDD's optimized collect(), using the same + * format as javaToPython. It is used by pyspark. + */ + private[sql] def collectToPython: JList[Array[Byte]] = { + val rowSchema = StructType.fromAttributes(this.queryExecution.analyzed.output) + val pickle = new Pickler + new java.util.ArrayList(collect().map { row => + rowToJArray(row, rowSchema) + }.grouped(100).map(batched => pickle.dumps(batched.toArray)).toIterable) + } + /** * Creates SchemaRDD by applying own schema to derived RDD. Typically used to wrap return value * of base RDD functions that do not change schema. @@ -433,7 +448,7 @@ class SchemaRDD( } // ======================================================================= - // Overriden RDD actions + // Overridden RDD actions // ======================================================================= override def collect(): Array[Row] = queryExecution.executedPlan.executeCollect() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSchemaRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSchemaRDD.scala index 4d799b4038fdd..e7faba0c7f620 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSchemaRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSchemaRDD.scala @@ -112,6 +112,8 @@ class JavaSchemaRDD( new java.util.ArrayList(arr) } + override def count(): Long = baseSchemaRDD.count + override def take(num: Int): JList[Row] = { import scala.collection.JavaConversions._ val arr: java.util.Collection[Row] = baseSchemaRDD.take(num).toSeq.map(new Row(_))