Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 18 additions & 17 deletions python/pyspark/sql/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -1249,7 +1249,7 @@ def _infer_schema_type(obj, dataType):
}


def _verify_type(obj, dataType, nullable=True):
def _verify_type(obj, dataType, nullable=True, name=None):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we need a doctest below.

"""
Verify the type of obj against dataType, raise a TypeError if they do not match.

Expand Down Expand Up @@ -1300,70 +1300,71 @@ def _verify_type(obj, dataType, nullable=True):
if nullable:
return
else:
raise ValueError("This field is not nullable, but got None")
raise ValueError("This field ({}, of type {}) is not nullable, but got None".format(
name, dataType))

# StringType can work with any types
if isinstance(dataType, StringType):
return

if isinstance(dataType, UserDefinedType):
if not (hasattr(obj, '__UDT__') and obj.__UDT__ == dataType):
raise ValueError("%r is not an instance of type %r" % (obj, dataType))
raise ValueError("%r is not an instance of type %r for field %s" % (obj, dataType, name))
_verify_type(dataType.toInternal(obj), dataType.sqlType())
return

_type = type(dataType)
assert _type in _acceptable_types, "unknown datatype: %s for object %r" % (dataType, obj)
assert _type in _acceptable_types, "unknown datatype: %s for object %r for field %s" % (dataType, obj, name)

if _type is StructType:
# check the type and fields later
pass
else:
# subclass of them can not be fromInternal in JVM
if type(obj) not in _acceptable_types[_type]:
raise TypeError("%s can not accept object %r in type %s" % (dataType, obj, type(obj)))
raise TypeError("%s can not accept object %r in type %s for field %s" % (dataType, obj, type(obj), name))

if isinstance(dataType, ByteType):
if obj < -128 or obj > 127:
raise ValueError("object of ByteType out of range, got: %s" % obj)
raise ValueError("object of ByteType out of range, got: %s for field %s" % (obj, name))

elif isinstance(dataType, ShortType):
if obj < -32768 or obj > 32767:
raise ValueError("object of ShortType out of range, got: %s" % obj)
raise ValueError("object of ShortType out of range, got: %s for field %s" % (obj, name))

elif isinstance(dataType, IntegerType):
if obj < -2147483648 or obj > 2147483647:
raise ValueError("object of IntegerType out of range, got: %s" % obj)
raise ValueError("object of IntegerType out of range, got: %s for field %s" % (obj, name))

elif isinstance(dataType, ArrayType):
for i in obj:
_verify_type(i, dataType.elementType, dataType.containsNull)
_verify_type(i, dataType.elementType, dataType.containsNull, name)

elif isinstance(dataType, MapType):
for k, v in obj.items():
_verify_type(k, dataType.keyType, False)
_verify_type(v, dataType.valueType, dataType.valueContainsNull)
_verify_type(k, dataType.keyType, False, name)
_verify_type(v, dataType.valueType, dataType.valueContainsNull, name)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might also want to flag individual array/map elements.


elif isinstance(dataType, StructType):
if isinstance(obj, dict):
for f in dataType.fields:
_verify_type(obj.get(f.name), f.dataType, f.nullable)
_verify_type(obj.get(f.name), f.dataType, f.nullable, f.name)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This doesn't work that well for nested structs:

MySubType = StructType([StructField('value', StringType(), nullable=False)])
MyType = StructType([
    StructField('one', MySubType),
    StructField('two', MySubType)])

_verify_type({'one': {'value': 'good'}, 'two': {'value': None}}, MyType)
# "This field (value, of type StringType) is not nullable, but got None"
# But is it one.value or two.value?

elif isinstance(obj, Row) and getattr(obj, "__from_dict__", False):
# the order in obj could be different than dataType.fields
for f in dataType.fields:
_verify_type(obj[f.name], f.dataType, f.nullable)
_verify_type(obj[f.name], f.dataType, f.nullable, f.name)
elif isinstance(obj, (tuple, list)):
if len(obj) != len(dataType.fields):
raise ValueError("Length of object (%d) does not match with "
"length of fields (%d)" % (len(obj), len(dataType.fields)))
"length of fields (%d) for field %s" % (len(obj), len(dataType.fields), name))
for v, f in zip(obj, dataType.fields):
_verify_type(v, f.dataType, f.nullable)
_verify_type(v, f.dataType, f.nullable, f.name)
elif hasattr(obj, "__dict__"):
d = obj.__dict__
for f in dataType.fields:
_verify_type(d.get(f.name), f.dataType, f.nullable)
_verify_type(d.get(f.name), f.dataType, f.nullable, f.name)
else:
raise TypeError("StructType can not accept object %r in type %s" % (obj, type(obj)))
raise TypeError("StructType can not accept object %r in type %s for field %s" % (obj, type(obj), name))


# This is used to unpickle a Row from JVM
Expand Down