@@ -21,7 +21,7 @@ import org.apache.spark.annotation.DeveloperApi
2121import org .apache .spark .sql .Row
2222import org .apache .spark .sql .catalyst .expressions ._
2323import org .apache .spark .sql .catalyst .plans ._
24- import org .apache .spark .sql .catalyst .plans .physical .{ ClusteredOrderedDistribution , Partitioning }
24+ import org .apache .spark .sql .catalyst .plans .physical ._
2525import org .apache .spark .sql .execution .{BinaryNode , SparkPlan }
2626import org .apache .spark .util .collection .CompactBuffer
2727
@@ -41,7 +41,7 @@ case class SortMergeJoin(
4141
4242 override def outputPartitioning : Partitioning = left.outputPartitioning
4343
44- override def requiredChildDistribution : Seq [ClusteredOrderedDistribution ] =
44+ override def requiredChildDistribution : Seq [Distribution ] =
4545 ClusteredOrderedDistribution (leftKeys) :: ClusteredOrderedDistribution (rightKeys) :: Nil
4646
4747 private val orders : Seq [SortOrder ] = leftKeys.map(s => SortOrder (s, Ascending ))
@@ -62,15 +62,14 @@ case class SortMergeJoin(
6262 private [this ] var rightElement : Row = _
6363 private [this ] var leftKey : Row = _
6464 private [this ] var rightKey : Row = _
65- private [this ] var read : Boolean = false
6665 private [this ] var currentlMatches : CompactBuffer [Row ] = _
6766 private [this ] var currentrMatches : CompactBuffer [Row ] = _
6867 private [this ] var currentlPosition : Int = - 1
6968 private [this ] var currentrPosition : Int = - 1
7069
7170 override final def hasNext : Boolean =
7271 (currentlPosition != - 1 && currentlPosition < currentlMatches.size) ||
73- (leftIter.hasNext && rightIter.hasNext && nextMatchingPair)
72+ nextMatchingPair
7473
7574 override final def next (): Row = {
7675 val joinedRow =
@@ -83,6 +82,32 @@ case class SortMergeJoin(
8382 joinedRow
8483 }
8584
85+ private def fetchLeft () = {
86+ if (leftIter.hasNext) {
87+ leftElement = leftIter.next()
88+ leftKey = leftKeyGenerator(leftElement)
89+ } else {
90+ leftElement = null
91+ }
92+ }
93+
94+ private def fetchRight () = {
95+ if (rightIter.hasNext) {
96+ rightElement = rightIter.next()
97+ rightKey = rightKeyGenerator(rightElement)
98+ } else {
99+ rightElement = null
100+ }
101+ }
102+
103+ // initialize iterator
104+ private def initialize () = {
105+ fetchLeft()
106+ fetchRight()
107+ }
108+
109+ initialize()
110+
86111 /**
87112 * Searches the left/right iterator for the next rows that matches.
88113 *
@@ -92,42 +117,33 @@ case class SortMergeJoin(
92117 private def nextMatchingPair (): Boolean = {
93118 currentlPosition = - 1
94119 currentlMatches = null
95- if (rightElement == null ) {
96- rightElement = rightIter.next()
97- rightKey = rightKeyGenerator(rightElement)
120+ var stop : Boolean = false
121+ while (! stop && leftElement != null && rightElement != null ) {
122+ if (ordering.compare(leftKey, rightKey) > 0 )
123+ fetchRight()
124+ else if (ordering.compare(leftKey, rightKey) < 0 )
125+ fetchLeft()
126+ else
127+ stop = true
98128 }
99- while (currentlMatches == null && leftIter.hasNext) {
100- if (! read) {
101- leftElement = leftIter.next()
102- leftKey = leftKeyGenerator(leftElement)
103- }
104- while (ordering.compare(leftKey, rightKey) > 0 && rightIter.hasNext) {
105- rightElement = rightIter.next()
106- rightKey = rightKeyGenerator(rightElement)
107- }
108- currentrMatches = new CompactBuffer [Row ]()
109- while (ordering.compare(leftKey, rightKey) == 0 && rightIter.hasNext) {
129+ currentrMatches = new CompactBuffer [Row ]()
130+ while (stop && rightElement != null ) {
131+ if (! rightKey.anyNull)
110132 currentrMatches += rightElement
111- rightElement = rightIter.next()
112- rightKey = rightKeyGenerator(rightElement)
113- }
114- if (ordering.compare(leftKey, rightKey) == 0 ) {
115- currentrMatches += rightElement
116- }
117- if (currentrMatches.size > 0 ) {
118- // there exists rows match in right table, should search left table
119- currentlMatches = new CompactBuffer [Row ]()
120- val leftMatch = leftKey.copy()
121- while (ordering.compare(leftKey, leftMatch) == 0 && leftIter.hasNext) {
122- currentlMatches += leftElement
123- leftElement = leftIter.next()
124- leftKey = leftKeyGenerator(leftElement)
125- }
126- if (ordering.compare(leftKey, leftMatch) == 0 ) {
133+ fetchRight()
134+ if (ordering.compare(leftKey, rightKey) != 0 )
135+ stop = false
136+ }
137+ if (currentrMatches.size > 0 ) {
138+ stop = false
139+ currentlMatches = new CompactBuffer [Row ]()
140+ val leftMatch = leftKey.copy()
141+ while (! stop && leftElement != null ) {
142+ if (! leftKey.anyNull)
127143 currentlMatches += leftElement
128- } else {
129- read = true
130- }
144+ fetchLeft()
145+ if (ordering.compare(leftKey, leftMatch) != 0 )
146+ stop = true
131147 }
132148 }
133149
0 commit comments