Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ import org.apache.spark.sql.execution.streaming._
import org.apache.spark.sql.functions.{count, window}
import org.apache.spark.sql.kafka010.KafkaSourceProvider._
import org.apache.spark.sql.streaming.{ProcessingTime, StreamTest}
import org.apache.spark.sql.streaming.util.StreamManualClock
import org.apache.spark.sql.test.{SharedSQLContext, TestSparkSession}
import org.apache.spark.util.Utils

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,21 +47,22 @@ case class ProcessingTimeExecutor(processingTime: ProcessingTime, clock: Clock =
extends TriggerExecutor with Logging {

private val intervalMs = processingTime.intervalMs
require(intervalMs >= 0)

override def execute(batchRunner: () => Boolean): Unit = {
override def execute(triggerHandler: () => Boolean): Unit = {
while (true) {
val batchStartTimeMs = clock.getTimeMillis()
val terminated = !batchRunner()
val triggerTimeMs = clock.getTimeMillis
val nextTriggerTimeMs = nextBatchTime(triggerTimeMs)
val terminated = !triggerHandler()
if (intervalMs > 0) {
val batchEndTimeMs = clock.getTimeMillis()
val batchElapsedTimeMs = batchEndTimeMs - batchStartTimeMs
val batchElapsedTimeMs = clock.getTimeMillis - triggerTimeMs
if (batchElapsedTimeMs > intervalMs) {
notifyBatchFallingBehind(batchElapsedTimeMs)
}
if (terminated) {
return
}
clock.waitTillTime(nextBatchTime(batchEndTimeMs))
clock.waitTillTime(nextTriggerTimeMs)
} else {
if (terminated) {
return
Expand All @@ -70,7 +71,7 @@ case class ProcessingTimeExecutor(processingTime: ProcessingTime, clock: Clock =
}
}

/** Called when a batch falls behind. Expose for test only */
/** Called when a batch falls behind */
def notifyBatchFallingBehind(realElapsedTimeMs: Long): Unit = {
logWarning("Current batch is falling behind. The trigger interval is " +
s"${intervalMs} milliseconds, but spent ${realElapsedTimeMs} milliseconds")
Expand All @@ -83,6 +84,6 @@ case class ProcessingTimeExecutor(processingTime: ProcessingTime, clock: Clock =
* an interval of `100 ms`, `nextBatchTime(nextBatchTime(0)) = 200` rather than `0`).
*/
def nextBatchTime(now: Long): Long = {
now / intervalMs * intervalMs + intervalMs
if (intervalMs == 0) now else now / intervalMs * intervalMs + intervalMs
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the doc seems wrong btw, mind fixing it? nextBatchTime(nextBatchTime(0)) = 100 or am I understanding it wrong?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

spoken offline, this isnt wrong.

}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,24 @@

package org.apache.spark.sql.execution.streaming

import java.util.concurrent.{CountDownLatch, TimeUnit}
import java.util.concurrent.ConcurrentHashMap

import scala.collection.mutable

import org.eclipse.jetty.util.ConcurrentHashSet
import org.scalatest.concurrent.Eventually
import org.scalatest.concurrent.PatienceConfiguration.Timeout
import org.scalatest.concurrent.Timeouts._
import org.scalatest.time.SpanSugar._

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.streaming.ProcessingTime
import org.apache.spark.util.{Clock, ManualClock, SystemClock}
import org.apache.spark.sql.streaming.util.StreamManualClock

class ProcessingTimeExecutorSuite extends SparkFunSuite {

val timeout = 10.seconds

test("nextBatchTime") {
val processingTimeExecutor = ProcessingTimeExecutor(ProcessingTime(100))
assert(processingTimeExecutor.nextBatchTime(0) === 100)
Expand All @@ -35,6 +45,57 @@ class ProcessingTimeExecutorSuite extends SparkFunSuite {
assert(processingTimeExecutor.nextBatchTime(150) === 200)
}

test("trigger timing") {
val triggerTimes = new ConcurrentHashSet[Int]
val clock = new StreamManualClock()
@volatile var continueExecuting = true
@volatile var clockIncrementInTrigger = 0L
val executor = ProcessingTimeExecutor(ProcessingTime("1000 milliseconds"), clock)
val executorThread = new Thread() {
override def run(): Unit = {
executor.execute(() => {
// Record the trigger time, increment clock if needed and
triggerTimes.add(clock.getTimeMillis.toInt)
clock.advance(clockIncrementInTrigger)
clockIncrementInTrigger = 0 // reset this so that there are no runaway triggers
continueExecuting
})
}
}
executorThread.start()
// First batch should execute immediately, then executor should wait for next one
eventually {
assert(triggerTimes.contains(0))
assert(clock.isStreamWaitingAt(0))
assert(clock.isStreamWaitingFor(1000))
}

// Second batch should execute when clock reaches the next trigger time.
// If next trigger takes less than the trigger interval, executor should wait for next one
clockIncrementInTrigger = 500
clock.setTime(1000)
eventually {
assert(triggerTimes.contains(1000))
assert(clock.isStreamWaitingAt(1500))
assert(clock.isStreamWaitingFor(2000))
}

// If next trigger takes less than the trigger interval, executor should immediately execute
// another one
clockIncrementInTrigger = 1500
clock.setTime(2000) // allow another trigger by setting clock to 2000
eventually {
// Since the next trigger will take 1500 (which is more than trigger interval of 1000)
// executor will immediately execute another trigger
assert(triggerTimes.contains(2000) && triggerTimes.contains(3500))
assert(clock.isStreamWaitingAt(3500))
assert(clock.isStreamWaitingFor(4000))
}
continueExecuting = false
clock.advance(1000)
waitForThreadJoin(executorThread)
}

test("calling nextBatchTime with the result of a previous call should return the next interval") {
val intervalMS = 100
val processingTimeExecutor = ProcessingTimeExecutor(ProcessingTime(intervalMS))
Expand All @@ -54,7 +115,7 @@ class ProcessingTimeExecutorSuite extends SparkFunSuite {
val processingTimeExecutor = ProcessingTimeExecutor(ProcessingTime(intervalMs))
processingTimeExecutor.execute(() => {
batchCounts += 1
// If the batch termination works well, batchCounts should be 3 after `execute`
// If the batch termination works correctly, batchCounts should be 3 after `execute`
batchCounts < 3
})
assert(batchCounts === 3)
Expand All @@ -66,9 +127,8 @@ class ProcessingTimeExecutorSuite extends SparkFunSuite {
}

test("notifyBatchFallingBehind") {
val clock = new ManualClock()
val clock = new StreamManualClock()
@volatile var batchFallingBehindCalled = false
val latch = new CountDownLatch(1)
val t = new Thread() {
override def run(): Unit = {
val processingTimeExecutor = new ProcessingTimeExecutor(ProcessingTime(100), clock) {
Expand All @@ -77,17 +137,24 @@ class ProcessingTimeExecutorSuite extends SparkFunSuite {
}
}
processingTimeExecutor.execute(() => {
latch.countDown()
clock.waitTillTime(200)
false
})
}
}
t.start()
// Wait until the batch is running so that we don't call `advance` too early
assert(latch.await(10, TimeUnit.SECONDS), "the batch has not yet started in 10 seconds")
eventually { assert(clock.isStreamWaitingFor(200)) }
clock.advance(200)
t.join()
waitForThreadJoin(t)
assert(batchFallingBehindCalled === true)
}

private def eventually(body: => Unit): Unit = {
Eventually.eventually(Timeout(timeout)) { body }
}

private def waitForThreadJoin(thread: Thread): Unit = {
failAfter(timeout) { thread.join() }
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ import org.apache.spark.sql.execution.streaming._
import org.apache.spark.sql.execution.streaming.FileStreamSource.{FileEntry, SeenFilesMap}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.streaming.ExistsThrowsExceptionFileSystem._
import org.apache.spark.sql.streaming.util.StreamManualClock
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types._
import org.apache.spark.util.Utils
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,6 @@ import java.sql.Date
import java.util.concurrent.ConcurrentHashMap

import org.scalatest.BeforeAndAfterAll
import org.scalatest.concurrent.Eventually.eventually
import org.scalatest.concurrent.PatienceConfiguration.Timeout

import org.apache.spark.SparkException
import org.apache.spark.api.java.function.FlatMapGroupsWithStateFunction
Expand All @@ -35,6 +33,7 @@ import org.apache.spark.sql.execution.RDDScanExec
import org.apache.spark.sql.execution.streaming.{FlatMapGroupsWithStateExec, GroupStateImpl, MemoryStream}
import org.apache.spark.sql.execution.streaming.state.{StateStore, StateStoreId, StoreUpdate}
import org.apache.spark.sql.streaming.FlatMapGroupsWithStateSuite.MemoryStateStore
import org.apache.spark.sql.streaming.util.StreamManualClock
import org.apache.spark.sql.types.{DataType, IntegerType}

/** Class to check custom state types */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ import org.apache.spark.sql.execution.streaming._
import org.apache.spark.sql.functions._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.sources.StreamSourceProvider
import org.apache.spark.sql.streaming.util.StreamManualClock
import org.apache.spark.sql.types.{IntegerType, StructField, StructType}
import org.apache.spark.util.Utils

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -214,24 +214,6 @@ trait StreamTest extends QueryTest with SharedSQLContext with Timeouts {
AssertOnQuery(query => { func(query); true })
}

class StreamManualClock(time: Long = 0L) extends ManualClock(time) with Serializable {
private var waitStartTime: Option[Long] = None

override def waitTillTime(targetTime: Long): Long = synchronized {
try {
waitStartTime = Some(getTimeMillis())
super.waitTillTime(targetTime)
} finally {
waitStartTime = None
}
}

def isStreamWaitingAt(time: Long): Boolean = synchronized {
waitStartTime == Some(time)
}
}


/**
* Executes the specified actions on the given streaming DataFrame and provides helpful
* error messages in the case of failures or incorrect answers.
Expand All @@ -242,6 +224,8 @@ trait StreamTest extends QueryTest with SharedSQLContext with Timeouts {
def testStream(
_stream: Dataset[_],
outputMode: OutputMode = OutputMode.Append)(actions: StreamAction*): Unit = synchronized {
import org.apache.spark.sql.streaming.util.StreamManualClock

// `synchronized` is added to prevent the user from calling multiple `testStream`s concurrently
// because this method assumes there is only one active query in its `StreamingQueryListener`
// and it may not work correctly when multiple `testStream`s run concurrently.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import org.apache.spark.sql.execution.streaming.state.StateStore
import org.apache.spark.sql.expressions.scalalang.typed
import org.apache.spark.sql.functions._
import org.apache.spark.sql.streaming.OutputMode._
import org.apache.spark.sql.streaming.util.StreamManualClock

object FailureSinglton {
var firstTime = true
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ import org.apache.spark.sql.{Encoder, SparkSession}
import org.apache.spark.sql.execution.streaming._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.streaming.StreamingQueryListener._
import org.apache.spark.sql.streaming.util.StreamManualClock
import org.apache.spark.util.JsonProtocol

class StreamingQueryListenerSuite extends StreamTest with BeforeAndAfter {
Expand Down
Loading