Skip to content

Commit 2a37f22

Browse files
EnricoMiHyukjinKwon
authored andcommitted
[SPARK-42168][3.2][SQL][PYTHON] Fix required child distribution of FlatMapCoGroupsInPandas (as in CoGroup)
### What changes were proposed in this pull request? Make `FlatMapCoGroupsInPandas` (used by PySpark) report its required child distribution as `HashClusteredDistribution`, rather than `ClusteredDistribution`. That is the same distribution as reported by `CoGroup` (used by Scala). ### Why are the changes needed? This allows the `EnsureRequirements` rule to correctly recognizes that `FlatMapCoGroupsInPandas` requiring `HashClusteredDistribution(id, day)` is not compatible with `HashPartitioning(day, id)`, while `ClusteredDistribution(id, day)` is compatible with `HashPartitioning(day, id)`. The following example returns an incorrect result in Spark 3.0, 3.1, and 3.2. ```Scala import org.apache.spark.sql.expressions.Window import org.apache.spark.sql.functions.{col, lit, sum} val ids = 1000 val days = 1000 val parts = 10 val id_df = spark.range(ids) val day_df = spark.range(days).withColumnRenamed("id", "day") val id_day_df = id_df.join(day_df) // these redundant aliases are needed to workaround bug SPARK-42132 val left_df = id_day_df.select($"id".as("id"), $"day".as("day"), lit("left").as("side")).repartition(parts).cache() val right_df = id_day_df.select($"id".as("id"), $"day".as("day"), lit("right").as("side")).repartition(parts).cache() //.withColumnRenamed("id", "id2") // note the column order is different to the groupBy("id", "day") column order below val window = Window.partitionBy("day", "id") case class Key(id: BigInt, day: BigInt) case class Value(id: BigInt, day: BigInt, side: String) case class Sum(id: BigInt, day: BigInt, side: String, day_sum: BigInt) val left_grouped_df = left_df.groupBy("id", "day").as[Key, Value] val right_grouped_df = right_df.withColumn("day_sum", sum(col("day")).over(window)).groupBy("id", "day").as[Key, Sum] val df = left_grouped_df.cogroup(right_grouped_df)((key: Key, left: Iterator[Value], right: Iterator[Sum]) => left) df.explain() df.show(5) ``` Output was ``` == Physical Plan == AdaptiveSparkPlan isFinalPlan=false +- FlatMapCoGroupsInPandas [id#8L, day#9L], [id#29L, day#30L], cogroup(id#8L, day#9L, side#10, id#29L, day#30L, side#31, day_sum#54L), [id#64L, day#65L, lefts#66, rights#67] :- Sort [id#8L ASC NULLS FIRST, day#9L ASC NULLS FIRST], false, 0 : +- Exchange hashpartitioning(id#8L, day#9L, 200), ENSURE_REQUIREMENTS, [plan_id=117] : +- ... +- Sort [id#29L ASC NULLS FIRST, day#30L ASC NULLS FIRST], false, 0 +- Project [id#29L, day#30L, id#29L, day#30L, side#31, day_sum#54L] +- Window [sum(day#30L) windowspecdefinition(day#30L, id#29L, specifiedwindowframe(RowFrame, unboundedpreceding$(), unboundedfollowing$())) AS day_sum#54L], [day#30L, id#29L] +- Sort [day#30L ASC NULLS FIRST, id#29L ASC NULLS FIRST], false, 0 +- Exchange hashpartitioning(day#30L, id#29L, 200), ENSURE_REQUIREMENTS, [plan_id=112] +- ... +---+---+-----+------+ | id|day|lefts|rights| +---+---+-----+------+ | 0| 3| 0| 1| | 0| 4| 0| 1| | 0| 13| 1| 0| | 0| 27| 0| 1| | 0| 31| 0| 1| +---+---+-----+------+ only showing top 5 rows ``` Output now is ``` == Physical Plan == AdaptiveSparkPlan isFinalPlan=false +- FlatMapCoGroupsInPandas [id#8L, day#9L], [id#29L, day#30L], cogroup(id#8L, day#9L, side#10, id#29L, day#30L, side#31, day_sum#54L), [id#64L, day#65L, lefts#66, rights#67] :- Sort [id#8L ASC NULLS FIRST, day#9L ASC NULLS FIRST], false, 0 : +- Exchange hashpartitioning(id#8L, day#9L, 200), ENSURE_REQUIREMENTS, [plan_id=117] : +- ... +- Sort [id#29L ASC NULLS FIRST, day#30L ASC NULLS FIRST], false, 0 +- Exchange hashpartitioning(id#29L, day#30L, 200), ENSURE_REQUIREMENTS, [plan_id=118] +- Project [id#29L, day#30L, id#29L, day#30L, side#31, day_sum#54L] +- Window [sum(day#30L) windowspecdefinition(day#30L, id#29L, specifiedwindowframe(RowFrame, unboundedpreceding$(), unboundedfollowing$())) AS day_sum#54L], [day#30L, id#29L] +- Sort [day#30L ASC NULLS FIRST, id#29L ASC NULLS FIRST], false, 0 +- Exchange hashpartitioning(day#30L, id#29L, 200), ENSURE_REQUIREMENTS, [plan_id=112] +- ... +---+---+-----+------+ | id|day|lefts|rights| +---+---+-----+------+ | 0| 13| 1| 1| | 0| 63| 1| 1| | 0| 89| 1| 1| | 0| 95| 1| 1| | 0| 96| 1| 1| +---+---+-----+------+ only showing top 5 rows ``` Spark 3.3 [reworked](https://github.com/apache/spark/pull/32875/files#diff-e938569a4ca4eba8f7e10fe473d4f9c306ea253df151405bcaba880a601f075fR75-R76) `HashClusteredDistribution`, and is not sensitive to using `ClusteredDistribution`: #32875 ### Does this PR introduce _any_ user-facing change? This fixes correctness. ### How was this patch tested? A unit test in `EnsureRequirementsSuite`. Closes #39717 from EnricoMi/branch-3.2-cogroup-window-bug. Authored-by: Enrico Minack <[email protected]> Signed-off-by: Hyukjin Kwon <[email protected]>
1 parent fed407a commit 2a37f22

File tree

3 files changed

+106
-5
lines changed

3 files changed

+106
-5
lines changed

python/pyspark/sql/tests/test_pandas_cogrouped_map.py

Lines changed: 46 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,9 @@
1717

1818
import unittest
1919

20-
from pyspark.sql.functions import array, explode, col, lit, udf, pandas_udf
20+
from pyspark.sql.functions import array, explode, col, lit, udf, pandas_udf, sum
2121
from pyspark.sql.types import DoubleType, StructType, StructField, Row
22+
from pyspark.sql.window import Window
2223
from pyspark.testing.sqlutils import ReusedSQLTestCase, have_pandas, have_pyarrow, \
2324
pandas_requirement_message, pyarrow_requirement_message
2425
from pyspark.testing.utils import QuietTest
@@ -215,6 +216,50 @@ def test_self_join(self):
215216

216217
self.assertEqual(row.asDict(), Row(column=2, value=2).asDict())
217218

219+
def test_with_window_function(self):
220+
# SPARK-42168: a window function with same partition keys but differing key order
221+
ids = 2
222+
days = 100
223+
vals = 10000
224+
parts = 10
225+
226+
id_df = self.spark.range(ids)
227+
day_df = self.spark.range(days).withColumnRenamed("id", "day")
228+
vals_df = self.spark.range(vals).withColumnRenamed("id", "value")
229+
df = id_df.join(day_df).join(vals_df)
230+
231+
left_df = df.withColumnRenamed("value", "left").repartition(parts).cache()
232+
# SPARK-42132: this bug requires us to alias all columns from df here
233+
right_df = df.select(
234+
col("id").alias("id"), col("day").alias("day"), col("value").alias("right")
235+
).repartition(parts).cache()
236+
237+
# note the column order is different to the groupBy("id", "day") column order below
238+
window = Window.partitionBy("day", "id")
239+
240+
left_grouped_df = left_df.groupBy("id", "day")
241+
right_grouped_df = right_df \
242+
.withColumn("day_sum", sum(col("day")).over(window)) \
243+
.groupBy("id", "day")
244+
245+
def cogroup(left: pd.DataFrame, right: pd.DataFrame) -> pd.DataFrame:
246+
return pd.DataFrame([{
247+
"id": left["id"][0] if not left.empty else (
248+
right["id"][0] if not right.empty else None
249+
),
250+
"day": left["day"][0] if not left.empty else (
251+
right["day"][0] if not right.empty else None
252+
),
253+
"lefts": len(left.index),
254+
"rights": len(right.index)
255+
}])
256+
257+
df = left_grouped_df.cogroup(right_grouped_df) \
258+
.applyInPandas(cogroup, schema="id long, day long, lefts integer, rights integer")
259+
260+
actual = df.orderBy("id", "day").take(days)
261+
self.assertEqual(actual, [Row(0, day, vals, vals) for day in range(days)])
262+
218263
@staticmethod
219264
def _test_with_key(left, right, isLeft):
220265

sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapCoGroupsInPandasExec.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType}
2121
import org.apache.spark.rdd.RDD
2222
import org.apache.spark.sql.catalyst.InternalRow
2323
import org.apache.spark.sql.catalyst.expressions._
24-
import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistribution, Distribution, Partitioning}
24+
import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, Distribution, HashClusteredDistribution, Partitioning}
2525
import org.apache.spark.sql.execution.{BinaryExecNode, CoGroupedIterator, SparkPlan}
2626
import org.apache.spark.sql.execution.python.PandasGroupUtils._
2727
import org.apache.spark.sql.types.StructType
@@ -66,8 +66,8 @@ case class FlatMapCoGroupsInPandasExec(
6666
override def outputPartitioning: Partitioning = left.outputPartitioning
6767

6868
override def requiredChildDistribution: Seq[Distribution] = {
69-
val leftDist = if (leftGroup.isEmpty) AllTuples else ClusteredDistribution(leftGroup)
70-
val rightDist = if (rightGroup.isEmpty) AllTuples else ClusteredDistribution(rightGroup)
69+
val leftDist = if (leftGroup.isEmpty) AllTuples else HashClusteredDistribution(leftGroup)
70+
val rightDist = if (rightGroup.isEmpty) AllTuples else HashClusteredDistribution(rightGroup)
7171
leftDist :: rightDist :: Nil
7272
}
7373

sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala

Lines changed: 57 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,18 @@
1717

1818
package org.apache.spark.sql.execution.exchange
1919

20-
import org.apache.spark.sql.catalyst.expressions.Literal
20+
import org.apache.spark.api.python.PythonEvalType
21+
import org.apache.spark.sql.catalyst.expressions._
22+
import org.apache.spark.sql.catalyst.expressions.aggregate.Sum
2123
import org.apache.spark.sql.catalyst.plans.Inner
2224
import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, PartitioningCollection}
2325
import org.apache.spark.sql.execution.{DummySparkPlan, SortExec}
2426
import org.apache.spark.sql.execution.joins.SortMergeJoinExec
27+
import org.apache.spark.sql.execution.python.FlatMapCoGroupsInPandasExec
28+
import org.apache.spark.sql.execution.window.WindowExec
2529
import org.apache.spark.sql.internal.SQLConf
2630
import org.apache.spark.sql.test.SharedSparkSession
31+
import org.apache.spark.sql.types.{IntegerType, StructField, StructType}
2732

2833
class EnsureRequirementsSuite extends SharedSparkSession {
2934
private val exprA = Literal(1)
@@ -135,4 +140,55 @@ class EnsureRequirementsSuite extends SharedSparkSession {
135140
}.size == 2)
136141
}
137142
}
143+
144+
test("SPARK-42168: FlatMapCoGroupInPandas and Window function with differing key order") {
145+
val lKey = AttributeReference("key", IntegerType)()
146+
val lKey2 = AttributeReference("key2", IntegerType)()
147+
148+
val rKey = AttributeReference("key", IntegerType)()
149+
val rKey2 = AttributeReference("key2", IntegerType)()
150+
val rValue = AttributeReference("value", IntegerType)()
151+
152+
val left = DummySparkPlan()
153+
val right = WindowExec(
154+
Alias(
155+
WindowExpression(
156+
Sum(rValue).toAggregateExpression(),
157+
WindowSpecDefinition(
158+
Seq(rKey2, rKey),
159+
Nil,
160+
SpecifiedWindowFrame(RowFrame, UnboundedPreceding, UnboundedFollowing)
161+
)
162+
), "sum")() :: Nil,
163+
Seq(rKey2, rKey),
164+
Nil,
165+
DummySparkPlan()
166+
)
167+
168+
val pythonUdf = PythonUDF("pyUDF", null,
169+
StructType(Seq(StructField("value", IntegerType))),
170+
Seq.empty,
171+
PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF,
172+
true)
173+
174+
val flapMapCoGroup = FlatMapCoGroupsInPandasExec(
175+
Seq(lKey, lKey2),
176+
Seq(rKey, rKey2),
177+
pythonUdf,
178+
AttributeReference("value", IntegerType)() :: Nil,
179+
left,
180+
right
181+
)
182+
183+
val result = EnsureRequirements.apply(flapMapCoGroup)
184+
result match {
185+
case FlatMapCoGroupsInPandasExec(leftKeys, rightKeys, _, _,
186+
SortExec(leftOrder, false, _, _), SortExec(rightOrder, false, _, _)) =>
187+
assert(leftKeys === Seq(lKey, lKey2))
188+
assert(rightKeys === Seq(rKey, rKey2))
189+
assert(leftKeys.map(k => SortOrder(k, Ascending)) === leftOrder)
190+
assert(rightKeys.map(k => SortOrder(k, Ascending)) === rightOrder)
191+
case other => fail(other.toString)
192+
}
193+
}
138194
}

0 commit comments

Comments
 (0)