Skip to content

Commit 50cf3e6

Browse files
committed
Fix ForeachSink with watermark + append
1 parent 5a92dc7 commit 50cf3e6

File tree

2 files changed

+62
-34
lines changed

2 files changed

+62
-34
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ForeachSink.scala

Lines changed: 12 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,8 @@
1818
package org.apache.spark.sql.execution.streaming
1919

2020
import org.apache.spark.TaskContext
21-
import org.apache.spark.rdd.RDD
22-
import org.apache.spark.sql.{DataFrame, Dataset, Encoder, ForeachWriter}
23-
import org.apache.spark.sql.catalyst.plans.logical.CatalystSerde
21+
import org.apache.spark.sql.{DataFrame, Encoder, ForeachWriter}
22+
import org.apache.spark.sql.catalyst.encoders.encoderFor
2423

2524
/**
2625
* A [[Sink]] that forwards all data into [[ForeachWriter]] according to the contract defined by
@@ -32,46 +31,26 @@ import org.apache.spark.sql.catalyst.plans.logical.CatalystSerde
3231
class ForeachSink[T : Encoder](writer: ForeachWriter[T]) extends Sink with Serializable {
3332

3433
override def addBatch(batchId: Long, data: DataFrame): Unit = {
35-
// TODO: Refine this method when SPARK-16264 is resolved; see comments below.
36-
3734
// This logic should've been as simple as:
3835
// ```
3936
// data.as[T].foreachPartition { iter => ... }
4037
// ```
4138
//
4239
// Unfortunately, doing that would just break the incremental planing. The reason is,
43-
// `Dataset.foreachPartition()` would further call `Dataset.rdd()`, but `Dataset.rdd()` just
44-
// does not support `IncrementalExecution`.
40+
// `Dataset.foreachPartition()` would further call `Dataset.rdd()`, but `Dataset.rdd()` will
41+
// create a new plan. Because StreamExecution uses the existing plan to collect metrics and
42+
// update watermark, we should never create a new plan. Otherwise, metrics and watermark are
43+
// updated in the new plan, and StreamExecution cannot retrieval them.
4544
//
46-
// So as a provisional fix, below we've made a special version of `Dataset` with its `rdd()`
47-
// method supporting incremental planning. But in the long run, we should generally make newly
48-
// created Datasets use `IncrementalExecution` where necessary (which is SPARK-16264 tries to
49-
// resolve).
50-
val incrementalExecution = data.queryExecution.asInstanceOf[IncrementalExecution]
51-
val datasetWithIncrementalExecution =
52-
new Dataset(data.sparkSession, incrementalExecution, implicitly[Encoder[T]]) {
53-
override lazy val rdd: RDD[T] = {
54-
val objectType = exprEnc.deserializer.dataType
55-
val deserialized = CatalystSerde.deserialize[T](logicalPlan)
56-
57-
// was originally: sparkSession.sessionState.executePlan(deserialized) ...
58-
val newIncrementalExecution = new IncrementalExecution(
59-
this.sparkSession,
60-
deserialized,
61-
incrementalExecution.outputMode,
62-
incrementalExecution.checkpointLocation,
63-
incrementalExecution.currentBatchId,
64-
incrementalExecution.currentEventTimeWatermark)
65-
newIncrementalExecution.toRdd.mapPartitions { rows =>
66-
rows.map(_.get(0, objectType))
67-
}.asInstanceOf[RDD[T]]
68-
}
69-
}
70-
datasetWithIncrementalExecution.foreachPartition { iter =>
45+
// Hence, we need to manually convert internal rows to objects using encoder.
46+
val encoder = encoderFor[T].resolveAndBind(
47+
data.logicalPlan.output,
48+
data.sparkSession.sessionState.analyzer)
49+
data.queryExecution.toRdd.foreachPartition { iter =>
7150
if (writer.open(TaskContext.getPartitionId(), batchId)) {
7251
try {
7352
while (iter.hasNext) {
74-
writer.process(iter.next())
53+
writer.process(encoder.fromRow(iter.next()))
7554
}
7655
} catch {
7756
case e: Throwable =>

sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/ForeachSinkSuite.scala

Lines changed: 50 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,7 @@ class ForeachSinkSuite extends StreamTest with SharedSQLContext with BeforeAndAf
171171
}
172172
}
173173

174-
test("foreach with watermark") {
174+
test("foreach with watermark: complete") {
175175
val inputData = MemoryStream[Int]
176176

177177
val windowedAggregation = inputData.toDF()
@@ -204,6 +204,55 @@ class ForeachSinkSuite extends StreamTest with SharedSQLContext with BeforeAndAf
204204
query.stop()
205205
}
206206
}
207+
208+
test("foreach with watermark: append") {
209+
val inputData = MemoryStream[Int]
210+
211+
val windowedAggregation = inputData.toDF()
212+
.withColumn("eventTime", $"value".cast("timestamp"))
213+
.withWatermark("eventTime", "10 seconds")
214+
.groupBy(window($"eventTime", "5 seconds") as 'window)
215+
.agg(count("*") as 'count)
216+
.select($"count".as[Long])
217+
.map(_.toInt)
218+
.repartition(1)
219+
220+
val query = windowedAggregation
221+
.writeStream
222+
.outputMode(OutputMode.Append)
223+
.foreach(new TestForeachWriter())
224+
.start()
225+
try {
226+
inputData.addData(10, 11, 12)
227+
query.processAllAvailable()
228+
inputData.addData(25) // Advance watermark to 15 seconds
229+
query.processAllAvailable()
230+
inputData.addData(25) // Evict items less than previous watermark
231+
query.processAllAvailable()
232+
233+
// There should be 3 batches and only does the last batch contain a value.
234+
val allEvents = ForeachSinkSuite.allEvents()
235+
assert(allEvents.size === 3)
236+
val expectedEvents = Seq(
237+
Seq(
238+
ForeachSinkSuite.Open(partition = 0, version = 0),
239+
ForeachSinkSuite.Close(None)
240+
),
241+
Seq(
242+
ForeachSinkSuite.Open(partition = 0, version = 1),
243+
ForeachSinkSuite.Close(None)
244+
),
245+
Seq(
246+
ForeachSinkSuite.Open(partition = 0, version = 2),
247+
ForeachSinkSuite.Process(value = 3),
248+
ForeachSinkSuite.Close(None)
249+
)
250+
)
251+
assert(allEvents === expectedEvents)
252+
} finally {
253+
query.stop()
254+
}
255+
}
207256
}
208257

209258
/** A global object to collect events in the executor */

0 commit comments

Comments
 (0)