diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala index 4c4ec000d0930..248e24bbd19f2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala @@ -55,12 +55,25 @@ trait QueryPlanConstraints extends ConstraintHelper { self: LogicalPlan => trait ConstraintHelper { + /** + * Infers an additional set of constraints from a given set of constraints. + */ + def inferAdditionalConstraints(constraints: Set[Expression]): Set[Expression] = { + var inferred = inferEqualityConstraints(constraints) + var lastInequalityInferred = Set.empty[Expression] + do { + lastInequalityInferred = inferInequalityConstraints(constraints ++ inferred) + inferred ++= lastInequalityInferred + } while (lastInequalityInferred.nonEmpty) + inferred + } + /** * Infers an additional set of constraints from a given set of equality constraints. * For e.g., if an operator has constraints of the form (`a = 5`, `a = b`), this returns an * additional constraint of the form `b = 5`. */ - def inferAdditionalConstraints(constraints: Set[Expression]): Set[Expression] = { + def inferEqualityConstraints(constraints: Set[Expression]): Set[Expression] = { var inferredConstraints = Set.empty[Expression] // IsNotNull should be constructed by `constructIsNotNullConstraints`. val predicates = constraints.filterNot(_.isInstanceOf[IsNotNull]) @@ -78,6 +91,72 @@ trait ConstraintHelper { inferredConstraints -- constraints } + /** + * Infers an additional set of constraints from a given set of inequality constraints. + * For e.g., if an operator has constraints of the form (`a > b`, `b > 5`), this returns an + * additional constraint of the form `a > 5`. + */ + def inferInequalityConstraints(constraints: Set[Expression]): Set[Expression] = { + val binaryComparisons = constraints.filter { + case _: GreaterThan => true + case _: GreaterThanOrEqual => true + case _: LessThan => true + case _: LessThanOrEqual => true + case _: EqualTo => true + case _ => false + } + + val greaterThans = binaryComparisons.map { + case EqualTo(l, r) if l.foldable => EqualTo(r, l) + case LessThan(l, r) => GreaterThan(r, l) + case LessThanOrEqual(l, r) => GreaterThanOrEqual(r, l) + case other => other + } + + val lessThans = binaryComparisons.map { + case EqualTo(l, r) if l.foldable => EqualTo(r, l) + case GreaterThan(l, r) => LessThan(r, l) + case GreaterThanOrEqual(l, r) => LessThanOrEqual(r, l) + case other => other + } + + var inferredConstraints = Set.empty[Expression] + greaterThans.foreach { + case op @ BinaryComparison(source: Attribute, destination: Expression) + if destination.foldable => + inferredConstraints ++= (greaterThans - op).map { + case GreaterThan(l, r) if r.semanticEquals(source) => + GreaterThan(l, destination) + case GreaterThanOrEqual(l, r) + if r.semanticEquals(source) && op.isInstanceOf[GreaterThan] => + GreaterThan(l, destination) + case GreaterThanOrEqual(l, r) if r.semanticEquals(source) => + GreaterThanOrEqual(l, destination) + case other => other + } + case _ => // No inference + } + + lessThans.foreach { + case op @ BinaryComparison(source: Attribute, destination: Expression) + if destination.foldable => + inferredConstraints ++= (lessThans - op).map { + case LessThan(l, r) if r.semanticEquals(source) => + LessThan(l, destination) + case LessThanOrEqual(l, r) + if r.semanticEquals(source) && op.isInstanceOf[LessThan] => + LessThan(l, destination) + case LessThanOrEqual(l, r) if r.semanticEquals(source) => + LessThanOrEqual(l, destination) + case other => other + } + case _ => // No inference + } + + (inferredConstraints -- constraints -- greaterThans -- lessThans) + .filterNot(i => constraints.exists(_.semanticEquals(i))) + } + private def replaceConstraints( constraints: Set[Expression], source: Expression, diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala index 79bd573f1d84a..5cb5697cbaf06 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala @@ -316,4 +316,126 @@ class InferFiltersFromConstraintsSuite extends PlanTest { condition) } } + + test("Constraints inferred from inequality constraints: basic") { + Seq(('a < 'b && 'b < 3, 'a < 'b && 'b < 3 && 'a < 3), // a < b && b < 3 => a < 3 + ('a < 'b && 'b <= 3, 'a < 'b && 'b <= 3 && 'a < 3), // a < b && b <= 3 => a < 3 + ('a < 'b && 'b === 3, 'a < 'b && 'b === 3 && 'a < 3), // a < b && b = 3 => a < 3 + ('a <= 'b && 'b < 3, 'a <= 'b && 'b < 3 && 'a < 3), // a <= b && b < 3 => a < 3 + ('a <= 'b && 'b <= 3, 'a <= 'b && 'b <= 3 && 'a <= 3), // a <= b && b <= 3 => a <= 3 + ('a <= 'b && 'b === 3, 'a <= 'b && 'b === 3 && 'a <= 3), // a <= b && b = 3 => a <= 3 + ('a > 'b && 'b > 3, 'a > 'b && 'b > 3 && 'a > 3), // a > b && b > 3 => a > 3 + ('a > 'b && 'b >= 3, 'a > 'b && 'b >= 3 && 'a > 3), // a > b && b >= 3 => a > 3 + ('a > 'b && 'b === 3, 'a > 'b && 'b === 3 && 'a > 3), // a > b && b = 3 => a > 3 + ('a >= 'b && 'b > 3, 'a >= 'b && 'b > 3 && 'a > 3), // a >= b && b > 3 => a > 3 + ('a >= 'b && 'b >= 3, 'a >= 'b && 'b >= 3 && 'a >= 3), // a >= b && b >= 3 => a >= 3 + ('a >= 'b && 'b === 3, 'a >= 'b && 'b === 3 && 'a >= 3) // a >= b && b = 3 => a >= 3 + ).foreach { + case (filter, inferred) => + val original = testRelation.where(filter) + val optimized = testRelation.where(IsNotNull('a) && IsNotNull('b) && inferred) + comparePlans(Optimize.execute(original.analyze), optimized.analyze) + } + } + + test("Constraints inferred from inequality constraints: join") { + Seq(("left.b".attr < "right.b".attr, 'b < 1, 'b < 1), + ("left.b".attr < "right.b".attr, 'b === 1, 'b < 1), + ("left.b".attr < "right.b".attr, 'b <= 1, 'b < 1), + ("left.b".attr <= "right.b".attr, 'b <= 1, 'b <= 1), + ("left.b".attr <= "right.b".attr, 'b === 1, 'b <= 1), + ("left.b".attr > "right.b".attr, 'b > 1, 'b > 1), + ("left.b".attr > "right.b".attr, 'b === 1, 'b > 1), + ("left.b".attr > "right.b".attr, 'b >= 1, 'b > 1), + ("left.b".attr >= "right.b".attr, 'b >= 1, 'b >= 1), + ("left.b".attr >= "right.b".attr, 'b === 1, 'b >= 1) + ).foreach { + case (cond, filter, inferred) => + val originalLeft = testRelation.subquery('left) + val originalRight = testRelation.where(filter).subquery('right) + + val left = testRelation.where(IsNotNull('a) && IsNotNull('b) && inferred).subquery('left) + val right = testRelation.where(IsNotNull('a) && IsNotNull('b) && filter).subquery('right) + val condition = Some("left.a".attr === "right.a".attr && cond) + testConstraintsAfterJoin(originalLeft, originalRight, left, right, Inner, condition) + } + } + + test("Constraints inferred from inequality constraints with cast") { + Seq(('a < 'b && 'b < 3L, 'a.cast(LongType) < 'b && 'b < 3L && 'a.cast(LongType) < 3L), + ('a < 'b && 'b <= 3L, 'a.cast(LongType) < 'b && 'b <= 3L && 'a.cast(LongType) < 3L), + ('a < 'b && 'b === 3L, 'a.cast(LongType) < 'b && 'b === 3L && 'a.cast(LongType) < 3L), + ('a <= 'b && 'b < 3L, 'a.cast(LongType) <= 'b && 'b < 3L && 'a.cast(LongType) < 3L), + ('a <= 'b && 'b <= 3L, 'a.cast(LongType) <= 'b && 'b <= 3L && 'a.cast(LongType) <= 3L), + ('a <= 'b && 'b === 3L, 'a.cast(LongType) <= 'b && 'b === 3L && 'a.cast(LongType) <= 3L), + ('a < 'b && 'b < 3, 'a.cast(LongType) < 'b && 'b < Literal(3).cast(LongType) + && 'a.cast(LongType) < Literal(3).cast(LongType)), + ('a > 'b && 'b > 3L, 'a.cast(LongType) > 'b && 'b > 3L && 'a.cast(LongType) > 3L), + ('a > 'b && 'b >= 3L, 'a.cast(LongType) > 'b && 'b >= 3L && 'a.cast(LongType) > 3L), + ('a > 'b && 'b === 3L, 'a.cast(LongType) > 'b && 'b === 3L && 'a.cast(LongType) > 3L), + ('a >= 'b && 'b > 3L, 'a.cast(LongType) >= 'b && 'b > 3L && 'a.cast(LongType) > 3L), + ('a >= 'b && 'b >= 3L, 'a.cast(LongType) >= 'b && 'b >= 3L && 'a.cast(LongType) >= 3L), + ('a >= 'b && 'b === 3L, 'a.cast(LongType) >= 'b && 'b === 3L && 'a.cast(LongType) >= 3L), + ('a > 'b && 'b > 3, 'a.cast(LongType) > 'b && 'b > Literal(3).cast(LongType) + && 'a.cast(LongType) > Literal(3).cast(LongType)) + ).foreach { + case (filter, inferred) => + val testRelation = LocalRelation('a.int, 'b.long) + val original = testRelation.where(filter) + val optimized = testRelation.where(IsNotNull('a) && IsNotNull('b) && inferred) + comparePlans(Optimize.execute(original.analyze), optimized.analyze) + } + } + + test("Constraints inferred from inequality attributes: case1") { + val condition = Some("x.a".attr > "y.a".attr) + val optimizedLeft = testRelation.where(IsNotNull('a) && 'a === 1).as("x") + val optimizedRight = testRelation.where('a < 1 && IsNotNull('a) ).as("y") + val correct = optimizedLeft.join(optimizedRight, Inner, condition) + + Seq(Literal(1) === 'a, 'a === Literal(1)).foreach { filter => + val original = testRelation.where(filter).as("x").join(testRelation.as("y"), Inner, condition) + comparePlans(Optimize.execute(original.analyze), correct.analyze) + } + } + + test("Constraints inferred from inequality attributes: case2") { + val original = testRelation.where('a < 'b && 'b < 'c && 'c < 5) + val optimized = testRelation.where(IsNotNull('a) && IsNotNull('b) && IsNotNull('c) + && 'a < 'b && 'b < 'c && 'a < 5 && 'b < 5 && 'c < 5) + comparePlans(Optimize.execute(original.analyze), optimized.analyze) + } + + test("Constraints inferred from inequality attributes: case3") { + val left = testRelation.where('b >= 3 && 'b <= 13).as("x") + val right = testRelation.as("y") + + val optimizedLeft = testRelation.where(IsNotNull('a) && IsNotNull('b) + && 'b >= 3 && 'b <= 13).as("x") + val optimizedRight = testRelation.where(IsNotNull('a) && IsNotNull('b) && IsNotNull('c) + && 'c > 3 && 'b <= 13).as("y") + val condition = Some("x.a".attr === "y.a".attr + && "x.b".attr >= "y.b".attr && "x.b".attr < "y.c".attr) + val original = left.join(right, Inner, condition) + val optimized = optimizedLeft.join(optimizedRight, Inner, condition) + comparePlans(Optimize.execute(original.analyze), optimized.analyze) + } + + test("Constraints inferred from inequality attributes: case4") { + val testRelation1 = LocalRelation('a.long, 'b.long, 'c.long).as("x") + val testRelation2 = LocalRelation('a.int, 'b.int, 'c.int).as("y") + + // y.b < 13 inferred from y.b < x.b && x.b <= 13 + val left = testRelation1.where('b <= 13L).as("x") + val right = testRelation2.as("y") + + val optimizedLeft = testRelation1.where(IsNotNull('a) && IsNotNull('b) && 'b <= 13L).as("x") + val optimizedRight = testRelation2.where(IsNotNull('a) && IsNotNull('b) + && 'b.cast(LongType) < 13L).as("y") + + val condition = Some("x.a".attr === "y.a".attr && "y.b".attr < "x.b".attr) + val original = left.join(right, Inner, condition) + val optimized = optimizedLeft.join(optimizedRight, Inner, condition) + comparePlans(Optimize.execute(original.analyze), optimized.analyze) + } }