Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions dev/sparktestsupport/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,6 +467,7 @@ def __hash__(self):
"pyspark.sql.observation",
# unittests
"pyspark.sql.tests.test_arrow",
"pyspark.sql.tests.test_arrow_python_udf",
"pyspark.sql.tests.test_catalog",
"pyspark.sql.tests.test_column",
"pyspark.sql.tests.test_conf",
Expand Down
24 changes: 20 additions & 4 deletions python/pyspark/sql/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
from pyspark.sql.types import ArrayType, DataType, StringType, StructType, _from_numpy_type

# Keep UserDefinedFunction import for backwards compatible import; moved in SPARK-22409
from pyspark.sql.udf import UserDefinedFunction, _create_udf # noqa: F401
from pyspark.sql.udf import UserDefinedFunction, _create_py_udf # noqa: F401

# Keep pandas_udf and PandasUDFType import for backwards compatible import; moved in SPARK-28264
from pyspark.sql.pandas.functions import pandas_udf, PandasUDFType # noqa: F401
Expand Down Expand Up @@ -9980,14 +9980,19 @@ def unwrap_udt(col: "ColumnOrName") -> Column:

@overload
def udf(
f: Callable[..., Any], returnType: "DataTypeOrString" = StringType()
f: Callable[..., Any],
returnType: "DataTypeOrString" = StringType(),
*,
useArrow: Optional[bool] = None,
) -> "UserDefinedFunctionLike":
...


@overload
def udf(
f: Optional["DataTypeOrString"] = None,
*,
useArrow: Optional[bool] = None,
) -> Callable[[Callable[..., Any]], "UserDefinedFunctionLike"]:
...

Expand All @@ -9996,13 +10001,16 @@ def udf(
def udf(
*,
returnType: "DataTypeOrString" = StringType(),
useArrow: Optional[bool] = None,
) -> Callable[[Callable[..., Any]], "UserDefinedFunctionLike"]:
...


def udf(
f: Optional[Union[Callable[..., Any], "DataTypeOrString"]] = None,
returnType: "DataTypeOrString" = StringType(),
*,
useArrow: Optional[bool] = None,
) -> Union["UserDefinedFunctionLike", Callable[[Callable[..., Any]], "UserDefinedFunctionLike"]]:
"""Creates a user defined function (UDF).

Expand All @@ -10015,6 +10023,9 @@ def udf(
returnType : :class:`pyspark.sql.types.DataType` or str
the return type of the user-defined function. The value can be either a
:class:`pyspark.sql.types.DataType` object or a DDL-formatted type string.
useArrow : bool or None
whether to use Arrow to optimize the (de)serialization. When it is None, the
Spark config "spark.sql.execution.pythonUDF.arrow.enabled" takes effect.

Examples
--------
Expand Down Expand Up @@ -10093,10 +10104,15 @@ def udf(
# for decorator use it as a returnType
return_type = f or returnType
return functools.partial(
_create_udf, returnType=return_type, evalType=PythonEvalType.SQL_BATCHED_UDF
_create_py_udf,
returnType=return_type,
evalType=PythonEvalType.SQL_BATCHED_UDF,
useArrow=useArrow,
)
else:
return _create_udf(f=f, returnType=returnType, evalType=PythonEvalType.SQL_BATCHED_UDF)
return _create_py_udf(
f=f, returnType=returnType, evalType=PythonEvalType.SQL_BATCHED_UDF, useArrow=useArrow
)


def _test() -> None:
Expand Down
130 changes: 130 additions & 0 deletions python/pyspark/sql/tests/test_arrow_python_udf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import unittest

from pyspark.sql.functions import udf
from pyspark.sql.tests.test_udf import BaseUDFTests
from pyspark.testing.sqlutils import (
have_pandas,
have_pyarrow,
pandas_requirement_message,
pyarrow_requirement_message,
ReusedSQLTestCase,
)


@unittest.skipIf(
not have_pandas or not have_pyarrow, pandas_requirement_message or pyarrow_requirement_message
)
class PythonUDFArrowTests(BaseUDFTests, ReusedSQLTestCase):
@classmethod
def setUpClass(cls):
super(PythonUDFArrowTests, cls).setUpClass()
cls.spark.conf.set("spark.sql.execution.pythonUDF.arrow.enabled", "true")

@unittest.skip("Unrelated test, and it fails when it runs duplicatedly.")
def test_broadcast_in_udf(self):
super(PythonUDFArrowTests, self).test_broadcast_in_udf()

@unittest.skip("Unrelated test, and it fails when it runs duplicatedly.")
def test_register_java_function(self):
super(PythonUDFArrowTests, self).test_register_java_function()

@unittest.skip("Unrelated test, and it fails when it runs duplicatedly.")
def test_register_java_udaf(self):
super(PythonUDFArrowTests, self).test_register_java_udaf()

@unittest.skip("Struct input types are not supported with Arrow optimization")
def test_udf_input_serialization_valuecompare_disabled(self):
super(PythonUDFArrowTests, self).test_udf_input_serialization_valuecompare_disabled()

def test_nested_input_error(self):
with self.assertRaisesRegexp(
Exception, "NotImplementedError: Struct input type are not supported"
):
self.spark.range(1).selectExpr("struct(1, 2) as struct").select(
udf(lambda x: x)("struct")
).collect()

def test_complex_input_types(self):
row = (
self.spark.range(1)
.selectExpr("array(1, 2, 3) as array", "map('a', 'b') as map")
.select(
udf(lambda x: str(x))("array"),
udf(lambda x: str(x))("map"),
)
.first()
)

# The input is NumPy array when the optimization is on.
self.assertEquals(row[0], "[1 2 3]")
self.assertEquals(row[1], "{'a': 'b'}")

def test_use_arrow(self):
# useArrow=True
row_true = (
self.spark.range(1)
.selectExpr(
"array(1, 2, 3) as array",
)
.select(
udf(lambda x: str(x), useArrow=True)("array"),
)
.first()
)

# useArrow=None
row_none = (
self.spark.range(1)
.selectExpr(
"array(1, 2, 3) as array",
)
.select(
udf(lambda x: str(x), useArrow=None)("array"),
)
.first()
)

# The input is a NumPy array when the Arrow optimization is on.
self.assertEquals(row_true[0], row_none[0]) # "[1 2 3]"

# useArrow=False
row_false = (
self.spark.range(1)
.selectExpr(
"array(1, 2, 3) as array",
)
.select(
udf(lambda x: str(x), useArrow=False)("array"),
)
.first()
)
self.assertEquals(row_false[0], "[1, 2, 3]")


if __name__ == "__main__":
from pyspark.sql.tests.test_arrow_python_udf import * # noqa: F401

try:
import xmlrunner

testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
except ImportError:
testRunner = None
unittest.main(testRunner=testRunner, verbosity=2)
50 changes: 49 additions & 1 deletion python/pyspark/sql/tests/test_udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
from pyspark.testing.utils import QuietTest


class UDFTests(ReusedSQLTestCase):
class BaseUDFTests(object):
def test_udf_with_callable(self):
d = [Row(number=i, squared=i**2) for i in range(10)]
rdd = self.sc.parallelize(d)
Expand Down Expand Up @@ -804,6 +804,54 @@ def test_udf_with_rand(self):
)


class UDFTests(BaseUDFTests, ReusedSQLTestCase):
@classmethod
def setUpClass(cls):
super(BaseUDFTests, cls).setUpClass()
cls.spark.conf.set("spark.sql.execution.pythonUDF.arrow.enabled", "false")


def test_use_arrow(self):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@xinrong-meng Who runs this test?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch! Removed it 63ef94e. test_use_arrow is duplicated in PythonUDFArrowTestsMixin of test_arrow_python_udf.py.

# useArrow=True
row_true = (
self.spark.range(1)
.selectExpr(
"array(1, 2, 3) as array",
)
.select(
udf(lambda x: str(x), useArrow=True)("array"),
)
.first()
)
# The input is a NumPy array when the Arrow optimization is on.
self.assertEquals(row_true[0], "[1 2 3]")

# useArrow=None
row_none = (
self.spark.range(1)
.selectExpr(
"array(1, 2, 3) as array",
)
.select(
udf(lambda x: str(x), useArrow=None)("array"),
)
.first()
)

# useArrow=False
row_false = (
self.spark.range(1)
.selectExpr(
"array(1, 2, 3) as array",
)
.select(
udf(lambda x: str(x), useArrow=False)("array"),
)
.first()
)
self.assertEquals(row_false[0], row_none[0]) # "[1, 2, 3]"


class UDFInitializationTests(unittest.TestCase):
def tearDown(self):
if SparkSession._instantiatedSession is not None:
Expand Down
Loading