diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/InputInfoTracker.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/InputInfoTracker.scala index 4f124a1356b5a..85c4043495da0 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/InputInfoTracker.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/InputInfoTracker.scala @@ -39,6 +39,12 @@ case class StreamInputInfo( def metadataDescription: Option[String] = metadata.get(StreamInputInfo.METADATA_KEY_DESCRIPTION).map(_.toString) + + def merge(other: StreamInputInfo): StreamInputInfo = { + require(other.inputStreamId == inputStreamId, + "Can't merge two StreamInputInfo with different id") + StreamInputInfo(inputStreamId, numRecords + other.numRecords, metadata ++ other.metadata) + } } @DeveloperApi @@ -79,6 +85,28 @@ private[streaming] class InputInfoTracker(ssc: StreamingContext) extends Logging inputInfos.map(_.toMap).getOrElse(Map[Int, StreamInputInfo]()) } + /** + * Get the all the input stream's information of all specified batch times and + * merge results together. + */ + def getInfo(batchTimes: Iterable[Time]): Map[Int, StreamInputInfo] = synchronized { + val inputInfosSet = batchTimes.map{ batchTime => + val inputInfos = batchTimeToInputInfos.get(batchTime) + inputInfos.getOrElse(mutable.Map[Int, StreamInputInfo]()) + } + + val aggregatedInputInfos = mutable.Map[Int, StreamInputInfo]() + inputInfosSet.foreach(inputInfos => inputInfos.foreach { case (id, info) => + val currentInfo = aggregatedInputInfos.get(id) + if (currentInfo.isEmpty) { + aggregatedInputInfos(id) = info + } else { + aggregatedInputInfos(id) = currentInfo.get.merge(info) + } + }) + aggregatedInputInfos.toMap + } + /** Cleanup the tracked input information older than threshold batch time */ def cleanup(batchThreshTime: Time): Unit = synchronized { val timesToCleanup = batchTimeToInputInfos.keys.filter(_ < batchThreshTime) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala index 8f9421fc098ba..8c6816071a05b 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala @@ -17,6 +17,7 @@ package org.apache.spark.streaming.scheduler +import scala.collection.mutable import scala.util.{Failure, Success, Try} import org.apache.spark.SparkEnv @@ -77,6 +78,10 @@ class JobGenerator(jobScheduler: JobScheduler) extends Logging { // last batch whose completion,checkpointing and metadata cleanup has been completed private var lastProcessedBatch: Time = null + // On some batch time, a JobSet with no jobs will be submit. We record such batch time here in + // order to correct the input info of later jobSet with jobs. + private var batchTimesWithNoJob: mutable.HashSet[Time] = mutable.HashSet[Time]() + /** Start generation of jobs */ def start(): Unit = synchronized { if (eventLoop != null) return // generator has already been started @@ -249,7 +254,15 @@ class JobGenerator(jobScheduler: JobScheduler) extends Logging { graph.generateJobs(time) // generate jobs using allocated block } match { case Success(jobs) => - val streamIdToInputInfos = jobScheduler.inputInfoTracker.getInfo(time) + val streamIdToInputInfos = if (jobs.isEmpty) { + batchTimesWithNoJob.add(time) + Map.empty[Int, StreamInputInfo] + } else { + batchTimesWithNoJob.add(time) + val inputInfo = jobScheduler.inputInfoTracker.getInfo(batchTimesWithNoJob) + batchTimesWithNoJob.clear() + inputInfo + } jobScheduler.submitJobSet(JobSet(time, jobs, streamIdToInputInfos)) case Failure(e) => jobScheduler.reportError("Error generating jobs for time " + time, e) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobSet.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobSet.scala index 0baedaf275d67..351f3a3fa030e 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobSet.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobSet.scala @@ -50,7 +50,7 @@ case class JobSet( def hasStarted: Boolean = processingStartTime > 0 - def hasCompleted: Boolean = incompleteJobs.isEmpty + def hasCompleted: Boolean = incompleteJobs.isEmpty && processingStartTime >= 0 // Time taken to process all the jobs from the time they started processing // (i.e. not including the time they wait in the streaming scheduler queue) diff --git a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/InputInfoTrackerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/InputInfoTrackerSuite.scala index a7e365649d3e8..e6cee5cfa121e 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/InputInfoTrackerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/InputInfoTrackerSuite.scala @@ -76,4 +76,34 @@ class InputInfoTrackerSuite extends SparkFunSuite with BeforeAndAfter { assert(inputInfoTracker.getInfo(Time(0)).get(streamId1) === None) assert(inputInfoTracker.getInfo(Time(1))(streamId1) === inputInfo2) } + + test("merge two InputInfos") { + val inputInfo1_1 = StreamInputInfo(1, 100L, Map("ID" -> 1)) + val inputInfo1_2 = StreamInputInfo(1, 200L, Map("ID" -> 1)) + val inputInfo2 = StreamInputInfo(2, 200L, Map("ID" -> 2)) + + val mergedInfo = inputInfo1_1.merge(inputInfo1_2) + assert(mergedInfo.inputStreamId == 1) + assert(mergedInfo.numRecords == 300L) + assert(mergedInfo.metadata == Map("ID" -> 1)) + + intercept[IllegalArgumentException]{ + inputInfo1_1.merge(inputInfo2) + } + } + + test("test get InputInfo of all specified times") { + val inputInfoTracker = new InputInfoTracker(ssc) + + val streamId1 = 0 + val inputInfo1 = StreamInputInfo(streamId1, 100L) + val inputInfo2 = StreamInputInfo(streamId1, 300L) + inputInfoTracker.reportInfo(Time(0), inputInfo1) + inputInfoTracker.reportInfo(Time(1), inputInfo2) + + val times = Seq(Time(0), Time(1)) + val mergedInfo = inputInfoTracker.getInfo(times)(streamId1) + assert(mergedInfo.inputStreamId == 0) + assert(mergedInfo.numRecords == 400L) + } }