From 4784085c7f9d2a4680ca0b56fb0aa8949091a7d0 Mon Sep 17 00:00:00 2001 From: Xinrong Meng Date: Wed, 10 May 2023 13:27:51 -0700 Subject: [PATCH 1/5] fix --- python/pyspark/sql/connect/udf.py | 26 +++++++++++++++++++------- python/pyspark/sql/udf.py | 16 +++++++++++----- 2 files changed, 30 insertions(+), 12 deletions(-) diff --git a/python/pyspark/sql/connect/udf.py b/python/pyspark/sql/connect/udf.py index 012c6c0d2d503..dde904b76c552 100644 --- a/python/pyspark/sql/connect/udf.py +++ b/python/pyspark/sql/connect/udf.py @@ -36,7 +36,7 @@ from pyspark.sql.connect.column import Column from pyspark.sql.connect.types import UnparsedDataType from pyspark.sql.types import ArrayType, DataType, MapType, StringType, StructType -from pyspark.sql.udf import UDFRegistration as PySparkUDFRegistration +from pyspark.sql.udf import UDFRegistration as PySparkUDFRegistration, _create_arrow_py_udf from pyspark.errors import PySparkTypeError @@ -55,7 +55,6 @@ def _create_py_udf( returnType: "DataTypeOrString", useArrow: Optional[bool] = None, ) -> "UserDefinedFunctionLike": - from pyspark.sql.udf import _create_arrow_py_udf from pyspark.sql.connect.session import _active_spark_session if _active_spark_session is None: @@ -252,6 +251,7 @@ def register( f = cast("UserDefinedFunctionLike", f) if f.evalType not in [ PythonEvalType.SQL_BATCHED_UDF, + PythonEvalType.SQL_ARROW_BATCHED_UDF, PythonEvalType.SQL_SCALAR_PANDAS_UDF, PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF, PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF, @@ -259,14 +259,26 @@ def register( raise PySparkTypeError( error_class="INVALID_UDF_EVAL_TYPE", message_parameters={ - "eval_type": "SQL_BATCHED_UDF, SQL_SCALAR_PANDAS_UDF, " - "SQL_SCALAR_PANDAS_ITER_UDF or SQL_GROUPED_AGG_PANDAS_UDF" + "eval_type": "SQL_BATCHED_UDF, SQL_ARROW_BATCHED_UDF, " + "SQL_SCALAR_PANDAS_UDF, SQL_SCALAR_PANDAS_ITER_UDF or " + "SQL_GROUPED_AGG_PANDAS_UDF" }, ) - return_udf = f + source_udf = _create_udf( + f.func, + returnType=f.returnType, + name=name, + evalType=f.evalType, + deterministic=f.deterministic, + ) + if f.evalType == PythonEvalType.SQL_ARROW_BATCHED_UDF: + register_udf = _create_arrow_py_udf(source_udf)._unwrapped + else: + register_udf = source_udf._unwrapped self.sparkSession._client.register_udf( - f.func, f.returnType, name, f.evalType, f.deterministic + register_udf.func, f.returnType, name, f.evalType, f.deterministic ) + return_udf = f else: if returnType is None: returnType = StringType() @@ -274,7 +286,7 @@ def register( f, returnType=returnType, evalType=PythonEvalType.SQL_BATCHED_UDF, name=name ) - self.sparkSession._client.register_udf(f, returnType, name) + self.sparkSession._client.register_udf(return_udf.func, returnType, name) return return_udf diff --git a/python/pyspark/sql/udf.py b/python/pyspark/sql/udf.py index 374e8c1bcbbe5..32ea80d0fe079 100644 --- a/python/pyspark/sql/udf.py +++ b/python/pyspark/sql/udf.py @@ -623,6 +623,7 @@ def register( f = cast("UserDefinedFunctionLike", f) if f.evalType not in [ PythonEvalType.SQL_BATCHED_UDF, + PythonEvalType.SQL_ARROW_BATCHED_UDF, PythonEvalType.SQL_SCALAR_PANDAS_UDF, PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF, PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF, @@ -630,18 +631,23 @@ def register( raise PySparkTypeError( error_class="INVALID_UDF_EVAL_TYPE", message_parameters={ - "eval_type": "SQL_BATCHED_UDF, SQL_SCALAR_PANDAS_UDF, " - "SQL_SCALAR_PANDAS_ITER_UDF or SQL_GROUPED_AGG_PANDAS_UDF" + "eval_type": "SQL_BATCHED_UDF, SQL_ARROW_BATCHED_UDF, " + "SQL_SCALAR_PANDAS_UDF, SQL_SCALAR_PANDAS_ITER_UDF or " + "SQL_GROUPED_AGG_PANDAS_UDF" }, ) - register_udf = _create_udf( + source_udf = _create_udf( f.func, returnType=f.returnType, name=name, evalType=f.evalType, deterministic=f.deterministic, - )._unwrapped # type: ignore[attr-defined] - return_udf = f + ) + if f.evalType == PythonEvalType.SQL_ARROW_BATCHED_UDF: + register_udf = _create_arrow_py_udf(source_udf)._unwrapped + else: + register_udf = source_udf._unwrapped # type: ignore[attr-defined] + return_udf = register_udf else: if returnType is None: returnType = StringType() From 9e1b53494beab69e9381139c341a36bebbba78a4 Mon Sep 17 00:00:00 2001 From: Xinrong Meng Date: Wed, 10 May 2023 14:02:31 -0700 Subject: [PATCH 2/5] test --- .../pyspark/sql/tests/test_arrow_python_udf.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/python/pyspark/sql/tests/test_arrow_python_udf.py b/python/pyspark/sql/tests/test_arrow_python_udf.py index 51112beadec01..3266168f29071 100644 --- a/python/pyspark/sql/tests/test_arrow_python_udf.py +++ b/python/pyspark/sql/tests/test_arrow_python_udf.py @@ -119,6 +119,24 @@ def test_eval_type(self): udf(lambda x: str(x), useArrow=False).evalType, PythonEvalType.SQL_BATCHED_UDF ) + def test_register(self): + df = self.spark.range(1).selectExpr( + "array(1, 2, 3) as array", + ) + str_repr_func = self.spark.udf.register("str_repr", udf(lambda x: str(x), useArrow=True)) + + # To verify that Arrow optimization is on + self.assertEquals( + df.selectExpr("str_repr(array) AS str_id").first()[0], + "[1 2 3]", # The input is a NumPy array when the Arrow optimization is on + ) + + # To verify that a UserDefinedFunction is returned + self.assertListEqual( + df.selectExpr("str_repr(array) AS str_id").collect(), + df.select(str_repr_func("array").alias("str_id")).collect(), + ) + class PythonUDFArrowTests(PythonUDFArrowTestsMixin, ReusedSQLTestCase): @classmethod From 4db78a94c65a781777ba1052b63cb0acbcc24d97 Mon Sep 17 00:00:00 2001 From: Xinrong Meng Date: Thu, 11 May 2023 11:31:44 -0700 Subject: [PATCH 3/5] restore Connect --- python/pyspark/sql/connect/udf.py | 22 +++++----------------- 1 file changed, 5 insertions(+), 17 deletions(-) diff --git a/python/pyspark/sql/connect/udf.py b/python/pyspark/sql/connect/udf.py index dde904b76c552..ca6448e666864 100644 --- a/python/pyspark/sql/connect/udf.py +++ b/python/pyspark/sql/connect/udf.py @@ -264,31 +264,19 @@ def register( "SQL_GROUPED_AGG_PANDAS_UDF" }, ) - source_udf = _create_udf( - f.func, - returnType=f.returnType, - name=name, - evalType=f.evalType, - deterministic=f.deterministic, - ) - if f.evalType == PythonEvalType.SQL_ARROW_BATCHED_UDF: - register_udf = _create_arrow_py_udf(source_udf)._unwrapped - else: - register_udf = source_udf._unwrapped self.sparkSession._client.register_udf( - register_udf.func, f.returnType, name, f.evalType, f.deterministic + f.func, f.returnType, name, f.evalType, f.deterministic ) - return_udf = f + return f else: if returnType is None: returnType = StringType() - return_udf = _create_udf( + py_udf = _create_udf( f, returnType=returnType, evalType=PythonEvalType.SQL_BATCHED_UDF, name=name ) - self.sparkSession._client.register_udf(return_udf.func, returnType, name) - - return return_udf + self.sparkSession._client.register_udf(py_udf.func, returnType, name) + return py_udf register.__doc__ = PySparkUDFRegistration.register.__doc__ From d0cd0e5b08f478110301c38ccaf41c303e827352 Mon Sep 17 00:00:00 2001 From: Xinrong Meng Date: Thu, 11 May 2023 11:37:14 -0700 Subject: [PATCH 4/5] restore Connect --- python/pyspark/sql/connect/udf.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/pyspark/sql/connect/udf.py b/python/pyspark/sql/connect/udf.py index ca6448e666864..bc3a17dca3c5f 100644 --- a/python/pyspark/sql/connect/udf.py +++ b/python/pyspark/sql/connect/udf.py @@ -36,7 +36,7 @@ from pyspark.sql.connect.column import Column from pyspark.sql.connect.types import UnparsedDataType from pyspark.sql.types import ArrayType, DataType, MapType, StringType, StructType -from pyspark.sql.udf import UDFRegistration as PySparkUDFRegistration, _create_arrow_py_udf +from pyspark.sql.udf import UDFRegistration as PySparkUDFRegistration from pyspark.errors import PySparkTypeError @@ -55,6 +55,7 @@ def _create_py_udf( returnType: "DataTypeOrString", useArrow: Optional[bool] = None, ) -> "UserDefinedFunctionLike": + from pyspark.sql.udf import _create_arrow_py_udf from pyspark.sql.connect.session import _active_spark_session if _active_spark_session is None: From 23e6cba9ffb7e56bdcc624eeb174086ae01c101d Mon Sep 17 00:00:00 2001 From: Xinrong Meng Date: Fri, 12 May 2023 11:04:12 -0700 Subject: [PATCH 5/5] fix --- python/pyspark/sql/tests/pandas/test_pandas_grouped_map.py | 2 +- python/pyspark/sql/udf.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/python/pyspark/sql/tests/pandas/test_pandas_grouped_map.py b/python/pyspark/sql/tests/pandas/test_pandas_grouped_map.py index 0a2f81c6e5ba9..8464e5f083c1c 100644 --- a/python/pyspark/sql/tests/pandas/test_pandas_grouped_map.py +++ b/python/pyspark/sql/tests/pandas/test_pandas_grouped_map.py @@ -219,7 +219,7 @@ def test_register_grouped_map_udf(self): exception=pe.exception, error_class="INVALID_UDF_EVAL_TYPE", message_parameters={ - "eval_type": "SQL_BATCHED_UDF, SQL_SCALAR_PANDAS_UDF, " + "eval_type": "SQL_BATCHED_UDF, SQL_ARROW_BATCHED_UDF, SQL_SCALAR_PANDAS_UDF, " "SQL_SCALAR_PANDAS_ITER_UDF or SQL_GROUPED_AGG_PANDAS_UDF" }, ) diff --git a/python/pyspark/sql/udf.py b/python/pyspark/sql/udf.py index 32ea80d0fe079..458281872950e 100644 --- a/python/pyspark/sql/udf.py +++ b/python/pyspark/sql/udf.py @@ -654,7 +654,7 @@ def register( return_udf = _create_udf( f, returnType=returnType, evalType=PythonEvalType.SQL_BATCHED_UDF, name=name ) - register_udf = return_udf._unwrapped # type: ignore[attr-defined] + register_udf = return_udf._unwrapped self.sparkSession._jsparkSession.udf().registerPython(name, register_udf._judf) return return_udf