Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 28 additions & 21 deletions python/pyspark/sql/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -2391,6 +2391,13 @@ def test_verify_type_exception_msg(self):
except Exception as e:
self.assertTrue(str(e).startswith(name))

schema = StructType([StructField('a', StructType([StructField('b', IntegerType())]))])
try:
_verify_type([["data"]], schema)
self.fail('Expected _verify_type() to throw so test can check exception message')
except Exception as e:
self.assertTrue(str(e).startswith("a.b:"))

def test_verify_type_ok_nullable(self):
obj = None
for data_type in [IntegerType(), FloatType(), StringType(), StructType([])]:
Expand All @@ -2406,7 +2413,7 @@ def test_verify_type_not_nullable(self):
import datetime
import decimal

MyStructType = StructType([
schema = StructType([
StructField('s', StringType(), nullable=False),
StructField('i', IntegerType(), nullable=True)])

Expand Down Expand Up @@ -2501,26 +2508,26 @@ def __init__(self, **ka):
ValueError),

# Struct
({"s": "a", "i": 1}, MyStructType, None),
({"s": "a", "i": None}, MyStructType, None),
({"s": "a"}, MyStructType, None),
({"s": "a", "f": 1.0}, MyStructType, None), # Extra fields OK
({"s": "a", "i": "1"}, MyStructType, TypeError),
(Row(s="a", i=1), MyStructType, None),
(Row(s="a", i=None), MyStructType, None),
(Row(s="a", i=1, f=1.0), MyStructType, None), # Extra fields OK
(Row(s="a"), MyStructType, ValueError), # Row can't have missing field
(Row(s="a", i="1"), MyStructType, TypeError),
(["a", 1], MyStructType, None),
(["a", None], MyStructType, None),
(["a"], MyStructType, ValueError),
(["a", "1"], MyStructType, TypeError),
(("a", 1), MyStructType, None),
(MyObj(s="a", i=1), MyStructType, None),
(MyObj(s="a", i=None), MyStructType, None),
(MyObj(s="a"), MyStructType, None),
(MyObj(s="a", i="1"), MyStructType, TypeError),
(MyObj(s=None, i="1"), MyStructType, ValueError),
({"s": "a", "i": 1}, schema, None),
({"s": "a", "i": None}, schema, None),
({"s": "a"}, schema, None),
({"s": "a", "f": 1.0}, schema, None), # Extra fields OK
({"s": "a", "i": "1"}, schema, TypeError),
(Row(s="a", i=1), schema, None),
(Row(s="a", i=None), schema, None),
(Row(s="a", i=1, f=1.0), schema, None), # Extra fields OK
(Row(s="a"), schema, ValueError), # Row can't have missing field
(Row(s="a", i="1"), schema, TypeError),
(["a", 1], schema, None),
(["a", None], schema, None),
(["a"], schema, ValueError),
(["a", "1"], schema, TypeError),
(("a", 1), schema, None),
(MyObj(s="a", i=1), schema, None),
(MyObj(s="a", i=None), schema, None),
(MyObj(s="a"), schema, None),
(MyObj(s="a", i="1"), schema, TypeError),
(MyObj(s=None, i="1"), schema, ValueError),
]

for obj, data_type, exp in spec:
Expand Down
64 changes: 37 additions & 27 deletions python/pyspark/sql/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -1249,7 +1249,7 @@ def _infer_schema_type(obj, dataType):
}


def _verify_type(obj, dataType, nullable=True, name="obj"):
def _verify_type(obj, dataType, nullable=True, name=None):
"""
Verify the type of obj against dataType, raise a TypeError if they do not match.

Expand Down Expand Up @@ -1296,84 +1296,94 @@ def _verify_type(obj, dataType, nullable=True, name="obj"):
...
ValueError:...
"""

if name is None:
new_msg = lambda msg: msg
new_element_name = lambda idx: "[%d]" % idx
new_key_name = lambda key: "[%s](key)" % key
new_value_name = lambda key: "[%s]" % key
new_name = lambda n: n
else:
new_msg = lambda msg: "%s: %s" % (name, msg)
new_element_name = lambda idx: "%s[%d]" % (name, idx)
new_key_name = lambda key: "%s[%s](key)" % (name, key)
new_value_name = lambda key: "%s[%s]" % (name, key)
new_name = lambda n: "%s.%s" % (name, n)

if obj is None:
if nullable:
return
else:
raise ValueError("%s: This field is not nullable, but got None" % name)
raise ValueError(new_msg("This field is not nullable, but got None"))

# StringType can work with any types
if isinstance(dataType, StringType):
return

if isinstance(dataType, UserDefinedType):
if not (hasattr(obj, '__UDT__') and obj.__UDT__ == dataType):
raise ValueError("%s: %r is not an instance of type %r" % (name, obj, dataType))
raise ValueError(new_msg("%r is not an instance of type %r" % (obj, dataType)))
_verify_type(dataType.toInternal(obj), dataType.sqlType(), name=name)
return

_type = type(dataType)
assert _type in _acceptable_types, \
"%s: unknown datatype: %s for object %r" % (name, dataType, obj)
new_msg("unknown datatype: %s for object %r" % (dataType, obj))

if _type is StructType:
# check the type and fields later
pass
else:
# subclass of them can not be fromInternal in JVM
if type(obj) not in _acceptable_types[_type]:
raise TypeError("%s: %s can not accept object %r in type %s"
% (name, dataType, obj, type(obj)))
raise TypeError(new_msg("%s can not accept object %r in type %s"
% (dataType, obj, type(obj))))

if isinstance(dataType, ByteType):
if obj < -128 or obj > 127:
raise ValueError("%s: object of ByteType out of range, got: %s" % (name, obj))
raise ValueError(new_msg("object of ByteType out of range, got: %s" % obj))

elif isinstance(dataType, ShortType):
if obj < -32768 or obj > 32767:
raise ValueError("%s: object of ShortType out of range, got: %s" % (name, obj))
raise ValueError(new_msg("object of ShortType out of range, got: %s" % obj))

elif isinstance(dataType, IntegerType):
if obj < -2147483648 or obj > 2147483647:
raise ValueError("%s: object of IntegerType out of range, got: %s" % (name, obj))
raise ValueError(new_msg("object of IntegerType out of range, got: %s" % obj))

elif isinstance(dataType, ArrayType):
for i, value in enumerate(obj):
new_name = "%s[%d]" % (name, i)
_verify_type(value, dataType.elementType, dataType.containsNull, name=new_name)
_verify_type(
value, dataType.elementType, dataType.containsNull, name=new_element_name(i))

elif isinstance(dataType, MapType):
for k, v in obj.items():
new_name = "%s[%s](key)" % (name, k)
_verify_type(k, dataType.keyType, False, name=new_name)
new_name = "%s[%s]" % (name, k)
_verify_type(v, dataType.valueType, dataType.valueContainsNull, name=new_name)
_verify_type(k, dataType.keyType, False, name=new_key_name(k))
_verify_type(
v, dataType.valueType, dataType.valueContainsNull, name=new_value_name(k))

elif isinstance(dataType, StructType):
if isinstance(obj, dict):
for f in dataType.fields:
new_name = "%s.%s" % (name, f.name)
_verify_type(obj.get(f.name), f.dataType, f.nullable, name=new_name)
_verify_type(obj.get(f.name), f.dataType, f.nullable, name=new_name(f.name))
elif isinstance(obj, Row) and getattr(obj, "__from_dict__", False):
# the order in obj could be different than dataType.fields
for f in dataType.fields:
new_name = "%s.%s" % (name, f.name)
_verify_type(obj[f.name], f.dataType, f.nullable, name=new_name)
_verify_type(obj[f.name], f.dataType, f.nullable, name=new_name(f.name))
elif isinstance(obj, (tuple, list)):
if len(obj) != len(dataType.fields):
raise ValueError("%s: Length of object (%d) does not match with "
"length of fields (%d)" % (name, len(obj), len(dataType.fields)))
raise ValueError(
new_msg("Length of object (%d) does not match with "
"length of fields (%d)" % (len(obj), len(dataType.fields))))
for v, f in zip(obj, dataType.fields):
new_name = "%s.%s" % (name, f.name)
_verify_type(v, f.dataType, f.nullable, name=new_name)
_verify_type(v, f.dataType, f.nullable, name=new_name(f.name))
elif hasattr(obj, "__dict__"):
d = obj.__dict__
for f in dataType.fields:
new_name = "%s.%s" % (name, f.name)
_verify_type(d.get(f.name), f.dataType, f.nullable, name=new_name)
_verify_type(d.get(f.name), f.dataType, f.nullable, name=new_name(f.name))
else:
raise TypeError("%s: StructType can not accept object %r in type %s"
% (name, obj, type(obj)))
raise TypeError(new_msg("StructType can not accept object %r in type %s"
% (obj, type(obj))))


# This is used to unpickle a Row from JVM
Expand Down