Skip to content

Commit 8b18f79

Browse files
committed
Address all new comments
1 parent d4e0084 commit 8b18f79

File tree

4 files changed

+61
-38
lines changed

4 files changed

+61
-38
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -97,17 +97,18 @@ private[execution] object HashedRelation {
9797
/**
9898
* Create a HashedRelation from an Iterator of InternalRow.
9999
*
100-
* @param isLookupAware reserve one extra boolean in value to track if value being looked up
101-
* @param value the expressions for value inserted into HashedRelation
100+
* @param canMarkRowLookedUp Reserve one extra boolean in value to track if value being looked up.
101+
* This is only used for full outer shuffled hash join.
102+
* @param valueExprs The expressions for value inserted into HashedRelation.
102103
*/
103104
def apply(
104105
input: Iterator[InternalRow],
105106
key: Seq[Expression],
106107
sizeEstimate: Int = 64,
107108
taskMemoryManager: TaskMemoryManager = null,
108109
isNullAware: Boolean = false,
109-
isLookupAware: Boolean = false,
110-
value: Option[Seq[Expression]] = None): HashedRelation = {
110+
canMarkRowLookedUp: Boolean = false,
111+
valueExprs: Option[Seq[Expression]] = None): HashedRelation = {
111112
val mm = Option(taskMemoryManager).getOrElse {
112113
new TaskMemoryManager(
113114
new UnifiedMemoryManager(
@@ -120,12 +121,13 @@ private[execution] object HashedRelation {
120121

121122
if (isNullAware && !input.hasNext) {
122123
EmptyHashedRelation
123-
} else if (key.length == 1 && key.head.dataType == LongType && !isLookupAware) {
124-
// NOTE: LongHashedRelation cannot support isLookupAware as it cannot
124+
} else if (key.length == 1 && key.head.dataType == LongType && !canMarkRowLookedUp) {
125+
// NOTE: LongHashedRelation cannot support canMarkRowLookedUp as it cannot
125126
// handle NULL key
126127
LongHashedRelation(input, key, sizeEstimate, mm, isNullAware)
127128
} else {
128-
UnsafeHashedRelation(input, key, sizeEstimate, mm, isNullAware, isLookupAware, value)
129+
UnsafeHashedRelation(
130+
input, key, sizeEstimate, mm, isNullAware, canMarkRowLookedUp, valueExprs)
129131
}
130132
}
131133
}
@@ -344,12 +346,10 @@ private[joins] object UnsafeHashedRelation {
344346
sizeEstimate: Int,
345347
taskMemoryManager: TaskMemoryManager,
346348
isNullAware: Boolean = false,
347-
isLookupAware: Boolean = false,
348-
value: Option[Seq[Expression]] = None): HashedRelation = {
349-
if (isNullAware && isLookupAware) {
350-
throw new SparkException(
351-
"isLookupAware and isNullAware cannot be enabled at same time for UnsafeHashedRelation")
352-
}
349+
canMarkRowLookedUp: Boolean = false,
350+
valueExprs: Option[Seq[Expression]] = None): HashedRelation = {
351+
require(!(isNullAware && canMarkRowLookedUp),
352+
"isNullAware and canMarkRowLookedUp cannot be enabled at same time")
353353

354354
val pageSizeBytes = Option(SparkEnv.get).map(_.memoryManager.pageSizeBytes)
355355
.getOrElse(new SparkConf().get(BUFFER_PAGESIZE).getOrElse(16L * 1024 * 1024))
@@ -376,11 +376,11 @@ private[joins] object UnsafeHashedRelation {
376376
}
377377
}
378378

379-
if (isLookupAware) {
379+
if (canMarkRowLookedUp) {
380380
// Add one extra boolean value at the end as part of the row,
381381
// to track the information that whether the corresponding key
382382
// has been looked up or not. See `ShuffledHashJoin.fullOuterJoin` for example of usage.
383-
val valueGenerator = UnsafeProjection.create(value.get :+ Literal(false))
383+
val valueGenerator = UnsafeProjection.create(valueExprs.get :+ Literal(false))
384384
while (input.hasNext) {
385385
val row = input.next().asInstanceOf[UnsafeRow]
386386
numFields = row.numFields() + 1

sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ case class ShuffledHashJoinExec(
6666
val start = System.nanoTime()
6767
val context = TaskContext.get()
6868

69-
val (isLookupAware, value) = if (joinType == FullOuter) {
69+
val (canMarkRowLookedUp, valueExprs) = if (joinType == FullOuter) {
7070
(true, Some(BindReferences.bindReferences(buildOutput, buildOutput)))
7171
} else {
7272
(false, None)
@@ -75,8 +75,8 @@ case class ShuffledHashJoinExec(
7575
iter,
7676
buildBoundKeys,
7777
taskMemoryManager = context.taskMemoryManager(),
78-
isLookupAware = isLookupAware,
79-
value = value)
78+
canMarkRowLookedUp = canMarkRowLookedUp,
79+
valueExprs = valueExprs)
8080
buildTime += NANOSECONDS.toMillis(System.nanoTime() - start)
8181
buildDataSize += relation.estimatedSize
8282
// This relation is usually used until the end of task.
@@ -103,7 +103,7 @@ case class ShuffledHashJoinExec(
103103
* 2. Process rows from stream side by looking up hash relation,
104104
* and mark the matched rows from build side be looked up.
105105
* 3. Process rows from build side by iterating hash relation,
106-
* and filter out rows from build side being looked up already.
106+
* and filter out rows from build side being matched already.
107107
*/
108108
private def fullOuterJoin(
109109
streamIter: Iterator[InternalRow],
@@ -180,15 +180,26 @@ case class ShuffledHashJoinExec(
180180
}
181181
}
182182

183-
// Process build side with filtering out rows looked up already
183+
// Process build side with filtering out rows looked up and
184+
// passed join condition already
185+
val streamNullJoinRow = new JoinedRow
186+
val streamNullJoinRowWithBuild = {
187+
buildSide match {
188+
case BuildLeft =>
189+
streamNullJoinRow.withRight(streamNullRow)
190+
streamNullJoinRow.withLeft _
191+
case BuildRight =>
192+
streamNullJoinRow.withLeft(streamNullRow)
193+
streamNullJoinRow.withRight _
194+
}
195+
}
184196
val buildResultIter = hashedRelation.values().flatMap { brow =>
185197
val unsafebrow = brow.asInstanceOf[UnsafeRow]
186198
val isLookup = unsafebrow.getBoolean(unsafebrow.numFields() - 1)
187199
if (!isLookup) {
188200
val buildRow = buildRowGenerator(unsafebrow)
189-
joinRowWithBuild(buildRow)
190-
joinRowWithStream(streamNullRow)
191-
Some(joinRow)
201+
streamNullJoinRowWithBuild(buildRow)
202+
Some(streamNullJoinRow)
192203
} else {
193204
None
194205
}

sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1193,31 +1193,42 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan
11931193
val inputDFs = Seq(
11941194
// Test unique join key
11951195
(spark.range(10).selectExpr("id as k1"),
1196-
spark.range(30).selectExpr("id as k2")),
1196+
spark.range(30).selectExpr("id as k2"),
1197+
$"k1" === $"k2"),
11971198
// Test non-unique join key
11981199
(spark.range(10).selectExpr("id % 5 as k1"),
1199-
spark.range(30).selectExpr("id % 5 as k2")),
1200+
spark.range(30).selectExpr("id % 5 as k2"),
1201+
$"k1" === $"k2"),
12001202
// Test string join key
12011203
(spark.range(10).selectExpr("cast(id * 3 as string) as k1"),
1202-
spark.range(30).selectExpr("cast(id as string) as k2")),
1204+
spark.range(30).selectExpr("cast(id as string) as k2"),
1205+
$"k1" === $"k2"),
12031206
// Test build side at right
12041207
(spark.range(30).selectExpr("cast(id / 3 as string) as k1"),
1205-
spark.range(10).selectExpr("cast(id as string) as k2")),
1208+
spark.range(10).selectExpr("cast(id as string) as k2"),
1209+
$"k1" === $"k2"),
12061210
// Test NULL join key
12071211
(spark.range(10).map(i => if (i % 2 == 0) i else null).selectExpr("value as k1"),
1208-
spark.range(30).map(i => if (i % 4 == 0) i else null).selectExpr("value as k2"))
1212+
spark.range(30).map(i => if (i % 4 == 0) i else null).selectExpr("value as k2"),
1213+
$"k1" === $"k2"),
1214+
// Test multiple join keys
1215+
(spark.range(10).map(i => if (i % 2 == 0) i else null).selectExpr(
1216+
"value as k1", "cast(value % 5 as short) as k2", "cast(value * 3 as long) as k3"),
1217+
spark.range(30).map(i => if (i % 4 == 0) i else null).selectExpr(
1218+
"value as k4", "cast(value % 5 as short) as k5", "cast(value * 3 as long) as k6"),
1219+
$"k1" === $"k4" && $"k2" === $"k5" && $"k3" === $"k6")
12091220
)
1210-
inputDFs.foreach { case (df1, df2) =>
1221+
inputDFs.foreach { case (df1, df2, joinExprs) =>
12111222
withSQLConf(
12121223
SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80",
12131224
SQLConf.SHUFFLE_PARTITIONS.key -> "2") {
1214-
val smjDF = df1.join(df2, $"k1" === $"k2", "full")
1225+
val smjDF = df1.join(df2, joinExprs, "full")
12151226
assert(smjDF.queryExecution.executedPlan.collect {
12161227
case _: SortMergeJoinExec => true }.size === 1)
12171228
val smjResult = smjDF.collect()
12181229

12191230
withSQLConf(SQLConf.PREFER_SORTMERGEJOIN.key -> "false") {
1220-
val shjDF = df1.join(df2, $"k1" === $"k2", "full")
1231+
val shjDF = df1.join(df2, joinExprs, "full")
12211232
assert(shjDF.queryExecution.executedPlan.collect {
12221233
case _: ShuffledHashJoinExec => true }.size === 1)
12231234
// Same result between shuffled hash join and sort merge join

sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -586,22 +586,23 @@ class HashedRelationSuite extends SharedSparkSession {
586586
val value = Seq(BoundReference(0, IntegerType, true))
587587
val unsafeProj = UnsafeProjection.create(value)
588588
val rows = (0 until 100).map(i => unsafeProj(InternalRow(i + 1)).copy())
589+
val expectedValues = (0 until 100).map(i => i + 1)
589590

590591
// test LongHashedRelation
591592
val longRelation = LongHashedRelation(rows.iterator, key, 10, mm)
592593
var values = longRelation.values()
593-
assert(values.map(_.getInt(0)).toArray.sortWith(_ < _) === (0 until 100).map(i => i + 1))
594+
assert(values.map(_.getInt(0)).toArray.sortWith(_ < _) === expectedValues)
594595

595596
// test UnsafeHashedRelation
596597
val unsafeRelation = UnsafeHashedRelation(rows.iterator, key, 10, mm)
597598
values = unsafeRelation.values()
598-
assert(values.map(_.getInt(0)).toArray.sortWith(_ < _) === (0 until 100).map(i => i + 1))
599+
assert(values.map(_.getInt(0)).toArray.sortWith(_ < _) === expectedValues)
599600

600-
// test lookup-aware UnsafeHashedRelation
601-
val lookupAwareUnsafeRelation = UnsafeHashedRelation(
602-
rows.iterator, key, 10, mm, isLookupAware = true, value = Some(value))
603-
values = lookupAwareUnsafeRelation.values()
601+
// test UnsafeHashedRelation which can mark row looked up
602+
val markRowUnsafeRelation = UnsafeHashedRelation(
603+
rows.iterator, key, 10, mm, canMarkRowLookedUp = true, valueExprs = Some(value))
604+
values = markRowUnsafeRelation.values()
604605
assert(values.map(v => (v.getInt(0), v.getBoolean(1))).toArray.sortWith(_._1 < _._1)
605-
=== (0 until 100).map(i => (i + 1, false)))
606+
=== expectedValues.map(i => (i, false)))
606607
}
607608
}

0 commit comments

Comments
 (0)