diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala index db6626bd18abc..41c2be827ca72 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala @@ -414,10 +414,12 @@ case class SortMergeJoinExec( } /** - * Generate a function to scan both sides to find a match, returns the term for - * matched one row from streamed side and buffered rows from buffered side. + * Generate a function to scan both sides to find a match, returns: + * 1. the function name + * 2. the term for matched one row from streamed side + * 3. the term for buffered rows from buffered side */ - private def genScanner(ctx: CodegenContext): (String, String) = { + private def genScanner(ctx: CodegenContext): (String, String, String) = { // Create class member for next row from both sides. // Inline mutable state since not many join operations in a task val streamedRow = ctx.addMutableState("InternalRow", "streamedRow", forceInline = true) @@ -518,9 +520,10 @@ case class SortMergeJoinExec( // 1. Inner and Left Semi join: skip the row. // 2. Left/Right Outer join: keep the row and return false (with `matches` being // empty). - ctx.addNewFunction("findNextJoinRows", + val findNextJoinRowsFuncName = ctx.freshName("findNextJoinRows") + ctx.addNewFunction(findNextJoinRowsFuncName, s""" - |private boolean findNextJoinRows( + |private boolean $findNextJoinRowsFuncName( | scala.collection.Iterator streamedIter, | scala.collection.Iterator bufferedIter) { | $streamedRow = null; @@ -574,7 +577,7 @@ case class SortMergeJoinExec( |} """.stripMargin, inlineToOuterClass = true) - (streamedRow, matches) + (findNextJoinRowsFuncName, streamedRow, matches) } /** @@ -647,7 +650,7 @@ case class SortMergeJoinExec( val bufferedInput = ctx.addMutableState("scala.collection.Iterator", "bufferedInput", v => s"$v = inputs[1];", forceInline = true) - val (streamedRow, matches) = genScanner(ctx) + val (findNextJoinRowsFuncName, streamedRow, matches) = genScanner(ctx) // Create variables for row from both sides. val (streamedVars, streamedVarDecl) = createStreamedVars(ctx, streamedRow) @@ -715,7 +718,7 @@ case class SortMergeJoinExec( |$numOutput.add(1); |${consume(ctx, resultVars)} """.stripMargin - val findNextJoinRows = s"findNextJoinRows($streamedInput, $bufferedInput)" + val findNextJoinRows = s"$findNextJoinRowsFuncName($streamedInput, $bufferedInput)" val thisPlan = ctx.addReferenceObj("plan", this) val eagerCleanup = s"$thisPlan.cleanupResources();"