Skip to content

Commit 635f6fb

Browse files
committed
Full outer shuffled hash join
1 parent 1c6dff7 commit 635f6fb

File tree

8 files changed

+323
-54
lines changed

8 files changed

+323
-54
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -235,8 +235,8 @@ trait JoinSelectionHelper {
235235
canBroadcastBySize(right, conf) && !hintToNotBroadcastRight(hint)
236236
}
237237
getBuildSide(
238-
canBuildLeft(joinType) && buildLeft,
239-
canBuildRight(joinType) && buildRight,
238+
canBuildBroadcastLeft(joinType) && buildLeft,
239+
canBuildBroadcastRight(joinType) && buildRight,
240240
left,
241241
right
242242
)
@@ -260,8 +260,8 @@ trait JoinSelectionHelper {
260260
canBuildLocalHashMapBySize(right, conf) && muchSmaller(right, left)
261261
}
262262
getBuildSide(
263-
canBuildLeft(joinType) && buildLeft,
264-
canBuildRight(joinType) && buildRight,
263+
canBuildShuffledHashJoinLeft(joinType) && buildLeft,
264+
canBuildShuffledHashJoinRight(joinType) && buildRight,
265265
left,
266266
right
267267
)
@@ -278,20 +278,35 @@ trait JoinSelectionHelper {
278278
plan.stats.sizeInBytes >= 0 && plan.stats.sizeInBytes <= conf.autoBroadcastJoinThreshold
279279
}
280280

281-
def canBuildLeft(joinType: JoinType): Boolean = {
281+
def canBuildBroadcastLeft(joinType: JoinType): Boolean = {
282282
joinType match {
283283
case _: InnerLike | RightOuter => true
284284
case _ => false
285285
}
286286
}
287287

288-
def canBuildRight(joinType: JoinType): Boolean = {
288+
def canBuildBroadcastRight(joinType: JoinType): Boolean = {
289289
joinType match {
290290
case _: InnerLike | LeftOuter | LeftSemi | LeftAnti | _: ExistenceJoin => true
291291
case _ => false
292292
}
293293
}
294294

295+
def canBuildShuffledHashJoinLeft(joinType: JoinType): Boolean = {
296+
joinType match {
297+
case _: InnerLike | RightOuter | FullOuter => true
298+
case _ => false
299+
}
300+
}
301+
302+
def canBuildShuffledHashJoinRight(joinType: JoinType): Boolean = {
303+
joinType match {
304+
case _: InnerLike | LeftOuter | FullOuter |
305+
LeftSemi | LeftAnti | _: ExistenceJoin => true
306+
case _ => false
307+
}
308+
}
309+
295310
def hintToBroadcastLeft(hint: JoinHint): Boolean = {
296311
hint.leftHint.exists(_.strategy.contains(BROADCAST))
297312
}

sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,9 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
116116
*
117117
* - Shuffle hash join:
118118
* Only supported for equi-joins, while the join keys do not need to be sortable.
119-
* Supported for all join types except full outer joins.
119+
* Supported for all join types.
120+
* Building hash map from table is a memory-intensive operation and it could cause OOM
121+
* when the build side is big.
120122
*
121123
* - Shuffle sort merge join (SMJ):
122124
* Only supported for equi-joins and the join keys have to be sortable.
@@ -260,7 +262,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
260262
// it's a right join, and broadcast right side if it's a left join.
261263
// TODO: revisit it. If left side is much smaller than the right side, it may be better
262264
// to broadcast the left side even if it's a left join.
263-
if (canBuildLeft(joinType)) BuildLeft else BuildRight
265+
if (canBuildBroadcastLeft(joinType)) BuildLeft else BuildRight
264266
}
265267

266268
def createBroadcastNLJoin(buildLeft: Boolean, buildRight: Boolean) = {

sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ trait HashJoin extends BaseJoinExec with CodegenSupport {
104104
}
105105
}
106106

107-
@transient private lazy val (buildOutput, streamedOutput) = {
107+
@transient protected lazy val (buildOutput, streamedOutput) = {
108108
buildSide match {
109109
case BuildLeft => (left.output, right.output)
110110
case BuildRight => (right.output, left.output)
@@ -123,7 +123,7 @@ trait HashJoin extends BaseJoinExec with CodegenSupport {
123123
protected def streamSideKeyGenerator(): UnsafeProjection =
124124
UnsafeProjection.create(streamedBoundKeys)
125125

126-
@transient private[this] lazy val boundCondition = if (condition.isDefined) {
126+
@transient protected[this] lazy val boundCondition = if (condition.isDefined) {
127127
Predicate.create(condition.get, streamedPlan.output ++ buildPlan.output).eval _
128128
} else {
129129
(r: InternalRow) => true

sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala

Lines changed: 83 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,11 @@ private[execution] sealed trait HashedRelation extends KnownSizeEstimation {
7676
*/
7777
def keys(): Iterator[InternalRow]
7878

79+
/**
80+
* Returns an iterator for values of InternalRow type.
81+
*/
82+
def values(): Iterator[InternalRow]
83+
7984
/**
8085
* Returns a read-only copy of this, to be safely used in current thread.
8186
*/
@@ -97,7 +102,9 @@ private[execution] object HashedRelation {
97102
key: Seq[Expression],
98103
sizeEstimate: Int = 64,
99104
taskMemoryManager: TaskMemoryManager = null,
100-
isNullAware: Boolean = false): HashedRelation = {
105+
isNullAware: Boolean = false,
106+
isLookupAware: Boolean = false,
107+
value: Option[Seq[Expression]] = None): HashedRelation = {
101108
val mm = Option(taskMemoryManager).getOrElse {
102109
new TaskMemoryManager(
103110
new UnifiedMemoryManager(
@@ -110,10 +117,10 @@ private[execution] object HashedRelation {
110117

111118
if (isNullAware && !input.hasNext) {
112119
EmptyHashedRelation
113-
} else if (key.length == 1 && key.head.dataType == LongType) {
120+
} else if (key.length == 1 && key.head.dataType == LongType && !isLookupAware) {
114121
LongHashedRelation(input, key, sizeEstimate, mm, isNullAware)
115122
} else {
116-
UnsafeHashedRelation(input, key, sizeEstimate, mm, isNullAware)
123+
UnsafeHashedRelation(input, key, sizeEstimate, mm, isNullAware, isLookupAware, value)
117124
}
118125
}
119126
}
@@ -128,15 +135,18 @@ private[execution] object HashedRelation {
128135
private[joins] class UnsafeHashedRelation(
129136
private var numKeys: Int,
130137
private var numFields: Int,
131-
private var binaryMap: BytesToBytesMap)
138+
private var binaryMap: BytesToBytesMap,
139+
private val isLookupAware: Boolean = false)
132140
extends HashedRelation with Externalizable with KryoSerializable {
133141

134-
private[joins] def this() = this(0, 0, null) // Needed for serialization
142+
private[joins] def this() = this(0, 0, null, false) // Needed for serialization
135143

136-
override def keyIsUnique: Boolean = binaryMap.numKeys() == binaryMap.numValues()
144+
override def keyIsUnique: Boolean = {
145+
binaryMap.numKeys() == binaryMap.numValues()
146+
}
137147

138148
override def asReadOnlyCopy(): UnsafeHashedRelation = {
139-
new UnsafeHashedRelation(numKeys, numFields, binaryMap)
149+
new UnsafeHashedRelation(numKeys, numFields, binaryMap, isLookupAware)
140150
}
141151

142152
override def estimatedSize: Long = binaryMap.getTotalMemoryConsumption
@@ -305,6 +315,27 @@ private[joins] class UnsafeHashedRelation(
305315
override def read(kryo: Kryo, in: Input): Unit = Utils.tryOrIOException {
306316
read(() => in.readInt(), () => in.readLong(), in.readBytes)
307317
}
318+
319+
override def values(): Iterator[InternalRow] = {
320+
if (isLookupAware) {
321+
val iter = binaryMap.iterator()
322+
323+
new Iterator[InternalRow] {
324+
override def hasNext: Boolean = iter.hasNext
325+
326+
override def next(): InternalRow = {
327+
if (!hasNext) {
328+
throw new NoSuchElementException("End of the iterator")
329+
}
330+
val loc = iter.next()
331+
resultRow.pointTo(loc.getValueBase, loc.getValueOffset, loc.getValueLength)
332+
resultRow
333+
}
334+
}
335+
} else {
336+
throw new UnsupportedOperationException
337+
}
338+
}
308339
}
309340

310341
private[joins] object UnsafeHashedRelation {
@@ -314,7 +345,9 @@ private[joins] object UnsafeHashedRelation {
314345
key: Seq[Expression],
315346
sizeEstimate: Int,
316347
taskMemoryManager: TaskMemoryManager,
317-
isNullAware: Boolean = false): HashedRelation = {
348+
isNullAware: Boolean = false,
349+
isLookupAware: Boolean = false,
350+
value: Option[Seq[Expression]] = None): HashedRelation = {
318351

319352
val pageSizeBytes = Option(SparkEnv.get).map(_.memoryManager.pageSizeBytes)
320353
.getOrElse(new SparkConf().get(BUFFER_PAGESIZE).getOrElse(16L * 1024 * 1024))
@@ -327,27 +360,52 @@ private[joins] object UnsafeHashedRelation {
327360
// Create a mapping of buildKeys -> rows
328361
val keyGenerator = UnsafeProjection.create(key)
329362
var numFields = 0
330-
while (input.hasNext) {
331-
val row = input.next().asInstanceOf[UnsafeRow]
332-
numFields = row.numFields()
333-
val key = keyGenerator(row)
334-
if (!key.anyNull) {
363+
364+
if (isLookupAware) {
365+
// Add one extra boolean value at the end as part of the row,
366+
// to track the information that whether the corresponding key
367+
// has been looked up or not. See `ShuffledHashJoin.fullOuterJoin` for example of usage.
368+
val valueGenerator = UnsafeProjection.create(value.get :+ Literal(false))
369+
370+
while (input.hasNext) {
371+
val row = input.next().asInstanceOf[UnsafeRow]
372+
numFields = row.numFields() + 1
373+
val key = keyGenerator(row)
374+
val value = valueGenerator(row)
335375
val loc = binaryMap.lookup(key.getBaseObject, key.getBaseOffset, key.getSizeInBytes)
336376
val success = loc.append(
337377
key.getBaseObject, key.getBaseOffset, key.getSizeInBytes,
338-
row.getBaseObject, row.getBaseOffset, row.getSizeInBytes)
378+
value.getBaseObject, value.getBaseOffset, value.getSizeInBytes)
339379
if (!success) {
340380
binaryMap.free()
341381
// scalastyle:off throwerror
342382
throw new SparkOutOfMemoryError("There is not enough memory to build hash map")
343383
// scalastyle:on throwerror
344384
}
345-
} else if (isNullAware) {
346-
return EmptyHashedRelationWithAllNullKeys
385+
}
386+
} else {
387+
while (input.hasNext) {
388+
val row = input.next().asInstanceOf[UnsafeRow]
389+
numFields = row.numFields()
390+
val key = keyGenerator(row)
391+
if (!key.anyNull) {
392+
val loc = binaryMap.lookup(key.getBaseObject, key.getBaseOffset, key.getSizeInBytes)
393+
val success = loc.append(
394+
key.getBaseObject, key.getBaseOffset, key.getSizeInBytes,
395+
row.getBaseObject, row.getBaseOffset, row.getSizeInBytes)
396+
if (!success) {
397+
binaryMap.free()
398+
// scalastyle:off throwerror
399+
throw new SparkOutOfMemoryError("There is not enough memory to build hash map")
400+
// scalastyle:on throwerror
401+
}
402+
} else if (isNullAware) {
403+
return EmptyHashedRelationWithAllNullKeys
404+
}
347405
}
348406
}
349407

350-
new UnsafeHashedRelation(key.size, numFields, binaryMap)
408+
new UnsafeHashedRelation(key.size, numFields, binaryMap, isLookupAware)
351409
}
352410
}
353411

@@ -885,6 +943,10 @@ class LongHashedRelation(
885943
* Returns an iterator for keys of InternalRow type.
886944
*/
887945
override def keys(): Iterator[InternalRow] = map.keys()
946+
947+
override def values(): Iterator[InternalRow] = {
948+
throw new UnsupportedOperationException
949+
}
888950
}
889951

890952
/**
@@ -939,6 +1001,10 @@ trait NullAwareHashedRelation extends HashedRelation with Externalizable {
9391001
throw new UnsupportedOperationException
9401002
}
9411003

1004+
override def values(): Iterator[InternalRow] = {
1005+
throw new UnsupportedOperationException
1006+
}
1007+
9421008
override def close(): Unit = {}
9431009

9441010
override def writeExternal(out: ObjectOutput): Unit = {}

0 commit comments

Comments
 (0)