Skip to content

Commit 93085c9

Browse files
Davies Liudavies
authored andcommitted
[SPARK-9482] [SQL] Fix thread-safey issue of using UnsafeProjection in join
This PR also change to use `def` instead of `lazy val` for UnsafeProjection, because it's not thread safe. TODO: cleanup the debug code once the flaky test passed 100 times. Author: Davies Liu <[email protected]> Closes apache#7940 from davies/semijoin and squashes the following commits: 93baac7 [Davies Liu] fix outerjoin 5c40ded [Davies Liu] address comments aa3de46 [Davies Liu] Merge branch 'master' of github.com:apache/spark into semijoin 7590a25 [Davies Liu] Merge branch 'master' of github.com:apache/spark into semijoin 2d4085b [Davies Liu] use def for resultProjection 0833407 [Davies Liu] Merge branch 'semijoin' of github.com:davies/spark into semijoin e0d8c71 [Davies Liu] use lazy val 6a59e8f [Davies Liu] Update HashedRelation.scala 0fdacaf [Davies Liu] fix broadcast and thread-safety of UnsafeProjection 2fc3ef6 [Davies Liu] reproduce failure in semijoin
1 parent 5b965d6 commit 93085c9

File tree

8 files changed

+44
-44
lines changed

8 files changed

+44
-44
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,14 +20,14 @@ package org.apache.spark.sql.execution.joins
2020
import scala.concurrent._
2121
import scala.concurrent.duration._
2222

23-
import org.apache.spark.{InternalAccumulator, TaskContext}
2423
import org.apache.spark.annotation.DeveloperApi
2524
import org.apache.spark.rdd.RDD
2625
import org.apache.spark.sql.catalyst.InternalRow
2726
import org.apache.spark.sql.catalyst.expressions.Expression
2827
import org.apache.spark.sql.catalyst.plans.physical.{Distribution, Partitioning, UnspecifiedDistribution}
29-
import org.apache.spark.sql.execution.{BinaryNode, SparkPlan, SQLExecution}
28+
import org.apache.spark.sql.execution.{BinaryNode, SQLExecution, SparkPlan}
3029
import org.apache.spark.util.ThreadUtils
30+
import org.apache.spark.{InternalAccumulator, TaskContext}
3131

3232
/**
3333
* :: DeveloperApi ::
@@ -102,6 +102,6 @@ case class BroadcastHashJoin(
102102

103103
object BroadcastHashJoin {
104104

105-
private val broadcastHashJoinExecutionContext = ExecutionContext.fromExecutorService(
105+
private[joins] val broadcastHashJoinExecutionContext = ExecutionContext.fromExecutorService(
106106
ThreadUtils.newDaemonCachedThreadPool("broadcast-hash-join", 128))
107107
}

sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala

Lines changed: 7 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -20,15 +20,14 @@ package org.apache.spark.sql.execution.joins
2020
import scala.concurrent._
2121
import scala.concurrent.duration._
2222

23-
import org.apache.spark.{InternalAccumulator, TaskContext}
2423
import org.apache.spark.annotation.DeveloperApi
2524
import org.apache.spark.rdd.RDD
2625
import org.apache.spark.sql.catalyst.InternalRow
2726
import org.apache.spark.sql.catalyst.expressions._
28-
import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, Distribution, UnspecifiedDistribution}
27+
import org.apache.spark.sql.catalyst.plans.physical.{Distribution, Partitioning, UnspecifiedDistribution}
2928
import org.apache.spark.sql.catalyst.plans.{JoinType, LeftOuter, RightOuter}
30-
import org.apache.spark.sql.execution.{BinaryNode, SparkPlan, SQLExecution}
31-
import org.apache.spark.util.ThreadUtils
29+
import org.apache.spark.sql.execution.{BinaryNode, SQLExecution, SparkPlan}
30+
import org.apache.spark.{InternalAccumulator, TaskContext}
3231

3332
/**
3433
* :: DeveloperApi ::
@@ -76,7 +75,7 @@ case class BroadcastHashOuterJoin(
7675
val hashed = HashedRelation(input.iterator, buildKeyGenerator, input.size)
7776
sparkContext.broadcast(hashed)
7877
}
79-
}(BroadcastHashOuterJoin.broadcastHashOuterJoinExecutionContext)
78+
}(BroadcastHashJoin.broadcastHashJoinExecutionContext)
8079
}
8180

8281
protected override def doPrepare(): Unit = {
@@ -98,19 +97,20 @@ case class BroadcastHashOuterJoin(
9897
case _ =>
9998
}
10099

100+
val resultProj = resultProjection
101101
joinType match {
102102
case LeftOuter =>
103103
streamedIter.flatMap(currentRow => {
104104
val rowKey = keyGenerator(currentRow)
105105
joinedRow.withLeft(currentRow)
106-
leftOuterIterator(rowKey, joinedRow, hashTable.get(rowKey))
106+
leftOuterIterator(rowKey, joinedRow, hashTable.get(rowKey), resultProj)
107107
})
108108

109109
case RightOuter =>
110110
streamedIter.flatMap(currentRow => {
111111
val rowKey = keyGenerator(currentRow)
112112
joinedRow.withRight(currentRow)
113-
rightOuterIterator(rowKey, hashTable.get(rowKey), joinedRow)
113+
rightOuterIterator(rowKey, hashTable.get(rowKey), joinedRow, resultProj)
114114
})
115115

116116
case x =>
@@ -120,9 +120,3 @@ case class BroadcastHashOuterJoin(
120120
}
121121
}
122122
}
123-
124-
object BroadcastHashOuterJoin {
125-
126-
private val broadcastHashOuterJoinExecutionContext = ExecutionContext.fromExecutorService(
127-
ThreadUtils.newDaemonCachedThreadPool("broadcast-hash-outer-join", 128))
128-
}

sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ case class BroadcastNestedLoopJoin(
4747
override def outputsUnsafeRows: Boolean = left.outputsUnsafeRows || right.outputsUnsafeRows
4848
override def canProcessUnsafeRows: Boolean = true
4949

50-
@transient private[this] lazy val resultProjection: InternalRow => InternalRow = {
50+
private[this] def genResultProjection: InternalRow => InternalRow = {
5151
if (outputsUnsafeRows) {
5252
UnsafeProjection.create(schema)
5353
} else {
@@ -88,6 +88,7 @@ case class BroadcastNestedLoopJoin(
8888

8989
val leftNulls = new GenericMutableRow(left.output.size)
9090
val rightNulls = new GenericMutableRow(right.output.size)
91+
val resultProj = genResultProjection
9192

9293
streamedIter.foreach { streamedRow =>
9394
var i = 0
@@ -97,11 +98,11 @@ case class BroadcastNestedLoopJoin(
9798
val broadcastedRow = broadcastedRelation.value(i)
9899
buildSide match {
99100
case BuildRight if boundCondition(joinedRow(streamedRow, broadcastedRow)) =>
100-
matchedRows += resultProjection(joinedRow(streamedRow, broadcastedRow)).copy()
101+
matchedRows += resultProj(joinedRow(streamedRow, broadcastedRow)).copy()
101102
streamRowMatched = true
102103
includedBroadcastTuples += i
103104
case BuildLeft if boundCondition(joinedRow(broadcastedRow, streamedRow)) =>
104-
matchedRows += resultProjection(joinedRow(broadcastedRow, streamedRow)).copy()
105+
matchedRows += resultProj(joinedRow(broadcastedRow, streamedRow)).copy()
105106
streamRowMatched = true
106107
includedBroadcastTuples += i
107108
case _ =>
@@ -111,9 +112,9 @@ case class BroadcastNestedLoopJoin(
111112

112113
(streamRowMatched, joinType, buildSide) match {
113114
case (false, LeftOuter | FullOuter, BuildRight) =>
114-
matchedRows += resultProjection(joinedRow(streamedRow, rightNulls)).copy()
115+
matchedRows += resultProj(joinedRow(streamedRow, rightNulls)).copy()
115116
case (false, RightOuter | FullOuter, BuildLeft) =>
116-
matchedRows += resultProjection(joinedRow(leftNulls, streamedRow)).copy()
117+
matchedRows += resultProj(joinedRow(leftNulls, streamedRow)).copy()
117118
case _ =>
118119
}
119120
}
@@ -127,6 +128,8 @@ case class BroadcastNestedLoopJoin(
127128

128129
val leftNulls = new GenericMutableRow(left.output.size)
129130
val rightNulls = new GenericMutableRow(right.output.size)
131+
val resultProj = genResultProjection
132+
130133
/** Rows from broadcasted joined with nulls. */
131134
val broadcastRowsWithNulls: Seq[InternalRow] = {
132135
val buf: CompactBuffer[InternalRow] = new CompactBuffer()
@@ -138,7 +141,7 @@ case class BroadcastNestedLoopJoin(
138141
joinedRow.withLeft(leftNulls)
139142
while (i < rel.length) {
140143
if (!allIncludedBroadcastTuples.contains(i)) {
141-
buf += resultProjection(joinedRow.withRight(rel(i))).copy()
144+
buf += resultProj(joinedRow.withRight(rel(i))).copy()
142145
}
143146
i += 1
144147
}
@@ -147,7 +150,7 @@ case class BroadcastNestedLoopJoin(
147150
joinedRow.withRight(rightNulls)
148151
while (i < rel.length) {
149152
if (!allIncludedBroadcastTuples.contains(i)) {
150-
buf += resultProjection(joinedRow.withLeft(rel(i))).copy()
153+
buf += resultProj(joinedRow.withLeft(rel(i))).copy()
151154
}
152155
i += 1
153156
}

sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,14 +52,14 @@ trait HashJoin {
5252
override def canProcessUnsafeRows: Boolean = isUnsafeMode
5353
override def canProcessSafeRows: Boolean = !isUnsafeMode
5454

55-
@transient protected lazy val buildSideKeyGenerator: Projection =
55+
protected def buildSideKeyGenerator: Projection =
5656
if (isUnsafeMode) {
5757
UnsafeProjection.create(buildKeys, buildPlan.output)
5858
} else {
5959
newMutableProjection(buildKeys, buildPlan.output)()
6060
}
6161

62-
@transient protected lazy val streamSideKeyGenerator: Projection =
62+
protected def streamSideKeyGenerator: Projection =
6363
if (isUnsafeMode) {
6464
UnsafeProjection.create(streamedKeys, streamedPlan.output)
6565
} else {

sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -76,22 +76,22 @@ trait HashOuterJoin {
7676
override def canProcessUnsafeRows: Boolean = isUnsafeMode
7777
override def canProcessSafeRows: Boolean = !isUnsafeMode
7878

79-
@transient protected lazy val buildKeyGenerator: Projection =
79+
protected def buildKeyGenerator: Projection =
8080
if (isUnsafeMode) {
8181
UnsafeProjection.create(buildKeys, buildPlan.output)
8282
} else {
8383
newMutableProjection(buildKeys, buildPlan.output)()
8484
}
8585

86-
@transient protected[this] lazy val streamedKeyGenerator: Projection = {
86+
protected[this] def streamedKeyGenerator: Projection = {
8787
if (isUnsafeMode) {
8888
UnsafeProjection.create(streamedKeys, streamedPlan.output)
8989
} else {
9090
newProjection(streamedKeys, streamedPlan.output)
9191
}
9292
}
9393

94-
@transient private[this] lazy val resultProjection: InternalRow => InternalRow = {
94+
protected[this] def resultProjection: InternalRow => InternalRow = {
9595
if (isUnsafeMode) {
9696
UnsafeProjection.create(self.schema)
9797
} else {
@@ -113,7 +113,8 @@ trait HashOuterJoin {
113113
protected[this] def leftOuterIterator(
114114
key: InternalRow,
115115
joinedRow: JoinedRow,
116-
rightIter: Iterable[InternalRow]): Iterator[InternalRow] = {
116+
rightIter: Iterable[InternalRow],
117+
resultProjection: InternalRow => InternalRow): Iterator[InternalRow] = {
117118
val ret: Iterable[InternalRow] = {
118119
if (!key.anyNull) {
119120
val temp = if (rightIter != null) {
@@ -124,12 +125,12 @@ trait HashOuterJoin {
124125
List.empty
125126
}
126127
if (temp.isEmpty) {
127-
resultProjection(joinedRow.withRight(rightNullRow)).copy :: Nil
128+
resultProjection(joinedRow.withRight(rightNullRow)) :: Nil
128129
} else {
129130
temp
130131
}
131132
} else {
132-
resultProjection(joinedRow.withRight(rightNullRow)).copy :: Nil
133+
resultProjection(joinedRow.withRight(rightNullRow)) :: Nil
133134
}
134135
}
135136
ret.iterator
@@ -138,24 +139,24 @@ trait HashOuterJoin {
138139
protected[this] def rightOuterIterator(
139140
key: InternalRow,
140141
leftIter: Iterable[InternalRow],
141-
joinedRow: JoinedRow): Iterator[InternalRow] = {
142+
joinedRow: JoinedRow,
143+
resultProjection: InternalRow => InternalRow): Iterator[InternalRow] = {
142144
val ret: Iterable[InternalRow] = {
143145
if (!key.anyNull) {
144146
val temp = if (leftIter != null) {
145147
leftIter.collect {
146-
case l if boundCondition(joinedRow.withLeft(l)) =>
147-
resultProjection(joinedRow).copy()
148+
case l if boundCondition(joinedRow.withLeft(l)) => resultProjection(joinedRow).copy()
148149
}
149150
} else {
150151
List.empty
151152
}
152153
if (temp.isEmpty) {
153-
resultProjection(joinedRow.withLeft(leftNullRow)).copy :: Nil
154+
resultProjection(joinedRow.withLeft(leftNullRow)) :: Nil
154155
} else {
155156
temp
156157
}
157158
} else {
158-
resultProjection(joinedRow.withLeft(leftNullRow)).copy :: Nil
159+
resultProjection(joinedRow.withLeft(leftNullRow)) :: Nil
159160
}
160161
}
161162
ret.iterator

sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -43,14 +43,14 @@ trait HashSemiJoin {
4343
override def canProcessUnsafeRows: Boolean = supportUnsafe
4444
override def canProcessSafeRows: Boolean = !supportUnsafe
4545

46-
@transient protected lazy val leftKeyGenerator: Projection =
46+
protected def leftKeyGenerator: Projection =
4747
if (supportUnsafe) {
4848
UnsafeProjection.create(leftKeys, left.output)
4949
} else {
5050
newMutableProjection(leftKeys, left.output)()
5151
}
5252

53-
@transient protected lazy val rightKeyGenerator: Projection =
53+
protected def rightKeyGenerator: Projection =
5454
if (supportUnsafe) {
5555
UnsafeProjection.create(rightKeys, right.output)
5656
} else {
@@ -62,12 +62,11 @@ trait HashSemiJoin {
6262

6363
protected def buildKeyHashSet(buildIter: Iterator[InternalRow]): java.util.Set[InternalRow] = {
6464
val hashSet = new java.util.HashSet[InternalRow]()
65-
var currentRow: InternalRow = null
6665

6766
// Create a Hash set of buildKeys
6867
val rightKey = rightKeyGenerator
6968
while (buildIter.hasNext) {
70-
currentRow = buildIter.next()
69+
val currentRow = buildIter.next()
7170
val rowKey = rightKey(currentRow)
7271
if (!rowKey.anyNull) {
7372
val keyExists = hashSet.contains(rowKey)
@@ -76,6 +75,7 @@ trait HashSemiJoin {
7675
}
7776
}
7877
}
78+
7979
hashSet
8080
}
8181

sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,11 @@
1717

1818
package org.apache.spark.sql.execution.joins
1919

20-
import java.io.{IOException, Externalizable, ObjectInput, ObjectOutput}
20+
import java.io.{Externalizable, IOException, ObjectInput, ObjectOutput}
2121
import java.nio.ByteOrder
2222
import java.util.{HashMap => JavaHashMap}
2323

2424
import org.apache.spark.shuffle.ShuffleMemoryManager
25-
import org.apache.spark.{SparkConf, SparkEnv, TaskContext}
2625
import org.apache.spark.sql.catalyst.InternalRow
2726
import org.apache.spark.sql.catalyst.expressions._
2827
import org.apache.spark.sql.execution.SparkSqlSerializer
@@ -31,6 +30,7 @@ import org.apache.spark.unsafe.map.BytesToBytesMap
3130
import org.apache.spark.unsafe.memory.{ExecutorMemoryManager, MemoryAllocator, TaskMemoryManager}
3231
import org.apache.spark.util.Utils
3332
import org.apache.spark.util.collection.CompactBuffer
33+
import org.apache.spark.{SparkConf, SparkEnv}
3434

3535

3636
/**

sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashOuterJoin.scala

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,19 +60,21 @@ case class ShuffledHashOuterJoin(
6060
case LeftOuter =>
6161
val hashed = HashedRelation(rightIter, buildKeyGenerator)
6262
val keyGenerator = streamedKeyGenerator
63+
val resultProj = resultProjection
6364
leftIter.flatMap( currentRow => {
6465
val rowKey = keyGenerator(currentRow)
6566
joinedRow.withLeft(currentRow)
66-
leftOuterIterator(rowKey, joinedRow, hashed.get(rowKey))
67+
leftOuterIterator(rowKey, joinedRow, hashed.get(rowKey), resultProj)
6768
})
6869

6970
case RightOuter =>
7071
val hashed = HashedRelation(leftIter, buildKeyGenerator)
7172
val keyGenerator = streamedKeyGenerator
73+
val resultProj = resultProjection
7274
rightIter.flatMap ( currentRow => {
7375
val rowKey = keyGenerator(currentRow)
7476
joinedRow.withRight(currentRow)
75-
rightOuterIterator(rowKey, hashed.get(rowKey), joinedRow)
77+
rightOuterIterator(rowKey, hashed.get(rowKey), joinedRow, resultProj)
7678
})
7779

7880
case FullOuter =>

0 commit comments

Comments
 (0)