Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -354,7 +354,8 @@ case class SortMergeJoinExec(
}

private lazy val ((streamedPlan, streamedKeys), (bufferedPlan, bufferedKeys)) = joinType match {
case _: InnerLike => ((left, leftKeys), (right, rightKeys))
case _: InnerLike | LeftOuter => ((left, leftKeys), (right, rightKeys))
case RightOuter => ((right, rightKeys), (left, leftKeys))
case x =>
throw new IllegalArgumentException(
s"SortMergeJoin.streamedPlan/bufferedPlan should not take $x as the JoinType")
Expand All @@ -363,8 +364,9 @@ case class SortMergeJoinExec(
private lazy val streamedOutput = streamedPlan.output
private lazy val bufferedOutput = bufferedPlan.output

override def supportCodegen: Boolean = {
joinType.isInstanceOf[InnerLike]
override def supportCodegen: Boolean = joinType match {
case _: InnerLike | LeftOuter | RightOuter => true
case _ => false
}

override def inputRDDs(): Seq[RDD[InternalRow]] = {
Expand Down Expand Up @@ -431,6 +433,69 @@ case class SortMergeJoinExec(
// Copy the streamed keys as class members so they could be used in next function call.
val matchedKeyVars = copyKeys(ctx, streamedKeyVars)

// Handle the case when streamed rows has any NULL keys.
val handleStreamedAnyNull = joinType match {
case _: InnerLike =>
// Skip streamed row.
s"""
|$streamedRow = null;
|continue;
""".stripMargin
case LeftOuter | RightOuter =>
// Eagerly return streamed row. Only call `matches.clear()` when `matches.isEmpty()` is
// false, to reduce unnecessary computation.
s"""
|if (!$matches.isEmpty()) {
| $matches.clear();
|}
|return false;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

        // Eagerly return streamed row.
        s"""
           |$matches.clear();
           |return false;
         """.stripMargin

?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wanted to avoid clear() if isEmpty() is true. ExternalAppendOnlyUnsafeRowArray.isEmpty() is very cheap but clear() sets multiple variables.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. Could you leave some comments about it there?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@maropu - added comment.

""".stripMargin
case x =>
throw new IllegalArgumentException(
s"SortMergeJoin.genScanner should not take $x as the JoinType")
}

// Handle the case when streamed keys has no match with buffered side.
val handleStreamedWithoutMatch = joinType match {
case _: InnerLike =>
// Skip streamed row.
s"$streamedRow = null;"
case LeftOuter | RightOuter =>
// Eagerly return with streamed row.
"return false;"
case x =>
throw new IllegalArgumentException(
s"SortMergeJoin.genScanner should not take $x as the JoinType")
}

// Generate a function to scan both streamed and buffered sides to find a match.
// Return whether a match is found.
//
// `streamedIter`: the iterator for streamed side.
// `bufferedIter`: the iterator for buffered side.
// `streamedRow`: the current row from streamed side.
// When `streamedIter` is empty, `streamedRow` is null.
// `matches`: the rows from buffered side already matched with `streamedRow`.
// `matches` is buffered and reused for all `streamedRow`s having same join keys.
// If there is no match with `streamedRow`, `matches` is empty.
// `bufferedRow`: the current matched row from buffered side.
//
// The function has the following step:
// - Step 1: Find the next `streamedRow` with non-null join keys.
// For `streamedRow` with null join keys (`handleStreamedAnyNull`):
// 1. Inner join: skip the row. `matches` will be cleared later when hitting the
// next `streamedRow` with non-null join keys.
// 2. Left/Right Outer join: clear the previous `matches` if needed, keep the row,
// and return false.
//
// - Step 2: Find the `matches` from buffered side having same join keys with `streamedRow`.
// Clear `matches` if we hit a new `streamedRow`, as we need to find new matches.
// Use `bufferedRow` to iterate buffered side to put all matched rows into
// `matches`. Return true when getting all matched rows.
// For `streamedRow` without `matches` (`handleStreamedWithoutMatch`):
// 1. Inner join: skip the row.
// 2. Left/Right Outer join: keep the row and return false (with `matches` being
// empty).
ctx.addNewFunction("findNextJoinRows",
s"""
|private boolean findNextJoinRows(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the outer case, a return value is not used?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks reusing the inner-case code makes the outer-case code inefficient. For example, if there are too many matched duplicate rows in the buffered side, it seems we don't need to put all the rows in matches, right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the outer case, a return value is not used?

Yes. Otherwise it's very hard to re-use code in findNextJoinRows. I can further make more change to not return anything for findNextJoinRows in case it's an outer join. Do we want to do that?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For example, if there are too many matched duplicate rows in the buffered side, it seems we don't need to put all the rows in matches, right?

Why we don't need to put all the rows? We anyway need to evaluate all the rows on buffered side for join, right?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why we don't need to put all the rows? We anyway need to evaluate all the rows on buffered side for join, right?

Oh, my bad. ya, you're right. I misunderstood it.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the outer case, a return value is not used?
Yes. Otherwise it's very hard to re-use code in findNextJoinRows. I can further make more change to not return anything for findNextJoinRows in case it's an outer join. Do we want to do that?

okay, the current one looks fine. Let's just wait for a @cloud-fan comment here.

Copy link
Member

@maropu maropu May 10, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

btw, in the current generated code, it seems conditionCheck is evaluated outside findNextJoinRows. We cannot evaluate it inside findNextJoinRows to avoid putting unmached rows in matches?

Copy link
Contributor Author

@c21 c21 May 10, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@maropu - No I think we need buffer anyway. The buffered rows has same join keys with current streamed row. But there can be multiple followed streamed rows having same join keys, as the buffered rows. Even though buffered rows cannot match condition with current streamed row, they may match condition with followed streamed rows. I think this is how current sort merge join (code-gen & iterator) is designed.

Expand All @@ -443,8 +508,7 @@ case class SortMergeJoinExec(
| $streamedRow = (InternalRow) streamedIter.next();
| ${streamedKeyVars.map(_.code).mkString("\n")}
| if ($streamedAnyNull) {
| $streamedRow = null;
| continue;
| $handleStreamedAnyNull
| }
| if (!$matches.isEmpty()) {
| ${genComparison(ctx, streamedKeyVars, matchedKeyVars)}
Expand Down Expand Up @@ -475,8 +539,9 @@ case class SortMergeJoinExec(
| if (!$matches.isEmpty()) {
| ${matchedKeyVars.map(_.code).mkString("\n")}
| return true;
| } else {
| $handleStreamedWithoutMatch
| }
| $streamedRow = null;
| } else {
| $matches.add((UnsafeRow) $bufferedRow);
| $bufferedRow = null;
Expand All @@ -501,7 +566,7 @@ case class SortMergeJoinExec(
ctx: CodegenContext,
streamedRow: String): (Seq[ExprCode], Seq[String]) = {
ctx.INPUT_ROW = streamedRow
left.output.zipWithIndex.map { case (a, i) =>
streamedPlan.output.zipWithIndex.map { case (a, i) =>
Copy link
Contributor Author

@c21 c21 May 11, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sorry forgot to change this in #32495, fix it now here. cc @maropu.

val value = ctx.freshName("value")
val valueCode = CodeGenerator.getValue(streamedRow, a.dataType, i.toString)
val javaType = CodeGenerator.javaType(a.dataType)
Expand Down Expand Up @@ -569,7 +634,15 @@ case class SortMergeJoinExec(

val iterator = ctx.freshName("iterator")
val numOutput = metricTerm(ctx, "numOutputRows")
val resultVars = streamedVars ++ bufferedVars
val resultVars = joinType match {
case _: InnerLike | LeftOuter =>
streamedVars ++ bufferedVars
case RightOuter =>
bufferedVars ++ streamedVars
case x =>
throw new IllegalArgumentException(
s"SortMergeJoin.doProduce should not take $x as the JoinType")
}

val (beforeLoop, condCheck) = if (condition.isDefined) {
// Split the code of creating variables based on whether it's used by condition or not.
Expand All @@ -580,21 +653,27 @@ case class SortMergeJoinExec(
ctx.currentVars = resultVars
val cond = BindReferences.bindReference(condition.get, output).genCode(ctx)
// evaluate the columns those used by condition before loop
val before = s"""
val before =
s"""
|boolean $loaded = false;
|$streamedBefore
""".stripMargin

val checking = s"""
|$bufferedBefore
|${cond.code}
|if (${cond.isNull} || !${cond.value}) continue;
|if (!$loaded) {
| $loaded = true;
| $streamedAfter
|}
|$bufferedAfter
""".stripMargin
val checking =
s"""
|$bufferedBefore
|if ($bufferedRow != null) {
| ${cond.code}
| if (${cond.isNull} || !${cond.value}) {
| continue;
| }
|}
|if (!$loaded) {
| $loaded = true;
| $streamedAfter
|}
|$bufferedAfter
""".stripMargin
(before, checking)
} else {
(evaluateVariables(streamedVars), "")
Expand All @@ -603,21 +682,55 @@ case class SortMergeJoinExec(
val thisPlan = ctx.addReferenceObj("plan", this)
val eagerCleanup = s"$thisPlan.cleanupResources();"

s"""
|while (findNextJoinRows($streamedInput, $bufferedInput)) {
| ${streamedVarDecl.mkString("\n")}
| ${beforeLoop.trim}
| scala.collection.Iterator<UnsafeRow> $iterator = $matches.generateIterator();
| while ($iterator.hasNext()) {
| InternalRow $bufferedRow = (InternalRow) $iterator.next();
| ${condCheck.trim}
| $numOutput.add(1);
| ${consume(ctx, resultVars)}
| }
| if (shouldStop()) return;
|}
|$eagerCleanup
lazy val innerJoin =
s"""
|while (findNextJoinRows($streamedInput, $bufferedInput)) {
| ${streamedVarDecl.mkString("\n")}
| ${beforeLoop.trim}
| scala.collection.Iterator<UnsafeRow> $iterator = $matches.generateIterator();
| while ($iterator.hasNext()) {
| InternalRow $bufferedRow = (InternalRow) $iterator.next();
| ${condCheck.trim}
| $numOutput.add(1);
| ${consume(ctx, resultVars)}
| }
| if (shouldStop()) return;
|}
|$eagerCleanup
""".stripMargin

lazy val outerJoin = {
val hasOutputRow = ctx.freshName("hasOutputRow")
s"""
|while ($streamedInput.hasNext()) {
| findNextJoinRows($streamedInput, $bufferedInput);
| ${streamedVarDecl.mkString("\n")}
| ${beforeLoop.trim}
| scala.collection.Iterator<UnsafeRow> $iterator = $matches.generateIterator();
| boolean $hasOutputRow = false;
|
| // the last iteration of this loop is to emit an empty row if there is no matched rows.
| while ($iterator.hasNext() || !$hasOutputRow) {
| InternalRow $bufferedRow = $iterator.hasNext() ?
| (InternalRow) $iterator.next() : null;
| ${condCheck.trim}
| $hasOutputRow = true;
| $numOutput.add(1);
| ${consume(ctx, resultVars)}
| }
| if (shouldStop()) return;
|}
|$eagerCleanup
""".stripMargin
}

joinType match {
case _: InnerLike => innerJoin
case LeftOuter | RightOuter => outerJoin
case x =>
throw new IllegalArgumentException(
s"SortMergeJoin.doProduce should not take $x as the JoinType")
}
}

override protected def withNewChildrenInternal(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ TakeOrderedAndProject (36)
: :- * Project (20)
: : +- * BroadcastHashJoin Inner BuildRight (19)
: : :- * Project (13)
: : : +- SortMergeJoin LeftOuter (12)
: : : +- * SortMergeJoin LeftOuter (12)
: : : :- * Sort (5)
: : : : +- Exchange (4)
: : : : +- * Filter (3)
Expand Down Expand Up @@ -86,7 +86,7 @@ Arguments: hashpartitioning(cr_order_number#9, cr_item_sk#8, 5), ENSURE_REQUIREM
Input [3]: [cr_item_sk#8, cr_order_number#9, cr_refunded_cash#10]
Arguments: [cr_order_number#9 ASC NULLS FIRST, cr_item_sk#8 ASC NULLS FIRST], false, 0

(12) SortMergeJoin
(12) SortMergeJoin [codegen id : 8]
Left keys [2]: [cs_order_number#3, cs_item_sk#2]
Right keys [2]: [cr_order_number#9, cr_item_sk#8]
Join condition: None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@ TakeOrderedAndProject [w_state,i_item_id,sales_before,sales_after]
Project [cs_warehouse_sk,cs_sales_price,cs_sold_date_sk,cr_refunded_cash,i_item_id]
BroadcastHashJoin [cs_item_sk,i_item_sk]
Project [cs_warehouse_sk,cs_item_sk,cs_sales_price,cs_sold_date_sk,cr_refunded_cash]
InputAdapter
SortMergeJoin [cs_order_number,cs_item_sk,cr_order_number,cr_item_sk]
SortMergeJoin [cs_order_number,cs_item_sk,cr_order_number,cr_item_sk]
InputAdapter
WholeStageCodegen (2)
Sort [cs_order_number,cs_item_sk]
InputAdapter
Expand All @@ -25,6 +25,7 @@ TakeOrderedAndProject [w_state,i_item_id,sales_before,sales_after]
Scan parquet default.catalog_sales [cs_warehouse_sk,cs_item_sk,cs_order_number,cs_sales_price,cs_sold_date_sk]
SubqueryBroadcast [d_date_sk] #1
ReusedExchange [d_date_sk,d_date] #3
InputAdapter
WholeStageCodegen (4)
Sort [cr_order_number,cr_item_sk]
InputAdapter
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ TakeOrderedAndProject (36)
: :- * Project (19)
: : +- * BroadcastHashJoin Inner BuildRight (18)
: : :- * Project (13)
: : : +- SortMergeJoin LeftOuter (12)
: : : +- * SortMergeJoin LeftOuter (12)
: : : :- * Sort (5)
: : : : +- Exchange (4)
: : : : +- * Filter (3)
Expand Down Expand Up @@ -86,7 +86,7 @@ Arguments: hashpartitioning(cr_order_number#9, cr_item_sk#8, 5), ENSURE_REQUIREM
Input [3]: [cr_item_sk#8, cr_order_number#9, cr_refunded_cash#10]
Arguments: [cr_order_number#9 ASC NULLS FIRST, cr_item_sk#8 ASC NULLS FIRST], false, 0

(12) SortMergeJoin
(12) SortMergeJoin [codegen id : 8]
Left keys [2]: [cs_order_number#3, cs_item_sk#2]
Right keys [2]: [cr_order_number#9, cr_item_sk#8]
Join condition: None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@ TakeOrderedAndProject [w_state,i_item_id,sales_before,sales_after]
Project [cs_item_sk,cs_sales_price,cs_sold_date_sk,cr_refunded_cash,w_state]
BroadcastHashJoin [cs_warehouse_sk,w_warehouse_sk]
Project [cs_warehouse_sk,cs_item_sk,cs_sales_price,cs_sold_date_sk,cr_refunded_cash]
InputAdapter
SortMergeJoin [cs_order_number,cs_item_sk,cr_order_number,cr_item_sk]
SortMergeJoin [cs_order_number,cs_item_sk,cr_order_number,cr_item_sk]
InputAdapter
WholeStageCodegen (2)
Sort [cs_order_number,cs_item_sk]
InputAdapter
Expand All @@ -25,6 +25,7 @@ TakeOrderedAndProject [w_state,i_item_id,sales_before,sales_after]
Scan parquet default.catalog_sales [cs_warehouse_sk,cs_item_sk,cs_order_number,cs_sales_price,cs_sold_date_sk]
SubqueryBroadcast [d_date_sk] #1
ReusedExchange [d_date_sk,d_date] #3
InputAdapter
WholeStageCodegen (4)
Sort [cr_order_number,cr_item_sk]
InputAdapter
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ TakeOrderedAndProject (80)
+- Exchange (78)
+- * HashAggregate (77)
+- * Project (76)
+- SortMergeJoin LeftOuter (75)
+- * SortMergeJoin LeftOuter (75)
:- * Sort (68)
: +- Exchange (67)
: +- * Project (66)
Expand Down Expand Up @@ -410,7 +410,7 @@ Arguments: hashpartitioning(cr_item_sk#43, cr_order_number#44, 5), ENSURE_REQUIR
Input [2]: [cr_item_sk#43, cr_order_number#44]
Arguments: [cr_item_sk#43 ASC NULLS FIRST, cr_order_number#44 ASC NULLS FIRST], false, 0

(75) SortMergeJoin
(75) SortMergeJoin [codegen id : 20]
Left keys [2]: [cs_item_sk#4, cs_order_number#6]
Right keys [2]: [cr_item_sk#43, cr_order_number#44]
Join condition: None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@ TakeOrderedAndProject [total_cnt,i_item_desc,w_warehouse_name,d_week_seq,no_prom
WholeStageCodegen (20)
HashAggregate [i_item_desc,w_warehouse_name,d_week_seq] [count,count]
Project [w_warehouse_name,i_item_desc,d_week_seq]
InputAdapter
SortMergeJoin [cs_item_sk,cs_order_number,cr_item_sk,cr_order_number]
SortMergeJoin [cs_item_sk,cs_order_number,cr_item_sk,cr_order_number]
InputAdapter
WholeStageCodegen (17)
Sort [cs_item_sk,cs_order_number]
InputAdapter
Expand Down Expand Up @@ -121,6 +121,7 @@ TakeOrderedAndProject [total_cnt,i_item_desc,w_warehouse_name,d_week_seq,no_prom
ColumnarToRow
InputAdapter
Scan parquet default.promotion [p_promo_sk]
InputAdapter
WholeStageCodegen (19)
Sort [cr_item_sk,cr_order_number]
InputAdapter
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ TakeOrderedAndProject (74)
+- Exchange (72)
+- * HashAggregate (71)
+- * Project (70)
+- SortMergeJoin LeftOuter (69)
+- * SortMergeJoin LeftOuter (69)
:- * Sort (62)
: +- Exchange (61)
: +- * Project (60)
Expand Down Expand Up @@ -380,7 +380,7 @@ Arguments: hashpartitioning(cr_item_sk#41, cr_order_number#42, 5), ENSURE_REQUIR
Input [2]: [cr_item_sk#41, cr_order_number#42]
Arguments: [cr_item_sk#41 ASC NULLS FIRST, cr_order_number#42 ASC NULLS FIRST], false, 0

(69) SortMergeJoin
(69) SortMergeJoin [codegen id : 14]
Left keys [2]: [cs_item_sk#4, cs_order_number#6]
Right keys [2]: [cr_item_sk#41, cr_order_number#42]
Join condition: None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@ TakeOrderedAndProject [total_cnt,i_item_desc,w_warehouse_name,d_week_seq,no_prom
WholeStageCodegen (14)
HashAggregate [i_item_desc,w_warehouse_name,d_week_seq] [count,count]
Project [w_warehouse_name,i_item_desc,d_week_seq]
InputAdapter
SortMergeJoin [cs_item_sk,cs_order_number,cr_item_sk,cr_order_number]
SortMergeJoin [cs_item_sk,cs_order_number,cr_item_sk,cr_order_number]
InputAdapter
WholeStageCodegen (11)
Sort [cs_item_sk,cs_order_number]
InputAdapter
Expand Down Expand Up @@ -103,6 +103,7 @@ TakeOrderedAndProject [total_cnt,i_item_desc,w_warehouse_name,d_week_seq,no_prom
ColumnarToRow
InputAdapter
Scan parquet default.promotion [p_promo_sk]
InputAdapter
WholeStageCodegen (13)
Sort [cr_item_sk,cr_order_number]
InputAdapter
Expand Down
Loading