Skip to content

Commit 3adb095

Browse files
author
Davies Liu
committed
createDataFrame from dict and Row
1 parent b73a570 commit 3adb095

File tree

4 files changed

+59
-30
lines changed

4 files changed

+59
-30
lines changed

python/pyspark/sql/context.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -215,7 +215,7 @@ def _inferSchema(self, rdd, samplingRatio=None):
215215

216216
@since(1.3)
217217
@ignore_unicode_prefix
218-
def createDataFrame(self, data, schema=None, samplingRatio=None):
218+
def createDataFrame(self, data, schema=None, samplingRatio=None, verifySchema=True):
219219
"""
220220
Creates a :class:`DataFrame` from an :class:`RDD`, a list or a :class:`pandas.DataFrame`.
221221
@@ -245,6 +245,7 @@ def createDataFrame(self, data, schema=None, samplingRatio=None):
245245
``byte`` instead of ``tinyint`` for :class:`pyspark.sql.types.ByteType`.
246246
We can also use ``int`` as a short name for :class:`pyspark.sql.types.IntegerType`.
247247
:param samplingRatio: the sample ratio of rows used for inferring
248+
:param verifySchema: verify data types of very row against schema.
248249
:return: :class:`DataFrame`
249250
250251
.. versionchanged:: 2.0
@@ -253,6 +254,8 @@ def createDataFrame(self, data, schema=None, samplingRatio=None):
253254
If it's not a :class:`pyspark.sql.types.StructType`, it will be wrapped into a
254255
:class:`pyspark.sql.types.StructType` and each record will also be wrapped into a tuple.
255256
257+
Added verifySchema.
258+
256259
>>> l = [('Alice', 1)]
257260
>>> sqlContext.createDataFrame(l).collect()
258261
[Row(_1=u'Alice', _2=1)]
@@ -300,7 +303,7 @@ def createDataFrame(self, data, schema=None, samplingRatio=None):
300303
...
301304
Py4JJavaError: ...
302305
"""
303-
return self.sparkSession.createDataFrame(data, schema, samplingRatio)
306+
return self.sparkSession.createDataFrame(data, schema, samplingRatio, verifySchema)
304307

305308
@since(1.3)
306309
def registerDataFrameAsTable(self, df, tableName):

python/pyspark/sql/session.py

Lines changed: 12 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -384,17 +384,15 @@ def _createFromLocal(self, data, schema):
384384

385385
if schema is None or isinstance(schema, (list, tuple)):
386386
struct = self._inferSchemaFromList(data)
387+
converter = _create_converter(struct)
388+
data = map(converter, data)
387389
if isinstance(schema, (list, tuple)):
388390
for i, name in enumerate(schema):
389391
struct.fields[i].name = name
390392
struct.names[i] = name
391393
schema = struct
392394

393-
elif isinstance(schema, StructType):
394-
for row in data:
395-
_verify_type(row, schema)
396-
397-
else:
395+
elif not isinstance(schema, StructType):
398396
raise TypeError("schema should be StructType or list or None, but got: %s" % schema)
399397

400398
# convert python objects to sql data
@@ -403,7 +401,7 @@ def _createFromLocal(self, data, schema):
403401

404402
@since(2.0)
405403
@ignore_unicode_prefix
406-
def createDataFrame(self, data, schema=None, samplingRatio=None):
404+
def createDataFrame(self, data, schema=None, samplingRatio=None, verifySchema=True):
407405
"""
408406
Creates a :class:`DataFrame` from an :class:`RDD`, a list or a :class:`pandas.DataFrame`.
409407
@@ -432,14 +430,9 @@ def createDataFrame(self, data, schema=None, samplingRatio=None):
432430
``byte`` instead of ``tinyint`` for :class:`pyspark.sql.types.ByteType`. We can also use
433431
``int`` as a short name for ``IntegerType``.
434432
:param samplingRatio: the sample ratio of rows used for inferring
433+
:param verifySchema: verify data types of very row against schema.
435434
:return: :class:`DataFrame`
436435
437-
.. versionchanged:: 2.0
438-
The ``schema`` parameter can be a :class:`pyspark.sql.types.DataType` or a
439-
datatype string after 2.0. If it's not a
440-
:class:`pyspark.sql.types.StructType`, it will be wrapped into a
441-
:class:`pyspark.sql.types.StructType` and each record will also be wrapped into a tuple.
442-
443436
>>> l = [('Alice', 1)]
444437
>>> spark.createDataFrame(l).collect()
445438
[Row(_1=u'Alice', _2=1)]
@@ -503,17 +496,18 @@ def createDataFrame(self, data, schema=None, samplingRatio=None):
503496
schema = [str(x) for x in data.columns]
504497
data = [r.tolist() for r in data.to_records(index=False)]
505498

499+
verify_func = _verify_type if verifySchema else lambda _, t: True
506500
if isinstance(schema, StructType):
507501
def prepare(obj):
508-
_verify_type(obj, schema)
502+
verify_func(obj, schema)
509503
return obj
510-
elif isinstance(schema, DataType):
511-
datatype = schema
504+
if isinstance(schema, DataType):
505+
dataType = schema
506+
schema = StructType().add("value", schema)
512507

513508
def prepare(obj):
514-
_verify_type(obj, datatype)
515-
return (obj, )
516-
schema = StructType().add("value", datatype)
509+
verify_func(obj, dataType)
510+
return obj,
517511
else:
518512
if isinstance(schema, list):
519513
schema = [x.encode('utf-8') if not isinstance(x, str) else x for x in schema]

python/pyspark/sql/tests.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -411,6 +411,21 @@ def test_infer_schema_to_local(self):
411411
df3 = self.spark.createDataFrame(rdd, df.schema)
412412
self.assertEqual(10, df3.count())
413413

414+
def test_apply_schema_to_dict_and_rows(self):
415+
schema = StructType().add("b", StringType()).add("a", IntegerType())
416+
input = [{"a": 1}, {"b": "coffee"}]
417+
rdd = self.sc.parallelize(input)
418+
df = self.spark.createDataFrame(input, schema)
419+
df2 = self.spark.createDataFrame(rdd, schema)
420+
self.assertEqual(df.schema, df2.schema)
421+
422+
rdd = self.sc.parallelize(range(10)).map(lambda x: Row(a=x, b=None))
423+
df3 = self.spark.createDataFrame(rdd, schema)
424+
self.assertEqual(10, df3.count())
425+
input = [Row(a=x, b=str(x)) for x in range(10)]
426+
df4 = self.spark.createDataFrame(input, schema)
427+
self.assertEqual(10, df4.count())
428+
414429
def test_create_dataframe_schema_mismatch(self):
415430
input = [Row(a=1)]
416431
rdd = self.sc.parallelize(range(3)).map(lambda i: Row(a=i))

python/pyspark/sql/types.py

Lines changed: 27 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -582,6 +582,8 @@ def toInternal(self, obj):
582582
else:
583583
if isinstance(obj, dict):
584584
return tuple(obj.get(n) for n in self.names)
585+
elif isinstance(obj, Row) and getattr(obj, "__from_dict__", False):
586+
return tuple(obj[n] for n in self.names)
585587
elif isinstance(obj, (list, tuple)):
586588
return tuple(obj)
587589
elif hasattr(obj, "__dict__"):
@@ -1243,7 +1245,7 @@ def _infer_schema_type(obj, dataType):
12431245
TimestampType: (datetime.datetime,),
12441246
ArrayType: (list, tuple, array),
12451247
MapType: (dict,),
1246-
StructType: (tuple, list),
1248+
StructType: (tuple, list, dict),
12471249
}
12481250

12491251

@@ -1314,10 +1316,10 @@ def _verify_type(obj, dataType, nullable=True):
13141316
assert _type in _acceptable_types, "unknown datatype: %s for object %r" % (dataType, obj)
13151317

13161318
if _type is StructType:
1317-
if not isinstance(obj, (tuple, list)):
1318-
raise TypeError("StructType can not accept object %r in type %s" % (obj, type(obj)))
1319+
# check the type and fields later
1320+
pass
13191321
else:
1320-
# subclass of them can not be fromInternald in JVM
1322+
# subclass of them can not be fromInternal in JVM
13211323
if type(obj) not in _acceptable_types[_type]:
13221324
raise TypeError("%s can not accept object %r in type %s" % (dataType, obj, type(obj)))
13231325

@@ -1343,11 +1345,25 @@ def _verify_type(obj, dataType, nullable=True):
13431345
_verify_type(v, dataType.valueType, dataType.valueContainsNull)
13441346

13451347
elif isinstance(dataType, StructType):
1346-
if len(obj) != len(dataType.fields):
1347-
raise ValueError("Length of object (%d) does not match with "
1348-
"length of fields (%d)" % (len(obj), len(dataType.fields)))
1349-
for v, f in zip(obj, dataType.fields):
1350-
_verify_type(v, f.dataType, f.nullable)
1348+
if isinstance(obj, dict):
1349+
for f in dataType.fields:
1350+
_verify_type(obj.get(f.name), f.dataType, f.nullable)
1351+
elif isinstance(obj, Row) and getattr(obj, "__from_dict__", False):
1352+
# the order in obj could be different than dataType.fields
1353+
for f in dataType.fields:
1354+
_verify_type(obj[f.name], f.dataType, f.nullable)
1355+
elif isinstance(obj, (tuple, list)):
1356+
if len(obj) != len(dataType.fields):
1357+
raise ValueError("Length of object (%d) does not match with "
1358+
"length of fields (%d)" % (len(obj), len(dataType.fields)))
1359+
for v, f in zip(obj, dataType.fields):
1360+
_verify_type(v, f.dataType, f.nullable)
1361+
elif hasattr(obj, "__dict__"):
1362+
d = obj.__dict__
1363+
for f in dataType.fields:
1364+
_verify_type(d.get(f.name), f.dataType, f.nullable)
1365+
else:
1366+
raise TypeError("StructType can not accept object %r in type %s" % (obj, type(obj)))
13511367

13521368

13531369
# This is used to unpickle a Row from JVM
@@ -1410,6 +1426,7 @@ def __new__(self, *args, **kwargs):
14101426
names = sorted(kwargs.keys())
14111427
row = tuple.__new__(self, [kwargs[n] for n in names])
14121428
row.__fields__ = names
1429+
row.__from_dict__ = True
14131430
return row
14141431

14151432
else:
@@ -1485,7 +1502,7 @@ def __getattr__(self, item):
14851502
raise AttributeError(item)
14861503

14871504
def __setattr__(self, key, value):
1488-
if key != '__fields__':
1505+
if key != '__fields__' and key != "__from_dict__":
14891506
raise Exception("Row is read-only")
14901507
self.__dict__[key] = value
14911508

0 commit comments

Comments
 (0)