diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index a5e6e2b054963..a9d12119b892b 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -627,7 +627,6 @@ def sortPartition(iterator): def sortByKey(self, ascending=True, numPartitions=None, keyfunc=lambda x: x): """ Sorts this RDD, which is assumed to consist of (key, value) pairs. - # noqa >>> tmp = [('a', 1), ('b', 2), ('1', 3), ('d', 4), ('2', 5)] >>> sc.parallelize(tmp).sortByKey().first() diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py index 9f4772eec9f2a..fa9e28e2829ef 100644 --- a/python/pyspark/sql/session.py +++ b/python/pyspark/sql/session.py @@ -513,7 +513,7 @@ def prepare(obj): schema = StructType().add("value", schema) def prepare(obj): - verify_func(obj, dataType) + verify_func(obj, dataType, name='value') return obj, else: if isinstance(schema, list): diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index f0a9a0400e392..2bd7ab1b87134 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -30,6 +30,19 @@ import functools import time import datetime +import traceback + +if sys.version_info[:2] <= (2, 6): + try: + import unittest2 as unittest + except ImportError: + sys.stderr.write('Please install unittest2 to test with Python 2.6 or earlier') + sys.exit(1) +else: + import unittest + if sys.version_info[0] >= 3: + xrange = range + basestring = str import py4j try: @@ -49,7 +62,7 @@ from pyspark import SparkContext from pyspark.sql import SparkSession, SQLContext, HiveContext, Column, Row from pyspark.sql.types import * -from pyspark.sql.types import UserDefinedType, _infer_type +from pyspark.sql.types import UserDefinedType, _infer_type, _verify_type from pyspark.tests import ReusedPySparkTestCase, SparkSubmitTests from pyspark.sql.functions import UserDefinedFunction, sha2, lit from pyspark.sql.window import Window @@ -2367,6 +2380,162 @@ def range_frame_match(): importlib.reload(window) + +class TypesTest(unittest.TestCase): + + def test_verify_type_exception_msg(self): + name = "test_name" + try: + _verify_type(None, StringType(), nullable=False, name=name) + self.fail('Expected _verify_type() to throw so test can check exception message') + except Exception as e: + self.assertTrue(str(e).startswith(name)) + + def test_verify_type_ok_nullable(self): + obj = None + for data_type in [IntegerType(), FloatType(), StringType(), StructType([])]: + msg = "_verify_type(%s, %s, nullable=True)" % (obj, data_type) + try: + _verify_type(obj, data_type, nullable=True) + except Exception as e: + traceback.print_exc() + self.fail(msg) + + def test_verify_type_not_nullable(self): + import array + import datetime + import decimal + + MyStructType = StructType([ + StructField('s', StringType(), nullable=False), + StructField('i', IntegerType(), nullable=True)]) + + class MyObj: + def __init__(self, **ka): + for k, v in ka.items(): + setattr(self, k, v) + + # obj, data_type, exception (None for success or Exception subclass for error) + spec = [ + # Strings (match anything but None) + ("", StringType(), None), + (u"", StringType(), None), + (1, StringType(), None), + (1.0, StringType(), None), + ([], StringType(), None), + ({}, StringType(), None), + (None, StringType(), ValueError), # Only None test + + # UDT + (ExamplePoint(1.0, 2.0), ExamplePointUDT(), None), + (ExamplePoint(1.0, 2.0), PythonOnlyUDT(), ValueError), + + # Boolean + (True, BooleanType(), None), + (1, BooleanType(), TypeError), + ("True", BooleanType(), TypeError), + ([1], BooleanType(), TypeError), + + # Bytes + (-(2**7) - 1, ByteType(), ValueError), + (-(2**7), ByteType(), None), + (2**7 - 1, ByteType(), None), + (2**7, ByteType(), ValueError), + ("1", ByteType(), TypeError), + (1.0, ByteType(), TypeError), + + # Shorts + (-(2**15) - 1, ShortType(), ValueError), + (-(2**15), ShortType(), None), + (2**15 - 1, ShortType(), None), + (2**15, ShortType(), ValueError), + + # Integer + (-(2**31) - 1, IntegerType(), ValueError), + (-(2**31), IntegerType(), None), + (2**31 - 1, IntegerType(), None), + (2**31, IntegerType(), ValueError), + + # Long + (2**64, LongType(), None), + + # Float & Double + (1.0, FloatType(), None), + (1, FloatType(), TypeError), + (1.0, DoubleType(), None), + (1, DoubleType(), TypeError), + + # Decimal + (decimal.Decimal("1.0"), DecimalType(), None), + (1.0, DecimalType(), TypeError), + (1, DecimalType(), TypeError), + ("1.0", DecimalType(), TypeError), + + # Binary + (bytearray([1, 2]), BinaryType(), None), + (1, BinaryType(), TypeError), + + # Date/Time + (datetime.date(2000, 1, 2), DateType(), None), + (datetime.datetime(2000, 1, 2, 3, 4), DateType(), None), + ("2000-01-02", DateType(), TypeError), + (datetime.datetime(2000, 1, 2, 3, 4), TimestampType(), None), + (946811040, TimestampType(), TypeError), + + # Array + ([], ArrayType(IntegerType()), None), + (["1", None], ArrayType(StringType(), containsNull=True), None), + (["1", None], ArrayType(StringType(), containsNull=False), ValueError), + ([1, 2], ArrayType(IntegerType()), None), + ([1, "2"], ArrayType(IntegerType()), TypeError), + ((1, 2), ArrayType(IntegerType()), None), + (array.array('h', [1, 2]), ArrayType(IntegerType()), None), + + # Map + ({}, MapType(StringType(), IntegerType()), None), + ({"a": 1}, MapType(StringType(), IntegerType()), None), + ({"a": 1}, MapType(IntegerType(), IntegerType()), TypeError), + ({"a": "1"}, MapType(StringType(), IntegerType()), TypeError), + ({"a": None}, MapType(StringType(), IntegerType(), valueContainsNull=True), None), + ({"a": None}, MapType(StringType(), IntegerType(), valueContainsNull=False), + ValueError), + + # Struct + ({"s": "a", "i": 1}, MyStructType, None), + ({"s": "a", "i": None}, MyStructType, None), + ({"s": "a"}, MyStructType, None), + ({"s": "a", "f": 1.0}, MyStructType, None), # Extra fields OK + ({"s": "a", "i": "1"}, MyStructType, TypeError), + (Row(s="a", i=1), MyStructType, None), + (Row(s="a", i=None), MyStructType, None), + (Row(s="a", i=1, f=1.0), MyStructType, None), # Extra fields OK + (Row(s="a"), MyStructType, ValueError), # Row can't have missing field + (Row(s="a", i="1"), MyStructType, TypeError), + (["a", 1], MyStructType, None), + (["a", None], MyStructType, None), + (["a"], MyStructType, ValueError), + (["a", "1"], MyStructType, TypeError), + (("a", 1), MyStructType, None), + (MyObj(s="a", i=1), MyStructType, None), + (MyObj(s="a", i=None), MyStructType, None), + (MyObj(s="a"), MyStructType, None), + (MyObj(s="a", i="1"), MyStructType, TypeError), + (MyObj(s=None, i="1"), MyStructType, ValueError), + ] + + for obj, data_type, exp in spec: + msg = "_verify_type(%s, %s, nullable=False) == %s" % (obj, data_type, exp) + if exp is None: + try: + _verify_type(obj, data_type, nullable=False) + except Exception: + traceback.print_exc() + self.fail(msg) + else: + with self.assertRaises(exp, msg=msg): + _verify_type(obj, data_type, nullable=False) + + if __name__ == "__main__": from pyspark.sql.tests import * if xmlrunner: diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index 26b54a7fb3709..4cd0075954596 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -1249,7 +1249,7 @@ def _infer_schema_type(obj, dataType): } -def _verify_type(obj, dataType, nullable=True): +def _verify_type(obj, dataType, nullable=True, name="obj"): """ Verify the type of obj against dataType, raise a TypeError if they do not match. @@ -1300,7 +1300,7 @@ def _verify_type(obj, dataType, nullable=True): if nullable: return else: - raise ValueError("This field is not nullable, but got None") + raise ValueError("%s: This field is not nullable, but got None" % name) # StringType can work with any types if isinstance(dataType, StringType): @@ -1308,12 +1308,13 @@ def _verify_type(obj, dataType, nullable=True): if isinstance(dataType, UserDefinedType): if not (hasattr(obj, '__UDT__') and obj.__UDT__ == dataType): - raise ValueError("%r is not an instance of type %r" % (obj, dataType)) - _verify_type(dataType.toInternal(obj), dataType.sqlType()) + raise ValueError("%s: %r is not an instance of type %r" % (name, obj, dataType)) + _verify_type(dataType.toInternal(obj), dataType.sqlType(), name=name) return _type = type(dataType) - assert _type in _acceptable_types, "unknown datatype: %s for object %r" % (dataType, obj) + assert _type in _acceptable_types, \ + "%s: unknown datatype: %s for object %r" % (name, dataType, obj) if _type is StructType: # check the type and fields later @@ -1321,49 +1322,58 @@ def _verify_type(obj, dataType, nullable=True): else: # subclass of them can not be fromInternal in JVM if type(obj) not in _acceptable_types[_type]: - raise TypeError("%s can not accept object %r in type %s" % (dataType, obj, type(obj))) + raise TypeError("%s: %s can not accept object %r in type %s" + % (name, dataType, obj, type(obj))) if isinstance(dataType, ByteType): if obj < -128 or obj > 127: - raise ValueError("object of ByteType out of range, got: %s" % obj) + raise ValueError("%s: object of ByteType out of range, got: %s" % (name, obj)) elif isinstance(dataType, ShortType): if obj < -32768 or obj > 32767: - raise ValueError("object of ShortType out of range, got: %s" % obj) + raise ValueError("%s: object of ShortType out of range, got: %s" % (name, obj)) elif isinstance(dataType, IntegerType): if obj < -2147483648 or obj > 2147483647: - raise ValueError("object of IntegerType out of range, got: %s" % obj) + raise ValueError("%s: object of IntegerType out of range, got: %s" % (name, obj)) elif isinstance(dataType, ArrayType): - for i in obj: - _verify_type(i, dataType.elementType, dataType.containsNull) + for i, value in enumerate(obj): + new_name = "%s[%d]" % (name, i) + _verify_type(value, dataType.elementType, dataType.containsNull, name=new_name) elif isinstance(dataType, MapType): for k, v in obj.items(): - _verify_type(k, dataType.keyType, False) - _verify_type(v, dataType.valueType, dataType.valueContainsNull) + new_name = "%s[%s](key)" % (name, k) + _verify_type(k, dataType.keyType, False, name=new_name) + new_name = "%s[%s]" % (name, k) + _verify_type(v, dataType.valueType, dataType.valueContainsNull, name=new_name) elif isinstance(dataType, StructType): if isinstance(obj, dict): for f in dataType.fields: - _verify_type(obj.get(f.name), f.dataType, f.nullable) + new_name = "%s.%s" % (name, f.name) + _verify_type(obj.get(f.name), f.dataType, f.nullable, name=new_name) elif isinstance(obj, Row) and getattr(obj, "__from_dict__", False): # the order in obj could be different than dataType.fields for f in dataType.fields: - _verify_type(obj[f.name], f.dataType, f.nullable) + new_name = "%s.%s" % (name, f.name) + _verify_type(obj[f.name], f.dataType, f.nullable, name=new_name) elif isinstance(obj, (tuple, list)): if len(obj) != len(dataType.fields): - raise ValueError("Length of object (%d) does not match with " - "length of fields (%d)" % (len(obj), len(dataType.fields))) + raise ValueError("%s: Length of object (%d) does not match with " + "length of fields (%d)" % (name, len(obj), len(dataType.fields))) for v, f in zip(obj, dataType.fields): - _verify_type(v, f.dataType, f.nullable) + new_name = "%s.%s" % (name, f.name) + _verify_type(v, f.dataType, f.nullable, name=new_name) elif hasattr(obj, "__dict__"): d = obj.__dict__ for f in dataType.fields: - _verify_type(d.get(f.name), f.dataType, f.nullable) + new_name = "%s.%s" % (name, f.name) + _verify_type(d.get(f.name), f.dataType, f.nullable, name=new_name) else: - raise TypeError("StructType can not accept object %r in type %s" % (obj, type(obj))) + raise TypeError("%s: StructType can not accept object %r in type %s" + % (name, obj, type(obj))) # This is used to unpickle a Row from JVM