Skip to content

Commit c6e80a2

Browse files
committed
Address Andrew's comments
1 parent 327b718 commit c6e80a2

File tree

4 files changed

+77
-71
lines changed

4 files changed

+77
-71
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/local/ExpandNode.scala

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -36,25 +36,22 @@ case class ExpandNode(
3636

3737
override def open(): Unit = {
3838
child.open()
39-
idx = -1
4039
groups = projections.map(ee => newProjection(ee, child.output)).toArray
40+
idx = groups.length
4141
}
4242

4343
override def next(): Boolean = {
44-
if (idx < 0 || idx >= groups.length) {
44+
if (idx >= groups.length) {
4545
if (child.next()) {
4646
input = child.fetch()
47-
result = groups(0)(input)
48-
idx = 1
49-
true
47+
idx = 0
5048
} else {
51-
false
49+
return false
5250
}
53-
} else {
54-
result = groups(idx)(input)
55-
idx += 1
56-
true
5751
}
52+
result = groups(idx)(input)
53+
idx += 1
54+
true
5855
}
5956

6057
override def fetch(): InternalRow = result

sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -86,11 +86,6 @@ abstract class LocalNode(conf: SQLConf) extends TreeNode[LocalNode] with Logging
8686
*/
8787
final def asIterator: Iterator[InternalRow] = new LocalNodeIterator(this)
8888

89-
/**
90-
* Returns the content through the [[Iterator]] interface.
91-
*/
92-
final def asIterator: Iterator[InternalRow] = new LocalNodeIterator(this)
93-
9489
/**
9590
* Returns the content of the iterator from the beginning to the end in the form of a Scala Seq.
9691
*/
@@ -109,7 +104,8 @@ abstract class LocalNode(conf: SQLConf) extends TreeNode[LocalNode] with Logging
109104
}
110105

111106
protected def newProjection(
112-
expressions: Seq[Expression], inputSchema: Seq[Attribute]): Projection = {
107+
expressions: Seq[Expression],
108+
inputSchema: Seq[Attribute]): Projection = {
113109
log.debug(
114110
s"Creating Projection: $expressions, inputSchema: $inputSchema, codegen:$codegenEnabled")
115111
if (codegenEnabled) {
@@ -152,7 +148,8 @@ abstract class LocalNode(conf: SQLConf) extends TreeNode[LocalNode] with Logging
152148
}
153149

154150
protected def newPredicate(
155-
expression: Expression, inputSchema: Seq[Attribute]): (InternalRow) => Boolean = {
151+
expression: Expression,
152+
inputSchema: Seq[Attribute]): (InternalRow) => Boolean = {
156153
if (codegenEnabled) {
157154
try {
158155
GeneratePredicate.generate(expression, inputSchema)

sql/core/src/main/scala/org/apache/spark/sql/execution/local/NestedLoopJoinNode.scala

Lines changed: 36 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,7 @@ import org.apache.spark.sql.catalyst.InternalRow
2222
import org.apache.spark.sql.catalyst.expressions._
2323
import org.apache.spark.sql.catalyst.plans.{FullOuter, RightOuter, LeftOuter, JoinType}
2424
import org.apache.spark.sql.execution.joins.{BuildLeft, BuildRight, BuildSide}
25-
import org.apache.spark.util.collection.BitSet
26-
import org.apache.spark.util.collection.CompactBuffer
25+
import org.apache.spark.util.collection.{BitSet, CompactBuffer}
2726

2827
case class NestedLoopJoinNode(
2928
conf: SQLConf,
@@ -77,65 +76,64 @@ case class NestedLoopJoinNode(
7776
val leftNulls = new GenericMutableRow(left.output.size)
7877
val rightNulls = new GenericMutableRow(right.output.size)
7978
val joinedRow = new JoinedRow
80-
val includedBuildTuples = new BitSet(buildRelation.size)
79+
val matchedBuildTuples = new BitSet(buildRelation.size)
8180
val resultProj = genResultProjection
8281
streamed.open()
8382

84-
val matchesOrStreamedRowsWithNulls = streamed.asIterator.flatMap { streamedRow =>
83+
// streamedRowMatches also contains null rows if using outer join
84+
val streamedRowMatches: Iterator[InternalRow] = streamed.asIterator.flatMap { streamedRow =>
8585
val matchedRows = new CompactBuffer[InternalRow]
8686

8787
var i = 0
8888
var streamRowMatched = false
8989

90+
// Scan the build relation to look for matches for each streamed row
9091
while (i < buildRelation.size) {
9192
val buildRow = buildRelation(i)
9293
buildSide match {
93-
case BuildRight if boundCondition(joinedRow(streamedRow, buildRow)) =>
94-
matchedRows += resultProj(joinedRow(streamedRow, buildRow)).copy()
95-
streamRowMatched = true
96-
includedBuildTuples.set(i)
97-
case BuildLeft if boundCondition(joinedRow(buildRow, streamedRow)) =>
98-
matchedRows += resultProj(joinedRow(buildRow, streamedRow)).copy()
99-
streamRowMatched = true
100-
includedBuildTuples.set(i)
101-
case _ =>
94+
case BuildRight => joinedRow(streamedRow, buildRow)
95+
case BuildLeft => joinedRow(buildRow, streamedRow)
96+
}
97+
if (boundCondition(joinedRow)) {
98+
matchedRows += resultProj(joinedRow).copy()
99+
streamRowMatched = true
100+
matchedBuildTuples.set(i)
102101
}
103102
i += 1
104103
}
105104

106-
(streamRowMatched, joinType, buildSide) match {
107-
case (false, LeftOuter | FullOuter, BuildRight) =>
108-
matchedRows += resultProj(joinedRow(streamedRow, rightNulls)).copy()
109-
case (false, RightOuter | FullOuter, BuildLeft) =>
110-
matchedRows += resultProj(joinedRow(leftNulls, streamedRow)).copy()
111-
case _ =>
105+
// If this row had no matches and we're using outer join, join it with the null rows
106+
if (!streamRowMatched) {
107+
(joinType, buildSide) match {
108+
case (LeftOuter | FullOuter, BuildRight) =>
109+
matchedRows += resultProj(joinedRow(streamedRow, rightNulls)).copy()
110+
case (RightOuter | FullOuter, BuildLeft) =>
111+
matchedRows += resultProj(joinedRow(leftNulls, streamedRow)).copy()
112+
case _ =>
113+
}
112114
}
113115

114116
matchedRows.iterator
115117
}
116118

119+
// If we're using outer join, find rows on the build side that didn't match anything
120+
// and join them with the null row
121+
lazy val unmatchedBuildRows: Iterator[InternalRow] = {
122+
var i = 0
123+
buildRelation.filter { row =>
124+
val r = !matchedBuildTuples.get(i)
125+
i += 1
126+
r
127+
}.iterator
128+
}
117129
iterator = (joinType, buildSide) match {
118130
case (RightOuter | FullOuter, BuildRight) =>
119-
var i = 0
120-
matchesOrStreamedRowsWithNulls ++ buildRelation.filter { row =>
121-
val r = !includedBuildTuples.get(i)
122-
i += 1
123-
r
124-
}.iterator.map { buildRow =>
125-
joinedRow.withLeft(leftNulls)
126-
resultProj(joinedRow.withRight(buildRow))
127-
}
131+
streamedRowMatches ++
132+
unmatchedBuildRows.map { buildRow => resultProj(joinedRow(leftNulls, buildRow)) }
128133
case (LeftOuter | FullOuter, BuildLeft) =>
129-
var i = 0
130-
matchesOrStreamedRowsWithNulls ++ buildRelation.filter { row =>
131-
val r = !includedBuildTuples.get(i)
132-
i += 1
133-
r
134-
}.iterator.map { buildRow =>
135-
joinedRow.withRight(rightNulls)
136-
resultProj(joinedRow.withLeft(buildRow))
137-
}
138-
case _ => matchesOrStreamedRowsWithNulls
134+
streamedRowMatches ++
135+
unmatchedBuildRows.map { buildRow => resultProj(joinedRow(buildRow, rightNulls)) }
136+
case _ => streamedRowMatches
139137
}
140138
}
141139

sql/core/src/test/scala/org/apache/spark/sql/execution/local/NestedLoopJoinNodeSuite.scala

Lines changed: 30 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,14 @@ package org.apache.spark.sql.execution.local
1919

2020
import org.apache.spark.sql.SQLConf
2121
import org.apache.spark.sql.catalyst.plans.{FullOuter, LeftOuter, RightOuter}
22-
import org.apache.spark.sql.execution.joins.BuildRight
22+
import org.apache.spark.sql.execution.joins.{BuildLeft, BuildRight, BuildSide}
2323

2424
class NestedLoopJoinNodeSuite extends LocalNodeTest {
2525

2626
import testImplicits._
2727

28-
def joinSuite(suiteName: String, confPairs: (String, String)*): Unit = {
28+
private def joinSuite(
29+
suiteName: String, buildSide: BuildSide, confPairs: (String, String)*): Unit = {
2930
test(s"$suiteName: left outer join") {
3031
withSQLConf(confPairs: _*) {
3132
checkAnswer2(
@@ -36,7 +37,7 @@ class NestedLoopJoinNodeSuite extends LocalNodeTest {
3637
conf,
3738
node1,
3839
node2,
39-
BuildRight,
40+
buildSide,
4041
LeftOuter,
4142
Some((upperCaseData.col("N") === lowerCaseData.col("n")).expr))
4243
),
@@ -50,7 +51,7 @@ class NestedLoopJoinNodeSuite extends LocalNodeTest {
5051
conf,
5152
node1,
5253
node2,
53-
BuildRight,
54+
buildSide,
5455
LeftOuter,
5556
Some(
5657
(upperCaseData.col("N") === lowerCaseData.col("n") &&
@@ -66,7 +67,7 @@ class NestedLoopJoinNodeSuite extends LocalNodeTest {
6667
conf,
6768
node1,
6869
node2,
69-
BuildRight,
70+
buildSide,
7071
LeftOuter,
7172
Some(
7273
(upperCaseData.col("N") === lowerCaseData.col("n") &&
@@ -82,7 +83,7 @@ class NestedLoopJoinNodeSuite extends LocalNodeTest {
8283
conf,
8384
node1,
8485
node2,
85-
BuildRight,
86+
buildSide,
8687
LeftOuter,
8788
Some(
8889
(upperCaseData.col("N") === lowerCaseData.col("n") &&
@@ -102,7 +103,7 @@ class NestedLoopJoinNodeSuite extends LocalNodeTest {
102103
conf,
103104
node1,
104105
node2,
105-
BuildRight,
106+
buildSide,
106107
RightOuter,
107108
Some((lowerCaseData.col("n") === upperCaseData.col("N")).expr))
108109
),
@@ -116,7 +117,7 @@ class NestedLoopJoinNodeSuite extends LocalNodeTest {
116117
conf,
117118
node1,
118119
node2,
119-
BuildRight,
120+
buildSide,
120121
RightOuter,
121122
Some((lowerCaseData.col("n") === upperCaseData.col("N") &&
122123
lowerCaseData.col("n") > 1).expr))
@@ -131,7 +132,7 @@ class NestedLoopJoinNodeSuite extends LocalNodeTest {
131132
conf,
132133
node1,
133134
node2,
134-
BuildRight,
135+
buildSide,
135136
RightOuter,
136137
Some((lowerCaseData.col("n") === upperCaseData.col("N") &&
137138
upperCaseData.col("N") > 1).expr))
@@ -146,7 +147,7 @@ class NestedLoopJoinNodeSuite extends LocalNodeTest {
146147
conf,
147148
node1,
148149
node2,
149-
BuildRight,
150+
buildSide,
150151
RightOuter,
151152
Some((lowerCaseData.col("n") === upperCaseData.col("N") &&
152153
lowerCaseData.col("l") > upperCaseData.col("L")).expr))
@@ -165,7 +166,7 @@ class NestedLoopJoinNodeSuite extends LocalNodeTest {
165166
conf,
166167
node1,
167168
node2,
168-
BuildRight,
169+
buildSide,
169170
FullOuter,
170171
Some((lowerCaseData.col("n") === upperCaseData.col("N")).expr))
171172
),
@@ -179,7 +180,7 @@ class NestedLoopJoinNodeSuite extends LocalNodeTest {
179180
conf,
180181
node1,
181182
node2,
182-
BuildRight,
183+
buildSide,
183184
FullOuter,
184185
Some((lowerCaseData.col("n") === upperCaseData.col("N") &&
185186
lowerCaseData.col("n") > 1).expr))
@@ -194,7 +195,7 @@ class NestedLoopJoinNodeSuite extends LocalNodeTest {
194195
conf,
195196
node1,
196197
node2,
197-
BuildRight,
198+
buildSide,
198199
FullOuter,
199200
Some((lowerCaseData.col("n") === upperCaseData.col("N") &&
200201
upperCaseData.col("N") > 1).expr))
@@ -209,7 +210,7 @@ class NestedLoopJoinNodeSuite extends LocalNodeTest {
209210
conf,
210211
node1,
211212
node2,
212-
BuildRight,
213+
buildSide,
213214
FullOuter,
214215
Some((lowerCaseData.col("n") === upperCaseData.col("N") &&
215216
lowerCaseData.col("l") > upperCaseData.col("L")).expr))
@@ -220,6 +221,19 @@ class NestedLoopJoinNodeSuite extends LocalNodeTest {
220221
}
221222

222223
joinSuite(
223-
"general", SQLConf.CODEGEN_ENABLED.key -> "false", SQLConf.UNSAFE_ENABLED.key -> "false")
224-
joinSuite("tungsten", SQLConf.CODEGEN_ENABLED.key -> "true", SQLConf.UNSAFE_ENABLED.key -> "true")
224+
"general-build-left",
225+
BuildLeft,
226+
SQLConf.CODEGEN_ENABLED.key -> "false", SQLConf.UNSAFE_ENABLED.key -> "false")
227+
joinSuite(
228+
"general-build-right",
229+
BuildRight,
230+
SQLConf.CODEGEN_ENABLED.key -> "false", SQLConf.UNSAFE_ENABLED.key -> "false")
231+
joinSuite(
232+
"tungsten-build-left",
233+
BuildLeft,
234+
SQLConf.CODEGEN_ENABLED.key -> "true", SQLConf.UNSAFE_ENABLED.key -> "true")
235+
joinSuite(
236+
"tungsten-build-right",
237+
BuildRight,
238+
SQLConf.CODEGEN_ENABLED.key -> "true", SQLConf.UNSAFE_ENABLED.key -> "true")
225239
}

0 commit comments

Comments
 (0)