Skip to content

Commit 617ce3b

Browse files
committed
[SPARK-18758][SS] StreamingQueryListener events from a StreamingQuery should be sent only to the listeners in the same session as the query
## What changes were proposed in this pull request? Listeners added with `sparkSession.streams.addListener(l)` are added to a SparkSession. So events only from queries in the same session as a listener should be posted to the listener. Currently, all the events gets rerouted through the Spark's main listener bus, that is, - StreamingQuery posts event to StreamingQueryListenerBus. Only the queries associated with the same session as the bus posts events to it. - StreamingQueryListenerBus posts event to Spark's main LiveListenerBus as a SparkEvent. - StreamingQueryListenerBus also subscribes to LiveListenerBus events thus getting back the posted event in a different thread. - The received is posted to the registered listeners. The problem is that *all StreamingQueryListenerBuses in all sessions* gets the events and posts them to their listeners. This is wrong. In this PR, I solve it by making StreamingQueryListenerBus track active queries (by their runIds) when a query posts the QueryStarted event to the bus. This allows the rerouted events to be filtered using the tracked queries. Note that this list needs to be maintained separately from the `StreamingQueryManager.activeQueries` because a terminated query is cleared from `StreamingQueryManager.activeQueries` as soon as it is stopped, but the this ListenerBus must clear a query only after the termination event of that query has been posted lazily, much after the query has been terminated. Credit goes to zsxwing for coming up with the initial idea. ## How was this patch tested? Updated test harness code to use the correct session, and added new unit test. Author: Tathagata Das <[email protected]> Closes #16186 from tdas/SPARK-18758. (cherry picked from commit 9ab725e) Signed-off-by: Tathagata Das <[email protected]>
1 parent 839c2eb commit 617ce3b

File tree

4 files changed

+119
-23
lines changed

4 files changed

+119
-23
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingQueryListenerBus.scala

Lines changed: 47 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,10 @@
1717

1818
package org.apache.spark.sql.execution.streaming
1919

20+
import java.util.UUID
21+
22+
import scala.collection.mutable
23+
2024
import org.apache.spark.scheduler.{LiveListenerBus, SparkListener, SparkListenerEvent}
2125
import org.apache.spark.sql.streaming.StreamingQueryListener
2226
import org.apache.spark.util.ListenerBus
@@ -25,7 +29,11 @@ import org.apache.spark.util.ListenerBus
2529
* A bus to forward events to [[StreamingQueryListener]]s. This one will send received
2630
* [[StreamingQueryListener.Event]]s to the Spark listener bus. It also registers itself with
2731
* Spark listener bus, so that it can receive [[StreamingQueryListener.Event]]s and dispatch them
28-
* to StreamingQueryListener.
32+
* to StreamingQueryListeners.
33+
*
34+
* Note that each bus and its registered listeners are associated with a single SparkSession
35+
* and StreamingQueryManager. So this bus will dispatch events to registered listeners for only
36+
* those queries that were started in the associated SparkSession.
2937
*/
3038
class StreamingQueryListenerBus(sparkListenerBus: LiveListenerBus)
3139
extends SparkListener with ListenerBus[StreamingQueryListener, StreamingQueryListener.Event] {
@@ -35,12 +43,30 @@ class StreamingQueryListenerBus(sparkListenerBus: LiveListenerBus)
3543
sparkListenerBus.addListener(this)
3644

3745
/**
38-
* Post a StreamingQueryListener event to the Spark listener bus asynchronously. This event will
39-
* be dispatched to all StreamingQueryListener in the thread of the Spark listener bus.
46+
* RunIds of active queries whose events are supposed to be forwarded by this ListenerBus
47+
* to registered `StreamingQueryListeners`.
48+
*
49+
* Note 1: We need to track runIds instead of ids because the runId is unique for every started
50+
* query, even it its a restart. So even if a query is restarted, this bus will identify them
51+
* separately and correctly account for the restart.
52+
*
53+
* Note 2: This list needs to be maintained separately from the
54+
* `StreamingQueryManager.activeQueries` because a terminated query is cleared from
55+
* `StreamingQueryManager.activeQueries` as soon as it is stopped, but the this ListenerBus
56+
* must clear a query only after the termination event of that query has been posted.
57+
*/
58+
private val activeQueryRunIds = new mutable.HashSet[UUID]
59+
60+
/**
61+
* Post a StreamingQueryListener event to the added StreamingQueryListeners.
62+
* Note that only the QueryStarted event is posted to the listener synchronously. Other events
63+
* are dispatched to Spark listener bus. This method is guaranteed to be called by queries in
64+
* the same SparkSession as this listener.
4065
*/
4166
def post(event: StreamingQueryListener.Event) {
4267
event match {
4368
case s: QueryStartedEvent =>
69+
activeQueryRunIds.synchronized { activeQueryRunIds += s.runId }
4470
sparkListenerBus.post(s)
4571
// post to local listeners to trigger callbacks
4672
postToAll(s)
@@ -63,18 +89,32 @@ class StreamingQueryListenerBus(sparkListenerBus: LiveListenerBus)
6389
}
6490
}
6591

92+
/**
93+
* Dispatch events to registered StreamingQueryListeners. Only the events associated queries
94+
* started in the same SparkSession as this ListenerBus will be dispatched to the listeners.
95+
*/
6696
override protected def doPostEvent(
6797
listener: StreamingQueryListener,
6898
event: StreamingQueryListener.Event): Unit = {
99+
def shouldReport(runId: UUID): Boolean = {
100+
activeQueryRunIds.synchronized { activeQueryRunIds.contains(runId) }
101+
}
102+
69103
event match {
70104
case queryStarted: QueryStartedEvent =>
71-
listener.onQueryStarted(queryStarted)
105+
if (shouldReport(queryStarted.runId)) {
106+
listener.onQueryStarted(queryStarted)
107+
}
72108
case queryProgress: QueryProgressEvent =>
73-
listener.onQueryProgress(queryProgress)
109+
if (shouldReport(queryProgress.progress.runId)) {
110+
listener.onQueryProgress(queryProgress)
111+
}
74112
case queryTerminated: QueryTerminatedEvent =>
75-
listener.onQueryTerminated(queryTerminated)
113+
if (shouldReport(queryTerminated.runId)) {
114+
listener.onQueryTerminated(queryTerminated)
115+
activeQueryRunIds.synchronized { activeQueryRunIds -= queryTerminated.runId }
116+
}
76117
case _ =>
77118
}
78119
}
79-
80120
}

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,11 +70,11 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext)
7070

7171
def schema: StructType = encoder.schema
7272

73-
def toDS()(implicit sqlContext: SQLContext): Dataset[A] = {
73+
def toDS(): Dataset[A] = {
7474
Dataset(sqlContext.sparkSession, logicalPlan)
7575
}
7676

77-
def toDF()(implicit sqlContext: SQLContext): DataFrame = {
77+
def toDF(): DataFrame = {
7878
Dataset.ofRows(sqlContext.sparkSession, logicalPlan)
7979
}
8080

sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -231,8 +231,8 @@ trait StreamTest extends QueryTest with SharedSQLContext with Timeouts {
231231
outputMode: OutputMode = OutputMode.Append)(actions: StreamAction*): Unit = {
232232

233233
val stream = _stream.toDF()
234+
val sparkSession = stream.sparkSession // use the session in DF, not the default session
234235
var pos = 0
235-
var currentPlan: LogicalPlan = stream.logicalPlan
236236
var currentStream: StreamExecution = null
237237
var lastStream: StreamExecution = null
238238
val awaiting = new mutable.HashMap[Int, Offset]() // source index -> offset to wait for
@@ -319,7 +319,6 @@ trait StreamTest extends QueryTest with SharedSQLContext with Timeouts {
319319
""".stripMargin)
320320
}
321321

322-
val testThread = Thread.currentThread()
323322
val metadataRoot = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
324323
var manualClockExpectedTime = -1L
325324
try {
@@ -337,14 +336,16 @@ trait StreamTest extends QueryTest with SharedSQLContext with Timeouts {
337336

338337
additionalConfs.foreach(pair => {
339338
val value =
340-
if (spark.conf.contains(pair._1)) Some(spark.conf.get(pair._1)) else None
339+
if (sparkSession.conf.contains(pair._1)) {
340+
Some(sparkSession.conf.get(pair._1))
341+
} else None
341342
resetConfValues(pair._1) = value
342-
spark.conf.set(pair._1, pair._2)
343+
sparkSession.conf.set(pair._1, pair._2)
343344
})
344345

345346
lastStream = currentStream
346347
currentStream =
347-
spark
348+
sparkSession
348349
.streams
349350
.startQuery(
350351
None,
@@ -518,8 +519,8 @@ trait StreamTest extends QueryTest with SharedSQLContext with Timeouts {
518519

519520
// Rollback prev configuration values
520521
resetConfValues.foreach {
521-
case (key, Some(value)) => spark.conf.set(key, value)
522-
case (key, None) => spark.conf.unset(key)
522+
case (key, Some(value)) => sparkSession.conf.set(key, value)
523+
case (key, None) => sparkSession.conf.unset(key)
523524
}
524525
}
525526
}

sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala

Lines changed: 62 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ package org.apache.spark.sql.streaming
2020
import java.util.UUID
2121

2222
import scala.collection.mutable
23+
import scala.concurrent.duration._
2324

2425
import org.scalactic.TolerantNumerics
2526
import org.scalatest.concurrent.AsyncAssertions.Waiter
@@ -30,6 +31,7 @@ import org.scalatest.PrivateMethodTester._
3031

3132
import org.apache.spark.SparkException
3233
import org.apache.spark.scheduler._
34+
import org.apache.spark.sql.{Encoder, SparkSession}
3335
import org.apache.spark.sql.execution.streaming._
3436
import org.apache.spark.sql.internal.SQLConf
3537
import org.apache.spark.sql.streaming.StreamingQueryListener._
@@ -45,7 +47,7 @@ class StreamingQueryListenerSuite extends StreamTest with BeforeAndAfter {
4547
after {
4648
spark.streams.active.foreach(_.stop())
4749
assert(spark.streams.active.isEmpty)
48-
assert(addedListeners.isEmpty)
50+
assert(addedListeners().isEmpty)
4951
// Make sure we don't leak any events to the next test
5052
spark.sparkContext.listenerBus.waitUntilEmpty(10000)
5153
}
@@ -148,7 +150,7 @@ class StreamingQueryListenerSuite extends StreamTest with BeforeAndAfter {
148150
assert(isListenerActive(listener1) === false)
149151
assert(isListenerActive(listener2) === true)
150152
} finally {
151-
addedListeners.foreach(spark.streams.removeListener)
153+
addedListeners().foreach(spark.streams.removeListener)
152154
}
153155
}
154156

@@ -251,6 +253,57 @@ class StreamingQueryListenerSuite extends StreamTest with BeforeAndAfter {
251253
}
252254
}
253255

256+
test("listener only posts events from queries started in the related sessions") {
257+
val session1 = spark.newSession()
258+
val session2 = spark.newSession()
259+
val collector1 = new EventCollector
260+
val collector2 = new EventCollector
261+
262+
def runQuery(session: SparkSession): Unit = {
263+
collector1.reset()
264+
collector2.reset()
265+
val mem = MemoryStream[Int](implicitly[Encoder[Int]], session.sqlContext)
266+
testStream(mem.toDS)(
267+
AddData(mem, 1, 2, 3),
268+
CheckAnswer(1, 2, 3)
269+
)
270+
session.sparkContext.listenerBus.waitUntilEmpty(5000)
271+
}
272+
273+
def assertEventsCollected(collector: EventCollector): Unit = {
274+
assert(collector.startEvent !== null)
275+
assert(collector.progressEvents.nonEmpty)
276+
assert(collector.terminationEvent !== null)
277+
}
278+
279+
def assertEventsNotCollected(collector: EventCollector): Unit = {
280+
assert(collector.startEvent === null)
281+
assert(collector.progressEvents.isEmpty)
282+
assert(collector.terminationEvent === null)
283+
}
284+
285+
assert(session1.ne(session2))
286+
assert(session1.streams.ne(session2.streams))
287+
288+
withListenerAdded(collector1, session1) {
289+
assert(addedListeners(session1).nonEmpty)
290+
291+
withListenerAdded(collector2, session2) {
292+
assert(addedListeners(session2).nonEmpty)
293+
294+
// query on session1 should send events only to collector1
295+
runQuery(session1)
296+
assertEventsCollected(collector1)
297+
assertEventsNotCollected(collector2)
298+
299+
// query on session2 should send events only to collector2
300+
runQuery(session2)
301+
assertEventsCollected(collector2)
302+
assertEventsNotCollected(collector1)
303+
}
304+
}
305+
}
306+
254307
testQuietly("ReplayListenerBus should ignore broken event jsons generated in 2.0.0") {
255308
// query-event-logs-version-2.0.0.txt has all types of events generated by
256309
// Structured Streaming in Spark 2.0.0.
@@ -298,21 +351,23 @@ class StreamingQueryListenerSuite extends StreamTest with BeforeAndAfter {
298351
}
299352
}
300353

301-
private def withListenerAdded(listener: StreamingQueryListener)(body: => Unit): Unit = {
354+
private def withListenerAdded(
355+
listener: StreamingQueryListener,
356+
session: SparkSession = spark)(body: => Unit): Unit = {
302357
try {
303358
failAfter(streamingTimeout) {
304-
spark.streams.addListener(listener)
359+
session.streams.addListener(listener)
305360
body
306361
}
307362
} finally {
308-
spark.streams.removeListener(listener)
363+
session.streams.removeListener(listener)
309364
}
310365
}
311366

312-
private def addedListeners(): Array[StreamingQueryListener] = {
367+
private def addedListeners(session: SparkSession = spark): Array[StreamingQueryListener] = {
313368
val listenerBusMethod =
314369
PrivateMethod[StreamingQueryListenerBus]('listenerBus)
315-
val listenerBus = spark.streams invokePrivate listenerBusMethod()
370+
val listenerBus = session.streams invokePrivate listenerBusMethod()
316371
listenerBus.listeners.toArray.map(_.asInstanceOf[StreamingQueryListener])
317372
}
318373

0 commit comments

Comments
 (0)