Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion python/pyspark/rdd.py
Original file line number Diff line number Diff line change
Expand Up @@ -627,7 +627,6 @@ def sortPartition(iterator):
def sortByKey(self, ascending=True, numPartitions=None, keyfunc=lambda x: x):
"""
Sorts this RDD, which is assumed to consist of (key, value) pairs.
# noqa
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(I have no idea why this was added in the first place ...)


>>> tmp = [('a', 1), ('b', 2), ('1', 3), ('d', 4), ('2', 5)]
>>> sc.parallelize(tmp).sortByKey().first()
Expand Down
2 changes: 1 addition & 1 deletion python/pyspark/sql/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -513,7 +513,7 @@ def prepare(obj):
schema = StructType().add("value", schema)

def prepare(obj):
verify_func(obj, dataType)
verify_func(obj, dataType, name='value')
return obj,
else:
if isinstance(schema, list):
Expand Down
171 changes: 170 additions & 1 deletion python/pyspark/sql/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,19 @@
import functools
import time
import datetime
import traceback

if sys.version_info[:2] <= (2, 6):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not a big deal but I guess we dropped 2.6 support.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like most of the other tests still have the <= (2, 6) check (see python/pyspark/ml/tests.py) so leaving in place.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea, let's leave it then. Not a big deal.

try:
import unittest2 as unittest
except ImportError:
sys.stderr.write('Please install unittest2 to test with Python 2.6 or earlier')
sys.exit(1)
else:
import unittest
if sys.version_info[0] >= 3:
xrange = range
basestring = str

import py4j
try:
Expand All @@ -49,7 +62,7 @@
from pyspark import SparkContext
from pyspark.sql import SparkSession, SQLContext, HiveContext, Column, Row
from pyspark.sql.types import *
from pyspark.sql.types import UserDefinedType, _infer_type
from pyspark.sql.types import UserDefinedType, _infer_type, _verify_type
from pyspark.tests import ReusedPySparkTestCase, SparkSubmitTests
from pyspark.sql.functions import UserDefinedFunction, sha2, lit
from pyspark.sql.window import Window
Expand Down Expand Up @@ -2367,6 +2380,162 @@ def range_frame_match():

importlib.reload(window)


class TypesTest(unittest.TestCase):

def test_verify_type_exception_msg(self):
name = "test_name"
try:
_verify_type(None, StringType(), nullable=False, name=name)
self.fail('Expected _verify_type() to throw so test can check exception message')
except Exception as e:
self.assertTrue(str(e).startswith(name))

def test_verify_type_ok_nullable(self):
obj = None
for data_type in [IntegerType(), FloatType(), StringType(), StructType([])]:
msg = "_verify_type(%s, %s, nullable=True)" % (obj, data_type)
try:
_verify_type(obj, data_type, nullable=True)
except Exception as e:
traceback.print_exc()
self.fail(msg)

def test_verify_type_not_nullable(self):
import array
import datetime
import decimal

MyStructType = StructType([
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we make the first character this lower-cased? (or maybe just simply schema?)

StructField('s', StringType(), nullable=False),
StructField('i', IntegerType(), nullable=True)])

class MyObj:
def __init__(self, **ka):
for k, v in ka.items():
setattr(self, k, v)

# obj, data_type, exception (None for success or Exception subclass for error)
spec = [
# Strings (match anything but None)
("", StringType(), None),
(u"", StringType(), None),
(1, StringType(), None),
(1.0, StringType(), None),
([], StringType(), None),
({}, StringType(), None),
(None, StringType(), ValueError), # Only None test

# UDT
(ExamplePoint(1.0, 2.0), ExamplePointUDT(), None),
(ExamplePoint(1.0, 2.0), PythonOnlyUDT(), ValueError),

# Boolean
(True, BooleanType(), None),
(1, BooleanType(), TypeError),
("True", BooleanType(), TypeError),
([1], BooleanType(), TypeError),

# Bytes
(-(2**7) - 1, ByteType(), ValueError),
(-(2**7), ByteType(), None),
(2**7 - 1, ByteType(), None),
(2**7, ByteType(), ValueError),
("1", ByteType(), TypeError),
(1.0, ByteType(), TypeError),

# Shorts
(-(2**15) - 1, ShortType(), ValueError),
(-(2**15), ShortType(), None),
(2**15 - 1, ShortType(), None),
(2**15, ShortType(), ValueError),

# Integer
(-(2**31) - 1, IntegerType(), ValueError),
(-(2**31), IntegerType(), None),
(2**31 - 1, IntegerType(), None),
(2**31, IntegerType(), ValueError),

# Long
(2**64, LongType(), None),

# Float & Double
(1.0, FloatType(), None),
(1, FloatType(), TypeError),
(1.0, DoubleType(), None),
(1, DoubleType(), TypeError),

# Decimal
(decimal.Decimal("1.0"), DecimalType(), None),
(1.0, DecimalType(), TypeError),
(1, DecimalType(), TypeError),
("1.0", DecimalType(), TypeError),

# Binary
(bytearray([1, 2]), BinaryType(), None),
(1, BinaryType(), TypeError),

# Date/Time
(datetime.date(2000, 1, 2), DateType(), None),
(datetime.datetime(2000, 1, 2, 3, 4), DateType(), None),
("2000-01-02", DateType(), TypeError),
(datetime.datetime(2000, 1, 2, 3, 4), TimestampType(), None),
(946811040, TimestampType(), TypeError),

# Array
([], ArrayType(IntegerType()), None),
(["1", None], ArrayType(StringType(), containsNull=True), None),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd like you to add containsNull=False case too which contains None in the list to verify that it raises ValueError correctly.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added

(["1", None], ArrayType(StringType(), containsNull=False), ValueError),
([1, 2], ArrayType(IntegerType()), None),
([1, "2"], ArrayType(IntegerType()), TypeError),
((1, 2), ArrayType(IntegerType()), None),
(array.array('h', [1, 2]), ArrayType(IntegerType()), None),

# Map
({}, MapType(StringType(), IntegerType()), None),
({"a": 1}, MapType(StringType(), IntegerType()), None),
({"a": 1}, MapType(IntegerType(), IntegerType()), TypeError),
({"a": "1"}, MapType(StringType(), IntegerType()), TypeError),
({"a": None}, MapType(StringType(), IntegerType(), valueContainsNull=True), None),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd also like you to add valueContainsNull=False case.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added

({"a": None}, MapType(StringType(), IntegerType(), valueContainsNull=False),
ValueError),

# Struct
({"s": "a", "i": 1}, MyStructType, None),
({"s": "a", "i": None}, MyStructType, None),
({"s": "a"}, MyStructType, None),
({"s": "a", "f": 1.0}, MyStructType, None), # Extra fields OK
({"s": "a", "i": "1"}, MyStructType, TypeError),
(Row(s="a", i=1), MyStructType, None),
(Row(s="a", i=None), MyStructType, None),
(Row(s="a", i=1, f=1.0), MyStructType, None), # Extra fields OK
(Row(s="a"), MyStructType, ValueError), # Row can't have missing field
(Row(s="a", i="1"), MyStructType, TypeError),
(["a", 1], MyStructType, None),
(["a", None], MyStructType, None),
(["a"], MyStructType, ValueError),
(["a", "1"], MyStructType, TypeError),
(("a", 1), MyStructType, None),
(MyObj(s="a", i=1), MyStructType, None),
(MyObj(s="a", i=None), MyStructType, None),
(MyObj(s="a"), MyStructType, None),
(MyObj(s="a", i="1"), MyStructType, TypeError),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here, None for s field.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added

(MyObj(s=None, i="1"), MyStructType, ValueError),
]

for obj, data_type, exp in spec:
msg = "_verify_type(%s, %s, nullable=False) == %s" % (obj, data_type, exp)
if exp is None:
try:
_verify_type(obj, data_type, nullable=False)
except Exception:
traceback.print_exc()
self.fail(msg)
else:
with self.assertRaises(exp, msg=msg):
_verify_type(obj, data_type, nullable=False)


if __name__ == "__main__":
from pyspark.sql.tests import *
if xmlrunner:
Expand Down
50 changes: 30 additions & 20 deletions python/pyspark/sql/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -1249,7 +1249,7 @@ def _infer_schema_type(obj, dataType):
}


def _verify_type(obj, dataType, nullable=True):
def _verify_type(obj, dataType, nullable=True, name="obj"):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a question. @dgingrich Do you maybe know if there is any change that "obj" is printed instead? It is rather a nitpick but I would think it is odds if it prints "obj".

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will print "obj" when called from session.createDataFrame (https://github.com/dgingrich/spark/blob/topic-spark-19507-verify-types/python/pyspark/sql/session.py#L408). It'd be easy to set the name where it's called but it wasn't clear what to set it to. The input can be either an RDD, list, or pandas.DataFrame.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's fix this case.

>>> from pyspark.sql.types import *
>>> spark.createDataFrame(["a"], StringType()).printSchema()
root
 |-- value: string (nullable = true)
>>> from pyspark.sql.types import *
>>> spark.createDataFrame(["a"], IntegerType()).printSchema()
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File ".../spark/python/pyspark/sql/session.py", line 526, in createDataFrame
    rdd, schema = self._createFromLocal(map(prepare, data), schema)
  File ".../spark/python/pyspark/sql/session.py", line 387, in _createFromLocal
    data = list(data)
  File ".../spark/python/pyspark/sql/session.py", line 516, in prepare
    verify_func(obj, dataType)
  File ".../spark/python/pyspark/sql/types.py", line 1326, in _verify_type
    % (name, dataType, obj, type(obj)))
TypeError: obj: IntegerType can not accept object 'a' in type <type 'str'>

It sounds "obj" should be "value". It looks we should specify the name around https://github.com/dgingrich/spark/blob/topic-spark-19507-verify-types/python/pyspark/sql/session.py#L516.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess this is only place that we print "obj" maybe? If so, let's set name=None.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Set name=value in the call at session.py line 516.

It will still print obj if the schema is a StructType: TypeError: obj.a: MyStructType can not accept object 'a' in type <type 'str'>. Would you like to change that too?

Right now changing the default name to None would make the error message worse: TypeError: None: IntegerType can not accept object 'a' in type <type 'str'>.

The best way to make the error message pretty is probably:

  • Set the default name to None
  • If name==None, don't prepend the %s: string to the error messages

That would make your exmple: TypeError: IntegerType can not accept object 'a' in type <type 'str'>.

IMO obj is not as pretty but reasonable since it's so simple. Let me know what you prefer. My only goal is that next time I get a schema failure it tells me what field to look at :)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we maybe then None and not print?

"""
Verify the type of obj against dataType, raise a TypeError if they do not match.

Expand Down Expand Up @@ -1300,70 +1300,80 @@ def _verify_type(obj, dataType, nullable=True):
if nullable:
return
else:
raise ValueError("This field is not nullable, but got None")
raise ValueError("%s: This field is not nullable, but got None" % name)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably, I missed something. However, is there any test case that actually checks this message change?

Copy link
Author

@dgingrich dgingrich Jun 22, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, I did't test the actual exception message. I normally don't check the contents of exception messages since they shouldn't be used programmatically (the tests are mostly to exercise all code paths to make sure I didn't break something).

But here it makes sense to check that the prefix is set since that's the main point of the PR. Added a test looking for the exception message prefix.


# StringType can work with any types
if isinstance(dataType, StringType):
return

if isinstance(dataType, UserDefinedType):
if not (hasattr(obj, '__UDT__') and obj.__UDT__ == dataType):
raise ValueError("%r is not an instance of type %r" % (obj, dataType))
_verify_type(dataType.toInternal(obj), dataType.sqlType())
raise ValueError("%s: %r is not an instance of type %r" % (name, obj, dataType))
_verify_type(dataType.toInternal(obj), dataType.sqlType(), name=name)
return

_type = type(dataType)
assert _type in _acceptable_types, "unknown datatype: %s for object %r" % (dataType, obj)
assert _type in _acceptable_types, \
"%s: unknown datatype: %s for object %r" % (name, dataType, obj)

if _type is StructType:
# check the type and fields later
pass
else:
# subclass of them can not be fromInternal in JVM
if type(obj) not in _acceptable_types[_type]:
raise TypeError("%s can not accept object %r in type %s" % (dataType, obj, type(obj)))
raise TypeError("%s: %s can not accept object %r in type %s"
% (name, dataType, obj, type(obj)))

if isinstance(dataType, ByteType):
if obj < -128 or obj > 127:
raise ValueError("object of ByteType out of range, got: %s" % obj)
raise ValueError("%s: object of ByteType out of range, got: %s" % (name, obj))

elif isinstance(dataType, ShortType):
if obj < -32768 or obj > 32767:
raise ValueError("object of ShortType out of range, got: %s" % obj)
raise ValueError("%s: object of ShortType out of range, got: %s" % (name, obj))

elif isinstance(dataType, IntegerType):
if obj < -2147483648 or obj > 2147483647:
raise ValueError("object of IntegerType out of range, got: %s" % obj)
raise ValueError("%s: object of IntegerType out of range, got: %s" % (name, obj))

elif isinstance(dataType, ArrayType):
for i in obj:
_verify_type(i, dataType.elementType, dataType.containsNull)
for i, value in enumerate(obj):
new_name = "%s[%d]" % (name, i)
_verify_type(value, dataType.elementType, dataType.containsNull, name=new_name)

elif isinstance(dataType, MapType):
for k, v in obj.items():
_verify_type(k, dataType.keyType, False)
_verify_type(v, dataType.valueType, dataType.valueContainsNull)
new_name = "%s[%s](key)" % (name, k)
_verify_type(k, dataType.keyType, False, name=new_name)
new_name = "%s[%s]" % (name, k)
_verify_type(v, dataType.valueType, dataType.valueContainsNull, name=new_name)

elif isinstance(dataType, StructType):
if isinstance(obj, dict):
for f in dataType.fields:
_verify_type(obj.get(f.name), f.dataType, f.nullable)
new_name = "%s.%s" % (name, f.name)
_verify_type(obj.get(f.name), f.dataType, f.nullable, name=new_name)
elif isinstance(obj, Row) and getattr(obj, "__from_dict__", False):
# the order in obj could be different than dataType.fields
for f in dataType.fields:
_verify_type(obj[f.name], f.dataType, f.nullable)
new_name = "%s.%s" % (name, f.name)
_verify_type(obj[f.name], f.dataType, f.nullable, name=new_name)
elif isinstance(obj, (tuple, list)):
if len(obj) != len(dataType.fields):
raise ValueError("Length of object (%d) does not match with "
"length of fields (%d)" % (len(obj), len(dataType.fields)))
raise ValueError("%s: Length of object (%d) does not match with "
"length of fields (%d)" % (name, len(obj), len(dataType.fields)))
for v, f in zip(obj, dataType.fields):
_verify_type(v, f.dataType, f.nullable)
new_name = "%s.%s" % (name, f.name)
_verify_type(v, f.dataType, f.nullable, name=new_name)
elif hasattr(obj, "__dict__"):
d = obj.__dict__
for f in dataType.fields:
_verify_type(d.get(f.name), f.dataType, f.nullable)
new_name = "%s.%s" % (name, f.name)
_verify_type(d.get(f.name), f.dataType, f.nullable, name=new_name)
else:
raise TypeError("StructType can not accept object %r in type %s" % (obj, type(obj)))
raise TypeError("%s: StructType can not accept object %r in type %s"
% (name, obj, type(obj)))


# This is used to unpickle a Row from JVM
Expand Down