Skip to content

Commit 01a7d33

Browse files
cloud-fanhvanhovell
authored andcommitted
[SPARK-18711][SQL] should disable subexpression elimination for LambdaVariable
## What changes were proposed in this pull request? This is kind of a long-standing bug, it's hidden until #15780 , which may add `AssertNotNull` on top of `LambdaVariable` and thus enables subexpression elimination. However, subexpression elimination will evaluate the common expressions at the beginning, which is invalid for `LambdaVariable`. `LambdaVariable` usually represents loop variable, which can't be evaluated ahead of the loop. This PR skips expressions containing `LambdaVariable` when doing subexpression elimination. ## How was this patch tested? updated test in `DatasetAggregatorSuite` Author: Wenchen Fan <[email protected]> Closes #16143 from cloud-fan/aggregator.
1 parent 2460128 commit 01a7d33

File tree

2 files changed

+9
-5
lines changed

2 files changed

+9
-5
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions
2020
import scala.collection.mutable
2121

2222
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
23+
import org.apache.spark.sql.catalyst.expressions.objects.LambdaVariable
2324

2425
/**
2526
* This class is used to compute equality of (sub)expression trees. Expressions can be added
@@ -72,7 +73,10 @@ class EquivalentExpressions {
7273
root: Expression,
7374
ignoreLeaf: Boolean = true,
7475
skipReferenceToExpressions: Boolean = true): Unit = {
75-
val skip = root.isInstanceOf[LeafExpression] && ignoreLeaf
76+
val skip = (root.isInstanceOf[LeafExpression] && ignoreLeaf) ||
77+
// `LambdaVariable` is usually used as a loop variable, which can't be evaluated ahead of the
78+
// loop. So we can't evaluate sub-expressions containing `LambdaVariable` at the beginning.
79+
root.find(_.isInstanceOf[LambdaVariable]).isDefined
7680
// There are some special expressions that we should not recurse into children.
7781
// 1. CodegenFallback: it's children will not be used to generate code (call eval() instead)
7882
// 2. ReferenceToExpressions: it's kind of an explicit sub-expression elimination.

sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -92,13 +92,13 @@ object NameAgg extends Aggregator[AggData, String, String] {
9292
}
9393

9494

95-
object SeqAgg extends Aggregator[AggData, Seq[Int], Seq[Int]] {
95+
object SeqAgg extends Aggregator[AggData, Seq[Int], Seq[(Int, Int)]] {
9696
def zero: Seq[Int] = Nil
9797
def reduce(b: Seq[Int], a: AggData): Seq[Int] = a.a +: b
9898
def merge(b1: Seq[Int], b2: Seq[Int]): Seq[Int] = b1 ++ b2
99-
def finish(r: Seq[Int]): Seq[Int] = r
99+
def finish(r: Seq[Int]): Seq[(Int, Int)] = r.map(i => i -> i)
100100
override def bufferEncoder: Encoder[Seq[Int]] = ExpressionEncoder()
101-
override def outputEncoder: Encoder[Seq[Int]] = ExpressionEncoder()
101+
override def outputEncoder: Encoder[Seq[(Int, Int)]] = ExpressionEncoder()
102102
}
103103

104104

@@ -281,7 +281,7 @@ class DatasetAggregatorSuite extends QueryTest with SharedSQLContext {
281281

282282
checkDataset(
283283
ds.groupByKey(_.b).agg(SeqAgg.toColumn),
284-
"a" -> Seq(1, 2)
284+
"a" -> Seq(1 -> 1, 2 -> 2)
285285
)
286286
}
287287

0 commit comments

Comments
 (0)