From 08309e7e87fb61cb14130d07196beec294f96a58 Mon Sep 17 00:00:00 2001 From: Kris Mok Date: Sat, 18 Mar 2023 01:05:38 +0000 Subject: [PATCH 1/3] SPARK-42851: guard EquivalentExpressions.addExpr() with supportedExpression() --- .../expressions/EquivalentExpressions.scala | 6 +++++- .../SubexpressionEliminationSuite.scala | 16 +++++++++++++++- .../spark/sql/DataFrameAggregateSuite.scala | 7 +++++++ 3 files changed, 27 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala index 330d66a21bea5..12def60042d80 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala @@ -40,7 +40,11 @@ class EquivalentExpressions { * Returns true if there was already a matching expression. */ def addExpr(expr: Expression): Boolean = { - updateExprInMap(expr, equivalenceMap) + if (supportedExpression(expr)) { + updateExprInMap(expr, equivalenceMap) + } else { + false + } } /** diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala index b16629f59aa2d..4f4294bf874be 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala @@ -21,7 +21,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.plans.logical.LocalRelation import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.types.{BinaryType, DataType, IntegerType} +import org.apache.spark.sql.types.{ArrayType, BinaryType, DataType, IntegerType} class SubexpressionEliminationSuite extends SparkFunSuite with ExpressionEvalHelper { test("Semantic equals and hash") { @@ -449,6 +449,20 @@ class SubexpressionEliminationSuite extends SparkFunSuite with ExpressionEvalHel assert(e2.getCommonSubexpressions.size == 1) assert(e2.getCommonSubexpressions.head == add) } + + test("SPARK-42851: Handle supportExpressions consistently across add and get") { + val tx = { + val arr = Literal(Array(1, 2)) + val ArrayType(et, cn) = arr.dataType + val lv = NamedLambdaVariable("x", et, cn) + val lambda = LambdaFunction(lv, Seq(lv)) + ArrayTransform(arr, lambda) + } + val equivalence = new EquivalentExpressions + val isNewExpr = equivalence.addExpr(tx) + val cseState = equivalence.getExprState(tx) + assert(isNewExpr == cseState.isDefined) + } } case class CodegenFallbackExpression(child: Expression) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index 737d31cc6e913..2ba9039166f48 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -1538,6 +1538,13 @@ class DataFrameAggregateSuite extends QueryTest ) checkAnswer(res, Row(1, 1, 1) :: Row(4, 1, 2) :: Nil) } + + test("SPARK-42851: common subexpression should consistently handle aggregate and result exprs") { + val res = sql( + "select max(transform(array(id), x -> x)), max(transform(array(id), x -> x)) from range(2)" + ) + checkAnswer(res, Row(Array(1), Array(1))) + } } case class B(c: Option[Double]) From 28d101ee6765c5453189fa62d6b8ade1568d99d2 Mon Sep 17 00:00:00 2001 From: Kris Mok Date: Tue, 21 Mar 2023 00:30:22 +0000 Subject: [PATCH 2/3] adjust the test case --- .../expressions/SubexpressionEliminationSuite.scala | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala index 4f4294bf874be..7da218241e0bb 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala @@ -450,7 +450,7 @@ class SubexpressionEliminationSuite extends SparkFunSuite with ExpressionEvalHel assert(e2.getCommonSubexpressions.head == add) } - test("SPARK-42851: Handle supportExpressions consistently across add and get") { + test("SPARK-42851: Handle supportExpression consistently across add and get") { val tx = { val arr = Literal(Array(1, 2)) val ArrayType(et, cn) = arr.dataType @@ -459,9 +459,10 @@ class SubexpressionEliminationSuite extends SparkFunSuite with ExpressionEvalHel ArrayTransform(arr, lambda) } val equivalence = new EquivalentExpressions - val isNewExpr = equivalence.addExpr(tx) + equivalence.addExpr(tx) + val hasMatching = equivalence.addExpr(tx) val cseState = equivalence.getExprState(tx) - assert(isNewExpr == cseState.isDefined) + assert(hasMatching == cseState.isDefined) } } From d9f8bbdeb004a48ddef344eddedaf85590075250 Mon Sep 17 00:00:00 2001 From: Kris Mok Date: Tue, 21 Mar 2023 06:57:10 +0000 Subject: [PATCH 3/3] address comments, adjust test case to use LambdaVariable instead of NamedLambdaVariable --- .../SubexpressionEliminationSuite.scala | 21 ++++++++++--------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala index 7da218241e0bb..44d8ea3a112e1 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala @@ -21,7 +21,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.plans.logical.LocalRelation import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.types.{ArrayType, BinaryType, DataType, IntegerType} +import org.apache.spark.sql.types.{BinaryType, DataType, IntegerType, ObjectType} class SubexpressionEliminationSuite extends SparkFunSuite with ExpressionEvalHelper { test("Semantic equals and hash") { @@ -451,17 +451,18 @@ class SubexpressionEliminationSuite extends SparkFunSuite with ExpressionEvalHel } test("SPARK-42851: Handle supportExpression consistently across add and get") { - val tx = { - val arr = Literal(Array(1, 2)) - val ArrayType(et, cn) = arr.dataType - val lv = NamedLambdaVariable("x", et, cn) - val lambda = LambdaFunction(lv, Seq(lv)) - ArrayTransform(arr, lambda) + val expr = { + val function = (lambda: Expression) => Add(lambda, Literal(1)) + val elementType = IntegerType + val colClass = classOf[Array[Int]] + val inputType = ObjectType(colClass) + val inputObject = BoundReference(0, inputType, nullable = true) + objects.MapObjects(function, inputObject, elementType, true, Option(colClass)) } val equivalence = new EquivalentExpressions - equivalence.addExpr(tx) - val hasMatching = equivalence.addExpr(tx) - val cseState = equivalence.getExprState(tx) + equivalence.addExpr(expr) + val hasMatching = equivalence.addExpr(expr) + val cseState = equivalence.getExprState(expr) assert(hasMatching == cseState.isDefined) } }