Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 73 additions & 0 deletions python/pyspark/sql/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -485,6 +485,79 @@ def test_udf_with_array_type(self):
self.assertEqual(list(range(3)), l1)
self.assertEqual(1, l2)

def test_udf_returning_date_time(self):
from pyspark.sql.functions import udf
from pyspark.sql.types import DateType

data = self.spark.createDataFrame([(2017, 10, 30)], ['year', 'month', 'day'])

expected_date = datetime.date(2017, 10, 30)
expected_datetime = datetime.datetime(2017, 10, 30)

# test Python UDF with default returnType=StringType()
# Returning a date or datetime object at runtime with such returnType declaration
# is a mismatch, which results in a null, as PySpark treats it as unconvertible.
py_date_str, py_datetime_str = udf(datetime.date), udf(datetime.datetime)
query = data.select(
py_date_str(data.year, data.month, data.day).isNull(),
py_datetime_str(data.year, data.month, data.day).isNull())
[row] = query.collect()
self.assertEqual(row[0], True)
self.assertEqual(row[1], True)

query = data.select(
py_date_str(data.year, data.month, data.day),
py_datetime_str(data.year, data.month, data.day))
[row] = query.collect()
self.assertEqual(row[0], None)
self.assertEqual(row[1], None)

# test Python UDF with specific returnType matching actual result
py_date, py_datetime = udf(datetime.date, DateType()), udf(datetime.datetime, 'timestamp')
query = data.select(
py_date(data.year, data.month, data.day) == lit(expected_date),
py_datetime(data.year, data.month, data.day) == lit(expected_datetime))
[row] = query.collect()
self.assertEqual(row[0], True)
self.assertEqual(row[1], True)

query = data.select(
py_date(data.year, data.month, data.day),
py_datetime(data.year, data.month, data.day))
[row] = query.collect()
self.assertEqual(row[0], expected_date)
self.assertEqual(row[1], expected_datetime)

# test semantic matching of datetime with timezone
# class in __main__ is not serializable
from pyspark.sql.tests import UTCOffsetTimezone
datetime_with_utc0 = datetime.datetime(2017, 10, 30, tzinfo=UTCOffsetTimezone(0))
datetime_with_utc1 = datetime.datetime(2017, 10, 30, tzinfo=UTCOffsetTimezone(1))
test_udf = udf(lambda: datetime_with_utc0, 'timestamp')
query = data.select(
test_udf() == lit(datetime_with_utc0),
test_udf() > lit(datetime_with_utc1),
test_udf()
)
[row] = query.collect()
self.assertEqual(row[0], True)
self.assertEqual(row[1], True)
# Note: datetime returned from PySpark is always naive (timezone unaware).
# It currently respects Python's current local timezone.
self.assertEqual(row[2].tzinfo, None)

# tzinfo=None is really the same as not specifying it: a naive datetime object
# Just adding a test case for it here for completeness
datetime_with_null_timezone = datetime.datetime(2017, 10, 30, tzinfo=None)
test_udf = udf(lambda: datetime_with_null_timezone, 'timestamp')
query = data.select(
test_udf() == lit(datetime_with_null_timezone),
test_udf()
)
[row] = query.collect()
self.assertEqual(row[0], True)
self.assertEqual(row[1], datetime_with_null_timezone)

def test_broadcast_in_udf(self):
bar = {"a": "aa", "b": "bb", "c": "abc"}
foo = self.sc.broadcast(bar)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.python

import java.io.OutputStream
import java.nio.charset.StandardCharsets
import java.util.Calendar

import scala.collection.JavaConverters._

Expand Down Expand Up @@ -144,6 +145,7 @@ object EvaluatePython {
}

case StringType => (obj: Any) => nullSafeConvert(obj) {
case _: Calendar => null
case _ => UTF8String.fromString(obj.toString)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shall we blacklist more types? e.g. if a udf returns decimal and mark the return type as string type, is it a mismatch?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was pounding on that yesterday, too... somehow I have this feeling that no matter which direction we take, there's no good answer to type mismatch situations.

Let's say if we blacklist more types, should we document the list so that it's clear what will definitely NOT work?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the perfectness, I think we should check all the types, https://github.com/irmen/Pyrolite,

PYTHON    ---->     JAVA
------              ----
None                null
bool                boolean
int                 int
long                long or BigInteger  (depending on size)
string              String
unicode             String
complex             net.razorvine.pickle.objects.ComplexNumber
datetime.date       java.util.Calendar
datetime.datetime   java.util.Calendar
datetime.time       net.razorvine.pickle.objects.Time
datetime.timedelta  net.razorvine.pickle.objects.TimeDelta
float               double   (float isn't used) 
array.array         array of appropriate primitive type (char, int, short, long, float, double)
list                java.util.List<Object>
tuple               Object[]
set                 java.util.Set
dict                java.util.Map
bytes               byte[]
bytearray           byte[]
decimal             BigDecimal    
custom class        Map<String, Object>  (dict with class attributes including its name in "__class__")
Pyro4.core.URI      net.razorvine.pyro.PyroURI
Pyro4.core.Proxy    net.razorvine.pyro.PyroProxy
Pyro4.errors.*      net.razorvine.pyro.PyroException
Pyro4.utils.flame.FlameBuiltin     net.razorvine.pyro.FlameBuiltin 
Pyro4.utils.flame.FlameModule      net.razorvine.pyro.FlameModule 
Pyro4.utils.flame.RemoteInteractiveConsole    net.razorvine.pyro.FlameRemoteConsole 

and then check if the string conversion looks reasonably consistent by obj.toString. If not, we add it in the blacklist.

Another possibility is to whitelist String, but then I guess this is rather a radical change.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

check if the string conversion looks reasonably consistent by obj.toString. If not, we add it in the blacklist.

hmm, this seems weird as the type mismatch now is defined by Pyrolite object's toString behavior...

Copy link
Member

@HyukjinKwon HyukjinKwon Jan 13, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So, for now .. I think it's fine as a small fix as is ... We are going to document that the return type and return value should be matched anyway ..

So, expected return values will be (including dict, list, tuple and array):

# Mapping Python types to Spark SQL DataType
_type_mappings = {
type(None): NullType,
bool: BooleanType,
int: LongType,
float: DoubleType,
str: StringType,
bytearray: BinaryType,
decimal.Decimal: DecimalType,
datetime.date: DateType,
datetime.datetime: TimestampType,
datetime.time: TimestampType,
}
if sys.version < "3":
_type_mappings.update({
unicode: StringType,
long: LongType,
})

Seems, we can also check if the string conversion looks reasonable and then blacklist net.razorvine.pickle.objects.Time if not ...

How does this sound to you @cloud-fan and @rednaxelafx?

Copy link
Member

@HyukjinKwon HyukjinKwon Jan 13, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BTW, seems there is another hole when we actually do the internal conversion with unexpected types:

>>> from pyspark.sql.functions import udf
>>> f = udf(lambda x: x, "date")
>>> spark.range(1).select(f("id")).show()
org.apache.spark.api.python.PythonException: Traceback (most recent call last):
  File "./python/pyspark/worker.py", line 229, in main
    process()
  File "./python/pyspark/worker.py", line 224, in process
    serializer.dump_stream(func(split_index, iterator), outfile)
  File "./python/pyspark/worker.py", line 149, in <lambda>
    func = lambda _, it: map(mapper, it)
  File "<string>", line 1, in <lambda>
  File "./python/pyspark/worker.py", line 72, in <lambda>
    return lambda *a: toInternal(f(*a))
  File "/.../pyspark/sql/types.py", line 175, in toInternal
    return d.toordinal() - self.EPOCH_ORDINAL
AttributeError: 'int' object has no attribute 'toordinal'

another hole

>>> from pyspark.sql.functions import udf, struct
>>> f = udf(lambda x: x, "string")
>>> spark.range(1).select(f(struct("id"))).show()
net.razorvine.pickle.PickleException: expected zero arguments for construction of ClassDict (for pyspark.sql.types._create_row)
	at net.razorvine.pickle.objects.ClassDictConstructor.construct(ClassDictConstructor.java:23)
	at net.razorvine.pickle.Unpickler.load_reduce(Unpickler.java:707)
	at net.razorvine.pickle.Unpickler.dispatch(Unpickler.java:175)
	at net.razorvine.pickle.Unpickler.load(Unpickler.java:99)
	at net.razorvine.pickle.Unpickler.loads(Unpickler.java:112)
	at org.apache.spark.sql.execution.python.BatchEvalPythonExec$$anonfun$evaluate$1.apply(BatchEvalPythonExec.scala:86)
	at org.apache.spark.sql.execution.python.BatchEvalPythonExec$$anonfun$evaluate$1.apply(BatchEvalPythonExec.scala:85)
	at scala.collection.Iterator$$anon$12.nextCur(Iterator.scala:434)
	at scala.collection.Iterator$$anon$12.hasNext(Iterator.scala:440)
	at scala.collection.Iterator$$anon$11.hasNext(Iterator.scala:408)
	at scala.collection.Iterator$$anon$11.hasNext(Iterator.scala:408)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think there is no perfect solution .. I think #20163 (comment) sounds good enough as a fix for this issue for now ..

Copy link
Member

@HyukjinKwon HyukjinKwon Jan 13, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@cloud-fan, how about something like this then?

    case StringType => (obj: Any) => nullSafeConvert(obj) {
      // Shortcut for string conversion
      case c: String => UTF8String.fromString(c)

      // Here, we return null for 'array', 'tuple', 'dict', 'list', 'datetime.datetime',
      // 'datetime.date' and 'datetime.time' because those string conversions are
      // not quite consistent with SQL string representation of data.
      case _: java.util.Calendar | _: net.razorvine.pickle.objects.Time |
           _: java.util.List[_] | _: java.util.Map[_, _] =>
        null
      case c if c.getClass.isArray => null

      // Here, we keep the string conversion fall back for compatibility.
      // TODO: We should revisit this and rewrite the type conversion logic in Spark 3.x.
      case c => UTF8String.fromString(c.toString)
    }

My few tests:

datetime.time:

from pyspark.sql.functions import udf
from datetime import time

f = udf(lambda x: time(0, 0), "string")
spark.range(1).select(f("id")).show()
+--------------------+
|        <lambda>(id)|
+--------------------+
|Time: 0 hours, 0 ...|
+--------------------+

array:

from pyspark.sql.functions import udf
import array

f = udf(lambda x: array.array("c", "aaa"), "string")
spark.range(1).select(f("id")).show()
+------------+
|<lambda>(id)|
+------------+
| [C@11618d9e|
+------------+

tuple:

from pyspark.sql.functions import udf

f = udf(lambda x: (x,), "string")
spark.range(1).select(f("id")).show()
+--------------------+
|        <lambda>(id)|
+--------------------+
|[Ljava.lang.Objec...|
+--------------------+

list:

from pyspark.sql.functions import udf
from datetime import datetime

f = udf(lambda x: [datetime(1990, 1, 1)], "string")
spark.range(1).select(f("id")).show()
+--------------------+
|        <lambda>(id)|
+--------------------+
|[java.util.Gregor...|
+--------------------+

dict:

from pyspark.sql.functions import udf
from datetime import datetime

f = udf(lambda x: {1: datetime(1990, 1, 1)}, "string")
spark.range(1).select(f("id")).show()
+--------------------+
|        <lambda>(id)|
+--------------------+
|{1=java.util.Greg...|
+--------------------+

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks good

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

btw, the array case seems a bit weird?

}

Expand Down