-
Couldn't load subscription status.
- Fork 28.9k
[SPARK-27463][PYTHON] Support Dataframe Cogroup via Pandas UDFs #24981
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
2e0b308
64ff5ac
6d039e3
d8a5c5d
73188f6
690fa14
c86b2bf
e3b66ac
8007fa6
d4cf6d0
87aeb92
e7528d0
b85ec75
6a8ecff
d2da787
7321141
6bbe31c
b444ff7
94be574
9241639
3de551f
300b53a
7d161ba
d1a6366
a201161
3e4bc95
307e664
7558b8d
19360c4
28493b4
d6d11e4
a62a1e3
ec78284
1a9ff58
c0d2919
4cd5c70
e025375
dd1ffaf
733b592
51dcbdc
1b966fd
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,98 @@ | ||
| # | ||
| # Licensed to the Apache Software Foundation (ASF) under one or more | ||
| # contributor license agreements. See the NOTICE file distributed with | ||
| # this work for additional information regarding copyright ownership. | ||
| # The ASF licenses this file to You under the Apache License, Version 2.0 | ||
| # (the "License"); you may not use this file except in compliance with | ||
| # the License. You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
| # | ||
|
|
||
| from pyspark import since | ||
| from pyspark.rdd import PythonEvalType | ||
d80tb7 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| from pyspark.sql.column import Column | ||
| from pyspark.sql.dataframe import DataFrame | ||
|
|
||
|
|
||
| class CoGroupedData(object): | ||
BryanCutler marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| """ | ||
| A logical grouping of two :class:`GroupedData`, | ||
| created by :func:`GroupedData.cogroup`. | ||
|
|
||
| .. note:: Experimental | ||
|
|
||
| .. versionadded:: 3.0 | ||
| """ | ||
|
|
||
| def __init__(self, gd1, gd2): | ||
| self._gd1 = gd1 | ||
| self._gd2 = gd2 | ||
| self.sql_ctx = gd1.sql_ctx | ||
|
|
||
| @since(3.0) | ||
| def apply(self, udf): | ||
| """ | ||
| Applies a function to each cogroup using a pandas udf and returns the result | ||
| as a `DataFrame`. | ||
|
|
||
| The user-defined function should take two `pandas.DataFrame` and return another | ||
| `pandas.DataFrame`. For each side of the cogroup, all columns are passed together | ||
d80tb7 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| as a `pandas.DataFrame` to the user-function and the returned `pandas.DataFrame` | ||
| are combined as a :class:`DataFrame`. | ||
|
|
||
| The returned `pandas.DataFrame` can be of arbitrary length and its schema must match the | ||
| returnType of the pandas udf. | ||
|
|
||
| .. note:: This function requires a full shuffle. All the data of a cogroup will be loaded | ||
| into memory, so the user should be aware of the potential OOM risk if data is skewed | ||
| and certain groups are too large to fit in memory. | ||
|
|
||
| .. note:: Experimental | ||
|
|
||
| :param udf: a cogrouped map user-defined function returned by | ||
| :func:`pyspark.sql.functions.pandas_udf`. | ||
|
|
||
| >>> from pyspark.sql.functions import pandas_udf, PandasUDFType | ||
| >>> df1 = spark.createDataFrame( | ||
| ... [(20000101, 1, 1.0), (20000101, 2, 2.0), (20000102, 1, 3.0), (20000102, 2, 4.0)], | ||
| ... ("time", "id", "v1")) | ||
| >>> df2 = spark.createDataFrame( | ||
| ... [(20000101, 1, "x"), (20000101, 2, "y")], | ||
| ... ("time", "id", "v2")) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. indentation nit |
||
| >>> @pandas_udf("time int, id int, v1 double, v2 string", PandasUDFType.COGROUPED_MAP) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should skip this test and run the doctests
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We are currently skipping all doctests for Pandas UDFs right? We could add the module but then need to skip each test individually, which might be more consistent with the rest of PySpark. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we document when arguments are three? (when it includes the grouping key) |
||
| ... def asof_join(l, r): | ||
| ... return pd.merge_asof(l, r, on="time", by="id") | ||
| >>> df1.groupby("id").cogroup(df2.groupby("id")).apply(asof_join).show() | ||
| +--------+---+---+---+ | ||
| | time| id| v1| v2| | ||
| +--------+---+---+---+ | ||
| |20000101| 1|1.0| x| | ||
| |20000102| 1|3.0| x| | ||
| |20000101| 2|2.0| y| | ||
| |20000102| 2|4.0| y| | ||
| +--------+---+---+---+ | ||
|
|
||
| .. seealso:: :meth:`pyspark.sql.functions.pandas_udf` | ||
|
|
||
| """ | ||
| # Columns are special because hasattr always return True | ||
| if isinstance(udf, Column) or not hasattr(udf, 'func') \ | ||
| or udf.evalType != PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF: | ||
| raise ValueError("Invalid udf: the udf argument must be a pandas_udf of type " | ||
| "COGROUPED_MAP.") | ||
| all_cols = self._extract_cols(self._gd1) + self._extract_cols(self._gd2) | ||
| udf_column = udf(*all_cols) | ||
| jdf = self._gd1._jgd.flatMapCoGroupsInPandas(self._gd2._jgd, udf_column._jc.expr()) | ||
| return DataFrame(jdf, self.sql_ctx) | ||
|
|
||
| @staticmethod | ||
| def _extract_cols(gd): | ||
| df = gd._df | ||
| return [df[col] for col in df.columns] | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems like we don't generate documentation for this:
cannot click.
It should be either documented at
python/docs/pyspark.sql.rstor imported atpyspark/sql/__init__.pywith including it at__all__.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1 adding it to
pyspark/sql/__init__.pywith including it at__all__since this is whatgroup.pydoes