From b59a81075d4b2692e1a68974a452143cb94e2778 Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Mon, 10 Feb 2020 13:57:58 +0800 Subject: [PATCH 1/9] Constraints should be inferred from inequality attributes --- .../plans/logical/QueryPlanConstraints.scala | 6 +++++- .../optimizer/InferFiltersFromConstraintsSuite.scala | 12 ++++++++++++ 2 files changed, 17 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala index 1355003358b9f..1ab485120d292 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala @@ -67,6 +67,10 @@ trait ConstraintHelper { val candidateConstraints = constraints - eq inferredConstraints ++= replaceConstraints(candidateConstraints, l, r) inferredConstraints ++= replaceConstraints(candidateConstraints, r, l) + case eq @ EqualTo(l: Attribute, r : Literal) => + inferredConstraints ++= replaceConstraints(constraints - eq, l, r) + case eq @ EqualTo(l : Literal, r: Attribute) => + inferredConstraints ++= replaceConstraints(constraints - eq, r, l) case _ => // No inference } inferredConstraints -- constraints @@ -75,7 +79,7 @@ trait ConstraintHelper { private def replaceConstraints( constraints: Set[Expression], source: Expression, - destination: Attribute): Set[Expression] = constraints.map(_ transform { + destination: Expression): Set[Expression] = constraints.map(_ transform { case e: Expression if e.semanticEquals(source) => destination }) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala index 974bc781d36ab..e7ca05abd1153 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala @@ -263,4 +263,16 @@ class InferFiltersFromConstraintsSuite extends PlanTest { val y = testRelation.subquery('y) testConstraintsAfterJoin(x, y, x.where(IsNotNull('a)), y, RightOuter) } + + test("SPARK-30768: Constraints should be inferred from inequality attributes") { + val condition = Some("x.a".attr > "y.a".attr) + val optimizedLeft = testRelation.where(IsNotNull('a) && 'a === 1).as("x") + val optimizedRight = testRelation.where(Literal(1) > 'a && IsNotNull('a) ).as("y") + val correct = optimizedLeft.join(optimizedRight, Inner, condition) + + Seq(Literal(1) === 'a, 'a === Literal(1)).foreach { filter => + val original = testRelation.where(filter).as("x").join(testRelation.as("y"), Inner, condition) + comparePlans(Optimize.execute(original.analyze), correct.analyze) + } + } } From 999126cf899816e4c4f2035f70c9099fceb63516 Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Wed, 12 Feb 2020 23:47:55 +0800 Subject: [PATCH 2/9] Constraints should be inferred from inequality constraints --- .../plans/logical/QueryPlanConstraints.scala | 61 +++++++++++++++++-- .../InferFiltersFromConstraintsSuite.scala | 31 ++++++++-- 2 files changed, 83 insertions(+), 9 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala index 1ab485120d292..abc2d7a64f00e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala @@ -61,19 +61,52 @@ trait ConstraintHelper { * additional constraint of the form `b = 5`. */ def inferAdditionalConstraints(constraints: Set[Expression]): Set[Expression] = { + val binaryComparisons = constraints.filter { + case _: GreaterThan => true + case _: GreaterThanOrEqual => true + case _: LessThan => true + case _: LessThanOrEqual => true + case _: EqualTo => true + case _ => false + } + + val greaterThans = binaryComparisons.map { + case LessThan(l, r) => GreaterThan(r, l) + case LessThanOrEqual(l, r) => GreaterThanOrEqual(r, l) + case other => other + } + + val lessThans = binaryComparisons.map { + case GreaterThan(l, r) => LessThan(r, l) + case GreaterThanOrEqual(l, r) => LessThanOrEqual(r, l) + case other => other + } + var inferredConstraints = Set.empty[Expression] constraints.foreach { case eq @ EqualTo(l: Attribute, r: Attribute) => val candidateConstraints = constraints - eq inferredConstraints ++= replaceConstraints(candidateConstraints, l, r) inferredConstraints ++= replaceConstraints(candidateConstraints, r, l) - case eq @ EqualTo(l: Attribute, r : Literal) => - inferredConstraints ++= replaceConstraints(constraints - eq, l, r) - case eq @ EqualTo(l : Literal, r: Attribute) => - inferredConstraints ++= replaceConstraints(constraints - eq, r, l) case _ => // No inference } - inferredConstraints -- constraints + + greaterThans.foreach { + case gt @ GreaterThan(l: Attribute, r: Attribute) => + inferredConstraints ++= inferInequalityConstraints(greaterThans - gt, r, l, gt) + case gt @ GreaterThanOrEqual(l: Attribute, r: Attribute) => + inferredConstraints ++= inferInequalityConstraints(greaterThans - gt, r, l, gt) + case _ => // No inference + } + + lessThans.foreach { + case lt @ LessThan(l: Attribute, r: Attribute) => + inferredConstraints ++= inferInequalityConstraints(lessThans - lt, r, l, lt) + case lt @ LessThanOrEqual(l: Attribute, r: Attribute) => + inferredConstraints ++= inferInequalityConstraints(lessThans - lt, r, l, lt) + case _ => // No inference + } + inferredConstraints -- constraints -- greaterThans -- lessThans } private def replaceConstraints( @@ -83,6 +116,24 @@ trait ConstraintHelper { case e: Expression if e.semanticEquals(source) => destination }) + private def inferInequalityConstraints( + constraints: Set[Expression], + source: Expression, + destination: Expression, + binaryComparison: BinaryComparison): Set[Expression] = constraints.map { + case EqualTo(l, r) if l.semanticEquals(source) => + binaryComparison.makeCopy(Array(destination, r)) + case EqualTo(l, r) if r.semanticEquals(source) => + binaryComparison.makeCopy(Array(destination, l)) + case gt @ GreaterThan(l, r) if l.semanticEquals(source) => + gt.makeCopy(Array(destination, r)) + case gt @ LessThan(l, r) if l.semanticEquals(source) => + gt.makeCopy(Array(destination, r)) + case BinaryComparison(l, r) if l.semanticEquals(source) => + binaryComparison.makeCopy(Array(destination, r)) + case other => other + } + /** * Infers a set of `isNotNull` constraints from null intolerant expressions as well as * non-nullable attributes. For e.g., if an expression is of the form (`a > 5`), this diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala index e7ca05abd1153..54fed5a06824b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala @@ -46,8 +46,8 @@ class InferFiltersFromConstraintsSuite extends PlanTest { y: LogicalPlan, expectedLeft: LogicalPlan, expectedRight: LogicalPlan, - joinType: JoinType) = { - val condition = Some("x.a".attr === "y.a".attr) + joinType: JoinType, + condition: Option[Expression] = Some("x.a".attr === "y.a".attr)) = { val originalQuery = x.join(y, joinType, condition).analyze val correctAnswer = expectedLeft.join(expectedRight, joinType, condition).analyze val optimized = Optimize.execute(originalQuery) @@ -264,10 +264,33 @@ class InferFiltersFromConstraintsSuite extends PlanTest { testConstraintsAfterJoin(x, y, x.where(IsNotNull('a)), y, RightOuter) } - test("SPARK-30768: Constraints should be inferred from inequality attributes") { + test("Constraints should be inferred from inequality constraints: basic") { + Seq(("left.b".attr < "right.b".attr, 'b < 1, 'b < 1), // a < b && b < c => a < c + ("left.b".attr < "right.b".attr, 'b === 1, 'b < 1), // a < b && b = c => a < c + ("left.b".attr < "right.b".attr, 'b <= 1, 'b < 1), // a < b && b <= c => a < c + ("left.b".attr <= "right.b".attr, 'b <= 1, 'b <= 1), // a <= b && b <= c => a <= c + ("left.b".attr <= "right.b".attr, 'b === 1, 'b <= 1), // a <= b && b = c => a <= c + ("left.b".attr > "right.b".attr, 'b > 1, 'b > 1), // a > b && b > c => a > c + ("left.b".attr > "right.b".attr, 'b === 1, 'b > 1), // a > b && b > c => a > c + ("left.b".attr > "right.b".attr, 'b >= 1, 'b > 1), // a > b && b >= c => a > c + ("left.b".attr >= "right.b".attr, 'b >= 1, 'b >= 1), // a >= b && b >= c => a >= c + ("left.b".attr >= "right.b".attr, 'b === 1, 'b >= 1) // a >= b && b >= c => a >= c + ).foreach { + case (cond, filter, inferred) => + val originalLeft = testRelation.subquery('left) + val originalRight = testRelation.where(filter).subquery('right) + + val left = testRelation.where(IsNotNull('a) && IsNotNull('b) && inferred).subquery('left) + val right = testRelation.where(IsNotNull('a) && IsNotNull('b) && filter).subquery('right) + val condition = Some("left.a".attr === "right.a".attr && cond) + testConstraintsAfterJoin(originalLeft, originalRight, left, right, Inner, condition) + } + } + + test("Constraints should be inferred from inequality attributes: simple case") { val condition = Some("x.a".attr > "y.a".attr) val optimizedLeft = testRelation.where(IsNotNull('a) && 'a === 1).as("x") - val optimizedRight = testRelation.where(Literal(1) > 'a && IsNotNull('a) ).as("y") + val optimizedRight = testRelation.where('a < 1 && IsNotNull('a) ).as("y") val correct = optimizedLeft.join(optimizedRight, Inner, condition) Seq(Literal(1) === 'a, 'a === Literal(1)).foreach { filter => From d4eb2d7ed17016e5c7e5983d8e0ee2411833fb92 Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Fri, 14 Feb 2020 22:00:39 +0800 Subject: [PATCH 3/9] fix test errors in TPCH --- .../plans/logical/QueryPlanConstraints.scala | 41 +++++++++++++------ 1 file changed, 28 insertions(+), 13 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala index b757cc59a3365..926c11d084384 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala @@ -61,6 +61,34 @@ trait ConstraintHelper { * additional constraint of the form `b = 5`. */ def inferAdditionalConstraints(constraints: Set[Expression]): Set[Expression] = { + var inferred = inferEqualityConstraints(constraints) + var newInferred = Set.empty[Expression] + do { + newInferred = inferInequalityConstraints(constraints ++ inferred) + inferred ++= newInferred + } while (newInferred.nonEmpty) + inferred + } + + def inferEqualityConstraints(constraints: Set[Expression]): Set[Expression] = { + var inferredConstraints = Set.empty[Expression] + // IsNotNull should be constructed by `constructIsNotNullConstraints`. + val predicates = constraints.filterNot(_.isInstanceOf[IsNotNull]) + predicates.foreach { + case eq @ EqualTo(l: Attribute, r: Attribute) => + val candidateConstraints = predicates - eq + inferredConstraints ++= replaceConstraints(candidateConstraints, l, r) + inferredConstraints ++= replaceConstraints(candidateConstraints, r, l) + case eq @ EqualTo(l @ Cast(_: Attribute, _, _), r: Attribute) => + inferredConstraints ++= replaceConstraints(predicates - eq, r, l) + case eq @ EqualTo(l: Attribute, r @ Cast(_: Attribute, _, _)) => + inferredConstraints ++= replaceConstraints(predicates - eq, l, r) + case _ => // No inference + } + inferredConstraints -- predicates + } + + def inferInequalityConstraints(constraints: Set[Expression]): Set[Expression] = { val binaryComparisons = constraints.filter { case _: GreaterThan => true case _: GreaterThanOrEqual => true @@ -83,19 +111,6 @@ trait ConstraintHelper { } var inferredConstraints = Set.empty[Expression] - // IsNotNull should be constructed by `constructIsNotNullConstraints`. - val predicates = constraints.filterNot(_.isInstanceOf[IsNotNull]) - predicates.foreach { - case eq @ EqualTo(l: Attribute, r: Attribute) => - val candidateConstraints = predicates - eq - inferredConstraints ++= replaceConstraints(candidateConstraints, l, r) - inferredConstraints ++= replaceConstraints(candidateConstraints, r, l) - case eq @ EqualTo(l @ Cast(_: Attribute, _, _), r: Attribute) => - inferredConstraints ++= replaceConstraints(predicates - eq, r, l) - case eq @ EqualTo(l: Attribute, r @ Cast(_: Attribute, _, _)) => - inferredConstraints ++= replaceConstraints(predicates - eq, l, r) - case _ => // No inference - } greaterThans.foreach { case gt @ GreaterThan(l: Attribute, r: Attribute) => From df679e673887dd9b8a795c8f5447ba871a6be03f Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Sun, 16 Feb 2020 23:02:35 +0800 Subject: [PATCH 4/9] Add more tests --- .../plans/logical/QueryPlanConstraints.scala | 32 +++++-- .../InferFiltersFromConstraintsSuite.scala | 94 ++++++++++++++++--- 2 files changed, 106 insertions(+), 20 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala index 926c11d084384..8ebf7f457de42 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala @@ -56,20 +56,23 @@ trait QueryPlanConstraints extends ConstraintHelper { self: LogicalPlan => trait ConstraintHelper { /** - * Infers an additional set of constraints from a given set of equality constraints. - * For e.g., if an operator has constraints of the form (`a = 5`, `a = b`), this returns an - * additional constraint of the form `b = 5`. + * Infers an additional set of constraints from a given set of constraints. */ def inferAdditionalConstraints(constraints: Set[Expression]): Set[Expression] = { var inferred = inferEqualityConstraints(constraints) - var newInferred = Set.empty[Expression] + var lastInequalityInferred = Set.empty[Expression] do { - newInferred = inferInequalityConstraints(constraints ++ inferred) - inferred ++= newInferred - } while (newInferred.nonEmpty) + lastInequalityInferred = inferInequalityConstraints(constraints ++ inferred) + inferred ++= lastInequalityInferred + } while (lastInequalityInferred.nonEmpty) inferred } + /** + * Infers an additional set of constraints from a given set of equality constraints. + * For e.g., if an operator has constraints of the form (`a = 5`, `a = b`), this returns an + * additional constraint of the form `b = 5`. + */ def inferEqualityConstraints(constraints: Set[Expression]): Set[Expression] = { var inferredConstraints = Set.empty[Expression] // IsNotNull should be constructed by `constructIsNotNullConstraints`. @@ -85,9 +88,14 @@ trait ConstraintHelper { inferredConstraints ++= replaceConstraints(predicates - eq, l, r) case _ => // No inference } - inferredConstraints -- predicates + inferredConstraints -- constraints } + /** + * Infers an additional set of constraints from a given set of inequality constraints. + * For e.g., if an operator has constraints of the form (`a > b`, `b > 5`), this returns an + * additional constraint of the form `a > 5`. + */ def inferInequalityConstraints(constraints: Set[Expression]): Set[Expression] = { val binaryComparisons = constraints.filter { case _: GreaterThan => true @@ -117,6 +125,10 @@ trait ConstraintHelper { inferredConstraints ++= inferInequalityConstraints(greaterThans - gt, r, l, gt) case gt @ GreaterThanOrEqual(l: Attribute, r: Attribute) => inferredConstraints ++= inferInequalityConstraints(greaterThans - gt, r, l, gt) + case gt @ GreaterThan(l @ Cast(_: Attribute, _, _), r: Attribute) => + inferredConstraints ++= inferInequalityConstraints(greaterThans - gt, r, l, gt) + case gt @ GreaterThanOrEqual(l @ Cast(_: Attribute, _, _), r: Attribute) => + inferredConstraints ++= inferInequalityConstraints(greaterThans - gt, r, l, gt) case _ => // No inference } @@ -125,6 +137,10 @@ trait ConstraintHelper { inferredConstraints ++= inferInequalityConstraints(lessThans - lt, r, l, lt) case lt @ LessThanOrEqual(l: Attribute, r: Attribute) => inferredConstraints ++= inferInequalityConstraints(lessThans - lt, r, l, lt) + case lt @ LessThan(l @ Cast(_: Attribute, _, _), r: Attribute) => + inferredConstraints ++= inferInequalityConstraints(lessThans - lt, r, l, lt) + case lt @ LessThanOrEqual(l @ Cast(_: Attribute, _, _), r: Attribute) => + inferredConstraints ++= inferInequalityConstraints(lessThans - lt, r, l, lt) case _ => // No inference } inferredConstraints -- constraints -- greaterThans -- lessThans diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala index 001abc301556a..fb8635f83b660 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala @@ -317,17 +317,38 @@ class InferFiltersFromConstraintsSuite extends PlanTest { } } - test("Constraints should be inferred from inequality constraints: basic") { - Seq(("left.b".attr < "right.b".attr, 'b < 1, 'b < 1), // a < b && b < c => a < c - ("left.b".attr < "right.b".attr, 'b === 1, 'b < 1), // a < b && b = c => a < c - ("left.b".attr < "right.b".attr, 'b <= 1, 'b < 1), // a < b && b <= c => a < c - ("left.b".attr <= "right.b".attr, 'b <= 1, 'b <= 1), // a <= b && b <= c => a <= c - ("left.b".attr <= "right.b".attr, 'b === 1, 'b <= 1), // a <= b && b = c => a <= c - ("left.b".attr > "right.b".attr, 'b > 1, 'b > 1), // a > b && b > c => a > c - ("left.b".attr > "right.b".attr, 'b === 1, 'b > 1), // a > b && b > c => a > c - ("left.b".attr > "right.b".attr, 'b >= 1, 'b > 1), // a > b && b >= c => a > c - ("left.b".attr >= "right.b".attr, 'b >= 1, 'b >= 1), // a >= b && b >= c => a >= c - ("left.b".attr >= "right.b".attr, 'b === 1, 'b >= 1) // a >= b && b >= c => a >= c + test("Constraints inferred from inequality constraints: basic") { + Seq(('a < 'b && 'b < 3, 'a < 'b && 'b < 3 && 'a < 3), // a < b && b < 3 => a < 3 + ('a < 'b && 'b <= 3, 'a < 'b && 'b <= 3 && 'a < 3), // a < b && b <= 3 => a < 3 + ('a < 'b && 'b === 3, 'a < 'b && 'b === 3 && 'a < 3), // a < b && b = 3 => a < 3 + ('a <= 'b && 'b < 3, 'a <= 'b && 'b < 3 && 'a < 3), // a <= b && b < 3 => a < 3 + ('a <= 'b && 'b <= 3, 'a <= 'b && 'b <= 3 && 'a <= 3), // a <= b && b <= 3 => a <= 3 + ('a <= 'b && 'b === 3, 'a <= 'b && 'b === 3 && 'a <= 3), // a <= b && b = 3 => a <= 3 + ('a > 'b && 'b > 3, 'a > 'b && 'b > 3 && 'a > 3), // a > b && b > 3 => a > 3 + ('a > 'b && 'b >= 3, 'a > 'b && 'b >= 3 && 'a > 3), // a > b && b >= 3 => a > 3 + ('a > 'b && 'b === 3, 'a > 'b && 'b === 3 && 'a > 3), // a > b && b = 3 => a > 3 + ('a >= 'b && 'b > 3, 'a >= 'b && 'b > 3 && 'a > 3), // a >= b && b > 3 => a > 3 + ('a >= 'b && 'b >= 3, 'a >= 'b && 'b >= 3 && 'a >= 3), // a >= b && b >= 3 => a >= 3 + ('a >= 'b && 'b === 3, 'a >= 'b && 'b === 3 && 'a >= 3) // a >= b && b = 3 => a >= 3 + ).foreach { + case (filter, inferred) => + val original = testRelation.where(filter) + val optimized = testRelation.where(IsNotNull('a) && IsNotNull('b) && inferred) + comparePlans(Optimize.execute(original.analyze), optimized.analyze) + } + } + + test("Constraints inferred from inequality constraints: join") { + Seq(("left.b".attr < "right.b".attr, 'b < 1, 'b < 1), + ("left.b".attr < "right.b".attr, 'b === 1, 'b < 1), + ("left.b".attr < "right.b".attr, 'b <= 1, 'b < 1), + ("left.b".attr <= "right.b".attr, 'b <= 1, 'b <= 1), + ("left.b".attr <= "right.b".attr, 'b === 1, 'b <= 1), + ("left.b".attr > "right.b".attr, 'b > 1, 'b > 1), + ("left.b".attr > "right.b".attr, 'b === 1, 'b > 1), + ("left.b".attr > "right.b".attr, 'b >= 1, 'b > 1), + ("left.b".attr >= "right.b".attr, 'b >= 1, 'b >= 1), + ("left.b".attr >= "right.b".attr, 'b === 1, 'b >= 1) ).foreach { case (cond, filter, inferred) => val originalLeft = testRelation.subquery('left) @@ -340,7 +361,33 @@ class InferFiltersFromConstraintsSuite extends PlanTest { } } - test("Constraints should be inferred from inequality attributes: simple case") { + test("Constraints inferred from inequality constraints with cast") { + Seq(('a < 'b && 'b < 3L, 'a.cast(LongType) < 'b && 'b < 3L && 'a.cast(LongType) < 3L), + ('a < 'b && 'b <= 3L, 'a.cast(LongType) < 'b && 'b <= 3L && 'a.cast(LongType) < 3L), + ('a < 'b && 'b === 3L, 'a.cast(LongType) < 'b && 'b === 3L && 'a.cast(LongType) < 3L), + ('a <= 'b && 'b < 3L, 'a.cast(LongType) <= 'b && 'b < 3L && 'a.cast(LongType) < 3L), + ('a <= 'b && 'b <= 3L, 'a.cast(LongType) <= 'b && 'b <= 3L && 'a.cast(LongType) <= 3L), + ('a <= 'b && 'b === 3L, 'a.cast(LongType) <= 'b && 'b === 3L && 'a.cast(LongType) <= 3L), + ('a < 'b && 'b < 3, 'a.cast(LongType) < 'b && 'b < Literal(3).cast(LongType) + && 'a.cast(LongType) < Literal(3).cast(LongType)), + ('a > 'b && 'b > 3L, 'a.cast(LongType) > 'b && 'b > 3L && 'a.cast(LongType) > 3L), + ('a > 'b && 'b >= 3L, 'a.cast(LongType) > 'b && 'b >= 3L && 'a.cast(LongType) > 3L), + ('a > 'b && 'b === 3L, 'a.cast(LongType) > 'b && 'b === 3L && 'a.cast(LongType) > 3L), + ('a >= 'b && 'b > 3L, 'a.cast(LongType) >= 'b && 'b > 3L && 'a.cast(LongType) > 3L), + ('a >= 'b && 'b >= 3L, 'a.cast(LongType) >= 'b && 'b >= 3L && 'a.cast(LongType) >= 3L), + ('a >= 'b && 'b === 3L, 'a.cast(LongType) >= 'b && 'b === 3L && 'a.cast(LongType) >= 3L), + ('a > 'b && 'b > 3, 'a.cast(LongType) > 'b && 'b > Literal(3).cast(LongType) + && 'a.cast(LongType) > Literal(3).cast(LongType)) + ).foreach { + case (filter, inferred) => + val testRelation = LocalRelation('a.int, 'b.long) + val original = testRelation.where(filter) + val optimized = testRelation.where(IsNotNull('a) && IsNotNull('b) && inferred) + comparePlans(Optimize.execute(original.analyze), optimized.analyze) + } + } + + test("Constraints inferred from inequality attributes: case1") { val condition = Some("x.a".attr > "y.a".attr) val optimizedLeft = testRelation.where(IsNotNull('a) && 'a === 1).as("x") val optimizedRight = testRelation.where('a < 1 && IsNotNull('a) ).as("y") @@ -351,4 +398,27 @@ class InferFiltersFromConstraintsSuite extends PlanTest { comparePlans(Optimize.execute(original.analyze), correct.analyze) } } + + test("Constraints inferred from inequality attributes: case2") { + val original = testRelation.where('a < 'b && 'b < 'c && 'c < 5) + val optimized = testRelation.where(IsNotNull('a) && IsNotNull('b) && IsNotNull('c) + && 'a < 'b && 'b < 'c && 'a < 5 && 'b < 5 && 'c < 5 && 'c > 'a) + comparePlans(Optimize.execute(original.analyze), optimized.analyze) + } + + test("Constraints inferred from inequality attributes: case3") { + val left = testRelation.where('b >= 3 && 'b <= 13).as("x") + val right = testRelation.as("y") + + val optimizedLeft = testRelation.where(IsNotNull('a) && IsNotNull('b) + && 'b >= 3 && 'b <= 13).as("x") + val optimizedRight = testRelation.where(IsNotNull('a) && IsNotNull('b) && IsNotNull('c) + && 'b < 'c && 'c > 3 && 'b <= 13).as("y") + + val condition = Some("x.a".attr === "y.a".attr + && "x.b".attr >= "y.b".attr && "x.b".attr < "y.c".attr) + val original = left.join(right, Inner, condition) + val optimized = optimizedLeft.join(optimizedRight, Inner, condition) + comparePlans(Optimize.execute(original.analyze), optimized.analyze) + } } From f9a90aae3545c6ef1f1451cd25f084d2cbbed5ea Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Mon, 17 Feb 2020 18:39:39 +0800 Subject: [PATCH 5/9] Remove semanticEquals Expression --- .../sql/catalyst/plans/logical/QueryPlanConstraints.scala | 5 ++++- .../optimizer/InferFiltersFromConstraintsSuite.scala | 2 +- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala index 8ebf7f457de42..8cbfe08da7b42 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala @@ -143,7 +143,10 @@ trait ConstraintHelper { inferredConstraints ++= inferInequalityConstraints(lessThans - lt, r, l, lt) case _ => // No inference } - inferredConstraints -- constraints -- greaterThans -- lessThans + (inferredConstraints -- constraints -- greaterThans -- lessThans).foldLeft(Set[Expression]()) { + case (acc, e) if acc.exists(_.semanticEquals(e)) => acc.dropWhile(_.semanticEquals(e)) + case (acc, e) => acc + e + } } private def replaceConstraints( diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala index fb8635f83b660..cfd2ab6541b5e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala @@ -402,7 +402,7 @@ class InferFiltersFromConstraintsSuite extends PlanTest { test("Constraints inferred from inequality attributes: case2") { val original = testRelation.where('a < 'b && 'b < 'c && 'c < 5) val optimized = testRelation.where(IsNotNull('a) && IsNotNull('b) && IsNotNull('c) - && 'a < 'b && 'b < 'c && 'a < 5 && 'b < 5 && 'c < 5 && 'c > 'a) + && 'a < 'b && 'b < 'c && 'a < 5 && 'b < 5 && 'c < 5) comparePlans(Optimize.execute(original.analyze), optimized.analyze) } From bfa60393789673b1cd77fbda038cdc8c17c1a866 Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Wed, 26 Feb 2020 22:58:02 +0800 Subject: [PATCH 6/9] fix --- .../catalyst/plans/logical/QueryPlanConstraints.scala | 10 ++++------ .../optimizer/InferFiltersFromConstraintsSuite.scala | 2 +- 2 files changed, 5 insertions(+), 7 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala index 8cbfe08da7b42..40f4ddf3ed942 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala @@ -143,10 +143,8 @@ trait ConstraintHelper { inferredConstraints ++= inferInequalityConstraints(lessThans - lt, r, l, lt) case _ => // No inference } - (inferredConstraints -- constraints -- greaterThans -- lessThans).foldLeft(Set[Expression]()) { - case (acc, e) if acc.exists(_.semanticEquals(e)) => acc.dropWhile(_.semanticEquals(e)) - case (acc, e) => acc + e - } + (inferredConstraints -- constraints -- greaterThans -- lessThans) + .filterNot(i => constraints.exists(_.semanticEquals(i))) } private def replaceConstraints( @@ -167,8 +165,8 @@ trait ConstraintHelper { binaryComparison.makeCopy(Array(destination, l)) case gt @ GreaterThan(l, r) if l.semanticEquals(source) => gt.makeCopy(Array(destination, r)) - case gt @ LessThan(l, r) if l.semanticEquals(source) => - gt.makeCopy(Array(destination, r)) + case lt @ LessThan(l, r) if l.semanticEquals(source) => + lt.makeCopy(Array(destination, r)) case BinaryComparison(l, r) if l.semanticEquals(source) => binaryComparison.makeCopy(Array(destination, r)) case other => other diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala index cfd2ab6541b5e..6025d15440509 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala @@ -402,7 +402,7 @@ class InferFiltersFromConstraintsSuite extends PlanTest { test("Constraints inferred from inequality attributes: case2") { val original = testRelation.where('a < 'b && 'b < 'c && 'c < 5) val optimized = testRelation.where(IsNotNull('a) && IsNotNull('b) && IsNotNull('c) - && 'a < 'b && 'b < 'c && 'a < 5 && 'b < 5 && 'c < 5) + && 'a < 'b && 'b < 'c && 'c > 'a && 'a < 5 && 'b < 5 && 'c < 5) comparePlans(Optimize.execute(original.analyze), optimized.analyze) } From 248e3cca6c5f77b768ff738ae5fecc1f9476defb Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Sat, 7 Mar 2020 21:29:13 +0800 Subject: [PATCH 7/9] Add support another real case --- .../plans/logical/QueryPlanConstraints.scala | 43 ++++++++++--------- .../InferFiltersFromConstraintsSuite.scala | 18 ++++++++ 2 files changed, 41 insertions(+), 20 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala index 40f4ddf3ed942..0efd61dce41c5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala @@ -122,25 +122,33 @@ trait ConstraintHelper { greaterThans.foreach { case gt @ GreaterThan(l: Attribute, r: Attribute) => - inferredConstraints ++= inferInequalityConstraints(greaterThans - gt, r, l, gt) + inferredConstraints ++= replaceInequalityConstraints(greaterThans, r, l, gt) case gt @ GreaterThanOrEqual(l: Attribute, r: Attribute) => - inferredConstraints ++= inferInequalityConstraints(greaterThans - gt, r, l, gt) + inferredConstraints ++= replaceInequalityConstraints(greaterThans, r, l, gt) case gt @ GreaterThan(l @ Cast(_: Attribute, _, _), r: Attribute) => - inferredConstraints ++= inferInequalityConstraints(greaterThans - gt, r, l, gt) + inferredConstraints ++= replaceInequalityConstraints(greaterThans, r, l, gt) case gt @ GreaterThanOrEqual(l @ Cast(_: Attribute, _, _), r: Attribute) => - inferredConstraints ++= inferInequalityConstraints(greaterThans - gt, r, l, gt) + inferredConstraints ++= replaceInequalityConstraints(greaterThans, r, l, gt) + case gt @ GreaterThan(l: Attribute, r @ Cast(_: Attribute, _, _)) => + inferredConstraints ++= replaceInequalityConstraints(greaterThans, r, l, gt) + case gt @ GreaterThanOrEqual(l: Attribute, r @ Cast(_: Attribute, _, _)) => + inferredConstraints ++= replaceInequalityConstraints(greaterThans, r, l, gt) case _ => // No inference } lessThans.foreach { case lt @ LessThan(l: Attribute, r: Attribute) => - inferredConstraints ++= inferInequalityConstraints(lessThans - lt, r, l, lt) + inferredConstraints ++= replaceInequalityConstraints(lessThans, r, l, lt) case lt @ LessThanOrEqual(l: Attribute, r: Attribute) => - inferredConstraints ++= inferInequalityConstraints(lessThans - lt, r, l, lt) + inferredConstraints ++= replaceInequalityConstraints(lessThans, r, l, lt) case lt @ LessThan(l @ Cast(_: Attribute, _, _), r: Attribute) => - inferredConstraints ++= inferInequalityConstraints(lessThans - lt, r, l, lt) + inferredConstraints ++= replaceInequalityConstraints(lessThans, r, l, lt) case lt @ LessThanOrEqual(l @ Cast(_: Attribute, _, _), r: Attribute) => - inferredConstraints ++= inferInequalityConstraints(lessThans - lt, r, l, lt) + inferredConstraints ++= replaceInequalityConstraints(lessThans, r, l, lt) + case lt @ LessThan(l: Attribute, r @ Cast(_: Attribute, _, _)) => + inferredConstraints ++= replaceInequalityConstraints(lessThans, r, l, lt) + case lt @ LessThanOrEqual(l: Attribute, r @ Cast(_: Attribute, _, _)) => + inferredConstraints ++= replaceInequalityConstraints(lessThans, r, l, lt) case _ => // No inference } (inferredConstraints -- constraints -- greaterThans -- lessThans) @@ -154,21 +162,16 @@ trait ConstraintHelper { case e: Expression if e.semanticEquals(source) => destination }) - private def inferInequalityConstraints( + private def replaceInequalityConstraints( constraints: Set[Expression], source: Expression, destination: Expression, - binaryComparison: BinaryComparison): Set[Expression] = constraints.map { - case EqualTo(l, r) if l.semanticEquals(source) => - binaryComparison.makeCopy(Array(destination, r)) - case EqualTo(l, r) if r.semanticEquals(source) => - binaryComparison.makeCopy(Array(destination, l)) - case gt @ GreaterThan(l, r) if l.semanticEquals(source) => - gt.makeCopy(Array(destination, r)) - case lt @ LessThan(l, r) if l.semanticEquals(source) => - lt.makeCopy(Array(destination, r)) - case BinaryComparison(l, r) if l.semanticEquals(source) => - binaryComparison.makeCopy(Array(destination, r)) + op: BinaryComparison): Set[Expression] = (constraints - op).map { + case EqualTo(l, r) if l.semanticEquals(source) => op.makeCopy(Array(destination, r)) + case EqualTo(l, r) if r.semanticEquals(source) => op.makeCopy(Array(destination, l)) + case gt @ GreaterThan(l, r) if l.semanticEquals(source) => gt.makeCopy(Array(destination, r)) + case lt @ LessThan(l, r) if l.semanticEquals(source) => lt.makeCopy(Array(destination, r)) + case BinaryComparison(l, r) if l.semanticEquals(source) => op.makeCopy(Array(destination, r)) case other => other } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala index 6025d15440509..16383387054e8 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala @@ -421,4 +421,22 @@ class InferFiltersFromConstraintsSuite extends PlanTest { val optimized = optimizedLeft.join(optimizedRight, Inner, condition) comparePlans(Optimize.execute(original.analyze), optimized.analyze) } + + test("Constraints inferred from inequality attributes: case4") { + val testRelation1 = LocalRelation('a.long, 'b.long, 'c.long).as("x") + val testRelation2 = LocalRelation('a.int, 'b.int, 'c.int).as("y") + + // y.b < 13 inferred from y.b < x.b && x.b <= 13 + val left = testRelation1.where('b <= 13L).as("x") + val right = testRelation2.as("y") + + val optimizedLeft = testRelation1.where(IsNotNull('a) && IsNotNull('b) && 'b <= 13L).as("x") + val optimizedRight = testRelation2.where(IsNotNull('a) && IsNotNull('b) + && 'b.cast(LongType) < 13L).as("y") + + val condition = Some("x.a".attr === "y.a".attr && "y.b".attr < "x.b".attr) + val original = left.join(right, Inner, condition) + val optimized = optimizedLeft.join(optimizedRight, Inner, condition) + comparePlans(Optimize.execute(original.analyze), optimized.analyze) + } } From af55a08af7efea673e9c67a5c8224e47afd4d505 Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Mon, 8 Jun 2020 22:11:11 +0800 Subject: [PATCH 8/9] Only infer foldable constraints --- .../plans/logical/QueryPlanConstraints.scala | 58 +++++++------------ .../InferFiltersFromConstraintsSuite.scala | 5 +- 2 files changed, 24 insertions(+), 39 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala index 0efd61dce41c5..92ef8d04ca56c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala @@ -107,49 +107,28 @@ trait ConstraintHelper { } val greaterThans = binaryComparisons.map { + case EqualTo(l, r) if l.foldable => EqualTo(r, l) case LessThan(l, r) => GreaterThan(r, l) case LessThanOrEqual(l, r) => GreaterThanOrEqual(r, l) case other => other } val lessThans = binaryComparisons.map { + case EqualTo(l, r) if l.foldable => EqualTo(r, l) case GreaterThan(l, r) => LessThan(r, l) case GreaterThanOrEqual(l, r) => LessThanOrEqual(r, l) case other => other } var inferredConstraints = Set.empty[Expression] - - greaterThans.foreach { - case gt @ GreaterThan(l: Attribute, r: Attribute) => - inferredConstraints ++= replaceInequalityConstraints(greaterThans, r, l, gt) - case gt @ GreaterThanOrEqual(l: Attribute, r: Attribute) => - inferredConstraints ++= replaceInequalityConstraints(greaterThans, r, l, gt) - case gt @ GreaterThan(l @ Cast(_: Attribute, _, _), r: Attribute) => - inferredConstraints ++= replaceInequalityConstraints(greaterThans, r, l, gt) - case gt @ GreaterThanOrEqual(l @ Cast(_: Attribute, _, _), r: Attribute) => - inferredConstraints ++= replaceInequalityConstraints(greaterThans, r, l, gt) - case gt @ GreaterThan(l: Attribute, r @ Cast(_: Attribute, _, _)) => - inferredConstraints ++= replaceInequalityConstraints(greaterThans, r, l, gt) - case gt @ GreaterThanOrEqual(l: Attribute, r @ Cast(_: Attribute, _, _)) => - inferredConstraints ++= replaceInequalityConstraints(greaterThans, r, l, gt) - case _ => // No inference - } - - lessThans.foreach { - case lt @ LessThan(l: Attribute, r: Attribute) => - inferredConstraints ++= replaceInequalityConstraints(lessThans, r, l, lt) - case lt @ LessThanOrEqual(l: Attribute, r: Attribute) => - inferredConstraints ++= replaceInequalityConstraints(lessThans, r, l, lt) - case lt @ LessThan(l @ Cast(_: Attribute, _, _), r: Attribute) => - inferredConstraints ++= replaceInequalityConstraints(lessThans, r, l, lt) - case lt @ LessThanOrEqual(l @ Cast(_: Attribute, _, _), r: Attribute) => - inferredConstraints ++= replaceInequalityConstraints(lessThans, r, l, lt) - case lt @ LessThan(l: Attribute, r @ Cast(_: Attribute, _, _)) => - inferredConstraints ++= replaceInequalityConstraints(lessThans, r, l, lt) - case lt @ LessThanOrEqual(l: Attribute, r @ Cast(_: Attribute, _, _)) => - inferredConstraints ++= replaceInequalityConstraints(lessThans, r, l, lt) - case _ => // No inference + Seq(greaterThans, lessThans).foreach { comparisons => + comparisons.foreach { + case b @ BinaryComparison(l: Attribute, r: Expression) if r.foldable => + inferredConstraints ++= replaceInequalityConstraints(comparisons, l, r, b) + case b @ BinaryComparison(l @ Cast(_: Attribute, _, _), r: Expression) if r.foldable => + inferredConstraints ++= replaceInequalityConstraints(comparisons, l, r, b) + case _ => // No inference + } } (inferredConstraints -- constraints -- greaterThans -- lessThans) .filterNot(i => constraints.exists(_.semanticEquals(i))) @@ -167,11 +146,18 @@ trait ConstraintHelper { source: Expression, destination: Expression, op: BinaryComparison): Set[Expression] = (constraints - op).map { - case EqualTo(l, r) if l.semanticEquals(source) => op.makeCopy(Array(destination, r)) - case EqualTo(l, r) if r.semanticEquals(source) => op.makeCopy(Array(destination, l)) - case gt @ GreaterThan(l, r) if l.semanticEquals(source) => gt.makeCopy(Array(destination, r)) - case lt @ LessThan(l, r) if l.semanticEquals(source) => lt.makeCopy(Array(destination, r)) - case BinaryComparison(l, r) if l.semanticEquals(source) => op.makeCopy(Array(destination, r)) + case gt @ GreaterThan(l, r) if r.semanticEquals(source) => + gt.copy(l, destination) + case GreaterThanOrEqual(l, r) if r.semanticEquals(source) && op.isInstanceOf[GreaterThan] => + op.makeCopy(Array(l, destination)) + case gt @ GreaterThanOrEqual(l, r) if r.semanticEquals(source) => + gt.copy(l, destination) + case lt @ LessThan(l, r) if r.semanticEquals(source) => + lt.copy(l, destination) + case LessThanOrEqual(l, r) if r.semanticEquals(source) && op.isInstanceOf[LessThan] => + op.makeCopy(Array(l, destination)) + case lt @ LessThanOrEqual(l, r) if r.semanticEquals(source) => + lt.copy(l, destination) case other => other } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala index 16383387054e8..5cb5697cbaf06 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala @@ -402,7 +402,7 @@ class InferFiltersFromConstraintsSuite extends PlanTest { test("Constraints inferred from inequality attributes: case2") { val original = testRelation.where('a < 'b && 'b < 'c && 'c < 5) val optimized = testRelation.where(IsNotNull('a) && IsNotNull('b) && IsNotNull('c) - && 'a < 'b && 'b < 'c && 'c > 'a && 'a < 5 && 'b < 5 && 'c < 5) + && 'a < 'b && 'b < 'c && 'a < 5 && 'b < 5 && 'c < 5) comparePlans(Optimize.execute(original.analyze), optimized.analyze) } @@ -413,8 +413,7 @@ class InferFiltersFromConstraintsSuite extends PlanTest { val optimizedLeft = testRelation.where(IsNotNull('a) && IsNotNull('b) && 'b >= 3 && 'b <= 13).as("x") val optimizedRight = testRelation.where(IsNotNull('a) && IsNotNull('b) && IsNotNull('c) - && 'b < 'c && 'c > 3 && 'b <= 13).as("y") - + && 'c > 3 && 'b <= 13).as("y") val condition = Some("x.a".attr === "y.a".attr && "x.b".attr >= "y.b".attr && "x.b".attr < "y.c".attr) val original = left.join(right, Inner, condition) From 5c76b9d700d834182760906577951bcaed49d147 Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Fri, 12 Jun 2020 18:03:12 +0800 Subject: [PATCH 9/9] Simplify --- .../plans/logical/QueryPlanConstraints.scala | 57 ++++++++++--------- 1 file changed, 30 insertions(+), 27 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala index 92ef8d04ca56c..248e24bbd19f2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala @@ -121,15 +121,38 @@ trait ConstraintHelper { } var inferredConstraints = Set.empty[Expression] - Seq(greaterThans, lessThans).foreach { comparisons => - comparisons.foreach { - case b @ BinaryComparison(l: Attribute, r: Expression) if r.foldable => - inferredConstraints ++= replaceInequalityConstraints(comparisons, l, r, b) - case b @ BinaryComparison(l @ Cast(_: Attribute, _, _), r: Expression) if r.foldable => - inferredConstraints ++= replaceInequalityConstraints(comparisons, l, r, b) - case _ => // No inference + greaterThans.foreach { + case op @ BinaryComparison(source: Attribute, destination: Expression) + if destination.foldable => + inferredConstraints ++= (greaterThans - op).map { + case GreaterThan(l, r) if r.semanticEquals(source) => + GreaterThan(l, destination) + case GreaterThanOrEqual(l, r) + if r.semanticEquals(source) && op.isInstanceOf[GreaterThan] => + GreaterThan(l, destination) + case GreaterThanOrEqual(l, r) if r.semanticEquals(source) => + GreaterThanOrEqual(l, destination) + case other => other } + case _ => // No inference } + + lessThans.foreach { + case op @ BinaryComparison(source: Attribute, destination: Expression) + if destination.foldable => + inferredConstraints ++= (lessThans - op).map { + case LessThan(l, r) if r.semanticEquals(source) => + LessThan(l, destination) + case LessThanOrEqual(l, r) + if r.semanticEquals(source) && op.isInstanceOf[LessThan] => + LessThan(l, destination) + case LessThanOrEqual(l, r) if r.semanticEquals(source) => + LessThanOrEqual(l, destination) + case other => other + } + case _ => // No inference + } + (inferredConstraints -- constraints -- greaterThans -- lessThans) .filterNot(i => constraints.exists(_.semanticEquals(i))) } @@ -141,26 +164,6 @@ trait ConstraintHelper { case e: Expression if e.semanticEquals(source) => destination }) - private def replaceInequalityConstraints( - constraints: Set[Expression], - source: Expression, - destination: Expression, - op: BinaryComparison): Set[Expression] = (constraints - op).map { - case gt @ GreaterThan(l, r) if r.semanticEquals(source) => - gt.copy(l, destination) - case GreaterThanOrEqual(l, r) if r.semanticEquals(source) && op.isInstanceOf[GreaterThan] => - op.makeCopy(Array(l, destination)) - case gt @ GreaterThanOrEqual(l, r) if r.semanticEquals(source) => - gt.copy(l, destination) - case lt @ LessThan(l, r) if r.semanticEquals(source) => - lt.copy(l, destination) - case LessThanOrEqual(l, r) if r.semanticEquals(source) && op.isInstanceOf[LessThan] => - op.makeCopy(Array(l, destination)) - case lt @ LessThanOrEqual(l, r) if r.semanticEquals(source) => - lt.copy(l, destination) - case other => other - } - /** * Infers a set of `isNotNull` constraints from null intolerant expressions as well as * non-nullable attributes. For e.g., if an expression is of the form (`a > 5`), this