@@ -68,13 +68,9 @@ trait DStreamCheckpointTester { self: SparkFunSuite =>
6868 require(StreamingContext .getActive().isEmpty,
6969 " Cannot run test with already active streaming context" )
7070
71- // Current code assumes that:
72- // number of inputs = number of outputs = number of batches to be run
71+ // Current code assumes that number of batches to be run = number of inputs
7372 val totalNumBatches = input.size
74- val nextNumBatches = totalNumBatches - numBatchesBeforeRestart
75- val initialNumExpectedOutputs = numBatchesBeforeRestart
76- val nextNumExpectedOutputs = expectedOutput.size - initialNumExpectedOutputs + 1
77- // because the last batch will be processed again
73+ val batchDurationMillis = batchDuration.milliseconds
7874
7975 // Setup the stream computation
8076 val checkpointDir = Utils .createTempDir(this .getClass.getSimpleName()).toString
@@ -92,20 +88,20 @@ trait DStreamCheckpointTester { self: SparkFunSuite =>
9288 ssc.checkpoint(checkpointDir)
9389
9490 // Do the computation for initial number of batches, create checkpoint file and quit
95- generateAndAssertOutput [V ](ssc, batchDuration, checkpointDir, numBatchesBeforeRestart ,
96- expectedOutput.take( numBatchesBeforeRestart), stopSparkContextAfterTest)
97-
91+ val beforeRestartOutput = generateOutput [V ](ssc,
92+ Time (batchDurationMillis * numBatchesBeforeRestart), checkpointDir , stopSparkContextAfterTest)
93+ assertOutput(beforeRestartOutput, expectedOutput, beforeRestart = true )
9894 // Restart and complete the computation from checkpoint file
99- // scalastyle:off println
100- print(
95+ logInfo(
10196 " \n -------------------------------------------\n " +
10297 " Restarting stream computation " +
10398 " \n -------------------------------------------\n "
10499 )
105- // scalastyle:on println
100+
106101 val restartedSsc = new StreamingContext (checkpointDir)
107- generateAndAssertOutput[V ](restartedSsc, batchDuration, checkpointDir, nextNumBatches,
108- expectedOutput.takeRight(nextNumExpectedOutputs), stopSparkContextAfterTest)
102+ val afterRestartOutput = generateOutput[V ](restartedSsc,
103+ Time (batchDurationMillis * totalNumBatches), checkpointDir, stopSparkContextAfterTest)
104+ assertOutput(afterRestartOutput, expectedOutput, beforeRestart = false )
109105 }
110106
111107 protected def createContextForCheckpointOperation (batchDuration : Duration ): StreamingContext = {
@@ -114,32 +110,30 @@ trait DStreamCheckpointTester { self: SparkFunSuite =>
114110 new StreamingContext (SparkContext .getOrCreate(conf), batchDuration)
115111 }
116112
117- private def generateAndAssertOutput [V : ClassTag ](
113+ private def generateOutput [V : ClassTag ](
118114 ssc : StreamingContext ,
119- batchDuration : Duration ,
115+ targetBatchTime : Time ,
120116 checkpointDir : String ,
121- numBatchesToRun : Int ,
122- expectedOutput : Seq [Seq [V ]],
123117 stopSparkContext : Boolean
124- ) {
118+ ): Seq [ Seq [ V ]] = {
125119 try {
120+ val batchDuration = ssc.graph.batchDuration
126121 val batchCounter = new BatchCounter (ssc)
127122 ssc.start()
128- val numBatches = expectedOutput.size
129123 val clock = ssc.scheduler.clock.asInstanceOf [ManualClock ]
130- // scalastyle:off println
124+ val currentTime = clock.getTimeMillis()
125+
131126 logInfo(" Manual clock before advancing = " + clock.getTimeMillis())
132- clock.advance((batchDuration * numBatches) .milliseconds)
127+ clock.setTime(targetBatchTime .milliseconds)
133128 logInfo(" Manual clock after advancing = " + clock.getTimeMillis())
134- // scalastyle:on println
135129
136130 val outputStream = ssc.graph.getOutputStreams().filter { dstream =>
137131 dstream.isInstanceOf [TestOutputStreamWithPartitions [V ]]
138132 }.head.asInstanceOf [TestOutputStreamWithPartitions [V ]]
139133
140134 eventually(timeout(10 seconds)) {
141135 ssc.awaitTerminationOrTimeout(10 )
142- assert(batchCounter.getNumCompletedBatches === numBatchesToRun )
136+ assert(batchCounter.getLastCompletedBatchTime === targetBatchTime )
143137 }
144138
145139 eventually(timeout(10 seconds)) {
@@ -150,17 +144,30 @@ trait DStreamCheckpointTester { self: SparkFunSuite =>
150144 // are written to make sure that both of them have been written.
151145 assert(checkpointFilesOfLatestTime.size === 2 )
152146 }
147+ outputStream.output.map(_.flatten)
153148
154- val output = outputStream.output.map(_.flatten)
155- val setComparison = output.zip(expectedOutput).forall { case (o, e) => o.toSet === e.toSet }
156- assert(setComparison, s " set comparison failed \n " +
157- s " Expected output ( ${expectedOutput.size} items): \n ${expectedOutput.mkString(" \n " )}\n " +
158- s " Generated output ( ${output.size} items): ${output.mkString(" \n " )}"
159- )
160149 } finally {
161150 ssc.stop(stopSparkContext = stopSparkContext)
162151 }
163152 }
153+
154+ private def assertOutput [V : ClassTag ](
155+ output : Seq [Seq [V ]],
156+ expectedOutput : Seq [Seq [V ]],
157+ beforeRestart : Boolean ): Unit = {
158+ val expectedPartialOutput = if (beforeRestart) {
159+ expectedOutput.take(output.size)
160+ } else {
161+ expectedOutput.takeRight(output.size)
162+ }
163+ val setComparison = output.zip(expectedPartialOutput).forall {
164+ case (o, e) => o.toSet === e.toSet
165+ }
166+ assert(setComparison, s " set comparison failed \n " +
167+ s " Expected output items: \n ${expectedPartialOutput.mkString(" \n " )}\n " +
168+ s " Generated output items: ${output.mkString(" \n " )}"
169+ )
170+ }
164171}
165172
166173/**
0 commit comments