Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -122,8 +122,6 @@ object ConstantPropagation extends Rule[LogicalPlan] {
}
}

type EqualityPredicates = Seq[((AttributeReference, Literal), BinaryComparison)]

/**
* Traverse a condition as a tree and replace attributes with constant values.
* - On matching [[And]], recursively traverse each children and get propagated mappings.
Expand All @@ -140,23 +138,23 @@ object ConstantPropagation extends Rule[LogicalPlan] {
* resulted false
* @return A tuple including:
* 1. Option[Expression]: optional changed condition after traversal
* 2. EqualityPredicates: propagated mapping of attribute => constant
* 2. AttributeMap: propagated mapping of attribute => constant
*/
private def traverse(condition: Expression, replaceChildren: Boolean, nullIsFalse: Boolean)
: (Option[Expression], EqualityPredicates) =
: (Option[Expression], AttributeMap[(Literal, BinaryComparison)]) =
condition match {
case e @ EqualTo(left: AttributeReference, right: Literal)
if safeToReplace(left, nullIsFalse) =>
(None, Seq(((left, right), e)))
(None, AttributeMap(Map(left -> (right, e))))
case e @ EqualTo(left: Literal, right: AttributeReference)
if safeToReplace(right, nullIsFalse) =>
(None, Seq(((right, left), e)))
(None, AttributeMap(Map(right -> (left, e))))
case e @ EqualNullSafe(left: AttributeReference, right: Literal)
if safeToReplace(left, nullIsFalse) =>
(None, Seq(((left, right), e)))
(None, AttributeMap(Map(left -> (right, e))))
case e @ EqualNullSafe(left: Literal, right: AttributeReference)
if safeToReplace(right, nullIsFalse) =>
(None, Seq(((right, left), e)))
(None, AttributeMap(Map(right -> (left, e))))
case a: And =>
val (newLeft, equalityPredicatesLeft) =
traverse(a.left, replaceChildren = false, nullIsFalse)
Expand All @@ -183,12 +181,12 @@ object ConstantPropagation extends Rule[LogicalPlan] {
} else {
None
}
(newSelf, Seq.empty)
(newSelf, AttributeMap.empty)
case n: Not =>
// Ignore the EqualityPredicates from children since they are only propagated through And.
val (newChild, _) = traverse(n.child, replaceChildren = true, nullIsFalse = false)
(newChild.map(Not), Seq.empty)
case _ => (None, Seq.empty)
(newChild.map(Not), AttributeMap.empty)
case _ => (None, AttributeMap.empty)
}

// We need to take into account if an attribute is nullable and the context of the conjunctive
Expand All @@ -199,16 +197,15 @@ object ConstantPropagation extends Rule[LogicalPlan] {
private def safeToReplace(ar: AttributeReference, nullIsFalse: Boolean) =
!ar.nullable || nullIsFalse

private def replaceConstants(condition: Expression, equalityPredicates: EqualityPredicates)
: Expression = {
val constantsMap = AttributeMap(equalityPredicates.map(_._1))
val predicates = equalityPredicates.map(_._2).toSet
def replaceConstants0(expression: Expression) = expression transform {
case a: AttributeReference => constantsMap.getOrElse(a, a)
}
private def replaceConstants(
condition: Expression,
equalityPredicates: AttributeMap[(Literal, BinaryComparison)]): Expression = {
val constantsMap = AttributeMap(equalityPredicates.map { case (attr, (lit, _)) => attr -> lit })
val predicates = equalityPredicates.values.map(_._2).toSet
condition transform {
case e @ EqualTo(_, _) if !predicates.contains(e) => replaceConstants0(e)
case e @ EqualNullSafe(_, _) if !predicates.contains(e) => replaceConstants0(e)
case b: BinaryComparison if !predicates.contains(b) => b transform {
case a: AttributeReference => constantsMap.getOrElse(a, a)
}
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -159,8 +159,9 @@ class ConstantPropagationSuite extends PlanTest {
columnA === Literal(1) && columnA === Literal(2) && columnB === Add(columnA, Literal(3)))

val correctAnswer = testRelation
.select(columnA)
.where(columnA === Literal(1) && columnA === Literal(2) && columnB === Literal(5)).analyze
.select(columnA, columnB)
.where(Literal.FalseLiteral)
.select(columnA).analyze

comparePlans(Optimize.execute(query.analyze), correctAnswer)
}
Expand All @@ -186,4 +187,31 @@ class ConstantPropagationSuite extends PlanTest {
.analyze
comparePlans(Optimize.execute(query2), correctAnswer2)
}

test("SPARK-42500: ConstantPropagation supports more cases") {
comparePlans(
Optimize.execute(testRelation.where(columnA === 1 && columnB > columnA + 2).analyze),
testRelation.where(columnA === 1 && columnB > 3).analyze)

comparePlans(
Optimize.execute(testRelation.where(columnA === 1 && columnA === 2).analyze),
testRelation.where(Literal.FalseLiteral).analyze)

comparePlans(
Optimize.execute(testRelation.where(columnA === 1 && columnA === columnA + 2).analyze),
testRelation.where(Literal.FalseLiteral).analyze)

comparePlans(
Optimize.execute(
testRelation.where((columnA === 1 || columnB === 2) && columnB === 1).analyze),
testRelation.where(columnA === 1 && columnB === 1).analyze)

comparePlans(
Optimize.execute(testRelation.where(columnA === 1 && columnA === 1).analyze),
testRelation.where(columnA === 1).analyze)

comparePlans(
Optimize.execute(testRelation.where(Not(columnA === 1 && columnA === columnA + 2)).analyze),
testRelation.where(Not(columnA === 1) || Not(columnA === columnA + 2)).analyze)
}
}