Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions dev/sparktestsupport/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -840,6 +840,7 @@ def __hash__(self):
"pyspark.sql.tests.connect.test_connect_function",
"pyspark.sql.tests.connect.test_connect_column",
"pyspark.sql.tests.connect.test_parity_arrow",
"pyspark.sql.tests.connect.test_parity_arrow_python_udf",
"pyspark.sql.tests.connect.test_parity_datasources",
"pyspark.sql.tests.connect.test_parity_errors",
"pyspark.sql.tests.connect.test_parity_catalog",
Expand Down
20 changes: 9 additions & 11 deletions python/pyspark/sql/connect/udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,6 @@ def _create_py_udf(
returnType: "DataTypeOrString",
useArrow: Optional[bool] = None,
) -> "UserDefinedFunctionLike":
from pyspark.sql.udf import _create_arrow_py_udf

if useArrow is None:
is_arrow_enabled = False
try:
Expand All @@ -74,22 +72,22 @@ def _create_py_udf(
else:
is_arrow_enabled = useArrow

regular_udf = _create_udf(f, returnType, PythonEvalType.SQL_BATCHED_UDF)
try:
is_func_with_args = len(getfullargspec(f).args) > 0
except TypeError:
is_func_with_args = False
eval_type: int = PythonEvalType.SQL_BATCHED_UDF

if is_arrow_enabled:
try:
is_func_with_args = len(getfullargspec(f).args) > 0
except TypeError:
is_func_with_args = False
if is_func_with_args:
return _create_arrow_py_udf(regular_udf)
eval_type = PythonEvalType.SQL_ARROW_BATCHED_UDF
else:
warnings.warn(
"Arrow optimization for Python UDFs cannot be enabled.",
UserWarning,
)
return regular_udf
else:
return regular_udf

return _create_udf(f, returnType, eval_type)


def _create_udf(
Expand Down
60 changes: 11 additions & 49 deletions python/pyspark/sql/udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@
from pyspark.rdd import _prepare_for_python_RDD, PythonEvalType
from pyspark.sql.column import Column, _to_java_column, _to_java_expr, _to_seq
from pyspark.sql.types import (
BinaryType,
DataType,
StringType,
StructType,
Expand Down Expand Up @@ -131,58 +130,24 @@ def _create_py_udf(
else:
is_arrow_enabled = useArrow

regular_udf = _create_udf(f, returnType, PythonEvalType.SQL_BATCHED_UDF)
try:
is_func_with_args = len(getfullargspec(f).args) > 0
except TypeError:
is_func_with_args = False
eval_type: int = PythonEvalType.SQL_BATCHED_UDF

if is_arrow_enabled:
try:
is_func_with_args = len(getfullargspec(f).args) > 0
except TypeError:
is_func_with_args = False
if is_func_with_args:
return _create_arrow_py_udf(regular_udf)
require_minimum_pandas_version()
require_minimum_pyarrow_version()
eval_type = PythonEvalType.SQL_ARROW_BATCHED_UDF
else:
warnings.warn(
"Arrow optimization for Python UDFs cannot be enabled.",
UserWarning,
)
return regular_udf
else:
return regular_udf


def _create_arrow_py_udf(regular_udf): # type: ignore
"""Create an Arrow-optimized Python UDF out of a regular Python UDF."""
require_minimum_pandas_version()
require_minimum_pyarrow_version()

import pandas as pd
from pyspark.sql.pandas.functions import _create_pandas_udf

f = regular_udf.func
return_type = regular_udf.returnType

# "result_func" ensures the result of a Python UDF to be consistent with/without Arrow
# optimization.
# Otherwise, an Arrow-optimized Python UDF raises "pyarrow.lib.ArrowTypeError: Expected a
# string or bytes dtype, got ..." whereas a non-Arrow-optimized Python UDF returns
# successfully.
result_func = lambda pdf: pdf # noqa: E731
if type(return_type) == StringType:
result_func = lambda r: str(r) if r is not None else r # noqa: E731
elif type(return_type) == BinaryType:
result_func = lambda r: bytes(r) if r is not None else r # noqa: E731

def vectorized_udf(*args: pd.Series) -> pd.Series:
return pd.Series(result_func(f(*a)) for a in zip(*args))

# Regular UDFs can take callable instances too.
vectorized_udf.__name__ = f.__name__ if hasattr(f, "__name__") else f.__class__.__name__
vectorized_udf.__module__ = f.__module__ if hasattr(f, "__module__") else f.__class__.__module__
vectorized_udf.__doc__ = f.__doc__
pudf = _create_pandas_udf(vectorized_udf, return_type, PythonEvalType.SQL_ARROW_BATCHED_UDF)
# Keep the attributes as if this is a regular Python UDF.
pudf.func = f
pudf.returnType = return_type
return pudf
return _create_udf(f, returnType, eval_type)


class UserDefinedFunction:
Expand Down Expand Up @@ -637,10 +602,7 @@ def register(
evalType=f.evalType,
deterministic=f.deterministic,
)
if f.evalType == PythonEvalType.SQL_ARROW_BATCHED_UDF:
register_udf = _create_arrow_py_udf(source_udf)._unwrapped
else:
register_udf = source_udf._unwrapped # type: ignore[attr-defined]
register_udf = source_udf._unwrapped # type: ignore[attr-defined]
return_udf = register_udf
else:
if returnType is None:
Expand Down
54 changes: 52 additions & 2 deletions python/pyspark/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@
ApplyInPandasWithStateSerializer,
)
from pyspark.sql.pandas.types import to_arrow_type
from pyspark.sql.types import StructType, _parse_datatype_json_string
from pyspark.sql.types import BinaryType, StringType, StructType, _parse_datatype_json_string
from pyspark.util import fail_on_stopiteration, try_simplify_traceback
from pyspark import shuffle
from pyspark.errors import PySparkRuntimeError, PySparkTypeError
Expand Down Expand Up @@ -121,6 +121,54 @@ def verify_result_length(result, length):
)


def wrap_arrow_batch_udf(f, return_type):
import pandas as pd

arrow_return_type = to_arrow_type(return_type)

# "result_func" ensures the result of a Python UDF to be consistent with/without Arrow
# optimization.
# Otherwise, an Arrow-optimized Python UDF raises "pyarrow.lib.ArrowTypeError: Expected a
# string or bytes dtype, got ..." whereas a non-Arrow-optimized Python UDF returns
# successfully.
result_func = lambda pdf: pdf # noqa: E731
if type(return_type) == StringType:
result_func = lambda r: str(r) if r is not None else r # noqa: E731
elif type(return_type) == BinaryType:
result_func = lambda r: bytes(r) if r is not None else r # noqa: E731

def evaluate(*args: pd.Series) -> pd.Series:
return pd.Series(result_func(f(*a)) for a in zip(*args))

def verify_result_type(result):
if not hasattr(result, "__len__"):
pd_type = "pandas.DataFrame" if type(return_type) == StructType else "pandas.Series"
raise PySparkTypeError(
error_class="UDF_RETURN_TYPE",
message_parameters={
"expected": pd_type,
"actual": type(result).__name__,
},
)
return result

def verify_result_length(result, length):
if len(result) != length:
raise PySparkRuntimeError(
error_class="SCHEMA_MISMATCH_FOR_PANDAS_UDF",
message_parameters={
"expected": str(length),
"actual": str(len(result)),
},
)
return result

return lambda *a: (
verify_result_length(verify_result_type(evaluate(*a)), len(a[0])),
arrow_return_type,
)


def wrap_pandas_batch_iter_udf(f, return_type):
arrow_return_type = to_arrow_type(return_type)
iter_type_label = "pandas.DataFrame" if type(return_type) == StructType else "pandas.Series"
Expand Down Expand Up @@ -486,8 +534,10 @@ def read_single_udf(pickleSer, infile, eval_type, runner_conf, udf_index):
func = fail_on_stopiteration(chained_func)

# the last returnType will be the return type of UDF
if eval_type in (PythonEvalType.SQL_SCALAR_PANDAS_UDF, PythonEvalType.SQL_ARROW_BATCHED_UDF):
if eval_type == PythonEvalType.SQL_SCALAR_PANDAS_UDF:
return arg_offsets, wrap_scalar_pandas_udf(func, return_type)
elif eval_type == PythonEvalType.SQL_ARROW_BATCHED_UDF:
return arg_offsets, wrap_arrow_batch_udf(func, return_type)
elif eval_type == PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF:
return arg_offsets, wrap_pandas_batch_iter_udf(func, return_type)
elif eval_type == PythonEvalType.SQL_MAP_PANDAS_ITER_UDF:
Expand Down