Skip to content

Commit d4e0084

Browse files
committed
Address all comments
1 parent 01f1f04 commit d4e0084

File tree

3 files changed

+66
-78
lines changed

3 files changed

+66
-78
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala

Lines changed: 42 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,9 @@ private[execution] object HashedRelation {
9696

9797
/**
9898
* Create a HashedRelation from an Iterator of InternalRow.
99+
*
100+
* @param isLookupAware reserve one extra boolean in value to track if value being looked up
101+
* @param value the expressions for value inserted into HashedRelation
99102
*/
100103
def apply(
101104
input: Iterator[InternalRow],
@@ -118,6 +121,8 @@ private[execution] object HashedRelation {
118121
if (isNullAware && !input.hasNext) {
119122
EmptyHashedRelation
120123
} else if (key.length == 1 && key.head.dataType == LongType && !isLookupAware) {
124+
// NOTE: LongHashedRelation cannot support isLookupAware as it cannot
125+
// handle NULL key
121126
LongHashedRelation(input, key, sizeEstimate, mm, isNullAware)
122127
} else {
123128
UnsafeHashedRelation(input, key, sizeEstimate, mm, isNullAware, isLookupAware, value)
@@ -148,7 +153,7 @@ private[joins] class UnsafeHashedRelation(
148153

149154
override def estimatedSize: Long = binaryMap.getTotalMemoryConsumption
150155

151-
// re-used in get()/getValue()
156+
// re-used in get()/getValue()/values()
152157
var resultRow = new UnsafeRow(numFields)
153158

154159
override def get(key: InternalRow): Iterator[InternalRow] = {
@@ -186,6 +191,23 @@ private[joins] class UnsafeHashedRelation(
186191
}
187192
}
188193

194+
override def values(): Iterator[InternalRow] = {
195+
val iter = binaryMap.iterator()
196+
197+
new Iterator[InternalRow] {
198+
override def hasNext: Boolean = iter.hasNext
199+
200+
override def next(): InternalRow = {
201+
if (!hasNext) {
202+
throw new NoSuchElementException("End of the iterator")
203+
}
204+
val loc = iter.next()
205+
resultRow.pointTo(loc.getValueBase, loc.getValueOffset, loc.getValueLength)
206+
resultRow
207+
}
208+
}
209+
}
210+
189211
override def keys(): Iterator[InternalRow] = {
190212
val iter = binaryMap.iterator()
191213

@@ -312,23 +334,6 @@ private[joins] class UnsafeHashedRelation(
312334
override def read(kryo: Kryo, in: Input): Unit = Utils.tryOrIOException {
313335
read(() => in.readInt(), () => in.readLong(), in.readBytes)
314336
}
315-
316-
override def values(): Iterator[InternalRow] = {
317-
val iter = binaryMap.iterator()
318-
319-
new Iterator[InternalRow] {
320-
override def hasNext: Boolean = iter.hasNext
321-
322-
override def next(): InternalRow = {
323-
if (!hasNext) {
324-
throw new NoSuchElementException("End of the iterator")
325-
}
326-
val loc = iter.next()
327-
resultRow.pointTo(loc.getValueBase, loc.getValueOffset, loc.getValueLength)
328-
resultRow
329-
}
330-
}
331-
}
332337
}
333338

334339
private[joins] object UnsafeHashedRelation {
@@ -341,6 +346,10 @@ private[joins] object UnsafeHashedRelation {
341346
isNullAware: Boolean = false,
342347
isLookupAware: Boolean = false,
343348
value: Option[Seq[Expression]] = None): HashedRelation = {
349+
if (isNullAware && isLookupAware) {
350+
throw new SparkException(
351+
"isLookupAware and isNullAware cannot be enabled at same time for UnsafeHashedRelation")
352+
}
344353

345354
val pageSizeBytes = Option(SparkEnv.get).map(_.memoryManager.pageSizeBytes)
346355
.getOrElse(new SparkConf().get(BUFFER_PAGESIZE).getOrElse(16L * 1024 * 1024))
@@ -354,44 +363,36 @@ private[joins] object UnsafeHashedRelation {
354363
val keyGenerator = UnsafeProjection.create(key)
355364
var numFields = 0
356365

366+
val append = (key: UnsafeRow, value: UnsafeRow) => {
367+
val loc = binaryMap.lookup(key.getBaseObject, key.getBaseOffset, key.getSizeInBytes)
368+
val success = loc.append(
369+
key.getBaseObject, key.getBaseOffset, key.getSizeInBytes,
370+
value.getBaseObject, value.getBaseOffset, value.getSizeInBytes)
371+
if (!success) {
372+
binaryMap.free()
373+
// scalastyle:off throwerror
374+
throw new SparkOutOfMemoryError("There is not enough memory to build hash map")
375+
// scalastyle:on throwerror
376+
}
377+
}
378+
357379
if (isLookupAware) {
358380
// Add one extra boolean value at the end as part of the row,
359381
// to track the information that whether the corresponding key
360382
// has been looked up or not. See `ShuffledHashJoin.fullOuterJoin` for example of usage.
361383
val valueGenerator = UnsafeProjection.create(value.get :+ Literal(false))
362-
363384
while (input.hasNext) {
364385
val row = input.next().asInstanceOf[UnsafeRow]
365386
numFields = row.numFields() + 1
366-
val key = keyGenerator(row)
367-
val value = valueGenerator(row)
368-
val loc = binaryMap.lookup(key.getBaseObject, key.getBaseOffset, key.getSizeInBytes)
369-
val success = loc.append(
370-
key.getBaseObject, key.getBaseOffset, key.getSizeInBytes,
371-
value.getBaseObject, value.getBaseOffset, value.getSizeInBytes)
372-
if (!success) {
373-
binaryMap.free()
374-
// scalastyle:off throwerror
375-
throw new SparkOutOfMemoryError("There is not enough memory to build hash map")
376-
// scalastyle:on throwerror
377-
}
387+
append(keyGenerator(row), valueGenerator(row))
378388
}
379389
} else {
380390
while (input.hasNext) {
381391
val row = input.next().asInstanceOf[UnsafeRow]
382392
numFields = row.numFields()
383393
val key = keyGenerator(row)
384394
if (!key.anyNull) {
385-
val loc = binaryMap.lookup(key.getBaseObject, key.getBaseOffset, key.getSizeInBytes)
386-
val success = loc.append(
387-
key.getBaseObject, key.getBaseOffset, key.getSizeInBytes,
388-
row.getBaseObject, row.getBaseOffset, row.getSizeInBytes)
389-
if (!success) {
390-
binaryMap.free()
391-
// scalastyle:off throwerror
392-
throw new SparkOutOfMemoryError("There is not enough memory to build hash map")
393-
// scalastyle:on throwerror
394-
}
395+
append(key, row)
395396
} else if (isNullAware) {
396397
return EmptyHashedRelationWithAllNullKeys
397398
}

sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala

Lines changed: 23 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -66,12 +66,11 @@ case class ShuffledHashJoinExec(
6666
val start = System.nanoTime()
6767
val context = TaskContext.get()
6868

69-
val (isLookupAware, value) =
70-
if (joinType == FullOuter) {
71-
(true, Some(BindReferences.bindReferences(buildOutput, buildOutput)))
72-
} else {
73-
(false, None)
74-
}
69+
val (isLookupAware, value) = if (joinType == FullOuter) {
70+
(true, Some(BindReferences.bindReferences(buildOutput, buildOutput)))
71+
} else {
72+
(false, None)
73+
}
7574
val relation = HashedRelation(
7675
iter,
7776
buildBoundKeys,
@@ -110,24 +109,12 @@ case class ShuffledHashJoinExec(
110109
streamIter: Iterator[InternalRow],
111110
hashedRelation: HashedRelation,
112111
numOutputRows: SQLMetric): Iterator[InternalRow] = {
113-
abstract class HashJoinedRow extends JoinedRow {
114-
/** Updates this JoinedRow by updating its stream side row. Returns itself. */
115-
def withStream(newStream: InternalRow): JoinedRow
116-
117-
/** Updates this JoinedRow by updating its build side row. Returns itself. */
118-
def withBuild(newBuild: InternalRow): JoinedRow
119-
}
120-
val joinRow: HashJoinedRow = buildSide match {
121-
case BuildLeft =>
122-
new HashJoinedRow {
123-
override def withStream(newStream: InternalRow): JoinedRow = withRight(newStream)
124-
override def withBuild(newBuild: InternalRow): JoinedRow = withLeft(newBuild)
125-
}
126-
case BuildRight =>
127-
new HashJoinedRow {
128-
override def withStream(newStream: InternalRow): JoinedRow = withLeft(newStream)
129-
override def withBuild(newBuild: InternalRow): JoinedRow = withRight(newBuild)
130-
}
112+
val joinRow = new JoinedRow
113+
val (joinRowWithStream, joinRowWithBuild) = {
114+
buildSide match {
115+
case BuildLeft => (joinRow.withRight _, joinRow.withLeft _)
116+
case BuildRight => (joinRow.withLeft _, joinRow.withRight _)
117+
}
131118
}
132119
val joinKeys = streamSideKeyGenerator()
133120
val buildRowGenerator = UnsafeProjection.create(buildOutput, buildOutput)
@@ -141,31 +128,31 @@ case class ShuffledHashJoinExec(
141128
val streamResultIter =
142129
if (hashedRelation.keyIsUnique) {
143130
streamIter.map { srow =>
144-
joinRow.withStream(srow)
131+
joinRowWithStream(srow)
145132
val keys = joinKeys(srow)
146133
if (keys.anyNull) {
147-
joinRow.withBuild(buildNullRow)
134+
joinRowWithBuild(buildNullRow)
148135
} else {
149136
val matched = hashedRelation.getValue(keys)
150137
if (matched != null) {
151138
val buildRow = buildRowGenerator(matched)
152-
if (boundCondition(joinRow.withBuild(buildRow))) {
139+
if (boundCondition(joinRowWithBuild(buildRow))) {
153140
markRowLookedUp(matched.asInstanceOf[UnsafeRow])
154141
joinRow
155142
} else {
156-
joinRow.withBuild(buildNullRow)
143+
joinRowWithBuild(buildNullRow)
157144
}
158145
} else {
159-
joinRow.withBuild(buildNullRow)
146+
joinRowWithBuild(buildNullRow)
160147
}
161148
}
162149
}
163150
} else {
164151
streamIter.flatMap { srow =>
165-
joinRow.withStream(srow)
152+
joinRowWithStream(srow)
166153
val keys = joinKeys(srow)
167154
if (keys.anyNull) {
168-
Iterator.single(joinRow.withBuild(buildNullRow))
155+
Iterator.single(joinRowWithBuild(buildNullRow))
169156
} else {
170157
val buildIter = hashedRelation.get(keys)
171158
new RowIterator {
@@ -174,14 +161,14 @@ case class ShuffledHashJoinExec(
174161
while (buildIter != null && buildIter.hasNext) {
175162
val matched = buildIter.next()
176163
val buildRow = buildRowGenerator(matched)
177-
if (boundCondition(joinRow.withBuild(buildRow))) {
164+
if (boundCondition(joinRowWithBuild(buildRow))) {
178165
markRowLookedUp(matched.asInstanceOf[UnsafeRow])
179166
found = true
180167
return true
181168
}
182169
}
183170
if (!found) {
184-
joinRow.withBuild(buildNullRow)
171+
joinRowWithBuild(buildNullRow)
185172
found = true
186173
return true
187174
}
@@ -199,8 +186,8 @@ case class ShuffledHashJoinExec(
199186
val isLookup = unsafebrow.getBoolean(unsafebrow.numFields() - 1)
200187
if (!isLookup) {
201188
val buildRow = buildRowGenerator(unsafebrow)
202-
joinRow.withBuild(buildRow)
203-
joinRow.withStream(streamNullRow)
189+
joinRowWithBuild(buildRow)
190+
joinRowWithStream(streamNullRow)
204191
Some(joinRow)
205192
} else {
206193
None
@@ -214,7 +201,7 @@ case class ShuffledHashJoinExec(
214201
}
215202
}
216203

217-
// TODO: support full outer shuffled hash join code-gen
204+
// TODO(SPARK-32567): support full outer shuffled hash join code-gen
218205
override def supportCodegen: Boolean = {
219206
joinType != FullOuter
220207
}

sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledJoin.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ trait ShuffledJoin extends BaseJoinExec {
5858
left.output
5959
case x =>
6060
throw new IllegalArgumentException(
61-
s"ShuffledJoin not take $x as the JoinType")
61+
s"${getClass.getSimpleName} not take $x as the JoinType")
6262
}
6363
}
6464
}

0 commit comments

Comments
 (0)