-
Notifications
You must be signed in to change notification settings - Fork 28.9k
[SPARK-19507][PySpark][SQL] Show field name in _verify_type error #17227
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
9c0d20c
34dbb78
fcd2067
5b4324e
2351153
6c1e0b6
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -30,6 +30,19 @@ | |
| import functools | ||
| import time | ||
| import datetime | ||
| import traceback | ||
|
|
||
| if sys.version_info[:2] <= (2, 6): | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not a big deal but I guess we dropped 2.6 support.
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Looks like most of the other tests still have the
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yea, let's leave it then. Not a big deal. |
||
| try: | ||
| import unittest2 as unittest | ||
| except ImportError: | ||
| sys.stderr.write('Please install unittest2 to test with Python 2.6 or earlier') | ||
| sys.exit(1) | ||
| else: | ||
| import unittest | ||
| if sys.version_info[0] >= 3: | ||
| xrange = range | ||
| basestring = str | ||
|
|
||
| import py4j | ||
| try: | ||
|
|
@@ -49,7 +62,7 @@ | |
| from pyspark import SparkContext | ||
| from pyspark.sql import SparkSession, SQLContext, HiveContext, Column, Row | ||
| from pyspark.sql.types import * | ||
| from pyspark.sql.types import UserDefinedType, _infer_type | ||
| from pyspark.sql.types import UserDefinedType, _infer_type, _verify_type | ||
| from pyspark.tests import ReusedPySparkTestCase, SparkSubmitTests | ||
| from pyspark.sql.functions import UserDefinedFunction, sha2, lit | ||
| from pyspark.sql.window import Window | ||
|
|
@@ -2367,6 +2380,162 @@ def range_frame_match(): | |
|
|
||
| importlib.reload(window) | ||
|
|
||
|
|
||
| class TypesTest(unittest.TestCase): | ||
|
|
||
| def test_verify_type_exception_msg(self): | ||
| name = "test_name" | ||
| try: | ||
| _verify_type(None, StringType(), nullable=False, name=name) | ||
| self.fail('Expected _verify_type() to throw so test can check exception message') | ||
| except Exception as e: | ||
| self.assertTrue(str(e).startswith(name)) | ||
|
|
||
| def test_verify_type_ok_nullable(self): | ||
| obj = None | ||
| for data_type in [IntegerType(), FloatType(), StringType(), StructType([])]: | ||
| msg = "_verify_type(%s, %s, nullable=True)" % (obj, data_type) | ||
| try: | ||
| _verify_type(obj, data_type, nullable=True) | ||
| except Exception as e: | ||
| traceback.print_exc() | ||
| self.fail(msg) | ||
|
|
||
| def test_verify_type_not_nullable(self): | ||
| import array | ||
| import datetime | ||
| import decimal | ||
|
|
||
| MyStructType = StructType([ | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could we make the first character this lower-cased? (or maybe just simply |
||
| StructField('s', StringType(), nullable=False), | ||
| StructField('i', IntegerType(), nullable=True)]) | ||
|
|
||
| class MyObj: | ||
| def __init__(self, **ka): | ||
| for k, v in ka.items(): | ||
| setattr(self, k, v) | ||
|
|
||
| # obj, data_type, exception (None for success or Exception subclass for error) | ||
| spec = [ | ||
| # Strings (match anything but None) | ||
| ("", StringType(), None), | ||
| (u"", StringType(), None), | ||
| (1, StringType(), None), | ||
| (1.0, StringType(), None), | ||
| ([], StringType(), None), | ||
| ({}, StringType(), None), | ||
| (None, StringType(), ValueError), # Only None test | ||
|
|
||
| # UDT | ||
| (ExamplePoint(1.0, 2.0), ExamplePointUDT(), None), | ||
| (ExamplePoint(1.0, 2.0), PythonOnlyUDT(), ValueError), | ||
|
|
||
| # Boolean | ||
| (True, BooleanType(), None), | ||
| (1, BooleanType(), TypeError), | ||
| ("True", BooleanType(), TypeError), | ||
| ([1], BooleanType(), TypeError), | ||
|
|
||
| # Bytes | ||
| (-(2**7) - 1, ByteType(), ValueError), | ||
| (-(2**7), ByteType(), None), | ||
| (2**7 - 1, ByteType(), None), | ||
| (2**7, ByteType(), ValueError), | ||
| ("1", ByteType(), TypeError), | ||
| (1.0, ByteType(), TypeError), | ||
|
|
||
| # Shorts | ||
| (-(2**15) - 1, ShortType(), ValueError), | ||
| (-(2**15), ShortType(), None), | ||
| (2**15 - 1, ShortType(), None), | ||
| (2**15, ShortType(), ValueError), | ||
|
|
||
| # Integer | ||
| (-(2**31) - 1, IntegerType(), ValueError), | ||
| (-(2**31), IntegerType(), None), | ||
| (2**31 - 1, IntegerType(), None), | ||
| (2**31, IntegerType(), ValueError), | ||
|
|
||
| # Long | ||
| (2**64, LongType(), None), | ||
|
|
||
| # Float & Double | ||
| (1.0, FloatType(), None), | ||
| (1, FloatType(), TypeError), | ||
| (1.0, DoubleType(), None), | ||
| (1, DoubleType(), TypeError), | ||
|
|
||
| # Decimal | ||
| (decimal.Decimal("1.0"), DecimalType(), None), | ||
| (1.0, DecimalType(), TypeError), | ||
| (1, DecimalType(), TypeError), | ||
| ("1.0", DecimalType(), TypeError), | ||
|
|
||
| # Binary | ||
| (bytearray([1, 2]), BinaryType(), None), | ||
| (1, BinaryType(), TypeError), | ||
|
|
||
| # Date/Time | ||
| (datetime.date(2000, 1, 2), DateType(), None), | ||
| (datetime.datetime(2000, 1, 2, 3, 4), DateType(), None), | ||
| ("2000-01-02", DateType(), TypeError), | ||
| (datetime.datetime(2000, 1, 2, 3, 4), TimestampType(), None), | ||
| (946811040, TimestampType(), TypeError), | ||
|
|
||
| # Array | ||
| ([], ArrayType(IntegerType()), None), | ||
| (["1", None], ArrayType(StringType(), containsNull=True), None), | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'd like you to add
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Added |
||
| (["1", None], ArrayType(StringType(), containsNull=False), ValueError), | ||
| ([1, 2], ArrayType(IntegerType()), None), | ||
| ([1, "2"], ArrayType(IntegerType()), TypeError), | ||
| ((1, 2), ArrayType(IntegerType()), None), | ||
| (array.array('h', [1, 2]), ArrayType(IntegerType()), None), | ||
|
|
||
| # Map | ||
| ({}, MapType(StringType(), IntegerType()), None), | ||
| ({"a": 1}, MapType(StringType(), IntegerType()), None), | ||
| ({"a": 1}, MapType(IntegerType(), IntegerType()), TypeError), | ||
| ({"a": "1"}, MapType(StringType(), IntegerType()), TypeError), | ||
| ({"a": None}, MapType(StringType(), IntegerType(), valueContainsNull=True), None), | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'd also like you to add
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Added |
||
| ({"a": None}, MapType(StringType(), IntegerType(), valueContainsNull=False), | ||
| ValueError), | ||
|
|
||
| # Struct | ||
| ({"s": "a", "i": 1}, MyStructType, None), | ||
| ({"s": "a", "i": None}, MyStructType, None), | ||
| ({"s": "a"}, MyStructType, None), | ||
| ({"s": "a", "f": 1.0}, MyStructType, None), # Extra fields OK | ||
| ({"s": "a", "i": "1"}, MyStructType, TypeError), | ||
| (Row(s="a", i=1), MyStructType, None), | ||
| (Row(s="a", i=None), MyStructType, None), | ||
| (Row(s="a", i=1, f=1.0), MyStructType, None), # Extra fields OK | ||
| (Row(s="a"), MyStructType, ValueError), # Row can't have missing field | ||
| (Row(s="a", i="1"), MyStructType, TypeError), | ||
| (["a", 1], MyStructType, None), | ||
| (["a", None], MyStructType, None), | ||
| (["a"], MyStructType, ValueError), | ||
| (["a", "1"], MyStructType, TypeError), | ||
| (("a", 1), MyStructType, None), | ||
| (MyObj(s="a", i=1), MyStructType, None), | ||
| (MyObj(s="a", i=None), MyStructType, None), | ||
| (MyObj(s="a"), MyStructType, None), | ||
| (MyObj(s="a", i="1"), MyStructType, TypeError), | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same here,
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Added |
||
| (MyObj(s=None, i="1"), MyStructType, ValueError), | ||
| ] | ||
|
|
||
| for obj, data_type, exp in spec: | ||
| msg = "_verify_type(%s, %s, nullable=False) == %s" % (obj, data_type, exp) | ||
| if exp is None: | ||
| try: | ||
| _verify_type(obj, data_type, nullable=False) | ||
| except Exception: | ||
| traceback.print_exc() | ||
| self.fail(msg) | ||
| else: | ||
| with self.assertRaises(exp, msg=msg): | ||
| _verify_type(obj, data_type, nullable=False) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| from pyspark.sql.tests import * | ||
| if xmlrunner: | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1249,7 +1249,7 @@ def _infer_schema_type(obj, dataType): | |
| } | ||
|
|
||
|
|
||
| def _verify_type(obj, dataType, nullable=True): | ||
| def _verify_type(obj, dataType, nullable=True, name="obj"): | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just a question. @dgingrich Do you maybe know if there is any change that "obj" is printed instead? It is rather a nitpick but I would think it is odds if it prints "obj".
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This will print "obj" when called from
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's fix this case. >>> from pyspark.sql.types import *
>>> spark.createDataFrame(["a"], StringType()).printSchema()>>> from pyspark.sql.types import *
>>> spark.createDataFrame(["a"], IntegerType()).printSchema()It sounds "obj" should be "value". It looks we should specify the name around https://github.com/dgingrich/spark/blob/topic-spark-19507-verify-types/python/pyspark/sql/session.py#L516.
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I guess this is only place that we print "obj" maybe? If so, let's set
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Set It will still print Right now changing the default name to None would make the error message worse: The best way to make the error message pretty is probably:
That would make your exmple: IMO
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could we maybe then |
||
| """ | ||
| Verify the type of obj against dataType, raise a TypeError if they do not match. | ||
|
|
||
|
|
@@ -1300,70 +1300,80 @@ def _verify_type(obj, dataType, nullable=True): | |
| if nullable: | ||
| return | ||
| else: | ||
| raise ValueError("This field is not nullable, but got None") | ||
| raise ValueError("%s: This field is not nullable, but got None" % name) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Probably, I missed something. However, is there any test case that actually checks this message change?
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No, I did't test the actual exception message. I normally don't check the contents of exception messages since they shouldn't be used programmatically (the tests are mostly to exercise all code paths to make sure I didn't break something). But here it makes sense to check that the prefix is set since that's the main point of the PR. Added a test looking for the exception message prefix. |
||
|
|
||
| # StringType can work with any types | ||
| if isinstance(dataType, StringType): | ||
| return | ||
|
|
||
| if isinstance(dataType, UserDefinedType): | ||
| if not (hasattr(obj, '__UDT__') and obj.__UDT__ == dataType): | ||
| raise ValueError("%r is not an instance of type %r" % (obj, dataType)) | ||
| _verify_type(dataType.toInternal(obj), dataType.sqlType()) | ||
| raise ValueError("%s: %r is not an instance of type %r" % (name, obj, dataType)) | ||
| _verify_type(dataType.toInternal(obj), dataType.sqlType(), name=name) | ||
| return | ||
|
|
||
| _type = type(dataType) | ||
| assert _type in _acceptable_types, "unknown datatype: %s for object %r" % (dataType, obj) | ||
| assert _type in _acceptable_types, \ | ||
| "%s: unknown datatype: %s for object %r" % (name, dataType, obj) | ||
|
|
||
| if _type is StructType: | ||
| # check the type and fields later | ||
| pass | ||
| else: | ||
| # subclass of them can not be fromInternal in JVM | ||
| if type(obj) not in _acceptable_types[_type]: | ||
| raise TypeError("%s can not accept object %r in type %s" % (dataType, obj, type(obj))) | ||
| raise TypeError("%s: %s can not accept object %r in type %s" | ||
| % (name, dataType, obj, type(obj))) | ||
|
|
||
| if isinstance(dataType, ByteType): | ||
| if obj < -128 or obj > 127: | ||
| raise ValueError("object of ByteType out of range, got: %s" % obj) | ||
| raise ValueError("%s: object of ByteType out of range, got: %s" % (name, obj)) | ||
|
|
||
| elif isinstance(dataType, ShortType): | ||
| if obj < -32768 or obj > 32767: | ||
| raise ValueError("object of ShortType out of range, got: %s" % obj) | ||
| raise ValueError("%s: object of ShortType out of range, got: %s" % (name, obj)) | ||
|
|
||
| elif isinstance(dataType, IntegerType): | ||
| if obj < -2147483648 or obj > 2147483647: | ||
| raise ValueError("object of IntegerType out of range, got: %s" % obj) | ||
| raise ValueError("%s: object of IntegerType out of range, got: %s" % (name, obj)) | ||
|
|
||
| elif isinstance(dataType, ArrayType): | ||
| for i in obj: | ||
| _verify_type(i, dataType.elementType, dataType.containsNull) | ||
| for i, value in enumerate(obj): | ||
| new_name = "%s[%d]" % (name, i) | ||
| _verify_type(value, dataType.elementType, dataType.containsNull, name=new_name) | ||
|
|
||
| elif isinstance(dataType, MapType): | ||
| for k, v in obj.items(): | ||
| _verify_type(k, dataType.keyType, False) | ||
| _verify_type(v, dataType.valueType, dataType.valueContainsNull) | ||
| new_name = "%s[%s](key)" % (name, k) | ||
| _verify_type(k, dataType.keyType, False, name=new_name) | ||
| new_name = "%s[%s]" % (name, k) | ||
| _verify_type(v, dataType.valueType, dataType.valueContainsNull, name=new_name) | ||
|
|
||
| elif isinstance(dataType, StructType): | ||
| if isinstance(obj, dict): | ||
| for f in dataType.fields: | ||
| _verify_type(obj.get(f.name), f.dataType, f.nullable) | ||
| new_name = "%s.%s" % (name, f.name) | ||
| _verify_type(obj.get(f.name), f.dataType, f.nullable, name=new_name) | ||
| elif isinstance(obj, Row) and getattr(obj, "__from_dict__", False): | ||
| # the order in obj could be different than dataType.fields | ||
| for f in dataType.fields: | ||
| _verify_type(obj[f.name], f.dataType, f.nullable) | ||
| new_name = "%s.%s" % (name, f.name) | ||
| _verify_type(obj[f.name], f.dataType, f.nullable, name=new_name) | ||
| elif isinstance(obj, (tuple, list)): | ||
| if len(obj) != len(dataType.fields): | ||
| raise ValueError("Length of object (%d) does not match with " | ||
| "length of fields (%d)" % (len(obj), len(dataType.fields))) | ||
| raise ValueError("%s: Length of object (%d) does not match with " | ||
| "length of fields (%d)" % (name, len(obj), len(dataType.fields))) | ||
| for v, f in zip(obj, dataType.fields): | ||
| _verify_type(v, f.dataType, f.nullable) | ||
| new_name = "%s.%s" % (name, f.name) | ||
| _verify_type(v, f.dataType, f.nullable, name=new_name) | ||
| elif hasattr(obj, "__dict__"): | ||
| d = obj.__dict__ | ||
| for f in dataType.fields: | ||
| _verify_type(d.get(f.name), f.dataType, f.nullable) | ||
| new_name = "%s.%s" % (name, f.name) | ||
| _verify_type(d.get(f.name), f.dataType, f.nullable, name=new_name) | ||
| else: | ||
| raise TypeError("StructType can not accept object %r in type %s" % (obj, type(obj))) | ||
| raise TypeError("%s: StructType can not accept object %r in type %s" | ||
| % (name, obj, type(obj))) | ||
|
|
||
|
|
||
| # This is used to unpickle a Row from JVM | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(I have no idea why this was added in the first place ...)