@@ -82,6 +82,19 @@ def inferSchema(self, rdd):
8282 >>> srdd.collect() == [{"field1" : 1, "field2" : "row1"}, {"field1" : 2, "field2": "row2"},
8383 ... {"field1" : 3, "field2": "row3"}]
8484 True
85+
86+ Nested collections are supported, which include array, dict, list, set, and tuple.
87+
88+ >>> from array import array
89+ >>> srdd = sqlCtx.inferSchema(nestedRdd1)
90+ >>> srdd.collect() == [{"f1" : array('i', [1, 2]), "f2" : {"row1" : 1.0}},
91+ ... {"f1" : array('i', [2, 3]), "f2" : {"row2" : 2.0}}]
92+ True
93+
94+ >>> srdd = sqlCtx.inferSchema(nestedRdd2)
95+ >>> srdd.collect() == [{"f1" : [[1, 2], [2, 3]], "f2" : set([1, 2]), "f3" : (1, 2)},
96+ ... {"f1" : [[2, 3], [3, 4]], "f2" : set([2, 3]), "f3" : (2, 3)}]
97+ True
8598 """
8699 if (rdd .__class__ is SchemaRDD ):
87100 raise ValueError ("Cannot apply schema to %s" % SchemaRDD .__name__ )
@@ -411,6 +424,7 @@ def subtract(self, other, numPartitions=None):
411424
412425def _test ():
413426 import doctest
427+ from array import array
414428 from pyspark .context import SparkContext
415429 globs = globals ().copy ()
416430 # The small batch size here ensures that we see multiple batches,
@@ -420,6 +434,12 @@ def _test():
420434 globs ['sqlCtx' ] = SQLContext (sc )
421435 globs ['rdd' ] = sc .parallelize ([{"field1" : 1 , "field2" : "row1" },
422436 {"field1" : 2 , "field2" : "row2" }, {"field1" : 3 , "field2" : "row3" }])
437+ globs ['nestedRdd1' ] = sc .parallelize ([
438+ {"f1" : array ('i' , [1 , 2 ]), "f2" : {"row1" : 1.0 }},
439+ {"f1" : array ('i' , [2 , 3 ]), "f2" : {"row2" : 2.0 }}])
440+ globs ['nestedRdd2' ] = sc .parallelize ([
441+ {"f1" : [[1 , 2 ], [2 , 3 ]], "f2" : set ([1 , 2 ]), "f3" : (1 , 2 )},
442+ {"f1" : [[2 , 3 ], [3 , 4 ]], "f2" : set ([2 , 3 ]), "f3" : (2 , 3 )}])
423443 (failure_count , test_count ) = doctest .testmod (globs = globs ,optionflags = doctest .ELLIPSIS )
424444 globs ['sc' ].stop ()
425445 if failure_count :
0 commit comments