Skip to content

Commit 50f0195

Browse files
committed
Removed delay from trigger executor
1 parent e7877fd commit 50f0195

File tree

10 files changed

+134
-37
lines changed

10 files changed

+134
-37
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TriggerExecutor.scala

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -47,21 +47,22 @@ case class ProcessingTimeExecutor(processingTime: ProcessingTime, clock: Clock =
4747
extends TriggerExecutor with Logging {
4848

4949
private val intervalMs = processingTime.intervalMs
50+
require(intervalMs >= 0)
5051

51-
override def execute(batchRunner: () => Boolean): Unit = {
52+
override def execute(triggerHandler: () => Boolean): Unit = {
5253
while (true) {
53-
val batchStartTimeMs = clock.getTimeMillis()
54-
val terminated = !batchRunner()
54+
val triggerTimeMs = clock.getTimeMillis
55+
val nextTriggerTimeMs = nextBatchTime(triggerTimeMs)
56+
val terminated = !triggerHandler()
5557
if (intervalMs > 0) {
56-
val batchEndTimeMs = clock.getTimeMillis()
57-
val batchElapsedTimeMs = batchEndTimeMs - batchStartTimeMs
58+
val batchElapsedTimeMs = clock.getTimeMillis - triggerTimeMs
5859
if (batchElapsedTimeMs > intervalMs) {
5960
notifyBatchFallingBehind(batchElapsedTimeMs)
6061
}
6162
if (terminated) {
6263
return
6364
}
64-
clock.waitTillTime(nextBatchTime(batchEndTimeMs))
65+
clock.waitTillTime(nextTriggerTimeMs)
6566
} else {
6667
if (terminated) {
6768
return
@@ -70,7 +71,7 @@ case class ProcessingTimeExecutor(processingTime: ProcessingTime, clock: Clock =
7071
}
7172
}
7273

73-
/** Called when a batch falls behind. Expose for test only */
74+
/** Called when a batch falls behind */
7475
def notifyBatchFallingBehind(realElapsedTimeMs: Long): Unit = {
7576
logWarning("Current batch is falling behind. The trigger interval is " +
7677
s"${intervalMs} milliseconds, but spent ${realElapsedTimeMs} milliseconds")
@@ -83,6 +84,6 @@ case class ProcessingTimeExecutor(processingTime: ProcessingTime, clock: Clock =
8384
* an interval of `100 ms`, `nextBatchTime(nextBatchTime(0)) = 200` rather than `0`).
8485
*/
8586
def nextBatchTime(now: Long): Long = {
86-
now / intervalMs * intervalMs + intervalMs
87+
if (intervalMs == 0) now else now / intervalMs * intervalMs + intervalMs
8788
}
8889
}

sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/ProcessingTimeExecutorSuite.scala

Lines changed: 71 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,21 @@
1717

1818
package org.apache.spark.sql.execution.streaming
1919

20-
import java.util.concurrent.{CountDownLatch, TimeUnit}
20+
import scala.collection.mutable
21+
22+
import org.scalatest.concurrent.Eventually
23+
import org.scalatest.concurrent.PatienceConfiguration.Timeout
24+
import org.scalatest.concurrent.Timeouts._
25+
import org.scalatest.time.SpanSugar._
2126

2227
import org.apache.spark.SparkFunSuite
2328
import org.apache.spark.sql.streaming.ProcessingTime
24-
import org.apache.spark.util.{Clock, ManualClock, SystemClock}
29+
import org.apache.spark.sql.streaming.util.StreamManualClock
2530

2631
class ProcessingTimeExecutorSuite extends SparkFunSuite {
2732

33+
val timeout = 10.seconds
34+
2835
test("nextBatchTime") {
2936
val processingTimeExecutor = ProcessingTimeExecutor(ProcessingTime(100))
3037
assert(processingTimeExecutor.nextBatchTime(0) === 100)
@@ -35,6 +42,56 @@ class ProcessingTimeExecutorSuite extends SparkFunSuite {
3542
assert(processingTimeExecutor.nextBatchTime(150) === 200)
3643
}
3744

45+
test("trigger timing") {
46+
val executedTimes = new mutable.ArrayBuffer[Long]
47+
val manualClock = new StreamManualClock()
48+
@volatile var continueExecuting = true
49+
@volatile var lastTriggerTime = -1L
50+
@volatile var clockIncrementInTrigger = 0L
51+
val executor = ProcessingTimeExecutor(ProcessingTime("1000 milliseconds"), manualClock)
52+
val executorThread = new Thread() {
53+
override def run(): Unit = {
54+
executor.execute(() => {
55+
// Record the trigger time, increment clock if needed and
56+
lastTriggerTime = manualClock.getTimeMillis()
57+
manualClock.advance(clockIncrementInTrigger)
58+
clockIncrementInTrigger = 0 // reset this so that there are no runaway triggers
59+
continueExecuting
60+
})
61+
}
62+
}
63+
executorThread.start()
64+
// First batch should execute immediately, then executor should wait for next one
65+
eventually {
66+
assert(lastTriggerTime === 0)
67+
assert(manualClock.isStreamWaitingAt(0))
68+
assert(manualClock.isStreamWaitingFor(1000))
69+
}
70+
71+
// Second batch should execute when clock reaches the next trigger time.
72+
// If next trigger takes less than the trigger interval, executor should wait for next one
73+
clockIncrementInTrigger = 500
74+
manualClock.setTime(1000)
75+
eventually {
76+
assert(lastTriggerTime === 1000)
77+
assert(manualClock.isStreamWaitingAt(1500))
78+
assert(manualClock.isStreamWaitingFor(2000))
79+
}
80+
81+
// If next trigger takes less than the trigger interval, executor should immediately execute
82+
// another one
83+
clockIncrementInTrigger = 1500
84+
manualClock.setTime(2000)
85+
eventually {
86+
assert(lastTriggerTime === 3500)
87+
assert(manualClock.isStreamWaitingAt(3500))
88+
assert(manualClock.isStreamWaitingFor(4000))
89+
}
90+
continueExecuting = false
91+
manualClock.advance(1000)
92+
waitForThreadJoin(executorThread)
93+
}
94+
3895
test("calling nextBatchTime with the result of a previous call should return the next interval") {
3996
val intervalMS = 100
4097
val processingTimeExecutor = ProcessingTimeExecutor(ProcessingTime(intervalMS))
@@ -54,7 +111,7 @@ class ProcessingTimeExecutorSuite extends SparkFunSuite {
54111
val processingTimeExecutor = ProcessingTimeExecutor(ProcessingTime(intervalMs))
55112
processingTimeExecutor.execute(() => {
56113
batchCounts += 1
57-
// If the batch termination works well, batchCounts should be 3 after `execute`
114+
// If the batch termination works correctly, batchCounts should be 3 after `execute`
58115
batchCounts < 3
59116
})
60117
assert(batchCounts === 3)
@@ -66,9 +123,8 @@ class ProcessingTimeExecutorSuite extends SparkFunSuite {
66123
}
67124

68125
test("notifyBatchFallingBehind") {
69-
val clock = new ManualClock()
126+
val clock = new StreamManualClock()
70127
@volatile var batchFallingBehindCalled = false
71-
val latch = new CountDownLatch(1)
72128
val t = new Thread() {
73129
override def run(): Unit = {
74130
val processingTimeExecutor = new ProcessingTimeExecutor(ProcessingTime(100), clock) {
@@ -77,17 +133,24 @@ class ProcessingTimeExecutorSuite extends SparkFunSuite {
77133
}
78134
}
79135
processingTimeExecutor.execute(() => {
80-
latch.countDown()
81136
clock.waitTillTime(200)
82137
false
83138
})
84139
}
85140
}
86141
t.start()
87142
// Wait until the batch is running so that we don't call `advance` too early
88-
assert(latch.await(10, TimeUnit.SECONDS), "the batch has not yet started in 10 seconds")
143+
eventually { assert(clock.isStreamWaitingFor(200)) }
89144
clock.advance(200)
90-
t.join()
145+
waitForThreadJoin(t)
91146
assert(batchFallingBehindCalled === true)
92147
}
148+
149+
private def eventually(body: => Unit): Unit = {
150+
Eventually.eventually(Timeout(timeout)) { body }
151+
}
152+
153+
private def waitForThreadJoin(thread: Thread): Unit = {
154+
failAfter(timeout) { thread.join() }
155+
}
93156
}

sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ import org.apache.spark.sql.execution.streaming._
3333
import org.apache.spark.sql.execution.streaming.FileStreamSource.{FileEntry, SeenFilesMap}
3434
import org.apache.spark.sql.internal.SQLConf
3535
import org.apache.spark.sql.streaming.ExistsThrowsExceptionFileSystem._
36+
import org.apache.spark.sql.streaming.util.StreamManualClock
3637
import org.apache.spark.sql.test.SharedSQLContext
3738
import org.apache.spark.sql.types._
3839
import org.apache.spark.util.Utils

sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,6 @@ import java.sql.Date
2121
import java.util.concurrent.ConcurrentHashMap
2222

2323
import org.scalatest.BeforeAndAfterAll
24-
import org.scalatest.concurrent.Eventually.eventually
25-
import org.scalatest.concurrent.PatienceConfiguration.Timeout
2624

2725
import org.apache.spark.SparkException
2826
import org.apache.spark.api.java.function.FlatMapGroupsWithStateFunction
@@ -35,6 +33,7 @@ import org.apache.spark.sql.execution.RDDScanExec
3533
import org.apache.spark.sql.execution.streaming.{FlatMapGroupsWithStateExec, GroupStateImpl, MemoryStream}
3634
import org.apache.spark.sql.execution.streaming.state.{StateStore, StateStoreId, StoreUpdate}
3735
import org.apache.spark.sql.streaming.FlatMapGroupsWithStateSuite.MemoryStateStore
36+
import org.apache.spark.sql.streaming.util.StreamManualClock
3837
import org.apache.spark.sql.types.{DataType, IntegerType}
3938

4039
/** Class to check custom state types */

sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ import org.apache.spark.sql.execution.streaming._
3232
import org.apache.spark.sql.functions._
3333
import org.apache.spark.sql.internal.SQLConf
3434
import org.apache.spark.sql.sources.StreamSourceProvider
35+
import org.apache.spark.sql.streaming.util.StreamManualClock
3536
import org.apache.spark.sql.types.{IntegerType, StructField, StructType}
3637
import org.apache.spark.util.Utils
3738

sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala

Lines changed: 2 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -214,24 +214,6 @@ trait StreamTest extends QueryTest with SharedSQLContext with Timeouts {
214214
AssertOnQuery(query => { func(query); true })
215215
}
216216

217-
class StreamManualClock(time: Long = 0L) extends ManualClock(time) with Serializable {
218-
private var waitStartTime: Option[Long] = None
219-
220-
override def waitTillTime(targetTime: Long): Long = synchronized {
221-
try {
222-
waitStartTime = Some(getTimeMillis())
223-
super.waitTillTime(targetTime)
224-
} finally {
225-
waitStartTime = None
226-
}
227-
}
228-
229-
def isStreamWaitingAt(time: Long): Boolean = synchronized {
230-
waitStartTime == Some(time)
231-
}
232-
}
233-
234-
235217
/**
236218
* Executes the specified actions on the given streaming DataFrame and provides helpful
237219
* error messages in the case of failures or incorrect answers.
@@ -242,6 +224,8 @@ trait StreamTest extends QueryTest with SharedSQLContext with Timeouts {
242224
def testStream(
243225
_stream: Dataset[_],
244226
outputMode: OutputMode = OutputMode.Append)(actions: StreamAction*): Unit = synchronized {
227+
import org.apache.spark.sql.streaming.util.StreamManualClock
228+
245229
// `synchronized` is added to prevent the user from calling multiple `testStream`s concurrently
246230
// because this method assumes there is only one active query in its `StreamingQueryListener`
247231
// and it may not work correctly when multiple `testStream`s run concurrently.

sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ import org.apache.spark.sql.execution.streaming.state.StateStore
3030
import org.apache.spark.sql.expressions.scalalang.typed
3131
import org.apache.spark.sql.functions._
3232
import org.apache.spark.sql.streaming.OutputMode._
33+
import org.apache.spark.sql.streaming.util.StreamManualClock
3334

3435
object FailureSinglton {
3536
var firstTime = true

sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ import org.apache.spark.sql.{Encoder, SparkSession}
3535
import org.apache.spark.sql.execution.streaming._
3636
import org.apache.spark.sql.internal.SQLConf
3737
import org.apache.spark.sql.streaming.StreamingQueryListener._
38+
import org.apache.spark.sql.streaming.util.StreamManualClock
3839
import org.apache.spark.util.JsonProtocol
3940

4041
class StreamingQueryListenerSuite extends StreamTest with BeforeAndAfter {

sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ import org.apache.spark.SparkException
3434
import org.apache.spark.sql.execution.streaming._
3535
import org.apache.spark.sql.functions._
3636
import org.apache.spark.sql.internal.SQLConf
37-
import org.apache.spark.sql.streaming.util.{BlockingSource, MockSourceProvider}
37+
import org.apache.spark.sql.streaming.util.{BlockingSource, MockSourceProvider, StreamManualClock}
3838
import org.apache.spark.util.ManualClock
3939

4040

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.streaming.util
19+
20+
import org.apache.spark.util.ManualClock
21+
22+
/** ManualClock used for streaming tests */
23+
class StreamManualClock(time: Long = 0L) extends ManualClock(time) with Serializable {
24+
private var waitStartTime: Option[Long] = None
25+
private var waitTargetTime: Option[Long] = None
26+
27+
override def waitTillTime(targetTime: Long): Long = synchronized {
28+
try {
29+
waitStartTime = Some(getTimeMillis())
30+
waitTargetTime = Some(targetTime)
31+
super.waitTillTime(targetTime)
32+
} finally {
33+
waitStartTime = None
34+
waitTargetTime = None
35+
}
36+
}
37+
38+
def isStreamWaitingAt(time: Long): Boolean = synchronized {
39+
waitStartTime == Some(time)
40+
}
41+
42+
def isStreamWaitingFor(target: Long): Boolean = synchronized {
43+
waitTargetTime == Some(target)
44+
}
45+
}
46+

0 commit comments

Comments
 (0)