-
Notifications
You must be signed in to change notification settings - Fork 28.9k
[SPARK-19309][SQL] disable common subexpression elimination for conditional expressions #16659
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -67,28 +67,34 @@ class EquivalentExpressions { | |
| /** | ||
| * Adds the expression to this data structure recursively. Stops if a matching expression | ||
| * is found. That is, if `expr` has already been added, its children are not added. | ||
| * If ignoreLeaf is true, leaf nodes are ignored. | ||
| */ | ||
| def addExprTree( | ||
| root: Expression, | ||
| ignoreLeaf: Boolean = true, | ||
| skipReferenceToExpressions: Boolean = true): Unit = { | ||
| val skip = (root.isInstanceOf[LeafExpression] && ignoreLeaf) || | ||
| def addExprTree(expr: Expression): Unit = { | ||
| val skip = expr.isInstanceOf[LeafExpression] || | ||
| // `LambdaVariable` is usually used as a loop variable, which can't be evaluated ahead of the | ||
| // loop. So we can't evaluate sub-expressions containing `LambdaVariable` at the beginning. | ||
| root.find(_.isInstanceOf[LambdaVariable]).isDefined | ||
| // There are some special expressions that we should not recurse into children. | ||
| expr.find(_.isInstanceOf[LambdaVariable]).isDefined | ||
|
|
||
| // There are some special expressions that we should not recurse into all of its children. | ||
| // 1. CodegenFallback: it's children will not be used to generate code (call eval() instead) | ||
| // 2. ReferenceToExpressions: it's kind of an explicit sub-expression elimination. | ||
| val shouldRecurse = root match { | ||
| // TODO: some expressions implements `CodegenFallback` but can still do codegen, | ||
| // e.g. `CaseWhen`, we should support them. | ||
| case _: CodegenFallback => false | ||
| case _: ReferenceToExpressions if skipReferenceToExpressions => false | ||
| case _ => true | ||
| // 2. If: common subexpressions will always be evaluated at the beginning, but the true and | ||
|
||
| // false expressions in `If` may not get accessed, according to the predicate | ||
| // expression. We should only recurse into the predicate expression. | ||
| // 3. CaseWhen: like `If`, the children of `CaseWhen` only get accessed in a certain | ||
| // condition. We should only recurse into the first condition expression as it | ||
| // will always get accessed. | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Compared with the previous impl, will we miss some expression elimination chances?
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nvm, |
||
| // 4. Coalesce: it's also a conditional expression, we should only recurse into the first | ||
| // children, because others may not get accessed. | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Although Could you update the comments?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
| def childrenToRecurse: Seq[Expression] = expr match { | ||
| case _: CodegenFallback => Nil | ||
| case i: If => i.predicate :: Nil | ||
| // `CaseWhen` implements `CodegenFallback`, we only need to handle `CaseWhenCodegen` here. | ||
| case c: CaseWhenCodegen => c.children.head :: Nil | ||
| case c: Coalesce => c.children.head :: Nil | ||
| case other => other.children | ||
| } | ||
| if (!skip && !addExpr(root) && shouldRecurse) { | ||
| root.children.foreach(addExprTree(_, ignoreLeaf)) | ||
|
|
||
| if (!skip && !addExpr(expr)) { | ||
| childrenToRecurse.foreach(addExprTree) | ||
| } | ||
| } | ||
|
|
||
|
|
||
This file was deleted.
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -726,18 +726,18 @@ class CodegenContext { | |
| val subExprEliminationExprs = mutable.HashMap.empty[Expression, SubExprEliminationState] | ||
|
|
||
| // Add each expression tree and compute the common subexpressions. | ||
| expressions.foreach(equivalentExpressions.addExprTree(_, true, false)) | ||
| expressions.foreach(equivalentExpressions.addExprTree) | ||
|
|
||
| // Get all the expressions that appear at least twice and set up the state for subexpression | ||
| // elimination. | ||
| val commonExprs = equivalentExpressions.getAllEquivalentExprs.filter(_.size > 1) | ||
| val codes = commonExprs.map { e => | ||
| val expr = e.head | ||
| // Generate the code for this expression tree. | ||
| val code = expr.genCode(this) | ||
| val state = SubExprEliminationState(code.isNull, code.value) | ||
| val eval = expr.genCode(this) | ||
| val state = SubExprEliminationState(eval.isNull, eval.value) | ||
| e.foreach(subExprEliminationExprs.put(_, state)) | ||
| code.code.trim | ||
| eval.code.trim | ||
| } | ||
| SubExprCodes(codes, subExprEliminationExprs.toMap) | ||
| } | ||
|
|
@@ -747,7 +747,7 @@ class CodegenContext { | |
| * common subexpressions, generates the functions that evaluate those expressions and populates | ||
| * the mapping of common subexpressions to the generated functions. | ||
| */ | ||
| private def subexpressionElimination(expressions: Seq[Expression]) = { | ||
| private def subexpressionElimination(expressions: Seq[Expression]): Unit = { | ||
| // Add each expression tree and compute the common subexpressions. | ||
| expressions.foreach(equivalentExpressions.addExprTree(_)) | ||
|
|
||
|
|
@@ -761,13 +761,13 @@ class CodegenContext { | |
| val value = s"${fnName}Value" | ||
|
|
||
| // Generate the code for this expression tree and wrap it in a function. | ||
| val code = expr.genCode(this) | ||
| val eval = expr.genCode(this) | ||
| val fn = | ||
| s""" | ||
| |private void $fnName(InternalRow $INPUT_ROW) { | ||
| | ${code.code.trim} | ||
| | $isNull = ${code.isNull}; | ||
| | $value = ${code.value}; | ||
| | ${eval.code.trim} | ||
| | $isNull = ${eval.isNull}; | ||
| | $value = ${eval.value}; | ||
| |} | ||
| """.stripMargin | ||
|
|
||
|
|
@@ -780,9 +780,6 @@ class CodegenContext { | |
| // The cost of doing subexpression elimination is: | ||
| // 1. Extra function call, although this is probably *good* as the JIT can decide to | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nit: we removed
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. but we do have an extra function call to evaluate common subexpression at the beginning.
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. : ) Just removed
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. oh i see :)
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. maybe we should still keep it, to make the indent consistent between the "cost" part and the "benefit" part. It also makes it more obvious that we only have one "cost".
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am fine to keep it. |
||
| // inline or not. | ||
| // 2. Extra branch to check isLoaded. This branch is likely to be predicted correctly | ||
| // very often. The reason it is not loaded is because of a prior branch. | ||
| // 3. Extra store into isLoaded. | ||
| // The benefit doing subexpression elimination is: | ||
| // 1. Running the expression logic. Even for a simple expression, it is likely more than 3 | ||
| // above. | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -97,9 +97,9 @@ class SubexpressionEliminationSuite extends SparkFunSuite { | |
| val add2 = Add(add, add) | ||
|
|
||
| var equivalence = new EquivalentExpressions | ||
| equivalence.addExprTree(add, true) | ||
| equivalence.addExprTree(abs, true) | ||
| equivalence.addExprTree(add2, true) | ||
| equivalence.addExprTree(add) | ||
| equivalence.addExprTree(abs) | ||
| equivalence.addExprTree(add2) | ||
|
|
||
| // Should only have one equivalence for `one + two` | ||
| assert(equivalence.getAllEquivalentExprs.count(_.size > 1) == 1) | ||
|
|
@@ -115,41 +115,17 @@ class SubexpressionEliminationSuite extends SparkFunSuite { | |
| val mul2 = Multiply(mul, mul) | ||
| val sqrt = Sqrt(mul2) | ||
| val sum = Add(mul2, sqrt) | ||
| equivalence.addExprTree(mul, true) | ||
| equivalence.addExprTree(mul2, true) | ||
| equivalence.addExprTree(sqrt, true) | ||
| equivalence.addExprTree(sum, true) | ||
| equivalence.addExprTree(mul) | ||
| equivalence.addExprTree(mul2) | ||
| equivalence.addExprTree(sqrt) | ||
| equivalence.addExprTree(sum) | ||
|
|
||
| // (one * two), (one * two) * (one * two) and sqrt( (one * two) * (one * two) ) should be found | ||
| assert(equivalence.getAllEquivalentExprs.count(_.size > 1) == 3) | ||
| assert(equivalence.getEquivalentExprs(mul).size == 3) | ||
| assert(equivalence.getEquivalentExprs(mul2).size == 3) | ||
| assert(equivalence.getEquivalentExprs(sqrt).size == 2) | ||
| assert(equivalence.getEquivalentExprs(sum).size == 1) | ||
|
|
||
| // Some expressions inspired by TPCH-Q1 | ||
| // sum(l_quantity) as sum_qty, | ||
| // sum(l_extendedprice) as sum_base_price, | ||
| // sum(l_extendedprice * (1 - l_discount)) as sum_disc_price, | ||
| // sum(l_extendedprice * (1 - l_discount) * (1 + l_tax)) as sum_charge, | ||
| // avg(l_extendedprice) as avg_price, | ||
| // avg(l_discount) as avg_disc | ||
| equivalence = new EquivalentExpressions | ||
| val quantity = Literal(1) | ||
| val price = Literal(1.1) | ||
| val discount = Literal(.24) | ||
| val tax = Literal(0.1) | ||
| equivalence.addExprTree(quantity, false) | ||
| equivalence.addExprTree(price, false) | ||
| equivalence.addExprTree(Multiply(price, Subtract(Literal(1), discount)), false) | ||
| equivalence.addExprTree( | ||
| Multiply( | ||
| Multiply(price, Subtract(Literal(1), discount)), | ||
| Add(Literal(1), tax)), false) | ||
| equivalence.addExprTree(price, false) | ||
| equivalence.addExprTree(discount, false) | ||
| // quantity, price, discount and (price * (1 - discount)) | ||
| assert(equivalence.getAllEquivalentExprs.count(_.size > 1) == 4) | ||
| } | ||
|
|
||
| test("Expression equivalence - non deterministic") { | ||
|
|
@@ -167,11 +143,24 @@ class SubexpressionEliminationSuite extends SparkFunSuite { | |
| val add = Add(two, fallback) | ||
|
|
||
| val equivalence = new EquivalentExpressions | ||
| equivalence.addExprTree(add, true) | ||
| equivalence.addExprTree(add) | ||
| // the `two` inside `fallback` should not be added | ||
| assert(equivalence.getAllEquivalentExprs.count(_.size > 1) == 0) | ||
| assert(equivalence.getAllEquivalentExprs.count(_.size == 1) == 3) // add, two, explode | ||
| } | ||
|
|
||
| test("Children of conditional expressions") { | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. To other reviewers: the new |
||
| val condition = And(Literal(true), Literal(false)) | ||
| val add = Add(Literal(1), Literal(2)) | ||
| val ifExpr = If(condition, add, add) | ||
|
|
||
| val equivalence = new EquivalentExpressions | ||
| equivalence.addExprTree(ifExpr) | ||
| // the `add` inside `If` should not be added | ||
| assert(equivalence.getAllEquivalentExprs.count(_.size > 1) == 0) | ||
| // only ifExpr and its predicate expression | ||
| assert(equivalence.getAllEquivalentExprs.count(_.size == 1) == 2) | ||
| } | ||
| } | ||
|
|
||
| case class CodegenFallbackExpression(child: Expression) | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -143,9 +143,15 @@ case class SimpleTypedAggregateExpression( | |
| override lazy val aggBufferAttributes: Seq[AttributeReference] = | ||
| bufferSerializer.map(_.toAttribute.asInstanceOf[AttributeReference]) | ||
|
|
||
| private def serializeToBuffer(expr: Expression): Seq[Expression] = { | ||
| bufferSerializer.map(_.transform { | ||
| case _: BoundReference => expr | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. sorry, typo... |
||
| }) | ||
| } | ||
|
|
||
| override lazy val initialValues: Seq[Expression] = { | ||
| val zero = Literal.fromObject(aggregator.zero, bufferExternalType) | ||
| bufferSerializer.map(ReferenceToExpressions(_, zero :: Nil)) | ||
| serializeToBuffer(zero) | ||
| } | ||
|
|
||
| override lazy val updateExpressions: Seq[Expression] = { | ||
|
|
@@ -154,8 +160,7 @@ case class SimpleTypedAggregateExpression( | |
| "reduce", | ||
| bufferExternalType, | ||
| bufferDeserializer :: inputDeserializer.get :: Nil) | ||
|
|
||
| bufferSerializer.map(ReferenceToExpressions(_, reduced :: Nil)) | ||
| serializeToBuffer(reduced) | ||
| } | ||
|
|
||
| override lazy val mergeExpressions: Seq[Expression] = { | ||
|
|
@@ -170,8 +175,7 @@ case class SimpleTypedAggregateExpression( | |
| "merge", | ||
| bufferExternalType, | ||
| leftBuffer :: rightBuffer :: Nil) | ||
|
|
||
| bufferSerializer.map(ReferenceToExpressions(_, merged :: Nil)) | ||
| serializeToBuffer(merged) | ||
| } | ||
|
|
||
| override lazy val evaluateExpression: Expression = { | ||
|
|
@@ -181,19 +185,17 @@ case class SimpleTypedAggregateExpression( | |
| outputExternalType, | ||
| bufferDeserializer :: Nil) | ||
|
|
||
| val outputSerializeExprs = outputSerializer.map(_.transform { | ||
| case _: BoundReference => resultObj | ||
| }) | ||
|
|
||
| dataType match { | ||
| case s: StructType => | ||
| case _: StructType => | ||
| val objRef = outputSerializer.head.find(_.isInstanceOf[BoundReference]).get | ||
| val struct = If( | ||
| IsNull(objRef), | ||
| Literal.create(null, dataType), | ||
| CreateStruct(outputSerializer)) | ||
| ReferenceToExpressions(struct, resultObj :: Nil) | ||
| If(IsNull(objRef), Literal.create(null, dataType), CreateStruct(outputSerializeExprs)) | ||
| case _ => | ||
| assert(outputSerializer.length == 1) | ||
| outputSerializer.head transform { | ||
| case b: BoundReference => resultObj | ||
| } | ||
| assert(outputSerializeExprs.length == 1) | ||
| outputSerializeExprs.head | ||
| } | ||
| } | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
From the code change, I don't see any place other than tests using ignoreLeaf = false. Curious why we have it.