Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions python/pyspark/sql/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ def _inferSchema(self, rdd, samplingRatio=None):

@since(1.3)
@ignore_unicode_prefix
def createDataFrame(self, data, schema=None, samplingRatio=None):
def createDataFrame(self, data, schema=None, samplingRatio=None, verifySchema=True):
"""
Creates a :class:`DataFrame` from an :class:`RDD`, a list or a :class:`pandas.DataFrame`.

Expand Down Expand Up @@ -245,6 +245,7 @@ def createDataFrame(self, data, schema=None, samplingRatio=None):
``byte`` instead of ``tinyint`` for :class:`pyspark.sql.types.ByteType`.
We can also use ``int`` as a short name for :class:`pyspark.sql.types.IntegerType`.
:param samplingRatio: the sample ratio of rows used for inferring
:param verifySchema: verify data types of every row against schema.
:return: :class:`DataFrame`

.. versionchanged:: 2.0
Expand All @@ -253,6 +254,9 @@ def createDataFrame(self, data, schema=None, samplingRatio=None):
If it's not a :class:`pyspark.sql.types.StructType`, it will be wrapped into a
:class:`pyspark.sql.types.StructType` and each record will also be wrapped into a tuple.

.. versionchanged:: 2.1
Added verifySchema.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe say version changed 2.1 for "Added verifySchema"?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1. I wasn't aware of this, but it looks like it's possible to have multiple versionchanged directives in the same docstring.


>>> l = [('Alice', 1)]
>>> sqlContext.createDataFrame(l).collect()
[Row(_1=u'Alice', _2=1)]
Expand Down Expand Up @@ -300,7 +304,7 @@ def createDataFrame(self, data, schema=None, samplingRatio=None):
...
Py4JJavaError: ...
"""
return self.sparkSession.createDataFrame(data, schema, samplingRatio)
return self.sparkSession.createDataFrame(data, schema, samplingRatio, verifySchema)

@since(1.3)
def registerDataFrameAsTable(self, df, tableName):
Expand Down
29 changes: 13 additions & 16 deletions python/pyspark/sql/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,17 +384,15 @@ def _createFromLocal(self, data, schema):

if schema is None or isinstance(schema, (list, tuple)):
struct = self._inferSchemaFromList(data)
converter = _create_converter(struct)
Copy link
Contributor

@holdenk holdenk Aug 3, 2016

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why did we add this here?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This _create_converter method is confusingly-named: what it's actually doing here is converting data from a dict to a tuple in case the schema is a StructType and data is a Python dictionary.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is missed before.

data = map(converter, data)
if isinstance(schema, (list, tuple)):
for i, name in enumerate(schema):
struct.fields[i].name = name
struct.names[i] = name
schema = struct

elif isinstance(schema, StructType):
for row in data:
_verify_type(row, schema)

else:
elif not isinstance(schema, StructType):
raise TypeError("schema should be StructType or list or None, but got: %s" % schema)

# convert python objects to sql data
Expand All @@ -403,7 +401,7 @@ def _createFromLocal(self, data, schema):

@since(2.0)
@ignore_unicode_prefix
def createDataFrame(self, data, schema=None, samplingRatio=None):
def createDataFrame(self, data, schema=None, samplingRatio=None, verifySchema=True):
"""
Creates a :class:`DataFrame` from an :class:`RDD`, a list or a :class:`pandas.DataFrame`.

Expand Down Expand Up @@ -432,13 +430,11 @@ def createDataFrame(self, data, schema=None, samplingRatio=None):
``byte`` instead of ``tinyint`` for :class:`pyspark.sql.types.ByteType`. We can also use
``int`` as a short name for ``IntegerType``.
:param samplingRatio: the sample ratio of rows used for inferring
:param verifySchema: verify data types of every row against schema.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 on also adding a versionchanged directive for this.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added

:return: :class:`DataFrame`

.. versionchanged:: 2.0
The ``schema`` parameter can be a :class:`pyspark.sql.types.DataType` or a
datatype string after 2.0. If it's not a
:class:`pyspark.sql.types.StructType`, it will be wrapped into a
:class:`pyspark.sql.types.StructType` and each record will also be wrapped into a tuple.
.. versionchanged:: 2.1
Added verifySchema.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Out of interest why are we removing this note but keeping the other 2.0 change note? Just wondering so that when I'm making my changes for 2.1 I can do the right thing.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@davies, I'm also slightly confused by this documentation change since it looks like the new 2.x behavior of wrapping single-field datatypes into structtypes and values into tuples is preserved by this patch. Could you clarify?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This API is new in 2.0 (for SparkSession), so remove them. we could add a change for verifySchema.


>>> l = [('Alice', 1)]
>>> spark.createDataFrame(l).collect()
Expand Down Expand Up @@ -503,17 +499,18 @@ def createDataFrame(self, data, schema=None, samplingRatio=None):
schema = [str(x) for x in data.columns]
data = [r.tolist() for r in data.to_records(index=False)]

verify_func = _verify_type if verifySchema else lambda _, t: True
if isinstance(schema, StructType):
def prepare(obj):
_verify_type(obj, schema)
verify_func(obj, schema)
return obj
elif isinstance(schema, DataType):
datatype = schema
dataType = schema
schema = StructType().add("value", schema)

def prepare(obj):
_verify_type(obj, datatype)
return (obj, )
schema = StructType().add("value", datatype)
verify_func(obj, dataType)
return obj,
else:
if isinstance(schema, list):
schema = [x.encode('utf-8') if not isinstance(x, str) else x for x in schema]
Expand Down
16 changes: 16 additions & 0 deletions python/pyspark/sql/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,6 +411,22 @@ def test_infer_schema_to_local(self):
df3 = self.spark.createDataFrame(rdd, df.schema)
self.assertEqual(10, df3.count())

def test_apply_schema_to_dict_and_rows(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we also add a test to exercise the verifySchema=False case?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added

schema = StructType().add("b", StringType()).add("a", IntegerType())
input = [{"a": 1}, {"b": "coffee"}]
rdd = self.sc.parallelize(input)
for verify in [False, True]:
df = self.spark.createDataFrame(input, schema, verifySchema=verify)
df2 = self.spark.createDataFrame(rdd, schema, verifySchema=verify)
self.assertEqual(df.schema, df2.schema)

rdd = self.sc.parallelize(range(10)).map(lambda x: Row(a=x, b=None))
df3 = self.spark.createDataFrame(rdd, schema, verifySchema=verify)
self.assertEqual(10, df3.count())
input = [Row(a=x, b=str(x)) for x in range(10)]
df4 = self.spark.createDataFrame(input, schema, verifySchema=verify)
self.assertEqual(10, df4.count())

def test_create_dataframe_schema_mismatch(self):
input = [Row(a=1)]
rdd = self.sc.parallelize(range(3)).map(lambda i: Row(a=i))
Expand Down
37 changes: 27 additions & 10 deletions python/pyspark/sql/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -582,6 +582,8 @@ def toInternal(self, obj):
else:
if isinstance(obj, dict):
return tuple(obj.get(n) for n in self.names)
elif isinstance(obj, Row) and getattr(obj, "__from_dict__", False):
return tuple(obj[n] for n in self.names)
elif isinstance(obj, (list, tuple)):
return tuple(obj)
elif hasattr(obj, "__dict__"):
Expand Down Expand Up @@ -1243,7 +1245,7 @@ def _infer_schema_type(obj, dataType):
TimestampType: (datetime.datetime,),
ArrayType: (list, tuple, array),
MapType: (dict,),
StructType: (tuple, list),
StructType: (tuple, list, dict),
}


Expand Down Expand Up @@ -1314,10 +1316,10 @@ def _verify_type(obj, dataType, nullable=True):
assert _type in _acceptable_types, "unknown datatype: %s for object %r" % (dataType, obj)

if _type is StructType:
if not isinstance(obj, (tuple, list)):
raise TypeError("StructType can not accept object %r in type %s" % (obj, type(obj)))
# check the type and fields later
pass
else:
# subclass of them can not be fromInternald in JVM
# subclass of them can not be fromInternal in JVM
if type(obj) not in _acceptable_types[_type]:
raise TypeError("%s can not accept object %r in type %s" % (dataType, obj, type(obj)))

Expand All @@ -1343,11 +1345,25 @@ def _verify_type(obj, dataType, nullable=True):
_verify_type(v, dataType.valueType, dataType.valueContainsNull)

elif isinstance(dataType, StructType):
if len(obj) != len(dataType.fields):
raise ValueError("Length of object (%d) does not match with "
"length of fields (%d)" % (len(obj), len(dataType.fields)))
for v, f in zip(obj, dataType.fields):
_verify_type(v, f.dataType, f.nullable)
if isinstance(obj, dict):
for f in dataType.fields:
_verify_type(obj.get(f.name), f.dataType, f.nullable)
elif isinstance(obj, Row) and getattr(obj, "__from_dict__", False):
# the order in obj could be different than dataType.fields
for f in dataType.fields:
_verify_type(obj[f.name], f.dataType, f.nullable)
elif isinstance(obj, (tuple, list)):
if len(obj) != len(dataType.fields):
raise ValueError("Length of object (%d) does not match with "
"length of fields (%d)" % (len(obj), len(dataType.fields)))
for v, f in zip(obj, dataType.fields):
_verify_type(v, f.dataType, f.nullable)
elif hasattr(obj, "__dict__"):
d = obj.__dict__
for f in dataType.fields:
_verify_type(d.get(f.name), f.dataType, f.nullable)
else:
raise TypeError("StructType can not accept object %r in type %s" % (obj, type(obj)))


# This is used to unpickle a Row from JVM
Expand Down Expand Up @@ -1410,6 +1426,7 @@ def __new__(self, *args, **kwargs):
names = sorted(kwargs.keys())
row = tuple.__new__(self, [kwargs[n] for n in names])
row.__fields__ = names
row.__from_dict__ = True
return row

else:
Expand Down Expand Up @@ -1485,7 +1502,7 @@ def __getattr__(self, item):
raise AttributeError(item)

def __setattr__(self, key, value):
if key != '__fields__':
if key != '__fields__' and key != "__from_dict__":
raise Exception("Row is read-only")
self.__dict__[key] = value

Expand Down