From 63bb36b548e6ce125e99103b0c7a653481b6dfa5 Mon Sep 17 00:00:00 2001 From: Xinrong Meng Date: Mon, 10 Apr 2023 10:49:57 -0700 Subject: [PATCH 01/13] _create_arrow_py_udf --- python/pyspark/sql/udf.py | 81 +++++++++++++++++++++------------------ 1 file changed, 43 insertions(+), 38 deletions(-) diff --git a/python/pyspark/sql/udf.py b/python/pyspark/sql/udf.py index 52d02dc00c258..fde2a7298e007 100644 --- a/python/pyspark/sql/udf.py +++ b/python/pyspark/sql/udf.py @@ -139,48 +139,53 @@ def _create_py_udf( and not isinstance(return_type, ArrayType) ) if is_arrow_enabled and is_output_atomic_type and is_func_with_args: - require_minimum_pandas_version() - require_minimum_pyarrow_version() - - import pandas as pd - from pyspark.sql.pandas.functions import _create_pandas_udf # type: ignore[attr-defined] - - # "result_func" ensures the result of a Python UDF to be consistent with/without Arrow - # optimization. - # Otherwise, an Arrow-optimized Python UDF raises "pyarrow.lib.ArrowTypeError: Expected a - # string or bytes dtype, got ..." whereas a non-Arrow-optimized Python UDF returns - # successfully. - result_func = lambda pdf: pdf # noqa: E731 - if type(return_type) == StringType: - result_func = lambda r: str(r) if r is not None else r # noqa: E731 - elif type(return_type) == BinaryType: - result_func = lambda r: bytes(r) if r is not None else r # noqa: E731 - - def vectorized_udf(*args: pd.Series) -> pd.Series: - if any(map(lambda arg: isinstance(arg, pd.DataFrame), args)): - raise NotImplementedError( - "Struct input type are not supported with Arrow optimization " - "enabled in Python UDFs. Disable " - "'spark.sql.execution.pythonUDF.arrow.enabled' to workaround." - ) - return pd.Series(result_func(f(*a)) for a in zip(*args)) - - # Regular UDFs can take callable instances too. - vectorized_udf.__name__ = f.__name__ if hasattr(f, "__name__") else f.__class__.__name__ - vectorized_udf.__module__ = ( - f.__module__ if hasattr(f, "__module__") else f.__class__.__module__ - ) - vectorized_udf.__doc__ = f.__doc__ - pudf = _create_pandas_udf(vectorized_udf, returnType, None) - # Keep the attributes as if this is a regular Python UDF. - pudf.func = f - pudf.returnType = return_type - pudf.evalType = regular_udf.evalType - return pudf + return _create_arrow_py_udf(f, regular_udf) else: return regular_udf +def _create_arrow_py_udf(f, regular_udf): # type: ignore + print("entered _create_arrow_py_udf") + require_minimum_pandas_version() + require_minimum_pyarrow_version() + + import pandas as pd + from pyspark.sql.pandas.functions import _create_pandas_udf + + return_type = regular_udf.returnType + + # "result_func" ensures the result of a Python UDF to be consistent with/without Arrow + # optimization. + # Otherwise, an Arrow-optimized Python UDF raises "pyarrow.lib.ArrowTypeError: Expected a + # string or bytes dtype, got ..." whereas a non-Arrow-optimized Python UDF returns + # successfully. + result_func = lambda pdf: pdf # noqa: E731 + if type(return_type) == StringType: + result_func = lambda r: str(r) if r is not None else r # noqa: E731 + elif type(return_type) == BinaryType: + result_func = lambda r: bytes(r) if r is not None else r # noqa: E731 + + def vectorized_udf(*args: pd.Series) -> pd.Series: + if any(map(lambda arg: isinstance(arg, pd.DataFrame), args)): + raise NotImplementedError( + "Struct input type are not supported with Arrow optimization " + "enabled in Python UDFs. Disable " + "'spark.sql.execution.pythonUDF.arrow.enabled' to workaround." + ) + return pd.Series(result_func(f(*a)) for a in zip(*args)) + + # Regular UDFs can take callable instances too. + vectorized_udf.__name__ = f.__name__ if hasattr(f, "__name__") else f.__class__.__name__ + vectorized_udf.__module__ = f.__module__ if hasattr(f, "__module__") else f.__class__.__module__ + vectorized_udf.__doc__ = f.__doc__ + pudf = _create_pandas_udf(vectorized_udf, return_type, None) + # Keep the attributes as if this is a regular Python UDF. + pudf.func = f + pudf.returnType = return_type + pudf.evalType = regular_udf.evalType + return pudf + + class UserDefinedFunction: """ User defined function in Python From d702b67e651ca089280b1d1074b321619049fc45 Mon Sep 17 00:00:00 2001 From: Xinrong Meng Date: Mon, 10 Apr 2023 10:50:15 -0700 Subject: [PATCH 02/13] in Connect --- python/pyspark/sql/connect/functions.py | 12 ++++++-- python/pyspark/sql/connect/udf.py | 39 ++++++++++++++++++++++++- 2 files changed, 47 insertions(+), 4 deletions(-) diff --git a/python/pyspark/sql/connect/functions.py b/python/pyspark/sql/connect/functions.py index 1c73bcdc84922..e7ecf267abc1c 100644 --- a/python/pyspark/sql/connect/functions.py +++ b/python/pyspark/sql/connect/functions.py @@ -50,7 +50,7 @@ LambdaFunction, UnresolvedNamedLambdaVariable, ) -from pyspark.sql.connect.udf import _create_udf +from pyspark.sql.connect.udf import _create_py_udf from pyspark.sql import functions as pysparkfuncs from pyspark.sql.types import _from_numpy_type, DataType, StructType, ArrayType, StringType @@ -2461,6 +2461,7 @@ def unwrap_udt(col: "ColumnOrName") -> Column: def udf( f: Optional[Union[Callable[..., Any], "DataTypeOrString"]] = None, returnType: "DataTypeOrString" = StringType(), + useArrow: Optional[bool] = None, ) -> Union["UserDefinedFunctionLike", Callable[[Callable[..., Any]], "UserDefinedFunctionLike"]]: from pyspark.rdd import PythonEvalType @@ -2469,10 +2470,15 @@ def udf( # for decorator use it as a returnType return_type = f or returnType return functools.partial( - _create_udf, returnType=return_type, evalType=PythonEvalType.SQL_BATCHED_UDF + _create_py_udf, + returnType=return_type, + evalType=PythonEvalType.SQL_BATCHED_UDF, + useArrow=useArrow, ) else: - return _create_udf(f=f, returnType=returnType, evalType=PythonEvalType.SQL_BATCHED_UDF) + return _create_py_udf( + f=f, returnType=returnType, evalType=PythonEvalType.SQL_BATCHED_UDF, useArrow=useArrow + ) udf.__doc__ = pysparkfuncs.udf.__doc__ diff --git a/python/pyspark/sql/connect/udf.py b/python/pyspark/sql/connect/udf.py index 9afc6e0e626a5..6983bae5bf35d 100644 --- a/python/pyspark/sql/connect/udf.py +++ b/python/pyspark/sql/connect/udf.py @@ -23,6 +23,7 @@ import sys import functools +from inspect import getfullargspec from typing import cast, Callable, Any, TYPE_CHECKING, Optional, Union from pyspark.rdd import PythonEvalType @@ -33,7 +34,7 @@ ) from pyspark.sql.connect.column import Column from pyspark.sql.connect.types import UnparsedDataType -from pyspark.sql.types import DataType, StringType +from pyspark.sql.types import ArrayType, DataType, MapType, StringType, StructType from pyspark.sql.udf import UDFRegistration as PySparkUDFRegistration @@ -47,6 +48,42 @@ from pyspark.sql.types import StringType +def _create_py_udf( + f: Callable[..., Any], + returnType: "DataTypeOrString", + evalType: int, + useArrow: Optional[bool] = None, +) -> "UserDefinedFunctionLike": + from pyspark.sql.udf import _create_arrow_py_udf + from pyspark.sql.connect.session import _active_spark_session + + if _active_spark_session is None: + is_arrow_enabled = False + else: + is_arrow_enabled = ( + _active_spark_session.conf.get("spark.sql.execution.pythonUDF.arrow.enabled") == "true" + if useArrow is None + else useArrow + ) + + regular_udf = _create_udf(f, returnType, evalType) + return_type = regular_udf.returnType + try: + is_func_with_args = len(getfullargspec(f).args) > 0 + except TypeError: + is_func_with_args = False + is_output_atomic_type = ( + not isinstance(return_type, StructType) + and not isinstance(return_type, MapType) + and not isinstance(return_type, ArrayType) + ) + if is_arrow_enabled and is_output_atomic_type and is_func_with_args: + print("entering _create_arrow_py_udf") + return _create_arrow_py_udf(f, regular_udf) + else: + return regular_udf + + def _create_udf( f: Callable[..., Any], returnType: "DataTypeOrString", From 01f7190327cf47d31f762bbb745b9495e58bd961 Mon Sep 17 00:00:00 2001 From: Xinrong Meng Date: Mon, 10 Apr 2023 10:51:58 -0700 Subject: [PATCH 03/13] tests --- .../connect/test_parity_arrow_python_udf.py | 41 +++++++++++++++++++ .../sql/tests/test_arrow_python_udf.py | 14 ++++--- 2 files changed, 49 insertions(+), 6 deletions(-) create mode 100644 python/pyspark/sql/tests/connect/test_parity_arrow_python_udf.py diff --git a/python/pyspark/sql/tests/connect/test_parity_arrow_python_udf.py b/python/pyspark/sql/tests/connect/test_parity_arrow_python_udf.py new file mode 100644 index 0000000000000..a7c37bde39f83 --- /dev/null +++ b/python/pyspark/sql/tests/connect/test_parity_arrow_python_udf.py @@ -0,0 +1,41 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import unittest + +from pyspark.sql.tests.connect.test_parity_udf import UDFParityTests +from pyspark.sql.tests.test_arrow_python_udf import PythonUDFArrowTestsMixin + + +class ArrowPythonUDFParityTests(UDFParityTests, PythonUDFArrowTestsMixin): + @classmethod + def setUpClass(cls): + super(ArrowPythonUDFParityTests, cls).setUpClass() + cls.spark.conf.set("spark.sql.execution.pythonUDF.arrow.enabled", "true") + + +if __name__ == "__main__": + import unittest + from pyspark.sql.tests.connect.test_parity_arrow_python_udf import * # noqa: F401 + + try: + import xmlrunner # type: ignore[import] + + testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/sql/tests/test_arrow_python_udf.py b/python/pyspark/sql/tests/test_arrow_python_udf.py index 14d00633cc647..f34123da6a634 100644 --- a/python/pyspark/sql/tests/test_arrow_python_udf.py +++ b/python/pyspark/sql/tests/test_arrow_python_udf.py @@ -31,12 +31,7 @@ @unittest.skipIf( not have_pandas or not have_pyarrow, pandas_requirement_message or pyarrow_requirement_message ) -class PythonUDFArrowTests(BaseUDFTestsMixin, ReusedSQLTestCase): - @classmethod - def setUpClass(cls): - super(PythonUDFArrowTests, cls).setUpClass() - cls.spark.conf.set("spark.sql.execution.pythonUDF.arrow.enabled", "true") - +class PythonUDFArrowTestsMixin(BaseUDFTestsMixin): @unittest.skip("Unrelated test, and it fails when it runs duplicatedly.") def test_broadcast_in_udf(self): super(PythonUDFArrowTests, self).test_broadcast_in_udf() @@ -118,6 +113,13 @@ def test_use_arrow(self): self.assertEquals(row_false[0], "[1, 2, 3]") +class PythonUDFArrowTests(PythonUDFArrowTestsMixin, ReusedSQLTestCase): + @classmethod + def setUpClass(cls): + super(PythonUDFArrowTests, cls).setUpClass() + cls.spark.conf.set("spark.sql.execution.pythonUDF.arrow.enabled", "true") + + if __name__ == "__main__": from pyspark.sql.tests.test_arrow_python_udf import * # noqa: F401 From 0fb7712ae93662172db19d403badb511ff8b6214 Mon Sep 17 00:00:00 2001 From: Xinrong Meng Date: Mon, 10 Apr 2023 11:04:16 -0700 Subject: [PATCH 04/13] - debug --- python/pyspark/sql/connect/udf.py | 1 - python/pyspark/sql/udf.py | 1 - 2 files changed, 2 deletions(-) diff --git a/python/pyspark/sql/connect/udf.py b/python/pyspark/sql/connect/udf.py index 6983bae5bf35d..d1508ced7b175 100644 --- a/python/pyspark/sql/connect/udf.py +++ b/python/pyspark/sql/connect/udf.py @@ -78,7 +78,6 @@ def _create_py_udf( and not isinstance(return_type, ArrayType) ) if is_arrow_enabled and is_output_atomic_type and is_func_with_args: - print("entering _create_arrow_py_udf") return _create_arrow_py_udf(f, regular_udf) else: return regular_udf diff --git a/python/pyspark/sql/udf.py b/python/pyspark/sql/udf.py index fde2a7298e007..024aaca98ae39 100644 --- a/python/pyspark/sql/udf.py +++ b/python/pyspark/sql/udf.py @@ -145,7 +145,6 @@ def _create_py_udf( def _create_arrow_py_udf(f, regular_udf): # type: ignore - print("entered _create_arrow_py_udf") require_minimum_pandas_version() require_minimum_pyarrow_version() From f46d0062ab4974655e9cefd24f924430b943f589 Mon Sep 17 00:00:00 2001 From: Xinrong Meng Date: Mon, 10 Apr 2023 11:28:57 -0700 Subject: [PATCH 05/13] docstrings --- python/pyspark/sql/udf.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/python/pyspark/sql/udf.py b/python/pyspark/sql/udf.py index 024aaca98ae39..8d69682a3a1bb 100644 --- a/python/pyspark/sql/udf.py +++ b/python/pyspark/sql/udf.py @@ -75,6 +75,7 @@ def _create_udf( name: Optional[str] = None, deterministic: bool = True, ) -> "UserDefinedFunctionLike": + """Create a regular(non-Arrow-optimized) Python UDF.""" # Set the name of the UserDefinedFunction object to be the name of function f udf_obj = UserDefinedFunction( f, returnType=returnType, name=name, evalType=evalType, deterministic=deterministic @@ -88,6 +89,7 @@ def _create_py_udf( evalType: int, useArrow: Optional[bool] = None, ) -> "UserDefinedFunctionLike": + """Create a regular/Arrow-optimized Python UDF.""" # The following table shows the results when the type coercion in Arrow is needed, that is, # when the user-specified return type(SQL Type) of the UDF and the actual instance(Python # Value(Type)) that the UDF returns are different. @@ -145,6 +147,7 @@ def _create_py_udf( def _create_arrow_py_udf(f, regular_udf): # type: ignore + """Create an Arrow-optimized Python UDF out of a regular Python UDF.""" require_minimum_pandas_version() require_minimum_pyarrow_version() From 3abeef439548e9b23b48f1251a6ae9c94194b936 Mon Sep 17 00:00:00 2001 From: Xinrong Meng Date: Mon, 10 Apr 2023 16:10:00 -0700 Subject: [PATCH 06/13] TEST From f6fc6e172358aee33daffc57bd684a768fa91d82 Mon Sep 17 00:00:00 2001 From: Xinrong Meng Date: Tue, 11 Apr 2023 17:54:26 -0700 Subject: [PATCH 07/13] TEST From 63ef94e6855a9b9cbf7f20282f28cebaf06cd3c6 Mon Sep 17 00:00:00 2001 From: Xinrong Meng Date: Wed, 19 Apr 2023 14:45:56 -0700 Subject: [PATCH 08/13] rmv duplicate test --- python/pyspark/sql/tests/test_udf.py | 41 ---------------------------- 1 file changed, 41 deletions(-) diff --git a/python/pyspark/sql/tests/test_udf.py b/python/pyspark/sql/tests/test_udf.py index 26fe735c9c375..d8a464b006f66 100644 --- a/python/pyspark/sql/tests/test_udf.py +++ b/python/pyspark/sql/tests/test_udf.py @@ -838,47 +838,6 @@ def setUpClass(cls): cls.spark.conf.set("spark.sql.execution.pythonUDF.arrow.enabled", "false") -def test_use_arrow(self): - # useArrow=True - row_true = ( - self.spark.range(1) - .selectExpr( - "array(1, 2, 3) as array", - ) - .select( - udf(lambda x: str(x), useArrow=True)("array"), - ) - .first() - ) - # The input is a NumPy array when the Arrow optimization is on. - self.assertEquals(row_true[0], "[1 2 3]") - - # useArrow=None - row_none = ( - self.spark.range(1) - .selectExpr( - "array(1, 2, 3) as array", - ) - .select( - udf(lambda x: str(x), useArrow=None)("array"), - ) - .first() - ) - - # useArrow=False - row_false = ( - self.spark.range(1) - .selectExpr( - "array(1, 2, 3) as array", - ) - .select( - udf(lambda x: str(x), useArrow=False)("array"), - ) - .first() - ) - self.assertEquals(row_false[0], row_none[0]) # "[1, 2, 3]" - - class UDFInitializationTests(unittest.TestCase): def tearDown(self): if SparkSession._instantiatedSession is not None: From 86938d56dc2493c366b1088b6a623e324c69ebfe Mon Sep 17 00:00:00 2001 From: Xinrong Meng Date: Thu, 20 Apr 2023 10:51:16 -0700 Subject: [PATCH 09/13] tearDownClass --- .../sql/tests/connect/test_parity_arrow_python_udf.py | 5 +++++ python/pyspark/sql/tests/test_arrow_python_udf.py | 5 +++++ 2 files changed, 10 insertions(+) diff --git a/python/pyspark/sql/tests/connect/test_parity_arrow_python_udf.py b/python/pyspark/sql/tests/connect/test_parity_arrow_python_udf.py index a7c37bde39f83..ac4826781bc66 100644 --- a/python/pyspark/sql/tests/connect/test_parity_arrow_python_udf.py +++ b/python/pyspark/sql/tests/connect/test_parity_arrow_python_udf.py @@ -27,6 +27,11 @@ def setUpClass(cls): super(ArrowPythonUDFParityTests, cls).setUpClass() cls.spark.conf.set("spark.sql.execution.pythonUDF.arrow.enabled", "true") + @classmethod + def tearDownClass(cls): + cls.spark.conf.unset("spark.sql.execution.pythonUDF.arrow.enabled") + super(ArrowPythonUDFParityTests, cls).tearDownClass() + if __name__ == "__main__": import unittest diff --git a/python/pyspark/sql/tests/test_arrow_python_udf.py b/python/pyspark/sql/tests/test_arrow_python_udf.py index f34123da6a634..b5b44168fd714 100644 --- a/python/pyspark/sql/tests/test_arrow_python_udf.py +++ b/python/pyspark/sql/tests/test_arrow_python_udf.py @@ -119,6 +119,11 @@ def setUpClass(cls): super(PythonUDFArrowTests, cls).setUpClass() cls.spark.conf.set("spark.sql.execution.pythonUDF.arrow.enabled", "true") + @classmethod + def tearDownClass(cls): + cls.spark.conf.unset("spark.sql.execution.pythonUDF.arrow.enabled") + super().tearDownClass() + if __name__ == "__main__": from pyspark.sql.tests.test_arrow_python_udf import * # noqa: F401 From f5aef182ef108f22f138a5c19690c4b10c98551d Mon Sep 17 00:00:00 2001 From: Xinrong Meng Date: Thu, 20 Apr 2023 10:54:20 -0700 Subject: [PATCH 10/13] rmv f from _create_arrow_py_udf --- python/pyspark/sql/connect/udf.py | 2 +- python/pyspark/sql/udf.py | 5 +++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/python/pyspark/sql/connect/udf.py b/python/pyspark/sql/connect/udf.py index d1508ced7b175..09d5ff3083e9e 100644 --- a/python/pyspark/sql/connect/udf.py +++ b/python/pyspark/sql/connect/udf.py @@ -78,7 +78,7 @@ def _create_py_udf( and not isinstance(return_type, ArrayType) ) if is_arrow_enabled and is_output_atomic_type and is_func_with_args: - return _create_arrow_py_udf(f, regular_udf) + return _create_arrow_py_udf(regular_udf) else: return regular_udf diff --git a/python/pyspark/sql/udf.py b/python/pyspark/sql/udf.py index 8d69682a3a1bb..be667a4230513 100644 --- a/python/pyspark/sql/udf.py +++ b/python/pyspark/sql/udf.py @@ -141,12 +141,12 @@ def _create_py_udf( and not isinstance(return_type, ArrayType) ) if is_arrow_enabled and is_output_atomic_type and is_func_with_args: - return _create_arrow_py_udf(f, regular_udf) + return _create_arrow_py_udf(regular_udf) else: return regular_udf -def _create_arrow_py_udf(f, regular_udf): # type: ignore +def _create_arrow_py_udf(regular_udf): # type: ignore """Create an Arrow-optimized Python UDF out of a regular Python UDF.""" require_minimum_pandas_version() require_minimum_pyarrow_version() @@ -154,6 +154,7 @@ def _create_arrow_py_udf(f, regular_udf): # type: ignore import pandas as pd from pyspark.sql.pandas.functions import _create_pandas_udf + f = regular_udf.func return_type = regular_udf.returnType # "result_func" ensures the result of a Python UDF to be consistent with/without Arrow From f313063850675c1fb779cb129da29c9cfbd30df8 Mon Sep 17 00:00:00 2001 From: Xinrong Meng Date: Thu, 20 Apr 2023 11:18:28 -0700 Subject: [PATCH 11/13] UserWarning --- python/pyspark/sql/connect/udf.py | 11 +++++++++-- python/pyspark/sql/udf.py | 10 ++++++++-- 2 files changed, 17 insertions(+), 4 deletions(-) diff --git a/python/pyspark/sql/connect/udf.py b/python/pyspark/sql/connect/udf.py index 09d5ff3083e9e..0a3750a06abd9 100644 --- a/python/pyspark/sql/connect/udf.py +++ b/python/pyspark/sql/connect/udf.py @@ -23,6 +23,7 @@ import sys import functools +import warnings from inspect import getfullargspec from typing import cast, Callable, Any, TYPE_CHECKING, Optional, Union @@ -77,8 +78,14 @@ def _create_py_udf( and not isinstance(return_type, MapType) and not isinstance(return_type, ArrayType) ) - if is_arrow_enabled and is_output_atomic_type and is_func_with_args: - return _create_arrow_py_udf(regular_udf) + if is_arrow_enabled: + if is_output_atomic_type and is_func_with_args: + return _create_arrow_py_udf(regular_udf) + else: + warnings.warn( + "Arrow optimization for Python UDFs cannot be enabled.", + UserWarning, + ) else: return regular_udf diff --git a/python/pyspark/sql/udf.py b/python/pyspark/sql/udf.py index be667a4230513..036acfb4c4290 100644 --- a/python/pyspark/sql/udf.py +++ b/python/pyspark/sql/udf.py @@ -140,8 +140,14 @@ def _create_py_udf( and not isinstance(return_type, MapType) and not isinstance(return_type, ArrayType) ) - if is_arrow_enabled and is_output_atomic_type and is_func_with_args: - return _create_arrow_py_udf(regular_udf) + if is_arrow_enabled: + if is_output_atomic_type and is_func_with_args: + return _create_arrow_py_udf(regular_udf) + else: + warnings.warn( + "Arrow optimization for Python UDFs cannot be enabled.", + UserWarning, + ) else: return regular_udf From 5e786328817b550b49ff3aac13771917567ba641 Mon Sep 17 00:00:00 2001 From: Xinrong Meng Date: Thu, 20 Apr 2023 11:22:16 -0700 Subject: [PATCH 12/13] finally super tearDownClass --- .../sql/tests/connect/test_parity_arrow_python_udf.py | 6 ++++-- python/pyspark/sql/tests/test_arrow_python_udf.py | 6 ++++-- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/python/pyspark/sql/tests/connect/test_parity_arrow_python_udf.py b/python/pyspark/sql/tests/connect/test_parity_arrow_python_udf.py index ac4826781bc66..e4a64a7d5913e 100644 --- a/python/pyspark/sql/tests/connect/test_parity_arrow_python_udf.py +++ b/python/pyspark/sql/tests/connect/test_parity_arrow_python_udf.py @@ -29,8 +29,10 @@ def setUpClass(cls): @classmethod def tearDownClass(cls): - cls.spark.conf.unset("spark.sql.execution.pythonUDF.arrow.enabled") - super(ArrowPythonUDFParityTests, cls).tearDownClass() + try: + cls.spark.conf.unset("spark.sql.execution.pythonUDF.arrow.enabled") + finally: + super(ArrowPythonUDFParityTests, cls).tearDownClass() if __name__ == "__main__": diff --git a/python/pyspark/sql/tests/test_arrow_python_udf.py b/python/pyspark/sql/tests/test_arrow_python_udf.py index b5b44168fd714..681c42c6a5cd8 100644 --- a/python/pyspark/sql/tests/test_arrow_python_udf.py +++ b/python/pyspark/sql/tests/test_arrow_python_udf.py @@ -121,8 +121,10 @@ def setUpClass(cls): @classmethod def tearDownClass(cls): - cls.spark.conf.unset("spark.sql.execution.pythonUDF.arrow.enabled") - super().tearDownClass() + try: + cls.spark.conf.unset("spark.sql.execution.pythonUDF.arrow.enabled") + finally: + super(PythonUDFArrowTests, cls).tearDownClass() if __name__ == "__main__": From ac86bf1f252bfc27f0d097ee6eb0cd6ac77663b3 Mon Sep 17 00:00:00 2001 From: Xinrong Meng Date: Thu, 20 Apr 2023 13:16:13 -0700 Subject: [PATCH 13/13] fallback to regular udf --- python/pyspark/sql/connect/udf.py | 1 + python/pyspark/sql/udf.py | 1 + 2 files changed, 2 insertions(+) diff --git a/python/pyspark/sql/connect/udf.py b/python/pyspark/sql/connect/udf.py index 0a3750a06abd9..aab7bb3c0d3f8 100644 --- a/python/pyspark/sql/connect/udf.py +++ b/python/pyspark/sql/connect/udf.py @@ -86,6 +86,7 @@ def _create_py_udf( "Arrow optimization for Python UDFs cannot be enabled.", UserWarning, ) + return regular_udf else: return regular_udf diff --git a/python/pyspark/sql/udf.py b/python/pyspark/sql/udf.py index 036acfb4c4290..c486d869cba96 100644 --- a/python/pyspark/sql/udf.py +++ b/python/pyspark/sql/udf.py @@ -148,6 +148,7 @@ def _create_py_udf( "Arrow optimization for Python UDFs cannot be enabled.", UserWarning, ) + return regular_udf else: return regular_udf