diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala index 3688a149443c..d417303bb147 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala @@ -36,6 +36,7 @@ private[spark] object PythonEvalType { val NON_UDF = 0 val SQL_BATCHED_UDF = 1 val SQL_PANDAS_UDF = 2 + val SQL_PANDAS_GROUPED_UDF = 3 } /** diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index ad18bd0c81ea..a0adeed99445 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -86,6 +86,7 @@ class PythonEvalType(object): NON_UDF = 0 SQL_BATCHED_UDF = 1 SQL_PANDAS_UDF = 2 + SQL_PANDAS_GROUPED_UDF = 3 class Serializer(object): diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 9bc12c3b7a16..a5f36838bbc9 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -28,7 +28,7 @@ from pyspark import since, SparkContext from pyspark.rdd import _prepare_for_python_RDD, ignore_unicode_prefix from pyspark.serializers import PickleSerializer, AutoBatchedSerializer -from pyspark.sql.types import StringType, DataType, _parse_datatype_string +from pyspark.sql.types import StringType, StructType, DataType, _parse_datatype_string from pyspark.sql.column import Column, _to_java_column, _to_seq from pyspark.sql.dataframe import DataFrame @@ -2038,13 +2038,22 @@ def _wrap_function(sc, func, returnType): sc.pythonVer, broadcast_vars, sc._javaAccumulator) +class PythonUdfType(object): + # row-based UDFs + NORMAL_UDF = 0 + # scalar vectorized UDFs + PANDAS_UDF = 1 + # grouped vectorized UDFs + PANDAS_GROUPED_UDF = 2 + + class UserDefinedFunction(object): """ User defined function in Python .. versionadded:: 1.3 """ - def __init__(self, func, returnType, name=None, vectorized=False): + def __init__(self, func, returnType, name=None, pythonUdfType=PythonUdfType.NORMAL_UDF): if not callable(func): raise TypeError( "Not a function or callable (__call__ is not defined): " @@ -2058,7 +2067,7 @@ def __init__(self, func, returnType, name=None, vectorized=False): self._name = name or ( func.__name__ if hasattr(func, '__name__') else func.__class__.__name__) - self.vectorized = vectorized + self.pythonUdfType = pythonUdfType @property def returnType(self): @@ -2090,7 +2099,7 @@ def _create_judf(self): wrapped_func = _wrap_function(sc, self.func, self.returnType) jdt = spark._jsparkSession.parseDataType(self.returnType.json()) judf = sc._jvm.org.apache.spark.sql.execution.python.UserDefinedPythonFunction( - self._name, wrapped_func, jdt, self.vectorized) + self._name, wrapped_func, jdt, self.pythonUdfType) return judf def __call__(self, *cols): @@ -2121,15 +2130,15 @@ def wrapper(*args): wrapper.func = self.func wrapper.returnType = self.returnType - wrapper.vectorized = self.vectorized + wrapper.pythonUdfType = self.pythonUdfType return wrapper -def _create_udf(f, returnType, vectorized): +def _create_udf(f, returnType, pythonUdfType): - def _udf(f, returnType=StringType(), vectorized=vectorized): - if vectorized: + def _udf(f, returnType=StringType(), pythonUdfType=pythonUdfType): + if pythonUdfType == PythonUdfType.PANDAS_UDF: import inspect argspec = inspect.getargspec(f) if len(argspec.args) == 0 and argspec.varargs is None: @@ -2137,17 +2146,24 @@ def _udf(f, returnType=StringType(), vectorized=vectorized): "0-arg pandas_udfs are not supported. " "Instead, create a 1-arg pandas_udf and ignore the arg in your function." ) - udf_obj = UserDefinedFunction(f, returnType, vectorized=vectorized) + elif pythonUdfType == PythonUdfType.PANDAS_GROUPED_UDF: + import inspect + argspec = inspect.getargspec(f) + if len(argspec.args) != 1 and argspec.varargs is None: + raise ValueError("Only 1-arg pandas_grouped_udfs are supported.") + + udf_obj = UserDefinedFunction(f, returnType, pythonUdfType=pythonUdfType) return udf_obj._wrapped() - # decorator @udf, @udf(), @udf(dataType()), or similar with @pandas_udf + # decorator @udf, @udf(), @udf(dataType()), or similar with @pandas_udf and @pandas_grouped_udf if f is None or isinstance(f, (str, DataType)): # If DataType has been passed as a positional argument # for decorator use it as a returnType return_type = f or returnType - return functools.partial(_udf, returnType=return_type, vectorized=vectorized) + return functools.partial( + _udf, returnType=return_type, pythonUdfType=pythonUdfType) else: - return _udf(f=f, returnType=returnType, vectorized=vectorized) + return _udf(f=f, returnType=returnType, pythonUdfType=pythonUdfType) @since(1.3) @@ -2181,7 +2197,7 @@ def udf(f=None, returnType=StringType()): | 8| JOHN DOE| 22| +----------+--------------+------------+ """ - return _create_udf(f, returnType=returnType, vectorized=False) + return _create_udf(f, returnType=returnType, pythonUdfType=PythonUdfType.NORMAL_UDF) @since(2.3) @@ -2192,67 +2208,82 @@ def pandas_udf(f=None, returnType=StringType()): :param f: user-defined function. A python function if used as a standalone function :param returnType: a :class:`pyspark.sql.types.DataType` object - The user-defined function can define one of the following transformations: - - 1. One or more `pandas.Series` -> A `pandas.Series` - - This udf is used with :meth:`pyspark.sql.DataFrame.withColumn` and - :meth:`pyspark.sql.DataFrame.select`. - The returnType should be a primitive data type, e.g., `DoubleType()`. - The length of the returned `pandas.Series` must be of the same as the input `pandas.Series`. - - >>> from pyspark.sql.types import IntegerType, StringType - >>> slen = pandas_udf(lambda s: s.str.len(), IntegerType()) - >>> @pandas_udf(returnType=StringType()) - ... def to_upper(s): - ... return s.str.upper() - ... - >>> @pandas_udf(returnType="integer") - ... def add_one(x): - ... return x + 1 - ... - >>> df = spark.createDataFrame([(1, "John Doe", 21)], ("id", "name", "age")) - >>> df.select(slen("name").alias("slen(name)"), to_upper("name"), add_one("age")) \\ - ... .show() # doctest: +SKIP - +----------+--------------+------------+ - |slen(name)|to_upper(name)|add_one(age)| - +----------+--------------+------------+ - | 8| JOHN DOE| 22| - +----------+--------------+------------+ - - 2. A `pandas.DataFrame` -> A `pandas.DataFrame` - - This udf is only used with :meth:`pyspark.sql.GroupedData.apply`. - The returnType should be a :class:`StructType` describing the schema of the returned - `pandas.DataFrame`. - - >>> df = spark.createDataFrame( - ... [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)], - ... ("id", "v")) - >>> @pandas_udf(returnType=df.schema) - ... def normalize(pdf): - ... v = pdf.v - ... return pdf.assign(v=(v - v.mean()) / v.std()) - >>> df.groupby('id').apply(normalize).show() # doctest: +SKIP - +---+-------------------+ - | id| v| - +---+-------------------+ - | 1|-0.7071067811865475| - | 1| 0.7071067811865475| - | 2|-0.8320502943378437| - | 2|-0.2773500981126146| - | 2| 1.1094003924504583| - +---+-------------------+ - - .. note:: This type of udf cannot be used with functions such as `withColumn` or `select` - because it defines a `DataFrame` transformation rather than a `Column` - transformation. - - .. seealso:: :meth:`pyspark.sql.GroupedData.apply` + The user-defined function can define the following transformation: + + One or more `pandas.Series` -> A `pandas.Series` + + This udf is used with :meth:`pyspark.sql.DataFrame.withColumn` and + :meth:`pyspark.sql.DataFrame.select`. + The returnType should be a primitive data type, e.g., `DoubleType()`. + The length of the returned `pandas.Series` must be of the same as the input `pandas.Series`. + + >>> from pyspark.sql.types import IntegerType, StringType + >>> slen = pandas_udf(lambda s: s.str.len(), IntegerType()) + >>> @pandas_udf(returnType=StringType()) + ... def to_upper(s): + ... return s.str.upper() + ... + >>> @pandas_udf(returnType="integer") + ... def add_one(x): + ... return x + 1 + ... + >>> df = spark.createDataFrame([(1, "John Doe", 21)], ("id", "name", "age")) + >>> df.select(slen("name").alias("slen(name)"), to_upper("name"), add_one("age")) \\ + ... .show() # doctest: +SKIP + +----------+--------------+------------+ + |slen(name)|to_upper(name)|add_one(age)| + +----------+--------------+------------+ + | 8| JOHN DOE| 22| + +----------+--------------+------------+ + + .. note:: The user-defined function must be deterministic. + """ + return _create_udf(f, returnType=returnType, pythonUdfType=PythonUdfType.PANDAS_UDF) + + +@since(2.3) +def pandas_grouped_udf(f=None, returnType=StructType()): + """ + Creates a grouped vectorized user defined function (UDF). + + :param f: user-defined function. A python function if used as a standalone function + :param returnType: a :class:`pyspark.sql.types.StructType` object + + The grouped user-defined function can define the following transformation: + + A `pandas.DataFrame` -> A `pandas.DataFrame` + + This udf is only used with :meth:`pyspark.sql.GroupedData.apply`. + The returnType should be a :class:`StructType` describing the schema of the returned + `pandas.DataFrame`. + + >>> df = spark.createDataFrame( + ... [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)], + ... ("id", "v")) + >>> @pandas_grouped_udf(returnType=df.schema) + ... def normalize(pdf): + ... v = pdf.v + ... return pdf.assign(v=(v - v.mean()) / v.std()) + >>> df.groupby('id').apply(normalize).show() # doctest: +SKIP + +---+-------------------+ + | id| v| + +---+-------------------+ + | 1|-0.7071067811865475| + | 1| 0.7071067811865475| + | 2|-0.8320502943378437| + | 2|-0.2773500981126146| + | 2| 1.1094003924504583| + +---+-------------------+ + + .. note:: This type of udf cannot be used with functions such as `withColumn` or `select` + because it defines a `DataFrame` transformation rather than a `Column` + transformation. + + .. seealso:: :meth:`pyspark.sql.GroupedData.apply` .. note:: The user-defined function must be deterministic. """ - return _create_udf(f, returnType=returnType, vectorized=True) + return _create_udf(f, returnType=returnType, pythonUdfType=PythonUdfType.PANDAS_GROUPED_UDF) blacklist = ['map', 'since', 'ignore_unicode_prefix'] diff --git a/python/pyspark/sql/group.py b/python/pyspark/sql/group.py index 817d0bc83bb7..71fbb2fd56a6 100644 --- a/python/pyspark/sql/group.py +++ b/python/pyspark/sql/group.py @@ -19,6 +19,7 @@ from pyspark.rdd import ignore_unicode_prefix from pyspark.sql.column import Column, _to_seq, _to_java_column, _create_column_from_literal from pyspark.sql.dataframe import DataFrame +from pyspark.sql.functions import PythonUdfType from pyspark.sql.types import * __all__ = ["GroupedData"] @@ -206,18 +207,18 @@ def apply(self, udf): to the user-function and the returned `pandas.DataFrame`s are combined as a :class:`DataFrame`. The returned `pandas.DataFrame` can be of arbitrary length and its schema must match the - returnType of the pandas udf. + returnType of the pandas grouped udf. This function does not support partial aggregation, and requires shuffling all the data in the :class:`DataFrame`. - :param udf: A function object returned by :meth:`pyspark.sql.functions.pandas_udf` + :param udf: A function object returned by :meth:`pyspark.sql.functions.pandas_grouped_udf` - >>> from pyspark.sql.functions import pandas_udf + >>> from pyspark.sql.functions import pandas_grouped_udf >>> df = spark.createDataFrame( ... [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)], ... ("id", "v")) - >>> @pandas_udf(returnType=df.schema) + >>> @pandas_grouped_udf(returnType=df.schema) ... def normalize(pdf): ... v = pdf.v ... return pdf.assign(v=(v - v.mean()) / v.std()) @@ -232,16 +233,17 @@ def apply(self, udf): | 2| 1.1094003924504583| +---+-------------------+ - .. seealso:: :meth:`pyspark.sql.functions.pandas_udf` + .. seealso:: :meth:`pyspark.sql.functions.pandas_grouped_udf` """ - from pyspark.sql.functions import pandas_udf + from pyspark.sql.functions import pandas_grouped_udf # Columns are special because hasattr always return True - if isinstance(udf, Column) or not hasattr(udf, 'func') or not udf.vectorized: - raise ValueError("The argument to apply must be a pandas_udf") + if isinstance(udf, Column) or not hasattr(udf, 'func') \ + or udf.pythonUdfType != PythonUdfType.PANDAS_GROUPED_UDF: + raise ValueError("The argument to apply must be a pandas_grouped_udf") if not isinstance(udf.returnType, StructType): - raise ValueError("The returnType of the pandas_udf must be a StructType") + raise ValueError("The returnType of the pandas_grouped_udf must be a StructType") df = self._df func = udf.func @@ -268,7 +270,7 @@ def wrapped(*cols): return [(result[result.columns[i]], arrow_type) for i, arrow_type in enumerate(arrow_return_types)] - wrapped_udf_obj = pandas_udf(wrapped, returnType) + wrapped_udf_obj = pandas_grouped_udf(wrapped, returnType) udf_column = wrapped_udf_obj(*[df[col] for col in df.columns]) jdf = self._jgd.flatMapGroupsInPandas(udf_column._jc.expr()) return DataFrame(jdf, self.sql_ctx) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index bac2ef84ae7a..fd3aaeb83468 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -3383,6 +3383,15 @@ def test_vectorized_udf_varargs(self): res = df.select(f(col('id'))) self.assertEquals(df.collect(), res.collect()) + def test_vectorized_udf_unsupported_types(self): + from pyspark.sql.functions import pandas_udf, col + schema = StructType([StructField("dt", DateType(), True)]) + df = self.spark.createDataFrame([(datetime.date(1970, 1, 1),)], schema=schema) + f = pandas_udf(lambda x: x, DateType()) + with QuietTest(self.sc): + with self.assertRaisesRegexp(Exception, 'Unsupported data type'): + df.select(f(col('dt'))).collect() + @unittest.skipIf(not _have_pandas or not _have_arrow, "Pandas or Arrow not installed") class GroupbyApplyTests(ReusedPySparkTestCase): @@ -3410,10 +3419,10 @@ def data(self): .withColumn("v", explode(col('vs'))).drop('vs') def test_simple(self): - from pyspark.sql.functions import pandas_udf + from pyspark.sql.functions import pandas_grouped_udf df = self.data - foo_udf = pandas_udf( + foo_udf = pandas_grouped_udf( lambda pdf: pdf.assign(v1=pdf.v * pdf.id * 1.0, v2=pdf.v + pdf.id), StructType( [StructField('id', LongType()), @@ -3426,10 +3435,10 @@ def test_simple(self): self.assertFramesEqual(expected, result) def test_decorator(self): - from pyspark.sql.functions import pandas_udf + from pyspark.sql.functions import pandas_grouped_udf df = self.data - @pandas_udf(StructType( + @pandas_grouped_udf(StructType( [StructField('id', LongType()), StructField('v', IntegerType()), StructField('v1', DoubleType()), @@ -3442,10 +3451,10 @@ def foo(pdf): self.assertFramesEqual(expected, result) def test_coerce(self): - from pyspark.sql.functions import pandas_udf + from pyspark.sql.functions import pandas_grouped_udf df = self.data - foo = pandas_udf( + foo = pandas_grouped_udf( lambda pdf: pdf, StructType([StructField('id', LongType()), StructField('v', DoubleType())])) @@ -3455,10 +3464,10 @@ def test_coerce(self): self.assertFramesEqual(expected, result) def test_complex_groupby(self): - from pyspark.sql.functions import pandas_udf, col + from pyspark.sql.functions import pandas_grouped_udf, col df = self.data - @pandas_udf(StructType( + @pandas_grouped_udf(StructType( [StructField('id', LongType()), StructField('v', IntegerType()), StructField('norm', DoubleType())])) @@ -3474,10 +3483,10 @@ def normalize(pdf): self.assertFramesEqual(expected, result) def test_empty_groupby(self): - from pyspark.sql.functions import pandas_udf, col + from pyspark.sql.functions import pandas_grouped_udf, col df = self.data - @pandas_udf(StructType( + @pandas_grouped_udf(StructType( [StructField('id', LongType()), StructField('v', IntegerType()), StructField('norm', DoubleType())])) @@ -3492,11 +3501,23 @@ def normalize(pdf): expected = expected.assign(norm=expected.norm.astype('float64')) self.assertFramesEqual(expected, result) + def test_datatype_string(self): + from pyspark.sql.functions import pandas_grouped_udf + df = self.data + + foo_udf = pandas_grouped_udf( + lambda pdf: pdf.assign(v1=pdf.v * pdf.id * 1.0, v2=pdf.v + pdf.id), + "id long, v int, v1 double, v2 long") + + result = df.groupby('id').apply(foo_udf).sort('id').toPandas() + expected = df.toPandas().groupby('id').apply(foo_udf.func).reset_index(drop=True) + self.assertFramesEqual(expected, result) + def test_wrong_return_type(self): - from pyspark.sql.functions import pandas_udf + from pyspark.sql.functions import pandas_grouped_udf df = self.data - foo = pandas_udf( + foo = pandas_grouped_udf( lambda pdf: pdf, StructType([StructField('id', LongType()), StructField('v', StringType())])) @@ -3504,21 +3525,60 @@ def test_wrong_return_type(self): with self.assertRaisesRegexp(Exception, 'Invalid.*type'): df.groupby('id').apply(foo).sort('id').toPandas() + def test_zero_or_more_than_1_parameters(self): + from pyspark.sql.functions import pandas_grouped_udf + error_str = 'Only 1-arg pandas_grouped_udfs are supported.' + with QuietTest(self.sc): + with self.assertRaisesRegexp(ValueError, error_str): + pandas_grouped_udf(lambda: 1, 'one long') + with self.assertRaisesRegexp(ValueError, error_str): + @pandas_grouped_udf + def zero_no_type(): + return 1 + with self.assertRaisesRegexp(ValueError, error_str): + @pandas_grouped_udf("one long") + def zero_with_type(): + return 1 + + with self.assertRaisesRegexp(ValueError, error_str): + pandas_grouped_udf(lambda pdf, x: pdf, 'one long') + with self.assertRaisesRegexp(ValueError, error_str): + @pandas_grouped_udf + def zero_no_type(pdf, x): + return pdf + with self.assertRaisesRegexp(ValueError, error_str): + @pandas_grouped_udf("one long") + def zero_with_type(pdf, x): + return pdf + def test_wrong_args(self): - from pyspark.sql.functions import udf, pandas_udf, sum + from pyspark.sql.functions import udf, pandas_udf, pandas_grouped_udf, sum df = self.data with QuietTest(self.sc): - with self.assertRaisesRegexp(ValueError, 'pandas_udf'): + with self.assertRaisesRegexp(ValueError, 'pandas_grouped_udf'): df.groupby('id').apply(lambda x: x) - with self.assertRaisesRegexp(ValueError, 'pandas_udf'): + with self.assertRaisesRegexp(ValueError, 'pandas_grouped_udf'): df.groupby('id').apply(udf(lambda x: x, DoubleType())) - with self.assertRaisesRegexp(ValueError, 'pandas_udf'): + with self.assertRaisesRegexp(ValueError, 'pandas_grouped_udf'): df.groupby('id').apply(sum(df.v)) - with self.assertRaisesRegexp(ValueError, 'pandas_udf'): + with self.assertRaisesRegexp(ValueError, 'pandas_grouped_udf'): df.groupby('id').apply(df.v + 1) + with self.assertRaisesRegexp(ValueError, 'pandas_grouped_udf'): + df.groupby('id').apply( + pandas_udf(lambda x: x, StructType([StructField("d", DoubleType())]))) with self.assertRaisesRegexp(ValueError, 'returnType'): - df.groupby('id').apply(pandas_udf(lambda x: x, DoubleType())) + df.groupby('id').apply(pandas_grouped_udf(lambda x: x, DoubleType())) + + def test_unsupported_types(self): + from pyspark.sql.functions import pandas_grouped_udf, col + schema = StructType( + [StructField("id", LongType(), True), StructField("dt", DateType(), True)]) + df = self.spark.createDataFrame([(1, datetime.date(1970, 1, 1),)], schema=schema) + f = pandas_grouped_udf(lambda x: x, df.schema) + with QuietTest(self.sc): + with self.assertRaisesRegexp(Exception, 'Unsupported data type'): + df.groupby('id').apply(f).collect() if __name__ == "__main__": diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index eb6d48688dc0..5e100e0a9a95 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -32,7 +32,7 @@ from pyspark.serializers import write_with_length, write_int, read_long, \ write_long, read_int, SpecialLengths, PythonEvalType, UTF8Deserializer, PickleSerializer, \ BatchedSerializer, ArrowStreamPandasSerializer -from pyspark.sql.types import to_arrow_type, StructType +from pyspark.sql.types import to_arrow_type from pyspark import shuffle pickleSer = PickleSerializer() @@ -74,28 +74,19 @@ def wrap_udf(f, return_type): def wrap_pandas_udf(f, return_type): - # If the return_type is a StructType, it indicates this is a groupby apply udf, - # and has already been wrapped under apply(), otherwise, it's a vectorized column udf. - # We can distinguish these two by return type because in groupby apply, we always specify - # returnType as a StructType, and in vectorized column udf, StructType is not supported. - # - # TODO: Look into refactoring use of StructType to be more flexible for future pandas_udfs - if isinstance(return_type, StructType): - return lambda *a: f(*a) - else: - arrow_return_type = to_arrow_type(return_type) + arrow_return_type = to_arrow_type(return_type) - def verify_result_length(*a): - result = f(*a) - if not hasattr(result, "__len__"): - raise TypeError("Return type of the user-defined functon should be " - "Pandas.Series, but is {}".format(type(result))) - if len(result) != len(a[0]): - raise RuntimeError("Result vector from pandas_udf was not the required length: " - "expected %d, got %d" % (len(a[0]), len(result))) - return result + def verify_result_length(*a): + result = f(*a) + if not hasattr(result, "__len__"): + raise TypeError("Return type of the user-defined functon should be " + "Pandas.Series, but is {}".format(type(result))) + if len(result) != len(a[0]): + raise RuntimeError("Result vector from pandas_udf was not the required length: " + "expected %d, got %d" % (len(a[0]), len(result))) + return result - return lambda *a: (verify_result_length(*a), arrow_return_type) + return lambda *a: (verify_result_length(*a), arrow_return_type) def read_single_udf(pickleSer, infile, eval_type): @@ -111,6 +102,9 @@ def read_single_udf(pickleSer, infile, eval_type): # the last returnType will be the return type of UDF if eval_type == PythonEvalType.SQL_PANDAS_UDF: return arg_offsets, wrap_pandas_udf(row_func, return_type) + elif eval_type == PythonEvalType.SQL_PANDAS_GROUPED_UDF: + # a groupby apply udf has already been wrapped under apply() + return arg_offsets, row_func else: return arg_offsets, wrap_udf(row_func, return_type) @@ -133,7 +127,8 @@ def read_udfs(pickleSer, infile, eval_type): func = lambda _, it: map(mapper, it) - if eval_type == PythonEvalType.SQL_PANDAS_UDF: + if eval_type == PythonEvalType.SQL_PANDAS_UDF \ + or eval_type == PythonEvalType.SQL_PANDAS_GROUPED_UDF: ser = ArrowStreamPandasSerializer() else: ser = BatchedSerializer(PickleSerializer(), 100) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala index 8abab24bc9b4..254687ec0088 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala @@ -24,10 +24,11 @@ import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeSet, Expre * This is used by DataFrame.groupby().apply(). */ case class FlatMapGroupsInPandas( - groupingAttributes: Seq[Attribute], - functionExpr: Expression, - output: Seq[Attribute], - child: LogicalPlan) extends UnaryNode { + groupingAttributes: Seq[Attribute], + functionExpr: Expression, + output: Seq[Attribute], + child: LogicalPlan) extends UnaryNode { + /** * This is needed because output attributes are considered `references` when * passed through the constructor. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala index cd0ac1feffa5..22faccb6f42f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala @@ -30,7 +30,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.util.usePrettyExpression import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression -import org.apache.spark.sql.execution.python.PythonUDF +import org.apache.spark.sql.execution.python.{PythonUDF, PythonUdfType} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{NumericType, StructField, StructType} @@ -437,7 +437,7 @@ class RelationalGroupedDataset protected[sql]( } /** - * Applies a vectorized python user-defined function to each group of data. + * Applies a grouped vectorized python user-defined function to each group of data. * The user-defined function defines a transformation: `pandas.DataFrame` -> `pandas.DataFrame`. * For each group, all elements in the group are passed as a `pandas.DataFrame` and the results * for all groups are combined into a new [[DataFrame]]. @@ -449,7 +449,8 @@ class RelationalGroupedDataset protected[sql]( * workers. */ private[sql] def flatMapGroupsInPandas(expr: PythonUDF): DataFrame = { - require(expr.vectorized, "Must pass a vectorized python udf") + require(expr.pythonUdfType == PythonUdfType.PANDAS_GROUPED_UDF, + "Must pass a grouped vectorized python udf") require(expr.dataType.isInstanceOf[StructType], "The returnType of the vectorized python udf must be a StructType") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala index e3f952e221d5..d6825369f737 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala @@ -137,11 +137,15 @@ object ExtractPythonUDFs extends Rule[SparkPlan] with PredicateHelper { udf.references.subsetOf(child.outputSet) } if (validUdfs.nonEmpty) { + if (validUdfs.exists(_.pythonUdfType == PythonUdfType.PANDAS_GROUPED_UDF)) { + throw new IllegalArgumentException("Can not use grouped vectorized UDFs") + } + val resultAttrs = udfs.zipWithIndex.map { case (u, i) => AttributeReference(s"pythonUDF$i", u.dataType)() } - val evaluation = validUdfs.partition(_.vectorized) match { + val evaluation = validUdfs.partition(_.pythonUdfType == PythonUdfType.PANDAS_UDF) match { case (vectorizedUdfs, plainUdfs) if plainUdfs.isEmpty => ArrowEvalPythonExec(vectorizedUdfs, child.output ++ resultAttrs, child) case (vectorizedUdfs, plainUdfs) if vectorizedUdfs.isEmpty => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala index b996b5bb38ba..5ed88ada428c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala @@ -94,7 +94,7 @@ case class FlatMapGroupsInPandasExec( val columnarBatchIter = new ArrowPythonRunner( chainedFunc, bufferSize, reuseWorker, - PythonEvalType.SQL_PANDAS_UDF, argOffsets, schema) + PythonEvalType.SQL_PANDAS_GROUPED_UDF, argOffsets, schema) .compute(grouped, context.partitionId(), context) columnarBatchIter.flatMap(_.rowIterator.asScala).map(UnsafeProjection.create(output, output)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDF.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDF.scala index 84a6d9e5be59..9c07c7638de5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDF.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDF.scala @@ -29,7 +29,7 @@ case class PythonUDF( func: PythonFunction, dataType: DataType, children: Seq[Expression], - vectorized: Boolean) + pythonUdfType: Int) extends Expression with Unevaluable with NonSQLExpression with UserDefinedExpression { override def toString: String = s"$name(${children.mkString(", ")})" diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala index a30a80acf5c2..bd298abb1564 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala @@ -22,6 +22,15 @@ import org.apache.spark.sql.Column import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.types.DataType +private[spark] object PythonUdfType { + // row-based UDFs + val NORMAL_UDF = 0 + // scalar vectorized UDFs + val PANDAS_UDF = 1 + // grouped vectorized UDFs + val PANDAS_GROUPED_UDF = 2 +} + /** * A user-defined Python function. This is used by the Python API. */ @@ -29,10 +38,10 @@ case class UserDefinedPythonFunction( name: String, func: PythonFunction, dataType: DataType, - vectorized: Boolean) { + pythonUdfType: Int) { def builder(e: Seq[Expression]): PythonUDF = { - PythonUDF(name, func, dataType, e, vectorized) + PythonUDF(name, func, dataType, e, pythonUdfType) } /** Returns a [[Column]] that will evaluate to calling this UDF with the given input. */ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExecSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExecSuite.scala index 153e6e1f88c7..95b21fc9f16a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExecSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExecSuite.scala @@ -109,4 +109,4 @@ class MyDummyPythonUDF extends UserDefinedPythonFunction( name = "dummyUDF", func = new DummyUDF, dataType = BooleanType, - vectorized = false) + pythonUdfType = PythonUdfType.NORMAL_UDF)