Skip to content

Commit 4fdb491

Browse files
kanzhangrxin
authored andcommitted
[SPARK-2010] Support for nested data in PySpark SQL
JIRA issue https://issues.apache.org/jira/browse/SPARK-2010 This PR adds support for nested collection types in PySpark SQL, including array, dict, list, set, and tuple. Example, ``` >>> from array import array >>> from pyspark.sql import SQLContext >>> sqlCtx = SQLContext(sc) >>> rdd = sc.parallelize([ ... {"f1" : array('i', [1, 2]), "f2" : {"row1" : 1.0}}, ... {"f1" : array('i', [2, 3]), "f2" : {"row2" : 2.0}}]) >>> srdd = sqlCtx.inferSchema(rdd) >>> srdd.collect() == [{"f1" : array('i', [1, 2]), "f2" : {"row1" : 1.0}}, ... {"f1" : array('i', [2, 3]), "f2" : {"row2" : 2.0}}] True >>> rdd = sc.parallelize([ ... {"f1" : [[1, 2], [2, 3]], "f2" : set([1, 2]), "f3" : (1, 2)}, ... {"f1" : [[2, 3], [3, 4]], "f2" : set([2, 3]), "f3" : (2, 3)}]) >>> srdd = sqlCtx.inferSchema(rdd) >>> srdd.collect() == \ ... [{"f1" : [[1, 2], [2, 3]], "f2" : set([1, 2]), "f3" : (1, 2)}, ... {"f1" : [[2, 3], [3, 4]], "f2" : set([2, 3]), "f3" : (2, 3)}] True ``` Author: Kan Zhang <[email protected]> Closes apache#1041 from kanzhang/SPARK-2010 and squashes the following commits: 1b2891d [Kan Zhang] [SPARK-2010] minor doc change and adding a TODO 504f27e [Kan Zhang] [SPARK-2010] Support for nested data in PySpark SQL
1 parent 716c88a commit 4fdb491

File tree

2 files changed

+40
-11
lines changed

2 files changed

+40
-11
lines changed

python/pyspark/sql.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,12 +77,25 @@ def inferSchema(self, rdd):
7777
"""Infer and apply a schema to an RDD of L{dict}s.
7878
7979
We peek at the first row of the RDD to determine the fields names
80-
and types, and then use that to extract all the dictionaries.
80+
and types, and then use that to extract all the dictionaries. Nested
81+
collections are supported, which include array, dict, list, set, and
82+
tuple.
8183
8284
>>> srdd = sqlCtx.inferSchema(rdd)
8385
>>> srdd.collect() == [{"field1" : 1, "field2" : "row1"}, {"field1" : 2, "field2": "row2"},
8486
... {"field1" : 3, "field2": "row3"}]
8587
True
88+
89+
>>> from array import array
90+
>>> srdd = sqlCtx.inferSchema(nestedRdd1)
91+
>>> srdd.collect() == [{"f1" : array('i', [1, 2]), "f2" : {"row1" : 1.0}},
92+
... {"f1" : array('i', [2, 3]), "f2" : {"row2" : 2.0}}]
93+
True
94+
95+
>>> srdd = sqlCtx.inferSchema(nestedRdd2)
96+
>>> srdd.collect() == [{"f1" : [[1, 2], [2, 3]], "f2" : set([1, 2]), "f3" : (1, 2)},
97+
... {"f1" : [[2, 3], [3, 4]], "f2" : set([2, 3]), "f3" : (2, 3)}]
98+
True
8699
"""
87100
if (rdd.__class__ is SchemaRDD):
88101
raise ValueError("Cannot apply schema to %s" % SchemaRDD.__name__)
@@ -413,6 +426,7 @@ def subtract(self, other, numPartitions=None):
413426

414427
def _test():
415428
import doctest
429+
from array import array
416430
from pyspark.context import SparkContext
417431
globs = globals().copy()
418432
# The small batch size here ensures that we see multiple batches,
@@ -422,6 +436,12 @@ def _test():
422436
globs['sqlCtx'] = SQLContext(sc)
423437
globs['rdd'] = sc.parallelize([{"field1" : 1, "field2" : "row1"},
424438
{"field1" : 2, "field2": "row2"}, {"field1" : 3, "field2": "row3"}])
439+
globs['nestedRdd1'] = sc.parallelize([
440+
{"f1" : array('i', [1, 2]), "f2" : {"row1" : 1.0}},
441+
{"f1" : array('i', [2, 3]), "f2" : {"row2" : 2.0}}])
442+
globs['nestedRdd2'] = sc.parallelize([
443+
{"f1" : [[1, 2], [2, 3]], "f2" : set([1, 2]), "f3" : (1, 2)},
444+
{"f1" : [[2, 3], [3, 4]], "f2" : set([2, 3]), "f3" : (2, 3)}])
425445
(failure_count, test_count) = doctest.testmod(globs=globs,optionflags=doctest.ELLIPSIS)
426446
globs['sc'].stop()
427447
if failure_count:

sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -298,19 +298,28 @@ class SQLContext(@transient val sparkContext: SparkContext)
298298

299299
/**
300300
* Peek at the first row of the RDD and infer its schema.
301-
* TODO: We only support primitive types, add support for nested types.
301+
* TODO: consolidate this with the type system developed in SPARK-2060.
302302
*/
303303
private[sql] def inferSchema(rdd: RDD[Map[String, _]]): SchemaRDD = {
304+
import scala.collection.JavaConversions._
305+
def typeFor(obj: Any): DataType = obj match {
306+
case c: java.lang.String => StringType
307+
case c: java.lang.Integer => IntegerType
308+
case c: java.lang.Long => LongType
309+
case c: java.lang.Double => DoubleType
310+
case c: java.lang.Boolean => BooleanType
311+
case c: java.util.List[_] => ArrayType(typeFor(c.head))
312+
case c: java.util.Set[_] => ArrayType(typeFor(c.head))
313+
case c: java.util.Map[_, _] =>
314+
val (key, value) = c.head
315+
MapType(typeFor(key), typeFor(value))
316+
case c if c.getClass.isArray =>
317+
val elem = c.asInstanceOf[Array[_]].head
318+
ArrayType(typeFor(elem))
319+
case c => throw new Exception(s"Object of type $c cannot be used")
320+
}
304321
val schema = rdd.first().map { case (fieldName, obj) =>
305-
val dataType = obj.getClass match {
306-
case c: Class[_] if c == classOf[java.lang.String] => StringType
307-
case c: Class[_] if c == classOf[java.lang.Integer] => IntegerType
308-
case c: Class[_] if c == classOf[java.lang.Long] => LongType
309-
case c: Class[_] if c == classOf[java.lang.Double] => DoubleType
310-
case c: Class[_] if c == classOf[java.lang.Boolean] => BooleanType
311-
case c => throw new Exception(s"Object of type $c cannot be used")
312-
}
313-
AttributeReference(fieldName, dataType, true)()
322+
AttributeReference(fieldName, typeFor(obj), true)()
314323
}.toSeq
315324

316325
val rowRdd = rdd.mapPartitions { iter =>

0 commit comments

Comments
 (0)