Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1215,6 +1215,15 @@ object InferFiltersFromConstraints extends Rule[LogicalPlan]
}
}

// Whether the result of this expression may be null. For example: CAST(strCol AS double)
// We will infer an IsNotNull expression for this expression to avoid skew join.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is it better to infer IsNotNull(col) instead of IsNotNull(CAST(col AS other_type))?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can infer IsNotNull(col) already. For example:

spark.sql("create table t1 (id string, value int) using parquet")
spark.sql("create table t2 (id int, value int) using parquet")

spark.sql("select * from t1 join t2 on t1.id = t2.id").explain("extended")

Before this pr:

== Optimized Logical Plan ==
Join Inner, (cast(id#0 as int) = id#2)
:- Filter isnotnull(id#0)
:  +- Relation default.t1[id#0,value#1] parquet
+- Filter isnotnull(id#2)
   +- Relation default.t2[id#2,value#3] parquet

After this pr:

== Optimized Logical Plan ==
Join Inner, (cast(id#0 as int) = id#2)
:- Filter (isnotnull(id#0) AND isnotnull(cast(id#0 as int)))
:  +- Relation default.t1[id#0,value#1] parquet
+- Filter isnotnull(id#2)
   +- Relation default.t2[id#2,value#3] parquet

Infer isnotnull(cast(t1.id as int)) may filter out many strings that can not be casted to int.

private def resultMayBeNull(exp: Expression): Boolean = exp match {
case e if !e.nullable => false
case Cast(child: Attribute, dataType, _, _) => !Cast.canUpCast(child.dataType, dataType)
case c: Coalesce if c.children.forall(_.isInstanceOf[Attribute]) => true
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can't we rely on the NullIntolerant interface?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can infer NullIntolerant already. For example:

spark.sql("create table t1 (id string, value int) using parquet")
spark.sql("create table t2 (id int, value int) using parquet")

spark.sql("select * from t1 join t2 on t1.id = t2.id").explain("extended")

== Optimized Logical Plan ==
Join Inner, (cast(id#0 as int) = id#2)
:- Filter isnotnull(id#0)
:  +- Relation default.t1[id#0,value#1] parquet
+- Filter isnotnull(id#2)
   +- Relation default.t2[id#2,value#3] parquet

Cast is NullIntolerant. We can infer IsNotNull(t1.id) already. But I also want to Infer isnotnull(cast(t1.id as int)) because t1.id may contains many strings that can not be casted to int.

case _ => false
}

private def inferFilters(plan: LogicalPlan): LogicalPlan = plan.transformWithPruning(
_.containsAnyPattern(FILTER, JOIN)) {
case filter @ Filter(condition, child) =>
Expand All @@ -1227,25 +1236,42 @@ object InferFiltersFromConstraints extends Rule[LogicalPlan]
}

case join @ Join(left, right, joinType, conditionOpt, _) =>
val leftKeys = new mutable.HashSet[Expression]
val rightKeys = new mutable.HashSet[Expression]
conditionOpt.foreach { condition =>
splitConjunctivePredicates(condition).foreach {
case EqualTo(l, r) =>
if (resultMayBeNull(l)) {
if (canEvaluate(l, left)) leftKeys.add(l)
if (canEvaluate(l, right)) rightKeys.add(l)
}
if (resultMayBeNull(r)) {
if (canEvaluate(r, left)) leftKeys.add(r)
if (canEvaluate(r, right)) rightKeys.add(r)
}
case _ =>
}
}

joinType match {
// For inner join, we can infer additional filters for both sides. LeftSemi is kind of an
// inner join, it just drops the right side in the final output.
case _: InnerLike | LeftSemi =>
val allConstraints = getAllConstraints(left, right, conditionOpt)
val newLeft = inferNewFilter(left, allConstraints)
val newRight = inferNewFilter(right, allConstraints)
val newLeft = inferNewFilter(left, allConstraints, leftKeys)
val newRight = inferNewFilter(right, allConstraints, rightKeys)
join.copy(left = newLeft, right = newRight)

// For right outer join, we can only infer additional filters for left side.
case RightOuter =>
val allConstraints = getAllConstraints(left, right, conditionOpt)
val newLeft = inferNewFilter(left, allConstraints)
val newLeft = inferNewFilter(left, allConstraints, leftKeys)
join.copy(left = newLeft)

// For left join, we can only infer additional filters for right side.
case LeftOuter | LeftAnti =>
val allConstraints = getAllConstraints(left, right, conditionOpt)
val newRight = inferNewFilter(right, allConstraints)
val newRight = inferNewFilter(right, allConstraints, rightKeys)
join.copy(right = newRight)

case _ => join
Expand All @@ -1261,9 +1287,13 @@ object InferFiltersFromConstraints extends Rule[LogicalPlan]
baseConstraints.union(inferAdditionalConstraints(baseConstraints))
}

private def inferNewFilter(plan: LogicalPlan, constraints: ExpressionSet): LogicalPlan = {
private def inferNewFilter(
plan: LogicalPlan,
constraints: ExpressionSet,
joinKeys: mutable.HashSet[Expression]): LogicalPlan = {
val newPredicates = constraints
.union(constructIsNotNullConstraints(constraints, plan.output))
.union(ExpressionSet(joinKeys.map(IsNotNull)))
.filter { c =>
c.references.nonEmpty && c.references.subsetOf(plan.outputSet) && c.deterministic
} -- plan.constraints
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -271,12 +271,15 @@ class InferFiltersFromConstraintsSuite extends PlanTest {
val originalLeft = testRelation1.subquery('left)
val originalRight = testRelation2.where('b === 1L).subquery('right)

val left = testRelation1.where(IsNotNull('a) && 'a.cast(LongType) === 1L).subquery('left)
val right = testRelation2.where(IsNotNull('b) && 'b === 1L).subquery('right)

Seq(Some("left.a".attr.cast(LongType) === "right.b".attr),
Some("right.b".attr === "left.a".attr.cast(LongType))).foreach { condition =>
testConstraintsAfterJoin(originalLeft, originalRight, left, right, Inner, condition)
testConstraintsAfterJoin(
originalLeft,
originalRight,
testRelation1.where(IsNotNull('a) && 'a.cast(LongType) === 1L).subquery('left),
testRelation2.where(IsNotNull('b) && 'b === 1L).subquery('right),
Inner,
condition)
}

Seq(Some("left.a".attr === "right.b".attr.cast(IntegerType)),
Expand All @@ -285,7 +288,8 @@ class InferFiltersFromConstraintsSuite extends PlanTest {
originalLeft,
originalRight,
testRelation1.where(IsNotNull('a)).subquery('left),
right,
testRelation2.where(IsNotNull('b) && IsNotNull('b.cast(IntegerType)) &&
'b === 1L).subquery('right),
Inner,
condition)
}
Expand All @@ -302,16 +306,23 @@ class InferFiltersFromConstraintsSuite extends PlanTest {

Seq(Some("left.a".attr.cast(LongType) === "right.b".attr),
Some("right.b".attr === "left.a".attr.cast(LongType))).foreach { condition =>
testConstraintsAfterJoin(originalLeft, originalRight, left, right, Inner, condition)
testConstraintsAfterJoin(
originalLeft,
originalRight,
testRelation1.where(IsNotNull('a) && 'a === 1).subquery('left),
testRelation2.where(IsNotNull('b)).subquery('right),
Inner,
condition)
}

Seq(Some("left.a".attr === "right.b".attr.cast(IntegerType)),
Some("right.b".attr.cast(IntegerType) === "left.a".attr)).foreach { condition =>
testConstraintsAfterJoin(
originalLeft,
originalRight,
left,
testRelation2.where(IsNotNull('b) && 'b.attr.cast(IntegerType) === 1).subquery('right),
testRelation1.where(IsNotNull('a) && 'a === 1).subquery('left),
testRelation2.where(IsNotNull('b) && IsNotNull('b.cast(IntegerType)) &&
'b.attr.cast(IntegerType) === 1).subquery('right),
Inner,
condition)
}
Expand Down Expand Up @@ -361,4 +372,32 @@ class InferFiltersFromConstraintsSuite extends PlanTest {
val optimized = Optimize.execute(originalQuery)
comparePlans(optimized, correctAnswer)
}

test("SPARK-31809: Infer IsNotNull for join condition") {
val testRelation2 = LocalRelation('a.string, 'b.int)

testConstraintsAfterJoin(
testRelation.subquery('left),
testRelation2.subquery('right),
testRelation.where(IsNotNull('a)).subquery('left),
testRelation2.where(IsNotNull('a.cast(IntegerType)) && IsNotNull('a)).subquery('right),
Inner,
Some("left.a".attr === "right.a".attr))

testConstraintsAfterJoin(
testRelation.subquery('left),
testRelation2.subquery('right),
testRelation.where(IsNotNull('a)).subquery('left),
testRelation2.subquery('right),
RightOuter,
Some("left.a".attr === "right.a".attr))

testConstraintsAfterJoin(
testRelation.subquery('left),
testRelation.subquery('right),
testRelation.where(IsNotNull(Coalesce(Seq('a, 'b)))).subquery('left),
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hive> EXPLAIN SELECT t1.* FROM t1 JOIN t2 ON coalesce(t1.a, t1.b)=t2.a;
OK
STAGE DEPENDENCIES:
  Stage-4 is a root stage
  Stage-3 depends on stages: Stage-4
  Stage-0 depends on stages: Stage-3

STAGE PLANS:
  Stage: Stage-4
    Map Reduce Local Work
      Alias -> Map Local Tables:
        $hdt$_0:t1
          Fetch Operator
            limit: -1
      Alias -> Map Local Operator Tree:
        $hdt$_0:t1
          TableScan
            alias: t1
            Statistics: Num rows: 1 Data size: 0 Basic stats: PARTIAL Column stats: NONE
            Filter Operator
              predicate: COALESCE(a,b) is not null (type: boolean)
              Statistics: Num rows: 1 Data size: 0 Basic stats: PARTIAL Column stats: NONE
              Select Operator
                expressions: a (type: string), b (type: string), c (type: string)
                outputColumnNames: _col0, _col1, _col2
                Statistics: Num rows: 1 Data size: 0 Basic stats: PARTIAL Column stats: NONE
                HashTable Sink Operator
                  keys:
                    0 COALESCE(_col0,_col1) (type: string)
                    1 _col0 (type: string)

  Stage: Stage-3
    Map Reduce
      Map Operator Tree:
          TableScan
            alias: t2
            Statistics: Num rows: 1 Data size: 0 Basic stats: PARTIAL Column stats: NONE
            Filter Operator
              predicate: a is not null (type: boolean)
              Statistics: Num rows: 1 Data size: 0 Basic stats: PARTIAL Column stats: NONE
              Select Operator
                expressions: a (type: string)
                outputColumnNames: _col0
                Statistics: Num rows: 1 Data size: 0 Basic stats: PARTIAL Column stats: NONE
                Map Join Operator
                  condition map:
                       Inner Join 0 to 1
                  keys:
                    0 COALESCE(_col0,_col1) (type: string)
                    1 _col0 (type: string)
                  outputColumnNames: _col0, _col1, _col2
                  Statistics: Num rows: 1 Data size: 0 Basic stats: PARTIAL Column stats: NONE
                  File Output Operator
                    compressed: false
                    Statistics: Num rows: 1 Data size: 0 Basic stats: PARTIAL Column stats: NONE
                    table:
                        input format: org.apache.hadoop.mapred.SequenceFileInputFormat
                        output format: org.apache.hadoop.hive.ql.io.HiveSequenceFileOutputFormat
                        serde: org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe
      Execution mode: vectorized
      Local Work:
        Map Reduce Local Work

  Stage: Stage-0
    Fetch Operator
      limit: -1
      Processor Tree:
        ListSink

testRelation.where(IsNotNull('c)).subquery('right),
Inner,
Some(Coalesce(Seq("left.a".attr, "left.b".attr)) === "right.c".attr))
}
}