@@ -77,12 +77,25 @@ def inferSchema(self, rdd):
7777 """Infer and apply a schema to an RDD of L{dict}s.
7878
7979 We peek at the first row of the RDD to determine the fields names
80- and types, and then use that to extract all the dictionaries.
80+ and types, and then use that to extract all the dictionaries. Nested
81+ collections are supported, which include array, dict, list, set, and
82+ tuple.
8183
8284 >>> srdd = sqlCtx.inferSchema(rdd)
8385 >>> srdd.collect() == [{"field1" : 1, "field2" : "row1"}, {"field1" : 2, "field2": "row2"},
8486 ... {"field1" : 3, "field2": "row3"}]
8587 True
88+
89+ >>> from array import array
90+ >>> srdd = sqlCtx.inferSchema(nestedRdd1)
91+ >>> srdd.collect() == [{"f1" : array('i', [1, 2]), "f2" : {"row1" : 1.0}},
92+ ... {"f1" : array('i', [2, 3]), "f2" : {"row2" : 2.0}}]
93+ True
94+
95+ >>> srdd = sqlCtx.inferSchema(nestedRdd2)
96+ >>> srdd.collect() == [{"f1" : [[1, 2], [2, 3]], "f2" : set([1, 2]), "f3" : (1, 2)},
97+ ... {"f1" : [[2, 3], [3, 4]], "f2" : set([2, 3]), "f3" : (2, 3)}]
98+ True
8699 """
87100 if (rdd .__class__ is SchemaRDD ):
88101 raise ValueError ("Cannot apply schema to %s" % SchemaRDD .__name__ )
@@ -413,6 +426,7 @@ def subtract(self, other, numPartitions=None):
413426
414427def _test ():
415428 import doctest
429+ from array import array
416430 from pyspark .context import SparkContext
417431 globs = globals ().copy ()
418432 # The small batch size here ensures that we see multiple batches,
@@ -422,6 +436,12 @@ def _test():
422436 globs ['sqlCtx' ] = SQLContext (sc )
423437 globs ['rdd' ] = sc .parallelize ([{"field1" : 1 , "field2" : "row1" },
424438 {"field1" : 2 , "field2" : "row2" }, {"field1" : 3 , "field2" : "row3" }])
439+ globs ['nestedRdd1' ] = sc .parallelize ([
440+ {"f1" : array ('i' , [1 , 2 ]), "f2" : {"row1" : 1.0 }},
441+ {"f1" : array ('i' , [2 , 3 ]), "f2" : {"row2" : 2.0 }}])
442+ globs ['nestedRdd2' ] = sc .parallelize ([
443+ {"f1" : [[1 , 2 ], [2 , 3 ]], "f2" : set ([1 , 2 ]), "f3" : (1 , 2 )},
444+ {"f1" : [[2 , 3 ], [3 , 4 ]], "f2" : set ([2 , 3 ]), "f3" : (2 , 3 )}])
425445 (failure_count , test_count ) = doctest .testmod (globs = globs ,optionflags = doctest .ELLIPSIS )
426446 globs ['sc' ].stop ()
427447 if failure_count :
0 commit comments