Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
2e0b308
initial commit of cogroup
d80tb7 Jun 20, 2019
64ff5ac
minor tidy up
d80tb7 Jun 20, 2019
6d039e3
removed incorrect test
d80tb7 Jun 21, 2019
d8a5c5d
tidies up test, fixed output cols
d80tb7 Jun 25, 2019
73188f6
removed incorrect file
d80tb7 Jun 25, 2019
690fa14
Revert: removed incorrect test
d80tb7 Jun 25, 2019
c86b2bf
Merge branch 'master' of https://github.com/d80tb7/spark into SPARK-2…
d80tb7 Jun 25, 2019
e3b66ac
fix for resolving key cols
d80tb7 Jun 25, 2019
8007fa6
common trait for grouped mandas udfs
d80tb7 Jun 27, 2019
d4cf6d0
poc using arrow streams
d80tb7 Jun 27, 2019
87aeb92
more unit tests fro cogroup
d80tb7 Jun 27, 2019
e7528d0
argspec includes grouping key
d80tb7 Jul 2, 2019
b85ec75
fixed tests und
d80tb7 Jul 2, 2019
6a8ecff
keys now handled properly. Validation of udf. More tests
d80tb7 Jul 2, 2019
d2da787
formatting
d80tb7 Jul 2, 2019
7321141
fixed scalastyle errors
d80tb7 Jul 2, 2019
6bbe31c
updated grouped map to new args format
d80tb7 Jul 2, 2019
b444ff7
Merge branch 'master' of https://github.com/apache/spark into SPARK-2…
d80tb7 Jul 2, 2019
94be574
some code review fixes
d80tb7 Jul 11, 2019
9241639
Merge branch 'master' of https://github.com/apache/spark into SPARK-2…
d80tb7 Jul 11, 2019
3de551f
more code review fixes
d80tb7 Jul 11, 2019
300b53a
more code review fixes
d80tb7 Jul 11, 2019
7d161ba
fix comment on PandasCogroupSerializer
d80tb7 Jul 11, 2019
d1a6366
formatting
d80tb7 Jul 11, 2019
a201161
Merge branch 'master' of https://github.com/apache/spark into SPARK-2…
d80tb7 Jul 19, 2019
3e4bc95
python style fixes
d80tb7 Jul 19, 2019
307e664
added doc
d80tb7 Jul 19, 2019
7558b8d
Merge branch 'master' of https://github.com/apache/spark into SPARK-2…
d80tb7 Jul 23, 2019
19360c4
minor formatting
d80tb7 Jul 23, 2019
28493b4
a couple more usnit tests
d80tb7 Jul 23, 2019
d6d11e4
minor formatting
d80tb7 Jul 23, 2019
a62a1e3
more doc
d80tb7 Jul 25, 2019
ec78284
added comment to cogroup func
d80tb7 Jul 25, 2019
1a9ff58
fixed python style
d80tb7 Jul 25, 2019
c0d2919
review comments
d80tb7 Aug 20, 2019
4cd5c70
review comments scala
d80tb7 Aug 20, 2019
e025375
Merge branch 'master' of https://github.com/apache/spark into SPARK-2…
d80tb7 Aug 20, 2019
dd1ffaf
python formatting
d80tb7 Aug 20, 2019
733b592
review comments (mainly formatting)
d80tb7 Sep 8, 2019
51dcbdc
Merge branch 'master' of https://github.com/apache/spark into SPARK-2…
d80tb7 Sep 8, 2019
1b966fd
couple more format changes
d80tb7 Sep 15, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ private[spark] object PythonEvalType {
val SQL_WINDOW_AGG_PANDAS_UDF = 203
val SQL_SCALAR_PANDAS_ITER_UDF = 204
val SQL_MAP_PANDAS_ITER_UDF = 205
val SQL_COGROUPED_MAP_PANDAS_UDF = 206

def toString(pythonEvalType: Int): String = pythonEvalType match {
case NON_UDF => "NON_UDF"
Expand All @@ -58,6 +59,7 @@ private[spark] object PythonEvalType {
case SQL_WINDOW_AGG_PANDAS_UDF => "SQL_WINDOW_AGG_PANDAS_UDF"
case SQL_SCALAR_PANDAS_ITER_UDF => "SQL_SCALAR_PANDAS_ITER_UDF"
case SQL_MAP_PANDAS_ITER_UDF => "SQL_MAP_PANDAS_ITER_UDF"
case SQL_COGROUPED_MAP_PANDAS_UDF => "SQL_COGROUPED_MAP_PANDAS_UDF"
}
}

Expand Down
1 change: 1 addition & 0 deletions python/pyspark/rdd.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ class PythonEvalType(object):
SQL_WINDOW_AGG_PANDAS_UDF = 203
SQL_SCALAR_PANDAS_ITER_UDF = 204
SQL_MAP_PANDAS_ITER_UDF = 205
SQL_COGROUPED_MAP_PANDAS_UDF = 206


def portable_hash(x):
Expand Down
26 changes: 26 additions & 0 deletions python/pyspark/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,6 +401,32 @@ def __repr__(self):
return "ArrowStreamPandasUDFSerializer"


class CogroupUDFSerializer(ArrowStreamPandasUDFSerializer):

def load_stream(self, stream):
"""
Deserialize Cogrouped ArrowRecordBatches to a tuple of Arrow tables and yield as two
lists of pandas.Series.
"""
import pyarrow as pa
dataframes_in_group = None

while dataframes_in_group is None or dataframes_in_group > 0:
dataframes_in_group = read_int(stream)

if dataframes_in_group == 2:
batch1 = [batch for batch in ArrowStreamSerializer.load_stream(self, stream)]
batch2 = [batch for batch in ArrowStreamSerializer.load_stream(self, stream)]
yield (
[self.arrow_to_pandas(c) for c in pa.Table.from_batches(batch1).itercolumns()],
[self.arrow_to_pandas(c) for c in pa.Table.from_batches(batch2).itercolumns()]
)

elif dataframes_in_group != 0:
raise ValueError(
'Invalid number of pandas.DataFrames in group {0}'.format(dataframes_in_group))


class BatchedSerializer(Serializer):

"""
Expand Down
98 changes: 98 additions & 0 deletions python/pyspark/sql/cogroup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
#
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like we don't generate documentation for this:

Screen Shot 2019-09-22 at 9 41 48 PM

cannot click.

It should be either documented at python/docs/pyspark.sql.rst or imported at pyspark/sql/__init__.py with including it at __all__.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 adding it to pyspark/sql/__init__.py with including it at __all__ since this is what group.py does

# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

from pyspark import since
from pyspark.rdd import PythonEvalType
from pyspark.sql.column import Column
from pyspark.sql.dataframe import DataFrame


class CoGroupedData(object):
"""
A logical grouping of two :class:`GroupedData`,
created by :func:`GroupedData.cogroup`.

.. note:: Experimental

.. versionadded:: 3.0
"""

def __init__(self, gd1, gd2):
self._gd1 = gd1
self._gd2 = gd2
self.sql_ctx = gd1.sql_ctx

@since(3.0)
def apply(self, udf):
"""
Applies a function to each cogroup using a pandas udf and returns the result
as a `DataFrame`.

The user-defined function should take two `pandas.DataFrame` and return another
`pandas.DataFrame`. For each side of the cogroup, all columns are passed together
as a `pandas.DataFrame` to the user-function and the returned `pandas.DataFrame`
are combined as a :class:`DataFrame`.

The returned `pandas.DataFrame` can be of arbitrary length and its schema must match the
returnType of the pandas udf.

.. note:: This function requires a full shuffle. All the data of a cogroup will be loaded
into memory, so the user should be aware of the potential OOM risk if data is skewed
and certain groups are too large to fit in memory.

.. note:: Experimental

:param udf: a cogrouped map user-defined function returned by
:func:`pyspark.sql.functions.pandas_udf`.

>>> from pyspark.sql.functions import pandas_udf, PandasUDFType
>>> df1 = spark.createDataFrame(
... [(20000101, 1, 1.0), (20000101, 2, 2.0), (20000102, 1, 3.0), (20000102, 2, 4.0)],
... ("time", "id", "v1"))
>>> df2 = spark.createDataFrame(
... [(20000101, 1, "x"), (20000101, 2, "y")],
... ("time", "id", "v2"))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

indentation nit

>>> @pandas_udf("time int, id int, v1 double, v2 string", PandasUDFType.COGROUPED_MAP)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should skip this test and run the doctests

  1. add this to dev/sparktestsupport/modules.py at pyspark_sql

  2. add:

    def main():
        doctest.testmod(...)
        ...

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We are currently skipping all doctests for Pandas UDFs right? We could add the module but then need to skip each test individually, which might be more consistent with the rest of PySpark.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we document when arguments are three? (when it includes the grouping key)

... def asof_join(l, r):
... return pd.merge_asof(l, r, on="time", by="id")
>>> df1.groupby("id").cogroup(df2.groupby("id")).apply(asof_join).show()
+--------+---+---+---+
| time| id| v1| v2|
+--------+---+---+---+
|20000101| 1|1.0| x|
|20000102| 1|3.0| x|
|20000101| 2|2.0| y|
|20000102| 2|4.0| y|
+--------+---+---+---+

.. seealso:: :meth:`pyspark.sql.functions.pandas_udf`

"""
# Columns are special because hasattr always return True
if isinstance(udf, Column) or not hasattr(udf, 'func') \
or udf.evalType != PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF:
raise ValueError("Invalid udf: the udf argument must be a pandas_udf of type "
"COGROUPED_MAP.")
all_cols = self._extract_cols(self._gd1) + self._extract_cols(self._gd2)
udf_column = udf(*all_cols)
jdf = self._gd1._jgd.flatMapCoGroupsInPandas(self._gd2._jgd, udf_column._jc.expr())
return DataFrame(jdf, self.sql_ctx)

@staticmethod
def _extract_cols(gd):
df = gd._df
return [df[col] for col in df.columns]
5 changes: 4 additions & 1 deletion python/pyspark/sql/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2814,6 +2814,8 @@ class PandasUDFType(object):

GROUPED_MAP = PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF

COGROUPED_MAP = PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF

GROUPED_AGG = PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF

MAP_ITER = PythonEvalType.SQL_MAP_PANDAS_ITER_UDF
Expand Down Expand Up @@ -3320,7 +3322,8 @@ def pandas_udf(f=None, returnType=None, functionType=None):
PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF,
PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF,
PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF,
PythonEvalType.SQL_MAP_PANDAS_ITER_UDF]:
PythonEvalType.SQL_MAP_PANDAS_ITER_UDF,
PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF]:
raise ValueError("Invalid functionType: "
"functionType must be one the values from PandasUDFType")

Expand Down
12 changes: 11 additions & 1 deletion python/pyspark/sql/group.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from pyspark.sql.column import Column, _to_seq
from pyspark.sql.dataframe import DataFrame
from pyspark.sql.types import *
from pyspark.sql.cogroup import CoGroupedData

__all__ = ["GroupedData"]

Expand Down Expand Up @@ -218,6 +219,15 @@ def pivot(self, pivot_col, values=None):
jgd = self._jgd.pivot(pivot_col, values)
return GroupedData(jgd, self._df)

@since(3.0)
def cogroup(self, other):
"""
Cogroups this group with another group so that we can run cogrouped operations.

See :class:`CoGroupedData` for the operations that can be run.
"""
return CoGroupedData(self, other)

@since(2.3)
def apply(self, udf):
"""
Expand All @@ -232,7 +242,7 @@ def apply(self, udf):
The returned `pandas.DataFrame` can be of arbitrary length and its schema must match the
returnType of the pandas udf.

.. note:: This function requires a full shuffle. all the data of a group will be loaded
.. note:: This function requires a full shuffle. All the data of a group will be loaded
into memory, so the user should be aware of the potential OOM risk if data is skewed
and certain groups are too large to fit in memory.

Expand Down
Loading