diff --git a/python/pyspark/sql/tests/test_pandas_cogrouped_map.py b/python/pyspark/sql/tests/test_pandas_cogrouped_map.py index 94a12bfb3f656..485bb880437aa 100644 --- a/python/pyspark/sql/tests/test_pandas_cogrouped_map.py +++ b/python/pyspark/sql/tests/test_pandas_cogrouped_map.py @@ -17,8 +17,9 @@ import unittest -from pyspark.sql.functions import array, explode, col, lit, udf, pandas_udf +from pyspark.sql.functions import array, explode, col, lit, udf, pandas_udf, sum from pyspark.sql.types import DoubleType, StructType, StructField, Row +from pyspark.sql.window import Window from pyspark.testing.sqlutils import ReusedSQLTestCase, have_pandas, have_pyarrow, \ pandas_requirement_message, pyarrow_requirement_message from pyspark.testing.utils import QuietTest @@ -215,6 +216,50 @@ def test_self_join(self): self.assertEqual(row.asDict(), Row(column=2, value=2).asDict()) + def test_with_window_function(self): + # SPARK-42168: a window function with same partition keys but differing key order + ids = 2 + days = 100 + vals = 10000 + parts = 10 + + id_df = self.spark.range(ids) + day_df = self.spark.range(days).withColumnRenamed("id", "day") + vals_df = self.spark.range(vals).withColumnRenamed("id", "value") + df = id_df.join(day_df).join(vals_df) + + left_df = df.withColumnRenamed("value", "left").repartition(parts).cache() + # SPARK-42132: this bug requires us to alias all columns from df here + right_df = df.select( + col("id").alias("id"), col("day").alias("day"), col("value").alias("right") + ).repartition(parts).cache() + + # note the column order is different to the groupBy("id", "day") column order below + window = Window.partitionBy("day", "id") + + left_grouped_df = left_df.groupBy("id", "day") + right_grouped_df = right_df \ + .withColumn("day_sum", sum(col("day")).over(window)) \ + .groupBy("id", "day") + + def cogroup(left: pd.DataFrame, right: pd.DataFrame) -> pd.DataFrame: + return pd.DataFrame([{ + "id": left["id"][0] if not left.empty else ( + right["id"][0] if not right.empty else None + ), + "day": left["day"][0] if not left.empty else ( + right["day"][0] if not right.empty else None + ), + "lefts": len(left.index), + "rights": len(right.index) + }]) + + df = left_grouped_df.cogroup(right_grouped_df) \ + .applyInPandas(cogroup, schema="id long, day long, lefts integer, rights integer") + + actual = df.orderBy("id", "day").take(days) + self.assertEqual(actual, [Row(0, day, vals, vals) for day in range(days)]) + @staticmethod def _test_with_key(left, right, isLeft): diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapCoGroupsInPandasExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapCoGroupsInPandasExec.scala index e830ea6b54662..e4503bdd9f4d0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapCoGroupsInPandasExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapCoGroupsInPandasExec.scala @@ -21,7 +21,7 @@ import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType} import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistribution, Distribution, Partitioning} +import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, Distribution, HashClusteredDistribution, Partitioning} import org.apache.spark.sql.execution.{BinaryExecNode, CoGroupedIterator, SparkPlan} import org.apache.spark.sql.execution.python.PandasGroupUtils._ import org.apache.spark.sql.types.StructType @@ -66,8 +66,8 @@ case class FlatMapCoGroupsInPandasExec( override def outputPartitioning: Partitioning = left.outputPartitioning override def requiredChildDistribution: Seq[Distribution] = { - val leftDist = if (leftGroup.isEmpty) AllTuples else ClusteredDistribution(leftGroup) - val rightDist = if (rightGroup.isEmpty) AllTuples else ClusteredDistribution(rightGroup) + val leftDist = if (leftGroup.isEmpty) AllTuples else HashClusteredDistribution(leftGroup) + val rightDist = if (rightGroup.isEmpty) AllTuples else HashClusteredDistribution(rightGroup) leftDist :: rightDist :: Nil } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala index 0425be6f9a79e..e61619fac6be0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala @@ -17,13 +17,18 @@ package org.apache.spark.sql.execution.exchange -import org.apache.spark.sql.catalyst.expressions.Literal +import org.apache.spark.api.python.PythonEvalType +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate.Sum import org.apache.spark.sql.catalyst.plans.Inner import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, PartitioningCollection} import org.apache.spark.sql.execution.{DummySparkPlan, SortExec} import org.apache.spark.sql.execution.joins.SortMergeJoinExec +import org.apache.spark.sql.execution.python.FlatMapCoGroupsInPandasExec +import org.apache.spark.sql.execution.window.WindowExec import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession +import org.apache.spark.sql.types.{IntegerType, StructField, StructType} class EnsureRequirementsSuite extends SharedSparkSession { private val exprA = Literal(1) @@ -135,4 +140,55 @@ class EnsureRequirementsSuite extends SharedSparkSession { }.size == 2) } } + + test("SPARK-42168: FlatMapCoGroupInPandas and Window function with differing key order") { + val lKey = AttributeReference("key", IntegerType)() + val lKey2 = AttributeReference("key2", IntegerType)() + + val rKey = AttributeReference("key", IntegerType)() + val rKey2 = AttributeReference("key2", IntegerType)() + val rValue = AttributeReference("value", IntegerType)() + + val left = DummySparkPlan() + val right = WindowExec( + Alias( + WindowExpression( + Sum(rValue).toAggregateExpression(), + WindowSpecDefinition( + Seq(rKey2, rKey), + Nil, + SpecifiedWindowFrame(RowFrame, UnboundedPreceding, UnboundedFollowing) + ) + ), "sum")() :: Nil, + Seq(rKey2, rKey), + Nil, + DummySparkPlan() + ) + + val pythonUdf = PythonUDF("pyUDF", null, + StructType(Seq(StructField("value", IntegerType))), + Seq.empty, + PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF, + true) + + val flapMapCoGroup = FlatMapCoGroupsInPandasExec( + Seq(lKey, lKey2), + Seq(rKey, rKey2), + pythonUdf, + AttributeReference("value", IntegerType)() :: Nil, + left, + right + ) + + val result = EnsureRequirements.apply(flapMapCoGroup) + result match { + case FlatMapCoGroupsInPandasExec(leftKeys, rightKeys, _, _, + SortExec(leftOrder, false, _, _), SortExec(rightOrder, false, _, _)) => + assert(leftKeys === Seq(lKey, lKey2)) + assert(rightKeys === Seq(rKey, rKey2)) + assert(leftKeys.map(k => SortOrder(k, Ascending)) === leftOrder) + assert(rightKeys.map(k => SortOrder(k, Ascending)) === rightOrder) + case other => fail(other.toString) + } + } }