1818package org .apache .spark .sql .catalyst .optimizer
1919
2020import org .apache .spark .sql .catalyst .analysis .EliminateAnalysisOperators
21- import org .apache .spark .sql .catalyst .expressions .{ Or , And , Literal , Expression }
21+ import org .apache .spark .sql .catalyst .expressions ._
2222import org .apache .spark .sql .catalyst .plans .logical ._
2323import org .apache .spark .sql .catalyst .plans .PlanTest
2424import org .apache .spark .sql .catalyst .rules ._
2525import org .apache .spark .sql .catalyst .dsl .plans ._
2626import org .apache .spark .sql .catalyst .dsl .expressions ._
2727
28- class BooleanSimplificationSuite extends PlanTest {
28+ class BooleanSimplificationSuite extends PlanTest with PredicateHelper {
2929
3030 object Optimize extends RuleExecutor [LogicalPlan ] {
3131 val batches =
@@ -40,14 +40,21 @@ class BooleanSimplificationSuite extends PlanTest {
4040
4141 val testRelation = LocalRelation (' a .int, ' b .int, ' c .int, ' d .string)
4242
43+ // The `foldLeft` is required to handle cases like comparing `a && (b && c)` and `(a && b) && c`
4344 def compareConditions (e1 : Expression , e2 : Expression ): Boolean = (e1, e2) match {
44- case (And (l1, l2), And (r1, r2)) =>
45- compareConditions(l1, r1) && compareConditions(l2, r2) ||
46- compareConditions(l1, r2) && compareConditions(l2, r1)
47-
48- case (Or (l1, l2), Or (r1, r2)) =>
49- compareConditions(l1, r1) && compareConditions(l2, r2) ||
50- compareConditions(l1, r2) && compareConditions(l2, r1)
45+ case (lhs : And , rhs : And ) =>
46+ val lhsSet = splitConjunctivePredicates(lhs).toSet
47+ val rhsSet = splitConjunctivePredicates(rhs).toSet
48+ lhsSet.foldLeft(rhsSet) { (set, e) =>
49+ set.find(compareConditions(_, e)).map(set - _).getOrElse(set)
50+ }.isEmpty
51+
52+ case (lhs : Or , rhs : Or ) =>
53+ val lhsSet = splitDisjunctivePredicates(lhs).toSet
54+ val rhsSet = splitDisjunctivePredicates(rhs).toSet
55+ lhsSet.foldLeft(rhsSet) { (set, e) =>
56+ set.find(compareConditions(_, e)).map(set - _).getOrElse(set)
57+ }.isEmpty
5158
5259 case (l, r) => l == r
5360 }
0 commit comments