diff --git a/python/pyspark/sql/connect/udf.py b/python/pyspark/sql/connect/udf.py index 012c6c0d2d503..bc3a17dca3c5f 100644 --- a/python/pyspark/sql/connect/udf.py +++ b/python/pyspark/sql/connect/udf.py @@ -252,6 +252,7 @@ def register( f = cast("UserDefinedFunctionLike", f) if f.evalType not in [ PythonEvalType.SQL_BATCHED_UDF, + PythonEvalType.SQL_ARROW_BATCHED_UDF, PythonEvalType.SQL_SCALAR_PANDAS_UDF, PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF, PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF, @@ -259,24 +260,24 @@ def register( raise PySparkTypeError( error_class="INVALID_UDF_EVAL_TYPE", message_parameters={ - "eval_type": "SQL_BATCHED_UDF, SQL_SCALAR_PANDAS_UDF, " - "SQL_SCALAR_PANDAS_ITER_UDF or SQL_GROUPED_AGG_PANDAS_UDF" + "eval_type": "SQL_BATCHED_UDF, SQL_ARROW_BATCHED_UDF, " + "SQL_SCALAR_PANDAS_UDF, SQL_SCALAR_PANDAS_ITER_UDF or " + "SQL_GROUPED_AGG_PANDAS_UDF" }, ) - return_udf = f self.sparkSession._client.register_udf( f.func, f.returnType, name, f.evalType, f.deterministic ) + return f else: if returnType is None: returnType = StringType() - return_udf = _create_udf( + py_udf = _create_udf( f, returnType=returnType, evalType=PythonEvalType.SQL_BATCHED_UDF, name=name ) - self.sparkSession._client.register_udf(f, returnType, name) - - return return_udf + self.sparkSession._client.register_udf(py_udf.func, returnType, name) + return py_udf register.__doc__ = PySparkUDFRegistration.register.__doc__ diff --git a/python/pyspark/sql/tests/pandas/test_pandas_grouped_map.py b/python/pyspark/sql/tests/pandas/test_pandas_grouped_map.py index 0a2f81c6e5ba9..8464e5f083c1c 100644 --- a/python/pyspark/sql/tests/pandas/test_pandas_grouped_map.py +++ b/python/pyspark/sql/tests/pandas/test_pandas_grouped_map.py @@ -219,7 +219,7 @@ def test_register_grouped_map_udf(self): exception=pe.exception, error_class="INVALID_UDF_EVAL_TYPE", message_parameters={ - "eval_type": "SQL_BATCHED_UDF, SQL_SCALAR_PANDAS_UDF, " + "eval_type": "SQL_BATCHED_UDF, SQL_ARROW_BATCHED_UDF, SQL_SCALAR_PANDAS_UDF, " "SQL_SCALAR_PANDAS_ITER_UDF or SQL_GROUPED_AGG_PANDAS_UDF" }, ) diff --git a/python/pyspark/sql/tests/test_arrow_python_udf.py b/python/pyspark/sql/tests/test_arrow_python_udf.py index 51112beadec01..3266168f29071 100644 --- a/python/pyspark/sql/tests/test_arrow_python_udf.py +++ b/python/pyspark/sql/tests/test_arrow_python_udf.py @@ -119,6 +119,24 @@ def test_eval_type(self): udf(lambda x: str(x), useArrow=False).evalType, PythonEvalType.SQL_BATCHED_UDF ) + def test_register(self): + df = self.spark.range(1).selectExpr( + "array(1, 2, 3) as array", + ) + str_repr_func = self.spark.udf.register("str_repr", udf(lambda x: str(x), useArrow=True)) + + # To verify that Arrow optimization is on + self.assertEquals( + df.selectExpr("str_repr(array) AS str_id").first()[0], + "[1 2 3]", # The input is a NumPy array when the Arrow optimization is on + ) + + # To verify that a UserDefinedFunction is returned + self.assertListEqual( + df.selectExpr("str_repr(array) AS str_id").collect(), + df.select(str_repr_func("array").alias("str_id")).collect(), + ) + class PythonUDFArrowTests(PythonUDFArrowTestsMixin, ReusedSQLTestCase): @classmethod diff --git a/python/pyspark/sql/udf.py b/python/pyspark/sql/udf.py index 374e8c1bcbbe5..458281872950e 100644 --- a/python/pyspark/sql/udf.py +++ b/python/pyspark/sql/udf.py @@ -623,6 +623,7 @@ def register( f = cast("UserDefinedFunctionLike", f) if f.evalType not in [ PythonEvalType.SQL_BATCHED_UDF, + PythonEvalType.SQL_ARROW_BATCHED_UDF, PythonEvalType.SQL_SCALAR_PANDAS_UDF, PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF, PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF, @@ -630,25 +631,30 @@ def register( raise PySparkTypeError( error_class="INVALID_UDF_EVAL_TYPE", message_parameters={ - "eval_type": "SQL_BATCHED_UDF, SQL_SCALAR_PANDAS_UDF, " - "SQL_SCALAR_PANDAS_ITER_UDF or SQL_GROUPED_AGG_PANDAS_UDF" + "eval_type": "SQL_BATCHED_UDF, SQL_ARROW_BATCHED_UDF, " + "SQL_SCALAR_PANDAS_UDF, SQL_SCALAR_PANDAS_ITER_UDF or " + "SQL_GROUPED_AGG_PANDAS_UDF" }, ) - register_udf = _create_udf( + source_udf = _create_udf( f.func, returnType=f.returnType, name=name, evalType=f.evalType, deterministic=f.deterministic, - )._unwrapped # type: ignore[attr-defined] - return_udf = f + ) + if f.evalType == PythonEvalType.SQL_ARROW_BATCHED_UDF: + register_udf = _create_arrow_py_udf(source_udf)._unwrapped + else: + register_udf = source_udf._unwrapped # type: ignore[attr-defined] + return_udf = register_udf else: if returnType is None: returnType = StringType() return_udf = _create_udf( f, returnType=returnType, evalType=PythonEvalType.SQL_BATCHED_UDF, name=name ) - register_udf = return_udf._unwrapped # type: ignore[attr-defined] + register_udf = return_udf._unwrapped self.sparkSession._jsparkSession.udf().registerPython(name, register_udf._judf) return return_udf