From a400d5534641813c3ebe75ca553361337766f04b Mon Sep 17 00:00:00 2001 From: David Gingrich Date: Tue, 28 Feb 2017 00:05:00 -0800 Subject: [PATCH 01/11] Show field name in _verify_type error --- python/pyspark/rdd.py | 1 - python/pyspark/sql/session.py | 2 +- python/pyspark/sql/tests.py | 171 +++++++++++++++++++++++++++++++++- python/pyspark/sql/types.py | 50 ++++++---- 4 files changed, 201 insertions(+), 23 deletions(-) diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 60141792d499b..7dfa17f68a943 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -627,7 +627,6 @@ def sortPartition(iterator): def sortByKey(self, ascending=True, numPartitions=None, keyfunc=lambda x: x): """ Sorts this RDD, which is assumed to consist of (key, value) pairs. - # noqa >>> tmp = [('a', 1), ('b', 2), ('1', 3), ('d', 4), ('2', 5)] >>> sc.parallelize(tmp).sortByKey().first() diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py index e3bf0f35ea15e..e8447ca64b563 100644 --- a/python/pyspark/sql/session.py +++ b/python/pyspark/sql/session.py @@ -524,7 +524,7 @@ def prepare(obj): schema = StructType().add("value", schema) def prepare(obj): - verify_func(obj, dataType) + verify_func(obj, dataType, name='value') return obj, else: if isinstance(schema, list): diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 0a1cd6856b8e8..d382016655c95 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -30,6 +30,19 @@ import functools import time import datetime +import traceback + +if sys.version_info[:2] <= (2, 6): + try: + import unittest2 as unittest + except ImportError: + sys.stderr.write('Please install unittest2 to test with Python 2.6 or earlier') + sys.exit(1) +else: + import unittest + if sys.version_info[0] >= 3: + xrange = range + basestring = str import py4j try: @@ -57,7 +70,7 @@ from pyspark import SparkContext from pyspark.sql import SparkSession, SQLContext, HiveContext, Column, Row from pyspark.sql.types import * -from pyspark.sql.types import UserDefinedType, _infer_type +from pyspark.sql.types import UserDefinedType, _infer_type, _verify_type from pyspark.tests import ReusedPySparkTestCase, SparkSubmitTests from pyspark.sql.functions import UserDefinedFunction, sha2, lit from pyspark.sql.window import Window @@ -2620,6 +2633,162 @@ def range_frame_match(): importlib.reload(window) + +class TypesTest(unittest.TestCase): + + def test_verify_type_exception_msg(self): + name = "test_name" + try: + _verify_type(None, StringType(), nullable=False, name=name) + self.fail('Expected _verify_type() to throw so test can check exception message') + except Exception as e: + self.assertTrue(str(e).startswith(name)) + + def test_verify_type_ok_nullable(self): + obj = None + for data_type in [IntegerType(), FloatType(), StringType(), StructType([])]: + msg = "_verify_type(%s, %s, nullable=True)" % (obj, data_type) + try: + _verify_type(obj, data_type, nullable=True) + except Exception as e: + traceback.print_exc() + self.fail(msg) + + def test_verify_type_not_nullable(self): + import array + import datetime + import decimal + + MyStructType = StructType([ + StructField('s', StringType(), nullable=False), + StructField('i', IntegerType(), nullable=True)]) + + class MyObj: + def __init__(self, **ka): + for k, v in ka.items(): + setattr(self, k, v) + + # obj, data_type, exception (None for success or Exception subclass for error) + spec = [ + # Strings (match anything but None) + ("", StringType(), None), + (u"", StringType(), None), + (1, StringType(), None), + (1.0, StringType(), None), + ([], StringType(), None), + ({}, StringType(), None), + (None, StringType(), ValueError), # Only None test + + # UDT + (ExamplePoint(1.0, 2.0), ExamplePointUDT(), None), + (ExamplePoint(1.0, 2.0), PythonOnlyUDT(), ValueError), + + # Boolean + (True, BooleanType(), None), + (1, BooleanType(), TypeError), + ("True", BooleanType(), TypeError), + ([1], BooleanType(), TypeError), + + # Bytes + (-(2**7) - 1, ByteType(), ValueError), + (-(2**7), ByteType(), None), + (2**7 - 1, ByteType(), None), + (2**7, ByteType(), ValueError), + ("1", ByteType(), TypeError), + (1.0, ByteType(), TypeError), + + # Shorts + (-(2**15) - 1, ShortType(), ValueError), + (-(2**15), ShortType(), None), + (2**15 - 1, ShortType(), None), + (2**15, ShortType(), ValueError), + + # Integer + (-(2**31) - 1, IntegerType(), ValueError), + (-(2**31), IntegerType(), None), + (2**31 - 1, IntegerType(), None), + (2**31, IntegerType(), ValueError), + + # Long + (2**64, LongType(), None), + + # Float & Double + (1.0, FloatType(), None), + (1, FloatType(), TypeError), + (1.0, DoubleType(), None), + (1, DoubleType(), TypeError), + + # Decimal + (decimal.Decimal("1.0"), DecimalType(), None), + (1.0, DecimalType(), TypeError), + (1, DecimalType(), TypeError), + ("1.0", DecimalType(), TypeError), + + # Binary + (bytearray([1, 2]), BinaryType(), None), + (1, BinaryType(), TypeError), + + # Date/Time + (datetime.date(2000, 1, 2), DateType(), None), + (datetime.datetime(2000, 1, 2, 3, 4), DateType(), None), + ("2000-01-02", DateType(), TypeError), + (datetime.datetime(2000, 1, 2, 3, 4), TimestampType(), None), + (946811040, TimestampType(), TypeError), + + # Array + ([], ArrayType(IntegerType()), None), + (["1", None], ArrayType(StringType(), containsNull=True), None), + (["1", None], ArrayType(StringType(), containsNull=False), ValueError), + ([1, 2], ArrayType(IntegerType()), None), + ([1, "2"], ArrayType(IntegerType()), TypeError), + ((1, 2), ArrayType(IntegerType()), None), + (array.array('h', [1, 2]), ArrayType(IntegerType()), None), + + # Map + ({}, MapType(StringType(), IntegerType()), None), + ({"a": 1}, MapType(StringType(), IntegerType()), None), + ({"a": 1}, MapType(IntegerType(), IntegerType()), TypeError), + ({"a": "1"}, MapType(StringType(), IntegerType()), TypeError), + ({"a": None}, MapType(StringType(), IntegerType(), valueContainsNull=True), None), + ({"a": None}, MapType(StringType(), IntegerType(), valueContainsNull=False), + ValueError), + + # Struct + ({"s": "a", "i": 1}, MyStructType, None), + ({"s": "a", "i": None}, MyStructType, None), + ({"s": "a"}, MyStructType, None), + ({"s": "a", "f": 1.0}, MyStructType, None), # Extra fields OK + ({"s": "a", "i": "1"}, MyStructType, TypeError), + (Row(s="a", i=1), MyStructType, None), + (Row(s="a", i=None), MyStructType, None), + (Row(s="a", i=1, f=1.0), MyStructType, None), # Extra fields OK + (Row(s="a"), MyStructType, ValueError), # Row can't have missing field + (Row(s="a", i="1"), MyStructType, TypeError), + (["a", 1], MyStructType, None), + (["a", None], MyStructType, None), + (["a"], MyStructType, ValueError), + (["a", "1"], MyStructType, TypeError), + (("a", 1), MyStructType, None), + (MyObj(s="a", i=1), MyStructType, None), + (MyObj(s="a", i=None), MyStructType, None), + (MyObj(s="a"), MyStructType, None), + (MyObj(s="a", i="1"), MyStructType, TypeError), + (MyObj(s=None, i="1"), MyStructType, ValueError), + ] + + for obj, data_type, exp in spec: + msg = "_verify_type(%s, %s, nullable=False) == %s" % (obj, data_type, exp) + if exp is None: + try: + _verify_type(obj, data_type, nullable=False) + except Exception: + traceback.print_exc() + self.fail(msg) + else: + with self.assertRaises(exp, msg=msg): + _verify_type(obj, data_type, nullable=False) + + if __name__ == "__main__": from pyspark.sql.tests import * if xmlrunner: diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index 26b54a7fb3709..4cd0075954596 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -1249,7 +1249,7 @@ def _infer_schema_type(obj, dataType): } -def _verify_type(obj, dataType, nullable=True): +def _verify_type(obj, dataType, nullable=True, name="obj"): """ Verify the type of obj against dataType, raise a TypeError if they do not match. @@ -1300,7 +1300,7 @@ def _verify_type(obj, dataType, nullable=True): if nullable: return else: - raise ValueError("This field is not nullable, but got None") + raise ValueError("%s: This field is not nullable, but got None" % name) # StringType can work with any types if isinstance(dataType, StringType): @@ -1308,12 +1308,13 @@ def _verify_type(obj, dataType, nullable=True): if isinstance(dataType, UserDefinedType): if not (hasattr(obj, '__UDT__') and obj.__UDT__ == dataType): - raise ValueError("%r is not an instance of type %r" % (obj, dataType)) - _verify_type(dataType.toInternal(obj), dataType.sqlType()) + raise ValueError("%s: %r is not an instance of type %r" % (name, obj, dataType)) + _verify_type(dataType.toInternal(obj), dataType.sqlType(), name=name) return _type = type(dataType) - assert _type in _acceptable_types, "unknown datatype: %s for object %r" % (dataType, obj) + assert _type in _acceptable_types, \ + "%s: unknown datatype: %s for object %r" % (name, dataType, obj) if _type is StructType: # check the type and fields later @@ -1321,49 +1322,58 @@ def _verify_type(obj, dataType, nullable=True): else: # subclass of them can not be fromInternal in JVM if type(obj) not in _acceptable_types[_type]: - raise TypeError("%s can not accept object %r in type %s" % (dataType, obj, type(obj))) + raise TypeError("%s: %s can not accept object %r in type %s" + % (name, dataType, obj, type(obj))) if isinstance(dataType, ByteType): if obj < -128 or obj > 127: - raise ValueError("object of ByteType out of range, got: %s" % obj) + raise ValueError("%s: object of ByteType out of range, got: %s" % (name, obj)) elif isinstance(dataType, ShortType): if obj < -32768 or obj > 32767: - raise ValueError("object of ShortType out of range, got: %s" % obj) + raise ValueError("%s: object of ShortType out of range, got: %s" % (name, obj)) elif isinstance(dataType, IntegerType): if obj < -2147483648 or obj > 2147483647: - raise ValueError("object of IntegerType out of range, got: %s" % obj) + raise ValueError("%s: object of IntegerType out of range, got: %s" % (name, obj)) elif isinstance(dataType, ArrayType): - for i in obj: - _verify_type(i, dataType.elementType, dataType.containsNull) + for i, value in enumerate(obj): + new_name = "%s[%d]" % (name, i) + _verify_type(value, dataType.elementType, dataType.containsNull, name=new_name) elif isinstance(dataType, MapType): for k, v in obj.items(): - _verify_type(k, dataType.keyType, False) - _verify_type(v, dataType.valueType, dataType.valueContainsNull) + new_name = "%s[%s](key)" % (name, k) + _verify_type(k, dataType.keyType, False, name=new_name) + new_name = "%s[%s]" % (name, k) + _verify_type(v, dataType.valueType, dataType.valueContainsNull, name=new_name) elif isinstance(dataType, StructType): if isinstance(obj, dict): for f in dataType.fields: - _verify_type(obj.get(f.name), f.dataType, f.nullable) + new_name = "%s.%s" % (name, f.name) + _verify_type(obj.get(f.name), f.dataType, f.nullable, name=new_name) elif isinstance(obj, Row) and getattr(obj, "__from_dict__", False): # the order in obj could be different than dataType.fields for f in dataType.fields: - _verify_type(obj[f.name], f.dataType, f.nullable) + new_name = "%s.%s" % (name, f.name) + _verify_type(obj[f.name], f.dataType, f.nullable, name=new_name) elif isinstance(obj, (tuple, list)): if len(obj) != len(dataType.fields): - raise ValueError("Length of object (%d) does not match with " - "length of fields (%d)" % (len(obj), len(dataType.fields))) + raise ValueError("%s: Length of object (%d) does not match with " + "length of fields (%d)" % (name, len(obj), len(dataType.fields))) for v, f in zip(obj, dataType.fields): - _verify_type(v, f.dataType, f.nullable) + new_name = "%s.%s" % (name, f.name) + _verify_type(v, f.dataType, f.nullable, name=new_name) elif hasattr(obj, "__dict__"): d = obj.__dict__ for f in dataType.fields: - _verify_type(d.get(f.name), f.dataType, f.nullable) + new_name = "%s.%s" % (name, f.name) + _verify_type(d.get(f.name), f.dataType, f.nullable, name=new_name) else: - raise TypeError("StructType can not accept object %r in type %s" % (obj, type(obj))) + raise TypeError("%s: StructType can not accept object %r in type %s" + % (name, obj, type(obj))) # This is used to unpickle a Row from JVM From 22311f1b4a1a51813904e0be1673b9e8466cac13 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Mon, 3 Jul 2017 14:41:01 +0900 Subject: [PATCH 02/11] Fix default obj parent name issue --- python/pyspark/sql/tests.py | 49 ++++++++++++++++------------ python/pyspark/sql/types.py | 64 +++++++++++++++++++++---------------- 2 files changed, 65 insertions(+), 48 deletions(-) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index d382016655c95..670a7944fc7cc 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -2644,6 +2644,13 @@ def test_verify_type_exception_msg(self): except Exception as e: self.assertTrue(str(e).startswith(name)) + schema = StructType([StructField('a', StructType([StructField('b', IntegerType())]))]) + try: + _verify_type([["data"]], schema) + self.fail('Expected _verify_type() to throw so test can check exception message') + except Exception as e: + self.assertTrue(str(e).startswith("a.b:")) + def test_verify_type_ok_nullable(self): obj = None for data_type in [IntegerType(), FloatType(), StringType(), StructType([])]: @@ -2659,7 +2666,7 @@ def test_verify_type_not_nullable(self): import datetime import decimal - MyStructType = StructType([ + schema = StructType([ StructField('s', StringType(), nullable=False), StructField('i', IntegerType(), nullable=True)]) @@ -2754,26 +2761,26 @@ def __init__(self, **ka): ValueError), # Struct - ({"s": "a", "i": 1}, MyStructType, None), - ({"s": "a", "i": None}, MyStructType, None), - ({"s": "a"}, MyStructType, None), - ({"s": "a", "f": 1.0}, MyStructType, None), # Extra fields OK - ({"s": "a", "i": "1"}, MyStructType, TypeError), - (Row(s="a", i=1), MyStructType, None), - (Row(s="a", i=None), MyStructType, None), - (Row(s="a", i=1, f=1.0), MyStructType, None), # Extra fields OK - (Row(s="a"), MyStructType, ValueError), # Row can't have missing field - (Row(s="a", i="1"), MyStructType, TypeError), - (["a", 1], MyStructType, None), - (["a", None], MyStructType, None), - (["a"], MyStructType, ValueError), - (["a", "1"], MyStructType, TypeError), - (("a", 1), MyStructType, None), - (MyObj(s="a", i=1), MyStructType, None), - (MyObj(s="a", i=None), MyStructType, None), - (MyObj(s="a"), MyStructType, None), - (MyObj(s="a", i="1"), MyStructType, TypeError), - (MyObj(s=None, i="1"), MyStructType, ValueError), + ({"s": "a", "i": 1}, schema, None), + ({"s": "a", "i": None}, schema, None), + ({"s": "a"}, schema, None), + ({"s": "a", "f": 1.0}, schema, None), # Extra fields OK + ({"s": "a", "i": "1"}, schema, TypeError), + (Row(s="a", i=1), schema, None), + (Row(s="a", i=None), schema, None), + (Row(s="a", i=1, f=1.0), schema, None), # Extra fields OK + (Row(s="a"), schema, ValueError), # Row can't have missing field + (Row(s="a", i="1"), schema, TypeError), + (["a", 1], schema, None), + (["a", None], schema, None), + (["a"], schema, ValueError), + (["a", "1"], schema, TypeError), + (("a", 1), schema, None), + (MyObj(s="a", i=1), schema, None), + (MyObj(s="a", i=None), schema, None), + (MyObj(s="a"), schema, None), + (MyObj(s="a", i="1"), schema, TypeError), + (MyObj(s=None, i="1"), schema, ValueError), ] for obj, data_type, exp in spec: diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index 4cd0075954596..4d0c682dc14ac 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -1249,7 +1249,7 @@ def _infer_schema_type(obj, dataType): } -def _verify_type(obj, dataType, nullable=True, name="obj"): +def _verify_type(obj, dataType, nullable=True, name=None): """ Verify the type of obj against dataType, raise a TypeError if they do not match. @@ -1296,11 +1296,25 @@ def _verify_type(obj, dataType, nullable=True, name="obj"): ... ValueError:... """ + + if name is None: + new_msg = lambda msg: msg + new_element_name = lambda idx: "[%d]" % idx + new_key_name = lambda key: "[%s](key)" % key + new_value_name = lambda key: "[%s]" % key + new_name = lambda n: n + else: + new_msg = lambda msg: "%s: %s" % (name, msg) + new_element_name = lambda idx: "%s[%d]" % (name, idx) + new_key_name = lambda key: "%s[%s](key)" % (name, key) + new_value_name = lambda key: "%s[%s]" % (name, key) + new_name = lambda n: "%s.%s" % (name, n) + if obj is None: if nullable: return else: - raise ValueError("%s: This field is not nullable, but got None" % name) + raise ValueError(new_msg("This field is not nullable, but got None")) # StringType can work with any types if isinstance(dataType, StringType): @@ -1308,13 +1322,13 @@ def _verify_type(obj, dataType, nullable=True, name="obj"): if isinstance(dataType, UserDefinedType): if not (hasattr(obj, '__UDT__') and obj.__UDT__ == dataType): - raise ValueError("%s: %r is not an instance of type %r" % (name, obj, dataType)) + raise ValueError(new_msg("%r is not an instance of type %r" % (obj, dataType))) _verify_type(dataType.toInternal(obj), dataType.sqlType(), name=name) return _type = type(dataType) assert _type in _acceptable_types, \ - "%s: unknown datatype: %s for object %r" % (name, dataType, obj) + new_msg("unknown datatype: %s for object %r" % (dataType, obj)) if _type is StructType: # check the type and fields later @@ -1322,58 +1336,54 @@ def _verify_type(obj, dataType, nullable=True, name="obj"): else: # subclass of them can not be fromInternal in JVM if type(obj) not in _acceptable_types[_type]: - raise TypeError("%s: %s can not accept object %r in type %s" - % (name, dataType, obj, type(obj))) + raise TypeError(new_msg("%s can not accept object %r in type %s" + % (dataType, obj, type(obj)))) if isinstance(dataType, ByteType): if obj < -128 or obj > 127: - raise ValueError("%s: object of ByteType out of range, got: %s" % (name, obj)) + raise ValueError(new_msg("object of ByteType out of range, got: %s" % obj)) elif isinstance(dataType, ShortType): if obj < -32768 or obj > 32767: - raise ValueError("%s: object of ShortType out of range, got: %s" % (name, obj)) + raise ValueError(new_msg("object of ShortType out of range, got: %s" % obj)) elif isinstance(dataType, IntegerType): if obj < -2147483648 or obj > 2147483647: - raise ValueError("%s: object of IntegerType out of range, got: %s" % (name, obj)) + raise ValueError(new_msg("object of IntegerType out of range, got: %s" % obj)) elif isinstance(dataType, ArrayType): for i, value in enumerate(obj): - new_name = "%s[%d]" % (name, i) - _verify_type(value, dataType.elementType, dataType.containsNull, name=new_name) + _verify_type( + value, dataType.elementType, dataType.containsNull, name=new_element_name(i)) elif isinstance(dataType, MapType): for k, v in obj.items(): - new_name = "%s[%s](key)" % (name, k) - _verify_type(k, dataType.keyType, False, name=new_name) - new_name = "%s[%s]" % (name, k) - _verify_type(v, dataType.valueType, dataType.valueContainsNull, name=new_name) + _verify_type(k, dataType.keyType, False, name=new_key_name(k)) + _verify_type( + v, dataType.valueType, dataType.valueContainsNull, name=new_value_name(k)) elif isinstance(dataType, StructType): if isinstance(obj, dict): for f in dataType.fields: - new_name = "%s.%s" % (name, f.name) - _verify_type(obj.get(f.name), f.dataType, f.nullable, name=new_name) + _verify_type(obj.get(f.name), f.dataType, f.nullable, name=new_name(f.name)) elif isinstance(obj, Row) and getattr(obj, "__from_dict__", False): # the order in obj could be different than dataType.fields for f in dataType.fields: - new_name = "%s.%s" % (name, f.name) - _verify_type(obj[f.name], f.dataType, f.nullable, name=new_name) + _verify_type(obj[f.name], f.dataType, f.nullable, name=new_name(f.name)) elif isinstance(obj, (tuple, list)): if len(obj) != len(dataType.fields): - raise ValueError("%s: Length of object (%d) does not match with " - "length of fields (%d)" % (name, len(obj), len(dataType.fields))) + raise ValueError( + new_msg("Length of object (%d) does not match with " + "length of fields (%d)" % (len(obj), len(dataType.fields)))) for v, f in zip(obj, dataType.fields): - new_name = "%s.%s" % (name, f.name) - _verify_type(v, f.dataType, f.nullable, name=new_name) + _verify_type(v, f.dataType, f.nullable, name=new_name(f.name)) elif hasattr(obj, "__dict__"): d = obj.__dict__ for f in dataType.fields: - new_name = "%s.%s" % (name, f.name) - _verify_type(d.get(f.name), f.dataType, f.nullable, name=new_name) + _verify_type(d.get(f.name), f.dataType, f.nullable, name=new_name(f.name)) else: - raise TypeError("%s: StructType can not accept object %r in type %s" - % (name, obj, type(obj))) + raise TypeError(new_msg("StructType can not accept object %r in type %s" + % (obj, type(obj)))) # This is used to unpickle a Row from JVM From d7f677830cb423a5da5e428bc3211f348795a2b2 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Tue, 4 Jul 2017 01:03:50 +0900 Subject: [PATCH 03/11] Fix type dispatch --- python/pyspark/sql/session.py | 12 +- python/pyspark/sql/tests.py | 43 ++++--- python/pyspark/sql/types.py | 206 ++++++++++++++++++++++------------ 3 files changed, 168 insertions(+), 93 deletions(-) diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py index e8447ca64b563..2cc0e2d1d7b8d 100644 --- a/python/pyspark/sql/session.py +++ b/python/pyspark/sql/session.py @@ -33,7 +33,7 @@ from pyspark.sql.dataframe import DataFrame from pyspark.sql.readwriter import DataFrameReader from pyspark.sql.streaming import DataStreamReader -from pyspark.sql.types import Row, DataType, StringType, StructType, _verify_type, \ +from pyspark.sql.types import Row, DataType, StringType, StructType, _make_type_verifier, \ _infer_schema, _has_nulltype, _merge_type, _create_converter, _parse_datatype_string from pyspark.sql.utils import install_exception_handler @@ -514,17 +514,21 @@ def createDataFrame(self, data, schema=None, samplingRatio=None, verifySchema=Tr schema = [str(x) for x in data.columns] data = [r.tolist() for r in data.to_records(index=False)] - verify_func = _verify_type if verifySchema else lambda _, t: True if isinstance(schema, StructType): + verify_func = _make_type_verifier(schema) if verifySchema else lambda _: True + def prepare(obj): - verify_func(obj, schema) + verify_func(obj) return obj elif isinstance(schema, DataType): dataType = schema schema = StructType().add("value", schema) + verify_func = _make_type_verifier( + dataType, name="field value") if verifySchema else lambda _: True + def prepare(obj): - verify_func(obj, dataType, name='value') + verify_func(obj) return obj, else: if isinstance(schema, list): diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 670a7944fc7cc..59cae81609eea 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -70,7 +70,7 @@ from pyspark import SparkContext from pyspark.sql import SparkSession, SQLContext, HiveContext, Column, Row from pyspark.sql.types import * -from pyspark.sql.types import UserDefinedType, _infer_type, _verify_type +from pyspark.sql.types import UserDefinedType, _infer_type, _make_type_verifier from pyspark.tests import ReusedPySparkTestCase, SparkSubmitTests from pyspark.sql.functions import UserDefinedFunction, sha2, lit from pyspark.sql.window import Window @@ -865,7 +865,7 @@ def test_convert_row_to_dict(self): self.assertEqual(1.0, row.asDict()['d']['key'].c) def test_udt(self): - from pyspark.sql.types import _parse_datatype_json_string, _infer_type, _verify_type + from pyspark.sql.types import _parse_datatype_json_string, _infer_type, _make_type_verifier from pyspark.sql.tests import ExamplePointUDT, ExamplePoint def check_datatype(datatype): @@ -881,8 +881,8 @@ def check_datatype(datatype): check_datatype(structtype_with_udt) p = ExamplePoint(1.0, 2.0) self.assertEqual(_infer_type(p), ExamplePointUDT()) - _verify_type(ExamplePoint(1.0, 2.0), ExamplePointUDT()) - self.assertRaises(ValueError, lambda: _verify_type([1.0, 2.0], ExamplePointUDT())) + _make_type_verifier(ExamplePointUDT())(ExamplePoint(1.0, 2.0)) + self.assertRaises(ValueError, lambda: _make_type_verifier(ExamplePointUDT())([1.0, 2.0])) check_datatype(PythonOnlyUDT()) structtype_with_udt = StructType([StructField("label", DoubleType(), False), @@ -890,8 +890,10 @@ def check_datatype(datatype): check_datatype(structtype_with_udt) p = PythonOnlyPoint(1.0, 2.0) self.assertEqual(_infer_type(p), PythonOnlyUDT()) - _verify_type(PythonOnlyPoint(1.0, 2.0), PythonOnlyUDT()) - self.assertRaises(ValueError, lambda: _verify_type([1.0, 2.0], PythonOnlyUDT())) + _make_type_verifier(PythonOnlyUDT())(PythonOnlyPoint(1.0, 2.0)) + self.assertRaises( + ValueError, + lambda: _make_type_verifier(PythonOnlyUDT())([1.0, 2.0])) def test_simple_udt_in_df(self): schema = StructType().add("key", LongType()).add("val", PythonOnlyUDT()) @@ -2637,31 +2639,40 @@ def range_frame_match(): class TypesTest(unittest.TestCase): def test_verify_type_exception_msg(self): + def verify_type(obj, dataType, nullable=True, name=None): + _make_type_verifier(dataType, nullable, name)(obj) + name = "test_name" try: - _verify_type(None, StringType(), nullable=False, name=name) - self.fail('Expected _verify_type() to throw so test can check exception message') + verify_type(None, StringType(), nullable=False, name=name) + self.fail('Expected verify_type() to throw so test can check exception message') except Exception as e: self.assertTrue(str(e).startswith(name)) schema = StructType([StructField('a', StructType([StructField('b', IntegerType())]))]) try: - _verify_type([["data"]], schema) - self.fail('Expected _verify_type() to throw so test can check exception message') + verify_type([["data"]], schema) + self.fail('Expected verify_type() to throw so test can check exception message') except Exception as e: - self.assertTrue(str(e).startswith("a.b:")) + self.assertTrue(str(e).startswith("field b in field a")) def test_verify_type_ok_nullable(self): + def verify_type(obj, dataType, nullable=True, name=None): + _make_type_verifier(dataType, nullable, name)(obj) + obj = None for data_type in [IntegerType(), FloatType(), StringType(), StructType([])]: - msg = "_verify_type(%s, %s, nullable=True)" % (obj, data_type) + msg = "verify_type(%s, %s, nullable=True)" % (obj, data_type) try: - _verify_type(obj, data_type, nullable=True) + verify_type(obj, data_type, nullable=True) except Exception as e: traceback.print_exc() self.fail(msg) def test_verify_type_not_nullable(self): + def verify_type(obj, dataType, nullable=True, name=None): + _make_type_verifier(dataType, nullable, name)(obj) + import array import datetime import decimal @@ -2784,16 +2795,16 @@ def __init__(self, **ka): ] for obj, data_type, exp in spec: - msg = "_verify_type(%s, %s, nullable=False) == %s" % (obj, data_type, exp) + msg = "verify_type(%s, %s, nullable=False) == %s" % (obj, data_type, exp) if exp is None: try: - _verify_type(obj, data_type, nullable=False) + verify_type(obj, data_type, nullable=False) except Exception: traceback.print_exc() self.fail(msg) else: with self.assertRaises(exp, msg=msg): - _verify_type(obj, data_type, nullable=False) + verify_type(obj, data_type, nullable=False) if __name__ == "__main__": diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index 4d0c682dc14ac..849cb224e1e79 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -1249,7 +1249,7 @@ def _infer_schema_type(obj, dataType): } -def _verify_type(obj, dataType, nullable=True, name=None): +def _make_type_verifier(dataType, nullable=True, name=None): """ Verify the type of obj against dataType, raise a TypeError if they do not match. @@ -1257,41 +1257,42 @@ def _verify_type(obj, dataType, nullable=True, name=None): range, e.g. using 128 as ByteType will overflow. Note that, Python float is not checked, so it will become infinity when cast to Java float if it overflows. - >>> _verify_type(None, StructType([])) - >>> _verify_type("", StringType()) - >>> _verify_type(0, LongType()) - >>> _verify_type(list(range(3)), ArrayType(ShortType())) - >>> _verify_type(set(), ArrayType(StringType())) # doctest: +IGNORE_EXCEPTION_DETAIL + >>> _make_type_verifier(StructType([]))(None) + >>> _make_type_verifier(StringType())("") + >>> _make_type_verifier(LongType())(0) + >>> _make_type_verifier(ArrayType(ShortType()))(list(range(3))) + >>> _make_type_verifier(ArrayType(StringType()))(set()) # doctest: +IGNORE_EXCEPTION_DETAIL Traceback (most recent call last): ... TypeError:... - >>> _verify_type({}, MapType(StringType(), IntegerType())) - >>> _verify_type((), StructType([])) - >>> _verify_type([], StructType([])) - >>> _verify_type([1], StructType([])) # doctest: +IGNORE_EXCEPTION_DETAIL + >>> _make_type_verifier(MapType(StringType(), IntegerType()))({}) + >>> _make_type_verifier(StructType([]))(()) + >>> _make_type_verifier(StructType([]))([]) + >>> _make_type_verifier(StructType([]))([1]) # doctest: +IGNORE_EXCEPTION_DETAIL Traceback (most recent call last): ... ValueError:... >>> # Check if numeric values are within the allowed range. - >>> _verify_type(12, ByteType()) - >>> _verify_type(1234, ByteType()) # doctest: +IGNORE_EXCEPTION_DETAIL + >>> _make_type_verifier(ByteType())(12) + >>> _make_type_verifier(ByteType())(1234) # doctest: +IGNORE_EXCEPTION_DETAIL Traceback (most recent call last): ... ValueError:... - >>> _verify_type(None, ByteType(), False) # doctest: +IGNORE_EXCEPTION_DETAIL + >>> _make_type_verifier(ByteType(), False)(None) # doctest: +IGNORE_EXCEPTION_DETAIL Traceback (most recent call last): ... ValueError:... - >>> _verify_type([1, None], ArrayType(ShortType(), False)) # doctest: +IGNORE_EXCEPTION_DETAIL + >>> _make_type_verifier( + ... ArrayType(ShortType(), False))([1, None]) # doctest: +IGNORE_EXCEPTION_DETAIL Traceback (most recent call last): ... ValueError:... - >>> _verify_type({None: 1}, MapType(StringType(), IntegerType())) + >>> _make_type_verifier(MapType(StringType(), IntegerType()))({None: 1}) Traceback (most recent call last): ... ValueError:... >>> schema = StructType().add("a", IntegerType()).add("b", StringType(), False) - >>> _verify_type((1, None), schema) # doctest: +IGNORE_EXCEPTION_DETAIL + >>> _make_type_verifier(schema)((1, None)) # doctest: +IGNORE_EXCEPTION_DETAIL Traceback (most recent call last): ... ValueError:... @@ -1299,91 +1300,150 @@ def _verify_type(obj, dataType, nullable=True, name=None): if name is None: new_msg = lambda msg: msg - new_element_name = lambda idx: "[%d]" % idx - new_key_name = lambda key: "[%s](key)" % key - new_value_name = lambda key: "[%s]" % key - new_name = lambda n: n + new_name = lambda n: "field %s" % n else: new_msg = lambda msg: "%s: %s" % (name, msg) - new_element_name = lambda idx: "%s[%d]" % (name, idx) - new_key_name = lambda key: "%s[%s](key)" % (name, key) - new_value_name = lambda key: "%s[%s]" % (name, key) - new_name = lambda n: "%s.%s" % (name, n) + new_name = lambda n: "field %s in %s" % (n, name) - if obj is None: - if nullable: - return + def verify_nullability(obj): + if obj is None: + if nullable: + return True + else: + raise ValueError(new_msg("This field is not nullable, but got None")) else: - raise ValueError(new_msg("This field is not nullable, but got None")) + return False # StringType can work with any types if isinstance(dataType, StringType): - return + def verify_string(obj): + if verify_nullability(obj): + return None + return verify_string if isinstance(dataType, UserDefinedType): - if not (hasattr(obj, '__UDT__') and obj.__UDT__ == dataType): - raise ValueError(new_msg("%r is not an instance of type %r" % (obj, dataType))) - _verify_type(dataType.toInternal(obj), dataType.sqlType(), name=name) - return + verifier = _make_type_verifier(dataType.sqlType(), name=name) + + def verify_udf(obj): + if verify_nullability(obj): + return None + if not (hasattr(obj, '__UDT__') and obj.__UDT__ == dataType): + raise ValueError(new_msg("%r is not an instance of type %r" % (obj, dataType))) + verifier(dataType.toInternal(obj)) + return verify_udf _type = type(dataType) - assert _type in _acceptable_types, \ - new_msg("unknown datatype: %s for object %r" % (dataType, obj)) - if _type is StructType: - # check the type and fields later - pass - else: + def assert_acceptable_types(obj): + assert _type in _acceptable_types, \ + new_msg("unknown datatype: %s for object %r" % (dataType, obj)) + + def verify_acceptable_types(obj): # subclass of them can not be fromInternal in JVM if type(obj) not in _acceptable_types[_type]: raise TypeError(new_msg("%s can not accept object %r in type %s" % (dataType, obj, type(obj)))) if isinstance(dataType, ByteType): - if obj < -128 or obj > 127: - raise ValueError(new_msg("object of ByteType out of range, got: %s" % obj)) + def verify_byte(obj): + if verify_nullability(obj): + return None + assert_acceptable_types(obj) + verify_acceptable_types(obj) + if obj < -128 or obj > 127: + raise ValueError(new_msg("object of ByteType out of range, got: %s" % obj)) + return verify_byte elif isinstance(dataType, ShortType): - if obj < -32768 or obj > 32767: - raise ValueError(new_msg("object of ShortType out of range, got: %s" % obj)) + def verify_short(obj): + if verify_nullability(obj): + return None + assert_acceptable_types(obj) + verify_acceptable_types(obj) + if obj < -32768 or obj > 32767: + raise ValueError(new_msg("object of ShortType out of range, got: %s" % obj)) + return verify_short elif isinstance(dataType, IntegerType): - if obj < -2147483648 or obj > 2147483647: - raise ValueError(new_msg("object of IntegerType out of range, got: %s" % obj)) + def verify_integer(obj): + if verify_nullability(obj): + return None + assert_acceptable_types(obj) + verify_acceptable_types(obj) + if obj < -2147483648 or obj > 2147483647: + raise ValueError( + new_msg("object of IntegerType out of range, got: %s" % obj)) + return verify_integer elif isinstance(dataType, ArrayType): - for i, value in enumerate(obj): - _verify_type( - value, dataType.elementType, dataType.containsNull, name=new_element_name(i)) + element_verifier = _make_type_verifier( + dataType.elementType, dataType.containsNull, name="element in array %s" % name) + + def verify_array(obj): + if verify_nullability(obj): + return None + assert_acceptable_types(obj) + verify_acceptable_types(obj) + for i in obj: + element_verifier(i) + return verify_array elif isinstance(dataType, MapType): - for k, v in obj.items(): - _verify_type(k, dataType.keyType, False, name=new_key_name(k)) - _verify_type( - v, dataType.valueType, dataType.valueContainsNull, name=new_value_name(k)) + key_verifier = _make_type_verifier(dataType.keyType, False, name="key of map %s" % name) + value_verifier = _make_type_verifier( + dataType.valueType, dataType.valueContainsNull, name="value of map %s" % name) + + def verify_map(obj): + if verify_nullability(obj): + return None + assert_acceptable_types(obj) + verify_acceptable_types(obj) + for k, v in obj.items(): + key_verifier(k) + value_verifier(v) + return verify_map elif isinstance(dataType, StructType): - if isinstance(obj, dict): - for f in dataType.fields: - _verify_type(obj.get(f.name), f.dataType, f.nullable, name=new_name(f.name)) - elif isinstance(obj, Row) and getattr(obj, "__from_dict__", False): - # the order in obj could be different than dataType.fields - for f in dataType.fields: - _verify_type(obj[f.name], f.dataType, f.nullable, name=new_name(f.name)) - elif isinstance(obj, (tuple, list)): - if len(obj) != len(dataType.fields): - raise ValueError( - new_msg("Length of object (%d) does not match with " - "length of fields (%d)" % (len(obj), len(dataType.fields)))) - for v, f in zip(obj, dataType.fields): - _verify_type(v, f.dataType, f.nullable, name=new_name(f.name)) - elif hasattr(obj, "__dict__"): - d = obj.__dict__ - for f in dataType.fields: - _verify_type(d.get(f.name), f.dataType, f.nullable, name=new_name(f.name)) - else: - raise TypeError(new_msg("StructType can not accept object %r in type %s" - % (obj, type(obj)))) + verifiers = [] + for f in dataType.fields: + verifier = _make_type_verifier(f.dataType, f.nullable, name=new_name(f.name)) + verifiers.append((f.name, verifier)) + + def verify_struct(obj): + if verify_nullability(obj): + return None + assert_acceptable_types(obj) + + if isinstance(obj, dict): + for f, verifier in verifiers: + verifier(obj.get(f)) + elif isinstance(obj, Row) and getattr(obj, "__from_dict__", False): + # the order in obj could be different than dataType.fields + for f, verifier in verifiers: + verifier(obj[f]) + elif isinstance(obj, (tuple, list)): + if len(obj) != len(verifiers): + raise ValueError( + new_msg("Length of object (%d) does not match with " + "length of fields (%d)" % (len(obj), len(verifiers)))) + for v, (_, verifier) in zip(obj, verifiers): + verifier(v) + elif hasattr(obj, "__dict__"): + d = obj.__dict__ + for f, verifier in verifiers: + verifier(d.get(f)) + else: + raise TypeError(new_msg("StructType can not accept object %r in type %s" + % (obj, type(obj)))) + return verify_struct + + else: + def verify_default(obj): + if verify_nullability(obj): + return None + assert_acceptable_types(obj) + verify_acceptable_types(obj) + return verify_default # This is used to unpickle a Row from JVM From 9564e37c9cc638f7a00c71306167087775019f85 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Tue, 4 Jul 2017 13:43:37 +0900 Subject: [PATCH 04/11] Cleaner and just remove 'return None' that confuses --- python/pyspark/sql/types.py | 74 +++++++++++++++++-------------------- 1 file changed, 34 insertions(+), 40 deletions(-) diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index 849cb224e1e79..56d9bb8fda0d2 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -1314,24 +1314,6 @@ def verify_nullability(obj): else: return False - # StringType can work with any types - if isinstance(dataType, StringType): - def verify_string(obj): - if verify_nullability(obj): - return None - return verify_string - - if isinstance(dataType, UserDefinedType): - verifier = _make_type_verifier(dataType.sqlType(), name=name) - - def verify_udf(obj): - if verify_nullability(obj): - return None - if not (hasattr(obj, '__UDT__') and obj.__UDT__ == dataType): - raise ValueError(new_msg("%r is not an instance of type %r" % (obj, dataType))) - verifier(dataType.toInternal(obj)) - return verify_udf - _type = type(dataType) def assert_acceptable_types(obj): @@ -1344,49 +1326,59 @@ def verify_acceptable_types(obj): raise TypeError(new_msg("%s can not accept object %r in type %s" % (dataType, obj, type(obj)))) - if isinstance(dataType, ByteType): + if isinstance(dataType, StringType): + # StringType can work with any types + verify_value = lambda _: _ + + elif isinstance(dataType, UserDefinedType): + verifier = _make_type_verifier(dataType.sqlType(), name=name) + + def verify_udf(obj): + if not (hasattr(obj, '__UDT__') and obj.__UDT__ == dataType): + raise ValueError(new_msg("%r is not an instance of type %r" % (obj, dataType))) + verifier(dataType.toInternal(obj)) + + verify_value = verify_udf + + elif isinstance(dataType, ByteType): def verify_byte(obj): - if verify_nullability(obj): - return None assert_acceptable_types(obj) verify_acceptable_types(obj) if obj < -128 or obj > 127: raise ValueError(new_msg("object of ByteType out of range, got: %s" % obj)) - return verify_byte + + verify_value = verify_byte elif isinstance(dataType, ShortType): def verify_short(obj): - if verify_nullability(obj): - return None assert_acceptable_types(obj) verify_acceptable_types(obj) if obj < -32768 or obj > 32767: raise ValueError(new_msg("object of ShortType out of range, got: %s" % obj)) - return verify_short + + verify_value = verify_short elif isinstance(dataType, IntegerType): def verify_integer(obj): - if verify_nullability(obj): - return None assert_acceptable_types(obj) verify_acceptable_types(obj) if obj < -2147483648 or obj > 2147483647: raise ValueError( new_msg("object of IntegerType out of range, got: %s" % obj)) - return verify_integer + + verify_value = verify_integer elif isinstance(dataType, ArrayType): element_verifier = _make_type_verifier( dataType.elementType, dataType.containsNull, name="element in array %s" % name) def verify_array(obj): - if verify_nullability(obj): - return None assert_acceptable_types(obj) verify_acceptable_types(obj) for i in obj: element_verifier(i) - return verify_array + + verify_value = verify_array elif isinstance(dataType, MapType): key_verifier = _make_type_verifier(dataType.keyType, False, name="key of map %s" % name) @@ -1394,14 +1386,13 @@ def verify_array(obj): dataType.valueType, dataType.valueContainsNull, name="value of map %s" % name) def verify_map(obj): - if verify_nullability(obj): - return None assert_acceptable_types(obj) verify_acceptable_types(obj) for k, v in obj.items(): key_verifier(k) value_verifier(v) - return verify_map + + verify_value = verify_map elif isinstance(dataType, StructType): verifiers = [] @@ -1410,8 +1401,6 @@ def verify_map(obj): verifiers.append((f.name, verifier)) def verify_struct(obj): - if verify_nullability(obj): - return None assert_acceptable_types(obj) if isinstance(obj, dict): @@ -1435,15 +1424,20 @@ def verify_struct(obj): else: raise TypeError(new_msg("StructType can not accept object %r in type %s" % (obj, type(obj)))) - return verify_struct + verify_value = verify_struct else: def verify_default(obj): - if verify_nullability(obj): - return None assert_acceptable_types(obj) verify_acceptable_types(obj) - return verify_default + + verify_value = verify_default + + def verify(obj): + if not verify_nullability(obj): + verify_value(obj) + + return verify # This is used to unpickle a Row from JVM From 5b80a8b92273e9abf6ce8b28dcd70fbb32d4613c Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Tue, 4 Jul 2017 13:48:32 +0900 Subject: [PATCH 05/11] Fix comments too accordingly --- python/pyspark/sql/types.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index 56d9bb8fda0d2..f5505ed4722ad 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -1251,11 +1251,12 @@ def _infer_schema_type(obj, dataType): def _make_type_verifier(dataType, nullable=True, name=None): """ - Verify the type of obj against dataType, raise a TypeError if they do not match. + Make a verifier that checks the type of obj against dataType and raises a TypeError if they do + not match. - Also verify the value of obj against datatype, raise a ValueError if it's not within the allowed - range, e.g. using 128 as ByteType will overflow. Note that, Python float is not checked, so it - will become infinity when cast to Java float if it overflows. + This verifier also checks the value of obj against datatype and raises a ValueError if it's not + within the allowed range, e.g. using 128 as ByteType will overflow. Note that, Python float is + not checked, so it will become infinity when cast to Java float if it overflows. >>> _make_type_verifier(StructType([]))(None) >>> _make_type_verifier(StringType())("") From 8765a1a52f110e961345a1bc34520e980d137bc1 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Tue, 4 Jul 2017 15:59:56 +0900 Subject: [PATCH 06/11] Shoter/cleaner and reorganaise some tests for readability --- python/pyspark/sql/tests.py | 196 +++++++++++++++++++----------------- 1 file changed, 106 insertions(+), 90 deletions(-) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 59cae81609eea..ab56eb4886192 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -2639,40 +2639,26 @@ def range_frame_match(): class TypesTest(unittest.TestCase): def test_verify_type_exception_msg(self): - def verify_type(obj, dataType, nullable=True, name=None): - _make_type_verifier(dataType, nullable, name)(obj) - - name = "test_name" - try: - verify_type(None, StringType(), nullable=False, name=name) - self.fail('Expected verify_type() to throw so test can check exception message') - except Exception as e: - self.assertTrue(str(e).startswith(name)) + msg = "Expected verify_type() to throw so test can check exception message." + with self.assertRaises(Exception, msg=msg) as cm: + _make_type_verifier(StringType(), nullable=False, name="test_name")(None) + self.assertTrue(str(cm.exception).startswith("test_name")) schema = StructType([StructField('a', StructType([StructField('b', IntegerType())]))]) - try: - verify_type([["data"]], schema) - self.fail('Expected verify_type() to throw so test can check exception message') - except Exception as e: - self.assertTrue(str(e).startswith("field b in field a")) + with self.assertRaises(Exception, msg=msg) as cm: + _make_type_verifier(schema)([["data"]]) + self.assertTrue(str(cm.exception).startswith("field b in field a")) def test_verify_type_ok_nullable(self): - def verify_type(obj, dataType, nullable=True, name=None): - _make_type_verifier(dataType, nullable, name)(obj) - obj = None - for data_type in [IntegerType(), FloatType(), StringType(), StructType([])]: - msg = "verify_type(%s, %s, nullable=True)" % (obj, data_type) + types = [IntegerType(), FloatType(), StringType(), StructType([])] + for data_type in types: try: - verify_type(obj, data_type, nullable=True) - except Exception as e: - traceback.print_exc() - self.fail(msg) + _make_type_verifier(data_type, nullable=True)(obj) + except Exception: + self.fail("verify_type(%s, %s, nullable=True)" % (obj, data_type)) def test_verify_type_not_nullable(self): - def verify_type(obj, dataType, nullable=True, name=None): - _make_type_verifier(dataType, nullable, name)(obj) - import array import datetime import decimal @@ -2682,129 +2668,159 @@ def verify_type(obj, dataType, nullable=True, name=None): StructField('i', IntegerType(), nullable=True)]) class MyObj: - def __init__(self, **ka): - for k, v in ka.items(): + def __init__(self, **kwargs): + for k, v in kwargs.items(): setattr(self, k, v) - # obj, data_type, exception (None for success or Exception subclass for error) - spec = [ - # Strings (match anything but None) - ("", StringType(), None), - (u"", StringType(), None), - (1, StringType(), None), - (1.0, StringType(), None), - ([], StringType(), None), - ({}, StringType(), None), - (None, StringType(), ValueError), # Only None test + # obj, data_type + success_spec = [ + # String + ("", StringType()), + (u"", StringType()), + (1, StringType()), + (1.0, StringType()), + ([], StringType()), + ({}, StringType()), + + # UDT + (ExamplePoint(1.0, 2.0), ExamplePointUDT()), + + # Boolean + (True, BooleanType()), + + # Byte + (-(2**7), ByteType()), + (2**7 - 1, ByteType()), + + # Short + (-(2**15), ShortType()), + (2**15 - 1, ShortType()), + + # Integer + (-(2**31), IntegerType()), + (2**31 - 1, IntegerType()), + + # Long + (2**64, LongType()), + + # Float & Double + (1.0, FloatType()), + (1.0, DoubleType()), + + # Decimal + (decimal.Decimal("1.0"), DecimalType()), + + # Binary + (bytearray([1, 2]), BinaryType()), + + # Date/Timestamp + (datetime.date(2000, 1, 2), DateType()), + (datetime.datetime(2000, 1, 2, 3, 4), DateType()), + (datetime.datetime(2000, 1, 2, 3, 4), TimestampType()), + + # Array + ([], ArrayType(IntegerType())), + (["1", None], ArrayType(StringType(), containsNull=True)), + ([1, 2], ArrayType(IntegerType())), + ((1, 2), ArrayType(IntegerType())), + (array.array('h', [1, 2]), ArrayType(IntegerType())), + + # Map + ({}, MapType(StringType(), IntegerType())), + ({"a": 1}, MapType(StringType(), IntegerType())), + ({"a": None}, MapType(StringType(), IntegerType(), valueContainsNull=True)), + + # Struct + ({"s": "a", "i": 1}, schema), + ({"s": "a", "i": None}, schema), + ({"s": "a"}, schema), + ({"s": "a", "f": 1.0}, schema), + (Row(s="a", i=1), schema), + (Row(s="a", i=None), schema), + (Row(s="a", i=1, f=1.0), schema), + (["a", 1], schema), + (["a", None], schema), + (("a", 1), schema), + (MyObj(s="a", i=1), schema), + (MyObj(s="a", i=None), schema), + (MyObj(s="a"), schema), + ] + + # obj, data_type, exception class + failure_spec = [ + # String (match anything but None) + (None, StringType(), ValueError), # UDT - (ExamplePoint(1.0, 2.0), ExamplePointUDT(), None), (ExamplePoint(1.0, 2.0), PythonOnlyUDT(), ValueError), # Boolean - (True, BooleanType(), None), (1, BooleanType(), TypeError), ("True", BooleanType(), TypeError), ([1], BooleanType(), TypeError), - # Bytes + # Byte (-(2**7) - 1, ByteType(), ValueError), - (-(2**7), ByteType(), None), - (2**7 - 1, ByteType(), None), (2**7, ByteType(), ValueError), ("1", ByteType(), TypeError), (1.0, ByteType(), TypeError), - # Shorts + # Short (-(2**15) - 1, ShortType(), ValueError), - (-(2**15), ShortType(), None), - (2**15 - 1, ShortType(), None), (2**15, ShortType(), ValueError), # Integer (-(2**31) - 1, IntegerType(), ValueError), - (-(2**31), IntegerType(), None), - (2**31 - 1, IntegerType(), None), (2**31, IntegerType(), ValueError), - # Long - (2**64, LongType(), None), - # Float & Double - (1.0, FloatType(), None), (1, FloatType(), TypeError), - (1.0, DoubleType(), None), (1, DoubleType(), TypeError), # Decimal - (decimal.Decimal("1.0"), DecimalType(), None), (1.0, DecimalType(), TypeError), (1, DecimalType(), TypeError), ("1.0", DecimalType(), TypeError), # Binary - (bytearray([1, 2]), BinaryType(), None), (1, BinaryType(), TypeError), - # Date/Time - (datetime.date(2000, 1, 2), DateType(), None), - (datetime.datetime(2000, 1, 2, 3, 4), DateType(), None), + # Date/Timestamp ("2000-01-02", DateType(), TypeError), - (datetime.datetime(2000, 1, 2, 3, 4), TimestampType(), None), (946811040, TimestampType(), TypeError), # Array - ([], ArrayType(IntegerType()), None), - (["1", None], ArrayType(StringType(), containsNull=True), None), (["1", None], ArrayType(StringType(), containsNull=False), ValueError), - ([1, 2], ArrayType(IntegerType()), None), ([1, "2"], ArrayType(IntegerType()), TypeError), - ((1, 2), ArrayType(IntegerType()), None), - (array.array('h', [1, 2]), ArrayType(IntegerType()), None), # Map - ({}, MapType(StringType(), IntegerType()), None), - ({"a": 1}, MapType(StringType(), IntegerType()), None), ({"a": 1}, MapType(IntegerType(), IntegerType()), TypeError), ({"a": "1"}, MapType(StringType(), IntegerType()), TypeError), - ({"a": None}, MapType(StringType(), IntegerType(), valueContainsNull=True), None), ({"a": None}, MapType(StringType(), IntegerType(), valueContainsNull=False), ValueError), # Struct - ({"s": "a", "i": 1}, schema, None), - ({"s": "a", "i": None}, schema, None), - ({"s": "a"}, schema, None), - ({"s": "a", "f": 1.0}, schema, None), # Extra fields OK ({"s": "a", "i": "1"}, schema, TypeError), - (Row(s="a", i=1), schema, None), - (Row(s="a", i=None), schema, None), - (Row(s="a", i=1, f=1.0), schema, None), # Extra fields OK (Row(s="a"), schema, ValueError), # Row can't have missing field (Row(s="a", i="1"), schema, TypeError), - (["a", 1], schema, None), - (["a", None], schema, None), (["a"], schema, ValueError), (["a", "1"], schema, TypeError), - (("a", 1), schema, None), - (MyObj(s="a", i=1), schema, None), - (MyObj(s="a", i=None), schema, None), - (MyObj(s="a"), schema, None), (MyObj(s="a", i="1"), schema, TypeError), (MyObj(s=None, i="1"), schema, ValueError), ] - for obj, data_type, exp in spec: + # Check success cases + for obj, data_type in success_spec: + try: + _make_type_verifier(data_type, nullable=False)(obj) + except Exception: + self.fail("verify_type(%s, %s, nullable=False)" % (obj, data_type)) + + # Check failure cases + for obj, data_type, exp in failure_spec: msg = "verify_type(%s, %s, nullable=False) == %s" % (obj, data_type, exp) - if exp is None: - try: - verify_type(obj, data_type, nullable=False) - except Exception: - traceback.print_exc() - self.fail(msg) - else: - with self.assertRaises(exp, msg=msg): - verify_type(obj, data_type, nullable=False) + with self.assertRaises(exp, msg=msg): + _make_type_verifier(data_type, nullable=False)(obj) if __name__ == "__main__": From 9ee8d03a951f06a7426f8e924a084d20b25e26d1 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Tue, 4 Jul 2017 16:01:55 +0900 Subject: [PATCH 07/11] TypesTests -> DataTypeVerificationTests --- python/pyspark/sql/tests.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index ab56eb4886192..fdb6c3df9cfab 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -2636,7 +2636,7 @@ def range_frame_match(): importlib.reload(window) -class TypesTest(unittest.TestCase): +class DataTypeVerificationTests(unittest.TestCase): def test_verify_type_exception_msg(self): msg = "Expected verify_type() to throw so test can check exception message." From 420b4bfbe2e0e5f0125df119dd1aed29cfab499a Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Tue, 4 Jul 2017 16:23:39 +0900 Subject: [PATCH 08/11] Don't forget to check the exception from context manager --- python/pyspark/sql/tests.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index fdb6c3df9cfab..e3cbf3af6244e 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -2642,12 +2642,12 @@ def test_verify_type_exception_msg(self): msg = "Expected verify_type() to throw so test can check exception message." with self.assertRaises(Exception, msg=msg) as cm: _make_type_verifier(StringType(), nullable=False, name="test_name")(None) - self.assertTrue(str(cm.exception).startswith("test_name")) + self.assertTrue(str(cm.exception).startswith("test_name")) schema = StructType([StructField('a', StructType([StructField('b', IntegerType())]))]) with self.assertRaises(Exception, msg=msg) as cm: _make_type_verifier(schema)([["data"]]) - self.assertTrue(str(cm.exception).startswith("field b in field a")) + self.assertTrue(str(cm.exception).startswith("field b in field a")) def test_verify_type_ok_nullable(self): obj = None From 15c575fea9b0adaa58ef1b4c553c9d5bb11990d3 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Tue, 4 Jul 2017 16:59:03 +0900 Subject: [PATCH 09/11] Specify the exception type and make exception message tests simpler --- python/pyspark/sql/tests.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index e3cbf3af6244e..edaf3383ebd32 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -2639,15 +2639,16 @@ def range_frame_match(): class DataTypeVerificationTests(unittest.TestCase): def test_verify_type_exception_msg(self): - msg = "Expected verify_type() to throw so test can check exception message." - with self.assertRaises(Exception, msg=msg) as cm: - _make_type_verifier(StringType(), nullable=False, name="test_name")(None) - self.assertTrue(str(cm.exception).startswith("test_name")) + self.assertRaisesRegexp( + ValueError, + "test_name", + lambda: _make_type_verifier(StringType(), nullable=False, name="test_name")(None)) schema = StructType([StructField('a', StructType([StructField('b', IntegerType())]))]) - with self.assertRaises(Exception, msg=msg) as cm: - _make_type_verifier(schema)([["data"]]) - self.assertTrue(str(cm.exception).startswith("field b in field a")) + self.assertRaisesRegexp( + TypeError, + "field b in field a", + lambda: _make_type_verifier(schema)([["data"]])) def test_verify_type_ok_nullable(self): obj = None From c49098e5e25f83af7257e8a659b2307eb150d737 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Tue, 4 Jul 2017 17:19:41 +0900 Subject: [PATCH 10/11] Revert unrelated changes --- python/pyspark/sql/tests.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index edaf3383ebd32..8f39ed3e8bae8 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -40,9 +40,6 @@ sys.exit(1) else: import unittest - if sys.version_info[0] >= 3: - xrange = range - basestring = str import py4j try: From 826dcfd43d6da28ca5b9bfa2d719d6c9ba090e5a Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Tue, 4 Jul 2017 17:28:36 +0900 Subject: [PATCH 11/11] Get rid of uesless imports --- python/pyspark/sql/tests.py | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 8f39ed3e8bae8..d89b1015339a9 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -30,16 +30,6 @@ import functools import time import datetime -import traceback - -if sys.version_info[:2] <= (2, 6): - try: - import unittest2 as unittest - except ImportError: - sys.stderr.write('Please install unittest2 to test with Python 2.6 or earlier') - sys.exit(1) -else: - import unittest import py4j try: