diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 2bd7ab1b87134..08f4b4dc7b637 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -2391,6 +2391,13 @@ def test_verify_type_exception_msg(self): except Exception as e: self.assertTrue(str(e).startswith(name)) + schema = StructType([StructField('a', StructType([StructField('b', IntegerType())]))]) + try: + _verify_type([["data"]], schema) + self.fail('Expected _verify_type() to throw so test can check exception message') + except Exception as e: + self.assertTrue(str(e).startswith("a.b:")) + def test_verify_type_ok_nullable(self): obj = None for data_type in [IntegerType(), FloatType(), StringType(), StructType([])]: @@ -2406,7 +2413,7 @@ def test_verify_type_not_nullable(self): import datetime import decimal - MyStructType = StructType([ + schema = StructType([ StructField('s', StringType(), nullable=False), StructField('i', IntegerType(), nullable=True)]) @@ -2501,26 +2508,26 @@ def __init__(self, **ka): ValueError), # Struct - ({"s": "a", "i": 1}, MyStructType, None), - ({"s": "a", "i": None}, MyStructType, None), - ({"s": "a"}, MyStructType, None), - ({"s": "a", "f": 1.0}, MyStructType, None), # Extra fields OK - ({"s": "a", "i": "1"}, MyStructType, TypeError), - (Row(s="a", i=1), MyStructType, None), - (Row(s="a", i=None), MyStructType, None), - (Row(s="a", i=1, f=1.0), MyStructType, None), # Extra fields OK - (Row(s="a"), MyStructType, ValueError), # Row can't have missing field - (Row(s="a", i="1"), MyStructType, TypeError), - (["a", 1], MyStructType, None), - (["a", None], MyStructType, None), - (["a"], MyStructType, ValueError), - (["a", "1"], MyStructType, TypeError), - (("a", 1), MyStructType, None), - (MyObj(s="a", i=1), MyStructType, None), - (MyObj(s="a", i=None), MyStructType, None), - (MyObj(s="a"), MyStructType, None), - (MyObj(s="a", i="1"), MyStructType, TypeError), - (MyObj(s=None, i="1"), MyStructType, ValueError), + ({"s": "a", "i": 1}, schema, None), + ({"s": "a", "i": None}, schema, None), + ({"s": "a"}, schema, None), + ({"s": "a", "f": 1.0}, schema, None), # Extra fields OK + ({"s": "a", "i": "1"}, schema, TypeError), + (Row(s="a", i=1), schema, None), + (Row(s="a", i=None), schema, None), + (Row(s="a", i=1, f=1.0), schema, None), # Extra fields OK + (Row(s="a"), schema, ValueError), # Row can't have missing field + (Row(s="a", i="1"), schema, TypeError), + (["a", 1], schema, None), + (["a", None], schema, None), + (["a"], schema, ValueError), + (["a", "1"], schema, TypeError), + (("a", 1), schema, None), + (MyObj(s="a", i=1), schema, None), + (MyObj(s="a", i=None), schema, None), + (MyObj(s="a"), schema, None), + (MyObj(s="a", i="1"), schema, TypeError), + (MyObj(s=None, i="1"), schema, ValueError), ] for obj, data_type, exp in spec: diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index 4cd0075954596..4d0c682dc14ac 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -1249,7 +1249,7 @@ def _infer_schema_type(obj, dataType): } -def _verify_type(obj, dataType, nullable=True, name="obj"): +def _verify_type(obj, dataType, nullable=True, name=None): """ Verify the type of obj against dataType, raise a TypeError if they do not match. @@ -1296,11 +1296,25 @@ def _verify_type(obj, dataType, nullable=True, name="obj"): ... ValueError:... """ + + if name is None: + new_msg = lambda msg: msg + new_element_name = lambda idx: "[%d]" % idx + new_key_name = lambda key: "[%s](key)" % key + new_value_name = lambda key: "[%s]" % key + new_name = lambda n: n + else: + new_msg = lambda msg: "%s: %s" % (name, msg) + new_element_name = lambda idx: "%s[%d]" % (name, idx) + new_key_name = lambda key: "%s[%s](key)" % (name, key) + new_value_name = lambda key: "%s[%s]" % (name, key) + new_name = lambda n: "%s.%s" % (name, n) + if obj is None: if nullable: return else: - raise ValueError("%s: This field is not nullable, but got None" % name) + raise ValueError(new_msg("This field is not nullable, but got None")) # StringType can work with any types if isinstance(dataType, StringType): @@ -1308,13 +1322,13 @@ def _verify_type(obj, dataType, nullable=True, name="obj"): if isinstance(dataType, UserDefinedType): if not (hasattr(obj, '__UDT__') and obj.__UDT__ == dataType): - raise ValueError("%s: %r is not an instance of type %r" % (name, obj, dataType)) + raise ValueError(new_msg("%r is not an instance of type %r" % (obj, dataType))) _verify_type(dataType.toInternal(obj), dataType.sqlType(), name=name) return _type = type(dataType) assert _type in _acceptable_types, \ - "%s: unknown datatype: %s for object %r" % (name, dataType, obj) + new_msg("unknown datatype: %s for object %r" % (dataType, obj)) if _type is StructType: # check the type and fields later @@ -1322,58 +1336,54 @@ def _verify_type(obj, dataType, nullable=True, name="obj"): else: # subclass of them can not be fromInternal in JVM if type(obj) not in _acceptable_types[_type]: - raise TypeError("%s: %s can not accept object %r in type %s" - % (name, dataType, obj, type(obj))) + raise TypeError(new_msg("%s can not accept object %r in type %s" + % (dataType, obj, type(obj)))) if isinstance(dataType, ByteType): if obj < -128 or obj > 127: - raise ValueError("%s: object of ByteType out of range, got: %s" % (name, obj)) + raise ValueError(new_msg("object of ByteType out of range, got: %s" % obj)) elif isinstance(dataType, ShortType): if obj < -32768 or obj > 32767: - raise ValueError("%s: object of ShortType out of range, got: %s" % (name, obj)) + raise ValueError(new_msg("object of ShortType out of range, got: %s" % obj)) elif isinstance(dataType, IntegerType): if obj < -2147483648 or obj > 2147483647: - raise ValueError("%s: object of IntegerType out of range, got: %s" % (name, obj)) + raise ValueError(new_msg("object of IntegerType out of range, got: %s" % obj)) elif isinstance(dataType, ArrayType): for i, value in enumerate(obj): - new_name = "%s[%d]" % (name, i) - _verify_type(value, dataType.elementType, dataType.containsNull, name=new_name) + _verify_type( + value, dataType.elementType, dataType.containsNull, name=new_element_name(i)) elif isinstance(dataType, MapType): for k, v in obj.items(): - new_name = "%s[%s](key)" % (name, k) - _verify_type(k, dataType.keyType, False, name=new_name) - new_name = "%s[%s]" % (name, k) - _verify_type(v, dataType.valueType, dataType.valueContainsNull, name=new_name) + _verify_type(k, dataType.keyType, False, name=new_key_name(k)) + _verify_type( + v, dataType.valueType, dataType.valueContainsNull, name=new_value_name(k)) elif isinstance(dataType, StructType): if isinstance(obj, dict): for f in dataType.fields: - new_name = "%s.%s" % (name, f.name) - _verify_type(obj.get(f.name), f.dataType, f.nullable, name=new_name) + _verify_type(obj.get(f.name), f.dataType, f.nullable, name=new_name(f.name)) elif isinstance(obj, Row) and getattr(obj, "__from_dict__", False): # the order in obj could be different than dataType.fields for f in dataType.fields: - new_name = "%s.%s" % (name, f.name) - _verify_type(obj[f.name], f.dataType, f.nullable, name=new_name) + _verify_type(obj[f.name], f.dataType, f.nullable, name=new_name(f.name)) elif isinstance(obj, (tuple, list)): if len(obj) != len(dataType.fields): - raise ValueError("%s: Length of object (%d) does not match with " - "length of fields (%d)" % (name, len(obj), len(dataType.fields))) + raise ValueError( + new_msg("Length of object (%d) does not match with " + "length of fields (%d)" % (len(obj), len(dataType.fields)))) for v, f in zip(obj, dataType.fields): - new_name = "%s.%s" % (name, f.name) - _verify_type(v, f.dataType, f.nullable, name=new_name) + _verify_type(v, f.dataType, f.nullable, name=new_name(f.name)) elif hasattr(obj, "__dict__"): d = obj.__dict__ for f in dataType.fields: - new_name = "%s.%s" % (name, f.name) - _verify_type(d.get(f.name), f.dataType, f.nullable, name=new_name) + _verify_type(d.get(f.name), f.dataType, f.nullable, name=new_name(f.name)) else: - raise TypeError("%s: StructType can not accept object %r in type %s" - % (name, obj, type(obj))) + raise TypeError(new_msg("StructType can not accept object %r in type %s" + % (obj, type(obj)))) # This is used to unpickle a Row from JVM