-
Notifications
You must be signed in to change notification settings - Fork 28.9k
[SPARK-35349][SQL] Add code-gen for left/right outer sort merge join #32476
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
982af12
4a57664
44b210f
765b247
617f89c
429edcc
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -354,7 +354,8 @@ case class SortMergeJoinExec( | |
| } | ||
|
|
||
| private lazy val ((streamedPlan, streamedKeys), (bufferedPlan, bufferedKeys)) = joinType match { | ||
| case _: InnerLike => ((left, leftKeys), (right, rightKeys)) | ||
| case _: InnerLike | LeftOuter => ((left, leftKeys), (right, rightKeys)) | ||
| case RightOuter => ((right, rightKeys), (left, leftKeys)) | ||
| case x => | ||
| throw new IllegalArgumentException( | ||
| s"SortMergeJoin.streamedPlan/bufferedPlan should not take $x as the JoinType") | ||
|
|
@@ -363,8 +364,9 @@ case class SortMergeJoinExec( | |
| private lazy val streamedOutput = streamedPlan.output | ||
| private lazy val bufferedOutput = bufferedPlan.output | ||
|
|
||
| override def supportCodegen: Boolean = { | ||
| joinType.isInstanceOf[InnerLike] | ||
| override def supportCodegen: Boolean = joinType match { | ||
| case _: InnerLike | LeftOuter | RightOuter => true | ||
| case _ => false | ||
| } | ||
|
|
||
| override def inputRDDs(): Seq[RDD[InternalRow]] = { | ||
|
|
@@ -431,6 +433,69 @@ case class SortMergeJoinExec( | |
| // Copy the streamed keys as class members so they could be used in next function call. | ||
| val matchedKeyVars = copyKeys(ctx, streamedKeyVars) | ||
|
|
||
| // Handle the case when streamed rows has any NULL keys. | ||
| val handleStreamedAnyNull = joinType match { | ||
| case _: InnerLike => | ||
| // Skip streamed row. | ||
| s""" | ||
| |$streamedRow = null; | ||
| |continue; | ||
| """.stripMargin | ||
| case LeftOuter | RightOuter => | ||
| // Eagerly return streamed row. Only call `matches.clear()` when `matches.isEmpty()` is | ||
| // false, to reduce unnecessary computation. | ||
| s""" | ||
| |if (!$matches.isEmpty()) { | ||
| | $matches.clear(); | ||
| |} | ||
| |return false; | ||
| """.stripMargin | ||
| case x => | ||
| throw new IllegalArgumentException( | ||
| s"SortMergeJoin.genScanner should not take $x as the JoinType") | ||
| } | ||
|
|
||
| // Handle the case when streamed keys has no match with buffered side. | ||
| val handleStreamedWithoutMatch = joinType match { | ||
| case _: InnerLike => | ||
| // Skip streamed row. | ||
| s"$streamedRow = null;" | ||
| case LeftOuter | RightOuter => | ||
| // Eagerly return with streamed row. | ||
| "return false;" | ||
| case x => | ||
| throw new IllegalArgumentException( | ||
| s"SortMergeJoin.genScanner should not take $x as the JoinType") | ||
| } | ||
|
|
||
| // Generate a function to scan both streamed and buffered sides to find a match. | ||
| // Return whether a match is found. | ||
| // | ||
| // `streamedIter`: the iterator for streamed side. | ||
| // `bufferedIter`: the iterator for buffered side. | ||
| // `streamedRow`: the current row from streamed side. | ||
| // When `streamedIter` is empty, `streamedRow` is null. | ||
| // `matches`: the rows from buffered side already matched with `streamedRow`. | ||
| // `matches` is buffered and reused for all `streamedRow`s having same join keys. | ||
| // If there is no match with `streamedRow`, `matches` is empty. | ||
| // `bufferedRow`: the current matched row from buffered side. | ||
| // | ||
| // The function has the following step: | ||
| // - Step 1: Find the next `streamedRow` with non-null join keys. | ||
| // For `streamedRow` with null join keys (`handleStreamedAnyNull`): | ||
| // 1. Inner join: skip the row. `matches` will be cleared later when hitting the | ||
| // next `streamedRow` with non-null join keys. | ||
| // 2. Left/Right Outer join: clear the previous `matches` if needed, keep the row, | ||
cloud-fan marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| // and return false. | ||
| // | ||
| // - Step 2: Find the `matches` from buffered side having same join keys with `streamedRow`. | ||
| // Clear `matches` if we hit a new `streamedRow`, as we need to find new matches. | ||
| // Use `bufferedRow` to iterate buffered side to put all matched rows into | ||
| // `matches`. Return true when getting all matched rows. | ||
| // For `streamedRow` without `matches` (`handleStreamedWithoutMatch`): | ||
| // 1. Inner join: skip the row. | ||
| // 2. Left/Right Outer join: keep the row and return false (with `matches` being | ||
| // empty). | ||
| ctx.addNewFunction("findNextJoinRows", | ||
cloud-fan marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| s""" | ||
| |private boolean findNextJoinRows( | ||
|
||
|
|
@@ -443,8 +508,7 @@ case class SortMergeJoinExec( | |
| | $streamedRow = (InternalRow) streamedIter.next(); | ||
| | ${streamedKeyVars.map(_.code).mkString("\n")} | ||
| | if ($streamedAnyNull) { | ||
| | $streamedRow = null; | ||
| | continue; | ||
| | $handleStreamedAnyNull | ||
| | } | ||
| | if (!$matches.isEmpty()) { | ||
| | ${genComparison(ctx, streamedKeyVars, matchedKeyVars)} | ||
|
|
@@ -475,8 +539,9 @@ case class SortMergeJoinExec( | |
| | if (!$matches.isEmpty()) { | ||
| | ${matchedKeyVars.map(_.code).mkString("\n")} | ||
| | return true; | ||
| | } else { | ||
| | $handleStreamedWithoutMatch | ||
| | } | ||
| | $streamedRow = null; | ||
| | } else { | ||
| | $matches.add((UnsafeRow) $bufferedRow); | ||
| | $bufferedRow = null; | ||
|
|
@@ -501,7 +566,7 @@ case class SortMergeJoinExec( | |
| ctx: CodegenContext, | ||
| streamedRow: String): (Seq[ExprCode], Seq[String]) = { | ||
| ctx.INPUT_ROW = streamedRow | ||
| left.output.zipWithIndex.map { case (a, i) => | ||
| streamedPlan.output.zipWithIndex.map { case (a, i) => | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
| val value = ctx.freshName("value") | ||
| val valueCode = CodeGenerator.getValue(streamedRow, a.dataType, i.toString) | ||
| val javaType = CodeGenerator.javaType(a.dataType) | ||
|
|
@@ -569,7 +634,15 @@ case class SortMergeJoinExec( | |
|
|
||
| val iterator = ctx.freshName("iterator") | ||
| val numOutput = metricTerm(ctx, "numOutputRows") | ||
| val resultVars = streamedVars ++ bufferedVars | ||
| val resultVars = joinType match { | ||
| case _: InnerLike | LeftOuter => | ||
| streamedVars ++ bufferedVars | ||
| case RightOuter => | ||
| bufferedVars ++ streamedVars | ||
| case x => | ||
| throw new IllegalArgumentException( | ||
| s"SortMergeJoin.doProduce should not take $x as the JoinType") | ||
| } | ||
|
|
||
| val (beforeLoop, condCheck) = if (condition.isDefined) { | ||
| // Split the code of creating variables based on whether it's used by condition or not. | ||
|
|
@@ -580,21 +653,27 @@ case class SortMergeJoinExec( | |
| ctx.currentVars = resultVars | ||
| val cond = BindReferences.bindReference(condition.get, output).genCode(ctx) | ||
| // evaluate the columns those used by condition before loop | ||
| val before = s""" | ||
| val before = | ||
| s""" | ||
| |boolean $loaded = false; | ||
| |$streamedBefore | ||
| """.stripMargin | ||
|
|
||
| val checking = s""" | ||
| |$bufferedBefore | ||
| |${cond.code} | ||
| |if (${cond.isNull} || !${cond.value}) continue; | ||
| |if (!$loaded) { | ||
| | $loaded = true; | ||
| | $streamedAfter | ||
| |} | ||
| |$bufferedAfter | ||
| """.stripMargin | ||
| val checking = | ||
| s""" | ||
| |$bufferedBefore | ||
| |if ($bufferedRow != null) { | ||
| | ${cond.code} | ||
| | if (${cond.isNull} || !${cond.value}) { | ||
| | continue; | ||
| | } | ||
| |} | ||
| |if (!$loaded) { | ||
| | $loaded = true; | ||
| | $streamedAfter | ||
| |} | ||
| |$bufferedAfter | ||
| """.stripMargin | ||
| (before, checking) | ||
| } else { | ||
| (evaluateVariables(streamedVars), "") | ||
|
|
@@ -603,21 +682,55 @@ case class SortMergeJoinExec( | |
| val thisPlan = ctx.addReferenceObj("plan", this) | ||
| val eagerCleanup = s"$thisPlan.cleanupResources();" | ||
|
|
||
| s""" | ||
| |while (findNextJoinRows($streamedInput, $bufferedInput)) { | ||
| | ${streamedVarDecl.mkString("\n")} | ||
| | ${beforeLoop.trim} | ||
| | scala.collection.Iterator<UnsafeRow> $iterator = $matches.generateIterator(); | ||
| | while ($iterator.hasNext()) { | ||
| | InternalRow $bufferedRow = (InternalRow) $iterator.next(); | ||
| | ${condCheck.trim} | ||
| | $numOutput.add(1); | ||
| | ${consume(ctx, resultVars)} | ||
| | } | ||
| | if (shouldStop()) return; | ||
| |} | ||
| |$eagerCleanup | ||
| lazy val innerJoin = | ||
| s""" | ||
| |while (findNextJoinRows($streamedInput, $bufferedInput)) { | ||
| | ${streamedVarDecl.mkString("\n")} | ||
| | ${beforeLoop.trim} | ||
| | scala.collection.Iterator<UnsafeRow> $iterator = $matches.generateIterator(); | ||
| | while ($iterator.hasNext()) { | ||
| | InternalRow $bufferedRow = (InternalRow) $iterator.next(); | ||
| | ${condCheck.trim} | ||
| | $numOutput.add(1); | ||
| | ${consume(ctx, resultVars)} | ||
| | } | ||
| | if (shouldStop()) return; | ||
| |} | ||
| |$eagerCleanup | ||
| """.stripMargin | ||
|
|
||
| lazy val outerJoin = { | ||
| val hasOutputRow = ctx.freshName("hasOutputRow") | ||
| s""" | ||
| |while ($streamedInput.hasNext()) { | ||
| | findNextJoinRows($streamedInput, $bufferedInput); | ||
| | ${streamedVarDecl.mkString("\n")} | ||
| | ${beforeLoop.trim} | ||
| | scala.collection.Iterator<UnsafeRow> $iterator = $matches.generateIterator(); | ||
| | boolean $hasOutputRow = false; | ||
| | | ||
| | // the last iteration of this loop is to emit an empty row if there is no matched rows. | ||
| | while ($iterator.hasNext() || !$hasOutputRow) { | ||
| | InternalRow $bufferedRow = $iterator.hasNext() ? | ||
| | (InternalRow) $iterator.next() : null; | ||
| | ${condCheck.trim} | ||
| | $hasOutputRow = true; | ||
| | $numOutput.add(1); | ||
| | ${consume(ctx, resultVars)} | ||
| | } | ||
| | if (shouldStop()) return; | ||
| |} | ||
| |$eagerCleanup | ||
| """.stripMargin | ||
| } | ||
|
|
||
| joinType match { | ||
| case _: InnerLike => innerJoin | ||
| case LeftOuter | RightOuter => outerJoin | ||
| case x => | ||
| throw new IllegalArgumentException( | ||
| s"SortMergeJoin.doProduce should not take $x as the JoinType") | ||
| } | ||
| } | ||
|
|
||
| override protected def withNewChildrenInternal( | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wanted to avoid
clear()ifisEmpty()is true.ExternalAppendOnlyUnsafeRowArray.isEmpty()is very cheap butclear()sets multiple variables.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see. Could you leave some comments about it there?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@maropu - added comment.