Skip to content

Commit 504f27e

Browse files
committed
[SPARK-2010] Support for nested data in PySpark SQL
1 parent 8919685 commit 504f27e

File tree

2 files changed

+38
-10
lines changed

2 files changed

+38
-10
lines changed

python/pyspark/sql.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,19 @@ def inferSchema(self, rdd):
8282
>>> srdd.collect() == [{"field1" : 1, "field2" : "row1"}, {"field1" : 2, "field2": "row2"},
8383
... {"field1" : 3, "field2": "row3"}]
8484
True
85+
86+
Nested collections are supported, which include array, dict, list, set, and tuple.
87+
88+
>>> from array import array
89+
>>> srdd = sqlCtx.inferSchema(nestedRdd1)
90+
>>> srdd.collect() == [{"f1" : array('i', [1, 2]), "f2" : {"row1" : 1.0}},
91+
... {"f1" : array('i', [2, 3]), "f2" : {"row2" : 2.0}}]
92+
True
93+
94+
>>> srdd = sqlCtx.inferSchema(nestedRdd2)
95+
>>> srdd.collect() == [{"f1" : [[1, 2], [2, 3]], "f2" : set([1, 2]), "f3" : (1, 2)},
96+
... {"f1" : [[2, 3], [3, 4]], "f2" : set([2, 3]), "f3" : (2, 3)}]
97+
True
8598
"""
8699
if (rdd.__class__ is SchemaRDD):
87100
raise ValueError("Cannot apply schema to %s" % SchemaRDD.__name__)
@@ -411,6 +424,7 @@ def subtract(self, other, numPartitions=None):
411424

412425
def _test():
413426
import doctest
427+
from array import array
414428
from pyspark.context import SparkContext
415429
globs = globals().copy()
416430
# The small batch size here ensures that we see multiple batches,
@@ -420,6 +434,12 @@ def _test():
420434
globs['sqlCtx'] = SQLContext(sc)
421435
globs['rdd'] = sc.parallelize([{"field1" : 1, "field2" : "row1"},
422436
{"field1" : 2, "field2": "row2"}, {"field1" : 3, "field2": "row3"}])
437+
globs['nestedRdd1'] = sc.parallelize([
438+
{"f1" : array('i', [1, 2]), "f2" : {"row1" : 1.0}},
439+
{"f1" : array('i', [2, 3]), "f2" : {"row2" : 2.0}}])
440+
globs['nestedRdd2'] = sc.parallelize([
441+
{"f1" : [[1, 2], [2, 3]], "f2" : set([1, 2]), "f3" : (1, 2)},
442+
{"f1" : [[2, 3], [3, 4]], "f2" : set([2, 3]), "f3" : (2, 3)}])
423443
(failure_count, test_count) = doctest.testmod(globs=globs,optionflags=doctest.ELLIPSIS)
424444
globs['sc'].stop()
425445
if failure_count:

sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -298,19 +298,27 @@ class SQLContext(@transient val sparkContext: SparkContext)
298298

299299
/**
300300
* Peek at the first row of the RDD and infer its schema.
301-
* TODO: We only support primitive types, add support for nested types.
302301
*/
303302
private[sql] def inferSchema(rdd: RDD[Map[String, _]]): SchemaRDD = {
303+
import scala.collection.JavaConversions._
304+
def typeFor(obj: Any): DataType = obj match {
305+
case c: java.lang.String => StringType
306+
case c: java.lang.Integer => IntegerType
307+
case c: java.lang.Long => LongType
308+
case c: java.lang.Double => DoubleType
309+
case c: java.lang.Boolean => BooleanType
310+
case c: java.util.List[_] => ArrayType(typeFor(c.head))
311+
case c: java.util.Set[_] => ArrayType(typeFor(c.head))
312+
case c: java.util.Map[_, _] =>
313+
val (key, value) = c.head
314+
MapType(typeFor(key), typeFor(value))
315+
case c if c.getClass.isArray =>
316+
val elem = c.asInstanceOf[Array[_]].head
317+
ArrayType(typeFor(elem))
318+
case c => throw new Exception(s"Object of type $c cannot be used")
319+
}
304320
val schema = rdd.first().map { case (fieldName, obj) =>
305-
val dataType = obj.getClass match {
306-
case c: Class[_] if c == classOf[java.lang.String] => StringType
307-
case c: Class[_] if c == classOf[java.lang.Integer] => IntegerType
308-
case c: Class[_] if c == classOf[java.lang.Long] => LongType
309-
case c: Class[_] if c == classOf[java.lang.Double] => DoubleType
310-
case c: Class[_] if c == classOf[java.lang.Boolean] => BooleanType
311-
case c => throw new Exception(s"Object of type $c cannot be used")
312-
}
313-
AttributeReference(fieldName, dataType, true)()
321+
AttributeReference(fieldName, typeFor(obj), true)()
314322
}.toSeq
315323

316324
val rowRdd = rdd.mapPartitions { iter =>

0 commit comments

Comments
 (0)