From 5526efc206e7f98f73a141151477a5022efff05e Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Fri, 18 Aug 2023 13:17:09 -0700 Subject: [PATCH 1/2] Enable and fix test_parity_arrow_python_udf. --- dev/sparktestsupport/modules.py | 1 + python/pyspark/sql/connect/udf.py | 20 +++++------ python/pyspark/sql/udf.py | 60 ++++++------------------------- python/pyspark/worker.py | 54 ++++++++++++++++++++++++++-- 4 files changed, 73 insertions(+), 62 deletions(-) diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py index c5be1957a7dcb..bca700186258a 100644 --- a/dev/sparktestsupport/modules.py +++ b/dev/sparktestsupport/modules.py @@ -840,6 +840,7 @@ def __hash__(self): "pyspark.sql.tests.connect.test_connect_function", "pyspark.sql.tests.connect.test_connect_column", "pyspark.sql.tests.connect.test_parity_arrow", + "pyspark.sql.tests.connect.test_parity_arrow_python_udf", "pyspark.sql.tests.connect.test_parity_datasources", "pyspark.sql.tests.connect.test_parity_errors", "pyspark.sql.tests.connect.test_parity_catalog", diff --git a/python/pyspark/sql/connect/udf.py b/python/pyspark/sql/connect/udf.py index eb0541b936925..1737d6ec5eb18 100644 --- a/python/pyspark/sql/connect/udf.py +++ b/python/pyspark/sql/connect/udf.py @@ -54,8 +54,6 @@ def _create_py_udf( returnType: "DataTypeOrString", useArrow: Optional[bool] = None, ) -> "UserDefinedFunctionLike": - from pyspark.sql.udf import _create_arrow_py_udf - if useArrow is None: is_arrow_enabled = False try: @@ -74,22 +72,22 @@ def _create_py_udf( else: is_arrow_enabled = useArrow - regular_udf = _create_udf(f, returnType, PythonEvalType.SQL_BATCHED_UDF) - try: - is_func_with_args = len(getfullargspec(f).args) > 0 - except TypeError: - is_func_with_args = False + eva_type: int = PythonEvalType.SQL_BATCHED_UDF + if is_arrow_enabled: + try: + is_func_with_args = len(getfullargspec(f).args) > 0 + except TypeError: + is_func_with_args = False if is_func_with_args: - return _create_arrow_py_udf(regular_udf) + eva_type = PythonEvalType.SQL_ARROW_BATCHED_UDF else: warnings.warn( "Arrow optimization for Python UDFs cannot be enabled.", UserWarning, ) - return regular_udf - else: - return regular_udf + + return _create_udf(f, returnType, eva_type) def _create_udf( diff --git a/python/pyspark/sql/udf.py b/python/pyspark/sql/udf.py index f25f525e33be3..7d7784dd5226d 100644 --- a/python/pyspark/sql/udf.py +++ b/python/pyspark/sql/udf.py @@ -32,7 +32,6 @@ from pyspark.rdd import _prepare_for_python_RDD, PythonEvalType from pyspark.sql.column import Column, _to_java_column, _to_java_expr, _to_seq from pyspark.sql.types import ( - BinaryType, DataType, StringType, StructType, @@ -131,58 +130,24 @@ def _create_py_udf( else: is_arrow_enabled = useArrow - regular_udf = _create_udf(f, returnType, PythonEvalType.SQL_BATCHED_UDF) - try: - is_func_with_args = len(getfullargspec(f).args) > 0 - except TypeError: - is_func_with_args = False + eval_type: int = PythonEvalType.SQL_BATCHED_UDF + if is_arrow_enabled: + try: + is_func_with_args = len(getfullargspec(f).args) > 0 + except TypeError: + is_func_with_args = False if is_func_with_args: - return _create_arrow_py_udf(regular_udf) + require_minimum_pandas_version() + require_minimum_pyarrow_version() + eval_type = PythonEvalType.SQL_ARROW_BATCHED_UDF else: warnings.warn( "Arrow optimization for Python UDFs cannot be enabled.", UserWarning, ) - return regular_udf - else: - return regular_udf - - -def _create_arrow_py_udf(regular_udf): # type: ignore - """Create an Arrow-optimized Python UDF out of a regular Python UDF.""" - require_minimum_pandas_version() - require_minimum_pyarrow_version() - - import pandas as pd - from pyspark.sql.pandas.functions import _create_pandas_udf - f = regular_udf.func - return_type = regular_udf.returnType - - # "result_func" ensures the result of a Python UDF to be consistent with/without Arrow - # optimization. - # Otherwise, an Arrow-optimized Python UDF raises "pyarrow.lib.ArrowTypeError: Expected a - # string or bytes dtype, got ..." whereas a non-Arrow-optimized Python UDF returns - # successfully. - result_func = lambda pdf: pdf # noqa: E731 - if type(return_type) == StringType: - result_func = lambda r: str(r) if r is not None else r # noqa: E731 - elif type(return_type) == BinaryType: - result_func = lambda r: bytes(r) if r is not None else r # noqa: E731 - - def vectorized_udf(*args: pd.Series) -> pd.Series: - return pd.Series(result_func(f(*a)) for a in zip(*args)) - - # Regular UDFs can take callable instances too. - vectorized_udf.__name__ = f.__name__ if hasattr(f, "__name__") else f.__class__.__name__ - vectorized_udf.__module__ = f.__module__ if hasattr(f, "__module__") else f.__class__.__module__ - vectorized_udf.__doc__ = f.__doc__ - pudf = _create_pandas_udf(vectorized_udf, return_type, PythonEvalType.SQL_ARROW_BATCHED_UDF) - # Keep the attributes as if this is a regular Python UDF. - pudf.func = f - pudf.returnType = return_type - return pudf + return _create_udf(f, returnType, eval_type) class UserDefinedFunction: @@ -637,10 +602,7 @@ def register( evalType=f.evalType, deterministic=f.deterministic, ) - if f.evalType == PythonEvalType.SQL_ARROW_BATCHED_UDF: - register_udf = _create_arrow_py_udf(source_udf)._unwrapped - else: - register_udf = source_udf._unwrapped # type: ignore[attr-defined] + register_udf = source_udf._unwrapped # type: ignore[attr-defined] return_udf = register_udf else: if returnType is None: diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 63c286e7fb04a..d2b41fef70853 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -53,7 +53,7 @@ ApplyInPandasWithStateSerializer, ) from pyspark.sql.pandas.types import to_arrow_type -from pyspark.sql.types import StructType, _parse_datatype_json_string +from pyspark.sql.types import BinaryType, StringType, StructType, _parse_datatype_json_string from pyspark.util import fail_on_stopiteration, try_simplify_traceback from pyspark import shuffle from pyspark.errors import PySparkRuntimeError, PySparkTypeError @@ -121,6 +121,54 @@ def verify_result_length(result, length): ) +def wrap_arrow_batch_udf(f, return_type): + import pandas as pd + + arrow_return_type = to_arrow_type(return_type) + + # "result_func" ensures the result of a Python UDF to be consistent with/without Arrow + # optimization. + # Otherwise, an Arrow-optimized Python UDF raises "pyarrow.lib.ArrowTypeError: Expected a + # string or bytes dtype, got ..." whereas a non-Arrow-optimized Python UDF returns + # successfully. + result_func = lambda pdf: pdf # noqa: E731 + if type(return_type) == StringType: + result_func = lambda r: str(r) if r is not None else r # noqa: E731 + elif type(return_type) == BinaryType: + result_func = lambda r: bytes(r) if r is not None else r # noqa: E731 + + def evaluate(*args: pd.Series) -> pd.Series: + return pd.Series(result_func(f(*a)) for a in zip(*args)) + + def verify_result_type(result): + if not hasattr(result, "__len__"): + pd_type = "pandas.DataFrame" if type(return_type) == StructType else "pandas.Series" + raise PySparkTypeError( + error_class="UDF_RETURN_TYPE", + message_parameters={ + "expected": pd_type, + "actual": type(result).__name__, + }, + ) + return result + + def verify_result_length(result, length): + if len(result) != length: + raise PySparkRuntimeError( + error_class="SCHEMA_MISMATCH_FOR_PANDAS_UDF", + message_parameters={ + "expected": str(length), + "actual": str(len(result)), + }, + ) + return result + + return lambda *a: ( + verify_result_length(verify_result_type(evaluate(*a)), len(a[0])), + arrow_return_type, + ) + + def wrap_pandas_batch_iter_udf(f, return_type): arrow_return_type = to_arrow_type(return_type) iter_type_label = "pandas.DataFrame" if type(return_type) == StructType else "pandas.Series" @@ -486,8 +534,10 @@ def read_single_udf(pickleSer, infile, eval_type, runner_conf, udf_index): func = fail_on_stopiteration(chained_func) # the last returnType will be the return type of UDF - if eval_type in (PythonEvalType.SQL_SCALAR_PANDAS_UDF, PythonEvalType.SQL_ARROW_BATCHED_UDF): + if eval_type == PythonEvalType.SQL_SCALAR_PANDAS_UDF: return arg_offsets, wrap_scalar_pandas_udf(func, return_type) + elif eval_type == PythonEvalType.SQL_ARROW_BATCHED_UDF: + return arg_offsets, wrap_arrow_batch_udf(func, return_type) elif eval_type == PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF: return arg_offsets, wrap_pandas_batch_iter_udf(func, return_type) elif eval_type == PythonEvalType.SQL_MAP_PANDAS_ITER_UDF: From 5e2e3419b72049d7f3ddf2f4acc38494709cbf14 Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Fri, 18 Aug 2023 14:21:50 -0700 Subject: [PATCH 2/2] Fix. --- python/pyspark/sql/connect/udf.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/python/pyspark/sql/connect/udf.py b/python/pyspark/sql/connect/udf.py index 1737d6ec5eb18..2636777e5f6fb 100644 --- a/python/pyspark/sql/connect/udf.py +++ b/python/pyspark/sql/connect/udf.py @@ -72,7 +72,7 @@ def _create_py_udf( else: is_arrow_enabled = useArrow - eva_type: int = PythonEvalType.SQL_BATCHED_UDF + eval_type: int = PythonEvalType.SQL_BATCHED_UDF if is_arrow_enabled: try: @@ -80,14 +80,14 @@ def _create_py_udf( except TypeError: is_func_with_args = False if is_func_with_args: - eva_type = PythonEvalType.SQL_ARROW_BATCHED_UDF + eval_type = PythonEvalType.SQL_ARROW_BATCHED_UDF else: warnings.warn( "Arrow optimization for Python UDFs cannot be enabled.", UserWarning, ) - return _create_udf(f, returnType, eva_type) + return _create_udf(f, returnType, eval_type) def _create_udf(