-
Notifications
You must be signed in to change notification settings - Fork 28.9k
[SPARK-42500][SQL] ConstantPropagation support more cases #40268
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
dc04923
0bb53d8
1387247
cabb083
0244034
b9f3fbb
9acb27e
ecd650a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -18,6 +18,7 @@ | |
| package org.apache.spark.sql.catalyst.optimizer | ||
|
|
||
| import scala.collection.immutable.HashSet | ||
| import scala.collection.mutable | ||
| import scala.collection.mutable.{ArrayBuffer, Stack} | ||
| import scala.util.control.NonFatal | ||
|
|
||
|
|
@@ -113,15 +114,13 @@ object ConstantPropagation extends Rule[LogicalPlan] { | |
| def apply(plan: LogicalPlan): LogicalPlan = plan.transformUpWithPruning( | ||
| _.containsAllPatterns(LITERAL, FILTER), ruleId) { | ||
| case f: Filter => | ||
| val (newCondition, _) = traverse(f.condition, replaceChildren = true, nullIsFalse = true) | ||
| if (newCondition.isDefined) { | ||
| f.copy(condition = newCondition.get) | ||
| } else { | ||
| f | ||
| } | ||
| f.mapExpressions(traverse(_, nullIsFalse = true, None)) | ||
| } | ||
|
|
||
| type EqualityPredicates = Seq[((AttributeReference, Literal), BinaryComparison)] | ||
| // The keys are always canonicalized `AttributeReference`s, but it is easier to use `Expression` | ||
| // type keys here instead of casting `AttributeReference.canonicalized` to `AttributeReference` at | ||
| // the calling sites. | ||
| type EqualityPredicates = mutable.Map[Expression, (Literal, BinaryComparison)] | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think Can we use type EqualityPredicates = AttributeMap[(Literal, BinaryComparison)]
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Actually, I deliberately used mutable map here to improve their addition (
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I added a small improvement in cabb083 |
||
|
|
||
| /** | ||
| * Traverse a condition as a tree and replace attributes with constant values. | ||
|
|
@@ -133,61 +132,57 @@ object ConstantPropagation extends Rule[LogicalPlan] { | |
| * - On matching [[Or]] or [[Not]], recursively traverse each children, propagate empty mapping. | ||
| * - Otherwise, stop traversal and propagate empty mapping. | ||
| * @param condition condition to be traversed | ||
| * @param replaceChildren whether to replace attributes with constant values in children | ||
| * @param nullIsFalse whether a boolean expression result can be considered to false e.g. in the | ||
| * case of `WHERE e`, null result of expression `e` means the same as if it | ||
| * resulted false | ||
| * @return A tuple including: | ||
| * 1. Option[Expression]: optional changed condition after traversal | ||
| * 2. EqualityPredicates: propagated mapping of attribute => constant | ||
| * @param equalityPredicates optional [[EqualityPredicates]] map to collect attribute => constant | ||
| * mapping in adjacent [[And]]]s | ||
| * @return changed condition after traversal | ||
| */ | ||
| private def traverse(condition: Expression, replaceChildren: Boolean, nullIsFalse: Boolean) | ||
| : (Option[Expression], EqualityPredicates) = | ||
| private def traverse( | ||
| condition: Expression, | ||
| nullIsFalse: Boolean, | ||
| equalityPredicates: Option[EqualityPredicates]): Expression = | ||
| condition match { | ||
| case e @ EqualTo(left: AttributeReference, right: Literal) | ||
| if safeToReplace(left, nullIsFalse) => | ||
| (None, Seq(((left, right), e))) | ||
| equalityPredicates.foreach(_ += left.canonicalized -> (right, e)) | ||
| e | ||
| case e @ EqualTo(left: Literal, right: AttributeReference) | ||
| if safeToReplace(right, nullIsFalse) => | ||
| (None, Seq(((right, left), e))) | ||
| equalityPredicates.foreach(_ += right.canonicalized -> (left, e)) | ||
| e | ||
| case e @ EqualNullSafe(left: AttributeReference, right: Literal) | ||
| if safeToReplace(left, nullIsFalse) => | ||
| (None, Seq(((left, right), e))) | ||
| equalityPredicates.foreach(_ += left.canonicalized -> (right, e)) | ||
| e | ||
| case e @ EqualNullSafe(left: Literal, right: AttributeReference) | ||
| if safeToReplace(right, nullIsFalse) => | ||
| (None, Seq(((right, left), e))) | ||
| case a: And => | ||
| val (newLeft, equalityPredicatesLeft) = | ||
| traverse(a.left, replaceChildren = false, nullIsFalse) | ||
| val (newRight, equalityPredicatesRight) = | ||
| traverse(a.right, replaceChildren = false, nullIsFalse) | ||
| val equalityPredicates = equalityPredicatesLeft ++ equalityPredicatesRight | ||
| val newSelf = if (equalityPredicates.nonEmpty && replaceChildren) { | ||
| Some(And(replaceConstants(newLeft.getOrElse(a.left), equalityPredicates), | ||
| replaceConstants(newRight.getOrElse(a.right), equalityPredicates))) | ||
| equalityPredicates.foreach(_ += right.canonicalized -> (left, e)) | ||
| e | ||
| case a @ And(left, right) => | ||
| val newEqualityPredicates: Option[EqualityPredicates] = | ||
| equalityPredicates.orElse(Some(mutable.Map.empty)) | ||
| val newLeft = traverse(left, nullIsFalse, newEqualityPredicates) | ||
| val newRight = traverse(right, nullIsFalse, newEqualityPredicates) | ||
| // We could recognize when conflicting constants are coming from the left and right sides | ||
| // and immediately shortcut the `And` expression to `Literal.FalseLiteral`, but that case is | ||
| // not so common and actually it is the job of `ConstantFolding` and `BooleanSimplification` | ||
| // rules to deal with those optimizations. | ||
| a.withNewChildren(if (equalityPredicates.isEmpty && newEqualityPredicates.get.nonEmpty) { | ||
| val replacedNewLeft = replaceConstants(newLeft, newEqualityPredicates.get) | ||
| val replacedNewRight = replaceConstants(newRight, newEqualityPredicates.get) | ||
| Seq(replacedNewLeft, replacedNewRight) | ||
| } else { | ||
| if (newLeft.isDefined || newRight.isDefined) { | ||
| Some(And(newLeft.getOrElse(a.left), newRight.getOrElse(a.right))) | ||
| } else { | ||
| None | ||
| } | ||
| } | ||
| (newSelf, equalityPredicates) | ||
| Seq(newLeft, newRight) | ||
| }) | ||
| case o: Or => | ||
| // Ignore the EqualityPredicates from children since they are only propagated through And. | ||
| val (newLeft, _) = traverse(o.left, replaceChildren = true, nullIsFalse) | ||
| val (newRight, _) = traverse(o.right, replaceChildren = true, nullIsFalse) | ||
| val newSelf = if (newLeft.isDefined || newRight.isDefined) { | ||
| Some(Or(left = newLeft.getOrElse(o.left), right = newRight.getOrElse((o.right)))) | ||
| } else { | ||
| None | ||
| } | ||
| (newSelf, Seq.empty) | ||
| o.mapChildren(traverse(_, nullIsFalse, None)) | ||
| case n: Not => | ||
| // Ignore the EqualityPredicates from children since they are only propagated through And. | ||
| val (newChild, _) = traverse(n.child, replaceChildren = true, nullIsFalse = false) | ||
| (newChild.map(Not), Seq.empty) | ||
| case _ => (None, Seq.empty) | ||
| n.mapChildren(traverse(_, nullIsFalse = false, None)) | ||
| case o => o | ||
| } | ||
|
|
||
| // We need to take into account if an attribute is nullable and the context of the conjunctive | ||
|
|
@@ -198,16 +193,14 @@ object ConstantPropagation extends Rule[LogicalPlan] { | |
| private def safeToReplace(ar: AttributeReference, nullIsFalse: Boolean) = | ||
| !ar.nullable || nullIsFalse | ||
|
|
||
| private def replaceConstants(condition: Expression, equalityPredicates: EqualityPredicates) | ||
| : Expression = { | ||
| val constantsMap = AttributeMap(equalityPredicates.map(_._1)) | ||
| val predicates = equalityPredicates.map(_._2).toSet | ||
| def replaceConstants0(expression: Expression) = expression transform { | ||
| case a: AttributeReference => constantsMap.getOrElse(a, a) | ||
| } | ||
| private def replaceConstants( | ||
| condition: Expression, | ||
| equalityPredicates: EqualityPredicates): Expression = { | ||
| val predicates = equalityPredicates.values.map(_._2).toSet | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we keep the val constantsMap = AttributeMap(equalityPredicates.map { case (attr, (lit, _)) => attr -> lit })
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The reason for changing The main point of that map is that we store only one |
||
| condition transform { | ||
| case e @ EqualTo(_, _) if !predicates.contains(e) => replaceConstants0(e) | ||
| case e @ EqualNullSafe(_, _) if !predicates.contains(e) => replaceConstants0(e) | ||
| case b: BinaryComparison if !predicates.contains(b) => b transform { | ||
| case a: AttributeReference => equalityPredicates.get(a.canonicalized).map(_._1).getOrElse(a) | ||
| } | ||
| } | ||
| } | ||
| } | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is our internal change:
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is very similar to a previous state of this PR but your version uses an immutable
AttributeMap[(Literal, BinaryComparison)]instead of a mutableMap[Expression, (Literal, BinaryComparison)]in my PR. The addition of mapsval equalityPredicates = equalityPredicatesLeft ++ equalityPredicatesRightinside handlingAnds is slower when we use immutable maps. That's why I proposed to use a mutable one.But since that version, I changed my PR to eliminate map addition at all and I use the mutable map as an accumulator from this commit: b9f3fbb. So the current version of this PR is even more optimized.