Skip to content
Closed
1 change: 0 additions & 1 deletion python/pyspark/rdd.py
Original file line number Diff line number Diff line change
Expand Up @@ -627,7 +627,6 @@ def sortPartition(iterator):
def sortByKey(self, ascending=True, numPartitions=None, keyfunc=lambda x: x):
"""
Sorts this RDD, which is assumed to consist of (key, value) pairs.
# noqa

>>> tmp = [('a', 1), ('b', 2), ('1', 3), ('d', 4), ('2', 5)]
>>> sc.parallelize(tmp).sortByKey().first()
Expand Down
12 changes: 8 additions & 4 deletions python/pyspark/sql/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
from pyspark.sql.dataframe import DataFrame
from pyspark.sql.readwriter import DataFrameReader
from pyspark.sql.streaming import DataStreamReader
from pyspark.sql.types import Row, DataType, StringType, StructType, _verify_type, \
from pyspark.sql.types import Row, DataType, StringType, StructType, _make_type_verifier, \
_infer_schema, _has_nulltype, _merge_type, _create_converter, _parse_datatype_string
from pyspark.sql.utils import install_exception_handler

Expand Down Expand Up @@ -514,17 +514,21 @@ def createDataFrame(self, data, schema=None, samplingRatio=None, verifySchema=Tr
schema = [str(x) for x in data.columns]
data = [r.tolist() for r in data.to_records(index=False)]

verify_func = _verify_type if verifySchema else lambda _, t: True
if isinstance(schema, StructType):
verify_func = _make_type_verifier(schema) if verifySchema else lambda _: True

def prepare(obj):
verify_func(obj, schema)
verify_func(obj)
return obj
elif isinstance(schema, DataType):
dataType = schema
schema = StructType().add("value", schema)

verify_func = _make_type_verifier(
dataType, name="field value") if verifySchema else lambda _: True
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is "field value" useful in the error message?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think so. It should return None I think but I just wanted to avoid other changes for reviewing.

Copy link
Member Author

@HyukjinKwon HyukjinKwon Jul 4, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, wait you mean field value. Yes, this is printed as a part of exception message and not used otherwise.


def prepare(obj):
verify_func(obj, dataType)
verify_func(obj)
return obj,
else:
if isinstance(schema, list):
Expand Down
203 changes: 197 additions & 6 deletions python/pyspark/sql/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@
from pyspark import SparkContext
from pyspark.sql import SparkSession, SQLContext, HiveContext, Column, Row
from pyspark.sql.types import *
from pyspark.sql.types import UserDefinedType, _infer_type
from pyspark.sql.types import UserDefinedType, _infer_type, _make_type_verifier
from pyspark.tests import ReusedPySparkTestCase, SparkSubmitTests
from pyspark.sql.functions import UserDefinedFunction, sha2, lit
from pyspark.sql.window import Window
Expand Down Expand Up @@ -852,7 +852,7 @@ def test_convert_row_to_dict(self):
self.assertEqual(1.0, row.asDict()['d']['key'].c)

def test_udt(self):
from pyspark.sql.types import _parse_datatype_json_string, _infer_type, _verify_type
from pyspark.sql.types import _parse_datatype_json_string, _infer_type, _make_type_verifier
from pyspark.sql.tests import ExamplePointUDT, ExamplePoint

def check_datatype(datatype):
Expand All @@ -868,17 +868,19 @@ def check_datatype(datatype):
check_datatype(structtype_with_udt)
p = ExamplePoint(1.0, 2.0)
self.assertEqual(_infer_type(p), ExamplePointUDT())
_verify_type(ExamplePoint(1.0, 2.0), ExamplePointUDT())
self.assertRaises(ValueError, lambda: _verify_type([1.0, 2.0], ExamplePointUDT()))
_make_type_verifier(ExamplePointUDT())(ExamplePoint(1.0, 2.0))
self.assertRaises(ValueError, lambda: _make_type_verifier(ExamplePointUDT())([1.0, 2.0]))

check_datatype(PythonOnlyUDT())
structtype_with_udt = StructType([StructField("label", DoubleType(), False),
StructField("point", PythonOnlyUDT(), False)])
check_datatype(structtype_with_udt)
p = PythonOnlyPoint(1.0, 2.0)
self.assertEqual(_infer_type(p), PythonOnlyUDT())
_verify_type(PythonOnlyPoint(1.0, 2.0), PythonOnlyUDT())
self.assertRaises(ValueError, lambda: _verify_type([1.0, 2.0], PythonOnlyUDT()))
_make_type_verifier(PythonOnlyUDT())(PythonOnlyPoint(1.0, 2.0))
self.assertRaises(
ValueError,
lambda: _make_type_verifier(PythonOnlyUDT())([1.0, 2.0]))

def test_simple_udt_in_df(self):
schema = StructType().add("key", LongType()).add("val", PythonOnlyUDT())
Expand Down Expand Up @@ -2620,6 +2622,195 @@ def range_frame_match():

importlib.reload(window)


class DataTypeVerificationTests(unittest.TestCase):

def test_verify_type_exception_msg(self):
self.assertRaisesRegexp(
ValueError,
"test_name",
lambda: _make_type_verifier(StringType(), nullable=False, name="test_name")(None))

schema = StructType([StructField('a', StructType([StructField('b', IntegerType())]))])
self.assertRaisesRegexp(
TypeError,
"field b in field a",
lambda: _make_type_verifier(schema)([["data"]]))

def test_verify_type_ok_nullable(self):
obj = None
types = [IntegerType(), FloatType(), StringType(), StructType([])]
for data_type in types:
try:
_make_type_verifier(data_type, nullable=True)(obj)
except Exception:
self.fail("verify_type(%s, %s, nullable=True)" % (obj, data_type))

def test_verify_type_not_nullable(self):
import array
import datetime
import decimal

schema = StructType([
StructField('s', StringType(), nullable=False),
StructField('i', IntegerType(), nullable=True)])

class MyObj:
def __init__(self, **kwargs):
for k, v in kwargs.items():
setattr(self, k, v)

# obj, data_type
success_spec = [
# String
("", StringType()),
(u"", StringType()),
(1, StringType()),
(1.0, StringType()),
([], StringType()),
({}, StringType()),

# UDT
(ExamplePoint(1.0, 2.0), ExamplePointUDT()),

# Boolean
(True, BooleanType()),

# Byte
(-(2**7), ByteType()),
(2**7 - 1, ByteType()),

# Short
(-(2**15), ShortType()),
(2**15 - 1, ShortType()),

# Integer
(-(2**31), IntegerType()),
(2**31 - 1, IntegerType()),

# Long
(2**64, LongType()),

# Float & Double
(1.0, FloatType()),
(1.0, DoubleType()),

# Decimal
(decimal.Decimal("1.0"), DecimalType()),

# Binary
(bytearray([1, 2]), BinaryType()),

# Date/Timestamp
(datetime.date(2000, 1, 2), DateType()),
(datetime.datetime(2000, 1, 2, 3, 4), DateType()),
(datetime.datetime(2000, 1, 2, 3, 4), TimestampType()),

# Array
([], ArrayType(IntegerType())),
(["1", None], ArrayType(StringType(), containsNull=True)),
([1, 2], ArrayType(IntegerType())),
((1, 2), ArrayType(IntegerType())),
(array.array('h', [1, 2]), ArrayType(IntegerType())),

# Map
({}, MapType(StringType(), IntegerType())),
({"a": 1}, MapType(StringType(), IntegerType())),
({"a": None}, MapType(StringType(), IntegerType(), valueContainsNull=True)),

# Struct
({"s": "a", "i": 1}, schema),
({"s": "a", "i": None}, schema),
({"s": "a"}, schema),
({"s": "a", "f": 1.0}, schema),
(Row(s="a", i=1), schema),
(Row(s="a", i=None), schema),
(Row(s="a", i=1, f=1.0), schema),
(["a", 1], schema),
(["a", None], schema),
(("a", 1), schema),
(MyObj(s="a", i=1), schema),
(MyObj(s="a", i=None), schema),
(MyObj(s="a"), schema),
]

# obj, data_type, exception class
failure_spec = [
# String (match anything but None)
(None, StringType(), ValueError),

# UDT
(ExamplePoint(1.0, 2.0), PythonOnlyUDT(), ValueError),

# Boolean
(1, BooleanType(), TypeError),
("True", BooleanType(), TypeError),
([1], BooleanType(), TypeError),

# Byte
(-(2**7) - 1, ByteType(), ValueError),
(2**7, ByteType(), ValueError),
("1", ByteType(), TypeError),
(1.0, ByteType(), TypeError),

# Short
(-(2**15) - 1, ShortType(), ValueError),
(2**15, ShortType(), ValueError),

# Integer
(-(2**31) - 1, IntegerType(), ValueError),
(2**31, IntegerType(), ValueError),

# Float & Double
(1, FloatType(), TypeError),
(1, DoubleType(), TypeError),

# Decimal
(1.0, DecimalType(), TypeError),
(1, DecimalType(), TypeError),
("1.0", DecimalType(), TypeError),

# Binary
(1, BinaryType(), TypeError),

# Date/Timestamp
("2000-01-02", DateType(), TypeError),
(946811040, TimestampType(), TypeError),

# Array
(["1", None], ArrayType(StringType(), containsNull=False), ValueError),
([1, "2"], ArrayType(IntegerType()), TypeError),

# Map
({"a": 1}, MapType(IntegerType(), IntegerType()), TypeError),
({"a": "1"}, MapType(StringType(), IntegerType()), TypeError),
({"a": None}, MapType(StringType(), IntegerType(), valueContainsNull=False),
ValueError),

# Struct
({"s": "a", "i": "1"}, schema, TypeError),
(Row(s="a"), schema, ValueError), # Row can't have missing field
(Row(s="a", i="1"), schema, TypeError),
(["a"], schema, ValueError),
(["a", "1"], schema, TypeError),
(MyObj(s="a", i="1"), schema, TypeError),
(MyObj(s=None, i="1"), schema, ValueError),
]

# Check success cases
for obj, data_type in success_spec:
try:
_make_type_verifier(data_type, nullable=False)(obj)
except Exception:
self.fail("verify_type(%s, %s, nullable=False)" % (obj, data_type))

# Check failure cases
for obj, data_type, exp in failure_spec:
msg = "verify_type(%s, %s, nullable=False) == %s" % (obj, data_type, exp)
with self.assertRaises(exp, msg=msg):
_make_type_verifier(data_type, nullable=False)(obj)


if __name__ == "__main__":
from pyspark.sql.tests import *
if xmlrunner:
Expand Down
Loading