diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 80a94a91a87b3..d9656866f833f 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -485,6 +485,79 @@ def test_udf_with_array_type(self): self.assertEqual(list(range(3)), l1) self.assertEqual(1, l2) + def test_udf_returning_date_time(self): + from pyspark.sql.functions import udf + from pyspark.sql.types import DateType + + data = self.spark.createDataFrame([(2017, 10, 30)], ['year', 'month', 'day']) + + expected_date = datetime.date(2017, 10, 30) + expected_datetime = datetime.datetime(2017, 10, 30) + + # test Python UDF with default returnType=StringType() + # Returning a date or datetime object at runtime with such returnType declaration + # is a mismatch, which results in a null, as PySpark treats it as unconvertible. + py_date_str, py_datetime_str = udf(datetime.date), udf(datetime.datetime) + query = data.select( + py_date_str(data.year, data.month, data.day).isNull(), + py_datetime_str(data.year, data.month, data.day).isNull()) + [row] = query.collect() + self.assertEqual(row[0], True) + self.assertEqual(row[1], True) + + query = data.select( + py_date_str(data.year, data.month, data.day), + py_datetime_str(data.year, data.month, data.day)) + [row] = query.collect() + self.assertEqual(row[0], None) + self.assertEqual(row[1], None) + + # test Python UDF with specific returnType matching actual result + py_date, py_datetime = udf(datetime.date, DateType()), udf(datetime.datetime, 'timestamp') + query = data.select( + py_date(data.year, data.month, data.day) == lit(expected_date), + py_datetime(data.year, data.month, data.day) == lit(expected_datetime)) + [row] = query.collect() + self.assertEqual(row[0], True) + self.assertEqual(row[1], True) + + query = data.select( + py_date(data.year, data.month, data.day), + py_datetime(data.year, data.month, data.day)) + [row] = query.collect() + self.assertEqual(row[0], expected_date) + self.assertEqual(row[1], expected_datetime) + + # test semantic matching of datetime with timezone + # class in __main__ is not serializable + from pyspark.sql.tests import UTCOffsetTimezone + datetime_with_utc0 = datetime.datetime(2017, 10, 30, tzinfo=UTCOffsetTimezone(0)) + datetime_with_utc1 = datetime.datetime(2017, 10, 30, tzinfo=UTCOffsetTimezone(1)) + test_udf = udf(lambda: datetime_with_utc0, 'timestamp') + query = data.select( + test_udf() == lit(datetime_with_utc0), + test_udf() > lit(datetime_with_utc1), + test_udf() + ) + [row] = query.collect() + self.assertEqual(row[0], True) + self.assertEqual(row[1], True) + # Note: datetime returned from PySpark is always naive (timezone unaware). + # It currently respects Python's current local timezone. + self.assertEqual(row[2].tzinfo, None) + + # tzinfo=None is really the same as not specifying it: a naive datetime object + # Just adding a test case for it here for completeness + datetime_with_null_timezone = datetime.datetime(2017, 10, 30, tzinfo=None) + test_udf = udf(lambda: datetime_with_null_timezone, 'timestamp') + query = data.select( + test_udf() == lit(datetime_with_null_timezone), + test_udf() + ) + [row] = query.collect() + self.assertEqual(row[0], True) + self.assertEqual(row[1], datetime_with_null_timezone) + def test_broadcast_in_udf(self): bar = {"a": "aa", "b": "bb", "c": "abc"} foo = self.sc.broadcast(bar) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala index 520afad287648..ec9fc2612bf95 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.python import java.io.OutputStream import java.nio.charset.StandardCharsets +import java.util.Calendar import scala.collection.JavaConverters._ @@ -144,6 +145,7 @@ object EvaluatePython { } case StringType => (obj: Any) => nullSafeConvert(obj) { + case _: Calendar => null case _ => UTF8String.fromString(obj.toString) }