Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.spark.sql.catalyst.optimizer

import scala.collection.immutable.HashSet
import scala.collection.mutable
import scala.collection.mutable.{ArrayBuffer, Stack}
import scala.util.control.NonFatal

Expand Down Expand Up @@ -113,15 +114,13 @@ object ConstantPropagation extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan.transformUpWithPruning(
_.containsAllPatterns(LITERAL, FILTER), ruleId) {
case f: Filter =>
val (newCondition, _) = traverse(f.condition, replaceChildren = true, nullIsFalse = true)
if (newCondition.isDefined) {
f.copy(condition = newCondition.get)
} else {
f
}
f.mapExpressions(traverse(_, nullIsFalse = true, None))
Copy link
Member

@wangyum wangyum Mar 29, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is our internal change:

object ConstantPropagation extends Rule[LogicalPlan] {
  def apply(plan: LogicalPlan): LogicalPlan = plan.transformUpWithPruning(
    _.containsAllPatterns(LITERAL, FILTER), ruleId) {
    case f: Filter =>
      val (newCondition, _) = traverse(f.condition, replaceChildren = true, nullIsFalse = true)
      if (newCondition.isDefined) {
        f.copy(condition = newCondition.get)
      } else {
        f
      }
  }

  /**
   * Traverse a condition as a tree and replace attributes with constant values.
   * - On matching [[And]], recursively traverse each children and get propagated mappings.
   *   If the current node is not child of another [[And]], replace all occurrences of the
   *   attributes with the corresponding constant values.
   * - If a child of [[And]] is [[EqualTo]] or [[EqualNullSafe]], propagate the mapping
   *   of attribute => constant.
   * - On matching [[Or]] or [[Not]], recursively traverse each children, propagate empty mapping.
   * - Otherwise, stop traversal and propagate empty mapping.
   * @param condition condition to be traversed
   * @param replaceChildren whether to replace attributes with constant values in children
   * @param nullIsFalse whether a boolean expression result can be considered to false e.g. in the
   *                    case of `WHERE e`, null result of expression `e` means the same as if it
   *                    resulted false
   * @return A tuple including:
   *         1. Option[Expression]: optional changed condition after traversal
   *         2. AttributeMap: propagated mapping of attribute => constant
   */
  private def traverse(condition: Expression, replaceChildren: Boolean, nullIsFalse: Boolean)
    : (Option[Expression], AttributeMap[(Literal, BinaryComparison)]) =
    condition match {
      case e @ EqualTo(left: AttributeReference, right: Literal)
        if safeToReplace(left, nullIsFalse) =>
        (None, AttributeMap(Map(left -> (right, e))))
      case e @ EqualTo(left: Literal, right: AttributeReference)
        if safeToReplace(right, nullIsFalse) =>
        (None, AttributeMap(Map(right -> (left, e))))
      case e @ EqualNullSafe(left: AttributeReference, right: Literal)
        if safeToReplace(left, nullIsFalse) =>
        (None, AttributeMap(Map(left -> (right, e))))
      case e @ EqualNullSafe(left: Literal, right: AttributeReference)
        if safeToReplace(right, nullIsFalse) =>
        (None, AttributeMap(Map(right -> (left, e))))
      case a: And =>
        val (newLeft, equalityPredicatesLeft) =
          traverse(a.left, replaceChildren = false, nullIsFalse)
        val (newRight, equalityPredicatesRight) =
          traverse(a.right, replaceChildren = false, nullIsFalse)
        val equalityPredicates = equalityPredicatesLeft ++ equalityPredicatesRight
        val newSelf = if (equalityPredicates.nonEmpty && replaceChildren) {
          Some(And(replaceConstants(newLeft.getOrElse(a.left), equalityPredicates),
            replaceConstants(newRight.getOrElse(a.right), equalityPredicates)))
        } else {
          if (newLeft.isDefined || newRight.isDefined) {
            Some(And(newLeft.getOrElse(a.left), newRight.getOrElse(a.right)))
          } else {
            None
          }
        }
        (newSelf, equalityPredicates)
      case o: Or =>
        // Ignore the EqualityPredicates from children since they are only propagated through And.
        val (newLeft, _) = traverse(o.left, replaceChildren = true, nullIsFalse)
        val (newRight, _) = traverse(o.right, replaceChildren = true, nullIsFalse)
        val newSelf = if (newLeft.isDefined || newRight.isDefined) {
          Some(Or(left = newLeft.getOrElse(o.left), right = newRight.getOrElse((o.right))))
        } else {
          None
        }
        (newSelf, AttributeMap.empty)
      case n: Not =>
        // Ignore the EqualityPredicates from children since they are only propagated through And.
        val (newChild, _) = traverse(n.child, replaceChildren = true, nullIsFalse = false)
        (newChild.map(Not), AttributeMap.empty)
      case _ => (None, AttributeMap.empty)
    }

  // We need to take into account if an attribute is nullable and the context of the conjunctive
  // expression. E.g. `SELECT * FROM t WHERE NOT(c = 1 AND c + 1 = 1)` where attribute `c` can be
  // substituted into `1 + 1 = 1` if 'c' isn't nullable. If 'c' is nullable then the enclosing
  // NOT prevents us to do the substitution as NOT flips the context (`nullIsFalse`) of what a
  // null result of the enclosed expression means.
  private def safeToReplace(ar: AttributeReference, nullIsFalse: Boolean) =
    !ar.nullable || nullIsFalse

  private def replaceConstants(
      condition: Expression,
      equalityPredicates: AttributeMap[(Literal, BinaryComparison)]): Expression = {
    val constantsMap = AttributeMap(equalityPredicates.map { case (attr, (lit, _)) => attr -> lit })
    val predicates = equalityPredicates.values.map(_._2).toSet
    condition transform {
      case b: BinaryComparison if !predicates.contains(b) => b transform {
        case a: AttributeReference => constantsMap.getOrElse(a, a)
      }
    }
  }
}

Copy link
Contributor Author

@peter-toth peter-toth Mar 29, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is very similar to a previous state of this PR but your version uses an immutable AttributeMap[(Literal, BinaryComparison)] instead of a mutable Map[Expression, (Literal, BinaryComparison)] in my PR. The addition of maps val equalityPredicates = equalityPredicatesLeft ++ equalityPredicatesRight inside handling Ands is slower when we use immutable maps. That's why I proposed to use a mutable one.

But since that version, I changed my PR to eliminate map addition at all and I use the mutable map as an accumulator from this commit: b9f3fbb. So the current version of this PR is even more optimized.

}

type EqualityPredicates = Seq[((AttributeReference, Literal), BinaryComparison)]
// The keys are always canonicalized `AttributeReference`s, but it is easier to use `Expression`
// type keys here instead of casting `AttributeReference.canonicalized` to `AttributeReference` at
// the calling sites.
type EqualityPredicates = mutable.Map[Expression, (Literal, BinaryComparison)]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think AttributeReference is enough.

Can we use AttributeMap ? In order to avoid the use of x.canonicalized later:

type EqualityPredicates = AttributeMap[(Literal, BinaryComparison)]

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, I deliberately used mutable map here to improve their addition (equalityPredicates ++= equalityPredicatesRight) later.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added a small improvement in cabb083


/**
* Traverse a condition as a tree and replace attributes with constant values.
Expand All @@ -133,61 +132,57 @@ object ConstantPropagation extends Rule[LogicalPlan] {
* - On matching [[Or]] or [[Not]], recursively traverse each children, propagate empty mapping.
* - Otherwise, stop traversal and propagate empty mapping.
* @param condition condition to be traversed
* @param replaceChildren whether to replace attributes with constant values in children
* @param nullIsFalse whether a boolean expression result can be considered to false e.g. in the
* case of `WHERE e`, null result of expression `e` means the same as if it
* resulted false
* @return A tuple including:
* 1. Option[Expression]: optional changed condition after traversal
* 2. EqualityPredicates: propagated mapping of attribute => constant
* @param equalityPredicates optional [[EqualityPredicates]] map to collect attribute => constant
* mapping in adjacent [[And]]]s
* @return changed condition after traversal
*/
private def traverse(condition: Expression, replaceChildren: Boolean, nullIsFalse: Boolean)
: (Option[Expression], EqualityPredicates) =
private def traverse(
condition: Expression,
nullIsFalse: Boolean,
equalityPredicates: Option[EqualityPredicates]): Expression =
condition match {
case e @ EqualTo(left: AttributeReference, right: Literal)
if safeToReplace(left, nullIsFalse) =>
(None, Seq(((left, right), e)))
equalityPredicates.foreach(_ += left.canonicalized -> (right, e))
e
case e @ EqualTo(left: Literal, right: AttributeReference)
if safeToReplace(right, nullIsFalse) =>
(None, Seq(((right, left), e)))
equalityPredicates.foreach(_ += right.canonicalized -> (left, e))
e
case e @ EqualNullSafe(left: AttributeReference, right: Literal)
if safeToReplace(left, nullIsFalse) =>
(None, Seq(((left, right), e)))
equalityPredicates.foreach(_ += left.canonicalized -> (right, e))
e
case e @ EqualNullSafe(left: Literal, right: AttributeReference)
if safeToReplace(right, nullIsFalse) =>
(None, Seq(((right, left), e)))
case a: And =>
val (newLeft, equalityPredicatesLeft) =
traverse(a.left, replaceChildren = false, nullIsFalse)
val (newRight, equalityPredicatesRight) =
traverse(a.right, replaceChildren = false, nullIsFalse)
val equalityPredicates = equalityPredicatesLeft ++ equalityPredicatesRight
val newSelf = if (equalityPredicates.nonEmpty && replaceChildren) {
Some(And(replaceConstants(newLeft.getOrElse(a.left), equalityPredicates),
replaceConstants(newRight.getOrElse(a.right), equalityPredicates)))
equalityPredicates.foreach(_ += right.canonicalized -> (left, e))
e
case a @ And(left, right) =>
val newEqualityPredicates: Option[EqualityPredicates] =
equalityPredicates.orElse(Some(mutable.Map.empty))
val newLeft = traverse(left, nullIsFalse, newEqualityPredicates)
val newRight = traverse(right, nullIsFalse, newEqualityPredicates)
// We could recognize when conflicting constants are coming from the left and right sides
// and immediately shortcut the `And` expression to `Literal.FalseLiteral`, but that case is
// not so common and actually it is the job of `ConstantFolding` and `BooleanSimplification`
// rules to deal with those optimizations.
a.withNewChildren(if (equalityPredicates.isEmpty && newEqualityPredicates.get.nonEmpty) {
val replacedNewLeft = replaceConstants(newLeft, newEqualityPredicates.get)
val replacedNewRight = replaceConstants(newRight, newEqualityPredicates.get)
Seq(replacedNewLeft, replacedNewRight)
} else {
if (newLeft.isDefined || newRight.isDefined) {
Some(And(newLeft.getOrElse(a.left), newRight.getOrElse(a.right)))
} else {
None
}
}
(newSelf, equalityPredicates)
Seq(newLeft, newRight)
})
case o: Or =>
// Ignore the EqualityPredicates from children since they are only propagated through And.
val (newLeft, _) = traverse(o.left, replaceChildren = true, nullIsFalse)
val (newRight, _) = traverse(o.right, replaceChildren = true, nullIsFalse)
val newSelf = if (newLeft.isDefined || newRight.isDefined) {
Some(Or(left = newLeft.getOrElse(o.left), right = newRight.getOrElse((o.right))))
} else {
None
}
(newSelf, Seq.empty)
o.mapChildren(traverse(_, nullIsFalse, None))
case n: Not =>
// Ignore the EqualityPredicates from children since they are only propagated through And.
val (newChild, _) = traverse(n.child, replaceChildren = true, nullIsFalse = false)
(newChild.map(Not), Seq.empty)
case _ => (None, Seq.empty)
n.mapChildren(traverse(_, nullIsFalse = false, None))
case o => o
}

// We need to take into account if an attribute is nullable and the context of the conjunctive
Expand All @@ -198,16 +193,14 @@ object ConstantPropagation extends Rule[LogicalPlan] {
private def safeToReplace(ar: AttributeReference, nullIsFalse: Boolean) =
!ar.nullable || nullIsFalse

private def replaceConstants(condition: Expression, equalityPredicates: EqualityPredicates)
: Expression = {
val constantsMap = AttributeMap(equalityPredicates.map(_._1))
val predicates = equalityPredicates.map(_._2).toSet
def replaceConstants0(expression: Expression) = expression transform {
case a: AttributeReference => constantsMap.getOrElse(a, a)
}
private def replaceConstants(
condition: Expression,
equalityPredicates: EqualityPredicates): Expression = {
val predicates = equalityPredicates.values.map(_._2).toSet
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we keep the constantsMap?

val constantsMap = AttributeMap(equalityPredicates.map { case (attr, (lit, _)) => attr -> lit })

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The reason for changing EqualityPredicates to mutable.Map earlier was to avoid building a map here.

The main point of that map is that we store only one Literal (and its original BinaryComparision) assigned to an attribute key. So if we have 2 or more conflicting EqualTo then in replaceConstants() we keep only one's original form and rewrite the other conflicing ones. E.g. a = 1 AND a = 2 we store only a -> (2, a = 2) in the map and rewrite the expression to 2 = 1 AND a = 2.

condition transform {
case e @ EqualTo(_, _) if !predicates.contains(e) => replaceConstants0(e)
case e @ EqualNullSafe(_, _) if !predicates.contains(e) => replaceConstants0(e)
case b: BinaryComparison if !predicates.contains(b) => b transform {
case a: AttributeReference => equalityPredicates.get(a.canonicalized).map(_._1).getOrElse(a)
}
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -159,8 +159,9 @@ class ConstantPropagationSuite extends PlanTest {
columnA === Literal(1) && columnA === Literal(2) && columnB === Add(columnA, Literal(3)))

val correctAnswer = testRelation
.select(columnA)
.where(columnA === Literal(1) && columnA === Literal(2) && columnB === Literal(5)).analyze
.select(columnA, columnB)
.where(Literal.FalseLiteral)
.select(columnA).analyze

comparePlans(Optimize.execute(query.analyze), correctAnswer)
}
Expand All @@ -186,4 +187,31 @@ class ConstantPropagationSuite extends PlanTest {
.analyze
comparePlans(Optimize.execute(query2), correctAnswer2)
}

test("SPARK-42500: ConstantPropagation supports more cases") {
comparePlans(
Optimize.execute(testRelation.where(columnA === 1 && columnB > columnA + 2).analyze),
testRelation.where(columnA === 1 && columnB > 3).analyze)

comparePlans(
Optimize.execute(testRelation.where(columnA === 1 && columnA === 2).analyze),
testRelation.where(Literal.FalseLiteral).analyze)

comparePlans(
Optimize.execute(testRelation.where(columnA === 1 && columnA === columnA + 2).analyze),
testRelation.where(Literal.FalseLiteral).analyze)

comparePlans(
Optimize.execute(
testRelation.where((columnA === 1 || columnB === 2) && columnB === 1).analyze),
testRelation.where(columnA === 1 && columnB === 1).analyze)

comparePlans(
Optimize.execute(testRelation.where(columnA === 1 && columnA === 1).analyze),
testRelation.where(columnA === 1).analyze)

comparePlans(
Optimize.execute(testRelation.where(Not(columnA === 1 && columnA === columnA + 2)).analyze),
testRelation.where(Not(columnA === 1) || Not(columnA === columnA + 2)).analyze)
}
}