Skip to content

Commit a91ab70

Browse files
Davies Liudavies
authored andcommitted
[SPARK-17474] [SQL] fix python udf in TakeOrderedAndProjectExec
## What changes were proposed in this pull request? When there is any Python UDF in the Project between Sort and Limit, it will be collected into TakeOrderedAndProjectExec, ExtractPythonUDFs failed to pull the Python UDFs out because QueryPlan.expressions does not include the expression inside Option[Seq[Expression]]. Ideally, we should fix the `QueryPlan.expressions`, but tried with no luck (it always run into infinite loop). In PR, I changed the TakeOrderedAndProjectExec to no use Option[Seq[Expression]] to workaround it. cc JoshRosen ## How was this patch tested? Added regression test. Author: Davies Liu <[email protected]> Closes #15030 from davies/all_expr.
1 parent f9c580f commit a91ab70

File tree

4 files changed

+20
-12
lines changed

4 files changed

+20
-12
lines changed

python/pyspark/sql/tests.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -376,6 +376,14 @@ def test_udf_in_generate(self):
376376
row = df.select(explode(f(*df))).groupBy().sum().first()
377377
self.assertEqual(row[0], 10)
378378

379+
def test_udf_with_order_by_and_limit(self):
380+
from pyspark.sql.functions import udf
381+
my_copy = udf(lambda x: x, IntegerType())
382+
df = self.spark.range(10).orderBy("id")
383+
res = df.select(df.id, my_copy(df.id).alias("copy")).limit(1)
384+
res.explain(True)
385+
self.assertEqual(res.collect(), [Row(id=0, copy=0)])
386+
379387
def test_basic_functions(self):
380388
rdd = self.sc.parallelize(['{"foo":"bar"}', '{"foo":"baz"}'])
381389
df = self.spark.read.json(rdd)

sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -66,22 +66,22 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
6666
override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
6767
case logical.ReturnAnswer(rootPlan) => rootPlan match {
6868
case logical.Limit(IntegerLiteral(limit), logical.Sort(order, true, child)) =>
69-
execution.TakeOrderedAndProjectExec(limit, order, None, planLater(child)) :: Nil
69+
execution.TakeOrderedAndProjectExec(limit, order, child.output, planLater(child)) :: Nil
7070
case logical.Limit(
7171
IntegerLiteral(limit),
7272
logical.Project(projectList, logical.Sort(order, true, child))) =>
7373
execution.TakeOrderedAndProjectExec(
74-
limit, order, Some(projectList), planLater(child)) :: Nil
74+
limit, order, projectList, planLater(child)) :: Nil
7575
case logical.Limit(IntegerLiteral(limit), child) =>
7676
execution.CollectLimitExec(limit, planLater(child)) :: Nil
7777
case other => planLater(other) :: Nil
7878
}
7979
case logical.Limit(IntegerLiteral(limit), logical.Sort(order, true, child)) =>
80-
execution.TakeOrderedAndProjectExec(limit, order, None, planLater(child)) :: Nil
80+
execution.TakeOrderedAndProjectExec(limit, order, child.output, planLater(child)) :: Nil
8181
case logical.Limit(
8282
IntegerLiteral(limit), logical.Project(projectList, logical.Sort(order, true, child))) =>
8383
execution.TakeOrderedAndProjectExec(
84-
limit, order, Some(projectList), planLater(child)) :: Nil
84+
limit, order, projectList, planLater(child)) :: Nil
8585
case _ => Nil
8686
}
8787
}

sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -114,20 +114,20 @@ case class GlobalLimitExec(limit: Int, child: SparkPlan) extends BaseLimitExec {
114114
case class TakeOrderedAndProjectExec(
115115
limit: Int,
116116
sortOrder: Seq[SortOrder],
117-
projectList: Option[Seq[NamedExpression]],
117+
projectList: Seq[NamedExpression],
118118
child: SparkPlan) extends UnaryExecNode {
119119

120120
override def output: Seq[Attribute] = {
121-
projectList.map(_.map(_.toAttribute)).getOrElse(child.output)
121+
projectList.map(_.toAttribute)
122122
}
123123

124124
override def outputPartitioning: Partitioning = SinglePartition
125125

126126
override def executeCollect(): Array[InternalRow] = {
127127
val ord = new LazilyGeneratedOrdering(sortOrder, child.output)
128128
val data = child.execute().map(_.copy()).takeOrdered(limit)(ord)
129-
if (projectList.isDefined) {
130-
val proj = UnsafeProjection.create(projectList.get, child.output)
129+
if (projectList != child.output) {
130+
val proj = UnsafeProjection.create(projectList, child.output)
131131
data.map(r => proj(r).copy())
132132
} else {
133133
data
@@ -148,8 +148,8 @@ case class TakeOrderedAndProjectExec(
148148
localTopK, child.output, SinglePartition, serializer))
149149
shuffled.mapPartitions { iter =>
150150
val topK = org.apache.spark.util.collection.Utils.takeOrdered(iter.map(_.copy()), limit)(ord)
151-
if (projectList.isDefined) {
152-
val proj = UnsafeProjection.create(projectList.get, child.output)
151+
if (projectList != child.output) {
152+
val proj = UnsafeProjection.create(projectList, child.output)
153153
topK.map(r => proj(r))
154154
} else {
155155
topK

sql/core/src/test/scala/org/apache/spark/sql/execution/TakeOrderedAndProjectSuite.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ class TakeOrderedAndProjectSuite extends SparkPlanTest with SharedSQLContext {
5959
checkThatPlansAgree(
6060
generateRandomInputData(),
6161
input =>
62-
noOpFilter(TakeOrderedAndProjectExec(limit, sortOrder, None, input)),
62+
noOpFilter(TakeOrderedAndProjectExec(limit, sortOrder, input.output, input)),
6363
input =>
6464
GlobalLimitExec(limit,
6565
LocalLimitExec(limit,
@@ -74,7 +74,7 @@ class TakeOrderedAndProjectSuite extends SparkPlanTest with SharedSQLContext {
7474
generateRandomInputData(),
7575
input =>
7676
noOpFilter(
77-
TakeOrderedAndProjectExec(limit, sortOrder, Some(Seq(input.output.last)), input)),
77+
TakeOrderedAndProjectExec(limit, sortOrder, Seq(input.output.last), input)),
7878
input =>
7979
GlobalLimitExec(limit,
8080
LocalLimitExec(limit,

0 commit comments

Comments
 (0)