diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala index d2a10df7acbd3..277804ec41d98 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala @@ -48,6 +48,7 @@ private[spark] object PythonEvalType { val SQL_WINDOW_AGG_PANDAS_UDF = 203 val SQL_SCALAR_PANDAS_ITER_UDF = 204 val SQL_MAP_PANDAS_ITER_UDF = 205 + val SQL_COGROUPED_MAP_PANDAS_UDF = 206 def toString(pythonEvalType: Int): String = pythonEvalType match { case NON_UDF => "NON_UDF" @@ -58,6 +59,7 @@ private[spark] object PythonEvalType { case SQL_WINDOW_AGG_PANDAS_UDF => "SQL_WINDOW_AGG_PANDAS_UDF" case SQL_SCALAR_PANDAS_ITER_UDF => "SQL_SCALAR_PANDAS_ITER_UDF" case SQL_MAP_PANDAS_ITER_UDF => "SQL_MAP_PANDAS_ITER_UDF" + case SQL_COGROUPED_MAP_PANDAS_UDF => "SQL_COGROUPED_MAP_PANDAS_UDF" } } diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 16c226f02e633..be0244b7d13e2 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -75,6 +75,7 @@ class PythonEvalType(object): SQL_WINDOW_AGG_PANDAS_UDF = 203 SQL_SCALAR_PANDAS_ITER_UDF = 204 SQL_MAP_PANDAS_ITER_UDF = 205 + SQL_COGROUPED_MAP_PANDAS_UDF = 206 def portable_hash(x): diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index 00f6081a3b14f..bceb92cb274ae 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -401,6 +401,32 @@ def __repr__(self): return "ArrowStreamPandasUDFSerializer" +class CogroupUDFSerializer(ArrowStreamPandasUDFSerializer): + + def load_stream(self, stream): + """ + Deserialize Cogrouped ArrowRecordBatches to a tuple of Arrow tables and yield as two + lists of pandas.Series. + """ + import pyarrow as pa + dataframes_in_group = None + + while dataframes_in_group is None or dataframes_in_group > 0: + dataframes_in_group = read_int(stream) + + if dataframes_in_group == 2: + batch1 = [batch for batch in ArrowStreamSerializer.load_stream(self, stream)] + batch2 = [batch for batch in ArrowStreamSerializer.load_stream(self, stream)] + yield ( + [self.arrow_to_pandas(c) for c in pa.Table.from_batches(batch1).itercolumns()], + [self.arrow_to_pandas(c) for c in pa.Table.from_batches(batch2).itercolumns()] + ) + + elif dataframes_in_group != 0: + raise ValueError( + 'Invalid number of pandas.DataFrames in group {0}'.format(dataframes_in_group)) + + class BatchedSerializer(Serializer): """ diff --git a/python/pyspark/sql/cogroup.py b/python/pyspark/sql/cogroup.py new file mode 100644 index 0000000000000..9b725e4bafe79 --- /dev/null +++ b/python/pyspark/sql/cogroup.py @@ -0,0 +1,98 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from pyspark import since +from pyspark.rdd import PythonEvalType +from pyspark.sql.column import Column +from pyspark.sql.dataframe import DataFrame + + +class CoGroupedData(object): + """ + A logical grouping of two :class:`GroupedData`, + created by :func:`GroupedData.cogroup`. + + .. note:: Experimental + + .. versionadded:: 3.0 + """ + + def __init__(self, gd1, gd2): + self._gd1 = gd1 + self._gd2 = gd2 + self.sql_ctx = gd1.sql_ctx + + @since(3.0) + def apply(self, udf): + """ + Applies a function to each cogroup using a pandas udf and returns the result + as a `DataFrame`. + + The user-defined function should take two `pandas.DataFrame` and return another + `pandas.DataFrame`. For each side of the cogroup, all columns are passed together + as a `pandas.DataFrame` to the user-function and the returned `pandas.DataFrame` + are combined as a :class:`DataFrame`. + + The returned `pandas.DataFrame` can be of arbitrary length and its schema must match the + returnType of the pandas udf. + + .. note:: This function requires a full shuffle. All the data of a cogroup will be loaded + into memory, so the user should be aware of the potential OOM risk if data is skewed + and certain groups are too large to fit in memory. + + .. note:: Experimental + + :param udf: a cogrouped map user-defined function returned by + :func:`pyspark.sql.functions.pandas_udf`. + + >>> from pyspark.sql.functions import pandas_udf, PandasUDFType + >>> df1 = spark.createDataFrame( + ... [(20000101, 1, 1.0), (20000101, 2, 2.0), (20000102, 1, 3.0), (20000102, 2, 4.0)], + ... ("time", "id", "v1")) + >>> df2 = spark.createDataFrame( + ... [(20000101, 1, "x"), (20000101, 2, "y")], + ... ("time", "id", "v2")) + >>> @pandas_udf("time int, id int, v1 double, v2 string", PandasUDFType.COGROUPED_MAP) + ... def asof_join(l, r): + ... return pd.merge_asof(l, r, on="time", by="id") + >>> df1.groupby("id").cogroup(df2.groupby("id")).apply(asof_join).show() + +--------+---+---+---+ + | time| id| v1| v2| + +--------+---+---+---+ + |20000101| 1|1.0| x| + |20000102| 1|3.0| x| + |20000101| 2|2.0| y| + |20000102| 2|4.0| y| + +--------+---+---+---+ + + .. seealso:: :meth:`pyspark.sql.functions.pandas_udf` + + """ + # Columns are special because hasattr always return True + if isinstance(udf, Column) or not hasattr(udf, 'func') \ + or udf.evalType != PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF: + raise ValueError("Invalid udf: the udf argument must be a pandas_udf of type " + "COGROUPED_MAP.") + all_cols = self._extract_cols(self._gd1) + self._extract_cols(self._gd2) + udf_column = udf(*all_cols) + jdf = self._gd1._jgd.flatMapCoGroupsInPandas(self._gd2._jgd, udf_column._jc.expr()) + return DataFrame(jdf, self.sql_ctx) + + @staticmethod + def _extract_cols(gd): + df = gd._df + return [df[col] for col in df.columns] diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index c7ff2882ed95a..d96c264aa7398 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -2814,6 +2814,8 @@ class PandasUDFType(object): GROUPED_MAP = PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF + COGROUPED_MAP = PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF + GROUPED_AGG = PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF MAP_ITER = PythonEvalType.SQL_MAP_PANDAS_ITER_UDF @@ -3320,7 +3322,8 @@ def pandas_udf(f=None, returnType=None, functionType=None): PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF, PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF, PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF, - PythonEvalType.SQL_MAP_PANDAS_ITER_UDF]: + PythonEvalType.SQL_MAP_PANDAS_ITER_UDF, + PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF]: raise ValueError("Invalid functionType: " "functionType must be one the values from PandasUDFType") diff --git a/python/pyspark/sql/group.py b/python/pyspark/sql/group.py index ec90ba905ef66..fcad64142485e 100644 --- a/python/pyspark/sql/group.py +++ b/python/pyspark/sql/group.py @@ -22,6 +22,7 @@ from pyspark.sql.column import Column, _to_seq from pyspark.sql.dataframe import DataFrame from pyspark.sql.types import * +from pyspark.sql.cogroup import CoGroupedData __all__ = ["GroupedData"] @@ -218,6 +219,15 @@ def pivot(self, pivot_col, values=None): jgd = self._jgd.pivot(pivot_col, values) return GroupedData(jgd, self._df) + @since(3.0) + def cogroup(self, other): + """ + Cogroups this group with another group so that we can run cogrouped operations. + + See :class:`CoGroupedData` for the operations that can be run. + """ + return CoGroupedData(self, other) + @since(2.3) def apply(self, udf): """ @@ -232,7 +242,7 @@ def apply(self, udf): The returned `pandas.DataFrame` can be of arbitrary length and its schema must match the returnType of the pandas udf. - .. note:: This function requires a full shuffle. all the data of a group will be loaded + .. note:: This function requires a full shuffle. All the data of a group will be loaded into memory, so the user should be aware of the potential OOM risk if data is skewed and certain groups are too large to fit in memory. diff --git a/python/pyspark/sql/tests/test_pandas_udf_cogrouped_map.py b/python/pyspark/sql/tests/test_pandas_udf_cogrouped_map.py new file mode 100644 index 0000000000000..7f3f7fa3168a7 --- /dev/null +++ b/python/pyspark/sql/tests/test_pandas_udf_cogrouped_map.py @@ -0,0 +1,280 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import unittest +import sys + +from pyspark.sql.functions import array, explode, col, lit, udf, sum, pandas_udf, PandasUDFType +from pyspark.sql.types import DoubleType, StructType, StructField +from pyspark.testing.sqlutils import ReusedSQLTestCase, have_pandas, have_pyarrow, \ + pandas_requirement_message, pyarrow_requirement_message +from pyspark.testing.utils import QuietTest + +if have_pandas: + import pandas as pd + from pandas.util.testing import assert_frame_equal, assert_series_equal + +if have_pyarrow: + import pyarrow as pa + + +""" +Tests below use pd.DataFrame.assign that will infer mixed types (unicode/str) for column names +from kwargs w/ Python 2, so need to set check_column_type=False and avoid this check +""" +if sys.version < '3': + _check_column_type = False +else: + _check_column_type = True + + +@unittest.skipIf( + not have_pandas or not have_pyarrow, + pandas_requirement_message or pyarrow_requirement_message) +class CoGroupedMapPandasUDFTests(ReusedSQLTestCase): + + @property + def data1(self): + return self.spark.range(10).toDF('id') \ + .withColumn("ks", array([lit(i) for i in range(20, 30)])) \ + .withColumn("k", explode(col('ks')))\ + .withColumn("v", col('k') * 10)\ + .drop('ks') + + @property + def data2(self): + return self.spark.range(10).toDF('id') \ + .withColumn("ks", array([lit(i) for i in range(20, 30)])) \ + .withColumn("k", explode(col('ks'))) \ + .withColumn("v2", col('k') * 100) \ + .drop('ks') + + def test_simple(self): + self._test_merge(self.data1, self.data2) + + def test_left_group_empty(self): + left = self.data1.where(col("id") % 2 == 0) + self._test_merge(left, self.data2) + + def test_right_group_empty(self): + right = self.data2.where(col("id") % 2 == 0) + self._test_merge(self.data1, right) + + def test_different_schemas(self): + right = self.data2.withColumn('v3', lit('a')) + self._test_merge(self.data1, right, 'id long, k int, v int, v2 int, v3 string') + + def test_complex_group_by(self): + left = pd.DataFrame.from_dict({ + 'id': [1, 2, 3], + 'k': [5, 6, 7], + 'v': [9, 10, 11] + }) + + right = pd.DataFrame.from_dict({ + 'id': [11, 12, 13], + 'k': [5, 6, 7], + 'v2': [90, 100, 110] + }) + + left_gdf = self.spark\ + .createDataFrame(left)\ + .groupby(col('id') % 2 == 0) + + right_gdf = self.spark \ + .createDataFrame(right) \ + .groupby(col('id') % 2 == 0) + + @pandas_udf('k long, v long, v2 long', PandasUDFType.COGROUPED_MAP) + def merge_pandas(l, r): + return pd.merge(l[['k', 'v']], r[['k', 'v2']], on=['k']) + + result = left_gdf \ + .cogroup(right_gdf) \ + .apply(merge_pandas) \ + .sort(['k']) \ + .toPandas() + + expected = pd.DataFrame.from_dict({ + 'k': [5, 6, 7], + 'v': [9, 10, 11], + 'v2': [90, 100, 110] + }) + + assert_frame_equal(expected, result, check_column_type=_check_column_type) + + def test_empty_group_by(self): + left = self.data1 + right = self.data2 + + @pandas_udf('id long, k int, v int, v2 int', PandasUDFType.COGROUPED_MAP) + def merge_pandas(l, r): + return pd.merge(l, r, on=['id', 'k']) + + result = left.groupby().cogroup(right.groupby())\ + .apply(merge_pandas) \ + .sort(['id', 'k']) \ + .toPandas() + + left = left.toPandas() + right = right.toPandas() + + expected = pd \ + .merge(left, right, on=['id', 'k']) \ + .sort_values(by=['id', 'k']) + + assert_frame_equal(expected, result, check_column_type=_check_column_type) + + def test_mixed_scalar_udfs_followed_by_cogrouby_apply(self): + df = self.spark.range(0, 10).toDF('v1') + df = df.withColumn('v2', udf(lambda x: x + 1, 'int')(df['v1'])) \ + .withColumn('v3', pandas_udf(lambda x: x + 2, 'int')(df['v1'])) + + result = df.groupby().cogroup(df.groupby())\ + .apply(pandas_udf(lambda x, y: pd.DataFrame([(x.sum().sum(), y.sum().sum())]), + 'sum1 int, sum2 int', + PandasUDFType.COGROUPED_MAP)).collect() + + self.assertEquals(result[0]['sum1'], 165) + self.assertEquals(result[0]['sum2'], 165) + + def test_with_key_left(self): + self._test_with_key(self.data1, self.data1, isLeft=True) + + def test_with_key_right(self): + self._test_with_key(self.data1, self.data1, isLeft=False) + + def test_with_key_left_group_empty(self): + left = self.data1.where(col("id") % 2 == 0) + self._test_with_key(left, self.data1, isLeft=True) + + def test_with_key_right_group_empty(self): + right = self.data1.where(col("id") % 2 == 0) + self._test_with_key(self.data1, right, isLeft=False) + + def test_with_key_complex(self): + + @pandas_udf('id long, k int, v int, key boolean', PandasUDFType.COGROUPED_MAP) + def left_assign_key(key, l, _): + return l.assign(key=key[0]) + + result = self.data1 \ + .groupby(col('id') % 2 == 0)\ + .cogroup(self.data2.groupby(col('id') % 2 == 0)) \ + .apply(left_assign_key) \ + .sort(['id', 'k']) \ + .toPandas() + + expected = self.data1.toPandas() + expected = expected.assign(key=expected.id % 2 == 0) + + assert_frame_equal(expected, result, check_column_type=_check_column_type) + + def test_wrong_return_type(self): + with QuietTest(self.sc): + with self.assertRaisesRegexp( + NotImplementedError, + 'Invalid returnType.*cogrouped map Pandas UDF.*MapType'): + pandas_udf( + lambda l, r: l, + 'id long, v map', + PandasUDFType.COGROUPED_MAP) + + def test_wrong_args(self): + # Test that we get a sensible exception invalid values passed to apply + left = self.data1 + right = self.data2 + with QuietTest(self.sc): + # Function rather than a udf + with self.assertRaisesRegexp(ValueError, 'Invalid udf'): + left.groupby('id').cogroup(right.groupby('id')).apply(lambda l, r: l) + + # Udf missing return type + with self.assertRaisesRegexp(ValueError, 'Invalid udf'): + left.groupby('id').cogroup(right.groupby('id'))\ + .apply(udf(lambda l, r: l, DoubleType())) + + # Pass in expression rather than udf + with self.assertRaisesRegexp(ValueError, 'Invalid udf'): + left.groupby('id').cogroup(right.groupby('id')).apply(left.v + 1) + + # Zero arg function + with self.assertRaisesRegexp(ValueError, 'Invalid function'): + left.groupby('id').cogroup(right.groupby('id'))\ + .apply(pandas_udf(lambda: 1, StructType([StructField("d", DoubleType())]))) + + # Udf without PandasUDFType + with self.assertRaisesRegexp(ValueError, 'Invalid udf'): + left.groupby('id').cogroup(right.groupby('id'))\ + .apply(pandas_udf(lambda x, y: x, DoubleType())) + + # Udf with incorrect PandasUDFType + with self.assertRaisesRegexp(ValueError, 'Invalid udf.*COGROUPED_MAP'): + left.groupby('id').cogroup(right.groupby('id'))\ + .apply(pandas_udf(lambda x, y: x, DoubleType(), PandasUDFType.SCALAR)) + + @staticmethod + def _test_with_key(left, right, isLeft): + + @pandas_udf('id long, k int, v int, key long', PandasUDFType.COGROUPED_MAP) + def right_assign_key(key, l, r): + return l.assign(key=key[0]) if isLeft else r.assign(key=key[0]) + + result = left \ + .groupby('id') \ + .cogroup(right.groupby('id')) \ + .apply(right_assign_key) \ + .toPandas() + + expected = left.toPandas() if isLeft else right.toPandas() + expected = expected.assign(key=expected.id) + + assert_frame_equal(expected, result, check_column_type=_check_column_type) + + @staticmethod + def _test_merge(left, right, output_schema='id long, k int, v int, v2 int'): + + @pandas_udf(output_schema, PandasUDFType.COGROUPED_MAP) + def merge_pandas(l, r): + return pd.merge(l, r, on=['id', 'k']) + + result = left \ + .groupby('id') \ + .cogroup(right.groupby('id')) \ + .apply(merge_pandas)\ + .sort(['id', 'k']) \ + .toPandas() + + left = left.toPandas() + right = right.toPandas() + + expected = pd \ + .merge(left, right, on=['id', 'k']) \ + .sort_values(by=['id', 'k']) + + assert_frame_equal(expected, result, check_column_type=_check_column_type) + + +if __name__ == "__main__": + from pyspark.sql.tests.test_pandas_udf_cogrouped_map import * + + try: + import xmlrunner + testRunner = xmlrunner.XMLTestRunner(output='target/test-reports', verbosity=2) + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/sql/udf.py b/python/pyspark/sql/udf.py index 188ec2634974a..c4d7c1ed205f1 100644 --- a/python/pyspark/sql/udf.py +++ b/python/pyspark/sql/udf.py @@ -42,6 +42,7 @@ def _create_udf(f, returnType, evalType): if evalType in (PythonEvalType.SQL_SCALAR_PANDAS_UDF, PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF, PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF, + PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF, PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF, PythonEvalType.SQL_MAP_PANDAS_ITER_UDF): @@ -65,6 +66,13 @@ def _create_udf(f, returnType, evalType): "Invalid function: pandas_udfs with function type GROUPED_MAP " "must take either one argument (data) or two arguments (key, data).") + if evalType == PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF \ + and len(argspec.args) not in (2, 3): + raise ValueError( + "Invalid function: pandas_udfs with function type COGROUPED_MAP " + "must take either two arguments (left, right) " + "or three arguments (key, left, right).") + # Set the name of the UserDefinedFunction object to be the name of function f udf_obj = UserDefinedFunction( f, returnType=returnType, name=None, evalType=evalType, deterministic=True) @@ -147,6 +155,17 @@ def returnType(self): else: raise TypeError("Invalid returnType for map iterator Pandas " "UDFs: returnType must be a StructType.") + elif self.evalType == PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF: + if isinstance(self._returnType_placeholder, StructType): + try: + to_arrow_type(self._returnType_placeholder) + except TypeError: + raise NotImplementedError( + "Invalid returnType with cogrouped map Pandas UDFs: " + "%s is not supported" % str(self._returnType_placeholder)) + else: + raise TypeError("Invalid returnType for cogrouped map Pandas " + "UDFs: returnType must be a StructType.") elif self.evalType == PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF: try: # StructType is not yet allowed as a return type, explicitly check here to fail fast diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 7f38c27360ed9..086202de2c68b 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -39,7 +39,7 @@ from pyspark.rdd import PythonEvalType from pyspark.serializers import write_with_length, write_int, read_long, read_bool, \ write_long, read_int, SpecialLengths, UTF8Deserializer, PickleSerializer, \ - BatchedSerializer, ArrowStreamPandasUDFSerializer + BatchedSerializer, ArrowStreamPandasUDFSerializer, CogroupUDFSerializer from pyspark.sql.types import to_arrow_type, StructType from pyspark.util import _get_argspec, fail_on_stopiteration from pyspark import shuffle @@ -121,6 +121,33 @@ def verify_result_type(result): map(verify_result_type, f(*iterator))) +def wrap_cogrouped_map_pandas_udf(f, return_type, argspec): + + def wrapped(left_key_series, left_value_series, right_key_series, right_value_series): + import pandas as pd + + left_df = pd.concat(left_value_series, axis=1) + right_df = pd.concat(right_value_series, axis=1) + + if len(argspec.args) == 2: + result = f(left_df, right_df) + elif len(argspec.args) == 3: + key_series = left_key_series if not left_df.empty else right_key_series + key = tuple(s[0] for s in key_series) + result = f(key, left_df, right_df) + if not isinstance(result, pd.DataFrame): + raise TypeError("Return type of the user-defined function should be " + "pandas.DataFrame, but is {}".format(type(result))) + if not len(result.columns) == len(return_type): + raise RuntimeError( + "Number of columns of the returned pandas.DataFrame " + "doesn't match specified schema. " + "Expected: {} Actual: {}".format(len(return_type), len(result.columns))) + return result + + return lambda kl, vl, kr, vr: [(wrapped(kl, vl, kr, vr), to_arrow_type(return_type))] + + def wrap_grouped_map_pandas_udf(f, return_type, argspec): def wrapped(key_series, value_series): @@ -244,6 +271,9 @@ def read_single_udf(pickleSer, infile, eval_type, runner_conf, udf_index): elif eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF: argspec = _get_argspec(chained_func) # signature was lost when wrapping it return arg_offsets, wrap_grouped_map_pandas_udf(func, return_type, argspec) + elif eval_type == PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF: + argspec = _get_argspec(chained_func) # signature was lost when wrapping it + return arg_offsets, wrap_cogrouped_map_pandas_udf(func, return_type, argspec) elif eval_type == PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF: return arg_offsets, wrap_grouped_agg_pandas_udf(func, return_type) elif eval_type == PythonEvalType.SQL_WINDOW_AGG_PANDAS_UDF: @@ -258,6 +288,7 @@ def read_udfs(pickleSer, infile, eval_type): runner_conf = {} if eval_type in (PythonEvalType.SQL_SCALAR_PANDAS_UDF, + PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF, PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF, PythonEvalType.SQL_MAP_PANDAS_ITER_UDF, PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF, @@ -280,13 +311,16 @@ def read_udfs(pickleSer, infile, eval_type): "spark.sql.legacy.execution.pandas.groupedMap.assignColumnsByName", "true")\ .lower() == "true" - # Scalar Pandas UDF handles struct type arguments as pandas DataFrames instead of - # pandas Series. See SPARK-27240. - df_for_struct = (eval_type == PythonEvalType.SQL_SCALAR_PANDAS_UDF or - eval_type == PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF or - eval_type == PythonEvalType.SQL_MAP_PANDAS_ITER_UDF) - ser = ArrowStreamPandasUDFSerializer(timezone, safecheck, assign_cols_by_name, - df_for_struct) + if eval_type == PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF: + ser = CogroupUDFSerializer(timezone, safecheck, assign_cols_by_name) + else: + # Scalar Pandas UDF handles struct type arguments as pandas DataFrames instead of + # pandas Series. See SPARK-27240. + df_for_struct = (eval_type == PythonEvalType.SQL_SCALAR_PANDAS_UDF or + eval_type == PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF or + eval_type == PythonEvalType.SQL_MAP_PANDAS_ITER_UDF) + ser = ArrowStreamPandasUDFSerializer(timezone, safecheck, assign_cols_by_name, + df_for_struct) else: ser = BatchedSerializer(PickleSerializer(), 100) @@ -343,6 +377,32 @@ def map_batch(batch): # profiling is not supported for UDF return func, None, ser, ser + def extract_key_value_indexes(grouped_arg_offsets): + """ + Helper function to extract the key and value indexes from arg_offsets for the grouped and + cogrouped pandas udfs. See BasePandasGroupExec.resolveArgOffsets for equivalent scala code. + + :param grouped_arg_offsets: List containing the key and value indexes of columns of the + DataFrames to be passed to the udf. It consists of n repeating groups where n is the + number of DataFrames. Each group has the following format: + group[0]: length of group + group[1]: length of key indexes + group[2.. group[1] +2]: key attributes + group[group[1] +3 group[0]]: value attributes + """ + parsed = [] + idx = 0 + while idx < len(grouped_arg_offsets): + offsets_len = grouped_arg_offsets[idx] + idx += 1 + offsets = grouped_arg_offsets[idx: idx + offsets_len] + split_index = offsets[0] + 1 + offset_keys = offsets[1: split_index] + offset_values = offsets[split_index:] + parsed.append([offset_keys, offset_values]) + idx += offsets_len + return parsed + udfs = {} call_udf = [] mapper_str = "" @@ -359,10 +419,24 @@ def map_batch(batch): arg_offsets, udf = read_single_udf( pickleSer, infile, eval_type, runner_conf, udf_index=0) udfs['f'] = udf - split_offset = arg_offsets[0] + 1 - arg0 = ["a[%d]" % o for o in arg_offsets[1: split_offset]] - arg1 = ["a[%d]" % o for o in arg_offsets[split_offset:]] - mapper_str = "lambda a: f([%s], [%s])" % (", ".join(arg0), ", ".join(arg1)) + parsed_offsets = extract_key_value_indexes(arg_offsets) + keys = ["a[%d]" % (o,) for o in parsed_offsets[0][0]] + vals = ["a[%d]" % (o, ) for o in parsed_offsets[0][1]] + mapper_str = "lambda a: f([%s], [%s])" % (", ".join(keys), ", ".join(vals)) + elif eval_type == PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF: + # We assume there is only one UDF here because cogrouped map doesn't + # support combining multiple UDFs. + assert num_udfs == 1 + arg_offsets, udf = read_single_udf( + pickleSer, infile, eval_type, runner_conf, udf_index=0) + udfs['f'] = udf + parsed_offsets = extract_key_value_indexes(arg_offsets) + df1_keys = ["a[0][%d]" % (o, ) for o in parsed_offsets[0][0]] + df1_vals = ["a[0][%d]" % (o, ) for o in parsed_offsets[0][1]] + df2_keys = ["a[1][%d]" % (o, ) for o in parsed_offsets[1][0]] + df2_vals = ["a[1][%d]" % (o, ) for o in parsed_offsets[1][1]] + mapper_str = "lambda a: f([%s], [%s], [%s], [%s])" % ( + ", ".join(df1_keys), ", ".join(df1_vals), ", ".join(df2_keys), ", ".join(df2_vals)) else: # Create function like this: # lambda a: (f0(a[0]), f1(a[1], a[2]), f2(a[3])) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 6f17256f8163e..fbc5c44e039b6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -1171,6 +1171,12 @@ class Analyzer( // To resolve duplicate expression IDs for Join and Intersect case j @ Join(left, right, _, _, _) if !j.duplicateResolved => j.copy(right = dedupRight(left, right)) + case f @ FlatMapCoGroupsInPandas(leftAttributes, rightAttributes, _, _, left, right) => + val leftRes = leftAttributes + .map(x => resolveExpressionBottomUp(x, left).asInstanceOf[Attribute]) + val rightRes = rightAttributes + .map(x => resolveExpressionBottomUp(x, right).asInstanceOf[Attribute]) + f.copy(leftAttributes = leftRes, rightAttributes = rightRes) // intersect/except will be rewritten to join at the begininng of optimizer. Here we need to // deduplicate the right side plan, so that we won't produce an invalid self-join later. case i @ Intersect(left, right, _) if !i.duplicateResolved => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala index dc2185194d84e..c4f741cd2cec8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeSet, Expression, PythonUDF} /** - * FlatMap groups using an udf: pandas.Dataframe -> pandas.DataFrame. + * FlatMap groups using a udf: pandas.Dataframe -> pandas.DataFrame. * This is used by DataFrame.groupby().apply(). */ case class FlatMapGroupsInPandas( @@ -40,7 +40,7 @@ case class FlatMapGroupsInPandas( } /** - * Map partitions using an udf: iter(pandas.Dataframe) -> iter(pandas.DataFrame). + * Map partitions using a udf: iter(pandas.Dataframe) -> iter(pandas.DataFrame). * This is used by DataFrame.mapInPandas() */ case class MapInPandas( @@ -51,6 +51,21 @@ case class MapInPandas( override val producedAttributes = AttributeSet(output) } +/** + * Flatmap cogroups using a udf: pandas.Dataframe, pandas.Dataframe -> pandas.Dataframe + * This is used by DataFrame.groupby().cogroup().apply(). + */ +case class FlatMapCoGroupsInPandas( + leftAttributes: Seq[Attribute], + rightAttributes: Seq[Attribute], + functionExpr: Expression, + output: Seq[Attribute], + left: LogicalPlan, + right: LogicalPlan) extends BinaryNode { + + override val producedAttributes = AttributeSet(output) +} + trait BaseEvalPython extends UnaryNode { def udfs: Seq[PythonUDF] diff --git a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala index e85636d82a62c..f6d13be0e89be 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala @@ -47,8 +47,8 @@ import org.apache.spark.sql.types.{NumericType, StructType} */ @Stable class RelationalGroupedDataset protected[sql]( - df: DataFrame, - groupingExprs: Seq[Expression], + val df: DataFrame, + val groupingExprs: Seq[Expression], groupType: RelationalGroupedDataset.GroupType) { private[this] def toDF(aggExprs: Seq[Expression]): DataFrame = { @@ -523,6 +523,48 @@ class RelationalGroupedDataset protected[sql]( Dataset.ofRows(df.sparkSession, plan) } + /** + * Applies a vectorized python user-defined function to each cogrouped data. + * The user-defined function defines a transformation: + * `pandas.DataFrame`, `pandas.DataFrame` -> `pandas.DataFrame`. + * For each group in the cogrouped data, all elements in the group are passed as a + * `pandas.DataFrame` and the results for all cogroups are combined into a new [[DataFrame]]. + * + * This function uses Apache Arrow as serialization format between Java executors and Python + * workers. + */ + private[sql] def flatMapCoGroupsInPandas( + r: RelationalGroupedDataset, + expr: PythonUDF): DataFrame = { + require(expr.evalType == PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF, + "Must pass a cogrouped map udf") + require(expr.dataType.isInstanceOf[StructType], + s"The returnType of the udf must be a ${StructType.simpleString}") + + val leftGroupingNamedExpressions = groupingExprs.map { + case ne: NamedExpression => ne + case other => Alias(other, other.toString)() + } + + val rightGroupingNamedExpressions = r.groupingExprs.map { + case ne: NamedExpression => ne + case other => Alias(other, other.toString)() + } + + val leftAttributes = leftGroupingNamedExpressions.map(_.toAttribute) + val rightAttributes = rightGroupingNamedExpressions.map(_.toAttribute) + + val leftChild = df.logicalPlan + val rightChild = r.df.logicalPlan + + val left = Project(leftGroupingNamedExpressions ++ leftChild.output, leftChild) + val right = Project(rightGroupingNamedExpressions ++ rightChild.output, rightChild) + + val output = expr.dataType.asInstanceOf[StructType].toAttributes + val plan = FlatMapCoGroupsInPandas(leftAttributes, rightAttributes, expr, output, left, right) + Dataset.ofRows(df.sparkSession, plan) + } + override def toString: String = { val builder = new StringBuilder builder.append("RelationalGroupedDataset: [grouping expressions: [") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 85469bf2401d4..a2f45898d273f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -682,6 +682,9 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { f, p, b, is, ot, planLater(child)) :: Nil case logical.FlatMapGroupsInPandas(grouping, func, output, child) => execution.python.FlatMapGroupsInPandasExec(grouping, func, output, planLater(child)) :: Nil + case logical.FlatMapCoGroupsInPandas(leftGroup, rightGroup, func, output, left, right) => + execution.python.FlatMapCoGroupsInPandasExec( + leftGroup, rightGroup, func, output, planLater(left), planLater(right)) :: Nil case logical.MapInPandas(func, output, child) => execution.python.MapInPandasExec(func, output, planLater(child)) :: Nil case logical.MapElements(f, _, _, objAttr, child) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala index 5101f7e871af2..fcf68467460bb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala @@ -19,12 +19,9 @@ package org.apache.spark.sql.execution.python import java.io._ import java.net._ -import java.util.concurrent.atomic.AtomicBoolean - -import scala.collection.JavaConverters._ import org.apache.arrow.vector.VectorSchemaRoot -import org.apache.arrow.vector.ipc.{ArrowStreamReader, ArrowStreamWriter} +import org.apache.arrow.vector.ipc.ArrowStreamWriter import org.apache.spark._ import org.apache.spark.api.python._ @@ -33,7 +30,6 @@ import org.apache.spark.sql.execution.arrow.ArrowWriter import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.sql.util.ArrowUtils -import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch, ColumnVector} import org.apache.spark.util.Utils /** @@ -46,7 +42,7 @@ class ArrowPythonRunner( schema: StructType, timeZoneId: String, conf: Map[String, String]) - extends BasePythonRunner[Iterator[InternalRow], ColumnarBatch]( + extends BaseArrowPythonRunner[Iterator[InternalRow]]( funcs, evalType, argOffsets) { override val bufferSize: Int = SQLConf.get.pandasUDFBufferSize @@ -119,72 +115,4 @@ class ArrowPythonRunner( } } - protected override def newReaderIterator( - stream: DataInputStream, - writerThread: WriterThread, - startTime: Long, - env: SparkEnv, - worker: Socket, - releasedOrClosed: AtomicBoolean, - context: TaskContext): Iterator[ColumnarBatch] = { - new ReaderIterator(stream, writerThread, startTime, env, worker, releasedOrClosed, context) { - - private val allocator = ArrowUtils.rootAllocator.newChildAllocator( - s"stdin reader for $pythonExec", 0, Long.MaxValue) - - private var reader: ArrowStreamReader = _ - private var root: VectorSchemaRoot = _ - private var schema: StructType = _ - private var vectors: Array[ColumnVector] = _ - - context.addTaskCompletionListener[Unit] { _ => - if (reader != null) { - reader.close(false) - } - allocator.close() - } - - private var batchLoaded = true - - protected override def read(): ColumnarBatch = { - if (writerThread.exception.isDefined) { - throw writerThread.exception.get - } - try { - if (reader != null && batchLoaded) { - batchLoaded = reader.loadNextBatch() - if (batchLoaded) { - val batch = new ColumnarBatch(vectors) - batch.setNumRows(root.getRowCount) - batch - } else { - reader.close(false) - allocator.close() - // Reach end of stream. Call `read()` again to read control data. - read() - } - } else { - stream.readInt() match { - case SpecialLengths.START_ARROW_STREAM => - reader = new ArrowStreamReader(stream, allocator) - root = reader.getVectorSchemaRoot() - schema = ArrowUtils.fromArrowSchema(root.getSchema()) - vectors = root.getFieldVectors().asScala.map { vector => - new ArrowColumnVector(vector) - }.toArray[ColumnVector] - read() - case SpecialLengths.TIMING_DATA => - handleTimingData() - read() - case SpecialLengths.PYTHON_EXCEPTION_THROWN => - throw handlePythonException() - case SpecialLengths.END_OF_DATA_SECTION => - handleEndOfDataSection() - null - } - } - } catch handleException - } - } - } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BaseArrowPythonRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BaseArrowPythonRunner.scala new file mode 100644 index 0000000000000..0cee7d2f96c22 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BaseArrowPythonRunner.scala @@ -0,0 +1,112 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.python + +import java.io._ +import java.net._ +import java.util.concurrent.atomic.AtomicBoolean + +import scala.collection.JavaConverters._ + +import org.apache.arrow.vector.VectorSchemaRoot +import org.apache.arrow.vector.ipc.ArrowStreamReader + +import org.apache.spark._ +import org.apache.spark.api.python._ +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.util.ArrowUtils +import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch, ColumnVector} + +/** + * Common functionality for a udf runner that exchanges data with Python worker via Arrow stream. + */ +abstract class BaseArrowPythonRunner[T]( + funcs: Seq[ChainedPythonFunctions], + evalType: Int, + argOffsets: Array[Array[Int]]) + extends BasePythonRunner[T, ColumnarBatch](funcs, evalType, argOffsets) { + + protected override def newReaderIterator( + stream: DataInputStream, + writerThread: WriterThread, + startTime: Long, + env: SparkEnv, + worker: Socket, + releasedOrClosed: AtomicBoolean, + context: TaskContext): Iterator[ColumnarBatch] = { + + new ReaderIterator(stream, writerThread, startTime, env, worker, releasedOrClosed, context) { + + private val allocator = ArrowUtils.rootAllocator.newChildAllocator( + s"stdin reader for $pythonExec", 0, Long.MaxValue) + + private var reader: ArrowStreamReader = _ + private var root: VectorSchemaRoot = _ + private var schema: StructType = _ + private var vectors: Array[ColumnVector] = _ + + context.addTaskCompletionListener[Unit] { _ => + if (reader != null) { + reader.close(false) + } + allocator.close() + } + + private var batchLoaded = true + + protected override def read(): ColumnarBatch = { + if (writerThread.exception.isDefined) { + throw writerThread.exception.get + } + try { + if (reader != null && batchLoaded) { + batchLoaded = reader.loadNextBatch() + if (batchLoaded) { + val batch = new ColumnarBatch(vectors) + batch.setNumRows(root.getRowCount) + batch + } else { + reader.close(false) + allocator.close() + // Reach end of stream. Call `read()` again to read control data. + read() + } + } else { + stream.readInt() match { + case SpecialLengths.START_ARROW_STREAM => + reader = new ArrowStreamReader(stream, allocator) + root = reader.getVectorSchemaRoot() + schema = ArrowUtils.fromArrowSchema(root.getSchema()) + vectors = root.getFieldVectors().asScala.map { vector => + new ArrowColumnVector(vector) + }.toArray[ColumnVector] + read() + case SpecialLengths.TIMING_DATA => + handleTimingData() + read() + case SpecialLengths.PYTHON_EXCEPTION_THROWN => + throw handlePythonException() + case SpecialLengths.END_OF_DATA_SECTION => + handleEndOfDataSection() + null + } + } + } catch handleException + } + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BasePandasGroupExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BasePandasGroupExec.scala new file mode 100644 index 0000000000000..477c288ad1211 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BasePandasGroupExec.scala @@ -0,0 +1,137 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.python + +import scala.collection.JavaConverters._ +import scala.collection.mutable.ArrayBuffer + +import org.apache.spark.TaskContext +import org.apache.spark.api.python.{BasePythonRunner, ChainedPythonFunctions} +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeSet, Expression, PythonUDF, UnsafeProjection} +import org.apache.spark.sql.execution.{GroupedIterator, SparkPlan} +import org.apache.spark.sql.util.ArrowUtils +import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch} + +/** + * Base functionality for plans which execute grouped python udfs. + */ +abstract class BasePandasGroupExec( + func: Expression, + output: Seq[Attribute]) + extends SparkPlan { + + protected val sessionLocalTimeZone = conf.sessionLocalTimeZone + + protected val pythonRunnerConf = ArrowUtils.getPythonRunnerConfMap(conf) + + protected val pandasFunction = func.asInstanceOf[PythonUDF].func + + protected val chainedFunc = Seq(ChainedPythonFunctions(Seq(pandasFunction))) + + override def producedAttributes: AttributeSet = AttributeSet(output) + + /** + * passes the data to the python runner and coverts the resulting + * columnarbatch into internal rows. + */ + protected def executePython[T]( + data: Iterator[T], + runner: BasePythonRunner[T, ColumnarBatch]): Iterator[InternalRow] = { + + val context = TaskContext.get() + val columnarBatchIter = runner.compute(data, context.partitionId(), context) + val unsafeProj = UnsafeProjection.create(output, output) + + columnarBatchIter.flatMap { batch => + // UDF returns a StructType column in ColumnarBatch, select the children here + val structVector = batch.column(0).asInstanceOf[ArrowColumnVector] + val outputVectors = output.indices.map(structVector.getChild) + val flattenedBatch = new ColumnarBatch(outputVectors.toArray) + flattenedBatch.setNumRows(batch.numRows()) + flattenedBatch.rowIterator.asScala + }.map(unsafeProj) + } + + /** + * groups according to grouping attributes and then projects into the deduplicated schema + */ + protected def groupAndProject( + input: Iterator[InternalRow], + groupingAttributes: Seq[Attribute], + inputSchema: Seq[Attribute], + dedupSchema: Seq[Attribute]): Iterator[(InternalRow, Iterator[InternalRow])] = { + val groupedIter = GroupedIterator(input, groupingAttributes, inputSchema) + val dedupProj = UnsafeProjection.create(dedupSchema, inputSchema) + groupedIter.map { + case (k, groupedRowIter) => (k, groupedRowIter.map(dedupProj)) + } + } + + /** + * Returns a the deduplicated attributes of the spark plan and the arg offsets of the + * keys and values. + * + * The deduplicated attributes are needed because the spark plan may contain an attribute + * twice; once in the key and once in the value. For any such attribute we need to + * deduplicate. + * + * The arg offsets are used to distinguish grouping grouping attributes and data attributes + * as following: + * + * argOffsets[0] is the length of the argOffsets array + * + * argOffsets[1] is the length of grouping attribute + * argOffsets[2 .. argOffsets[0]+2] is the arg offsets for grouping attributes + * + * argOffsets[argOffsets[0]+2 .. ] is the arg offsets for data attributes + */ + protected def resolveArgOffsets( + child: SparkPlan, groupingAttributes: Seq[Attribute]): (Seq[Attribute], Array[Int]) = { + + val dataAttributes = child.output.drop(groupingAttributes.length) + val groupingIndicesInData = groupingAttributes.map { attribute => + dataAttributes.indexWhere(attribute.semanticEquals) + } + + val groupingArgOffsets = new ArrayBuffer[Int] + val nonDupGroupingAttributes = new ArrayBuffer[Attribute] + val nonDupGroupingSize = groupingIndicesInData.count(_ == -1) + + groupingAttributes.zip(groupingIndicesInData).foreach { + case (attribute, index) => + if (index == -1) { + groupingArgOffsets += nonDupGroupingAttributes.length + nonDupGroupingAttributes += attribute + } else { + groupingArgOffsets += index + nonDupGroupingSize + } + } + + val dataArgOffsets = nonDupGroupingAttributes.length until + (nonDupGroupingAttributes.length + dataAttributes.length) + + val argOffsetsLength = groupingAttributes.length + dataArgOffsets.length + 1 + val argOffsets = Array(argOffsetsLength, + groupingAttributes.length) ++ groupingArgOffsets ++ dataArgOffsets + + // Attributes after deduplication + val dedupAttributes = nonDupGroupingAttributes ++ dataAttributes + (dedupAttributes, argOffsets) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/CogroupedArrowPythonRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/CogroupedArrowPythonRunner.scala new file mode 100644 index 0000000000000..8ea9881c575a1 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/CogroupedArrowPythonRunner.scala @@ -0,0 +1,113 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.python + +import java.io._ +import java.net._ + +import org.apache.arrow.vector.VectorSchemaRoot +import org.apache.arrow.vector.ipc.ArrowStreamWriter + +import org.apache.spark._ +import org.apache.spark.api.python._ +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.execution.arrow.ArrowWriter +import org.apache.spark.sql.types._ +import org.apache.spark.sql.util.ArrowUtils +import org.apache.spark.util.Utils + + +/** + * Python UDF Runner for cogrouped udfs. Although the data is exchanged with the python + * worker via arrow, we cannot use `ArrowPythonRunner` as we need to send more than one + * dataframe. + */ +class CogroupedArrowPythonRunner( + funcs: Seq[ChainedPythonFunctions], + evalType: Int, + argOffsets: Array[Array[Int]], + leftSchema: StructType, + rightSchema: StructType, + timeZoneId: String, + conf: Map[String, String]) + extends BaseArrowPythonRunner[(Iterator[InternalRow], Iterator[InternalRow])]( + funcs, evalType, argOffsets) { + + protected def newWriterThread( + env: SparkEnv, + worker: Socket, + inputIterator: Iterator[(Iterator[InternalRow], Iterator[InternalRow])], + partitionIndex: Int, + context: TaskContext): WriterThread = { + + new WriterThread(env, worker, inputIterator, partitionIndex, context) { + + protected override def writeCommand(dataOut: DataOutputStream): Unit = { + + // Write config for the worker as a number of key -> value pairs of strings + dataOut.writeInt(conf.size) + for ((k, v) <- conf) { + PythonRDD.writeUTF(k, dataOut) + PythonRDD.writeUTF(v, dataOut) + } + + PythonUDFRunner.writeUDFs(dataOut, funcs, argOffsets) + } + + protected override def writeIteratorToStream(dataOut: DataOutputStream): Unit = { + // For each we first send the number of dataframes in each group then send + // first df, then send second df. End of data is marked by sending 0. + while (inputIterator.hasNext) { + dataOut.writeInt(2) + val (nextLeft, nextRight) = inputIterator.next() + writeGroup(nextLeft, leftSchema, dataOut, "left") + writeGroup(nextRight, rightSchema, dataOut, "right") + } + dataOut.writeInt(0) + } + + def writeGroup( + group: Iterator[InternalRow], + schema: StructType, + dataOut: DataOutputStream, + name: String) = { + val arrowSchema = ArrowUtils.toArrowSchema(schema, timeZoneId) + val allocator = ArrowUtils.rootAllocator.newChildAllocator( + s"stdout writer for $pythonExec ($name)", 0, Long.MaxValue) + val root = VectorSchemaRoot.create(arrowSchema, allocator) + + Utils.tryWithSafeFinally { + val writer = new ArrowStreamWriter(root, null, dataOut) + val arrowWriter = ArrowWriter.create(root) + writer.start() + + while (group.hasNext) { + arrowWriter.write(group.next()) + } + arrowWriter.finish() + writer.writeBatch() + writer.end() + }{ + root.close() + allocator.close() + } + } + } + } +} + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapCoGroupsInPandasExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapCoGroupsInPandasExec.scala new file mode 100644 index 0000000000000..cc83e0cecdc33 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapCoGroupsInPandasExec.scala @@ -0,0 +1,97 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.python + +import org.apache.spark.api.python.PythonEvalType +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistribution, Distribution, Partitioning} +import org.apache.spark.sql.execution.{BinaryExecNode, CoGroupedIterator, SparkPlan} +import org.apache.spark.sql.types.StructType + + +/** + * Physical node for [[org.apache.spark.sql.catalyst.plans.logical.FlatMapCoGroupsInPandas]] + * + * The input dataframes are first Cogrouped. Rows from each side of the cogroup are passed to the + * Python worker via Arrow. As each side of the cogroup may have a different schema we send every + * group in its own Arrow stream. + * The Python worker turns the resulting record batches to `pandas.DataFrame`s, invokes the + * user-defined function, and passes the resulting `pandas.DataFrame` + * as an Arrow record batch. Finally, each record batch is turned to + * Iterator[InternalRow] using ColumnarBatch. + * + * Note on memory usage: + * Both the Python worker and the Java executor need to have enough memory to + * hold the largest cogroup. The memory on the Java side is used to construct the + * record batches (off heap memory). The memory on the Python side is used for + * holding the `pandas.DataFrame`. It's possible to further split one group into + * multiple record batches to reduce the memory footprint on the Java side, this + * is left as future work. + */ +case class FlatMapCoGroupsInPandasExec( + leftGroup: Seq[Attribute], + rightGroup: Seq[Attribute], + func: Expression, + output: Seq[Attribute], + left: SparkPlan, + right: SparkPlan) + extends BasePandasGroupExec(func, output) with BinaryExecNode { + + override def outputPartitioning: Partitioning = left.outputPartitioning + + override def requiredChildDistribution: Seq[Distribution] = { + val leftDist = if (leftGroup.isEmpty) AllTuples else ClusteredDistribution(leftGroup) + val rightDist = if (rightGroup.isEmpty) AllTuples else ClusteredDistribution(rightGroup) + leftDist :: rightDist :: Nil + } + + override def requiredChildOrdering: Seq[Seq[SortOrder]] = { + leftGroup + .map(SortOrder(_, Ascending)) :: rightGroup.map(SortOrder(_, Ascending)) :: Nil + } + + override protected def doExecute(): RDD[InternalRow] = { + + val (leftDedup, leftArgOffsets) = resolveArgOffsets(left, leftGroup) + val (rightDedup, rightArgOffsets) = resolveArgOffsets(right, rightGroup) + + // Map cogrouped rows to ArrowPythonRunner results, Only execute if partition is not empty + left.execute().zipPartitions(right.execute()) { (leftData, rightData) => + if (leftData.isEmpty && rightData.isEmpty) Iterator.empty else { + + val leftGrouped = groupAndProject(leftData, leftGroup, left.output, leftDedup) + val rightGrouped = groupAndProject(rightData, rightGroup, right.output, rightDedup) + val data = new CoGroupedIterator(leftGrouped, rightGrouped, leftGroup) + .map { case (_, l, r) => (l, r) } + + val runner = new CogroupedArrowPythonRunner( + chainedFunc, + PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF, + Array(leftArgOffsets ++ rightArgOffsets), + StructType.fromAttributes(leftDedup), + StructType.fromAttributes(rightDedup), + sessionLocalTimeZone, + pythonRunnerConf) + + executePython(data, runner) + } + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala index 267698d1bca50..22a0d1e09b12e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala @@ -17,19 +17,14 @@ package org.apache.spark.sql.execution.python -import scala.collection.JavaConverters._ -import scala.collection.mutable.ArrayBuffer - -import org.apache.spark.TaskContext -import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType} +import org.apache.spark.api.python.PythonEvalType import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistribution, Distribution, Partitioning} -import org.apache.spark.sql.execution.{GroupedIterator, SparkPlan, UnaryExecNode} +import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode} import org.apache.spark.sql.types.StructType -import org.apache.spark.sql.util.ArrowUtils -import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch} + /** * Physical node for [[org.apache.spark.sql.catalyst.plans.logical.FlatMapGroupsInPandas]] @@ -53,14 +48,10 @@ case class FlatMapGroupsInPandasExec( func: Expression, output: Seq[Attribute], child: SparkPlan) - extends UnaryExecNode { - - private val pandasFunction = func.asInstanceOf[PythonUDF].func + extends BasePandasGroupExec(func, output) with UnaryExecNode { override def outputPartitioning: Partitioning = child.outputPartitioning - override def producedAttributes: AttributeSet = AttributeSet(output) - override def requiredChildDistribution: Seq[Distribution] = { if (groupingAttributes.isEmpty) { AllTuples :: Nil @@ -75,88 +66,23 @@ case class FlatMapGroupsInPandasExec( override protected def doExecute(): RDD[InternalRow] = { val inputRDD = child.execute() - val chainedFunc = Seq(ChainedPythonFunctions(Seq(pandasFunction))) - val sessionLocalTimeZone = conf.sessionLocalTimeZone - val pythonRunnerConf = ArrowUtils.getPythonRunnerConfMap(conf) - - // Deduplicate the grouping attributes. - // If a grouping attribute also appears in data attributes, then we don't need to send the - // grouping attribute to Python worker. If a grouping attribute is not in data attributes, - // then we need to send this grouping attribute to python worker. - // - // We use argOffsets to distinguish grouping attributes and data attributes as following: - // - // argOffsets[0] is the length of grouping attributes - // argOffsets[1 .. argOffsets[0]+1] is the arg offsets for grouping attributes - // argOffsets[argOffsets[0]+1 .. ] is the arg offsets for data attributes - - val dataAttributes = child.output.drop(groupingAttributes.length) - val groupingIndicesInData = groupingAttributes.map { attribute => - dataAttributes.indexWhere(attribute.semanticEquals) - } - - val groupingArgOffsets = new ArrayBuffer[Int] - val nonDupGroupingAttributes = new ArrayBuffer[Attribute] - val nonDupGroupingSize = groupingIndicesInData.count(_ == -1) - - // Non duplicate grouping attributes are added to nonDupGroupingAttributes and - // their offsets are 0, 1, 2 ... - // Duplicate grouping attributes are NOT added to nonDupGroupingAttributes and - // their offsets are n + index, where n is the total number of non duplicate grouping - // attributes and index is the index in the data attributes that the grouping attribute - // is a duplicate of. - - groupingAttributes.zip(groupingIndicesInData).foreach { - case (attribute, index) => - if (index == -1) { - groupingArgOffsets += nonDupGroupingAttributes.length - nonDupGroupingAttributes += attribute - } else { - groupingArgOffsets += index + nonDupGroupingSize - } - } - - val dataArgOffsets = nonDupGroupingAttributes.length until - (nonDupGroupingAttributes.length + dataAttributes.length) - - val argOffsets = Array(Array(groupingAttributes.length) ++ groupingArgOffsets ++ dataArgOffsets) - - // Attributes after deduplication - val dedupAttributes = nonDupGroupingAttributes ++ dataAttributes - val dedupSchema = StructType.fromAttributes(dedupAttributes) + val (dedupAttributes, argOffsets) = resolveArgOffsets(child, groupingAttributes) // Map grouped rows to ArrowPythonRunner results, Only execute if partition is not empty inputRDD.mapPartitionsInternal { iter => if (iter.isEmpty) iter else { - val grouped = if (groupingAttributes.isEmpty) { - Iterator(iter) - } else { - val groupedIter = GroupedIterator(iter, groupingAttributes, child.output) - val dedupProj = UnsafeProjection.create(dedupAttributes, child.output) - groupedIter.map { - case (_, groupedRowIter) => groupedRowIter.map(dedupProj) - } - } - val context = TaskContext.get() + val data = groupAndProject(iter, groupingAttributes, child.output, dedupAttributes) + .map{case(_, x) => x} - val columnarBatchIter = new ArrowPythonRunner( + val runner = new ArrowPythonRunner( chainedFunc, PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF, - argOffsets, - dedupSchema, + Array(argOffsets), + StructType.fromAttributes(dedupAttributes), sessionLocalTimeZone, - pythonRunnerConf).compute(grouped, context.partitionId(), context) - - val unsafeProj = UnsafeProjection.create(output, output) + pythonRunnerConf) - columnarBatchIter.flatMap { batch => - // Grouped Map UDF returns a StructType column in ColumnarBatch, select the children here - val structVector = batch.column(0).asInstanceOf[ArrowColumnVector] - val outputVectors = output.indices.map(structVector.getChild) - val flattenedBatch = new ColumnarBatch(outputVectors.toArray) - flattenedBatch.setNumRows(batch.numRows()) - flattenedBatch.rowIterator.asScala - }.map(unsafeProj) + executePython(data, runner) }} } }