Skip to content

Commit fffb0c0

Browse files
Davies LiuJoshRosen
authored andcommitted
[SPARK-16700][PYSPARK][SQL] create DataFrame from dict/Row with schema
## What changes were proposed in this pull request? In 2.0, we verify the data type against schema for every row for safety, but with performance cost, this PR make it optional. When we verify the data type for StructType, it does not support all the types we support in infer schema (for example, dict), this PR fix that to make them consistent. For Row object which is created using named arguments, the order of fields are sorted by name, they may be not different than the order in provided schema, this PR fix that by ignore the order of fields in this case. ## How was this patch tested? Created regression tests for them. Author: Davies Liu <[email protected]> Closes #14469 from davies/py_dict.
1 parent 5da6c4b commit fffb0c0

File tree

4 files changed

+62
-28
lines changed

4 files changed

+62
-28
lines changed

python/pyspark/sql/context.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -215,7 +215,7 @@ def _inferSchema(self, rdd, samplingRatio=None):
215215

216216
@since(1.3)
217217
@ignore_unicode_prefix
218-
def createDataFrame(self, data, schema=None, samplingRatio=None):
218+
def createDataFrame(self, data, schema=None, samplingRatio=None, verifySchema=True):
219219
"""
220220
Creates a :class:`DataFrame` from an :class:`RDD`, a list or a :class:`pandas.DataFrame`.
221221
@@ -245,6 +245,7 @@ def createDataFrame(self, data, schema=None, samplingRatio=None):
245245
``byte`` instead of ``tinyint`` for :class:`pyspark.sql.types.ByteType`.
246246
We can also use ``int`` as a short name for :class:`pyspark.sql.types.IntegerType`.
247247
:param samplingRatio: the sample ratio of rows used for inferring
248+
:param verifySchema: verify data types of every row against schema.
248249
:return: :class:`DataFrame`
249250
250251
.. versionchanged:: 2.0
@@ -253,6 +254,9 @@ def createDataFrame(self, data, schema=None, samplingRatio=None):
253254
If it's not a :class:`pyspark.sql.types.StructType`, it will be wrapped into a
254255
:class:`pyspark.sql.types.StructType` and each record will also be wrapped into a tuple.
255256
257+
.. versionchanged:: 2.1
258+
Added verifySchema.
259+
256260
>>> l = [('Alice', 1)]
257261
>>> sqlContext.createDataFrame(l).collect()
258262
[Row(_1=u'Alice', _2=1)]
@@ -300,7 +304,7 @@ def createDataFrame(self, data, schema=None, samplingRatio=None):
300304
...
301305
Py4JJavaError: ...
302306
"""
303-
return self.sparkSession.createDataFrame(data, schema, samplingRatio)
307+
return self.sparkSession.createDataFrame(data, schema, samplingRatio, verifySchema)
304308

305309
@since(1.3)
306310
def registerDataFrameAsTable(self, df, tableName):

python/pyspark/sql/session.py

Lines changed: 13 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -384,17 +384,15 @@ def _createFromLocal(self, data, schema):
384384

385385
if schema is None or isinstance(schema, (list, tuple)):
386386
struct = self._inferSchemaFromList(data)
387+
converter = _create_converter(struct)
388+
data = map(converter, data)
387389
if isinstance(schema, (list, tuple)):
388390
for i, name in enumerate(schema):
389391
struct.fields[i].name = name
390392
struct.names[i] = name
391393
schema = struct
392394

393-
elif isinstance(schema, StructType):
394-
for row in data:
395-
_verify_type(row, schema)
396-
397-
else:
395+
elif not isinstance(schema, StructType):
398396
raise TypeError("schema should be StructType or list or None, but got: %s" % schema)
399397

400398
# convert python objects to sql data
@@ -403,7 +401,7 @@ def _createFromLocal(self, data, schema):
403401

404402
@since(2.0)
405403
@ignore_unicode_prefix
406-
def createDataFrame(self, data, schema=None, samplingRatio=None):
404+
def createDataFrame(self, data, schema=None, samplingRatio=None, verifySchema=True):
407405
"""
408406
Creates a :class:`DataFrame` from an :class:`RDD`, a list or a :class:`pandas.DataFrame`.
409407
@@ -432,13 +430,11 @@ def createDataFrame(self, data, schema=None, samplingRatio=None):
432430
``byte`` instead of ``tinyint`` for :class:`pyspark.sql.types.ByteType`. We can also use
433431
``int`` as a short name for ``IntegerType``.
434432
:param samplingRatio: the sample ratio of rows used for inferring
433+
:param verifySchema: verify data types of every row against schema.
435434
:return: :class:`DataFrame`
436435
437-
.. versionchanged:: 2.0
438-
The ``schema`` parameter can be a :class:`pyspark.sql.types.DataType` or a
439-
datatype string after 2.0. If it's not a
440-
:class:`pyspark.sql.types.StructType`, it will be wrapped into a
441-
:class:`pyspark.sql.types.StructType` and each record will also be wrapped into a tuple.
436+
.. versionchanged:: 2.1
437+
Added verifySchema.
442438
443439
>>> l = [('Alice', 1)]
444440
>>> spark.createDataFrame(l).collect()
@@ -503,17 +499,18 @@ def createDataFrame(self, data, schema=None, samplingRatio=None):
503499
schema = [str(x) for x in data.columns]
504500
data = [r.tolist() for r in data.to_records(index=False)]
505501

502+
verify_func = _verify_type if verifySchema else lambda _, t: True
506503
if isinstance(schema, StructType):
507504
def prepare(obj):
508-
_verify_type(obj, schema)
505+
verify_func(obj, schema)
509506
return obj
510507
elif isinstance(schema, DataType):
511-
datatype = schema
508+
dataType = schema
509+
schema = StructType().add("value", schema)
512510

513511
def prepare(obj):
514-
_verify_type(obj, datatype)
515-
return (obj, )
516-
schema = StructType().add("value", datatype)
512+
verify_func(obj, dataType)
513+
return obj,
517514
else:
518515
if isinstance(schema, list):
519516
schema = [x.encode('utf-8') if not isinstance(x, str) else x for x in schema]

python/pyspark/sql/tests.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -411,6 +411,22 @@ def test_infer_schema_to_local(self):
411411
df3 = self.spark.createDataFrame(rdd, df.schema)
412412
self.assertEqual(10, df3.count())
413413

414+
def test_apply_schema_to_dict_and_rows(self):
415+
schema = StructType().add("b", StringType()).add("a", IntegerType())
416+
input = [{"a": 1}, {"b": "coffee"}]
417+
rdd = self.sc.parallelize(input)
418+
for verify in [False, True]:
419+
df = self.spark.createDataFrame(input, schema, verifySchema=verify)
420+
df2 = self.spark.createDataFrame(rdd, schema, verifySchema=verify)
421+
self.assertEqual(df.schema, df2.schema)
422+
423+
rdd = self.sc.parallelize(range(10)).map(lambda x: Row(a=x, b=None))
424+
df3 = self.spark.createDataFrame(rdd, schema, verifySchema=verify)
425+
self.assertEqual(10, df3.count())
426+
input = [Row(a=x, b=str(x)) for x in range(10)]
427+
df4 = self.spark.createDataFrame(input, schema, verifySchema=verify)
428+
self.assertEqual(10, df4.count())
429+
414430
def test_create_dataframe_schema_mismatch(self):
415431
input = [Row(a=1)]
416432
rdd = self.sc.parallelize(range(3)).map(lambda i: Row(a=i))

python/pyspark/sql/types.py

Lines changed: 27 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -582,6 +582,8 @@ def toInternal(self, obj):
582582
else:
583583
if isinstance(obj, dict):
584584
return tuple(obj.get(n) for n in self.names)
585+
elif isinstance(obj, Row) and getattr(obj, "__from_dict__", False):
586+
return tuple(obj[n] for n in self.names)
585587
elif isinstance(obj, (list, tuple)):
586588
return tuple(obj)
587589
elif hasattr(obj, "__dict__"):
@@ -1243,7 +1245,7 @@ def _infer_schema_type(obj, dataType):
12431245
TimestampType: (datetime.datetime,),
12441246
ArrayType: (list, tuple, array),
12451247
MapType: (dict,),
1246-
StructType: (tuple, list),
1248+
StructType: (tuple, list, dict),
12471249
}
12481250

12491251

@@ -1314,10 +1316,10 @@ def _verify_type(obj, dataType, nullable=True):
13141316
assert _type in _acceptable_types, "unknown datatype: %s for object %r" % (dataType, obj)
13151317

13161318
if _type is StructType:
1317-
if not isinstance(obj, (tuple, list)):
1318-
raise TypeError("StructType can not accept object %r in type %s" % (obj, type(obj)))
1319+
# check the type and fields later
1320+
pass
13191321
else:
1320-
# subclass of them can not be fromInternald in JVM
1322+
# subclass of them can not be fromInternal in JVM
13211323
if type(obj) not in _acceptable_types[_type]:
13221324
raise TypeError("%s can not accept object %r in type %s" % (dataType, obj, type(obj)))
13231325

@@ -1343,11 +1345,25 @@ def _verify_type(obj, dataType, nullable=True):
13431345
_verify_type(v, dataType.valueType, dataType.valueContainsNull)
13441346

13451347
elif isinstance(dataType, StructType):
1346-
if len(obj) != len(dataType.fields):
1347-
raise ValueError("Length of object (%d) does not match with "
1348-
"length of fields (%d)" % (len(obj), len(dataType.fields)))
1349-
for v, f in zip(obj, dataType.fields):
1350-
_verify_type(v, f.dataType, f.nullable)
1348+
if isinstance(obj, dict):
1349+
for f in dataType.fields:
1350+
_verify_type(obj.get(f.name), f.dataType, f.nullable)
1351+
elif isinstance(obj, Row) and getattr(obj, "__from_dict__", False):
1352+
# the order in obj could be different than dataType.fields
1353+
for f in dataType.fields:
1354+
_verify_type(obj[f.name], f.dataType, f.nullable)
1355+
elif isinstance(obj, (tuple, list)):
1356+
if len(obj) != len(dataType.fields):
1357+
raise ValueError("Length of object (%d) does not match with "
1358+
"length of fields (%d)" % (len(obj), len(dataType.fields)))
1359+
for v, f in zip(obj, dataType.fields):
1360+
_verify_type(v, f.dataType, f.nullable)
1361+
elif hasattr(obj, "__dict__"):
1362+
d = obj.__dict__
1363+
for f in dataType.fields:
1364+
_verify_type(d.get(f.name), f.dataType, f.nullable)
1365+
else:
1366+
raise TypeError("StructType can not accept object %r in type %s" % (obj, type(obj)))
13511367

13521368

13531369
# This is used to unpickle a Row from JVM
@@ -1410,6 +1426,7 @@ def __new__(self, *args, **kwargs):
14101426
names = sorted(kwargs.keys())
14111427
row = tuple.__new__(self, [kwargs[n] for n in names])
14121428
row.__fields__ = names
1429+
row.__from_dict__ = True
14131430
return row
14141431

14151432
else:
@@ -1485,7 +1502,7 @@ def __getattr__(self, item):
14851502
raise AttributeError(item)
14861503

14871504
def __setattr__(self, key, value):
1488-
if key != '__fields__':
1505+
if key != '__fields__' and key != "__from_dict__":
14891506
raise Exception("Row is read-only")
14901507
self.__dict__[key] = value
14911508

0 commit comments

Comments
 (0)