From 935c1676af9c182613a9e0ed0f7aea766bebe2e1 Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Fri, 11 Dec 2015 18:00:38 -0800 Subject: [PATCH 1/4] Look at all elements of the local collection when infering the types --- python/pyspark/sql/context.py | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py index b05aa2f5c4cd7..8f35a50c4ea60 100644 --- a/python/pyspark/sql/context.py +++ b/python/pyspark/sql/context.py @@ -236,14 +236,9 @@ def _inferSchemaFromList(self, data): if type(first) is dict: warnings.warn("inferring schema from dict is deprecated," "please use pyspark.sql.Row instead") - schema = _infer_schema(first) + schema = reduce(_merge_type, map(_infer_schema, data)) if _has_nulltype(schema): - for r in data: - schema = _merge_type(schema, _infer_schema(r)) - if not _has_nulltype(schema): - break - else: - raise ValueError("Some of types cannot be determined after inferring") + raise ValueError("Some of types cannot be determined after inferring") return schema def _inferSchema(self, rdd, samplingRatio=None): From 7ca7ecad1faea3b3746aec10130ebf07f4e2a3ad Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Fri, 11 Dec 2015 18:00:48 -0800 Subject: [PATCH 2/4] Add a test --- python/pyspark/sql/tests.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 9f5f7cfdf7a69..46461263d4df3 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -353,6 +353,16 @@ def test_apply_schema_to_row(self): df3 = self.sqlCtx.createDataFrame(rdd, df.schema) self.assertEqual(10, df3.count()) + def test_infer_schema_to_local(self): + input = [{"a": 1}, {"b": "coffee"}] + df = self.sqlCtx.createDataFrame(input) + df2 = self.sqlCtx.createDataFrame(sc.parallelize(input), samplingRatio=1.0) + self.assertEqual(df.schema(), df2.schema()) + + rdd = self.sc.parallelize(range(10)).map(lambda x: Row(a=x)) + df3 = self.sqlCtx.createDataFrame(rdd, df.schema) + self.assertEqual(10, df3.count()) + def test_serialize_nested_array_and_map(self): d = [Row(l=[Row(a=1, b='s')], d={"key": Row(c=1.0, d="2")})] rdd = self.sc.parallelize(d) From 6e463fdec6dae7f5b9b2f712e751bf12b995cfff Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Fri, 11 Dec 2015 20:21:54 -0800 Subject: [PATCH 3/4] Fix test --- python/pyspark/sql/tests.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 46461263d4df3..10b99175ad952 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -355,9 +355,10 @@ def test_apply_schema_to_row(self): def test_infer_schema_to_local(self): input = [{"a": 1}, {"b": "coffee"}] + rdd = self.sc.parallelize(input) df = self.sqlCtx.createDataFrame(input) - df2 = self.sqlCtx.createDataFrame(sc.parallelize(input), samplingRatio=1.0) - self.assertEqual(df.schema(), df2.schema()) + df2 = self.sqlCtx.createDataFrame(rdd, samplingRatio=1.0) + self.assertEqual(df.schema, df2.schema) rdd = self.sc.parallelize(range(10)).map(lambda x: Row(a=x)) df3 = self.sqlCtx.createDataFrame(rdd, df.schema) From b863cc93eb2a5ddbeaae2d2d2233530772e0f94c Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Sat, 12 Dec 2015 00:00:37 -0800 Subject: [PATCH 4/4] python 3 import reduce from functools --- python/pyspark/sql/context.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py index 8f35a50c4ea60..ba6915a12347e 100644 --- a/python/pyspark/sql/context.py +++ b/python/pyspark/sql/context.py @@ -18,6 +18,7 @@ import sys import warnings import json +from functools import reduce if sys.version >= '3': basestring = unicode = str