Skip to content

Commit 0fdacaf

Browse files
author
Davies Liu
committed
fix broadcast and thread-safety of UnsafeProjection
1 parent 2fc3ef6 commit 0fdacaf

File tree

8 files changed

+33
-63
lines changed

8 files changed

+33
-63
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala

Lines changed: 3 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -17,17 +17,15 @@
1717

1818
package org.apache.spark.sql.execution.joins
1919

20-
import scala.concurrent._
2120
import scala.concurrent.duration._
2221

23-
import org.apache.spark.{InternalAccumulator, TaskContext}
2422
import org.apache.spark.annotation.DeveloperApi
2523
import org.apache.spark.rdd.RDD
2624
import org.apache.spark.sql.catalyst.InternalRow
2725
import org.apache.spark.sql.catalyst.expressions.Expression
2826
import org.apache.spark.sql.catalyst.plans.physical.{Distribution, Partitioning, UnspecifiedDistribution}
2927
import org.apache.spark.sql.execution.{BinaryNode, SparkPlan}
30-
import org.apache.spark.util.ThreadUtils
28+
import org.apache.spark.{InternalAccumulator, TaskContext}
3129

3230
/**
3331
* :: DeveloperApi ::
@@ -59,16 +57,11 @@ case class BroadcastHashJoin(
5957
override def requiredChildDistribution: Seq[Distribution] =
6058
UnspecifiedDistribution :: UnspecifiedDistribution :: Nil
6159

62-
@transient
63-
private val broadcastFuture = future {
60+
protected override def doExecute(): RDD[InternalRow] = {
6461
// Note that we use .execute().collect() because we don't want to convert data to Scala types
6562
val input: Array[InternalRow] = buildPlan.execute().map(_.copy()).collect()
6663
val hashed = HashedRelation(input.iterator, buildSideKeyGenerator, input.size)
67-
sparkContext.broadcast(hashed)
68-
}(BroadcastHashJoin.broadcastHashJoinExecutionContext)
69-
70-
protected override def doExecute(): RDD[InternalRow] = {
71-
val broadcastRelation = Await.result(broadcastFuture, timeout)
64+
val broadcastRelation = sparkContext.broadcast(hashed)
7265

7366
streamedPlan.execute().mapPartitions { streamedIter =>
7467
val hashedRelation = broadcastRelation.value
@@ -82,9 +75,3 @@ case class BroadcastHashJoin(
8275
}
8376
}
8477
}
85-
86-
object BroadcastHashJoin {
87-
88-
private val broadcastHashJoinExecutionContext = ExecutionContext.fromExecutorService(
89-
ThreadUtils.newDaemonCachedThreadPool("broadcast-hash-join", 128))
90-
}

sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala

Lines changed: 7 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -17,18 +17,16 @@
1717

1818
package org.apache.spark.sql.execution.joins
1919

20-
import scala.concurrent._
2120
import scala.concurrent.duration._
2221

23-
import org.apache.spark.{InternalAccumulator, TaskContext}
2422
import org.apache.spark.annotation.DeveloperApi
2523
import org.apache.spark.rdd.RDD
2624
import org.apache.spark.sql.catalyst.InternalRow
2725
import org.apache.spark.sql.catalyst.expressions._
28-
import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, Distribution, UnspecifiedDistribution}
26+
import org.apache.spark.sql.catalyst.plans.physical.{Distribution, Partitioning, UnspecifiedDistribution}
2927
import org.apache.spark.sql.catalyst.plans.{JoinType, LeftOuter, RightOuter}
3028
import org.apache.spark.sql.execution.{BinaryNode, SparkPlan}
31-
import org.apache.spark.util.ThreadUtils
29+
import org.apache.spark.{InternalAccumulator, TaskContext}
3230

3331
/**
3432
* :: DeveloperApi ::
@@ -60,16 +58,11 @@ case class BroadcastHashOuterJoin(
6058

6159
override def outputPartitioning: Partitioning = streamedPlan.outputPartitioning
6260

63-
@transient
64-
private val broadcastFuture = future {
61+
override def doExecute(): RDD[InternalRow] = {
6562
// Note that we use .execute().collect() because we don't want to convert data to Scala types
6663
val input: Array[InternalRow] = buildPlan.execute().map(_.copy()).collect()
6764
val hashed = HashedRelation(input.iterator, buildKeyGenerator, input.size)
68-
sparkContext.broadcast(hashed)
69-
}(BroadcastHashOuterJoin.broadcastHashOuterJoinExecutionContext)
70-
71-
override def doExecute(): RDD[InternalRow] = {
72-
val broadcastRelation = Await.result(broadcastFuture, timeout)
65+
val broadcastRelation = sparkContext.broadcast(hashed)
7366

7467
streamedPlan.execute().mapPartitions { streamedIter =>
7568
val joinedRow = new JoinedRow()
@@ -83,19 +76,20 @@ case class BroadcastHashOuterJoin(
8376
case _ =>
8477
}
8578

79+
val resultProj = resultProjection
8680
joinType match {
8781
case LeftOuter =>
8882
streamedIter.flatMap(currentRow => {
8983
val rowKey = keyGenerator(currentRow)
9084
joinedRow.withLeft(currentRow)
91-
leftOuterIterator(rowKey, joinedRow, hashTable.get(rowKey))
85+
leftOuterIterator(rowKey, joinedRow, hashTable.get(rowKey), resultProj)
9286
})
9387

9488
case RightOuter =>
9589
streamedIter.flatMap(currentRow => {
9690
val rowKey = keyGenerator(currentRow)
9791
joinedRow.withRight(currentRow)
98-
rightOuterIterator(rowKey, hashTable.get(rowKey), joinedRow)
92+
rightOuterIterator(rowKey, hashTable.get(rowKey), joinedRow, resultProj)
9993
})
10094

10195
case x =>
@@ -105,9 +99,3 @@ case class BroadcastHashOuterJoin(
10599
}
106100
}
107101
}
108-
109-
object BroadcastHashOuterJoin {
110-
111-
private val broadcastHashOuterJoinExecutionContext = ExecutionContext.fromExecutorService(
112-
ThreadUtils.newDaemonCachedThreadPool("broadcast-hash-outer-join", 128))
113-
}

sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -76,22 +76,22 @@ trait HashOuterJoin {
7676
override def canProcessUnsafeRows: Boolean = isUnsafeMode
7777
override def canProcessSafeRows: Boolean = !isUnsafeMode
7878

79-
@transient protected lazy val buildKeyGenerator: Projection =
79+
@transient protected def buildKeyGenerator: Projection =
8080
if (isUnsafeMode) {
8181
UnsafeProjection.create(buildKeys, buildPlan.output)
8282
} else {
8383
newMutableProjection(buildKeys, buildPlan.output)()
8484
}
8585

86-
@transient protected[this] lazy val streamedKeyGenerator: Projection = {
86+
@transient protected[this] def streamedKeyGenerator: Projection = {
8787
if (isUnsafeMode) {
8888
UnsafeProjection.create(streamedKeys, streamedPlan.output)
8989
} else {
9090
newProjection(streamedKeys, streamedPlan.output)
9191
}
9292
}
9393

94-
@transient private[this] lazy val resultProjection: InternalRow => InternalRow = {
94+
@transient protected[this] def resultProjection: InternalRow => InternalRow = {
9595
if (isUnsafeMode) {
9696
UnsafeProjection.create(self.schema)
9797
} else {
@@ -113,23 +113,24 @@ trait HashOuterJoin {
113113
protected[this] def leftOuterIterator(
114114
key: InternalRow,
115115
joinedRow: JoinedRow,
116-
rightIter: Iterable[InternalRow]): Iterator[InternalRow] = {
116+
rightIter: Iterable[InternalRow],
117+
resultProjection: InternalRow => InternalRow): Iterator[InternalRow] = {
117118
val ret: Iterable[InternalRow] = {
118119
if (!key.anyNull) {
119120
val temp = if (rightIter != null) {
120121
rightIter.collect {
121-
case r if boundCondition(joinedRow.withRight(r)) => resultProjection(joinedRow).copy()
122+
case r if boundCondition(joinedRow.withRight(r)) => resultProjection(joinedRow)
122123
}
123124
} else {
124125
List.empty
125126
}
126127
if (temp.isEmpty) {
127-
resultProjection(joinedRow.withRight(rightNullRow)).copy :: Nil
128+
resultProjection(joinedRow.withRight(rightNullRow)) :: Nil
128129
} else {
129130
temp
130131
}
131132
} else {
132-
resultProjection(joinedRow.withRight(rightNullRow)).copy :: Nil
133+
resultProjection(joinedRow.withRight(rightNullRow)) :: Nil
133134
}
134135
}
135136
ret.iterator
@@ -138,24 +139,24 @@ trait HashOuterJoin {
138139
protected[this] def rightOuterIterator(
139140
key: InternalRow,
140141
leftIter: Iterable[InternalRow],
141-
joinedRow: JoinedRow): Iterator[InternalRow] = {
142+
joinedRow: JoinedRow,
143+
resultProjection: InternalRow => InternalRow): Iterator[InternalRow] = {
142144
val ret: Iterable[InternalRow] = {
143145
if (!key.anyNull) {
144146
val temp = if (leftIter != null) {
145147
leftIter.collect {
146-
case l if boundCondition(joinedRow.withLeft(l)) =>
147-
resultProjection(joinedRow).copy()
148+
case l if boundCondition(joinedRow.withLeft(l)) => resultProjection(joinedRow)
148149
}
149150
} else {
150151
List.empty
151152
}
152153
if (temp.isEmpty) {
153-
resultProjection(joinedRow.withLeft(leftNullRow)).copy :: Nil
154+
resultProjection(joinedRow.withLeft(leftNullRow)) :: Nil
154155
} else {
155156
temp
156157
}
157158
} else {
158-
resultProjection(joinedRow.withLeft(leftNullRow)).copy :: Nil
159+
resultProjection(joinedRow.withLeft(leftNullRow)) :: Nil
159160
}
160161
}
161162
ret.iterator

sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -43,14 +43,14 @@ trait HashSemiJoin {
4343
override def canProcessUnsafeRows: Boolean = supportUnsafe
4444
override def canProcessSafeRows: Boolean = !supportUnsafe
4545

46-
@transient protected lazy val leftKeyGenerator: Projection =
46+
@transient protected def leftKeyGenerator: Projection =
4747
if (supportUnsafe) {
4848
UnsafeProjection.create(leftKeys, left.output)
4949
} else {
5050
newMutableProjection(leftKeys, left.output)()
5151
}
5252

53-
@transient protected lazy val rightKeyGenerator: Projection =
53+
@transient protected def rightKeyGenerator: Projection =
5454
if (supportUnsafe) {
5555
UnsafeProjection.create(rightKeys, right.output)
5656
} else {
@@ -62,12 +62,11 @@ trait HashSemiJoin {
6262

6363
protected def buildKeyHashSet(buildIter: Iterator[InternalRow]): java.util.Set[InternalRow] = {
6464
val hashSet = new java.util.HashSet[InternalRow]()
65-
var currentRow: InternalRow = null
6665

6766
// Create a Hash set of buildKeys
6867
val rightKey = rightKeyGenerator
6968
while (buildIter.hasNext) {
70-
currentRow = buildIter.next()
69+
val currentRow = buildIter.next()
7170
val rowKey = rightKey(currentRow)
7271
if (!rowKey.anyNull) {
7372
val keyExists = hashSet.contains(rowKey)
@@ -76,9 +75,7 @@ trait HashSemiJoin {
7675
}
7776
}
7877
}
79-
// scalastyle:off println
80-
println(s"Build HashSet with ${hashSet.size()} items")
81-
// scalastyle:on println
78+
8279
hashSet
8380
}
8481

sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -359,10 +359,6 @@ private[joins] object UnsafeHashedRelation {
359359
}
360360
}
361361

362-
// scalastyle:off println
363-
println(s"Build UnsafeHashedRelation with ${hashTable.size()} items")
364-
// scalastyle:on println
365-
366362
new UnsafeHashedRelation(hashTable)
367363
}
368364
}

sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashOuterJoin.scala

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,19 +60,21 @@ case class ShuffledHashOuterJoin(
6060
case LeftOuter =>
6161
val hashed = HashedRelation(rightIter, buildKeyGenerator)
6262
val keyGenerator = streamedKeyGenerator
63+
val resultProj = resultProjection
6364
leftIter.flatMap( currentRow => {
6465
val rowKey = keyGenerator(currentRow)
6566
joinedRow.withLeft(currentRow)
66-
leftOuterIterator(rowKey, joinedRow, hashed.get(rowKey))
67+
leftOuterIterator(rowKey, joinedRow, hashed.get(rowKey), resultProj)
6768
})
6869

6970
case RightOuter =>
7071
val hashed = HashedRelation(leftIter, buildKeyGenerator)
7172
val keyGenerator = streamedKeyGenerator
73+
val resultProj = resultProjection
7274
rightIter.flatMap ( currentRow => {
7375
val rowKey = keyGenerator(currentRow)
7476
joinedRow.withRight(currentRow)
75-
rightOuterIterator(rowKey, hashed.get(rowKey), joinedRow)
77+
rightOuterIterator(rowKey, hashed.get(rowKey), joinedRow, resultProj)
7678
})
7779

7880
case FullOuter =>

sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -348,9 +348,6 @@ abstract class HiveComparisonTest
348348

349349
// Run w/ catalyst
350350
val catalystResults = queryList.zip(hiveResults).map { case (queryString, hive) =>
351-
// scalastyle:off println
352-
println("Run :" + queryString)
353-
// scalastyle:on println
354351
val query = new TestHive.QueryExecution(queryString)
355352
try { (query, prepareAnswer(query, query.stringResult())) } catch {
356353
case e: Throwable =>

sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQueryFileTest.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,8 @@ abstract class HiveQueryFileTest extends HiveComparisonTest {
6161
val queriesString = fileToString(testCaseFile)
6262
if (testCaseName == "semijoin") {
6363
(1 to 100).foreach(x => createQueryTest(testCaseName, queriesString))
64+
} else {
65+
createQueryTest(testCaseName, queriesString)
6466
}
6567
} else {
6668
// Only output warnings for the built in whitelist as this clutters the output when the user

0 commit comments

Comments
 (0)