diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala index 1d756a2dcb744..580840b357624 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.optimizer import scala.collection.immutable.HashSet +import scala.collection.mutable import scala.collection.mutable.{ArrayBuffer, Stack} import scala.util.control.NonFatal @@ -113,15 +114,13 @@ object ConstantPropagation extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan.transformUpWithPruning( _.containsAllPatterns(LITERAL, FILTER), ruleId) { case f: Filter => - val (newCondition, _) = traverse(f.condition, replaceChildren = true, nullIsFalse = true) - if (newCondition.isDefined) { - f.copy(condition = newCondition.get) - } else { - f - } + f.mapExpressions(traverse(_, nullIsFalse = true, None)) } - type EqualityPredicates = Seq[((AttributeReference, Literal), BinaryComparison)] + // The keys are always canonicalized `AttributeReference`s, but it is easier to use `Expression` + // type keys here instead of casting `AttributeReference.canonicalized` to `AttributeReference` at + // the calling sites. + type EqualityPredicates = mutable.Map[Expression, (Literal, BinaryComparison)] /** * Traverse a condition as a tree and replace attributes with constant values. @@ -133,61 +132,57 @@ object ConstantPropagation extends Rule[LogicalPlan] { * - On matching [[Or]] or [[Not]], recursively traverse each children, propagate empty mapping. * - Otherwise, stop traversal and propagate empty mapping. * @param condition condition to be traversed - * @param replaceChildren whether to replace attributes with constant values in children * @param nullIsFalse whether a boolean expression result can be considered to false e.g. in the * case of `WHERE e`, null result of expression `e` means the same as if it * resulted false - * @return A tuple including: - * 1. Option[Expression]: optional changed condition after traversal - * 2. EqualityPredicates: propagated mapping of attribute => constant + * @param equalityPredicates optional [[EqualityPredicates]] map to collect attribute => constant + * mapping in adjacent [[And]]]s + * @return changed condition after traversal */ - private def traverse(condition: Expression, replaceChildren: Boolean, nullIsFalse: Boolean) - : (Option[Expression], EqualityPredicates) = + private def traverse( + condition: Expression, + nullIsFalse: Boolean, + equalityPredicates: Option[EqualityPredicates]): Expression = condition match { case e @ EqualTo(left: AttributeReference, right: Literal) if safeToReplace(left, nullIsFalse) => - (None, Seq(((left, right), e))) + equalityPredicates.foreach(_ += left.canonicalized -> (right, e)) + e case e @ EqualTo(left: Literal, right: AttributeReference) if safeToReplace(right, nullIsFalse) => - (None, Seq(((right, left), e))) + equalityPredicates.foreach(_ += right.canonicalized -> (left, e)) + e case e @ EqualNullSafe(left: AttributeReference, right: Literal) if safeToReplace(left, nullIsFalse) => - (None, Seq(((left, right), e))) + equalityPredicates.foreach(_ += left.canonicalized -> (right, e)) + e case e @ EqualNullSafe(left: Literal, right: AttributeReference) if safeToReplace(right, nullIsFalse) => - (None, Seq(((right, left), e))) - case a: And => - val (newLeft, equalityPredicatesLeft) = - traverse(a.left, replaceChildren = false, nullIsFalse) - val (newRight, equalityPredicatesRight) = - traverse(a.right, replaceChildren = false, nullIsFalse) - val equalityPredicates = equalityPredicatesLeft ++ equalityPredicatesRight - val newSelf = if (equalityPredicates.nonEmpty && replaceChildren) { - Some(And(replaceConstants(newLeft.getOrElse(a.left), equalityPredicates), - replaceConstants(newRight.getOrElse(a.right), equalityPredicates))) + equalityPredicates.foreach(_ += right.canonicalized -> (left, e)) + e + case a @ And(left, right) => + val newEqualityPredicates: Option[EqualityPredicates] = + equalityPredicates.orElse(Some(mutable.Map.empty)) + val newLeft = traverse(left, nullIsFalse, newEqualityPredicates) + val newRight = traverse(right, nullIsFalse, newEqualityPredicates) + // We could recognize when conflicting constants are coming from the left and right sides + // and immediately shortcut the `And` expression to `Literal.FalseLiteral`, but that case is + // not so common and actually it is the job of `ConstantFolding` and `BooleanSimplification` + // rules to deal with those optimizations. + a.withNewChildren(if (equalityPredicates.isEmpty && newEqualityPredicates.get.nonEmpty) { + val replacedNewLeft = replaceConstants(newLeft, newEqualityPredicates.get) + val replacedNewRight = replaceConstants(newRight, newEqualityPredicates.get) + Seq(replacedNewLeft, replacedNewRight) } else { - if (newLeft.isDefined || newRight.isDefined) { - Some(And(newLeft.getOrElse(a.left), newRight.getOrElse(a.right))) - } else { - None - } - } - (newSelf, equalityPredicates) + Seq(newLeft, newRight) + }) case o: Or => // Ignore the EqualityPredicates from children since they are only propagated through And. - val (newLeft, _) = traverse(o.left, replaceChildren = true, nullIsFalse) - val (newRight, _) = traverse(o.right, replaceChildren = true, nullIsFalse) - val newSelf = if (newLeft.isDefined || newRight.isDefined) { - Some(Or(left = newLeft.getOrElse(o.left), right = newRight.getOrElse((o.right)))) - } else { - None - } - (newSelf, Seq.empty) + o.mapChildren(traverse(_, nullIsFalse, None)) case n: Not => // Ignore the EqualityPredicates from children since they are only propagated through And. - val (newChild, _) = traverse(n.child, replaceChildren = true, nullIsFalse = false) - (newChild.map(Not), Seq.empty) - case _ => (None, Seq.empty) + n.mapChildren(traverse(_, nullIsFalse = false, None)) + case o => o } // We need to take into account if an attribute is nullable and the context of the conjunctive @@ -198,16 +193,14 @@ object ConstantPropagation extends Rule[LogicalPlan] { private def safeToReplace(ar: AttributeReference, nullIsFalse: Boolean) = !ar.nullable || nullIsFalse - private def replaceConstants(condition: Expression, equalityPredicates: EqualityPredicates) - : Expression = { - val constantsMap = AttributeMap(equalityPredicates.map(_._1)) - val predicates = equalityPredicates.map(_._2).toSet - def replaceConstants0(expression: Expression) = expression transform { - case a: AttributeReference => constantsMap.getOrElse(a, a) - } + private def replaceConstants( + condition: Expression, + equalityPredicates: EqualityPredicates): Expression = { + val predicates = equalityPredicates.values.map(_._2).toSet condition transform { - case e @ EqualTo(_, _) if !predicates.contains(e) => replaceConstants0(e) - case e @ EqualNullSafe(_, _) if !predicates.contains(e) => replaceConstants0(e) + case b: BinaryComparison if !predicates.contains(b) => b transform { + case a: AttributeReference => equalityPredicates.get(a.canonicalized).map(_._1).getOrElse(a) + } } } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantPropagationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantPropagationSuite.scala index f5f1455f94611..106af71a9d653 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantPropagationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantPropagationSuite.scala @@ -159,8 +159,9 @@ class ConstantPropagationSuite extends PlanTest { columnA === Literal(1) && columnA === Literal(2) && columnB === Add(columnA, Literal(3))) val correctAnswer = testRelation - .select(columnA) - .where(columnA === Literal(1) && columnA === Literal(2) && columnB === Literal(5)).analyze + .select(columnA, columnB) + .where(Literal.FalseLiteral) + .select(columnA).analyze comparePlans(Optimize.execute(query.analyze), correctAnswer) } @@ -186,4 +187,31 @@ class ConstantPropagationSuite extends PlanTest { .analyze comparePlans(Optimize.execute(query2), correctAnswer2) } + + test("SPARK-42500: ConstantPropagation supports more cases") { + comparePlans( + Optimize.execute(testRelation.where(columnA === 1 && columnB > columnA + 2).analyze), + testRelation.where(columnA === 1 && columnB > 3).analyze) + + comparePlans( + Optimize.execute(testRelation.where(columnA === 1 && columnA === 2).analyze), + testRelation.where(Literal.FalseLiteral).analyze) + + comparePlans( + Optimize.execute(testRelation.where(columnA === 1 && columnA === columnA + 2).analyze), + testRelation.where(Literal.FalseLiteral).analyze) + + comparePlans( + Optimize.execute( + testRelation.where((columnA === 1 || columnB === 2) && columnB === 1).analyze), + testRelation.where(columnA === 1 && columnB === 1).analyze) + + comparePlans( + Optimize.execute(testRelation.where(columnA === 1 && columnA === 1).analyze), + testRelation.where(columnA === 1).analyze) + + comparePlans( + Optimize.execute(testRelation.where(Not(columnA === 1 && columnA === columnA + 2)).analyze), + testRelation.where(Not(columnA === 1) || Not(columnA === columnA + 2)).analyze) + } }