Skip to content

Commit 4464f16

Browse files
committed
fix error
1 parent 880d8e9 commit 4464f16

File tree

1 file changed

+53
-37
lines changed

1 file changed

+53
-37
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala

Lines changed: 53 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ import org.apache.spark.annotation.DeveloperApi
2121
import org.apache.spark.sql.Row
2222
import org.apache.spark.sql.catalyst.expressions._
2323
import org.apache.spark.sql.catalyst.plans._
24-
import org.apache.spark.sql.catalyst.plans.physical.{ClusteredOrderedDistribution, Partitioning}
24+
import org.apache.spark.sql.catalyst.plans.physical._
2525
import org.apache.spark.sql.execution.{BinaryNode, SparkPlan}
2626
import org.apache.spark.util.collection.CompactBuffer
2727

@@ -41,7 +41,7 @@ case class SortMergeJoin(
4141

4242
override def outputPartitioning: Partitioning = left.outputPartitioning
4343

44-
override def requiredChildDistribution: Seq[ClusteredOrderedDistribution] =
44+
override def requiredChildDistribution: Seq[Distribution] =
4545
ClusteredOrderedDistribution(leftKeys) :: ClusteredOrderedDistribution(rightKeys) :: Nil
4646

4747
private val orders: Seq[SortOrder] = leftKeys.map(s => SortOrder(s, Ascending))
@@ -62,15 +62,14 @@ case class SortMergeJoin(
6262
private[this] var rightElement: Row = _
6363
private[this] var leftKey: Row = _
6464
private[this] var rightKey: Row = _
65-
private[this] var read: Boolean = false
6665
private[this] var currentlMatches: CompactBuffer[Row] = _
6766
private[this] var currentrMatches: CompactBuffer[Row] = _
6867
private[this] var currentlPosition: Int = -1
6968
private[this] var currentrPosition: Int = -1
7069

7170
override final def hasNext: Boolean =
7271
(currentlPosition != -1 && currentlPosition < currentlMatches.size) ||
73-
(leftIter.hasNext && rightIter.hasNext && nextMatchingPair)
72+
nextMatchingPair
7473

7574
override final def next(): Row = {
7675
val joinedRow =
@@ -83,6 +82,32 @@ case class SortMergeJoin(
8382
joinedRow
8483
}
8584

85+
private def fetchLeft() = {
86+
if (leftIter.hasNext) {
87+
leftElement = leftIter.next()
88+
leftKey = leftKeyGenerator(leftElement)
89+
} else {
90+
leftElement = null
91+
}
92+
}
93+
94+
private def fetchRight() = {
95+
if (rightIter.hasNext) {
96+
rightElement = rightIter.next()
97+
rightKey = rightKeyGenerator(rightElement)
98+
} else {
99+
rightElement = null
100+
}
101+
}
102+
103+
// initialize iterator
104+
private def initialize() = {
105+
fetchLeft()
106+
fetchRight()
107+
}
108+
109+
initialize()
110+
86111
/**
87112
* Searches the left/right iterator for the next rows that matches.
88113
*
@@ -92,42 +117,33 @@ case class SortMergeJoin(
92117
private def nextMatchingPair(): Boolean = {
93118
currentlPosition = -1
94119
currentlMatches = null
95-
if (rightElement == null) {
96-
rightElement = rightIter.next()
97-
rightKey = rightKeyGenerator(rightElement)
120+
var stop: Boolean = false
121+
while (!stop && leftElement != null && rightElement != null) {
122+
if (ordering.compare(leftKey, rightKey) > 0)
123+
fetchRight()
124+
else if (ordering.compare(leftKey, rightKey) < 0)
125+
fetchLeft()
126+
else
127+
stop = true
98128
}
99-
while (currentlMatches == null && leftIter.hasNext) {
100-
if (!read) {
101-
leftElement = leftIter.next()
102-
leftKey = leftKeyGenerator(leftElement)
103-
}
104-
while (ordering.compare(leftKey, rightKey) > 0 && rightIter.hasNext) {
105-
rightElement = rightIter.next()
106-
rightKey = rightKeyGenerator(rightElement)
107-
}
108-
currentrMatches = new CompactBuffer[Row]()
109-
while (ordering.compare(leftKey, rightKey) == 0 && rightIter.hasNext) {
129+
currentrMatches = new CompactBuffer[Row]()
130+
while (stop && rightElement != null) {
131+
if (!rightKey.anyNull)
110132
currentrMatches += rightElement
111-
rightElement = rightIter.next()
112-
rightKey = rightKeyGenerator(rightElement)
113-
}
114-
if (ordering.compare(leftKey, rightKey) == 0) {
115-
currentrMatches += rightElement
116-
}
117-
if (currentrMatches.size > 0) {
118-
// there exists rows match in right table, should search left table
119-
currentlMatches = new CompactBuffer[Row]()
120-
val leftMatch = leftKey.copy()
121-
while (ordering.compare(leftKey, leftMatch) == 0 && leftIter.hasNext) {
122-
currentlMatches += leftElement
123-
leftElement = leftIter.next()
124-
leftKey = leftKeyGenerator(leftElement)
125-
}
126-
if (ordering.compare(leftKey, leftMatch) == 0) {
133+
fetchRight()
134+
if (ordering.compare(leftKey, rightKey) != 0)
135+
stop = false
136+
}
137+
if (currentrMatches.size > 0) {
138+
stop = false
139+
currentlMatches = new CompactBuffer[Row]()
140+
val leftMatch = leftKey.copy()
141+
while (!stop && leftElement != null) {
142+
if (!leftKey.anyNull)
127143
currentlMatches += leftElement
128-
} else {
129-
read = true
130-
}
144+
fetchLeft()
145+
if (ordering.compare(leftKey, leftMatch) != 0)
146+
stop = true
131147
}
132148
}
133149

0 commit comments

Comments
 (0)