Skip to content

Commit 76791b8

Browse files
d80tb7HyukjinKwon
authored andcommitted
[SPARK-27463][PYTHON][FOLLOW-UP] Miscellaneous documentation and code cleanup of cogroup pandas UDF
Follow up from #24981 incorporating some comments from HyukjinKwon. Specifically: - Adding `CoGroupedData` to `pyspark/sql/__init__.py __all__` so that documentation is generated. - Added pydoc, including example, for the use case whereby the user supplies a cogrouping function including a key. - Added the boilerplate for doctests to cogroup.py. Note that cogroup.py only contains the apply() function which has doctests disabled as per the other Pandas Udfs. - Restricted the newly exposed RelationalGroupedDataset constructor parameters to access only by the sql package. - Some minor formatting tweaks. This was tested by running the appropriate unit tests. I'm unsure as to how to check that my change will cause the documentation to be generated correctly, but it someone can describe how I can do this I'd be happy to check. Closes #25939 from d80tb7/SPARK-27463-fixes. Authored-by: Chris Martin <[email protected]> Signed-off-by: HyukjinKwon <[email protected]>
1 parent 39eb79a commit 76791b8

File tree

5 files changed

+64
-28
lines changed

5 files changed

+64
-28
lines changed

python/pyspark/sql/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,11 +51,12 @@
5151
from pyspark.sql.group import GroupedData
5252
from pyspark.sql.readwriter import DataFrameReader, DataFrameWriter
5353
from pyspark.sql.window import Window, WindowSpec
54+
from pyspark.sql.cogroup import CoGroupedData
5455

5556

5657
__all__ = [
5758
'SparkSession', 'SQLContext', 'UDFRegistration',
5859
'DataFrame', 'GroupedData', 'Column', 'Catalog', 'Row',
5960
'DataFrameNaFunctions', 'DataFrameStatFunctions', 'Window', 'WindowSpec',
60-
'DataFrameReader', 'DataFrameWriter'
61+
'DataFrameReader', 'DataFrameWriter', 'CoGroupedData'
6162
]

python/pyspark/sql/cogroup.py

Lines changed: 54 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
# See the License for the specific language governing permissions and
1515
# limitations under the License.
1616
#
17+
import sys
1718

1819
from pyspark import since
1920
from pyspark.rdd import PythonEvalType
@@ -43,9 +44,9 @@ def apply(self, udf):
4344
as a `DataFrame`.
4445
4546
The user-defined function should take two `pandas.DataFrame` and return another
46-
`pandas.DataFrame`. For each side of the cogroup, all columns are passed together
47-
as a `pandas.DataFrame` to the user-function and the returned `pandas.DataFrame`
48-
are combined as a :class:`DataFrame`.
47+
`pandas.DataFrame`. For each side of the cogroup, all columns are passed together as a
48+
`pandas.DataFrame` to the user-function and the returned `pandas.DataFrame` are combined as
49+
a :class:`DataFrame`.
4950
5051
The returned `pandas.DataFrame` can be of arbitrary length and its schema must match the
5152
returnType of the pandas udf.
@@ -61,15 +62,16 @@ def apply(self, udf):
6162
6263
>>> from pyspark.sql.functions import pandas_udf, PandasUDFType
6364
>>> df1 = spark.createDataFrame(
64-
... [(20000101, 1, 1.0), (20000101, 2, 2.0), (20000102, 1, 3.0), (20000102, 2, 4.0)],
65-
... ("time", "id", "v1"))
65+
... [(20000101, 1, 1.0), (20000101, 2, 2.0), (20000102, 1, 3.0), (20000102, 2, 4.0)],
66+
... ("time", "id", "v1"))
6667
>>> df2 = spark.createDataFrame(
67-
... [(20000101, 1, "x"), (20000101, 2, "y")],
68-
... ("time", "id", "v2"))
69-
>>> @pandas_udf("time int, id int, v1 double, v2 string", PandasUDFType.COGROUPED_MAP)
68+
... [(20000101, 1, "x"), (20000101, 2, "y")],
69+
... ("time", "id", "v2"))
70+
>>> @pandas_udf("time int, id int, v1 double, v2 string",
71+
... PandasUDFType.COGROUPED_MAP) # doctest: +SKIP
7072
... def asof_join(l, r):
7173
... return pd.merge_asof(l, r, on="time", by="id")
72-
>>> df1.groupby("id").cogroup(df2.groupby("id")).apply(asof_join).show()
74+
>>> df1.groupby("id").cogroup(df2.groupby("id")).apply(asof_join).show() # doctest: +SKIP
7375
+--------+---+---+---+
7476
| time| id| v1| v2|
7577
+--------+---+---+---+
@@ -79,6 +81,27 @@ def apply(self, udf):
7981
|20000102| 2|4.0| y|
8082
+--------+---+---+---+
8183
84+
Alternatively, the user can define a function that takes three arguments. In this case,
85+
the grouping key(s) will be passed as the first argument and the data will be passed as the
86+
second and third arguments. The grouping key(s) will be passed as a tuple of numpy data
87+
types, e.g., `numpy.int32` and `numpy.float64`. The data will still be passed in as two
88+
`pandas.DataFrame` containing all columns from the original Spark DataFrames.
89+
90+
>>> @pandas_udf("time int, id int, v1 double, v2 string",
91+
... PandasUDFType.COGROUPED_MAP) # doctest: +SKIP
92+
... def asof_join(k, l, r):
93+
... if k == (1,):
94+
... return pd.merge_asof(l, r, on="time", by="id")
95+
... else:
96+
... return pd.DataFrame(columns=['time', 'id', 'v1', 'v2'])
97+
>>> df1.groupby("id").cogroup(df2.groupby("id")).apply(asof_join).show() # doctest: +SKIP
98+
+--------+---+---+---+
99+
| time| id| v1| v2|
100+
+--------+---+---+---+
101+
|20000101| 1|1.0| x|
102+
|20000102| 1|3.0| x|
103+
+--------+---+---+---+
104+
82105
.. seealso:: :meth:`pyspark.sql.functions.pandas_udf`
83106
84107
"""
@@ -96,3 +119,25 @@ def apply(self, udf):
96119
def _extract_cols(gd):
97120
df = gd._df
98121
return [df[col] for col in df.columns]
122+
123+
124+
def _test():
125+
import doctest
126+
from pyspark.sql import SparkSession
127+
import pyspark.sql.cogroup
128+
globs = pyspark.sql.cogroup.__dict__.copy()
129+
spark = SparkSession.builder\
130+
.master("local[4]")\
131+
.appName("sql.cogroup tests")\
132+
.getOrCreate()
133+
globs['spark'] = spark
134+
(failure_count, test_count) = doctest.testmod(
135+
pyspark.sql.cogroup, globs=globs,
136+
optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE | doctest.REPORT_NDIFF)
137+
spark.stop()
138+
if failure_count:
139+
sys.exit(-1)
140+
141+
142+
if __name__ == "__main__":
143+
_test()

python/pyspark/sql/tests/test_pandas_udf_cogrouped_map.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -32,14 +32,9 @@
3232
import pyarrow as pa
3333

3434

35-
"""
36-
Tests below use pd.DataFrame.assign that will infer mixed types (unicode/str) for column names
37-
from kwargs w/ Python 2, so need to set check_column_type=False and avoid this check
38-
"""
39-
if sys.version < '3':
40-
_check_column_type = False
41-
else:
42-
_check_column_type = True
35+
# Tests below use pd.DataFrame.assign that will infer mixed types (unicode/str) for column names
36+
# From kwargs w/ Python 2, so need to set check_column_type=False and avoid this check
37+
_check_column_type = sys.version >= '3'
4338

4439

4540
@unittest.skipIf(

python/pyspark/sql/tests/test_pandas_udf_grouped_map.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -37,14 +37,9 @@
3737
import pyarrow as pa
3838

3939

40-
"""
41-
Tests below use pd.DataFrame.assign that will infer mixed types (unicode/str) for column names
42-
from kwargs w/ Python 2, so need to set check_column_type=False and avoid this check
43-
"""
44-
if sys.version < '3':
45-
_check_column_type = False
46-
else:
47-
_check_column_type = True
40+
# Tests below use pd.DataFrame.assign that will infer mixed types (unicode/str) for column names
41+
# from kwargs w/ Python 2, so need to set check_column_type=False and avoid this check
42+
_check_column_type = sys.version >= '3'
4843

4944

5045
@unittest.skipIf(

sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,8 @@ import org.apache.spark.sql.types.{NumericType, StructType}
4747
*/
4848
@Stable
4949
class RelationalGroupedDataset protected[sql](
50-
val df: DataFrame,
51-
val groupingExprs: Seq[Expression],
50+
private[sql] val df: DataFrame,
51+
private[sql] val groupingExprs: Seq[Expression],
5252
groupType: RelationalGroupedDataset.GroupType) {
5353

5454
private[this] def toDF(aggExprs: Seq[Expression]): DataFrame = {

0 commit comments

Comments
 (0)