Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 46 additions & 1 deletion python/pyspark/sql/tests/test_pandas_cogrouped_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,9 @@

import unittest

from pyspark.sql.functions import array, explode, col, lit, udf, pandas_udf
from pyspark.sql.functions import array, explode, col, lit, udf, pandas_udf, sum
from pyspark.sql.types import DoubleType, StructType, StructField, Row
from pyspark.sql.window import Window
from pyspark.testing.sqlutils import ReusedSQLTestCase, have_pandas, have_pyarrow, \
pandas_requirement_message, pyarrow_requirement_message
from pyspark.testing.utils import QuietTest
Expand Down Expand Up @@ -215,6 +216,50 @@ def test_self_join(self):

self.assertEqual(row.asDict(), Row(column=2, value=2).asDict())

def test_with_window_function(self):
# SPARK-42168: a window function with same partition keys but differing key order
ids = 2
days = 100
vals = 10000
parts = 10

id_df = self.spark.range(ids)
day_df = self.spark.range(days).withColumnRenamed("id", "day")
vals_df = self.spark.range(vals).withColumnRenamed("id", "value")
df = id_df.join(day_df).join(vals_df)

left_df = df.withColumnRenamed("value", "left").repartition(parts).cache()
# SPARK-42132: this bug requires us to alias all columns from df here
right_df = df.select(
col("id").alias("id"), col("day").alias("day"), col("value").alias("right")
).repartition(parts).cache()

# note the column order is different to the groupBy("id", "day") column order below
window = Window.partitionBy("day", "id")

left_grouped_df = left_df.groupBy("id", "day")
right_grouped_df = right_df \
.withColumn("day_sum", sum(col("day")).over(window)) \
.groupBy("id", "day")

def cogroup(left: pd.DataFrame, right: pd.DataFrame) -> pd.DataFrame:
return pd.DataFrame([{
"id": left["id"][0] if not left.empty else (
right["id"][0] if not right.empty else None
),
"day": left["day"][0] if not left.empty else (
right["day"][0] if not right.empty else None
),
"lefts": len(left.index),
"rights": len(right.index)
}])

df = left_grouped_df.cogroup(right_grouped_df) \
.applyInPandas(cogroup, schema="id long, day long, lefts integer, rights integer")

actual = df.orderBy("id", "day").take(days)
self.assertEqual(actual, [Row(0, day, vals, vals) for day in range(days)])

@staticmethod
def _test_with_key(left, right, isLeft):

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistribution, Distribution, Partitioning}
import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, Distribution, HashClusteredDistribution, Partitioning}
import org.apache.spark.sql.execution.{BinaryExecNode, CoGroupedIterator, SparkPlan}
import org.apache.spark.sql.execution.python.PandasGroupUtils._
import org.apache.spark.sql.types.StructType
Expand Down Expand Up @@ -66,8 +66,8 @@ case class FlatMapCoGroupsInPandasExec(
override def outputPartitioning: Partitioning = left.outputPartitioning

override def requiredChildDistribution: Seq[Distribution] = {
val leftDist = if (leftGroup.isEmpty) AllTuples else ClusteredDistribution(leftGroup)
val rightDist = if (rightGroup.isEmpty) AllTuples else ClusteredDistribution(rightGroup)
val leftDist = if (leftGroup.isEmpty) AllTuples else HashClusteredDistribution(leftGroup)
val rightDist = if (rightGroup.isEmpty) AllTuples else HashClusteredDistribution(rightGroup)
leftDist :: rightDist :: Nil
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,18 @@

package org.apache.spark.sql.execution.exchange

import org.apache.spark.sql.catalyst.expressions.Literal
import org.apache.spark.api.python.PythonEvalType
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.Sum
import org.apache.spark.sql.catalyst.plans.Inner
import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, PartitioningCollection}
import org.apache.spark.sql.execution.{DummySparkPlan, SortExec}
import org.apache.spark.sql.execution.joins.SortMergeJoinExec
import org.apache.spark.sql.execution.python.FlatMapCoGroupsInPandasExec
import org.apache.spark.sql.execution.window.WindowExec
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.sql.types.{IntegerType, StructField, StructType}

class EnsureRequirementsSuite extends SharedSparkSession {
private val exprA = Literal(1)
Expand Down Expand Up @@ -135,4 +140,55 @@ class EnsureRequirementsSuite extends SharedSparkSession {
}.size == 2)
}
}

test("SPARK-42168: FlatMapCoGroupInPandas and Window function with differing key order") {
val lKey = AttributeReference("key", IntegerType)()
val lKey2 = AttributeReference("key2", IntegerType)()

val rKey = AttributeReference("key", IntegerType)()
val rKey2 = AttributeReference("key2", IntegerType)()
val rValue = AttributeReference("value", IntegerType)()

val left = DummySparkPlan()
val right = WindowExec(
Alias(
WindowExpression(
Sum(rValue).toAggregateExpression(),
WindowSpecDefinition(
Seq(rKey2, rKey),
Nil,
SpecifiedWindowFrame(RowFrame, UnboundedPreceding, UnboundedFollowing)
)
), "sum")() :: Nil,
Seq(rKey2, rKey),
Nil,
DummySparkPlan()
)

val pythonUdf = PythonUDF("pyUDF", null,
StructType(Seq(StructField("value", IntegerType))),
Seq.empty,
PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF,
true)

val flapMapCoGroup = FlatMapCoGroupsInPandasExec(
Seq(lKey, lKey2),
Seq(rKey, rKey2),
pythonUdf,
AttributeReference("value", IntegerType)() :: Nil,
left,
right
)

val result = EnsureRequirements.apply(flapMapCoGroup)
result match {
case FlatMapCoGroupsInPandasExec(leftKeys, rightKeys, _, _,
SortExec(leftOrder, false, _, _), SortExec(rightOrder, false, _, _)) =>
assert(leftKeys === Seq(lKey, lKey2))
assert(rightKeys === Seq(rKey, rKey2))
assert(leftKeys.map(k => SortOrder(k, Ascending)) === leftOrder)
assert(rightKeys.map(k => SortOrder(k, Ascending)) === rightOrder)
case other => fail(other.toString)
}
}
}