diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 73be7902b998e..4c41042c53853 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -1215,6 +1215,15 @@ object InferFiltersFromConstraints extends Rule[LogicalPlan] } } + // Whether the result of this expression may be null. For example: CAST(strCol AS double) + // We will infer an IsNotNull expression for this expression to avoid skew join. + private def resultMayBeNull(exp: Expression): Boolean = exp match { + case e if !e.nullable => false + case Cast(child: Attribute, dataType, _, _) => !Cast.canUpCast(child.dataType, dataType) + case c: Coalesce if c.children.forall(_.isInstanceOf[Attribute]) => true + case _ => false + } + private def inferFilters(plan: LogicalPlan): LogicalPlan = plan.transformWithPruning( _.containsAnyPattern(FILTER, JOIN)) { case filter @ Filter(condition, child) => @@ -1227,25 +1236,42 @@ object InferFiltersFromConstraints extends Rule[LogicalPlan] } case join @ Join(left, right, joinType, conditionOpt, _) => + val leftKeys = new mutable.HashSet[Expression] + val rightKeys = new mutable.HashSet[Expression] + conditionOpt.foreach { condition => + splitConjunctivePredicates(condition).foreach { + case EqualTo(l, r) => + if (resultMayBeNull(l)) { + if (canEvaluate(l, left)) leftKeys.add(l) + if (canEvaluate(l, right)) rightKeys.add(l) + } + if (resultMayBeNull(r)) { + if (canEvaluate(r, left)) leftKeys.add(r) + if (canEvaluate(r, right)) rightKeys.add(r) + } + case _ => + } + } + joinType match { // For inner join, we can infer additional filters for both sides. LeftSemi is kind of an // inner join, it just drops the right side in the final output. case _: InnerLike | LeftSemi => val allConstraints = getAllConstraints(left, right, conditionOpt) - val newLeft = inferNewFilter(left, allConstraints) - val newRight = inferNewFilter(right, allConstraints) + val newLeft = inferNewFilter(left, allConstraints, leftKeys) + val newRight = inferNewFilter(right, allConstraints, rightKeys) join.copy(left = newLeft, right = newRight) // For right outer join, we can only infer additional filters for left side. case RightOuter => val allConstraints = getAllConstraints(left, right, conditionOpt) - val newLeft = inferNewFilter(left, allConstraints) + val newLeft = inferNewFilter(left, allConstraints, leftKeys) join.copy(left = newLeft) // For left join, we can only infer additional filters for right side. case LeftOuter | LeftAnti => val allConstraints = getAllConstraints(left, right, conditionOpt) - val newRight = inferNewFilter(right, allConstraints) + val newRight = inferNewFilter(right, allConstraints, rightKeys) join.copy(right = newRight) case _ => join @@ -1261,9 +1287,13 @@ object InferFiltersFromConstraints extends Rule[LogicalPlan] baseConstraints.union(inferAdditionalConstraints(baseConstraints)) } - private def inferNewFilter(plan: LogicalPlan, constraints: ExpressionSet): LogicalPlan = { + private def inferNewFilter( + plan: LogicalPlan, + constraints: ExpressionSet, + joinKeys: mutable.HashSet[Expression]): LogicalPlan = { val newPredicates = constraints .union(constructIsNotNullConstraints(constraints, plan.output)) + .union(ExpressionSet(joinKeys.map(IsNotNull))) .filter { c => c.references.nonEmpty && c.references.subsetOf(plan.outputSet) && c.deterministic } -- plan.constraints diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala index 72ad6ca24c1f1..3d5d2ff37ecf7 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala @@ -271,12 +271,15 @@ class InferFiltersFromConstraintsSuite extends PlanTest { val originalLeft = testRelation1.subquery('left) val originalRight = testRelation2.where('b === 1L).subquery('right) - val left = testRelation1.where(IsNotNull('a) && 'a.cast(LongType) === 1L).subquery('left) - val right = testRelation2.where(IsNotNull('b) && 'b === 1L).subquery('right) - Seq(Some("left.a".attr.cast(LongType) === "right.b".attr), Some("right.b".attr === "left.a".attr.cast(LongType))).foreach { condition => - testConstraintsAfterJoin(originalLeft, originalRight, left, right, Inner, condition) + testConstraintsAfterJoin( + originalLeft, + originalRight, + testRelation1.where(IsNotNull('a) && 'a.cast(LongType) === 1L).subquery('left), + testRelation2.where(IsNotNull('b) && 'b === 1L).subquery('right), + Inner, + condition) } Seq(Some("left.a".attr === "right.b".attr.cast(IntegerType)), @@ -285,7 +288,8 @@ class InferFiltersFromConstraintsSuite extends PlanTest { originalLeft, originalRight, testRelation1.where(IsNotNull('a)).subquery('left), - right, + testRelation2.where(IsNotNull('b) && IsNotNull('b.cast(IntegerType)) && + 'b === 1L).subquery('right), Inner, condition) } @@ -302,7 +306,13 @@ class InferFiltersFromConstraintsSuite extends PlanTest { Seq(Some("left.a".attr.cast(LongType) === "right.b".attr), Some("right.b".attr === "left.a".attr.cast(LongType))).foreach { condition => - testConstraintsAfterJoin(originalLeft, originalRight, left, right, Inner, condition) + testConstraintsAfterJoin( + originalLeft, + originalRight, + testRelation1.where(IsNotNull('a) && 'a === 1).subquery('left), + testRelation2.where(IsNotNull('b)).subquery('right), + Inner, + condition) } Seq(Some("left.a".attr === "right.b".attr.cast(IntegerType)), @@ -310,8 +320,9 @@ class InferFiltersFromConstraintsSuite extends PlanTest { testConstraintsAfterJoin( originalLeft, originalRight, - left, - testRelation2.where(IsNotNull('b) && 'b.attr.cast(IntegerType) === 1).subquery('right), + testRelation1.where(IsNotNull('a) && 'a === 1).subquery('left), + testRelation2.where(IsNotNull('b) && IsNotNull('b.cast(IntegerType)) && + 'b.attr.cast(IntegerType) === 1).subquery('right), Inner, condition) } @@ -361,4 +372,32 @@ class InferFiltersFromConstraintsSuite extends PlanTest { val optimized = Optimize.execute(originalQuery) comparePlans(optimized, correctAnswer) } + + test("SPARK-31809: Infer IsNotNull for join condition") { + val testRelation2 = LocalRelation('a.string, 'b.int) + + testConstraintsAfterJoin( + testRelation.subquery('left), + testRelation2.subquery('right), + testRelation.where(IsNotNull('a)).subquery('left), + testRelation2.where(IsNotNull('a.cast(IntegerType)) && IsNotNull('a)).subquery('right), + Inner, + Some("left.a".attr === "right.a".attr)) + + testConstraintsAfterJoin( + testRelation.subquery('left), + testRelation2.subquery('right), + testRelation.where(IsNotNull('a)).subquery('left), + testRelation2.subquery('right), + RightOuter, + Some("left.a".attr === "right.a".attr)) + + testConstraintsAfterJoin( + testRelation.subquery('left), + testRelation.subquery('right), + testRelation.where(IsNotNull(Coalesce(Seq('a, 'b)))).subquery('left), + testRelation.where(IsNotNull('c)).subquery('right), + Inner, + Some(Coalesce(Seq("left.a".attr, "left.b".attr)) === "right.c".attr)) + } }