From a9eb7de5a76a407c6dfda3ce1d9a91cfb956f3a2 Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Wed, 27 Oct 2021 19:42:01 +0800 Subject: [PATCH 1/4] [SPARK-31809][SQL] Infer IsNotNull from join condition --- .../sql/catalyst/optimizer/Optimizer.scala | 40 ++++++++++++++-- .../InferFiltersFromConstraintsSuite.scala | 46 +++++++++++++++---- 2 files changed, 73 insertions(+), 13 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 73be7902b998e..b95a1350aae07 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -1215,6 +1215,14 @@ object InferFiltersFromConstraints extends Rule[LogicalPlan] } } + // Whether the result of this expression may be null. For example: CAST(strCol AS double) + // We will infer an IsNotNull expression for this expression to avoid skew join. + private def resultMayBeNull(e: Expression): Boolean = e match { + case Cast(child, dataType, _, _) => !Cast.canUpCast(child.dataType, dataType) + case _: Coalesce => true + case _ => false + } + private def inferFilters(plan: LogicalPlan): LogicalPlan = plan.transformWithPruning( _.containsAnyPattern(FILTER, JOIN)) { case filter @ Filter(condition, child) => @@ -1227,25 +1235,43 @@ object InferFiltersFromConstraints extends Rule[LogicalPlan] } case join @ Join(left, right, joinType, conditionOpt, _) => + val leftKeys = new mutable.HashSet[Expression] + val rightKeys = new mutable.HashSet[Expression] + conditionOpt.foreach { condition => + splitConjunctivePredicates(condition).foreach { + case EqualTo(l, r) if l.references.isEmpty || r.references.isEmpty => + case EqualTo(l, r) => + if (resultMayBeNull(l)) { + if (canEvaluate(l, left)) leftKeys.add(l) + if (canEvaluate(l, right)) rightKeys.add(l) + } + if (resultMayBeNull(r)) { + if (canEvaluate(r, left)) leftKeys.add(r) + if (canEvaluate(r, right)) rightKeys.add(r) + } + case _ => + } + } + joinType match { // For inner join, we can infer additional filters for both sides. LeftSemi is kind of an // inner join, it just drops the right side in the final output. case _: InnerLike | LeftSemi => val allConstraints = getAllConstraints(left, right, conditionOpt) - val newLeft = inferNewFilter(left, allConstraints) - val newRight = inferNewFilter(right, allConstraints) + val newLeft = inferNewFilter(left, allConstraints, leftKeys) + val newRight = inferNewFilter(right, allConstraints, rightKeys) join.copy(left = newLeft, right = newRight) // For right outer join, we can only infer additional filters for left side. case RightOuter => val allConstraints = getAllConstraints(left, right, conditionOpt) - val newLeft = inferNewFilter(left, allConstraints) + val newLeft = inferNewFilter(left, allConstraints, leftKeys) join.copy(left = newLeft) // For left join, we can only infer additional filters for right side. case LeftOuter | LeftAnti => val allConstraints = getAllConstraints(left, right, conditionOpt) - val newRight = inferNewFilter(right, allConstraints) + val newRight = inferNewFilter(right, allConstraints, rightKeys) join.copy(right = newRight) case _ => join @@ -1261,9 +1287,13 @@ object InferFiltersFromConstraints extends Rule[LogicalPlan] baseConstraints.union(inferAdditionalConstraints(baseConstraints)) } - private def inferNewFilter(plan: LogicalPlan, constraints: ExpressionSet): LogicalPlan = { + private def inferNewFilter( + plan: LogicalPlan, + constraints: ExpressionSet, + joinKeys: mutable.HashSet[Expression]): LogicalPlan = { val newPredicates = constraints .union(constructIsNotNullConstraints(constraints, plan.output)) + .union(ExpressionSet(joinKeys.map(IsNotNull))) .filter { c => c.references.nonEmpty && c.references.subsetOf(plan.outputSet) && c.deterministic } -- plan.constraints diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala index 72ad6ca24c1f1..08d416fada5f8 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala @@ -271,12 +271,15 @@ class InferFiltersFromConstraintsSuite extends PlanTest { val originalLeft = testRelation1.subquery('left) val originalRight = testRelation2.where('b === 1L).subquery('right) - val left = testRelation1.where(IsNotNull('a) && 'a.cast(LongType) === 1L).subquery('left) - val right = testRelation2.where(IsNotNull('b) && 'b === 1L).subquery('right) - Seq(Some("left.a".attr.cast(LongType) === "right.b".attr), Some("right.b".attr === "left.a".attr.cast(LongType))).foreach { condition => - testConstraintsAfterJoin(originalLeft, originalRight, left, right, Inner, condition) + testConstraintsAfterJoin( + originalLeft, + originalRight, + testRelation1.where(IsNotNull('a) && 'a.cast(LongType) === 1L).subquery('left), + testRelation2.where(IsNotNull('b) && 'b === 1L).subquery('right), + Inner, + condition) } Seq(Some("left.a".attr === "right.b".attr.cast(IntegerType)), @@ -285,7 +288,8 @@ class InferFiltersFromConstraintsSuite extends PlanTest { originalLeft, originalRight, testRelation1.where(IsNotNull('a)).subquery('left), - right, + testRelation2.where(IsNotNull('b) && IsNotNull('b.cast(IntegerType)) && + 'b === 1L).subquery('right), Inner, condition) } @@ -302,7 +306,13 @@ class InferFiltersFromConstraintsSuite extends PlanTest { Seq(Some("left.a".attr.cast(LongType) === "right.b".attr), Some("right.b".attr === "left.a".attr.cast(LongType))).foreach { condition => - testConstraintsAfterJoin(originalLeft, originalRight, left, right, Inner, condition) + testConstraintsAfterJoin( + originalLeft, + originalRight, + testRelation1.where(IsNotNull('a) && 'a === 1).subquery('left), + testRelation2.where(IsNotNull('b)).subquery('right), + Inner, + condition) } Seq(Some("left.a".attr === "right.b".attr.cast(IntegerType)), @@ -310,8 +320,9 @@ class InferFiltersFromConstraintsSuite extends PlanTest { testConstraintsAfterJoin( originalLeft, originalRight, - left, - testRelation2.where(IsNotNull('b) && 'b.attr.cast(IntegerType) === 1).subquery('right), + testRelation1.where(IsNotNull('a) && 'a === 1).subquery('left), + testRelation2.where(IsNotNull('b) && IsNotNull('b.cast(IntegerType)) && + 'b.attr.cast(IntegerType) === 1).subquery('right), Inner, condition) } @@ -361,4 +372,23 @@ class InferFiltersFromConstraintsSuite extends PlanTest { val optimized = Optimize.execute(originalQuery) comparePlans(optimized, correctAnswer) } + + test("SPARK-31809: Infer IsNotNull for join condition") { + testConstraintsAfterJoin( + testRelation.subquery('left), + testRelation.subquery('right), + testRelation.where(IsNotNull('a.cast(StringType).cast(DoubleType)) && IsNotNull('a)) + .subquery('left), + testRelation.where(IsNotNull('c)).subquery('right), + Inner, + Some("left.a".attr.cast(StringType).cast(DoubleType) === "right.c".attr.cast(DoubleType))) + + testConstraintsAfterJoin( + testRelation.subquery('left), + testRelation.subquery('right), + testRelation.where(IsNotNull(Coalesce(Seq('a, 'b)))).subquery('left), + testRelation.where(IsNotNull('c)).subquery('right), + Inner, + Some(Coalesce(Seq("left.a".attr, "left.b".attr)) === "right.c".attr)) + } } From 7796e5ceed08209f7ea6e4da15659418913a84a1 Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Wed, 27 Oct 2021 21:23:51 +0800 Subject: [PATCH 2/4] Fix --- .../org/apache/spark/sql/catalyst/optimizer/Optimizer.scala | 1 - 1 file changed, 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index b95a1350aae07..0016da1f64e5c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -1239,7 +1239,6 @@ object InferFiltersFromConstraints extends Rule[LogicalPlan] val rightKeys = new mutable.HashSet[Expression] conditionOpt.foreach { condition => splitConjunctivePredicates(condition).foreach { - case EqualTo(l, r) if l.references.isEmpty || r.references.isEmpty => case EqualTo(l, r) => if (resultMayBeNull(l)) { if (canEvaluate(l, left)) leftKeys.add(l) From c88566ae4c9b1d14cfc732031d14fca331e7f51f Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Thu, 28 Oct 2021 08:15:17 +0800 Subject: [PATCH 3/4] Fix test in JoinSuit --- .../sql/catalyst/optimizer/Optimizer.scala | 6 +++--- .../InferFiltersFromConstraintsSuite.scala | 19 ++++++++++++++----- 2 files changed, 17 insertions(+), 8 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 0016da1f64e5c..68834e76233be 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -1217,9 +1217,9 @@ object InferFiltersFromConstraints extends Rule[LogicalPlan] // Whether the result of this expression may be null. For example: CAST(strCol AS double) // We will infer an IsNotNull expression for this expression to avoid skew join. - private def resultMayBeNull(e: Expression): Boolean = e match { - case Cast(child, dataType, _, _) => !Cast.canUpCast(child.dataType, dataType) - case _: Coalesce => true + private def resultMayBeNull(exp: Expression): Boolean = exp match { + case Cast(child: Attribute, dataType, _, _) => !Cast.canUpCast(child.dataType, dataType) + case c: Coalesce if c.children.forall(_.isInstanceOf[Attribute]) => true case _ => false } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala index 08d416fada5f8..3d5d2ff37ecf7 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala @@ -374,14 +374,23 @@ class InferFiltersFromConstraintsSuite extends PlanTest { } test("SPARK-31809: Infer IsNotNull for join condition") { + val testRelation2 = LocalRelation('a.string, 'b.int) + testConstraintsAfterJoin( testRelation.subquery('left), - testRelation.subquery('right), - testRelation.where(IsNotNull('a.cast(StringType).cast(DoubleType)) && IsNotNull('a)) - .subquery('left), - testRelation.where(IsNotNull('c)).subquery('right), + testRelation2.subquery('right), + testRelation.where(IsNotNull('a)).subquery('left), + testRelation2.where(IsNotNull('a.cast(IntegerType)) && IsNotNull('a)).subquery('right), Inner, - Some("left.a".attr.cast(StringType).cast(DoubleType) === "right.c".attr.cast(DoubleType))) + Some("left.a".attr === "right.a".attr)) + + testConstraintsAfterJoin( + testRelation.subquery('left), + testRelation2.subquery('right), + testRelation.where(IsNotNull('a)).subquery('left), + testRelation2.subquery('right), + RightOuter, + Some("left.a".attr === "right.a".attr)) testConstraintsAfterJoin( testRelation.subquery('left), From 919492ee6095a007077119a11425abd7f08b02f6 Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Thu, 28 Oct 2021 14:59:44 +0800 Subject: [PATCH 4/4] nullability --- .../org/apache/spark/sql/catalyst/optimizer/Optimizer.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 68834e76233be..4c41042c53853 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -1218,6 +1218,7 @@ object InferFiltersFromConstraints extends Rule[LogicalPlan] // Whether the result of this expression may be null. For example: CAST(strCol AS double) // We will infer an IsNotNull expression for this expression to avoid skew join. private def resultMayBeNull(exp: Expression): Boolean = exp match { + case e if !e.nullable => false case Cast(child: Attribute, dataType, _, _) => !Cast.canUpCast(child.dataType, dataType) case c: Coalesce if c.children.forall(_.isInstanceOf[Attribute]) => true case _ => false