diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index f32f2c7986dc..d6d79c2508ba 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -81,6 +81,7 @@ abstract class Optimizer(sessionCatalog: SessionCatalog) TransposeWindow, NullPropagation, ConstantPropagation, + FilterReduction, FoldablePropagation, OptimizeIn, ConstantFolding, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala index 39709529c00d..386542583d84 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala @@ -151,6 +151,249 @@ object ConstantPropagation extends Rule[LogicalPlan] with PredicateHelper { } } +/** + * Substitutes expressions which can be statically reduced by constraints. + * eg. + * {{{ + * SELECT * FROM table WHERE i <= 5 AND i = 5 => ... WHERE i = 5 + * SELECT * FROM table WHERE i < j AND ... AND i > j => ... WHERE false + * }}} + */ +object FilterReduction extends Rule[LogicalPlan] with ConstraintHelper { + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case f: Filter => + val newCondition = normalizeAndReduceWithConstraints(f.condition) + if (newCondition fastEquals f.condition) { + f + } else { + f.copy(condition = newCondition) + } + } + + private def normalizeAndReduceWithConstraints(expression: Expression): Expression = + reduceWithConstraints(normalize(expression), true)._1 + + private def normalize(expression: Expression) = expression transform { + case GreaterThan(x, y) => LessThan(y, x) + case GreaterThanOrEqual(x, y) => LessThanOrEqual(y, x) + } + + /** + * Traverse a condition as a tree and simplify expressions with constraints. + * - This functions assumes that the plan has been normalized using [[normalize()]] + * - On matching [[And]], recursively traverse both children, simplify child expressions with + * propagated constraints from sibling and propagate up union of constraints. + * - If a child of [[And]] is [[LessThan]], [[LessThanOrEqual]], [[EqualTo]] or [[EqualNullSafe]], + * propagate the constraint. + * - On matching [[Or]], [[If]], [[CaseWhen]] or [[Not]] recursively traverse each children, but + * propagate up no constraints. + * - Starting off from a condition expression of a [[Filter]] as top node, to the bottom of the + * expression tree, [[And]], [[Or]], [[If]] and [[CaseWhen]] nodes are considered as safe, + * non-[[NullIntolerant]] nodes where reduction rules can be executed without nullability check. + * - Otherwise, stop traversal and propagate no constraints. + * @param expression expression to be traversed + * @param nullIsFalse defines if a null value can be considered as false + * @return A tuple including: + * 1. Expression: optionally changed expression after traversal + * 2. Seq[Expression]: propagated constraints + */ + private def reduceWithConstraints( + expression: Expression, + nullIsFalse: Boolean): (Expression, Seq[Expression]) = + expression match { + case e @ (_: LessThan | _: LessThanOrEqual | _: EqualTo | _: EqualNullSafe) + if e.deterministic => (e, Seq(e)) + case a @ And(left, right) => + val (newLeft, leftConstraints) = reduceWithConstraints(left, nullIsFalse) + val reducedRight = reduceWithConstraints(right, leftConstraints, nullIsFalse) + val (reducedNewRight, rightConstraints) = + reduceWithConstraints(reducedRight, nullIsFalse) + val reducedNewLeft = reduceWithConstraints(newLeft, rightConstraints, nullIsFalse) + val newAnd = if ((reducedNewLeft fastEquals left) && + (reducedNewRight fastEquals right)) { + a + } else { + And(reducedNewLeft, reducedNewRight) + } + (newAnd, leftConstraints ++ rightConstraints) + case o @ (_: Or | _: If | _: CaseWhen) => + (o.mapChildren(reduceWithConstraints(_, nullIsFalse)._1), Seq.empty) + case n: Not => + (n.mapChildren(reduceWithConstraints(_, false)._1), Seq.empty) + case _ => (expression, Seq.empty) + } + + private def reduceWithConstraints( + expression: Expression, + constraints: Seq[Expression], + nullIsFalse: Boolean) = + constraints + .foldLeft(expression)((e, constraint) => reduceWithConstraint(e, constraint, nullIsFalse)) + + private def planEqual(x: Expression, y: Expression) = + !x.foldable && !y.foldable && x.canonicalized == y.canonicalized + + private def valueEqual(x: Expression, y: Expression) = + x.foldable && y.foldable && EqualTo(x, y).eval(EmptyRow).asInstanceOf[Boolean] + + private def valueLessThan(x: Expression, y: Expression) = + x.foldable && y.foldable && LessThan(x, y).eval(EmptyRow).asInstanceOf[Boolean] + + private def valueLessThanOrEqual(x: Expression, y: Expression) = + x.foldable && y.foldable && LessThanOrEqual(x, y).eval(EmptyRow).asInstanceOf[Boolean] + + private def reduceWithConstraint( + expression: Expression, + constraint: Expression, + nullIsFalse: Boolean): Expression = + if (nullIsFalse || constraint.children.forall(!_.nullable)) { + constraint match { + case a LessThan b => expression transformUp { + case c LessThan d if planEqual(b, d) && (planEqual(a, c) || valueLessThanOrEqual(c, a)) => + Literal.TrueLiteral + case c LessThan d if planEqual(b, c) && (planEqual(a, d) || valueLessThanOrEqual(d, a)) => + Literal.FalseLiteral + case c LessThan d if planEqual(a, c) && (planEqual(b, d) || valueLessThanOrEqual(b, d)) => + Literal.TrueLiteral + case c LessThan d if planEqual(a, d) && (planEqual(b, c) || valueLessThanOrEqual(b, c)) => + Literal.FalseLiteral + + case c LessThanOrEqual d + if planEqual(b, d) && (planEqual(a, c) || valueLessThanOrEqual(c, a)) => + Literal.TrueLiteral + case c LessThanOrEqual d + if planEqual(b, c) && (planEqual(a, d) || valueLessThanOrEqual(d, a)) => + Literal.FalseLiteral + case c LessThanOrEqual d + if planEqual(a, c) && (planEqual(b, d) || valueLessThanOrEqual(b, d)) => + Literal.TrueLiteral + case c LessThanOrEqual d + if planEqual(a, d) && (planEqual(b, c) || valueLessThanOrEqual(b, c)) => + Literal.FalseLiteral + + case c EqualTo d if planEqual(b, d) && (planEqual(a, c) || valueLessThanOrEqual(c, a)) => + Literal.FalseLiteral + case c EqualTo d if planEqual(b, c) && (planEqual(a, d) || valueLessThanOrEqual(d, a)) => + Literal.FalseLiteral + case c EqualTo d if planEqual(a, c) && (planEqual(b, d) || valueLessThanOrEqual(b, d)) => + Literal.FalseLiteral + case c EqualTo d if planEqual(a, d) && (planEqual(b, c) || valueLessThanOrEqual(b, c)) => + Literal.FalseLiteral + + case c EqualNullSafe d + if planEqual(b, d) && (planEqual(a, c) || valueLessThanOrEqual(c, a)) => + Literal.FalseLiteral + case c EqualNullSafe d + if planEqual(b, c) && (planEqual(a, d) || valueLessThanOrEqual(d, a)) => + Literal.FalseLiteral + case c EqualNullSafe d + if planEqual(a, c) && (planEqual(b, d) || valueLessThanOrEqual(b, d)) => + Literal.FalseLiteral + case c EqualNullSafe d + if planEqual(a, d) && (planEqual(b, c) || valueLessThanOrEqual(b, c)) => + Literal.FalseLiteral + + case c EqualNullSafe d if planEqual(b, d) => EqualTo(c, d) + case c EqualNullSafe d if planEqual(b, c) => EqualTo(c, d) + case c EqualNullSafe d if planEqual(a, c) => EqualTo(c, d) + case c EqualNullSafe d if planEqual(a, d) => EqualTo(c, d) + } + case a LessThanOrEqual b => expression transformUp { + case c LessThan d if planEqual(b, d) && valueLessThan(c, a) => + Literal.TrueLiteral + case c LessThan d if planEqual(b, c) && (planEqual(a, d) || valueLessThanOrEqual(d, a)) => + Literal.FalseLiteral + case c LessThan d if planEqual(a, c) && valueLessThan(b, d) => + Literal.TrueLiteral + case c LessThan d if planEqual(a, d) && (planEqual(b, c) || valueLessThanOrEqual(b, c)) => + Literal.FalseLiteral + + case c LessThanOrEqual d + if planEqual(b, d) && (planEqual(a, c) || valueLessThanOrEqual(c, a)) => + Literal.TrueLiteral + case c LessThanOrEqual d if planEqual(b, c) && valueLessThan(d, a) => + Literal.FalseLiteral + case c LessThanOrEqual d if planEqual(b, c) && (planEqual(a, d) || valueEqual(a, d)) => + EqualTo(c, d) + case c LessThanOrEqual d + if planEqual(a, c) && (planEqual(b, d) || valueLessThanOrEqual(b, d)) => + Literal.TrueLiteral + case c LessThanOrEqual d if planEqual(a, d) && valueLessThan(b, c) => + Literal.FalseLiteral + case c LessThanOrEqual d if planEqual(a, d) && (planEqual(b, c) || valueEqual(b, c)) => + EqualTo(c, d) + + case c EqualTo d if planEqual(b, d) && valueLessThan(c, a) => Literal.FalseLiteral + case c EqualTo d if planEqual(b, c) && valueLessThan(d, a) => Literal.FalseLiteral + case c EqualTo d if planEqual(a, c) && valueLessThan(b, d) => Literal.FalseLiteral + case c EqualTo d if planEqual(a, d) && valueLessThan(b, c) => Literal.FalseLiteral + + case c EqualNullSafe d if planEqual(b, d) && valueLessThan(c, a) => Literal.FalseLiteral + case c EqualNullSafe d if planEqual(b, c) && valueLessThan(d, a) => Literal.FalseLiteral + case c EqualNullSafe d if planEqual(a, c) && valueLessThan(b, d) => Literal.FalseLiteral + case c EqualNullSafe d if planEqual(a, d) && valueLessThan(b, c) => Literal.FalseLiteral + + case c EqualNullSafe d if planEqual(b, d) => EqualTo(c, d) + case c EqualNullSafe d if planEqual(b, c) => EqualTo(c, d) + case c EqualNullSafe d if planEqual(a, c) => EqualTo(c, d) + case c EqualNullSafe d if planEqual(a, d) => EqualTo(c, d) + } + case a EqualTo b => expression transformUp { + case c LessThan d if planEqual(b, d) && planEqual(a, c) => Literal.FalseLiteral + case c LessThan d if planEqual(b, c) && planEqual(a, d) => Literal.FalseLiteral + case c LessThan d if planEqual(a, d) && planEqual(b, c) => Literal.FalseLiteral + case c LessThan d if planEqual(a, c) && planEqual(b, d) => Literal.FalseLiteral + + case c LessThanOrEqual d if planEqual(b, d) && planEqual(a, c) => Literal.TrueLiteral + case c LessThanOrEqual d if planEqual(b, c) && planEqual(a, d) => Literal.TrueLiteral + case c LessThanOrEqual d if planEqual(a, d) && planEqual(b, c) => Literal.TrueLiteral + case c LessThanOrEqual d if planEqual(a, c) && planEqual(b, d) => Literal.TrueLiteral + + case c EqualTo d if planEqual(b, d) && planEqual(a, c) => Literal.TrueLiteral + case c EqualTo d if planEqual(b, c) && planEqual(a, d) => Literal.TrueLiteral + case c EqualTo d if planEqual(a, d) && planEqual(b, c) => Literal.TrueLiteral + case c EqualTo d if planEqual(a, c) && planEqual(b, d) => Literal.TrueLiteral + + case c EqualNullSafe d if planEqual(b, d) => + if (planEqual(a, c)) Literal.TrueLiteral else EqualTo(c, d) + case c EqualNullSafe d if planEqual(b, c) => + if (planEqual(a, d)) Literal.TrueLiteral else EqualTo(c, d) + case c EqualNullSafe d if planEqual(a, d) => + if (planEqual(b, c)) Literal.TrueLiteral else EqualTo(c, d) + case c EqualNullSafe d if planEqual(a, c) => + if (planEqual(b, d)) Literal.TrueLiteral else EqualTo(c, d) + } + case a EqualNullSafe b => expression transformUp { + case c LessThan d if planEqual(b, d) && planEqual(a, c) => Literal.FalseLiteral + case c LessThan d if planEqual(b, c) && planEqual(d, a) => Literal.FalseLiteral + case c LessThan d if planEqual(a, d) && planEqual(b, c) => Literal.FalseLiteral + case c LessThan d if planEqual(a, c) && planEqual(d, b) => Literal.FalseLiteral + + case c LessThanOrEqual d if planEqual(b, d) && planEqual(a, c) => EqualTo(c, d) + case c LessThanOrEqual d if planEqual(b, c) && planEqual(a, d) => EqualTo(c, d) + case c LessThanOrEqual d if planEqual(a, d) && planEqual(b, c) => EqualTo(c, d) + case c LessThanOrEqual d if planEqual(a, c) && planEqual(b, d) => EqualTo(c, d) + + case c EqualNullSafe d if planEqual(b, d) && planEqual(a, c) => Literal.TrueLiteral + case c EqualNullSafe d if planEqual(b, c) && planEqual(a, d) => Literal.TrueLiteral + case c EqualNullSafe d if planEqual(a, d) && planEqual(b, c) => Literal.TrueLiteral + case c EqualNullSafe d if planEqual(a, c) && planEqual(b, d) => Literal.TrueLiteral + } + case _ => expression + } + } else { + constraint match { + case a EqualNullSafe b => expression transformUp { + case c EqualNullSafe d if planEqual(b, d) && planEqual(a, c) => Literal.TrueLiteral + case c EqualNullSafe d if planEqual(b, c) && planEqual(a, d) => Literal.TrueLiteral + case c EqualNullSafe d if planEqual(a, d) && planEqual(b, c) => Literal.TrueLiteral + case c EqualNullSafe d if planEqual(a, c) && planEqual(b, d) => Literal.TrueLiteral + } + case _ => expression + } + } +} + /** * Reorder associative integral-type operators and fold all constants into one. */ diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterReductionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterReductionSuite.scala new file mode 100644 index 000000000000..fad2ab17c577 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterReductionSuite.scala @@ -0,0 +1,460 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.analysis.EliminateSubqueryAliases +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} +import org.apache.spark.sql.catalyst.rules.RuleExecutor + +/** + * Unit tests for constant propagation in expressions. + */ +class FilterReductionSuite extends PlanTest { + + trait OptimizeBase extends RuleExecutor[LogicalPlan] { + protected def batches = + Batch("AnalysisNodes", Once, + EliminateSubqueryAliases) :: + Batch("FilterReduction", FixedPoint(10), + ConstantPropagation, + FilterReduction, + ConstantFolding, + BooleanSimplification, + SimplifyBinaryComparison, + PruneFilters) :: Nil + } + + object Optimize extends OptimizeBase + + object OptimizeWithoutFilterReduction extends OptimizeBase { + override protected def batches = + super.batches.map(b => Batch(b.name, b.strategy, b.rules.filterNot(_ == FilterReduction): _*)) + } + + val testRelation = LocalRelation('a.int, 'b.int, 'c.int.notNull, 'd.int.notNull, 'x.boolean) + + val data = { + val intElementsWithNull = Seq(null, 1, 2, 3, 4) + val intElementsWithoutNull = Seq(null, 1, 2, 3, 4) + val booleanElementsWithNull = Seq(null, true, false) + for { + a <- intElementsWithNull + b <- intElementsWithNull + c <- intElementsWithoutNull + d <- intElementsWithoutNull + x <- booleanElementsWithNull + } yield (a, b, c, d, x) + } + + val testRelationWithData = LocalRelation.fromExternalRows(testRelation.output, data.map(Row(_))) + + private def testFilterReduction( + input: Expression, + expectEmptyRelation: Boolean, + expectedConstraints: Seq[Expression] = Seq.empty) = { + val originalQuery = testRelationWithData.where(input).analyze + val optimized = Optimize.execute(originalQuery) + val correctAnswer = if (expectEmptyRelation) { + testRelation + } else if (expectedConstraints.isEmpty) { + testRelationWithData + } else { + testRelationWithData.where(expectedConstraints.reduce(And)).analyze + } + comparePlans(optimized, correctAnswer) + } + + private def testSameAsWithoutFilterReduction(input: Expression) = { + val originalQuery = testRelationWithData.where(input).analyze + val optimized = Optimize.execute(originalQuery) + val correctAnswer = OptimizeWithoutFilterReduction.execute(originalQuery) + comparePlans(optimized, correctAnswer) + } + + test("Filter reduction with nullable attributes") { + testFilterReduction('a < 2 && Literal(2) < 'a, true) + testFilterReduction('a < 2 && Literal(2) <= 'a, true) + testFilterReduction('a < 2 && Literal(2) === 'a, true) + testFilterReduction('a < 2 && Literal(2) <=> 'a, true) + testFilterReduction('a < 2 && Literal(2) >= 'a, false, Seq('a < 2)) + testFilterReduction('a < 2 && Literal(2) > 'a, false, Seq('a < 2)) + testFilterReduction('a <= 2 && Literal(2) < 'a, true) + testFilterReduction('a <= 2 && Literal(2) <= 'a, false, Seq('a === 2)) + testFilterReduction('a <= 2 && Literal(2) === 'a, false, Seq('a === 2)) + testFilterReduction('a <= 2 && Literal(2) <=> 'a, false, Seq('a <=> 2)) + testFilterReduction('a <= 2 && Literal(2) >= 'a, false, Seq('a <= 2)) + testFilterReduction('a <= 2 && Literal(2) > 'a, false, Seq('a < 2)) + testFilterReduction('a === 2 && Literal(2) < 'a, true) + testFilterReduction('a === 2 && Literal(2) <= 'a, false, Seq('a === 2)) + testFilterReduction('a === 2 && Literal(2) === 'a, false, Seq('a === 2)) + testFilterReduction('a === 2 && Literal(2) <=> 'a, false, Seq('a === 2)) + testFilterReduction('a === 2 && Literal(2) >= 'a, false, Seq('a === 2)) + testFilterReduction('a === 2 && Literal(2) > 'a, true) + testFilterReduction('a <=> 2 && Literal(2) < 'a, true) + testFilterReduction('a <=> 2 && Literal(2) <= 'a, false, Seq('a <=> 2)) + testFilterReduction('a <=> 2 && Literal(2) === 'a, false, Seq('a <=> 2)) + testFilterReduction('a <=> 2 && Literal(2) <=> 'a, false, Seq('a <=> 2)) + testFilterReduction('a <=> 2 && Literal(2) >= 'a, false, Seq('a <=> 2)) + testFilterReduction('a <=> 2 && Literal(2) > 'a, true) + testFilterReduction('a >= 2 && Literal(2) < 'a, false, Seq(Literal(2) < 'a)) + testFilterReduction('a >= 2 && Literal(2) <= 'a, false, Seq(Literal(2) <= 'a)) + testFilterReduction('a >= 2 && Literal(2) === 'a, false, Seq('a === 2)) + testFilterReduction('a >= 2 && Literal(2) <=> 'a, false, Seq('a <=> 2)) + testFilterReduction('a >= 2 && Literal(2) >= 'a, false, Seq('a === 2)) + testFilterReduction('a >= 2 && Literal(2) > 'a, true) + testFilterReduction('a > 2 && Literal(2) < 'a, false, Seq(Literal(2) < 'a)) + testFilterReduction('a > 2 && Literal(2) <= 'a, false, Seq(Literal(2) < 'a)) + testFilterReduction('a > 2 && Literal(2) === 'a, true) + testFilterReduction('a > 2 && Literal(2) <=> 'a, true) + testFilterReduction('a > 2 && Literal(2) >= 'a, true) + testFilterReduction('a > 2 && Literal(2) > 'a, true) + + testFilterReduction(('x || 'a < 2) && Literal(2) < 'a, false, Seq('x, Literal(2) < 'a)) + testFilterReduction(('x || 'a < 2) && Literal(2) <= 'a, false, Seq('x, Literal(2) <= 'a)) + testFilterReduction(('x || 'a < 2) && Literal(2) === 'a, false, Seq('x, Literal(2) === 'a)) + testFilterReduction(('x || 'a < 2) && Literal(2) <=> 'a, false, Seq('x, Literal(2) <=> 'a)) + testFilterReduction(('x || 'a < 2) && Literal(2) >= 'a, false, Seq('x || 'a < 2, 'a <= 2)) + testFilterReduction(('x || 'a < 2) && Literal(2) > 'a, false, Seq('a < 2)) + testFilterReduction(('x || 'a <= 2) && Literal(2) < 'a, false, Seq('x, Literal(2) < 'a)) + testFilterReduction(('x || 'a <= 2) && Literal(2) <= 'a, false, + Seq('x || 'a === 2, Literal(2) <= 'a)) + testFilterReduction(('x || 'a <= 2) && Literal(2) === 'a, false, Seq(Literal(2) === 'a)) + testFilterReduction(('x || 'a <= 2) && Literal(2) <=> 'a, false, Seq(Literal(2) <=> 'a)) + testFilterReduction(('x || 'a <= 2) && Literal(2) >= 'a, false, Seq('a <= 2)) + testFilterReduction(('x || 'a <= 2) && Literal(2) > 'a, false, Seq('a < 2)) + testFilterReduction(('x || 'a === 2) && Literal(2) < 'a, false, Seq('x, Literal(2) < 'a)) + testFilterReduction(('x || 'a === 2) && Literal(2) <= 'a, false, + Seq('x || 'a === 2, Literal(2) <= 'a)) + testFilterReduction(('x || 'a === 2) && Literal(2) === 'a, false, Seq(Literal(2) === 'a)) + testFilterReduction(('x || 'a === 2) && Literal(2) <=> 'a, false, Seq(Literal(2) <=> 'a)) + testFilterReduction(('x || 'a === 2) && Literal(2) >= 'a, false, Seq('x || 'a === 2, 'a <= 2)) + testFilterReduction(('x || 'a === 2) && Literal(2) > 'a, false, Seq('x, 'a < 2)) + testFilterReduction(('x || 'a <=> 2) && Literal(2) < 'a, false, Seq('x, Literal(2) < 'a)) + testFilterReduction(('x || 'a <=> 2) && Literal(2) <= 'a, false, + Seq('x || 'a === 2, Literal(2) <= 'a)) + testFilterReduction(('x || 'a <=> 2) && Literal(2) === 'a, false, Seq(Literal(2) === 'a)) + testFilterReduction(('x || 'a <=> 2) && Literal(2) <=> 'a, false, Seq(Literal(2) <=> 'a)) + testFilterReduction(('x || 'a <=> 2) && Literal(2) >= 'a, false, Seq('x || 'a === 2, 'a <= 2)) + testFilterReduction(('x || 'a <=> 2) && Literal(2) > 'a, false, Seq('x, 'a < 2)) + testFilterReduction(('x || 'a >= 2) && Literal(2) < 'a, false, Seq(Literal(2) < 'a)) + testFilterReduction(('x || 'a >= 2) && Literal(2) <= 'a, false, Seq(Literal(2) <= 'a)) + testFilterReduction(('x || 'a >= 2) && Literal(2) === 'a, false, Seq(Literal(2) === 'a)) + testFilterReduction(('x || 'a >= 2) && Literal(2) <=> 'a, false, Seq(Literal(2) <=> 'a)) + testFilterReduction(('x || 'a >= 2) && Literal(2) >= 'a, false, + Seq('x || 'a === 2, 'a <= Literal(2))) + testFilterReduction(('x || 'a >= 2) && Literal(2) > 'a, false, Seq('x, 'a < Literal(2))) + testFilterReduction(('x || 'a > 2) && Literal(2) < 'a, false, Seq(Literal(2) < 'a)) + testFilterReduction(('x || 'a > 2) && Literal(2) <= 'a, false, + Seq('x || Literal(2) < 'a, Literal(2) <= 'a)) + testFilterReduction(('x || 'a > 2) && Literal(2) === 'a, false, Seq('x, Literal(2) === 'a)) + testFilterReduction(('x || 'a > 2) && Literal(2) <=> 'a, false, Seq('x, Literal(2) <=> 'a)) + testFilterReduction(('x || 'a > 2) && Literal(2) >= 'a, false, Seq('x, 'a <= Literal(2))) + testFilterReduction(('x || 'a > 2) && Literal(2) > 'a, false, Seq('x, 'a < Literal(2))) + + testFilterReduction('a < 2 && Literal(3) < 'a, true) + testFilterReduction('a < 2 && Literal(3) <= 'a, true) + testFilterReduction('a < 2 && Literal(3) === 'a, true) + testFilterReduction('a < 2 && Literal(3) <=> 'a, true) + testFilterReduction('a < 2 && Literal(3) >= 'a, false, Seq('a < 2)) + testFilterReduction('a < 2 && Literal(3) > 'a, false, Seq('a < 2)) + testFilterReduction('a <= 2 && Literal(3) < 'a, true) + testFilterReduction('a <= 2 && Literal(3) <= 'a, true) + testFilterReduction('a <= 2 && Literal(3) === 'a, true) + testFilterReduction('a <= 2 && Literal(3) <=> 'a, true) + testFilterReduction('a <= 2 && Literal(3) >= 'a, false, Seq('a <= 2)) + testFilterReduction('a <= 2 && Literal(3) > 'a, false, Seq('a <= 2)) + testFilterReduction('a === 2 && Literal(3) < 'a, true) + testFilterReduction('a === 2 && Literal(3) <= 'a, true) + testFilterReduction('a === 2 && Literal(3) === 'a, true) + testFilterReduction('a === 2 && Literal(3) <=> 'a, true) + testFilterReduction('a === 2 && Literal(3) >= 'a, false, Seq('a === 2)) + testFilterReduction('a === 2 && Literal(3) > 'a, false, Seq('a === 2)) + testFilterReduction('a <=> 2 && Literal(3) < 'a, true) + testFilterReduction('a <=> 2 && Literal(3) <= 'a, true) + testFilterReduction('a <=> 2 && Literal(3) === 'a, true) + testFilterReduction('a <=> 2 && Literal(3) <=> 'a, true) + testFilterReduction('a <=> 2 && Literal(3) >= 'a, false, Seq('a <=> 2)) + testFilterReduction('a <=> 2 && Literal(3) > 'a, false, Seq('a <=> 2)) + testFilterReduction('a >= 2 && Literal(3) < 'a, false, Seq(Literal(3) < 'a)) + testFilterReduction('a >= 2 && Literal(3) <= 'a, false, Seq(Literal(3) <= 'a)) + testFilterReduction('a >= 2 && Literal(3) === 'a, false, Seq(Literal(3) === 'a)) + testFilterReduction('a >= 2 && Literal(3) <=> 'a, false, Seq(Literal(3) <=> 'a)) + testFilterReduction('a >= 2 && Literal(3) >= 'a, false, Seq(Literal(2) <= 'a, 'a <= 3)) + testFilterReduction('a >= 2 && Literal(3) > 'a, false, Seq(Literal(2) <= 'a, 'a < 3)) + testFilterReduction('a > 2 && Literal(3) < 'a, false, Seq(Literal(3) < 'a)) + testFilterReduction('a > 2 && Literal(3) <= 'a, false, Seq(Literal(3) <= 'a)) + testFilterReduction('a > 2 && Literal(3) === 'a, false, Seq(Literal(3) === 'a)) + testFilterReduction('a > 2 && Literal(3) <=> 'a, false, Seq(Literal(3) <=> 'a)) + testFilterReduction('a > 2 && Literal(3) >= 'a, false, Seq(Literal(2) < 'a, 'a <= 3)) + testFilterReduction('a > 2 && Literal(3) > 'a, false, Seq(Literal(2) < 'a, 'a < 3)) + + testFilterReduction(('x || 'a < 2) && Literal(3) < 'a, false, Seq('x, Literal(3) < 'a)) + testFilterReduction(('x || 'a < 2) && Literal(3) <= 'a, false, Seq('x, Literal(3) <= 'a)) + testFilterReduction(('x || 'a < 2) && Literal(3) === 'a, false, Seq('x, Literal(3) === 'a)) + testFilterReduction(('x || 'a < 2) && Literal(3) <=> 'a, false, Seq('x, Literal(3) <=> 'a)) + testFilterReduction(('x || 'a < 2) && Literal(3) >= 'a, false, Seq('x || 'a < 2, 'a <= 3)) + testFilterReduction(('x || 'a < 2) && Literal(3) > 'a, false, Seq('x || 'a < 2, 'a < 3)) + testFilterReduction(('x || 'a <= 2) && Literal(3) < 'a, false, Seq('x, Literal(3) < 'a)) + testFilterReduction(('x || 'a <= 2) && Literal(3) <= 'a, false, Seq('x, Literal(3) <= 'a)) + testFilterReduction(('x || 'a <= 2) && Literal(3) === 'a, false, Seq('x, Literal(3) === 'a)) + testFilterReduction(('x || 'a <= 2) && Literal(3) <=> 'a, false, Seq('x, Literal(3) <=> 'a)) + testFilterReduction(('x || 'a <= 2) && Literal(3) >= 'a, false, Seq('x || 'a <= 2, 'a <= 3)) + testFilterReduction(('x || 'a <= 2) && Literal(3) > 'a, false, Seq('x || 'a <= 2, 'a < 3)) + testFilterReduction(('x || 'a === 2) && Literal(3) < 'a, false, Seq('x, Literal(3) < 'a)) + testFilterReduction(('x || 'a === 2) && Literal(3) <= 'a, false, Seq('x, Literal(3) <= 'a)) + testFilterReduction(('x || 'a === 2) && Literal(3) === 'a, false, Seq('x, Literal(3) === 'a)) + testFilterReduction(('x || 'a === 2) && Literal(3) <=> 'a, false, Seq('x, Literal(3) <=> 'a)) + testFilterReduction(('x || 'a === 2) && Literal(3) >= 'a, false, Seq('x || 'a === 2, 'a <= 3)) + testFilterReduction(('x || 'a === 2) && Literal(3) > 'a, false, Seq('x || 'a === 2, 'a < 3)) + testFilterReduction(('x || 'a <=> 2) && Literal(3) < 'a, false, Seq('x, Literal(3) < 'a)) + testFilterReduction(('x || 'a <=> 2) && Literal(3) <= 'a, false, Seq('x, Literal(3) <= 'a)) + testFilterReduction(('x || 'a <=> 2) && Literal(3) === 'a, false, Seq('x, Literal(3) === 'a)) + testFilterReduction(('x || 'a <=> 2) && Literal(3) <=> 'a, false, Seq('x, Literal(3) <=> 'a)) + testFilterReduction(('x || 'a <=> 2) && Literal(3) >= 'a, false, Seq('x || 'a === 2, 'a <= 3)) + testFilterReduction(('x || 'a <=> 2) && Literal(3) > 'a, false, Seq('x || 'a === 2, 'a < 3)) + testFilterReduction(('x || 'a >= 2) && Literal(3) < 'a, false, Seq(Literal(3) < 'a)) + testFilterReduction(('x || 'a >= 2) && Literal(3) <= 'a, false, Seq(Literal(3) <= 'a)) + testFilterReduction(('x || 'a >= 2) && Literal(3) === 'a, false, Seq(Literal(3) === 'a)) + testFilterReduction(('x || 'a >= 2) && Literal(3) <=> 'a, false, Seq(Literal(3) <=> 'a)) + testFilterReduction(('x || 'a >= 2) && Literal(3) >= 'a, false, + Seq('x || Literal(2) <= 'a, 'a <= 3)) + testFilterReduction(('x || 'a >= 2) && Literal(3) > 'a, false, + Seq('x || Literal(2) <= 'a, 'a < 3)) + testFilterReduction(('x || 'a > 2) && Literal(3) < 'a, false, Seq(Literal(3) < 'a)) + testFilterReduction(('x || 'a > 2) && Literal(3) <= 'a, false, Seq(Literal(3) <= 'a)) + testFilterReduction(('x || 'a > 2) && Literal(3) === 'a, false, Seq(Literal(3) === 'a)) + testFilterReduction(('x || 'a > 2) && Literal(3) <=> 'a, false, Seq(Literal(3) <=> 'a)) + testFilterReduction(('x || 'a > 2) && Literal(3) >= 'a, false, + Seq('x || Literal(2) < 'a, 'a <= 3)) + testFilterReduction(('x || 'a > 2) && Literal(3) > 'a, false, + Seq('x || Literal(2) < 'a, 'a < 3)) + + testFilterReduction('a < 'b && 'b < 'a, true) + testFilterReduction('a < 'b && 'b <= 'a, true) + testFilterReduction('a < 'b && 'b === 'a, true) + testFilterReduction('a < 'b && 'b <=> 'a, true) + testFilterReduction('a < 'b && 'b >= 'a, false, Seq('a < 'b)) + testFilterReduction('a < 'b && 'b > 'a, false, Seq('a < 'b)) + testFilterReduction('a <= 'b && 'b < 'a, true) + testFilterReduction('a <= 'b && 'b <= 'a, false, Seq('a === 'b)) + testFilterReduction('a <= 'b && 'b === 'a, false, Seq('a === 'b)) + testFilterReduction('a <= 'b && 'b <=> 'a, false, Seq('a === 'b)) + testFilterReduction('a <= 'b && 'b >= 'a, false, Seq('a <= 'b)) + testFilterReduction('a <= 'b && 'b > 'a, false, Seq('a < 'b)) + testFilterReduction('a === 'b && 'b < 'a, true) + testFilterReduction('a === 'b && 'b <= 'a, false, Seq('a === 'b)) + testFilterReduction('a === 'b && 'b === 'a, false, Seq('a === 'b)) + testFilterReduction('a === 'b && 'b <=> 'a, false, Seq('a === 'b)) + testFilterReduction('a === 'b && 'b >= 'a, false, Seq('a === 'b)) + testFilterReduction('a === 'b && 'b > 'a, true) + testFilterReduction('a <=> 'b && 'b < 'a, true) + testFilterReduction('a <=> 'b && 'b <= 'a, false, Seq('a === 'b)) + testFilterReduction('a <=> 'b && 'b === 'a, false, Seq('a === 'b)) + testFilterReduction('a <=> 'b && 'b <=> 'a, false, Seq('a <=> 'b)) + testFilterReduction('a <=> 'b && 'b >= 'a, false, Seq('a === 'b)) + testFilterReduction('a <=> 'b && 'b > 'a, true) + testFilterReduction('a >= 'b && 'b < 'a, false, Seq('b < 'a)) + testFilterReduction('a >= 'b && 'b <= 'a, false, Seq('b <= 'a)) + testFilterReduction('a >= 'b && 'b === 'a, false, Seq('a === 'b)) + testFilterReduction('a >= 'b && 'b <=> 'a, false, Seq('a === 'b)) + testFilterReduction('a >= 'b && 'b >= 'a, false, Seq('a === 'b)) + testFilterReduction('a >= 'b && 'b > 'a, true) + testFilterReduction('a > 'b && 'b < 'a, false, Seq('b < 'a)) + testFilterReduction('a > 'b && 'b <= 'a, false, Seq('b < 'a)) + testFilterReduction('a > 'b && 'b === 'a, true) + testFilterReduction('a > 'b && 'b <=> 'a, true) + testFilterReduction('a > 'b && 'b >= 'a, true) + testFilterReduction('a > 'b && 'b > 'a, true) + + testFilterReduction('a < abs('b) && abs('b) < 'a, true) + testFilterReduction('a < abs('b) && abs('b) <= 'a, true) + testFilterReduction('a < abs('b) && abs('b) === 'a, true) + testFilterReduction('a < abs('b) && abs('b) <=> 'a, true) + testFilterReduction('a < abs('b) && abs('b) >= 'a, false, Seq('a < abs('b))) + testFilterReduction('a < abs('b) && abs('b) > 'a, false, Seq('a < abs('b))) + testFilterReduction('a <= abs('b) && abs('b) < 'a, true) + testFilterReduction('a <= abs('b) && abs('b) <= 'a, false, Seq('a === abs('b))) + testFilterReduction('a <= abs('b) && abs('b) === 'a, false, Seq('a === abs('b))) + testFilterReduction('a <= abs('b) && abs('b) <=> 'a, false, Seq('a === abs('b))) + testFilterReduction('a <= abs('b) && abs('b) >= 'a, false, Seq('a <= abs('b))) + testFilterReduction('a <= abs('b) && abs('b) > 'a, false, Seq('a < abs('b))) + testFilterReduction('a === abs('b) && abs('b) < 'a, true) + testFilterReduction('a === abs('b) && abs('b) <= 'a, false, Seq('a === abs('b))) + testFilterReduction('a === abs('b) && abs('b) === 'a, false, Seq('a === abs('b))) + testFilterReduction('a === abs('b) && abs('b) <=> 'a, false, Seq('a === abs('b))) + testFilterReduction('a === abs('b) && abs('b) >= 'a, false, Seq('a === abs('b))) + testFilterReduction('a === abs('b) && abs('b) > 'a, true) + testFilterReduction('a <=> abs('b) && abs('b) < 'a, true) + testFilterReduction('a <=> abs('b) && abs('b) <= 'a, false, Seq('a === abs('b))) + testFilterReduction('a <=> abs('b) && abs('b) === 'a, false, Seq('a === abs('b))) + testFilterReduction('a <=> abs('b) && abs('b) <=> 'a, false, Seq('a <=> abs('b))) + testFilterReduction('a <=> abs('b) && abs('b) >= 'a, false, Seq('a === abs('b))) + testFilterReduction('a <=> abs('b) && abs('b) > 'a, true) + testFilterReduction('a >= abs('b) && abs('b) < 'a, false, Seq(abs('b) < 'a)) + testFilterReduction('a >= abs('b) && abs('b) <= 'a, false, Seq(abs('b) <= 'a)) + testFilterReduction('a >= abs('b) && abs('b) === 'a, false, Seq('a === abs('b))) + testFilterReduction('a >= abs('b) && abs('b) <=> 'a, false, Seq('a === abs('b))) + testFilterReduction('a >= abs('b) && abs('b) >= 'a, false, Seq('a === abs('b))) + testFilterReduction('a >= abs('b) && abs('b) > 'a, true) + testFilterReduction('a > abs('b) && abs('b) < 'a, false, Seq(abs('b) < 'a)) + testFilterReduction('a > abs('b) && abs('b) <= 'a, false, Seq(abs('b) < 'a)) + testFilterReduction('a > abs('b) && abs('b) === 'a, true) + testFilterReduction('a > abs('b) && abs('b) <=> 'a, true) + testFilterReduction('a > abs('b) && abs('b) >= 'a, true) + testFilterReduction('a > abs('b) && abs('b) > 'a, true) + + testFilterReduction(('x || 'a < abs('b)) && abs('b) < 'a, false, Seq('x, abs('b) < 'a)) + testFilterReduction(('x || 'a < abs('b)) && abs('b) <= 'a, false, Seq('x, abs('b) <= 'a)) + testFilterReduction(('x || 'a < abs('b)) && abs('b) === 'a, false, Seq('x, abs('b) === 'a)) + testFilterReduction(('x || 'a < abs('b)) && abs('b) <=> 'a, false, Seq('x, abs('b) <=> 'a)) + testFilterReduction(('x || 'a < abs('b)) && abs('b) >= 'a, false, + Seq('x || 'a < abs('b), 'a <= abs('b))) + testFilterReduction(('x || 'a < abs('b)) && abs('b) > 'a, false, Seq('a < abs('b))) + testFilterReduction(('x || 'a <= abs('b)) && abs('b) < 'a, false, Seq('x, abs('b) < 'a)) + testFilterReduction(('x || 'a <= abs('b)) && abs('b) <= 'a, false, + Seq('x || 'a === abs('b), abs('b) <= 'a)) + testFilterReduction(('x || 'a <= abs('b)) && abs('b) === 'a, false, Seq(abs('b) === 'a)) + testFilterReduction(('x || 'a <= abs('b)) && abs('b) <=> 'a, false, + Seq('x || 'a === abs('b), abs('b) <=> 'a)) + testFilterReduction(('x || 'a <= abs('b)) && abs('b) >= 'a, false, Seq('a <= abs('b))) + testFilterReduction(('x || 'a <= abs('b)) && abs('b) > 'a, false, Seq('a < abs('b))) + testFilterReduction(('x || 'a === abs('b)) && abs('b) < 'a, false, Seq('x, abs('b) < 'a)) + testFilterReduction(('x || 'a === abs('b)) && abs('b) <= 'a, false, + Seq('x || 'a === abs('b), abs('b) <= 'a)) + testFilterReduction(('x || 'a === abs('b)) && abs('b) === 'a, false, Seq(abs('b) === 'a)) + testFilterReduction(('x || 'a === abs('b)) && abs('b) <=> 'a, false, + Seq('x || 'a === abs('b), abs('b) <=> 'a)) + testFilterReduction(('x || 'a === abs('b)) && abs('b) >= 'a, false, + Seq('x || 'a === abs('b), 'a <= abs('b))) + testFilterReduction(('x || 'a === abs('b)) && abs('b) > 'a, false, Seq('x, 'a < abs('b))) + testFilterReduction(('x || 'a <=> abs('b)) && abs('b) < 'a, false, Seq('x, abs('b) < 'a)) + testFilterReduction(('x || 'a <=> abs('b)) && abs('b) <= 'a, false, + Seq('x || 'a === abs('b), abs('b) <= 'a)) + testFilterReduction(('x || 'a <=> abs('b)) && abs('b) === 'a, false, Seq(abs('b) === 'a)) + testFilterReduction(('x || 'a <=> abs('b)) && abs('b) <=> 'a, false, Seq(abs('b) <=> 'a)) + testFilterReduction(('x || 'a <=> abs('b)) && abs('b) >= 'a, false, + Seq('x || 'a === abs('b), 'a <= abs('b))) + testFilterReduction(('x || 'a <=> abs('b)) && abs('b) > 'a, false, Seq('x, 'a < abs('b))) + testFilterReduction(('x || 'a >= abs('b)) && abs('b) < 'a, false, Seq(abs('b) < 'a)) + testFilterReduction(('x || 'a >= abs('b)) && abs('b) <= 'a, false, Seq(abs('b) <= 'a)) + testFilterReduction(('x || 'a >= abs('b)) && abs('b) === 'a, false, Seq(abs('b) === 'a)) + testFilterReduction(('x || 'a >= abs('b)) && abs('b) <=> 'a, false, + Seq('x || 'a === abs('b), abs('b) <=> 'a)) + testFilterReduction(('x || 'a >= abs('b)) && abs('b) >= 'a, false, + Seq('x || 'a === abs('b), 'a <= abs('b))) + testFilterReduction(('x || 'a >= abs('b)) && abs('b) > 'a, false, Seq('x, 'a < abs('b))) + testFilterReduction(('x || 'a > abs('b)) && abs('b) < 'a, false, Seq(abs('b) < 'a)) + testFilterReduction(('x || 'a > abs('b)) && abs('b) <= 'a, false, + Seq('x || abs('b) < 'a, abs('b) <= 'a)) + testFilterReduction(('x || 'a > abs('b)) && abs('b) === 'a, false, Seq('x, abs('b) === 'a)) + testFilterReduction(('x || 'a > abs('b)) && abs('b) <=> 'a, false, Seq('x, abs('b) <=> 'a)) + testFilterReduction(('x || 'a > abs('b)) && abs('b) >= 'a, false, Seq('x, 'a <= abs('b))) + testFilterReduction(('x || 'a > abs('b)) && abs('b) > 'a, false, Seq('x, 'a < abs('b))) + } + + // These cases test scenarios when there is NullIntolerant node (ex. Not) between Filter and the + // subtree of And nodes and the expression to be reduced is nullable. + // For example in these cases the following reduction does not hold when X and Y is null and so + // FilterReduction should do nothing: + // Not(X < Y && Y < X) => Not(X < Y && false) + // Not(X < Y && Not(X < Y)) => Not(X < Y && Not(true)) + // Not(X <=> Y && Y < X) => Not(X <=> Y && false) + // Not(X < Y && Y <=> X) => Not(X < Y && false) + test("Filter reduction with nullable attributes and NullIntolerant nodes") { + testSameAsWithoutFilterReduction(Not(('x || 'a < abs('b)) && abs('b) < 'a)) + testSameAsWithoutFilterReduction(Not(('x || 'a < abs('b)) && abs('b) <= 'a)) + testSameAsWithoutFilterReduction(Not(('x || 'a < abs('b)) && abs('b) === 'a)) + testSameAsWithoutFilterReduction(Not(('x || 'a < abs('b)) && abs('b) <=> 'a)) + testSameAsWithoutFilterReduction(Not(('x || 'a < abs('b)) && abs('b) >= 'a)) + testSameAsWithoutFilterReduction(Not(('x || 'a < abs('b)) && abs('b) > 'a)) + testSameAsWithoutFilterReduction(Not(('x || 'a <= abs('b)) && abs('b) < 'a)) + testSameAsWithoutFilterReduction(Not(('x || 'a <= abs('b)) && abs('b) <= 'a)) + testSameAsWithoutFilterReduction(Not(('x || 'a <= abs('b)) && abs('b) === 'a)) + testSameAsWithoutFilterReduction(Not(('x || 'a <= abs('b)) && abs('b) <=> 'a)) + testSameAsWithoutFilterReduction(Not(('x || 'a <= abs('b)) && abs('b) >= 'a)) + testSameAsWithoutFilterReduction(Not(('x || 'a <= abs('b)) && abs('b) > 'a)) + testSameAsWithoutFilterReduction(Not(('x || 'a === abs('b)) && abs('b) < 'a)) + testSameAsWithoutFilterReduction(Not(('x || 'a === abs('b)) && abs('b) <= 'a)) + testSameAsWithoutFilterReduction(Not(('x || 'a === abs('b)) && abs('b) === 'a)) + testSameAsWithoutFilterReduction(Not(('x || 'a === abs('b)) && abs('b) <=> 'a)) + testSameAsWithoutFilterReduction(Not(('x || 'a === abs('b)) && abs('b) >= 'a)) + testSameAsWithoutFilterReduction(Not(('x || 'a === abs('b)) && abs('b) > 'a)) + testSameAsWithoutFilterReduction(Not(('x || 'a <=> abs('b)) && abs('b) < 'a)) + testSameAsWithoutFilterReduction(Not(('x || 'a <=> abs('b)) && abs('b) <= 'a)) + testSameAsWithoutFilterReduction(Not(('x || 'a <=> abs('b)) && abs('b) === 'a)) + // the only exception: + // testSameAsWithoutFilterReduction(Not(('x || 'a <=> abs('b)) && abs('b) <=> 'a)) + testSameAsWithoutFilterReduction(Not(('x || 'a <=> abs('b)) && abs('b) >= 'a)) + testSameAsWithoutFilterReduction(Not(('x || 'a <=> abs('b)) && abs('b) > 'a)) + testSameAsWithoutFilterReduction(Not(('x || 'a >= abs('b)) && abs('b) < 'a)) + testSameAsWithoutFilterReduction(Not(('x || 'a >= abs('b)) && abs('b) <= 'a)) + testSameAsWithoutFilterReduction(Not(('x || 'a >= abs('b)) && abs('b) === 'a)) + testSameAsWithoutFilterReduction(Not(('x || 'a >= abs('b)) && abs('b) <=> 'a)) + testSameAsWithoutFilterReduction(Not(('x || 'a >= abs('b)) && abs('b) >= 'a)) + testSameAsWithoutFilterReduction(Not(('x || 'a >= abs('b)) && abs('b) > 'a)) + testSameAsWithoutFilterReduction(Not(('x || 'a > abs('b)) && abs('b) < 'a)) + testSameAsWithoutFilterReduction(Not(('x || 'a > abs('b)) && abs('b) <= 'a)) + testSameAsWithoutFilterReduction(Not(('x || 'a > abs('b)) && abs('b) === 'a)) + testSameAsWithoutFilterReduction(Not(('x || 'a > abs('b)) && abs('b) <=> 'a)) + testSameAsWithoutFilterReduction(Not(('x || 'a > abs('b)) && abs('b) >= 'a)) + testSameAsWithoutFilterReduction(Not(('x || 'a > abs('b)) && abs('b) > 'a)) + + testSameAsWithoutFilterReduction(Not(('x || Not('a < abs('b))) && abs('b) < 'a)) + testSameAsWithoutFilterReduction(Not(('x || Not('a < abs('b))) && abs('b) <= 'a)) + testSameAsWithoutFilterReduction(Not(('x || Not('a < abs('b))) && abs('b) === 'a)) + testSameAsWithoutFilterReduction(Not(('x || Not('a < abs('b))) && abs('b) <=> 'a)) + testSameAsWithoutFilterReduction(Not(('x || Not('a < abs('b))) && abs('b) >= 'a)) + testSameAsWithoutFilterReduction(Not(('x || Not('a < abs('b))) && abs('b) > 'a)) + testSameAsWithoutFilterReduction(Not(('x || Not('a <= abs('b))) && abs('b) < 'a)) + testSameAsWithoutFilterReduction(Not(('x || Not('a <= abs('b))) && abs('b) <= 'a)) + testSameAsWithoutFilterReduction(Not(('x || Not('a <= abs('b))) && abs('b) === 'a)) + testSameAsWithoutFilterReduction(Not(('x || Not('a <= abs('b))) && abs('b) <=> 'a)) + testSameAsWithoutFilterReduction(Not(('x || Not('a <= abs('b))) && abs('b) >= 'a)) + testSameAsWithoutFilterReduction(Not(('x || Not('a <= abs('b))) && abs('b) > 'a)) + testSameAsWithoutFilterReduction(Not(('x || Not('a === abs('b))) && abs('b) < 'a)) + testSameAsWithoutFilterReduction(Not(('x || Not('a === abs('b))) && abs('b) <= 'a)) + testSameAsWithoutFilterReduction(Not(('x || Not('a === abs('b))) && abs('b) === 'a)) + testSameAsWithoutFilterReduction(Not(('x || Not('a === abs('b))) && abs('b) <=> 'a)) + testSameAsWithoutFilterReduction(Not(('x || Not('a === abs('b))) && abs('b) >= 'a)) + testSameAsWithoutFilterReduction(Not(('x || Not('a === abs('b))) && abs('b) > 'a)) + testSameAsWithoutFilterReduction(Not(('x || Not('a <=> abs('b))) && abs('b) < 'a)) + testSameAsWithoutFilterReduction(Not(('x || Not('a <=> abs('b))) && abs('b) <= 'a)) + testSameAsWithoutFilterReduction(Not(('x || Not('a <=> abs('b))) && abs('b) === 'a)) + // the only exception: + // testSameAsWithoutFilterReduction(Not(('x || Not('a <=> abs('b))) && abs('b) <=> 'a)) + testSameAsWithoutFilterReduction(Not(('x || Not('a <=> abs('b))) && abs('b) >= 'a)) + testSameAsWithoutFilterReduction(Not(('x || Not('a <=> abs('b))) && abs('b) > 'a)) + testSameAsWithoutFilterReduction(Not(('x || Not('a >= abs('b))) && abs('b) < 'a)) + testSameAsWithoutFilterReduction(Not(('x || Not('a >= abs('b))) && abs('b) <= 'a)) + testSameAsWithoutFilterReduction(Not(('x || Not('a >= abs('b))) && abs('b) === 'a)) + testSameAsWithoutFilterReduction(Not(('x || Not('a >= abs('b))) && abs('b) <=> 'a)) + testSameAsWithoutFilterReduction(Not(('x || Not('a >= abs('b))) && abs('b) >= 'a)) + testSameAsWithoutFilterReduction(Not(('x || Not('a >= abs('b))) && abs('b) > 'a)) + testSameAsWithoutFilterReduction(Not(('x || Not('a > abs('b))) && abs('b) < 'a)) + testSameAsWithoutFilterReduction(Not(('x || Not('a > abs('b))) && abs('b) <= 'a)) + testSameAsWithoutFilterReduction(Not(('x || Not('a > abs('b))) && abs('b) === 'a)) + testSameAsWithoutFilterReduction(Not(('x || Not('a > abs('b))) && abs('b) <=> 'a)) + testSameAsWithoutFilterReduction(Not(('x || Not('a > abs('b))) && abs('b) >= 'a)) + testSameAsWithoutFilterReduction(Not(('x || Not('a > abs('b))) && abs('b) > 'a)) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala index 5394732f41f2..32de8b6d88cb 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala @@ -95,30 +95,49 @@ trait PlanTestBase extends PredicateHelper with SQLHelper { self: Suite => protected def normalizePlan(plan: LogicalPlan): LogicalPlan = { plan transform { case Filter(condition: Expression, child: LogicalPlan) => - Filter(splitConjunctivePredicates(condition).map(rewriteEqual).sortBy(_.hashCode()) - .reduce(And), child) + Filter(normalize(condition), child) case sample: Sample => sample.copy(seed = 0L) case Join(left, right, joinType, condition, hint) if condition.isDefined => - val newCondition = - splitConjunctivePredicates(condition.get).map(rewriteEqual).sortBy(_.hashCode()) - .reduce(And) - Join(left, right, joinType, Some(newCondition), hint) + Join(left, right, joinType, condition.map(normalize), hint) } } /** - * Rewrite [[EqualTo]] and [[EqualNullSafe]] operator to keep order. The following cases will be - * equivalent: - * 1. (a = b), (b = a); - * 2. (a <=> b), (b <=> a). + * Rewrite [[EqualTo]], [[EqualNullSafe]], [[GreaterThan]], [[GreaterThanOrEqual]], [[And]] and + * [[Or]] operators to keep order. + * The following pairs will be equivalent: + * 1. (a = b) and (b = a), + * 2. (a <=> b) and (b <=> a), + * 3. (a > b) and (b < a), + * 4. (a >= b) and (b <= a), + * 5. (a <= b AND b <= a) and (b <= a AND a <= b), + * 6. (a <= b OR b <= a) and (b <= a OR a <= b) */ - private def rewriteEqual(condition: Expression): Expression = condition match { - case eq @ EqualTo(l: Expression, r: Expression) => - Seq(l, r).sortBy(_.hashCode()).reduce(EqualTo) - case eq @ EqualNullSafe(l: Expression, r: Expression) => - Seq(l, r).sortBy(_.hashCode()).reduce(EqualNullSafe) - case _ => condition // Don't reorder. + private def normalize(expression: Expression): Expression = expression match { + case EqualTo(l: Expression, r: Expression) => + Seq(l, r) + .map(normalize) + .sortBy(p => scala.util.hashing.MurmurHash3.seqHash(Seq(p.getClass, p))) + .reduce(EqualTo) + case EqualNullSafe(l: Expression, r: Expression) => + Seq(l, r) + .map(normalize) + .sortBy(p => scala.util.hashing.MurmurHash3.seqHash(Seq(p.getClass, p))) + .reduce(EqualNullSafe) + case GreaterThan(l, r) => LessThan(normalize(r), normalize(l)) + case GreaterThanOrEqual(l, r) => LessThanOrEqual(normalize(r), normalize(l)) + case and: And => + splitConjunctivePredicates(and) + .map(normalize) + .sortBy(p => scala.util.hashing.MurmurHash3.seqHash(Seq(p.getClass, p))) + .reduce(And) + case or: Or => + splitDisjunctivePredicates(or) + .map(normalize) + .sortBy(p => scala.util.hashing.MurmurHash3.seqHash(Seq(p.getClass, p))) + .reduce(Or) + case _ => expression.mapChildren(normalize) } /** Fails the test if the two plans do not match */