Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -414,10 +414,12 @@ case class SortMergeJoinExec(
}

/**
* Generate a function to scan both sides to find a match, returns the term for
* matched one row from streamed side and buffered rows from buffered side.
* Generate a function to scan both sides to find a match, returns:
* 1. the function name
* 2. the term for matched one row from streamed side
* 3. the term for buffered rows from buffered side
*/
private def genScanner(ctx: CodegenContext): (String, String) = {
private def genScanner(ctx: CodegenContext): (String, String, String) = {
// Create class member for next row from both sides.
// Inline mutable state since not many join operations in a task
val streamedRow = ctx.addMutableState("InternalRow", "streamedRow", forceInline = true)
Expand Down Expand Up @@ -518,9 +520,10 @@ case class SortMergeJoinExec(
// 1. Inner and Left Semi join: skip the row.
// 2. Left/Right Outer join: keep the row and return false (with `matches` being
// empty).
ctx.addNewFunction("findNextJoinRows",
val findNextJoinRowsFuncName = ctx.freshName("findNextJoinRows")
ctx.addNewFunction(findNextJoinRowsFuncName,
s"""
|private boolean findNextJoinRows(
|private boolean $findNextJoinRowsFuncName(
| scala.collection.Iterator streamedIter,
| scala.collection.Iterator bufferedIter) {
| $streamedRow = null;
Expand Down Expand Up @@ -574,7 +577,7 @@ case class SortMergeJoinExec(
|}
""".stripMargin, inlineToOuterClass = true)

(streamedRow, matches)
(findNextJoinRowsFuncName, streamedRow, matches)
}

/**
Expand Down Expand Up @@ -647,7 +650,7 @@ case class SortMergeJoinExec(
val bufferedInput = ctx.addMutableState("scala.collection.Iterator", "bufferedInput",
v => s"$v = inputs[1];", forceInline = true)

val (streamedRow, matches) = genScanner(ctx)
val (findNextJoinRowsFuncName, streamedRow, matches) = genScanner(ctx)

// Create variables for row from both sides.
val (streamedVars, streamedVarDecl) = createStreamedVars(ctx, streamedRow)
Expand Down Expand Up @@ -715,7 +718,7 @@ case class SortMergeJoinExec(
|$numOutput.add(1);
|${consume(ctx, resultVars)}
""".stripMargin
val findNextJoinRows = s"findNextJoinRows($streamedInput, $bufferedInput)"
val findNextJoinRows = s"$findNextJoinRowsFuncName($streamedInput, $bufferedInput)"
val thisPlan = ctx.addReferenceObj("plan", this)
val eagerCleanup = s"$thisPlan.cleanupResources();"

Expand Down