From 9c0d20ca37b119f036e495c055513e7c5bb2471c Mon Sep 17 00:00:00 2001 From: David Gingrich Date: Tue, 28 Feb 2017 00:05:00 -0800 Subject: [PATCH 1/6] Remove "# noqa" comment from docstring --- python/pyspark/rdd.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index a5e6e2b054963..a9d12119b892b 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -627,7 +627,6 @@ def sortPartition(iterator): def sortByKey(self, ascending=True, numPartitions=None, keyfunc=lambda x: x): """ Sorts this RDD, which is assumed to consist of (key, value) pairs. - # noqa >>> tmp = [('a', 1), ('b', 2), ('1', 3), ('d', 4), ('2', 5)] >>> sc.parallelize(tmp).sortByKey().first() From 34dbb78628e261dcc5088b1b978d52d5c2b64ba6 Mon Sep 17 00:00:00 2001 From: David Gingrich Date: Tue, 28 Feb 2017 00:09:59 -0800 Subject: [PATCH 2/6] WIP: Add name parameter and better debugging to _verify_types * Add name paramter to _verify_types * Include name parameter in debug messages * Build name message for nested structs, arrays, and maps * Add detailed tests to flesh out spec for _verify_types (WIP) --- python/pyspark/sql/tests.py | 135 +++++++++++++++++++++++++++++++++++- python/pyspark/sql/types.py | 50 +++++++------ 2 files changed, 164 insertions(+), 21 deletions(-) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index f0a9a0400e392..9bfcac0d2f94f 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -31,6 +31,18 @@ import time import datetime +if sys.version_info[:2] <= (2, 6): + try: + import unittest2 as unittest + except ImportError: + sys.stderr.write('Please install unittest2 to test with Python 2.6 or earlier') + sys.exit(1) +else: + import unittest + if sys.version_info[0] >= 3: + xrange = range + basestring = str + import py4j try: import xmlrunner @@ -49,7 +61,7 @@ from pyspark import SparkContext from pyspark.sql import SparkSession, SQLContext, HiveContext, Column, Row from pyspark.sql.types import * -from pyspark.sql.types import UserDefinedType, _infer_type +from pyspark.sql.types import UserDefinedType, _infer_type, _verify_type from pyspark.tests import ReusedPySparkTestCase, SparkSubmitTests from pyspark.sql.functions import UserDefinedFunction, sha2, lit from pyspark.sql.window import Window @@ -2367,6 +2379,127 @@ def range_frame_match(): importlib.reload(window) + +class TypesTest(unittest.TestCase): + + def test_verify_type_ok_nullable(self): + for obj, data_type in [ + (None, IntegerType()), + (None, FloatType()), + (None, StringType()), + (None, StructType([]))]: + _verify_type(obj, data_type, nullable=True) + msg = "_verify_type(%s, %s, nullable=True)" % (obj, data_type) + self.assertIsTrue(True, msg) + + def test_verify_type_not_nullable(self): + import array + import datetime + import decimal + + MyStructType = StructType([ + StuctField('s', StringType(), nullable=False), + StructField('i', IntegerType(), nullable=True)]) + + class MyDictObj: + def __init__(self, s, i): + self.__dict__ = {'s': s, 'i': i} + + # exp: None for success, Exception subclass for error + for obj, data_type, exp in [ + # Strings (match anything but None) + ("", StringType(), None), + (1, StringType(), None), + (1.0, StringType(), None), + ([], StringType(), None), + ({}, StringType(), None), + (None, StringType(), ValueError), # Only None test + + # UDT + (ExamplePoint(1.0, 2.0), ExamplePointUDT(), None), + (ExamplePoint(1.0, 2.0), PythonOnlyUDT(), ValueError), + + # Boolean + (True, BooleanType(), None), + (1, BooleanType(), TypeError), + ("True", BooleanType(), TypeError), + ([1], BooleanType(), TypeError), + + # Bytes + (-(2**7) - 1, ByteType(), ValueError) + (-(2**7), ByteType(), None), + (2**7 - 1 , ByteType(), None), + (2**7, ByteType(), ValueError) + ("1", ByteType(), TypeError), + (1.0, ByteType(), TypeError), + + # Shorts + (-(2**15) - 1, ShortType(), ValueError), + (-(2**15), ShortType(), None), + (2**17 - 1, ShortType(), None), + (2**17, ShortType(), ValueError), + + # Integer + (-(2**31) - 1, IntegerType(), ValueError), + (-(2**31), IntegerType(), None), + (2**31 - 1, IntegerType(), None), + (2**31, IntegerType(), ValueError), + + # Long + (2**64, LongType(), None), + + # Float & Double + (1.0, FloatType(), None), + (1, FloatType(), TypeError), + (1.0, DoubleType(), None), + (1, DoubleType(), TypeError), + + # Decimal + (decimal.Decimal("1.0"), DecimalType(), None), + + # TODO: Finish tests + + # String + ("string", StringType(), None), + (u"unicode", UnicodeType(), None), + + # Binary + (bytearray([1, 2]), BinaryType(), None), + + # Date/Time + (datetime.date(2000, 1, 2), DateType(), None), + (datetime.datetime(2000, 1, 2, 3, 4), DateType(), None), + (datetime.datetime(2000, 1, 2, 3, 4), TimestampType(), None), + + # Array + ([], ArrayType(IntegerType()), None), + (["1", None], ArrayType(StringType(), containsNull=True), None), + ([1, 2], ArrayType(IntegerType()), None), + ((1, 2), ArrayType(IntegerType()), None), + (array.array('h', [1, 2]), ArrayType(IntegerType()), None), + + # Map + ({}, MapType(StringType(), InetgerType()), None), + ({"a": 1}, MapType(StringType(), IntegerType()), None), + ({"a": None}, MapType(StringType(), IntegerType(), valueContainsNull=True), None), + + # Struct + ({'s': 'a', 'i': 1}, MyStructType, None), + ({'s': 'a'}, MyStructType, None), + (['a', 1], MyStructType, None), + (('a', 1), MyStructType, None), + (Row(s=a, i=1), MyStructType, None), + (MyDictObj(s=1, i=1), MyStructType, None), + ]: + msg = "_verify_type(%s, %s, nullable=False) == %s" % (obj, data_type, exp) + if exp is None: + _verify_type(obj, data_type, nullable=False) + self.assertIsTrue(True, msg) + else: + with self.assertRaises(exp, msg): + _verify_type(obj, data_type, nullable=False) + + if __name__ == "__main__": from pyspark.sql.tests import * if xmlrunner: diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index 26b54a7fb3709..4cd0075954596 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -1249,7 +1249,7 @@ def _infer_schema_type(obj, dataType): } -def _verify_type(obj, dataType, nullable=True): +def _verify_type(obj, dataType, nullable=True, name="obj"): """ Verify the type of obj against dataType, raise a TypeError if they do not match. @@ -1300,7 +1300,7 @@ def _verify_type(obj, dataType, nullable=True): if nullable: return else: - raise ValueError("This field is not nullable, but got None") + raise ValueError("%s: This field is not nullable, but got None" % name) # StringType can work with any types if isinstance(dataType, StringType): @@ -1308,12 +1308,13 @@ def _verify_type(obj, dataType, nullable=True): if isinstance(dataType, UserDefinedType): if not (hasattr(obj, '__UDT__') and obj.__UDT__ == dataType): - raise ValueError("%r is not an instance of type %r" % (obj, dataType)) - _verify_type(dataType.toInternal(obj), dataType.sqlType()) + raise ValueError("%s: %r is not an instance of type %r" % (name, obj, dataType)) + _verify_type(dataType.toInternal(obj), dataType.sqlType(), name=name) return _type = type(dataType) - assert _type in _acceptable_types, "unknown datatype: %s for object %r" % (dataType, obj) + assert _type in _acceptable_types, \ + "%s: unknown datatype: %s for object %r" % (name, dataType, obj) if _type is StructType: # check the type and fields later @@ -1321,49 +1322,58 @@ def _verify_type(obj, dataType, nullable=True): else: # subclass of them can not be fromInternal in JVM if type(obj) not in _acceptable_types[_type]: - raise TypeError("%s can not accept object %r in type %s" % (dataType, obj, type(obj))) + raise TypeError("%s: %s can not accept object %r in type %s" + % (name, dataType, obj, type(obj))) if isinstance(dataType, ByteType): if obj < -128 or obj > 127: - raise ValueError("object of ByteType out of range, got: %s" % obj) + raise ValueError("%s: object of ByteType out of range, got: %s" % (name, obj)) elif isinstance(dataType, ShortType): if obj < -32768 or obj > 32767: - raise ValueError("object of ShortType out of range, got: %s" % obj) + raise ValueError("%s: object of ShortType out of range, got: %s" % (name, obj)) elif isinstance(dataType, IntegerType): if obj < -2147483648 or obj > 2147483647: - raise ValueError("object of IntegerType out of range, got: %s" % obj) + raise ValueError("%s: object of IntegerType out of range, got: %s" % (name, obj)) elif isinstance(dataType, ArrayType): - for i in obj: - _verify_type(i, dataType.elementType, dataType.containsNull) + for i, value in enumerate(obj): + new_name = "%s[%d]" % (name, i) + _verify_type(value, dataType.elementType, dataType.containsNull, name=new_name) elif isinstance(dataType, MapType): for k, v in obj.items(): - _verify_type(k, dataType.keyType, False) - _verify_type(v, dataType.valueType, dataType.valueContainsNull) + new_name = "%s[%s](key)" % (name, k) + _verify_type(k, dataType.keyType, False, name=new_name) + new_name = "%s[%s]" % (name, k) + _verify_type(v, dataType.valueType, dataType.valueContainsNull, name=new_name) elif isinstance(dataType, StructType): if isinstance(obj, dict): for f in dataType.fields: - _verify_type(obj.get(f.name), f.dataType, f.nullable) + new_name = "%s.%s" % (name, f.name) + _verify_type(obj.get(f.name), f.dataType, f.nullable, name=new_name) elif isinstance(obj, Row) and getattr(obj, "__from_dict__", False): # the order in obj could be different than dataType.fields for f in dataType.fields: - _verify_type(obj[f.name], f.dataType, f.nullable) + new_name = "%s.%s" % (name, f.name) + _verify_type(obj[f.name], f.dataType, f.nullable, name=new_name) elif isinstance(obj, (tuple, list)): if len(obj) != len(dataType.fields): - raise ValueError("Length of object (%d) does not match with " - "length of fields (%d)" % (len(obj), len(dataType.fields))) + raise ValueError("%s: Length of object (%d) does not match with " + "length of fields (%d)" % (name, len(obj), len(dataType.fields))) for v, f in zip(obj, dataType.fields): - _verify_type(v, f.dataType, f.nullable) + new_name = "%s.%s" % (name, f.name) + _verify_type(v, f.dataType, f.nullable, name=new_name) elif hasattr(obj, "__dict__"): d = obj.__dict__ for f in dataType.fields: - _verify_type(d.get(f.name), f.dataType, f.nullable) + new_name = "%s.%s" % (name, f.name) + _verify_type(d.get(f.name), f.dataType, f.nullable, name=new_name) else: - raise TypeError("StructType can not accept object %r in type %s" % (obj, type(obj))) + raise TypeError("%s: StructType can not accept object %r in type %s" + % (name, obj, type(obj))) # This is used to unpickle a Row from JVM From fcd2067c97e60a72653cd34905727e827c95ccbd Mon Sep 17 00:00:00 2001 From: David Gingrich Date: Sun, 12 Mar 2017 23:49:10 -0700 Subject: [PATCH 3/6] Finish types._verify_type tests. --- python/pyspark/sql/tests.py | 215 ++++++++++++++++++++---------------- 1 file changed, 120 insertions(+), 95 deletions(-) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 9bfcac0d2f94f..488b27f638df9 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -30,6 +30,7 @@ import functools import time import datetime +import traceback if sys.version_info[:2] <= (2, 6): try: @@ -2390,7 +2391,7 @@ def test_verify_type_ok_nullable(self): (None, StructType([]))]: _verify_type(obj, data_type, nullable=True) msg = "_verify_type(%s, %s, nullable=True)" % (obj, data_type) - self.assertIsTrue(True, msg) + self.assertTrue(True, msg) def test_verify_type_not_nullable(self): import array @@ -2398,105 +2399,129 @@ def test_verify_type_not_nullable(self): import decimal MyStructType = StructType([ - StuctField('s', StringType(), nullable=False), + StructField('s', StringType(), nullable=False), StructField('i', IntegerType(), nullable=True)]) - class MyDictObj: - def __init__(self, s, i): - self.__dict__ = {'s': s, 'i': i} - - # exp: None for success, Exception subclass for error - for obj, data_type, exp in [ - # Strings (match anything but None) - ("", StringType(), None), - (1, StringType(), None), - (1.0, StringType(), None), - ([], StringType(), None), - ({}, StringType(), None), - (None, StringType(), ValueError), # Only None test - - # UDT - (ExamplePoint(1.0, 2.0), ExamplePointUDT(), None), - (ExamplePoint(1.0, 2.0), PythonOnlyUDT(), ValueError), - - # Boolean - (True, BooleanType(), None), - (1, BooleanType(), TypeError), - ("True", BooleanType(), TypeError), - ([1], BooleanType(), TypeError), - - # Bytes - (-(2**7) - 1, ByteType(), ValueError) - (-(2**7), ByteType(), None), - (2**7 - 1 , ByteType(), None), - (2**7, ByteType(), ValueError) - ("1", ByteType(), TypeError), - (1.0, ByteType(), TypeError), - - # Shorts - (-(2**15) - 1, ShortType(), ValueError), - (-(2**15), ShortType(), None), - (2**17 - 1, ShortType(), None), - (2**17, ShortType(), ValueError), - - # Integer - (-(2**31) - 1, IntegerType(), ValueError), - (-(2**31), IntegerType(), None), - (2**31 - 1, IntegerType(), None), - (2**31, IntegerType(), ValueError), - - # Long - (2**64, LongType(), None), - - # Float & Double - (1.0, FloatType(), None), - (1, FloatType(), TypeError), - (1.0, DoubleType(), None), - (1, DoubleType(), TypeError), - - # Decimal - (decimal.Decimal("1.0"), DecimalType(), None), - - # TODO: Finish tests - - # String - ("string", StringType(), None), - (u"unicode", UnicodeType(), None), - - # Binary - (bytearray([1, 2]), BinaryType(), None), - - # Date/Time - (datetime.date(2000, 1, 2), DateType(), None), - (datetime.datetime(2000, 1, 2, 3, 4), DateType(), None), - (datetime.datetime(2000, 1, 2, 3, 4), TimestampType(), None), - - # Array - ([], ArrayType(IntegerType()), None), - (["1", None], ArrayType(StringType(), containsNull=True), None), - ([1, 2], ArrayType(IntegerType()), None), - ((1, 2), ArrayType(IntegerType()), None), - (array.array('h', [1, 2]), ArrayType(IntegerType()), None), - - # Map - ({}, MapType(StringType(), InetgerType()), None), - ({"a": 1}, MapType(StringType(), IntegerType()), None), - ({"a": None}, MapType(StringType(), IntegerType(), valueContainsNull=True), None), - - # Struct - ({'s': 'a', 'i': 1}, MyStructType, None), - ({'s': 'a'}, MyStructType, None), - (['a', 1], MyStructType, None), - (('a', 1), MyStructType, None), - (Row(s=a, i=1), MyStructType, None), - (MyDictObj(s=1, i=1), MyStructType, None), - ]: + class MyObj: + def __init__(self, **ka): + for k, v in ka.items(): + setattr(self, k, v) + + # obj, data_type, exception (None for success or Exception subclass for error) + spec = [ + # Strings (match anything but None) + ("", StringType(), None), + (u"", StringType(), None), + (1, StringType(), None), + (1.0, StringType(), None), + ([], StringType(), None), + ({}, StringType(), None), + (None, StringType(), ValueError), # Only None test + + # UDT + (ExamplePoint(1.0, 2.0), ExamplePointUDT(), None), + (ExamplePoint(1.0, 2.0), PythonOnlyUDT(), ValueError), + + # Boolean + (True, BooleanType(), None), + (1, BooleanType(), TypeError), + ("True", BooleanType(), TypeError), + ([1], BooleanType(), TypeError), + + # Bytes + (-(2**7) - 1, ByteType(), ValueError), + (-(2**7), ByteType(), None), + (2**7 - 1, ByteType(), None), + (2**7, ByteType(), ValueError), + ("1", ByteType(), TypeError), + (1.0, ByteType(), TypeError), + + # Shorts + (-(2**15) - 1, ShortType(), ValueError), + (-(2**15), ShortType(), None), + (2**15 - 1, ShortType(), None), + (2**15, ShortType(), ValueError), + + # Integer + (-(2**31) - 1, IntegerType(), ValueError), + (-(2**31), IntegerType(), None), + (2**31 - 1, IntegerType(), None), + (2**31, IntegerType(), ValueError), + + # Long + (2**64, LongType(), None), + + # Float & Double + (1.0, FloatType(), None), + (1, FloatType(), TypeError), + (1.0, DoubleType(), None), + (1, DoubleType(), TypeError), + + # Decimal + (decimal.Decimal("1.0"), DecimalType(), None), + (1.0, DecimalType(), TypeError), + (1, DecimalType(), TypeError), + ("1.0", DecimalType(), TypeError), + + # Binary + (bytearray([1, 2]), BinaryType(), None), + (1, BinaryType(), TypeError), + + # Date/Time + (datetime.date(2000, 1, 2), DateType(), None), + (datetime.datetime(2000, 1, 2, 3, 4), DateType(), None), + ("2000-01-02", DateType(), TypeError), + (datetime.datetime(2000, 1, 2, 3, 4), TimestampType(), None), + (946811040, TimestampType(), TypeError), + + # Array + ([], ArrayType(IntegerType()), None), + (["1", None], ArrayType(StringType(), containsNull=True), None), + ([1, 2], ArrayType(IntegerType()), None), + ([1, "2"], ArrayType(IntegerType()), TypeError), + ((1, 2), ArrayType(IntegerType()), None), + (array.array('h', [1, 2]), ArrayType(IntegerType()), None), + + # Map + ({}, MapType(StringType(), IntegerType()), None), + ({"a": 1}, MapType(StringType(), IntegerType()), None), + ({"a": 1}, MapType(IntegerType(), IntegerType()), TypeError), + ({"a": "1"}, MapType(StringType(), IntegerType()), TypeError), + ({"a": None}, MapType(StringType(), IntegerType(), valueContainsNull=True), None), + + # Struct + ({"s": "a", "i": 1}, MyStructType, None), + ({"s": "a", "i": None}, MyStructType, None), + ({"s": "a"}, MyStructType, None), + ({"s": "a", "f": 1.0}, MyStructType, None), # Extra fields OK + ({"s": "a", "i": "1"}, MyStructType, TypeError), + (Row(s="a", i=1), MyStructType, None), + (Row(s="a", i=None), MyStructType, None), + (Row(s="a", i=1, f=1.0), MyStructType, None), # Extra fields OK + (Row(s="a"), MyStructType, ValueError), # Row can't have missing field + (Row(s="a", i="1"), MyStructType, TypeError), + (["a", 1], MyStructType, None), + (["a", None], MyStructType, None), + (["a"], MyStructType, ValueError), + (["a", "1"], MyStructType, TypeError), + (("a", 1), MyStructType, None), + (MyObj(s="a", i=1), MyStructType, None), + (MyObj(s="a", i=None), MyStructType, None), + (MyObj(s="a"), MyStructType, None), + (MyObj(s="a", i="1"), MyStructType, TypeError), + ] + + for obj, data_type, exp in spec: msg = "_verify_type(%s, %s, nullable=False) == %s" % (obj, data_type, exp) if exp is None: - _verify_type(obj, data_type, nullable=False) - self.assertIsTrue(True, msg) + try: + _verify_type(obj, data_type, nullable=False) + except Exception as e: + traceback.print_exc() + self.fail(msg) + self.assertTrue(True, msg) else: - with self.assertRaises(exp, msg): + with self.assertRaises(exp, msg=msg): _verify_type(obj, data_type, nullable=False) From 5b4324e966a71fff787625b60001566f2895ac6a Mon Sep 17 00:00:00 2001 From: David Gingrich Date: Tue, 20 Jun 2017 14:47:21 -0700 Subject: [PATCH 4/6] Cleanup (PR feedback) * Change "test_verify_type_ok_nullable" to match "test_verify_type_not_nullable" * Add more null value tests * Remove unused "self.assertTrue(true)" lines --- python/pyspark/sql/tests.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 488b27f638df9..c26eea4c0c9a8 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -2389,9 +2389,12 @@ def test_verify_type_ok_nullable(self): (None, FloatType()), (None, StringType()), (None, StructType([]))]: - _verify_type(obj, data_type, nullable=True) msg = "_verify_type(%s, %s, nullable=True)" % (obj, data_type) - self.assertTrue(True, msg) + try: + _verify_type(obj, data_type, nullable=True) + except Exception as e: + traceback.print_exc() + self.fail(msg) def test_verify_type_not_nullable(self): import array @@ -2477,6 +2480,7 @@ def __init__(self, **ka): # Array ([], ArrayType(IntegerType()), None), (["1", None], ArrayType(StringType(), containsNull=True), None), + (["1", None], ArrayType(StringType(), containsNull=False), ValueError), ([1, 2], ArrayType(IntegerType()), None), ([1, "2"], ArrayType(IntegerType()), TypeError), ((1, 2), ArrayType(IntegerType()), None), @@ -2488,6 +2492,8 @@ def __init__(self, **ka): ({"a": 1}, MapType(IntegerType(), IntegerType()), TypeError), ({"a": "1"}, MapType(StringType(), IntegerType()), TypeError), ({"a": None}, MapType(StringType(), IntegerType(), valueContainsNull=True), None), + ({"a": None}, MapType(StringType(), IntegerType(), valueContainsNull=False), + ValueError), # Struct ({"s": "a", "i": 1}, MyStructType, None), @@ -2509,6 +2515,7 @@ def __init__(self, **ka): (MyObj(s="a", i=None), MyStructType, None), (MyObj(s="a"), MyStructType, None), (MyObj(s="a", i="1"), MyStructType, TypeError), + (MyObj(s=None, i="1"), MyStructType, ValueError), ] for obj, data_type, exp in spec: @@ -2519,7 +2526,6 @@ def __init__(self, **ka): except Exception as e: traceback.print_exc() self.fail(msg) - self.assertTrue(True, msg) else: with self.assertRaises(exp, msg=msg): _verify_type(obj, data_type, nullable=False) From 23511537966577bd6d2b4bbe9dd898faa0e72e97 Mon Sep 17 00:00:00 2001 From: David Gingrich Date: Thu, 22 Jun 2017 13:36:13 -0700 Subject: [PATCH 5/6] Cleanup (PR feedback) * Test that _verify_type exceptions start with passed name * Remove unused Exception variable --- python/pyspark/sql/tests.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index c26eea4c0c9a8..52a897dd61a85 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -2383,6 +2383,14 @@ def range_frame_match(): class TypesTest(unittest.TestCase): + def test_verify_type_exception_msg(self): + name = "test_name" + try: + _verify_type(None, StringType(), nullable=False, name=name) + self.fail('Expected _verify_type() to throw so test can check exception message') + except Exception as e: + self.assertTrue(str(e).startswith(name)) + def test_verify_type_ok_nullable(self): for obj, data_type in [ (None, IntegerType()), @@ -2523,7 +2531,7 @@ def __init__(self, **ka): if exp is None: try: _verify_type(obj, data_type, nullable=False) - except Exception as e: + except Exception: traceback.print_exc() self.fail(msg) else: From 6c1e0b690bdd1914b5056c8b2934614534c622cb Mon Sep 17 00:00:00 2001 From: David Gingrich Date: Mon, 26 Jun 2017 15:14:39 -0700 Subject: [PATCH 6/6] Cleanup (PR feedback) * Set name=value when calling _verify_type from createDataFrame for a DataType * Remove Nones from testing tuples in test_verify_type_ok_nullable --- python/pyspark/sql/session.py | 2 +- python/pyspark/sql/tests.py | 7 ++----- 2 files changed, 3 insertions(+), 6 deletions(-) diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py index 9f4772eec9f2a..fa9e28e2829ef 100644 --- a/python/pyspark/sql/session.py +++ b/python/pyspark/sql/session.py @@ -513,7 +513,7 @@ def prepare(obj): schema = StructType().add("value", schema) def prepare(obj): - verify_func(obj, dataType) + verify_func(obj, dataType, name='value') return obj, else: if isinstance(schema, list): diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 52a897dd61a85..2bd7ab1b87134 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -2392,11 +2392,8 @@ def test_verify_type_exception_msg(self): self.assertTrue(str(e).startswith(name)) def test_verify_type_ok_nullable(self): - for obj, data_type in [ - (None, IntegerType()), - (None, FloatType()), - (None, StringType()), - (None, StructType([]))]: + obj = None + for data_type in [IntegerType(), FloatType(), StringType(), StructType([])]: msg = "_verify_type(%s, %s, nullable=True)" % (obj, data_type) try: _verify_type(obj, data_type, nullable=True)